当前位置: 首页 > news >正文

获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:

  • retain_grad()
  • hook

不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是使用“retain_grad()”吧

1、retain_grad()

retain_grad()显式地保存非叶节点的梯度, 代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)

使用方法:

直接在forward中对你想要输出gred的tensor“.retain_grad()”即可:tensor.retain_grad()

import torchdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)z_1.retain_grad()y_1 = torch.sigmoid(z_1)y_1.retain_grad()z_2 = torch.mm(w2, y_1)z_2.retain_grad()y_2 = torch.sigmoid(z_2)# y_2.retain_grad()loss = 1 / 2 * (((y_2 - y) ** 2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# 反向
loss.backward()  # 反向传播,计算梯度print(loss.grad)print(y_2.grad)print(z_2.grad)# 输出结果是否是None,如果是None-->True
def is_none(obj):return obj is None
# 打印出非叶子结点的gred
print(is_none(z_1.grad))
print(is_none(y_2.grad))
print(z_2.grad)

注意:不要对保存梯度的变量做任何修改,例如:z_1, y_1, z_2, y_2,修改为gred_list = [z_1, y_1, z_2, y_2],然后输入梯度值,那是错误的,要直接一个一个输出,不要做任何操作

2、hook的使用

使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.

import torch# 我们可以定义一个hook来保存中间的变量
grads = {} # 存储节点名称与节点的grad
def save_grad(name):def hook(grad):grads[name] = gradreturn hookdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)y_1 = torch.sigmoid(z_1)z_2 = torch.mm(w2, y_1)y_2 = torch.sigmoid(z_2)loss = 1/2*(((y_2 - y)**2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向传播
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)# hook中间节点
z_1.register_hook(save_grad('z_1'))
y_1.register_hook(save_grad('y_1'))
z_2.register_hook(save_grad('z_2'))
y_2.register_hook(save_grad('y_2'))# 反向传播
loss.backward()
print(grads['z_1'])
print(grads['y_1'])
print(grads['z_2'])
print(grads['y_2'])

https://www.cnblogs.com/dxscode/p/16146470.html

pytorch | loss不收敛或者训练中梯度grad为None的问题_pytorch梯度为none_Rilkean heart的博客-CSDN博客

相关文章:

获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是: retain_grad()hook 不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是…...

JMeter(八):响应断言详解

响应断言 :对服务器的响应进行断言校验 (1)应用范围: main sample and sub sample, main sample only , sub-sample only , jmeter variable 关于应用范围,我们大多数勾选“main sample only” 就足够了,因为我们一个请求,实质上只有一个请求。但是当我们发一个请求时,…...

【网络编程】IO复用的应用一:非阻塞connect

在connect连接中,若socket以非阻塞的方式进行连接,则系统内设置的TCP三次握手超时时间为0,所以它不会等待TCP三次握手完成,直接返回,错误为EINPROGRESS。   所以,我们可以通过判断connect时返回的错误码是…...

Spring注解开发,bean的作用范围及生命周期、Spring注解开发依赖注入

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaweb 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Spring注解开发 一、注解开发定义Bean二、纯注解开发Bean三…...

C#设计模式之---原型模式

原型模式(Prototype Pattern) 原型模式(Prototype Pattern) 是用原型实例指定创建对象的种类,并且通过拷贝这些原型创建新的对象。原型模式是一种创建型设计模式。也就是用一个已经创建的实例作为原型,通过…...

STM32入门学习之外部中断

1.STM32的IO口可以作为外部中断输入口。本文通过按键按下作为外部中断的输入,点亮LED灯。在STM32的19个外部中断中,0-15为外部IO口的中断输入口。STM32的引脚分别对应着0-15的外部中断线。比如,外部中断线0对应着GPIOA.0-GPIOG.0,…...

Jenkins 配置maven和jdk

前提:服务器已经安装maven和jdk 一、在Jenkins中添加全局变量 系统管理–>系统配置–>全局属性–>环境变量 添加三个全局变量 JAVA_HOME、MAVEN_HOME、PATH 二、配置maven 系统管理–>全局工具配置–>maven–>新增 新增配置 三、配置JDK 在系统管…...

Leetcode | Binary search | 22. 74. 162. 33. 34. 153.

22. Generate Parentheses 要意识到只要还有左括号,就可以放到path里。只要右括号数量小于左括号,也可以放进去。就是valid的组合。recurse两次 74. Search a 2D Matrix 看成sorted list就好。直接用m*n表示最后一位的index,并且每次只需要 …...

生命在于折腾——面试问题汇总

这里面的问题都是我参加面试时候遇到的问题,大家就这样看吧。 一、个人情况 1、自我介绍 2、为什么离开上一家公司 3、有没有参加过HVV 4、介绍一下上家公司的项目 5、小程序和公众号渗透测试做过么 6、实习工资多少 7、有挖过漏洞么 二、基础知识 1、信息收集的…...

<Java>Map<String,Object>中解析Object类型数据为数组格式

背景&#xff1a; 前端&#xff1a;入参为字符串和数组类型&#xff1b;通过json字符串传给后台&#xff0c; 后台&#xff1a;后台通过工具解析为Map<String&#xff0c;Object>&#xff0c;然后需要解析出Map里面的数组值做操作&#xff1b; 需求&#xff1a; 入参&…...

别再分库分表了,试试TiDB!

什么是NewSQL 传统SQL的问题 升级服务器硬件 数据分片 NoSQL 的问题 优点 缺点 NewSQL 特性 NewSQL 的主要特性 三种SQL的对比 TiDB怎么来的 TiDB社区版和企业版 TIDB核心特性 水平弹性扩展 分布式事务支持 金融级高可用 实时 HTAP 云原生的分布式数据库 高度兼…...

Java进阶之Dump文件初体验

视频地址&#xff1a;https://www.bilibili.com/video/BV1Ak4y137oh 学习文章&#xff1a;https://d9bp4nr5ye.feishu.cn/wiki/VQoAwlzrXiLFZekuLIyc1uK5nqc 最近线上频繁的内存告警&#xff0c;同事A通过分析dump文件解决了这个问题&#xff0c;我当然是不会放过这种学习的机…...

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

曲线拟合(MATLAB拟合工具箱)位置前馈量计算(压力闭环控制应用)

利用PLC进行压力闭环控制的项目背景介绍请查看下面文章链接,这里不再赘述。 信捷PLC压力闭环控制应用(C语言完整PD、PID源代码)_RXXW_Dor的博客-CSDN博客闭环控制的系列文章,可以查看PID专栏的的系列文章,链接如下:张力控制之速度闭环(速度前馈量计算)_RXXW_Dor的博客-CSD…...

小程序使用echarts

参考文档&#xff1a;echarts官网、echarts-for-weixin 第一步引入组件库&#xff0c;可直接从echarts-for-weixin下载&#xff0c;也可以从echarts官网自定义生成&#xff0c;这里我们就不贴了组件库引入好后&#xff0c;就是页面引用啦&#xff0c;废话不多说&#xff0c;直…...

面向对象——封装

C面向对象的三大特性为&#xff1a;封装、继承、多态 C认为万事万物都皆为对象&#xff0c;对象上有其属性和行为 例如&#xff1a; ​ 人可以作为对象&#xff0c;属性有姓名、年龄、身高、体重…&#xff0c;行为有走、跑、跳、吃饭、唱歌… ​ 车也可以作为对象&#xf…...

【LeetCode】160.相交链表

题目 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 注意&#xff0c;函数返回结…...

【JWT的使用】

文章目录 前言1、用户登录1.1 JWTThreadLocal 2.1 代码实现2.1.1 ThreadLocal工具类2.2.2 定义拦截器2.2.3 注册拦截器 前言 1、用户登录 1.1 JWT JSON Web Token简称JWT&#xff0c;用于对应用程序上用户进行身份验证的标记。使用 JWTS 之后不需要保存用户的 cookie 或其他…...

Python获取音视频时长

Python获取音视频时长 Python获取音视频时长1、安装插件2、获取音视频时长.py3、打包exe4、下载地址 Python获取音视频时长 1、安装插件 pip install moviepy -i https://pypi.tuna.tsinghua.edu.cn/simple2、获取音视频时长.py 上代码&#xff1a;获取音视频时长.py # -*-…...

TCP四次握手为什么客户端等待的时间是2MSL

目录 什么是MSL从第三次握手开始分析总结 什么是MSL MSL是Maximum Segment Lifetime英文的缩写&#xff0c;中文可以译为“报文最大生存时间”&#xff0c;他是任何报文在网络上存在的最长时间&#xff0c;超过这个时间报文将被丢弃。 从第三次握手开始分析 第三次握手服务端…...

ChatGPT网络错误不是运气问题:用mtr追踪真实路径,定位ISP路由黑洞、中间盒QoS限速与WAF误拦截(附15分钟速查表)

更多请点击&#xff1a; https://codechina.net 第一章&#xff1a;ChatGPT网络错误不是运气问题&#xff1a;用mtr追踪真实路径&#xff0c;定位ISP路由黑洞、中间盒QoS限速与WAF误拦截&#xff08;附15分钟速查表&#xff09; ChatGPT连接失败常被归因为“服务器繁忙”或“网…...

量子机器学习噪声挑战与HPQS混合框架解析

1. 量子机器学习中的噪声挑战与HPQS解决方案量子机器学习(QML)作为量子计算与经典机器学习的交叉领域&#xff0c;正在重新定义我们处理复杂模式识别问题的方式。与传统机器学习不同&#xff0c;QML利用量子态的叠加和纠缠特性&#xff0c;理论上可以在某些特定任务上实现指数级…...

别再手动删了!用Notepad++正则表达式5分钟批量清理课程目录(附实战案例)

5分钟极简正则表达式实战&#xff1a;用Notepad智能清洗杂乱课程目录 每次整理网课资源时&#xff0c;最头疼的莫过于面对几十个类似03_Python基础--循环结构实战.mp4这样的文件名。手动一个个删除序号和分类不仅耗时&#xff0c;还容易出错。上周帮同事整理200多份培训视频时&…...

CPU核心存储架构:寄存器文件与SRAM的设计原理与应用对比

1. 项目概述&#xff1a;从“存储”到“访问”的核心差异在处理器设计的核心地带&#xff0c;有两个名字听起来很像、功能也似乎都是“存东西”的组件&#xff0c;却常常让刚入行的朋友感到困惑&#xff1a;Register File&#xff08;寄存器文件&#xff09;和 SRAM&#xff08…...

P6 马铃薯病害识别

&#x1f368; 本文为&#x1f517;365天深度学习训练营中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 个人总结&#xff1a;了解VGG由 5 组卷积池化块堆叠构成&#xff0c;依靠小尺寸卷积核逐层提取图像浅层、深层特征&#xff0c;最后通过全连接层完成分类。&…...

初创团队如何利用 Taotoken Token Plan 有效控制 AI 实验成本

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 初创团队如何利用 Taotoken Token Plan 有效控制 AI 实验成本 对于资源有限的初创团队而言&#xff0c;在产品原型和概念验证阶段&…...

自动化测试的最佳实践:这6个原则让你的测试脚本更稳定

在当前互联网行业快速迭代的开发模式下&#xff0c;自动化测试已经成为保障软件交付质量、提升测试效率的核心手段。据行业调研数据显示&#xff0c;成熟的互联网测试团队中&#xff0c;核心回归测试场景的自动化覆盖率已经超过80%&#xff0c;自动化测试承担了绝大部分重复性测…...

独立开发者如何利用Taotoken同时管理多个AI项目的模型调用

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 独立开发者如何利用Taotoken同时管理多个AI项目的模型调用 对于独立开发者而言&#xff0c;同时维护多个小型产品是常态。每个产品…...

大模型推理优化:激活稀疏性技术解析与实践

1. 大模型推理优化的核心挑战与机遇在自然语言处理领域&#xff0c;大型语言模型&#xff08;LLM&#xff09;的推理效率已成为制约其广泛应用的关键瓶颈。以GPT-3 175B为例&#xff0c;单次推理需要约350GB显存和数千亿次浮点运算&#xff0c;这对硬件资源提出了极高要求。传统…...

Supermask:冻结权重+二值掩码的神经网络子结构发现方法

1. 什么是 Supermasks&#xff1f;——不是“超级面具”&#xff0c;而是神经网络里的“先天直觉” 你有没有试过教一个刚学会走路的孩子认苹果&#xff1f;你不需要从零开始教他光谱分析、细胞结构或者植物分类学&#xff0c;只要拿个红彤彤的苹果在他眼前晃一晃&#xff0c;再…...