深度学习应用技巧4-模型融合:投票法、加权平均法、集成模型法

大家好,我是微学AI,今天给大家介绍一下,深度学习中的模型融合。它是将多个深度学习模型或其预测结果结合起来,以提高模型整体性能的一种技术。

深度学习中的模型融合技术,也叫做集成学习,是指同时使用多个模型来进行预测或分类,将它们的结果结合起来,从而获得更准确、更鲁棒的结果。这种方法能够弥补单个模型的不足之处,提高模型的性能。 

常见的深度学习模型包括卷积神经网络(CNN)、循环神经网络(RNN)等,在实际应用中,通常会使用多个模型来解决同一个任务。然而,单独使用每个模型可能会存在过拟合、欠拟合、训练时间长等一些问题。这时,模型融合技术就派上用场了。 对于模型融合技术,其主要的思想是结合多个模型的优点,减少缺点,从而提高整体的性能。

一、模型融合技巧主要包括以下几个方面

1. 投票法: 对多个相同类型的模型进行训练,最后通过投票的方式选择输出结果最多的类别作为最终的预测结果。在实践中,通常会使用奇数个模型,以避免出现相同数量的投票结果。

2. 加权平均法: 对多个相同类型的模型的输出结果进行加权平均。采用加权平均法融合的模型,可以根据效果不同,分配不同的权重。

3. 集成多种不同类型的模型: 在深度学习中,常常会使用不同类型的模型,如 CNN、 RNN 、 LSTM 等,将它们进行集成,综合利用不同模型的优点,进一步提高系统的性能。

4. 提前停止模型训练: 训练多个模型时,如果其中一个模型已经达到了最优的状态,可以停止继续训练,以达到快速训练过程和提高融合效果的目的。

模型融合主要提高了深度学习模型的表现力和泛化性能,更好的解决了过拟合等问题。在选择模型融合技巧时,可以根据具体实际应用选择不同的融合方法,灵活运用各种方法,从而得到更好的模型效果。

应用场景想象:

想象你正在参加一个模型比赛,需要对一些数据进行预测。你仅使用了一个模型进行预测,但是你发现这个模型可能存在某些缺陷,导致预测结果不够准确,于是你想到使用模型融合技术。 你开始集成三个不同的模型,每一个模型都有自己的特点和优缺点。为了得到集成模型的预测结果,你可以采用堆叠法,即把三个模型的预测结果作为输入特征,再训练一个新的模型进行预测。通过堆叠,你得到了一个更加准确的预测结果。 如果你觉得模型之间存在评估的差异,你也可以使用加权平均法来集成模型。你可以根据每一个模型的表现,为它们分配一定的权重,然后根据这些权重,对它们输出的结果进行加权平均,从而得到更加精确的预测结果。 最后,你也可以使用投票法,这种方法是将多个模型的预测结果进行投票,选择获得最多票数的结果作为最终预测结果。它会适用于模型数量较多的情况,即使其中某个模型出现了不准确的情况,也不会对最终结果有太大的影响。 通过这些集成模型的方法,你可以在深度学习中获得更为准确、稳定的预测结果,从而提升模型的性能和可靠性。

二、模型融合代码样例

下面我将利用投票法、加权平均法和集成模型法三种方式对CNN网络进行融合。

import numpy as np
from sklearn.metrics import accuracy_score
from tensorflow import keras
from tensorflow.keras import layers

# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()

# 将像素值缩放到 0-1 范围内
train_images = train_images.astype('float32') / 255.
test_images = test_images.astype('float32') / 255.

# 将标签转换为 one-hot 编码
train_labels = keras.utils.to_categorical(train_labels, 10)
test_labels = keras.utils.to_categorical(test_labels, 10)

# 定义 CNN 模型
def create_model():
    model = keras.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    return model

# 创建多个 CNN 模型
models = []
for i in range(3):
    model = create_model()
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_images, train_labels, epochs=5, batch_size=128, verbose=1)
    models.append(model)

1.投票法:对测试集进行预测,并取预测结果的众数作为最终预测结果

predictions = []
for model in models:
    predictions.append(model.predict(test_images))
y_pred = np.argmax(np.round(np.mean(predictions, axis=0)), axis=1)
y_true = np.argmax(test_labels, axis=1)

# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用投票法进行模型融合的准确率:", acc)

2.加权平均法:为每个模型定义一个权重,并将预测结果加权平均

weights = [0.2, 0.3, 0.5]
predictions = []
for i, model in enumerate(models):
    prediction = model.predict(test_images)
    predictions.append(weights[i] * prediction)
y_pred = np.argmax(np.sum(predictions, axis=0), axis=1)
y_true = np.argmax(test_labels, axis=1)

# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用加权平均法进行模型融合的准确率:", acc)

3.模型集成法:将多个 CNN 模型堆叠在一起,将其看作一个更强大的模型,并对测试集进行预测

inputs = keras.Input(shape=(28, 28, 1))
outputs = [model(inputs) for model in models]
y = layers.Average()(outputs)
ensemble_model = keras.Model(inputs=inputs, outputs=y)
ensemble_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练集成模型
ensemble_model.fit(train_images, train_labels, epochs=5, batch_size=64)

# 测试集成模型
predictions = ensemble_model.predict(test_images)
y_pred = np.argmax(predictions, axis=1)
y_true = np.argmax(test_labels, axis=1)

# 计算准确率
acc = accuracy_score(y_true, y_pred)
print("使用模型集成方法进行模型融合的准确率:", acc)

最后我们可以看三种方法的准确率,在实战案例中根据业务需求进行方法的选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/2681.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

cjson文件格式介绍

cjson是一种轻量级的JSON解析库,它支持将JSON格式的数据转换为C语言中的数据结构,同时也支持将C语言中的数据结构转换为JSON格式的数据。cjson的文件格式是指在使用cjson库时,将JSON格式的数据存储在文件中,然后通过cjson库读取文…

音视频开发—MediaCodec 解码H264/H265码流视频

使用MediaCodec目的 MediaCodec是Android底层多媒体框架的一部分,通常与MediaExtractor、MediaMuxer、AudioTrack结合使用,可以编码H264、H265、AAC、3gp等常见的音视频格式 MediaCodec工作原理是处理输入数据以产生输出数据 MediaCodec工作流程 Med…

SpringBoot 结合RabbitMQ与Redis实现商品的并发下单【SpringBoot系列12】

SpringCloud 大型系列课程正在制作中,欢迎大家关注与提意见。 程序员每天的CV 与 板砖,也要知其所以然,本系列课程可以帮助初学者学习 SpringBooot 项目开发 与 SpringCloud 微服务系列项目开发 1 项目准备 SpringBoot 整合 RabbitMQ 消息队…

Linux下的指令(常用的指令,以及案例展示)

目录 一:模式切换图 二:vi和vim相关操作 三:开机、重启用户的登录注销 四:用户管理(切换、添加、删除、查询) 4.1 基本管理 4.2 用户组​ 五:实用指令 5.1 指定运行级别 5.2 找回运行密…

QEMU启动ARM32 Linux内核

目录前言前置知识ARM Versatile Express开发板简介ARM处理器家族简介安装qemu-system-arm安装交叉编译工具交叉编译ARM32 Linux内核交叉编译ARM32 Busybox使用busybox制作initramfs使用QEMU启动ARM32 Linux内核模拟vexpress-a9开发板模拟vexpress-a15开发板参考前言 本文介绍采…

Thread类的基本用法

Thread类的基本用法🔎1.线程创建🌻继承Thread类🌼继承Thread重写run()方法🌼继承Thread匿名内部类🌻实现Runnable接口🌼实现Runnable接口重写run()方法🌼实现Runnable接口匿名内部类&#x1f33…

spring5(四):IOC 操作 Bean 管理(基于注解方式)

IOC操作Bean管理(基于xml方式)前言一、注解1、概述二、入门案例1、Bean 的创建2、Bean的自动装配2.1 Autowired2、Qualifie3、Resource4、Value3、扫描组件3.1 配置文件版3.2 注解版4、测试前言 本博主将用CSDN记录软件开发求学之路上亲身所得与所学的心…

SQL优化13连问,收藏好!

1.日常工作中,你是怎么优化SQL的? 大家可以从这几个维度回答这个问题: 分析慢查询日志 使用explain查看执行计划 索引优化 深分页优化 避免全表扫描 避免返回不必要的数据(如select具体字段而不是select*) 使用…

Docker常规安装简介

总体步骤 搜索镜像拉取镜像查看镜像启动镜像,服务端口映射停止容器移除容器 案例 安装tomcat docker hub上面查找tomcat镜像,docker search tomcat从docker hub上拉取tomcat镜像到本地 docker pull tomcatdocker images查看是否有拉取到的tomcat 使用tomcat镜像创…

windows微服务部署

windows部署一.nginx部署1.nginx 官网下载2. 配置nginx3.配置nigix 防止nigix刷新404不生效二.配置redis部署成服务1.在系统配置中 配置为系统变量2.打开快捷登录服务管理#3. 开启redis三.windows部署jar包一.nginx部署 1.nginx 官网下载 地址 官网地址 安装 windows版本 可安…

天猫2月咖啡行业数据分析(咖啡品牌销量排行)

随着人们消费水平的提高以及休闲、办公等场景化的需要,咖啡已成为越来越多人日常生活中的必需品,咖啡行业的市场规模也在不断扩大。并且,随着咖啡品牌不断发力线上赛道,咖啡的电商化之路也越来越成熟,而与此同时&#…

STM32MP157-QT-串口调试助手设计

文章目录前言STM32MP157串口调试助手widget.uipro文件widget.h头文件槽函数成员声明widget.cpp头文件扫描串口并添加到下拉列表串口配置参数获取配置参数打开、关闭串口读取数据信号读数据函数代码发送数据清空接收、发送区发送新行定时发送移植安装含编译 Qt 应用程序的交叉编…

PH电极酸碱度检测

最近做了一个项目是关于PH电极测酸碱度的一个仪器。 简单地说:玻璃电极是一种氢离子选择性电极,相当于一个对玻璃膜两侧氢离子浓度差异能产生附加电势差的“盐桥”,一般的盐桥是为了消除浓差电势或者液体接触电势这种附加电势差,玻…

I2C和SPI总线以及通信

通讯属性 概括 Serial/parallel 串行/并行Synchronous/asynchronous 同步/异步Point-to-point / bus 点对点 总线Half-duplex/full-duplex 半双工/全双工Master-slave/ equal partners 主从/对等single-ending / differential 单端/差分 点对点和总线 点对点通讯 只有两个通…

Springboot Long类型数据太长返回给前端,精度丢失问题 复现、解决

前言 惯例,收到兄弟求救,关于long类型丢失精度的问题: 存在一个初学者不会,就会有第二个初学者不会,所以我出手。 正文 不多说,开搞。 如题, 后端返回的数据 给到 前端, Long类型数…

DevOps系列文章 - K8S构建Jenkins持续集成平台

k8s安装直接跳过,用Kubeadm安装也比较简单安装和配置 NFSNFS简介NFS(Network File System),它最大的功能就是可以通过网络,让不同的机器、不同的操作系统可以共享彼此的文件。我们可以利用NFS共享Jenkins运行的配置文件…

一文了解Jackson注解@JsonFormat及失效解决

背景 项目中使用WRITE_DATES_AS_TIMESTAMPS: true转换日期格式为时间戳未生效。如下: spring:jackson:time-zone: Asia/Shanghaiserialization:WRITE_DATES_AS_TIMESTAMPS: true尝试是否关于时间的注解是否会生效,使用JsonForma和JsonFiled均失效。 常…

功能测试转型测试开发年薪27W,又一名功能测试摆脱点点点,进了大厂

咱们直接开门见山,没错我的粉丝向我投来了喜报,从功能测试转型测试开发,进入大厂,摆脱最初级的点点点功能测试,拿到高薪,遗憾的是,这名粉丝因为个人原因没有经过指导就去面试了,否则…

字节跳动测试岗面试记:二面被按地上血虐,所幸Offer已到手...

在互联网做了几年之后,去大厂“镀镀金”是大部分人的首选。大厂不仅待遇高、福利好,更重要的是,它是对你专业能力的背书,大厂工作背景多少会给你的简历增加几分竞争力。 但说实话,想进大厂还真没那么容易。最近面试字…

webpack——使用、分析打包代码

世上本无nodejs js最初是在前端浏览器上运行的语言,js代码一旦脱离了浏览器环境,就无法被运行。直到nodejs的出现,我们在电脑上配置了node环境,就可以让js代码脱离浏览器,在node环境中运行。 浏览器不支持模块化 nodej…
最新文章