当前位置: 首页 > news >正文

获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:

  • retain_grad()
  • hook

不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是使用“retain_grad()”吧

1、retain_grad()

retain_grad()显式地保存非叶节点的梯度, 代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)

使用方法:

直接在forward中对你想要输出gred的tensor“.retain_grad()”即可:tensor.retain_grad()

import torchdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)z_1.retain_grad()y_1 = torch.sigmoid(z_1)y_1.retain_grad()z_2 = torch.mm(w2, y_1)z_2.retain_grad()y_2 = torch.sigmoid(z_2)# y_2.retain_grad()loss = 1 / 2 * (((y_2 - y) ** 2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# 反向
loss.backward()  # 反向传播,计算梯度print(loss.grad)print(y_2.grad)print(z_2.grad)# 输出结果是否是None,如果是None-->True
def is_none(obj):return obj is None
# 打印出非叶子结点的gred
print(is_none(z_1.grad))
print(is_none(y_2.grad))
print(z_2.grad)

注意:不要对保存梯度的变量做任何修改,例如:z_1, y_1, z_2, y_2,修改为gred_list = [z_1, y_1, z_2, y_2],然后输入梯度值,那是错误的,要直接一个一个输出,不要做任何操作

2、hook的使用

使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.

import torch# 我们可以定义一个hook来保存中间的变量
grads = {} # 存储节点名称与节点的grad
def save_grad(name):def hook(grad):grads[name] = gradreturn hookdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)y_1 = torch.sigmoid(z_1)z_2 = torch.mm(w2, y_1)y_2 = torch.sigmoid(z_2)loss = 1/2*(((y_2 - y)**2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向传播
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)# hook中间节点
z_1.register_hook(save_grad('z_1'))
y_1.register_hook(save_grad('y_1'))
z_2.register_hook(save_grad('z_2'))
y_2.register_hook(save_grad('y_2'))# 反向传播
loss.backward()
print(grads['z_1'])
print(grads['y_1'])
print(grads['z_2'])
print(grads['y_2'])

https://www.cnblogs.com/dxscode/p/16146470.html

pytorch | loss不收敛或者训练中梯度grad为None的问题_pytorch梯度为none_Rilkean heart的博客-CSDN博客

相关文章:

获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是: retain_grad()hook 不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是…...

JMeter(八):响应断言详解

响应断言 :对服务器的响应进行断言校验 (1)应用范围: main sample and sub sample, main sample only , sub-sample only , jmeter variable 关于应用范围,我们大多数勾选“main sample only” 就足够了,因为我们一个请求,实质上只有一个请求。但是当我们发一个请求时,…...

【网络编程】IO复用的应用一:非阻塞connect

在connect连接中,若socket以非阻塞的方式进行连接,则系统内设置的TCP三次握手超时时间为0,所以它不会等待TCP三次握手完成,直接返回,错误为EINPROGRESS。   所以,我们可以通过判断connect时返回的错误码是…...

Spring注解开发,bean的作用范围及生命周期、Spring注解开发依赖注入

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaweb 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Spring注解开发 一、注解开发定义Bean二、纯注解开发Bean三…...

C#设计模式之---原型模式

原型模式(Prototype Pattern) 原型模式(Prototype Pattern) 是用原型实例指定创建对象的种类,并且通过拷贝这些原型创建新的对象。原型模式是一种创建型设计模式。也就是用一个已经创建的实例作为原型,通过…...

STM32入门学习之外部中断

1.STM32的IO口可以作为外部中断输入口。本文通过按键按下作为外部中断的输入,点亮LED灯。在STM32的19个外部中断中,0-15为外部IO口的中断输入口。STM32的引脚分别对应着0-15的外部中断线。比如,外部中断线0对应着GPIOA.0-GPIOG.0,…...

Jenkins 配置maven和jdk

前提:服务器已经安装maven和jdk 一、在Jenkins中添加全局变量 系统管理–>系统配置–>全局属性–>环境变量 添加三个全局变量 JAVA_HOME、MAVEN_HOME、PATH 二、配置maven 系统管理–>全局工具配置–>maven–>新增 新增配置 三、配置JDK 在系统管…...

Leetcode | Binary search | 22. 74. 162. 33. 34. 153.

22. Generate Parentheses 要意识到只要还有左括号,就可以放到path里。只要右括号数量小于左括号,也可以放进去。就是valid的组合。recurse两次 74. Search a 2D Matrix 看成sorted list就好。直接用m*n表示最后一位的index,并且每次只需要 …...

生命在于折腾——面试问题汇总

这里面的问题都是我参加面试时候遇到的问题,大家就这样看吧。 一、个人情况 1、自我介绍 2、为什么离开上一家公司 3、有没有参加过HVV 4、介绍一下上家公司的项目 5、小程序和公众号渗透测试做过么 6、实习工资多少 7、有挖过漏洞么 二、基础知识 1、信息收集的…...

<Java>Map<String,Object>中解析Object类型数据为数组格式

背景&#xff1a; 前端&#xff1a;入参为字符串和数组类型&#xff1b;通过json字符串传给后台&#xff0c; 后台&#xff1a;后台通过工具解析为Map<String&#xff0c;Object>&#xff0c;然后需要解析出Map里面的数组值做操作&#xff1b; 需求&#xff1a; 入参&…...

别再分库分表了,试试TiDB!

什么是NewSQL 传统SQL的问题 升级服务器硬件 数据分片 NoSQL 的问题 优点 缺点 NewSQL 特性 NewSQL 的主要特性 三种SQL的对比 TiDB怎么来的 TiDB社区版和企业版 TIDB核心特性 水平弹性扩展 分布式事务支持 金融级高可用 实时 HTAP 云原生的分布式数据库 高度兼…...

Java进阶之Dump文件初体验

视频地址&#xff1a;https://www.bilibili.com/video/BV1Ak4y137oh 学习文章&#xff1a;https://d9bp4nr5ye.feishu.cn/wiki/VQoAwlzrXiLFZekuLIyc1uK5nqc 最近线上频繁的内存告警&#xff0c;同事A通过分析dump文件解决了这个问题&#xff0c;我当然是不会放过这种学习的机…...

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

曲线拟合(MATLAB拟合工具箱)位置前馈量计算(压力闭环控制应用)

利用PLC进行压力闭环控制的项目背景介绍请查看下面文章链接,这里不再赘述。 信捷PLC压力闭环控制应用(C语言完整PD、PID源代码)_RXXW_Dor的博客-CSDN博客闭环控制的系列文章,可以查看PID专栏的的系列文章,链接如下:张力控制之速度闭环(速度前馈量计算)_RXXW_Dor的博客-CSD…...

小程序使用echarts

参考文档&#xff1a;echarts官网、echarts-for-weixin 第一步引入组件库&#xff0c;可直接从echarts-for-weixin下载&#xff0c;也可以从echarts官网自定义生成&#xff0c;这里我们就不贴了组件库引入好后&#xff0c;就是页面引用啦&#xff0c;废话不多说&#xff0c;直…...

面向对象——封装

C面向对象的三大特性为&#xff1a;封装、继承、多态 C认为万事万物都皆为对象&#xff0c;对象上有其属性和行为 例如&#xff1a; ​ 人可以作为对象&#xff0c;属性有姓名、年龄、身高、体重…&#xff0c;行为有走、跑、跳、吃饭、唱歌… ​ 车也可以作为对象&#xf…...

【LeetCode】160.相交链表

题目 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 注意&#xff0c;函数返回结…...

【JWT的使用】

文章目录 前言1、用户登录1.1 JWTThreadLocal 2.1 代码实现2.1.1 ThreadLocal工具类2.2.2 定义拦截器2.2.3 注册拦截器 前言 1、用户登录 1.1 JWT JSON Web Token简称JWT&#xff0c;用于对应用程序上用户进行身份验证的标记。使用 JWTS 之后不需要保存用户的 cookie 或其他…...

Python获取音视频时长

Python获取音视频时长 Python获取音视频时长1、安装插件2、获取音视频时长.py3、打包exe4、下载地址 Python获取音视频时长 1、安装插件 pip install moviepy -i https://pypi.tuna.tsinghua.edu.cn/simple2、获取音视频时长.py 上代码&#xff1a;获取音视频时长.py # -*-…...

TCP四次握手为什么客户端等待的时间是2MSL

目录 什么是MSL从第三次握手开始分析总结 什么是MSL MSL是Maximum Segment Lifetime英文的缩写&#xff0c;中文可以译为“报文最大生存时间”&#xff0c;他是任何报文在网络上存在的最长时间&#xff0c;超过这个时间报文将被丢弃。 从第三次握手开始分析 第三次握手服务端…...

Langgraph实战--自定义embeding

概述 在Langgraph中我想使用第三方的embeding接口来实现文本的embeding。但目前langchain只提供了两个类&#xff0c;一个是AzureOpenAIEmbeddings&#xff0c;一个是&#xff1a;OpenAIEmbeddings。通过ChatOpenAI无法使用第三方的接口&#xff0c;例如&#xff1a;硅基流平台…...

2023年ASOC SCI2区TOP,随机跟随蚁群优化算法RFACO,深度解析+性能实测

目录 1.摘要2.连续蚁群优化算法ACOR3.随机跟随策略4.结果展示5.参考文献6.代码获取7.算法辅导应用定制读者交流 1.摘要 连续蚁群优化是一种基于群体的启发式搜索算法&#xff08;ACOR&#xff09;&#xff0c;其灵感来源于蚁群的路径寻找行为&#xff0c;具有结构简单、控制参…...

限流算法java实现

参考教程&#xff1a;2小时吃透4种分布式限流算法 1.计数器限流 public class CounterLimiter {// 开始时间private static long startTime System.currentTimeMillis();// 时间间隔&#xff0c;单位为msprivate long interval 1000L;// 限制访问次数private int limitCount…...

锁的艺术:深入浅出讲解乐观锁与悲观锁

在多线程和分布式系统中&#xff0c;数据一致性是一个核心问题。锁机制作为解决并发冲突的重要手段&#xff0c;被广泛应用于各种场景。乐观锁和悲观锁是两种常见的锁策略&#xff0c;它们在设计理念、实现方式和适用场景上各有特点。本文将深入探讨乐观锁和悲观锁的原理、实现…...

使用python实现奔跑的线条效果

效果&#xff0c;展示&#xff08;视频效果展示&#xff09;&#xff1a; 奔跑的线条 from turtle import * import time t1Turtle() t2Turtle() t3Turtle() t1.hideturtle() t2.hideturtle() t3.hideturtle() t1.pencolor("red") t2.pencolor("green") t3…...

CSS6404L 在物联网设备中的应用优势:低功耗高可靠的存储革新与竞品对比

物联网设备对存储芯片的需求聚焦于低功耗、小尺寸、高可靠性与传输效率&#xff0c;Cascadeteq 的 CSS6404L 64Mb Quad-SPI Pseudo-SRAM 凭借差异化技术特性&#xff0c;在同类产品中展现显著优势。以下从核心特性及竞品对比两方面解析其应用价值。 一、CSS6404L 核心产品特性…...

C语言中的数据类型(二)--结构体

在之前我们已经探讨了C语言中的自定义数据类型和数组&#xff0c;链接如下&#xff1a;C语言中的数据类型&#xff08;上&#xff09;_c语言数据类型-CSDN博客 目录 一、结构体的声明 二、结构体变量的定义和初始化 三、结构体成员的访问 3.1 结构体成员的直接访问 3.2 结…...

Server2003 B-1 Windows操作系统渗透

任务环境说明&#xff1a; 服务器场景&#xff1a;Server2003&#xff08;开放链接&#xff09; 服务器场景操作系统&#xff1a;Windows7 1.通过本地PC中渗透测试平台Kali对服务器场景Windows进行系统服务及版本扫描渗透测试&#xff0c;并将该操作显示结果中Telnet服务对应的…...

机器学习复习3--模型的选择

选择合适的机器学习模型是机器学习项目成功的关键一步。这通常不是一个一蹴而就的过程&#xff0c;而是需要综合考虑多个因素&#xff0c;并进行实验和评估。 1. 理解问题本质 这是模型选择的首要步骤。需要清晰地定义试图解决的问题类型&#xff1a; 监督学习 : 数据集包含…...

算法-多条件排序

1、数对排序的使用 pair<ll,ll> a[31];//cmp为比较规则 ll cmp(pair<ll,ll>a,pair<ll,ll>b){if(a.first!b.first)return a.first>b.first;else return a.second<b.second; }//按照比较规则进行排序 sort(a1,a31,cmp); 2、具体例题 输入样例&#xff1…...