深度学习-梯度消失/爆炸产生的原因、解决方法
在深度学习模型中,梯度消失和梯度爆炸现象是限制深层神经网络有效训练的主要问题之一,这两个现象从本质上来说是由链式求导过程中梯度的缩小或增大引起的。特别是在深层网络中,若初始梯度在反向传播过程中逐层被放大或缩小,最后导致前几层的权重更新停滞(梯度消失)或异常增大(梯度爆炸),影响模型的有效训练和收敛。接下来,我们从网络深度、激活函数的选择等方面深入分析其成因,并探讨解决这些问题的主流方法。
1. 梯度消失与梯度爆炸的成因
(1)网络深度
在深层神经网络中,每层网络的输出需要通过链式法则依次向前层传递梯度。对于N层网络,梯度会以每层的权重导数值的乘积进行传递。如果网络层数较多,且每层权重的初始值较小,则连乘的结果会逐渐趋于零,导致梯度逐层减小,这即是梯度消失的现象。反之,如果每层权重的初始值较大,则连乘结果会不断增大,出现梯度爆炸。
(2)激活函数的选择 激活函数的选择直接影响到梯度在反向传播中的衰减或放大,尤其是早期的Sigmoid和Tanh激活函数。
- Sigmoid函数:Sigmoid将输入压缩到0到1的范围内,但在0附近的梯度会快速趋近于零,这种“饱和效应”会导致反向传播的梯度迅速衰减,产生梯度消失现象。
- Tanh函数:Tanh虽然比Sigmoid有较大的梯度值区间(-1到1),但在极值区间也会出现梯度趋于零的情况。
- ReLU函数:ReLU(Rectified Linear Unit)虽在正区间表现良好,但在负值区间恒为零,会导致部分神经元的输出始终为零,称为“神经元死亡”,影响梯度传递。
2. 解决梯度消失与爆炸的方法
(1)优化权重初始化策略
- Xavier初始化:适合Sigmoid和Tanh激活函数。它将权重初始化为均值为0、方差为
2/(输入神经元数 + 输出神经元数)
的值,确保输出的分布尽量均匀,防止梯度消失或爆炸。 - He初始化:专为ReLU和其变种设计,将权重初始化为均值为0、方差为
2/输入神经元数
,使正向和反向传播中梯度保持在合理范围,减轻梯度消失的现象。
(2)激活函数的优化
- ReLU (Rectified Linear Unit):ReLU的导数在正区间为1,能够减轻梯度消失问题。然而,负区间梯度为0会导致“神经元死亡”。为此,引入了多种ReLU的变体:
- Leaky ReLU:在负区间引入一个小的斜率(如0.01)而非直接置零,有效缓解神经元死亡现象。
- Parametric ReLU (PReLU):进一步改进了Leaky ReLU,使负区间的斜率可以学习优化,以适应不同任务的数据分布。
- ELU (Exponential Linear Unit):在负区间以指数形式衰减,而非恒为0,有助于提高网络的收敛速度和稳定性。
- Swish函数:由Google提出,定义为
x * sigmoid(x)
,允许负数并对输入进行平滑处理,取得了较好的梯度稳定性。
(3)使用正则化技术
- 梯度裁剪(Gradient Clipping):在反向传播中限制梯度的最大值(例如,将超过某阈值的梯度强制设为该阈值)。这种方法通常用于防止梯度爆炸,在RNN和LSTM模型中常用。
- 权重正则化:通过L1和L2正则化对模型参数进行约束。L2正则化通过在损失函数中加入权重平方和作为惩罚项,使得过大的权重更新得以抑制,防止梯度爆炸。
- Layer Normalization:Layer Normalization在每一层对每个神经元的输出进行归一化操作,以确保梯度稳定性,特别适用于循环神经网络(RNN)等任务。
(4)引入新型网络结构
- 残差网络(Residual Networks, ResNet):引入残差连接(skip connections),让信息绕过中间的隐藏层直接传到输出层,确保梯度信息在深层网络中可以顺利传递,极大减轻了梯度消失问题,使得上百层的深层网络得以训练成功。
- 批标准化(Batch Normalization, BN):在每个小批量数据上进行标准化处理,将激活值归一化为均值为0、方差为1的分布。BN不仅稳定了梯度流动,且能提高模型的收敛速度和精度,是现代神经网络中常用的标准技术。
- 长短期记忆网络(LSTM):LSTM(Long Short-Term Memory)结构是为解决循环神经网络中梯度消失问题设计的。LSTM单元通过内部的“遗忘门”、“输入门”和“输出门”机制,控制记忆的更新和遗忘过程。这种机制使得梯度可以有效保留并传播,防止了长期依赖关系中的梯度消失问题,LSTM广泛应用于自然语言处理和时间序列任务。
(5)优化算法的改进
- 自适应优化算法(如Adam和RMSprop):自适应学习率优化算法如Adam、RMSprop等根据梯度的一阶和二阶矩估计动态调整学习率,使得梯度更新在每一层得到较好的适应,能在一定程度上减轻梯度消失与爆炸的问题。
- 学习率调度器(Learning Rate Scheduler):在训练过程中动态调整学习率,初期使用较大学习率快速搜索全局最优,随后逐渐减小学习率以精细化模型参数,避免梯度爆炸或振荡。
(6)其他增强训练的策略
- 早停(Early Stopping):在检测到模型的验证误差持续不变或增大时,提前停止训练,防止梯度爆炸带来的过拟合问题。
- 预训练与微调:通过在相似任务上进行预训练来获得初始参数,再对目标任务进行微调。该策略能为深层网络提供较好的初始点,避免梯度消失或爆炸带来的收敛困难问题。
- 正则化参数搜索:对于不同层次的神经元选择合适的正则化参数,特别是L2正则化和Dropout正则化,有助于保持网络的泛化能力与梯度稳定性。
3. 代码示例
以下是实现梯度剪切和Batch Normalization的示例代码:
import torch
import torch.nn as nn
import torch.optim as optim# 一个简单的全连接神经网络
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(784, 512)self.bn1 = nn.BatchNorm1d(512) # 使用Batch Normalizationself.relu = nn.ReLU()self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.fc1(x)x = self.bn1(x) # 在第一个全连接层后添加BNx = self.relu(x)x = self.fc2(x)return x# 创建模型和优化器
model = SimpleNN()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 模拟训练循环
for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = nn.CrossEntropyLoss()(output, target)loss.backward()# 梯度剪切torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 设定梯度最大阈值为1.0optimizer.step()
/*
模型的第一层全连接后加入Batch Normalization,以减少梯度的偏移,提高梯度在深层网络中传播稳定性。
使用梯度剪切函数clip_grad_norm_防止梯度爆炸,通过设定梯度的最大阈值,更新参数时避免数值不稳定。
*/
相关文章:
深度学习-梯度消失/爆炸产生的原因、解决方法
在深度学习模型中,梯度消失和梯度爆炸现象是限制深层神经网络有效训练的主要问题之一,这两个现象从本质上来说是由链式求导过程中梯度的缩小或增大引起的。特别是在深层网络中,若初始梯度在反向传播过程中逐层被放大或缩小,最后导…...
MVC(Model-View-Controller)模式概述
MVC(Model-View-Controller)是一种设计模式,最初由 Trygve Reenskaug 在 1970 年代提出,并在 Smalltalk 编程环境中得到了广泛应用。MVC 模式旨在实现用户界面和业务逻辑的分离,以增强应用程序的可维护性、可扩展性和复…...

数据结构 —— 红黑树
目录 1. 初识红黑树 1.1 红黑树的概念 1.2 红⿊树的规则 1.3 红黑树如何确保最长路径不超过最短路径的2倍 1.4 红黑树的效率:O(logN) 2. 红黑树的实现 2.1 红黑树的基础结构框架 2.2 红黑树的插⼊ 2.2.1 情况1:变色 2.2.2 情况2:单旋变色 2.2…...
《功能高分子学报》
《功能高分子学报》 中国标准连续出版物号:CN 31-1633/O6,国际标准连续出版物号:ISSN 1008-9357,邮发代号:4-629,刊期:双月刊。 《功能高分子学报》主要刊登功能高分子和其他高分子领域具有创新意义的学术…...

Linux特种文件系统--tmpfs文件系统
tmpfs类似于RamDisk(只能使用物理内存),使用虚拟内存(简称VM)子系统的页面存储文件。tmpfs完全依赖VM,遵循子系统的整体调度策略。说白了tmpfs跟普通进程差不多,使用的都是某种形式的虚拟内存&a…...

《基于STMF103的FreeRTOS内核移植》
目录 1.FreeRTOS资料下载与出处 1.1官网下载,网址:www.freertos.org 1.2在正点原子官网,任意STM32F1的开发板资料A盘里, 2.FreeRTOS移植重要文件讲解 2.1 FreeRTOS与FreeRTOS-Plus文件夹 2.2 Demo、Lincence、Source ●Demo文件…...
一七二、Vue3性能优化方式
Vue 3 的性能优化相较于 Vue 2 有了显著提升,利用新特性和改进方法可以更高效地构建和优化应用。以下是 Vue 3 的常见性能优化方法及示例。 1. 使用组合式 API (Composition API) Vue 3 引入的组合式 API,通过逻辑拆分和复用来实现更高效的代码组织和性…...

软件测试--BUG篇
博主主页: 码农派大星. 数据结构专栏:Java数据结构 数据库专栏:MySQL数据库 JavaEE专栏:JavaEE 软件测试专栏:软件测试 关注博主带你了解更多知识 目录 1. 软件测试的⽣命周期 2. BUG 1. BUG 的概念 2. 描述bug的要素 3.bug级别 4.bug的⽣命周期 5 与开发产⽣争执怎…...
Scikit-learn和Keras简介
一,Scikit-learn是一个开源的机器学习库,用于Python编程语言。它建立在NumPy、SciPy和matplotlib这些科学计算库之上,提供了简单有效的数据挖掘和数据分析工具。Scikit-learn库包含了许多用于分类、回归、聚类和降维的算法,包括支…...

python在word的页脚插入页码
1、插入简易页码 import win32com.client as win32 from win32com.client import constants import osdoc_app win32.gencache.EnsureDispatch(Word.Application)#打开word应用程序 doc_app.Visible Truedoc doc_app.Documents.Add() footer doc.Sections(1).Footers(cons…...
Java面试题十四
一、Java中的JNI(Java Native Interface)是什么?它有什么用途? Java中的JNI(Java Native Interface)是Java提供的一种编程框架,它允许Java代码与本地(Native)代码&#x…...
yarn : 无法加载文件,未对文件 进行数字签名。无法在当前系统上运行该脚本。
执行这个命令时报错:yarn --registryhttps://registry.npm.taobao.org yarn : 无法加载文件 C:\Users\Administrator\AppData\Roaming\npm\yarn.ps1。未对文件 C:\Users\Administ rator\AppData\Roaming\npm\yarn.ps1 进行数字签名。无法在当前系统上运行该脚本。有…...

Hadoop——HDFS
什么是HDFS HDFS(Hadoop Distributed File System)是Apache Hadoop的核心组件之一,是一个分布式文件系统,专门设计用于在大规模集群上存储和管理海量数据。它的设计目标是提供高吞吐量的数据访问和容错能力,以支持大数…...

计算机的一些基础知识
文章目录 编程语言 程序 所谓程序,就是 一组指令 以及 这组指令要处理的数据。狭义上来说,程序对我们来说,通常表现为一组文件。 程序 指令 指令要处理的数据。 编程语言发展 机器语言:0、1 二进制构成汇编语言:…...

学习RocketMQ(记录了个人艰难学习RocketMQ的笔记)
一、部署单点RocketMQ Docker 部署 RocketMQ (图文并茂超详细)_docker 部署rocketmq-CSDN博客 这个博主讲的很好,可食用,替大家实践了一遍 二、原理篇 为什么使用RocketMQ: 为什么选择RocketMQ | RocketMQ 关于一些原理,感觉…...

【设计模式】策略模式定义及其实现代码示例
文章目录 一、策略模式1.1 策略模式的定义1.2 策略模式的参与者1.3 策略模式的优点1.4 策略模式的缺点1.5 策略模式的使用场景 二、策略模式简单实现2.1 案例描述2.2 实现代码 三、策略模式的代码优化3.1 优化思路3.2 抽象策略接口3.3 上下文3.4 具体策略实现类3.5 测试 参考资…...

list与iterator的之间的区别,如何用斐波那契数列探索yield
问题 list与iterator的之间的区别是什么?如何用斐波那契数列探索yield? 2 方法 将数据转换成list,通过对list索引和切片操作,以及可以进行添加、删除和修改元素。 iterator是一种对象,用于遍历可迭代对象(如列表、元组…...

抖音店铺数据也就是抖店,如何使用小店数据集来挖掘价值?
抖音商家现在基本达到二百多万家抖店,有一些公司可能会根据开放的数据研究行业分布、GMV等等,就像是也出了专业的一些平台如“蝉妈妈”、“达多多”,对我来说受限制就是难受。 当然也有很多大型合法的数据平台有抖店数据集,但…...

KubeVirt 安装和配置 Windows虚拟机
本文将将介绍如何安装 KubeVirt 和使用 KubeVirt 配置 Windows 虚拟机。 前置条件 准备 Ubuntu 操作系统,一定要安装图形化界面。 安装 Docker(最新版本) 安装 libvirt 和 TigerVNC: apt install libvirt-daemon-system libvir…...

CM API方式设置YARN队列资源
简述 对于CDH版本我们可以参考Fayson的文章,本次是CDP7.1.7 CM7.4.4 ,下面只演示一个设置队列容量百分比的示例,其他请参考cloudera官网。 获取cookies文件 生成cookies.txt文件 curl -i -k -v -c cookies.txt -u admin:admin http://192.168.242.100:7180/api/v44/clusters …...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...
day52 ResNet18 CBAM
在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

srs linux
下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935,SRS管理页面端口是8080,可…...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)
引言 工欲善其事,必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后,我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集,就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...

【C++】纯虚函数类外可以写实现吗?
1. 答案 先说答案,可以。 2.代码测试 .h头文件 #include <iostream> #include <string>// 抽象基类 class AbstractBase { public:AbstractBase() default;virtual ~AbstractBase() default; // 默认析构函数public:virtual int PureVirtualFunct…...

Ubuntu系统复制(U盘-电脑硬盘)
所需环境 电脑自带硬盘:1块 (1T) U盘1:Ubuntu系统引导盘(用于“U盘2”复制到“电脑自带硬盘”) U盘2:Ubuntu系统盘(1T,用于被复制) !!!建议“电脑…...