当前位置: 首页 > news >正文

获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是:

  • retain_grad()
  • hook

不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是使用“retain_grad()”吧

1、retain_grad()

retain_grad()显式地保存非叶节点的梯度, 代价就是会增加显存的消耗(对比hook函数的方法则是在反向计算时直接打印, 因此不会增加显存消耗.)

使用方法:

直接在forward中对你想要输出gred的tensor“.retain_grad()”即可:tensor.retain_grad()

import torchdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)z_1.retain_grad()y_1 = torch.sigmoid(z_1)y_1.retain_grad()z_2 = torch.mm(w2, y_1)z_2.retain_grad()y_2 = torch.sigmoid(z_2)# y_2.retain_grad()loss = 1 / 2 * (((y_2 - y) ** 2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)
# 反向
loss.backward()  # 反向传播,计算梯度print(loss.grad)print(y_2.grad)print(z_2.grad)# 输出结果是否是None,如果是None-->True
def is_none(obj):return obj is None
# 打印出非叶子结点的gred
print(is_none(z_1.grad))
print(is_none(y_2.grad))
print(z_2.grad)

注意:不要对保存梯度的变量做任何修改,例如:z_1, y_1, z_2, y_2,修改为gred_list = [z_1, y_1, z_2, y_2],然后输入梯度值,那是错误的,要直接一个一个输出,不要做任何操作

2、hook的使用

使用retain_grad会消耗额外的显存, 我们可以使用hook在反向计算的时候进行保存. 还是上面的例子, 我们使用hook来完成.

import torch# 我们可以定义一个hook来保存中间的变量
grads = {} # 存储节点名称与节点的grad
def save_grad(name):def hook(grad):grads[name] = gradreturn hookdef forwrad(x, y, w1, w2):# 其中 x,y 为输入数据,w为该函数所需要的参数z_1 = torch.mm(w1, x)y_1 = torch.sigmoid(z_1)z_2 = torch.mm(w2, y_1)y_2 = torch.sigmoid(z_2)loss = 1/2*(((y_2 - y)**2).sum())return loss, z_1, y_1, z_2, y_2# 测试代码
x = torch.tensor([[1.0]])
y = torch.tensor([[1.0], [0.0]])
w1 = torch.tensor([[1.0], [2.0]], requires_grad=True)
w2 = torch.tensor([[3.0, 4.0], [5.0, 6.0]], requires_grad=True)
# 正向传播
loss, z_1, y_1, z_2, y_2 = forwrad(x, y, w1, w2)# hook中间节点
z_1.register_hook(save_grad('z_1'))
y_1.register_hook(save_grad('y_1'))
z_2.register_hook(save_grad('z_2'))
y_2.register_hook(save_grad('y_2'))# 反向传播
loss.backward()
print(grads['z_1'])
print(grads['y_1'])
print(grads['z_2'])
print(grads['y_2'])

https://www.cnblogs.com/dxscode/p/16146470.html

pytorch | loss不收敛或者训练中梯度grad为None的问题_pytorch梯度为none_Rilkean heart的博客-CSDN博客

相关文章:

获取非叶子节点的grad(retain_grad()、hook)【为了解决grad值是None的问题】

在调试过程中, 有时候我们需要对中间变量梯度进行监控, 以确保网络的有效性, 这个时候我们需要打印出非叶节点的梯度, 为了实现这个目的, 我们可以通过两种手段进行, 分别是: retain_grad()hook 不过我感觉“hook”比“retain_grad()”要麻烦.....,所以我感觉还是…...

JMeter(八):响应断言详解

响应断言 :对服务器的响应进行断言校验 (1)应用范围: main sample and sub sample, main sample only , sub-sample only , jmeter variable 关于应用范围,我们大多数勾选“main sample only” 就足够了,因为我们一个请求,实质上只有一个请求。但是当我们发一个请求时,…...

【网络编程】IO复用的应用一:非阻塞connect

在connect连接中,若socket以非阻塞的方式进行连接,则系统内设置的TCP三次握手超时时间为0,所以它不会等待TCP三次握手完成,直接返回,错误为EINPROGRESS。   所以,我们可以通过判断connect时返回的错误码是…...

Spring注解开发,bean的作用范围及生命周期、Spring注解开发依赖注入

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaweb 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Spring注解开发 一、注解开发定义Bean二、纯注解开发Bean三…...

C#设计模式之---原型模式

原型模式(Prototype Pattern) 原型模式(Prototype Pattern) 是用原型实例指定创建对象的种类,并且通过拷贝这些原型创建新的对象。原型模式是一种创建型设计模式。也就是用一个已经创建的实例作为原型,通过…...

STM32入门学习之外部中断

1.STM32的IO口可以作为外部中断输入口。本文通过按键按下作为外部中断的输入,点亮LED灯。在STM32的19个外部中断中,0-15为外部IO口的中断输入口。STM32的引脚分别对应着0-15的外部中断线。比如,外部中断线0对应着GPIOA.0-GPIOG.0,…...

Jenkins 配置maven和jdk

前提:服务器已经安装maven和jdk 一、在Jenkins中添加全局变量 系统管理–>系统配置–>全局属性–>环境变量 添加三个全局变量 JAVA_HOME、MAVEN_HOME、PATH 二、配置maven 系统管理–>全局工具配置–>maven–>新增 新增配置 三、配置JDK 在系统管…...

Leetcode | Binary search | 22. 74. 162. 33. 34. 153.

22. Generate Parentheses 要意识到只要还有左括号,就可以放到path里。只要右括号数量小于左括号,也可以放进去。就是valid的组合。recurse两次 74. Search a 2D Matrix 看成sorted list就好。直接用m*n表示最后一位的index,并且每次只需要 …...

生命在于折腾——面试问题汇总

这里面的问题都是我参加面试时候遇到的问题,大家就这样看吧。 一、个人情况 1、自我介绍 2、为什么离开上一家公司 3、有没有参加过HVV 4、介绍一下上家公司的项目 5、小程序和公众号渗透测试做过么 6、实习工资多少 7、有挖过漏洞么 二、基础知识 1、信息收集的…...

<Java>Map<String,Object>中解析Object类型数据为数组格式

背景&#xff1a; 前端&#xff1a;入参为字符串和数组类型&#xff1b;通过json字符串传给后台&#xff0c; 后台&#xff1a;后台通过工具解析为Map<String&#xff0c;Object>&#xff0c;然后需要解析出Map里面的数组值做操作&#xff1b; 需求&#xff1a; 入参&…...

别再分库分表了,试试TiDB!

什么是NewSQL 传统SQL的问题 升级服务器硬件 数据分片 NoSQL 的问题 优点 缺点 NewSQL 特性 NewSQL 的主要特性 三种SQL的对比 TiDB怎么来的 TiDB社区版和企业版 TIDB核心特性 水平弹性扩展 分布式事务支持 金融级高可用 实时 HTAP 云原生的分布式数据库 高度兼…...

Java进阶之Dump文件初体验

视频地址&#xff1a;https://www.bilibili.com/video/BV1Ak4y137oh 学习文章&#xff1a;https://d9bp4nr5ye.feishu.cn/wiki/VQoAwlzrXiLFZekuLIyc1uK5nqc 最近线上频繁的内存告警&#xff0c;同事A通过分析dump文件解决了这个问题&#xff0c;我当然是不会放过这种学习的机…...

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

曲线拟合(MATLAB拟合工具箱)位置前馈量计算(压力闭环控制应用)

利用PLC进行压力闭环控制的项目背景介绍请查看下面文章链接,这里不再赘述。 信捷PLC压力闭环控制应用(C语言完整PD、PID源代码)_RXXW_Dor的博客-CSDN博客闭环控制的系列文章,可以查看PID专栏的的系列文章,链接如下:张力控制之速度闭环(速度前馈量计算)_RXXW_Dor的博客-CSD…...

小程序使用echarts

参考文档&#xff1a;echarts官网、echarts-for-weixin 第一步引入组件库&#xff0c;可直接从echarts-for-weixin下载&#xff0c;也可以从echarts官网自定义生成&#xff0c;这里我们就不贴了组件库引入好后&#xff0c;就是页面引用啦&#xff0c;废话不多说&#xff0c;直…...

面向对象——封装

C面向对象的三大特性为&#xff1a;封装、继承、多态 C认为万事万物都皆为对象&#xff0c;对象上有其属性和行为 例如&#xff1a; ​ 人可以作为对象&#xff0c;属性有姓名、年龄、身高、体重…&#xff0c;行为有走、跑、跳、吃饭、唱歌… ​ 车也可以作为对象&#xf…...

【LeetCode】160.相交链表

题目 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 注意&#xff0c;函数返回结…...

【JWT的使用】

文章目录 前言1、用户登录1.1 JWTThreadLocal 2.1 代码实现2.1.1 ThreadLocal工具类2.2.2 定义拦截器2.2.3 注册拦截器 前言 1、用户登录 1.1 JWT JSON Web Token简称JWT&#xff0c;用于对应用程序上用户进行身份验证的标记。使用 JWTS 之后不需要保存用户的 cookie 或其他…...

Python获取音视频时长

Python获取音视频时长 Python获取音视频时长1、安装插件2、获取音视频时长.py3、打包exe4、下载地址 Python获取音视频时长 1、安装插件 pip install moviepy -i https://pypi.tuna.tsinghua.edu.cn/simple2、获取音视频时长.py 上代码&#xff1a;获取音视频时长.py # -*-…...

TCP四次握手为什么客户端等待的时间是2MSL

目录 什么是MSL从第三次握手开始分析总结 什么是MSL MSL是Maximum Segment Lifetime英文的缩写&#xff0c;中文可以译为“报文最大生存时间”&#xff0c;他是任何报文在网络上存在的最长时间&#xff0c;超过这个时间报文将被丢弃。 从第三次握手开始分析 第三次握手服务端…...

RAG系统的需求分析

这个是一个基于私有知识库的智能对话平台&#xff0c;允许用户上传文档构建专属知识库&#xff0c;并通过自然语言交互的方式查询和获取知识。它结合了大语言模型和向量检索技术&#xff0c;让用户通过对话的形式与自己的知识库进行高效交互应用场景个人用户场景:学习助手&…...

深度解析:数据仓库——定义、核心架构与企业核心价值

深度解析&#xff1a;数据仓库——定义、核心架构与企业核心价值一、引言二、定义&#xff1a;什么是数据仓库&#xff1f;2.1 标准定义2.2 核心四大特征&#xff08;数据仓库基石&#xff09;三、架构流程&#xff1a;数据仓库的标准工作流程&#xff08;带流程图&#xff09;…...

AI正冲击金融岗!高薪职业如何守住饭碗?金融人转行AI指南

AI技术正全面冲击金融行业&#xff0c;初级分析师、风控专员、客服等中低端认知劳动密集型岗位面临被替代风险。但高端投行、深度研究、资源型和创新型岗位短期内仍安全。金融人转型AI有独特优势&#xff0c;如数据敏感性、业务理解力等。转型路径包括AI应用专家、金融科技产品…...

大麦抢票自动化工具:技术赋能下的抢票效率革命

大麦抢票自动化工具&#xff1a;技术赋能下的抢票效率革命 【免费下载链接】DamaiHelper 大麦网演唱会演出抢票脚本。 项目地址: https://gitcode.com/gh_mirrors/dama/DamaiHelper 在热门演出门票抢购场景中&#xff0c;用户常常面临手动操作反应迟缓、重复劳动效率低下…...

微服务架构的陷阱:我们是如何从拆分成“微”麻烦的

对于软件测试从业者而言&#xff0c;微服务架构的兴起既带来了前所未有的挑战&#xff0c;也揭示了隐藏在水面之下的诸多陷阱。从单体应用向微服务转型&#xff0c;初衷是为了提升系统的灵活性、可维护性和团队的交付效率。然而&#xff0c;在实践中&#xff0c;许多团队却发现…...

源代码之下的硅基启示录——Claude Code“核泄漏”事件的深度剖析与时代回响

引言 公元2026年3月30日&#xff0c;一个看似平常的春日&#xff0c;硅基世界却迎来了一场史无前例的地震。 一家以“安全”为最高信条的AI公司&#xff0c;以一种最荒诞的方式&#xff0c;亲手打开了潘多拉的魔盒。Anthropic&#xff0c;这家估值高达3800亿美元的AI新贵&#…...

YOLO12快速部署指南:Gradio界面已配好,启动就能用

YOLO12快速部署指南&#xff1a;Gradio界面已配好&#xff0c;启动就能用 1. 为什么选择YOLO12镜像 YOLO12作为2025年最新发布的目标检测模型&#xff0c;带来了革命性的注意力为中心架构。这个预配置好的镜像让您无需任何复杂操作&#xff0c;就能立即体验最先进的目标检测技…...

零基础玩转mxbai-embed-large-v1:6大核心功能实战,从向量化到摘要生成

零基础玩转mxbai-embed-large-v1&#xff1a;6大核心功能实战&#xff0c;从向量化到摘要生成 1. 引言&#xff1a;为什么选择mxbai-embed-large-v1&#xff1f; mxbai-embed-large-v1是当前自然语言处理领域的一颗新星&#xff0c;这款多功能句子嵌入模型在MTEB基准测试中表…...

新手入门福音:用快马AI生成你的第一个Python版游戏账号管理工具

作为一个刚接触Python编程的新手&#xff0c;最近想尝试开发一个简单的游戏账号管理工具。这个需求其实挺常见的&#xff0c;比如我平时玩多个游戏&#xff0c;账号密码经常记混&#xff0c;如果能有个小工具统一管理就方便多了。在朋友的推荐下&#xff0c;我尝试用InsCode(快…...

APDS9960手势传感器驱动开发与嵌入式实战

1. APDS9960手势传感器库技术解析与嵌入式工程实践APDS9960是一款由Broadcom&#xff08;原Avago&#xff09;推出的集成环境光、颜色、接近度及手势识别功能的多模态光学传感器芯片。其核心价值在于将传统分立式光感方案&#xff08;如独立ALSProximityGesture模块&#xff09…...