PyTorch深度学习的梯度消失和梯度爆炸的识别、解决和最佳实践
通过结合梯度监控、网络架构改进和优化策略,可以有效应对梯度消失/爆炸问题。建议在模型开发初期就加入梯度监控机制,这有助于快速定位问题层。对于超深网络(>50层),建议优先考虑使用预激活残差结构(ResNet-v2)。
一、梯度消失/爆炸原理
1. 问题成因:
- 梯度消失:反向传播时梯度值逐层指数级衰减(常见于Sigmoid/Tanh激活函数)
- 梯度爆炸:反向传播时梯度值逐层指数级增长(常见于深层网络和不当初始化)
2. 数学原理:
假设网络有L层,每层梯度计算为:
∂ L o s s ∂ W l = ∂ L o s s ∂ h L ∏ k = l L − 1 ( W k + 1 T ⊙ σ ′ ( h k ) ) \frac{\partial Loss}{\partial W_l} = \frac{\partial Loss}{\partial h_L} \prod_{k=l}^{L-1} (W_{k+1}^T \odot \sigma'(h_k)) ∂Wl∂Loss=∂hL∂Lossk=l∏L−1(Wk+1T⊙σ′(hk))
当连乘积项趋向0时出现梯度消失,趋向无穷大时出现梯度爆炸。
二、问题识别与监控代码
使用梯度监控hook记录各层梯度分布:
import torch
import torch.nn as nn# 定义一个有梯度消失问题的网络
class ProblemNet(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(784, 200),nn.Sigmoid(),nn.Linear(200, 200),nn.Sigmoid(),nn.Linear(200, 10))def forward(self, x):return self.layers(x)# 梯度监控hook
def register_grad_hook(model):grads = []def hook_fn(module, grad_input, grad_output):grad_mean = grad_output[0].abs().mean().item()grads.append(grad_mean)return Nonefor layer in model.layers:if isinstance(layer, nn.Linear):layer.register_full_backward_hook(hook_fn)return grads# 训练过程
model = ProblemNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 模拟数据
x = torch.randn(32, 784)
y = torch.randint(0,10,(32,))grads = register_grad_hook(model) # 注册监控hookoptimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()# 打印各层梯度均值
print("梯度均值监测:")
for i, g in enumerate(grads):print(f"Layer {i+1} grad mean: {g:.4e}")
典型输出(梯度消失):
梯度均值监测:
Layer 1 grad mean: 2.3432e-05
Layer 2 grad mean: 1.0784e-08
Layer 3 grad mean: 0.0000e+00
三、解决方案与改进代码
改进策略:
- 激活函数改用ReLU
- 添加批归一化层
- 使用Xavier初始化
- 添加梯度裁剪
class ImprovedNet(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(784, 200),nn.BatchNorm1d(200),nn.ReLU(inplace=True),nn.Linear(200, 200),nn.BatchNorm1d(200),nn.ReLU(inplace=True),nn.Linear(200, 10))self._init_weights()def _init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)def forward(self, x):return self.layers(x)# 使用梯度裁剪的优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # 梯度裁剪# 重新运行训练...
改进后的典型输出:
梯度均值监测:
Layer 1 grad mean: 3.1425e-02
Layer 2 grad mean: 2.8713e-02
Layer 3 grad mean: 1.9564e-02
四、最佳实践建议
-
激活函数选择:
- 优先使用ReLU/Leaky ReLU(α=0.01)
- 尝试Swish(x*sigmoid(βx))等新型激活函数
-
权重初始化:
# He初始化(ReLU适用) nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')# Xavier初始化(Tanh适用) nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('tanh')) -
梯度控制技术:
# 梯度裁剪(推荐值1.0-5.0) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)# 梯度累积(模拟大batch_size) accumulation_steps = 4 loss = loss / accumulation_steps loss.backward() if (i+1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad() -
架构改进:
# 添加残差连接 class ResidualBlock(nn.Module):def __init__(self, in_dim):super().__init__()self.fc = nn.Sequential(nn.Linear(in_dim, in_dim),nn.BatchNorm1d(in_dim),nn.ReLU(),nn.Linear(in_dim, in_dim))def forward(self, x):return x + self.fc(x) -
监控工具:
# 使用TensorBoard监控梯度分布 from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter()for name, param in model.named_parameters():if 'weight' in name and param.grad is not None:writer.add_histogram(f'grad/{name}', param.grad, global_step)
五、诊断流程图
训练异常 → 监控梯度 → if 梯度出现NaN: 检查学习率 → 添加梯度裁剪 → 检查数据归一化elif 梯度<1e-6: 改用ReLU → 添加残差连接 → 检查初始化方法else: 继续正常训练
相关文章:
PyTorch深度学习的梯度消失和梯度爆炸的识别、解决和最佳实践
通过结合梯度监控、网络架构改进和优化策略,可以有效应对梯度消失/爆炸问题。建议在模型开发初期就加入梯度监控机制,这有助于快速定位问题层。对于超深网络(>50层),建议优先考虑使用预激活残差结构(Res…...
Nginx1.19.2不适配OPENSSL3.0问题
Nginx 1.19.2 是较老的版本,而 Nginx 1.21 版本已经适配 OpenSSL 3.0,所以建议 升级 Nginx 到 1.25.0 或更高版本: wget http://nginx.org/download/nginx-1.25.0.tar.gz tar -xzf nginx-1.25.0.tar.gz cd nginx-1.25.0 ./configure --prefix…...
蓝桥杯 Excel地址
Excel地址 题目描述 Excel 单元格的地址表示很有趣,它使用字母来表示列号。 比如, A 表示第 1 列, B 表示第 2 列, Z 表示第 26 列, AA 表示第 27 列, AB 表示第 28 列, BA 表示第 53 列&#x…...
免费pdf格式转换工具
基本功能 - 支持单文件转换和批量转换两种模式 - 内置PDF文件预览功能 - 支持8种常见格式转换:Word、Excel、JPG/PNG图片、HTML、文本、PowerPoint和ePub 单文件转换功能 - 文件选择:支持浏览和选择单个PDF文件 - 输出位置:可自定义设置输出…...
I²C总线应用场景及1.8V与3.3V电压选择
以下是关于IC总线应用场景及1.8V与3.3V电压选择的详细分析: 一、IC总线的典型应用场景 1. 板内通信(主要场景) 描述:IC 最初设计是为电路板(PCB)上的芯片间短距离通信,尤其适用于集成度高的系统。典型器件: 传感器模块(如温湿度传感器BME280)。存储芯片(如EEPROM 2…...
css错峰布局/瀑布流样式(类似于快手样式)
当样式一侧比较高的时候会自动换行,尽量保持高度大概一致, 例: 一侧元素为5,另一侧元素为6 当为5的一侧过于高的时候,可能会变为4/7分部dom节点 如果不需要这样的话删除样式 flex-flow:column wrap; 设置父级dom样…...
Deepseek中的MoE架构的改造:动态可变参数激活的MoE混合专家架构(DVPA-MoE)的考虑
大家好,我是微学AI,今天给大家介绍一下动态可变参数激活MoE架构(Dynamic Variable Parameter-Activated MoE, DVPA-MoE)的架构与实际应用,本架构支持从7B到32B的等多档参数动态激活。该架构通过细粒度难度评估和分层专家路由,实现“小问题用小参数,大问题用大参数”的精…...
docker-compose Install reranker(fastgpt支持) GPU模式
前言BGE-重新排名器 与 embedding 模型不同,reranker 或 cross-encoder 使用 question 和 document 作为输入,直接输出相似性而不是 embedding。 为了平衡准确性和时间成本,cross-encoder 被广泛用于对其他简单模型检索到的前 k 个文档进行重…...
doris: MySQL
Doris JDBC Catalog 支持通过标准 JDBC 接口连接 MySQL 数据库。本文档介绍如何配置 MySQL 数据库连接。 使用须知 要连接到 MySQL 数据库,您需要 MySQL 5.7, 8.0 或更高版本 MySQL 数据库的 JDBC 驱动程序,您可以从 Maven 仓库下载最新或指定版本的…...
JVM参数调整
一、内存相关参数 1. 堆内存控制 -Xmx:最大堆内存(如 -Xmx4g,默认物理内存1/4)。-Xms:初始堆内存(建议与-Xmx相等,避免动态扩容带来的性能波动)。-Xmn:新生代大小&…...
【DeepSeek问答】访问QStandardItemModel::index(r,c)获取的空索引导致程序崩溃
好的,我现在来仔细思考一下用户的问题。用户在使用QStandardItemModel的setItem方法时,调用了setItem(4,6,item),也就是在第4行第6列的位置设置了一个item。然后他们尝试通过index(3,6)来获取这个位置的项目,想知道会有什么后果。…...
基于websocket的多用户网页五子棋 --- 测试报告
目录 功能测试自动化测试性能测试 功能测试 1.登录注册页面 2.游戏大厅页面 3.游戏房间页面 自动化测试 1.使用脑图编写web自动化测试用例 2.创建自动化项目,根据用例通过selenium来实现脚本 根据脑图进行测试用例的编写: 每个页面一个测试类&am…...
在 macOS 上使用 CLion 进行 Google Test 单元测试
介绍 Google Test(GTest)是 Google 开源的 C 单元测试框架,它提供了简单易用的断言、测试夹具(Fixtures)和测试运行机制,使 C 开发者能够编写高效的单元测试。 本博客将介绍如何在 macOS 上使用 CLion 配…...
深度解码!清华大学第六弹《AIGC发展研究3.0版》
在Grok3与GPT-4.5相继发布之际,《AIGC发展研究3.0版》的重磅报告——这份长达200页的行业圣经,不仅预测了2025年AI技术爆发点,更将「天人合一」的东方智慧融入AI伦理建构,堪称数字时代的《道德经》。 文档:清华大学第…...
【论文笔记】Attentive Eraser
标题:Attentive Eraser: Unleashing Diffusion Model’s Object Removal Potential via Self-Attention Redirection Guidance Source:https://arxiv.org/pdf/2412.12974 收录:AAAI 25 作者单位:浙工商,字节&#…...
97k倍区间
97k倍区间 ⭐️难度:中等 🌟考点:暴力,2017省赛 📖 📚 import java.util.Scanner;public class Main {static int N 100010;public static void main(String[] args) {Scanner sc new Scanner(System.…...
cursor使用经验分享(java后端服务开发向)
前言 cursor是一款基于vscode,并集成AI能力的代码编辑器,其功能包括但不限于代码生成及补全、AI对话(能够直接将代码环境作为上下文)、即时应用建议等等,是一款面向未来的代码编辑器。 对于vscode,最先想…...
SpringBoot3—场景整合:AOT
一、AOT与JIT AOT:Ahead-of-Time(提前编译):程序执行前,全部被编译成机器码 JIT:Just in Time(即时编译): 程序边编译,边运行; 编译:源代码&am…...
蓝桥与力扣刷题(蓝桥 数字三角形)
题目: 上图给出了一个数字三角形。从三角形的顶部到底部有很多条不同的路径。对于每条路径,把路径上面的数加起来可以得到一个和,你的任务就是找到最大的和(路径上的每一步只可沿左斜线向下或右斜线向下走)。 输入描述…...
蓝桥试题:传球游戏(二维dp)
一、题目描述 上体育课的时候,小蛮的老师经常带着同学们一起做游戏。这次,老师带着同学们一起做传球游戏。 游戏规则是这样的:n 个同学站成一个圆圈,其中的一个同学手里拿着一个球,当老师吹哨子时开始传球࿰…...
游戏引擎学习第138天
仓库:https://gitee.com/mrxiao_com/2d_game_3 资产:game_hero_test_assets_003.zip 发布 我们的目标是展示游戏运行时的完整过程,从像素渲染到不使用GPU的方式,我们自己编写了渲染器并完成了所有的工作。今天我们开始了一些新的内容&#…...
Lab 3 Page Table
题目链接 我的问题: 1 每个进程的kernel stack是干啥的来着?在何时初始化的? 题目2:A kernel page table per process (hard) 1 一些题目要求 Your first job is to modify the kernel so that every process uses its own c…...
嵌入式学习L5D2-exec函数族和守护进程
exec函数族1 下面那个加了p环境变量就不用那个了。 输出的是系统 exec函数族2 后面不执行了 第二个参数瞎写也可以,但是要填 这里是说不想被替换,就在子进程里面执行这个。 守护进程概念 后台进程 守护进程是后台进程 一个fork了一个进程ÿ…...
洛谷P1091
题目如下 思路 谢谢观看...
行为模式---迭代器模式
概念 迭代器模式是设计模式的行为模式,它的主要设计思想是提供一个可以操作聚合对象(容器或者复杂数据类型)表示(迭代器类)。通过迭代器类去访问操作聚合对象可以隐藏内部表示,也可以使客户端可以统一处理…...
阿里云 DataWorks面试题集锦及参考答案
目录 简述阿里云 DataWorks 的核心功能模块及其在企业数据治理中的作用 简述 DataWorks 的核心功能模块及其应用场景 解释 DataWorks 中工作空间、项目、业务流程的三层逻辑关系 解释 DataWorks 中的 “节点”、“工作流” 和 “依赖关系” 设计 解释 DataWorks 中 “周期任…...
【五.LangChain技术与应用】【29.LangChain Agent小案例1:智能代理的实战应用】
“为什么我的Agent总是处理不好实时数据?”“如何让AI自己调用API查股票?” 这些困扰开发者的问题,今天咱们用一个真实案例来彻底解决。不聊虚的,直接上手教你怎么用LangChain Agent造一个会自己查股价、算指标、生成报告的股票分析助手。全程高能,代码可直接复制粘贴到项…...
TWind 的黑马点评随笔
TWind 的黑马点评随笔 目前是把黑马点评的技术部分完全做完了,不能说吃得饱饱,也算个半饱吧。 黑马点评严格来说不算项目,因为它给的前端过于垃圾,内容又重在Redis,所以称之为Redis练习貌似跟贴切。 尽管如…...
windows部署spleeter 版本2.4.0:分离音频的人声和背景音乐
windows部署spleeter 版本2.4.0:分离音频的人声和背景音乐 一、Spleeter 是什么? Spleeter 是由法国音乐流媒体公司 Deezer 开发并开源的一款基于深度学习的音频分离工具。它能够将音乐中的不同音轨(如人声、鼓、贝斯、钢琴等)分…...
dify + ollama + deepseek-r1+ stable-diffusion 构建绘画智能体
故事背景 stable-diffusion 集成进 dify 后,我们搭建一个小智能体,验证下文生图功能 业务流程 #mermaid-svg-6nSwwp69eMizP6bt {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-6nSwwp69eMiz…...
