当前位置: 首页 > article >正文

梯度消失和梯度爆炸的一些处理方法

在这里是记录一下梯度消失或梯度爆炸的一些处理技巧。全当学习总结了如有错误还请留言,在此感激不尽。
权重和梯度的更新公式如下:
w = w − η ⋅ ∇ w w = w - \eta \cdot \nabla w w=wηw

个人通俗的理解梯度消失就是网络模型在反向求导的时候出现梯度值太小的问题,然而学习率一般也很小这就会导致权重在进行更新的时候,几乎不会发生变化,导致模型不收敛。

如果梯度值在更新过程中变成非常大,会导致网络权重的大幅更新,并因此使网络变得不稳定。在极端情况下,权重的值变得非常大,以至于溢出导致出现NaN,导致模型不收敛。

以下一些方法可以解决或者缓解梯度消失或梯度爆炸的情况。

1)梯度裁剪(Gradient Clipping)

直接对梯度进行数值上的约束,当梯度的范数超过某一阈值时,将其裁剪到该阈值以内,以此来防止梯度爆炸。
torch中接口如下

nn.utils.clip_grad_norm_(model.parameters(), clip_value)  # 默认是2范数

clip_grad_norm_() 函数的主要参数包括:

  • parameters:一个包含需要裁剪梯度的参数的迭代器或者是生成器,通常我们会传入model.parameters()
  • max_norm:要裁剪到的最大范数值。如果梯度范数超过这个值,则会按比例缩放梯度。默认为无穷大。
  • norm_type:要计算的范数类型。通常设置为 2 表示 L2 范数(欧几里得范数),也可以设置为其他值如 1 表示 L1 范数。默认为 2。

clip_grad_norm_() 函数内部计算方式可以参考下式:

r e t u r n = c l i p _ v a l u e ∗ x g r a d ∣ ∣ x g r a d ∣ ∣ return = clip\_value * \frac{x_{grad}}{||x_{grad}||} return=clip_value∣∣xgrad∣∣xgrad
clip_value值确定一般是:

  1. 可视化梯度:在初步训练过程中,观察模型的梯度值是否出现过大,可以通过打印或可视化参数梯度的大小来判断。
  2. 经验范围:一些研究者和实践者建议将max_norm设置在 1 到 5 之间,或者梯度总和的初始平均值的某个倍数。

2)激活函数的选择

尽量少选择sigmoid和tanh等饱和性的激活函数,可以选择 ReLU及其变种等激活函数。

3)批量归一化(Batch Normalization, BN)

神经网络每一层学习到的分布都是无法预测的,前一层的输出即是下一层的输入,由于参数的更新,每一层的输入分布都在发生变化,导致网络很难收敛。如果说让一个batch的数据在网络结构中都服从同一种分布,将可以解决这个问题。

由于 BatchNorm 使得输入数据的分布更为集中和稳定,因此在网络中向传播时,梯度的变化受到输入数据分布变化的影响变小,有效地解决了梯度消失和梯度爆炸的问题。主要作用 稳定网络内部的分布。

4)残差连接(Residual Connections / ResNet)

通过跨层的连接方式,网络保留了原始信息的输入,有助于维持梯度的有效流动,在一定程度上可以抑制梯度爆炸的现象。

5)初始化策略

使用合理的权重初始化方法,比如xavier初始化或He初始化,这些方法可以根据网络层的输入和输出节点数量来调整初始权重的分布,从而减少梯度消失的可能性。

6)门控机制(Gate Mechanisms)

专门设计来解决梯度消失问题的,可参考LSTM或者GRU网络模型。通过门控机制精确的管理着信息的存储、遗忘和更新等。门控机制本身并不直接防止梯度过大但是结合Gradient Clipping 可以有效的避免梯度爆炸的发生。


http://www.kler.cn/news/274852.html

相关文章:

  • [Uni-app] 微信小程序的圆环进度条
  • web前端之多种方式实现switch滑块功能、动态设置css变量、after伪元素、选择器、has伪类
  • iPhone迎AI大革命:谷歌、OpenAI助苹果重塑智能巅峰
  • 室友打团太吵?一条命令断掉它的WiFi
  • 论文浅尝 | GPT-RE:基于大语言模型针对关系抽取的上下文学习
  • SpringMVC04 实现简单的留言墙功能
  • vue3之声明式和编程式导航
  • 远程安全访问JumpServer:使用cpolar内网穿透搭建固定公网地址
  • [python] ETL 工作流程 Prefect
  • 反射与串扰
  • Gin框架 源码解析
  • 图神经网络实战(5)——常用图数据集
  • 计算机毕业设计-基于python的旅游信息爬取以及数据分析
  • 【web世界探险家】HTML5 探索与实践
  • 【SpringBoot3+MyBatis-Plus】头条新闻项目实现CRUD登录注册
  • webpack5零基础入门-12搭建开发服务器
  • Docker 从0安装 nacos集群
  • Linux文件 profile、bashrc、bash_profile区别
  • vivado 物理优化约束、交互式物理优化
  • 权限维持小结
  • 【回溯、分治、Kadane】算法例题
  • 富格林:揭露黑幕套路安全规避风险
  • 嵌入式学习41-数据结构2
  • Java学习笔记:异常处理
  • UE5 GAS开发p30 创建UI HUD 血条
  • 【Anaconda】换源常用命令
  • 风丘电动汽车高压测试方案 助您高效应对车辆试验难题
  • Spark on Yarn安装配置
  • Java后台生成多个Excel并用Zip打包后(可以将excel文件放置到不同的目录)下载
  • JS+CSS3点击粒子烟花动画js特效