获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】
在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:
- retain_grad()
- hook
不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是使用“retain_grad()”吧
1、retain_grad()
retain_grad()显式地保存非叶节点的梯度, 代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)
使用方法:
直接在forward中对你想要输出gred的tensor“.retain_grad()”即可:tensor.retain_grad()
import torchdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)z_1.retain_grad()y_1 = torch.sigmoid(z_1)y_1.retain_grad()z_2 = torch.mm(w2, y_1)z_2.retain_grad()y_2 = torch.sigmoid(z_2)# y_2.retain_grad()loss = 1 / 2 * (((y_2 - y) ** 2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# 反向
loss.backward() # 反向传播,计算梯度print(loss.grad)print(y_2.grad)print(z_2.grad)# 输出结果是否是None,如果是None-->True
def is_none(obj):return obj is None
# 打印出非叶子结点的gred
print(is_none(z_1.grad))
print(is_none(y_2.grad))
print(z_2.grad)
注意:不要对保存梯度的变量做任何修改,例如:z_1, y_1, z_2, y_2,修改为gred_list = [z_1, y_1, z_2, y_2],然后输入梯度值,那是错误的,要直接一个一个输出,不要做任何操作
2、hook的使用
使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.
import torch# 我们可以定义一个hook来保存中间的变量
grads = {} # 存储节点名称与节点的grad
def save_grad(name):def hook(grad):grads[name] = gradreturn hookdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)y_1 = torch.sigmoid(z_1)z_2 = torch.mm(w2, y_1)y_2 = torch.sigmoid(z_2)loss = 1/2*(((y_2 - y)**2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向传播
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)# hook中间节点
z_1.register_hook(save_grad('z_1'))
y_1.register_hook(save_grad('y_1'))
z_2.register_hook(save_grad('z_2'))
y_2.register_hook(save_grad('y_2'))# 反向传播
loss.backward()
print(grads['z_1'])
print(grads['y_1'])
print(grads['z_2'])
print(grads['y_2'])
https://www.cnblogs.com/dxscode/p/16146470.html
pytorch | loss不收敛或者训练中梯度grad为None的问题_pytorch梯度为none_Rilkean heart的博客-CSDN博客
相关文章:
获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】
在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是: retain_grad()hook 不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是…...

JMeter(八):响应断言详解
响应断言 :对服务器的响应进行断言校验 (1)应用范围: main sample and sub sample, main sample only , sub-sample only , jmeter variable 关于应用范围,我们大多数勾选“main sample only” 就足够了,因为我们一个请求,实质上只有一个请求。但是当我们发一个请求时,…...
【网络编程】IO复用的应用一:非阻塞connect
在connect连接中,若socket以非阻塞的方式进行连接,则系统内设置的TCP三次握手超时时间为0,所以它不会等待TCP三次握手完成,直接返回,错误为EINPROGRESS。 所以,我们可以通过判断connect时返回的错误码是…...

Spring注解开发,bean的作用范围及生命周期、Spring注解开发依赖注入
🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaweb 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Spring注解开发 一、注解开发定义Bean二、纯注解开发Bean三…...
C#设计模式之---原型模式
原型模式(Prototype Pattern) 原型模式(Prototype Pattern) 是用原型实例指定创建对象的种类,并且通过拷贝这些原型创建新的对象。原型模式是一种创建型设计模式。也就是用一个已经创建的实例作为原型,通过…...

STM32入门学习之外部中断
1.STM32的IO口可以作为外部中断输入口。本文通过按键按下作为外部中断的输入,点亮LED灯。在STM32的19个外部中断中,0-15为外部IO口的中断输入口。STM32的引脚分别对应着0-15的外部中断线。比如,外部中断线0对应着GPIOA.0-GPIOG.0,…...

Jenkins 配置maven和jdk
前提:服务器已经安装maven和jdk 一、在Jenkins中添加全局变量 系统管理–>系统配置–>全局属性–>环境变量 添加三个全局变量 JAVA_HOME、MAVEN_HOME、PATH 二、配置maven 系统管理–>全局工具配置–>maven–>新增 新增配置 三、配置JDK 在系统管…...

Leetcode | Binary search | 22. 74. 162. 33. 34. 153.
22. Generate Parentheses 要意识到只要还有左括号,就可以放到path里。只要右括号数量小于左括号,也可以放进去。就是valid的组合。recurse两次 74. Search a 2D Matrix 看成sorted list就好。直接用m*n表示最后一位的index,并且每次只需要 …...
生命在于折腾——面试问题汇总
这里面的问题都是我参加面试时候遇到的问题,大家就这样看吧。 一、个人情况 1、自我介绍 2、为什么离开上一家公司 3、有没有参加过HVV 4、介绍一下上家公司的项目 5、小程序和公众号渗透测试做过么 6、实习工资多少 7、有挖过漏洞么 二、基础知识 1、信息收集的…...

<Java>Map<String,Object>中解析Object类型数据为数组格式
背景: 前端:入参为字符串和数组类型;通过json字符串传给后台, 后台:后台通过工具解析为Map<String,Object>,然后需要解析出Map里面的数组值做操作; 需求: 入参&…...

别再分库分表了,试试TiDB!
什么是NewSQL 传统SQL的问题 升级服务器硬件 数据分片 NoSQL 的问题 优点 缺点 NewSQL 特性 NewSQL 的主要特性 三种SQL的对比 TiDB怎么来的 TiDB社区版和企业版 TIDB核心特性 水平弹性扩展 分布式事务支持 金融级高可用 实时 HTAP 云原生的分布式数据库 高度兼…...

Java进阶之Dump文件初体验
视频地址:https://www.bilibili.com/video/BV1Ak4y137oh 学习文章:https://d9bp4nr5ye.feishu.cn/wiki/VQoAwlzrXiLFZekuLIyc1uK5nqc 最近线上频繁的内存告警,同事A通过分析dump文件解决了这个问题,我当然是不会放过这种学习的机…...

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
曲线拟合(MATLAB拟合工具箱)位置前馈量计算(压力闭环控制应用)
利用PLC进行压力闭环控制的项目背景介绍请查看下面文章链接,这里不再赘述。 信捷PLC压力闭环控制应用(C语言完整PD、PID源代码)_RXXW_Dor的博客-CSDN博客闭环控制的系列文章,可以查看PID专栏的的系列文章,链接如下:张力控制之速度闭环(速度前馈量计算)_RXXW_Dor的博客-CSD…...
小程序使用echarts
参考文档:echarts官网、echarts-for-weixin 第一步引入组件库,可直接从echarts-for-weixin下载,也可以从echarts官网自定义生成,这里我们就不贴了组件库引入好后,就是页面引用啦,废话不多说,直…...
面向对象——封装
C面向对象的三大特性为:封装、继承、多态 C认为万事万物都皆为对象,对象上有其属性和行为 例如: 人可以作为对象,属性有姓名、年龄、身高、体重…,行为有走、跑、跳、吃饭、唱歌… 车也可以作为对象…...

【LeetCode】160.相交链表
题目 给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中不存在环。 注意,函数返回结…...
【JWT的使用】
文章目录 前言1、用户登录1.1 JWTThreadLocal 2.1 代码实现2.1.1 ThreadLocal工具类2.2.2 定义拦截器2.2.3 注册拦截器 前言 1、用户登录 1.1 JWT JSON Web Token简称JWT,用于对应用程序上用户进行身份验证的标记。使用 JWTS 之后不需要保存用户的 cookie 或其他…...
Python获取音视频时长
Python获取音视频时长 Python获取音视频时长1、安装插件2、获取音视频时长.py3、打包exe4、下载地址 Python获取音视频时长 1、安装插件 pip install moviepy -i https://pypi.tuna.tsinghua.edu.cn/simple2、获取音视频时长.py 上代码:获取音视频时长.py # -*-…...
TCP四次握手为什么客户端等待的时间是2MSL
目录 什么是MSL从第三次握手开始分析总结 什么是MSL MSL是Maximum Segment Lifetime英文的缩写,中文可以译为“报文最大生存时间”,他是任何报文在网络上存在的最长时间,超过这个时间报文将被丢弃。 从第三次握手开始分析 第三次握手服务端…...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...

Python爬虫(一):爬虫伪装
一、网站防爬机制概述 在当今互联网环境中,具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类: 身份验证机制:直接将未经授权的爬虫阻挡在外反爬技术体系:通过各种技术手段增加爬虫获取数据的难度…...

MySQL 8.0 OCP 英文题库解析(十三)
Oracle 为庆祝 MySQL 30 周年,截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始,将英文题库免费公布出来,并进行解析,帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...
用鸿蒙HarmonyOS5实现中国象棋小游戏的过程
下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...

C++_哈希表
本篇文章是对C学习的哈希表部分的学习分享 相信一定会对你有所帮助~ 那咱们废话不多说,直接开始吧! 一、基础概念 1. 哈希核心思想: 哈希函数的作用:通过此函数建立一个Key与存储位置之间的映射关系。理想目标:实现…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现指南针功能
指南针功能是许多位置服务应用的基础功能之一。下面我将详细介绍如何在HarmonyOS 5中使用DevEco Studio实现指南针功能。 1. 开发环境准备 确保已安装DevEco Studio 3.1或更高版本确保项目使用的是HarmonyOS 5.0 SDK在项目的module.json5中配置必要的权限 2. 权限配置 在mo…...

goreplay
1.github地址 https://github.com/buger/goreplay 2.简单介绍 GoReplay 是一个开源的网络监控工具,可以记录用户的实时流量并将其用于镜像、负载测试、监控和详细分析。 3.出现背景 随着应用程序的增长,测试它所需的工作量也会呈指数级增长。GoRepl…...

使用VMware克隆功能快速搭建集群
自己搭建的虚拟机,后续不管是学习java还是大数据,都需要集群,java需要分布式的微服务,大数据Hadoop的计算集群,如果从头开始搭建虚拟机会比较费时费力,这里分享一下如何使用克隆功能快速搭建一个集群 先把…...