当前位置: 首页 > news >正文

pytorch中的hook机制register_forward_hook

上篇文章主要介绍了hook钩子函数的大致使用流程,本篇文章主要介绍pytorch中的hook机制register_forward_hook,手动在forward之前注册hook,hook在forward执行以后被自动执行。

1、hook背景
Hook被成为钩子机制,pytorch中包含forward和backward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。

2、源码阅读
register_forward_hook()函数必须在forward()函数调用之前被使用,因为该函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:forward is called”,也就是这个函数在forward()之后就没有作用了!):
作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。

def register_forward_hook(self, hook):r"""Registers a forward hook on the module.The hook will be called every time after :func:`forward` has computed an output.It should have the following signature::hook(module, input, output) -> None or modified outputThe hook can modify the output. It can modify the input inplace butit will not have effect on forward since this is called after:func:`forward` is called.Returns::class:`torch.utils.hooks.RemovableHandle`:a handle that can be used to remove the added hook by calling``handle.remove()``"""handle = hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id] = hookreturn handle

3、定义一个用于测试hook的类
如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),init()中调用initialize函数对所有层进行初始化。

**注意:**在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。

class TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True

4、定义hook函数
hook()函数是register_forward_hook()函数必须提供的参数,首先定义几个容器用于记录:
定义用于获取网络各层输入输出tensor的容器:

# 同时定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:

hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字。

def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None

5、对需要的层注册hook
注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):
注册钩子可以对某些层单独进行:

net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)

6、测试forward()返回的特征和hook记录的是否一致
6.1 测试forward()提供的输入输出特征

由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

输出如下:

*****forward return features*****
(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****

6.2 hook记录的输入特征和输出特征
hook通过list结构进行记录,所以可以直接print。

测试features_in是否存储了输入:

print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

得到和forward一样的结果:

*****hook record features*****
[(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>, 
<class 'torch.nn.modules.linear.Linear'>,<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****

6.3 把hook记录的和forward做减法
如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:
测试forward返回的feautes_in是不是和hook记录的一致:

print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])

得到的全部都是0,说明hook没问题:

sub result
tensor([[0., 0.],[0., 0.]])
tensor([[0., 0.],[0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],[0.]], grad_fn=<SubBackward0>)

7、完整代码

import torch
import torch.nn as nnclass TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True# 定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []# hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None# 定义全部是1的输入:
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])# 注册钩子可以对某些层单独进行:
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)# 测试网络输出:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)# 测试features_in是不是存储了输入:
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)# 测试forward返回的feautes_in是不是和hook记录的一致:
print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])

相关文章:

pytorch中的hook机制register_forward_hook

上篇文章主要介绍了hook钩子函数的大致使用流程&#xff0c;本篇文章主要介绍pytorch中的hook机制register_forward_hook&#xff0c;手动在forward之前注册hook&#xff0c;hook在forward执行以后被自动执行。 1、hook背景 Hook被成为钩子机制&#xff0c;pytorch中包含forwa…...

使用Gin框架返回JSON、XML和HTML数据

简介 Gin是一个高性能的Go语言Web框架&#xff0c;它不仅提供了简洁的API&#xff0c;还支持快速的路由和中间件处理。在Web开发中&#xff0c;返回JSON、XML和HTML数据是非常常见的需求。本文将介绍如何使用Gin框架来返回这三种类型的数据。 环境准备 在开始之前&#xff0…...

网工内推 | 国企运维工程师,华为认证优先,最高年薪20w

01 上海陆家嘴物业管理有限公司 &#x1f537;招聘岗位&#xff1a;IT运维工程师 &#x1f537;岗位职责&#xff1a; 1、负责对公司软、硬件系统、周边设备、桌面系统、服务器、网络基础环境运行维护、故障排除。 2、负责对各部门软件操作、网络安全进行检查、指导。 3、负责…...

c# 使用异步函数实现线程的功能

c#程序执行时 想要拖动窗口 需要使用线程,但是使用线程 对操作前端窗体很不友好. 所以写了一个异步函数,网上搜了一下,貌似异步函数比线程 更加友好,更加现代 做这个功能的原因是 主要是想等程序执行完 走一个提示.用线程很难做到 using System; using System.Threading; usi…...

MySQL之MySQL server has gone away复现测试

测试MySQL server has gone away复现条件 环境情形一报错信息复现测试 情形二报错信息复现测试 环境 Python: 3.8/3.9 MySQL: 5.x 情形一 报错信息 File "/usr/local/lib/python3.6/dist-packages/MySQLdb/cursors.py", line 319, in _querydb.query(q)File "/…...

编程深水区之并发④:Web多线程

Node的灵感来源于Chrome&#xff0c;更是移植了V8引擎。在Node中能够实现的多线程&#xff0c;在Web环境中自然也可以。 一、浏览器是多进程和多线程的复杂应用 在本系列的第二章节&#xff0c;有提到现代浏览器是一个多进程和多线程的复杂应用。浏览器主进程统管全局&#xf…...

【实战指南】从提升AI知识库效果,从PDF转Markdown开始

经常有人抱怨AI知识库精确度不够、答非所问。我有时候想想&#xff0c;会觉得其实AI也挺冤的&#xff0c;因为很有可能不是它能力不行&#xff0c;而是你一开始给的文档就有问题&#xff0c;导致它提取文本有错误、不完整&#xff0c;那后边一连串的检索、生成怎么可能好呢&…...

Android 删除telephony的features

比如删除android.hardware.telephony.subscription 找到这个文件&#xff1a;frameworks/native/data/etc/android.hardware.telephony.subscription.xml <!-- This is the standard set of features for devices to support Telephony Subscription API. --> -<perm…...

Linux驱动开发—编写第一个最简单的驱动模块

文章目录 开发驱动准备工作1.正常运行的Linux系统的开发板2.内核源码树3.nfs挂载的rootfs4.得心趁手的IDE 第一个Hello world 驱动程序常见模块的操作命令模块的初始化和清理模块的版本信息模块中的各种宏 示例Hello World代码printk函数解析 使用MakeFile编译驱动模块使用insm…...

科普文:微服务之Spring Cloud 组件API网关Gateway

API网关是一个服务器&#xff0c;是系统的唯一入口。从面向对象设计的角度看&#xff0c;它与外观模式类似。API网关封装了系统内部架构&#xff0c;为每个客户端提供一个定制的API。它可能还具有其它职责&#xff0c;如身份验证、监控、负载均衡、缓存、请求分片与管理、静态响…...

Kubernetes中的CRI、CNI与CSI:深入理解云原生存储、网络与容器运行时

引言 随着云原生技术的飞速发展&#xff0c;Kubernetes&#xff08;简称K8s&#xff09;作为云原生应用的核心调度平台&#xff0c;其重要性日益凸显。K8s通过开放一系列接口&#xff0c;实现了高度的可扩展性和灵活性&#xff0c;其中CRI&#xff08;Container Runtime Inter…...

【数据结构】二叉搜索树(Java + 链表实现)

Hi~&#xff01;这里是奋斗的明志&#xff0c;很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~~ &#x1f331;&#x1f331;个人主页&#xff1a;奋斗的明志 &#x1f331;&#x1f331;所属专栏&#xff1a;数据结构、LeetCode专栏 &#x1f4da;本系…...

java Brotli压缩算法实现压缩、解压缩

在Java中实现Brotli压缩和解压缩&#xff0c;你可以使用org.brotlienc和org.brotlidec包中的类。以下是压缩和解压缩的基本步骤和示例代码&#xff1a; 压缩文件 创建FileInputStream以读取原始文件。创建BrotliOutputStream以写入压缩数据。读取原始文件并写入压缩流。关闭流…...

centos7.9 安装java相关组件

10.23.15.71 - 78 账户 admin IMES1 改为root再操作 $ sudo su root ($ su root) 下载包 /home/admin/download $ mkdir download $ chown -R admin:admin /home/admin/download 安装包 /data/local $ tar -sxvf jdk-11.0.23_linux-x64_bin.tar.gz -C /data/local $ mv jdk…...

在IntelliJ IDEA中,快速找到控制类(Controller类)中所有的方法,可以通过以下几种方式实现:

在IntelliJ IDEA中&#xff0c;快速找到控制类&#xff08;Controller类&#xff09;中所有的方法&#xff0c;可以通过以下几种方式实现&#xff1a; 1. 使用快捷键 Alt 7 操作说明&#xff1a;在IDEA中&#xff0c;按下Alt 7可以快速打开“Structure”窗口&#xff08;在…...

ChatGPT的强大之处:探究及与国内产品的对比

论文题目&#xff1a;ChatGPT的强大之处&#xff1a;探究及与国内产品的对比 摘要 ChatGPT作为一种广泛应用的人工智能语言模型&#xff0c;自发布以来迅速走红全球。本文旨在探讨ChatGPT是否真如其流行程度所示那般强大&#xff0c;并对比其与国内类似产品的优劣&#xff0c;深…...

MySql审计平台

安装方式&#xff1a; cookieY/Yearning: &#x1f433; A most popular sql audit platform for mysql (github.com) 对数据库的一系列后台操作 AI助手 - AI助手提供SQL优化建议&#xff0c;帮助用户优化SQL语句&#xff0c;以获得更好的性能。同时AI助手还提供文本到SQL的…...

深度学习6--深度神经网络

1.VGG网络 在图像分 类这个领域中&#xff0c;深度卷积网络一般由卷积模块和全连接模块组成。 (1)卷积模块包含卷积层、池化层、Dropout 层、激活函数等。普遍认为&#xff0c;卷积模块是对 图像特征的提取&#xff0c;并不是对图像进行分类。 (2)全连接模块跟在卷积模块之后&…...

有了Power BI还需要深入学习Excel图表制作吗?

Power BI和Excel都是微软公司的产品&#xff0c;但它们在数据分析和可视化方面有着不同的定位和功能。 Power BI是一个强大的商业分析工具&#xff0c;它提供了数据集成、数据建模、报告和仪表板的创建等功能。Power BI 特别适合处理大量数据&#xff0c;并且可以连接到多种数…...

WEB渗透Web突破篇-命令执行

命令执行 >curl http://0ox095.ceye.io/whoami >ping whoami.b182oj.ceye.io >ping %CD%.lfofz7.dnslog.cn & cmd /v /c "whoami > temp && certutil -encode temp temp2 && findstr /L /V "CERTIFICATE" temp2 > temp3 &…...

消费级GPU福音:百川2-13B-4bits+OpenClaw自动化测试报告

消费级GPU福音&#xff1a;百川2-13B-4bitsOpenClaw自动化测试报告 1. 为什么选择这个组合&#xff1f; 去年冬天&#xff0c;我盯着显卡监控软件里跳动的显存占用数字&#xff0c;突然意识到一个问题&#xff1a;大多数开源大模型对消费级GPU太不友好了。动辄20GB以上的显存…...

AI 为什么不认识 Excel 文件?——用 SpreadJS 与 GCExcel 打通 AI 与数据的鸿沟

在技术领域&#xff0c;我们常常被那些闪耀的、可见的成果所吸引。今天&#xff0c;这个焦点无疑是大语言模型技术。它们的流畅对话、惊人的创造力&#xff0c;让我们得以一窥未来的轮廓。然而&#xff0c;作为在企业一线构建、部署和维护复杂系统的实践者&#xff0c;我们深知…...

保姆级教程:用Vivado MIG IP核搞定DDR3读写仿真(附AXI4波形分析)

从零掌握Vivado MIG IP核&#xff1a;DDR3读写仿真与AXI4协议深度解析 刚接触Xilinx FPGA的工程师第一次打开MIG IP核配置界面时&#xff0c;往往会被密密麻麻的参数选项吓到——时钟设置、AXI接口、地址映射、时序约束&#xff0c;每个环节都可能成为项目推进的拦路虎。本文将…...

MATLAB代码:基于主从博弈的电热综合能源系统DE算法优化动态定价与能量管理

MATLAB代码&#xff1a;基于主从博弈的电热综合能源系统动态定价与能量管理 关键词&#xff1a;主从博弈 电热综合能源 动态定价 能量管理 仿真平台&#xff1a;MATLAB 平台 优势&#xff1a;代码具有一定的深度和创新性&#xff0c;注释清晰&#xff0c;非烂大街的代码&…...

HJ166 讨厌鬼进货

题目题解(40)讨论(20)排行 入门 通过率&#xff1a;61.91% 时间限制&#xff1a;1秒 空间限制&#xff1a;256M 知识点贪心 校招时部分企业笔试将禁止编程题跳出页面&#xff0c;为提前适应&#xff0c;练习时请使用在线自测&#xff0c;而非本地IDE。 描述 讨厌鬼需要采…...

UML/结构/创建/行为—计算机等级考试—软件设计师考前备忘录—东方仙盟

UML → 创建型 5 种 → 结构型 7 种 → 行为型 11 种每种&#xff1a;定义&#xff08;教材版&#xff09; 1 道真题选择题你直接复制进 Word&#xff0c;考前背这一篇就够。一、UML 核心考点&#xff08;上午选择 下午应用题&#xff09;1. 用例图&#xff08;Use Case&#…...

文墨共鸣大模型高效写作工具链:替代Typora的AI增强Markdown编辑体验

文墨共鸣大模型高效写作工具链&#xff1a;替代Typora的AI增强Markdown编辑体验 如果你也像我一样&#xff0c;常年和Markdown文档打交道&#xff0c;那你一定对Typora不陌生。它简洁、优雅&#xff0c;所见即所得的编辑体验&#xff0c;让它成为了许多写作者和技术博主的心头…...

告别鼠标手!用Python的keyboard库打造你的专属游戏/办公热键助手(附完整源码)

告别鼠标手&#xff01;用Python的keyboard库打造你的专属游戏/办公热键助手&#xff08;附完整源码&#xff09; 长时间盯着电脑屏幕&#xff0c;手腕因为频繁点击鼠标而酸痛不已&#xff1f;这种"鼠标手"的困扰几乎成了现代办公族和游戏玩家的标配。但你可能没意识…...

拯救你的网站兼容性:手把手教你用heic2any解决苹果图片上传问题

苹果用户图片上传难题的终极解决方案&#xff1a;前端HEIC转换实战指南 你是否遇到过这样的场景&#xff1a;精心设计的网站上传功能&#xff0c;在苹果用户面前却频频报错&#xff1f;后台服务器不断收到无法识别的图片格式&#xff0c;而用户则抱怨"明明能拍照片却上传…...

智能座舱音频革命:如何用AVB交换机+TSN协议打造零延迟车载音响系统?

智能座舱音频革命&#xff1a;AVB交换机与TSN协议构建毫秒级同步音响系统 当你在驾驶舱内播放一首交响乐时&#xff0c;前排低音炮与后排高音单元的时差超过10毫秒&#xff0c;人耳就能感知声场撕裂——这种体验在传统车载音频架构中几乎无法避免。随着智能座舱向"第三生活…...