当前位置: 首页 > article >正文

PyTorch深度学习的梯度消失和梯度爆炸的识别、解决和最佳实践

通过结合梯度监控、网络架构改进和优化策略,可以有效应对梯度消失/爆炸问题。建议在模型开发初期就加入梯度监控机制,这有助于快速定位问题层。对于超深网络(>50层),建议优先考虑使用预激活残差结构(ResNet-v2)。

一、梯度消失/爆炸原理

1. 问题成因

  • 梯度消失:反向传播时梯度值逐层指数级衰减(常见于Sigmoid/Tanh激活函数)
  • 梯度爆炸:反向传播时梯度值逐层指数级增长(常见于深层网络和不当初始化)

2. 数学原理
假设网络有L层,每层梯度计算为:
∂ L o s s ∂ W l = ∂ L o s s ∂ h L ∏ k = l L − 1 ( W k + 1 T ⊙ σ ′ ( h k ) ) \frac{\partial Loss}{\partial W_l} = \frac{\partial Loss}{\partial h_L} \prod_{k=l}^{L-1} (W_{k+1}^T \odot \sigma'(h_k)) WlLoss=hLLossk=lL1(Wk+1Tσ(hk))
当连乘积项趋向0时出现梯度消失,趋向无穷大时出现梯度爆炸。


二、问题识别与监控代码

使用梯度监控hook记录各层梯度分布:

import torch
import torch.nn as nn# 定义一个有梯度消失问题的网络
class ProblemNet(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(784, 200),nn.Sigmoid(),nn.Linear(200, 200),nn.Sigmoid(),nn.Linear(200, 10))def forward(self, x):return self.layers(x)# 梯度监控hook
def register_grad_hook(model):grads = []def hook_fn(module, grad_input, grad_output):grad_mean = grad_output[0].abs().mean().item()grads.append(grad_mean)return Nonefor layer in model.layers:if isinstance(layer, nn.Linear):layer.register_full_backward_hook(hook_fn)return grads# 训练过程
model = ProblemNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 模拟数据
x = torch.randn(32, 784)
y = torch.randint(0,10,(32,))grads = register_grad_hook(model)  # 注册监控hookoptimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()# 打印各层梯度均值
print("梯度均值监测:")
for i, g in enumerate(grads):print(f"Layer {i+1} grad mean: {g:.4e}")

典型输出(梯度消失):

梯度均值监测:
Layer 1 grad mean: 2.3432e-05
Layer 2 grad mean: 1.0784e-08
Layer 3 grad mean: 0.0000e+00

三、解决方案与改进代码

改进策略

  1. 激活函数改用ReLU
  2. 添加批归一化层
  3. 使用Xavier初始化
  4. 添加梯度裁剪
class ImprovedNet(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(784, 200),nn.BatchNorm1d(200),nn.ReLU(inplace=True),nn.Linear(200, 200),nn.BatchNorm1d(200),nn.ReLU(inplace=True),nn.Linear(200, 10))self._init_weights()def _init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)def forward(self, x):return self.layers(x)# 使用梯度裁剪的优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # 梯度裁剪# 重新运行训练...

改进后的典型输出:

梯度均值监测:
Layer 1 grad mean: 3.1425e-02
Layer 2 grad mean: 2.8713e-02 
Layer 3 grad mean: 1.9564e-02

四、最佳实践建议

  1. 激活函数选择

    • 优先使用ReLU/Leaky ReLU(α=0.01)
    • 尝试Swish(x*sigmoid(βx))等新型激活函数
  2. 权重初始化

    # He初始化(ReLU适用)
    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')# Xavier初始化(Tanh适用)
    nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('tanh'))
    
  3. 梯度控制技术

    # 梯度裁剪(推荐值1.0-5.0)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)# 梯度累积(模拟大batch_size)
    accumulation_steps = 4
    loss = loss / accumulation_steps
    loss.backward()
    if (i+1) % accumulation_steps == 0:optimizer.step()optimizer.zero_grad()
    
  4. 架构改进

    # 添加残差连接
    class ResidualBlock(nn.Module):def __init__(self, in_dim):super().__init__()self.fc = nn.Sequential(nn.Linear(in_dim, in_dim),nn.BatchNorm1d(in_dim),nn.ReLU(),nn.Linear(in_dim, in_dim))def forward(self, x):return x + self.fc(x)
    
  5. 监控工具

    # 使用TensorBoard监控梯度分布
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()for name, param in model.named_parameters():if 'weight' in name and param.grad is not None:writer.add_histogram(f'grad/{name}', param.grad, global_step)
    

五、诊断流程图

训练异常 → 监控梯度 → if 梯度出现NaN: 检查学习率 → 添加梯度裁剪 → 检查数据归一化elif 梯度<1e-6: 改用ReLU → 添加残差连接 → 检查初始化方法else: 继续正常训练

相关文章:

PyTorch深度学习的梯度消失和梯度爆炸的识别、解决和最佳实践

通过结合梯度监控、网络架构改进和优化策略&#xff0c;可以有效应对梯度消失/爆炸问题。建议在模型开发初期就加入梯度监控机制&#xff0c;这有助于快速定位问题层。对于超深网络&#xff08;>50层&#xff09;&#xff0c;建议优先考虑使用预激活残差结构&#xff08;Res…...

Nginx1.19.2不适配OPENSSL3.0问题

Nginx 1.19.2 是较老的版本&#xff0c;而 Nginx 1.21 版本已经适配 OpenSSL 3.0&#xff0c;所以建议 升级 Nginx 到 1.25.0 或更高版本&#xff1a; wget http://nginx.org/download/nginx-1.25.0.tar.gz tar -xzf nginx-1.25.0.tar.gz cd nginx-1.25.0 ./configure --prefix…...

蓝桥杯 Excel地址

Excel地址 题目描述 Excel 单元格的地址表示很有趣&#xff0c;它使用字母来表示列号。 比如&#xff0c; A 表示第 1 列&#xff0c; B 表示第 2 列&#xff0c; Z 表示第 26 列&#xff0c; AA 表示第 27 列&#xff0c; AB 表示第 28 列&#xff0c; BA 表示第 53 列&#x…...

免费pdf格式转换工具

基本功能 - 支持单文件转换和批量转换两种模式 - 内置PDF文件预览功能 - 支持8种常见格式转换&#xff1a;Word、Excel、JPG/PNG图片、HTML、文本、PowerPoint和ePub 单文件转换功能 - 文件选择&#xff1a;支持浏览和选择单个PDF文件 - 输出位置&#xff1a;可自定义设置输出…...

I²C总线应用场景及1.8V与3.3V电压选择

以下是关于IC总线应用场景及1.8V与3.3V电压选择的详细分析: 一、IC总线的典型应用场景 1. 板内通信(主要场景) 描述:IC 最初设计是为电路板(PCB)上的芯片间短距离通信,尤其适用于集成度高的系统。典型器件: 传感器模块(如温湿度传感器BME280)。存储芯片(如EEPROM 2…...

css错峰布局/瀑布流样式(类似于快手样式)

当样式一侧比较高的时候会自动换行&#xff0c;尽量保持高度大概一致&#xff0c; 例&#xff1a; 一侧元素为5&#xff0c;另一侧元素为6 当为5的一侧过于高的时候&#xff0c;可能会变为4/7分部dom节点 如果不需要这样的话删除样式 flex-flow:column wrap; 设置父级dom样…...

Deepseek中的MoE架构的改造:动态可变参数激活的MoE混合专家架构(DVPA-MoE)的考虑

大家好,我是微学AI,今天给大家介绍一下动态可变参数激活MoE架构(Dynamic Variable Parameter-Activated MoE, DVPA-MoE)的架构与实际应用,本架构支持从7B到32B的等多档参数动态激活。该架构通过细粒度难度评估和分层专家路由,实现“小问题用小参数,大问题用大参数”的精…...

docker-compose Install reranker(fastgpt支持) GPU模式

前言BGE-重新排名器 与 embedding 模型不同&#xff0c;reranker 或 cross-encoder 使用 question 和 document 作为输入&#xff0c;直接输出相似性而不是 embedding。 为了平衡准确性和时间成本&#xff0c;cross-encoder 被广泛用于对其他简单模型检索到的前 k 个文档进行重…...

doris: MySQL

Doris JDBC Catalog 支持通过标准 JDBC 接口连接 MySQL 数据库。本文档介绍如何配置 MySQL 数据库连接。 使用须知​ 要连接到 MySQL 数据库&#xff0c;您需要 MySQL 5.7, 8.0 或更高版本 MySQL 数据库的 JDBC 驱动程序&#xff0c;您可以从 Maven 仓库下载最新或指定版本的…...

JVM参数调整

一、内存相关参数 1. 堆内存控制 -Xmx&#xff1a;最大堆内存&#xff08;如 -Xmx4g&#xff0c;默认物理内存1/4&#xff09;。-Xms&#xff1a;初始堆内存&#xff08;建议与-Xmx相等&#xff0c;避免动态扩容带来的性能波动&#xff09;。-Xmn&#xff1a;新生代大小&…...

【DeepSeek问答】访问QStandardItemModel::index(r,c)获取的空索引导致程序崩溃

好的&#xff0c;我现在来仔细思考一下用户的问题。用户在使用QStandardItemModel的setItem方法时&#xff0c;调用了setItem(4,6,item)&#xff0c;也就是在第4行第6列的位置设置了一个item。然后他们尝试通过index(3,6)来获取这个位置的项目&#xff0c;想知道会有什么后果。…...

基于websocket的多用户网页五子棋 --- 测试报告

目录 功能测试自动化测试性能测试 功能测试 1.登录注册页面 2.游戏大厅页面 3.游戏房间页面 自动化测试 1.使用脑图编写web自动化测试用例 2.创建自动化项目&#xff0c;根据用例通过selenium来实现脚本 根据脑图进行测试用例的编写&#xff1a; 每个页面一个测试类&am…...

在 macOS 上使用 CLion 进行 Google Test 单元测试

介绍 Google Test&#xff08;GTest&#xff09;是 Google 开源的 C 单元测试框架&#xff0c;它提供了简单易用的断言、测试夹具&#xff08;Fixtures&#xff09;和测试运行机制&#xff0c;使 C 开发者能够编写高效的单元测试。 本博客将介绍如何在 macOS 上使用 CLion 配…...

深度解码!清华大学第六弹《AIGC发展研究3.0版》

在Grok3与GPT-4.5相继发布之际&#xff0c;《AIGC发展研究3.0版》的重磅报告——这份长达200页的行业圣经&#xff0c;不仅预测了2025年AI技术爆发点&#xff0c;更将「天人合一」的东方智慧融入AI伦理建构&#xff0c;堪称数字时代的《道德经》。 文档&#xff1a;清华大学第…...

【论文笔记】Attentive Eraser

标题&#xff1a;Attentive Eraser: Unleashing Diffusion Model’s Object Removal Potential via Self-Attention Redirection Guidance Source&#xff1a;https://arxiv.org/pdf/2412.12974 收录&#xff1a;AAAI 25 作者单位&#xff1a;浙工商&#xff0c;字节&#…...

97k倍区间

97k倍区间 ⭐️难度&#xff1a;中等 &#x1f31f;考点&#xff1a;暴力&#xff0c;2017省赛 &#x1f4d6; &#x1f4da; import java.util.Scanner;public class Main {static int N 100010;public static void main(String[] args) {Scanner sc new Scanner(System.…...

cursor使用经验分享(java后端服务开发向)

前言 cursor是一款基于vscode&#xff0c;并集成AI能力的代码编辑器&#xff0c;其功能包括但不限于代码生成及补全、AI对话&#xff08;能够直接将代码环境作为上下文&#xff09;、即时应用建议等等&#xff0c;是一款面向未来的代码编辑器。 对于vscode&#xff0c;最先想…...

SpringBoot3—场景整合:AOT

一、AOT与JIT AOT&#xff1a;Ahead-of-Time&#xff08;提前编译&#xff09;&#xff1a;程序执行前&#xff0c;全部被编译成机器码 JIT&#xff1a;Just in Time&#xff08;即时编译&#xff09;: 程序边编译&#xff0c;边运行&#xff1b; 编译&#xff1a;源代码&am…...

蓝桥与力扣刷题(蓝桥 数字三角形)

题目&#xff1a; 上图给出了一个数字三角形。从三角形的顶部到底部有很多条不同的路径。对于每条路径&#xff0c;把路径上面的数加起来可以得到一个和&#xff0c;你的任务就是找到最大的和&#xff08;路径上的每一步只可沿左斜线向下或右斜线向下走&#xff09;。 输入描述…...

蓝桥试题:传球游戏(二维dp)

一、题目描述 上体育课的时候&#xff0c;小蛮的老师经常带着同学们一起做游戏。这次&#xff0c;老师带着同学们一起做传球游戏。 游戏规则是这样的&#xff1a;n 个同学站成一个圆圈&#xff0c;其中的一个同学手里拿着一个球&#xff0c;当老师吹哨子时开始传球&#xff0…...

游戏引擎学习第138天

仓库:https://gitee.com/mrxiao_com/2d_game_3 资产&#xff1a;game_hero_test_assets_003.zip 发布 我们的目标是展示游戏运行时的完整过程&#xff0c;从像素渲染到不使用GPU的方式&#xff0c;我们自己编写了渲染器并完成了所有的工作。今天我们开始了一些新的内容&#…...

Lab 3 Page Table

题目链接 我的问题&#xff1a; 1 每个进程的kernel stack是干啥的来着&#xff1f;在何时初始化的&#xff1f; 题目2&#xff1a;A kernel page table per process (hard) 1 一些题目要求 Your first job is to modify the kernel so that every process uses its own c…...

嵌入式学习L5D2-exec函数族和守护进程

exec函数族1 下面那个加了p环境变量就不用那个了。 输出的是系统 exec函数族2 后面不执行了 第二个参数瞎写也可以&#xff0c;但是要填 这里是说不想被替换&#xff0c;就在子进程里面执行这个。 守护进程概念 后台进程 守护进程是后台进程 一个fork了一个进程&#xff…...

洛谷P1091

题目如下 思路 谢谢观看...

行为模式---迭代器模式

概念 迭代器模式是设计模式的行为模式&#xff0c;它的主要设计思想是提供一个可以操作聚合对象&#xff08;容器或者复杂数据类型&#xff09;表示&#xff08;迭代器类&#xff09;。通过迭代器类去访问操作聚合对象可以隐藏内部表示&#xff0c;也可以使客户端可以统一处理…...

阿里云 DataWorks面试题集锦及参考答案

目录 简述阿里云 DataWorks 的核心功能模块及其在企业数据治理中的作用 简述 DataWorks 的核心功能模块及其应用场景 解释 DataWorks 中工作空间、项目、业务流程的三层逻辑关系 解释 DataWorks 中的 “节点”、“工作流” 和 “依赖关系” 设计 解释 DataWorks 中 “周期任…...

【五.LangChain技术与应用】【29.LangChain Agent小案例1:智能代理的实战应用】

“为什么我的Agent总是处理不好实时数据?”“如何让AI自己调用API查股票?” 这些困扰开发者的问题,今天咱们用一个真实案例来彻底解决。不聊虚的,直接上手教你怎么用LangChain Agent造一个会自己查股价、算指标、生成报告的股票分析助手。全程高能,代码可直接复制粘贴到项…...

TWind 的黑马点评随笔

TWind 的黑马点评随笔 ​ 目前是把黑马点评的技术部分完全做完了&#xff0c;不能说吃得饱饱&#xff0c;也算个半饱吧。 ​ 黑马点评严格来说不算项目&#xff0c;因为它给的前端过于垃圾&#xff0c;内容又重在Redis&#xff0c;所以称之为Redis练习貌似跟贴切。 ​ 尽管如…...

windows部署spleeter 版本2.4.0:分离音频的人声和背景音乐

windows部署spleeter 版本2.4.0&#xff1a;分离音频的人声和背景音乐 一、Spleeter 是什么&#xff1f; Spleeter 是由法国音乐流媒体公司 Deezer 开发并开源的一款基于深度学习的音频分离工具。它能够将音乐中的不同音轨&#xff08;如人声、鼓、贝斯、钢琴等&#xff09;分…...

dify + ollama + deepseek-r1+ stable-diffusion 构建绘画智能体

故事背景 stable-diffusion 集成进 dify 后&#xff0c;我们搭建一个小智能体&#xff0c;验证下文生图功能 业务流程 #mermaid-svg-6nSwwp69eMizP6bt {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-6nSwwp69eMiz…...