python打卡day42
Grad-CAM与Hook函数
知识点回顾
- 回调函数
- lambda函数
- hook函数的模块钩子和张量钩子
- Grad-CAM的示例
在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具——hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:
- 调试与可视化中间层输出
- 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
- 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
- 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理
1、回调函数
Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名
在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:
- 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
- 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
- Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)
2、lambda函数
在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式
- 参数列表:可以是单个参数、多个参数或无参数
- 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)
举个例子
# 定义匿名函数:计算平方
square = lambda x: x ** 2# 调用
print(square(5)) # 输出: 25
3、hook函数
PyTorch 提供了两种主要的 hook:
- Module Hooks(模块钩子):用于监听整个模块的输入和输出
- Tensor Hooks:用于监听张量的梯度
(1)模块钩子
允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:
- register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
- register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息
前向钩子举个例子
# 创建模型实例
model = SimpleModel()# 创建一个列表用于存储中间层的输出
conv_outputs = []# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):print(f"钩子被调用!模块类型: {type(module)}")print(f"输入形状: {input[0].shape}") # input是一个元组,对应 (image, label)print(f"输出形状: {output.shape}")# 保存卷积层的输出用于后续分析# 使用detach()避免追踪梯度,防止内存泄漏conv_outputs.append(output.detach())# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()
反向钩子
# 定义一个存储梯度的列表
conv_gradients = []# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):print(f"反向钩子被调用!模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度供后续分析conv_gradients.append((grad_input, grad_output))# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()# 释放钩子
hook_handle.remove()
(2)张量钩子
PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:
- register_hook:用于监听张量的梯度
- register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如将梯度减半return grad / 2# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()print(f"x的梯度: {x.grad}")# 释放钩子
hook_handle.remove()
4、Grad-CAM
一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是“猫”),原理:
- 梯度计算:看最后几层特征图的梯度,哪个特征图对预测“猫”的贡献大
- 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
- 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度self.register_hooks()def register_hooks(self):# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目标层注册前向钩子和反向钩子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向传播,得到模型输出model_output = self.model(input_image)if target_class is None:# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影响self.model.zero_grad()# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 获取之前保存的目标层的梯度和激活值gradients = self.gradientsactivations = self.activations# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响cam = F.relu(cam)# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102 # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()
@浙大疏锦行
相关文章:

python打卡day42
Grad-CAM与Hook函数 知识点回顾 回调函数lambda函数hook函数的模块钩子和张量钩子Grad-CAM的示例 在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyT…...

XMOS以全新智能音频及边缘AI技术亮相广州国际专业灯光音响展
全球领先的边缘AI和智能音频解决方案提供商XMOS于5月27-30日亮相第23届广州国际专业灯光、音响展览会(prolight sound Guangzhou,以下简称“广州展”,XMOS展位号:5.2A66)。在本届展会上,XMOS将展出先进的音…...

Playwright 测试框架 - Node.js
🚀超全实战:基于 Playwright + Node.js 的自动化测试项目教程【附源码】 📌 本文适合自动化测试入门者 & 前端测试实战者。从零开始手把手教你搭建一个 Playwright + Node.js 项目,涵盖配置、测试用例编写、运行与调试、报告生成以及实用进阶技巧。建议收藏!👍 �…...

机器学习有监督学习sklearn实战二:六种算法对鸢尾花(Iris)数据集进行分类和特征可视化
本项目代码在个人github链接:https://github.com/KLWU07/Machine-learning-Project-practice 六种分类算法分别为逻辑回归LR、线性判别分析LDA、K近邻KNN、决策树CART、朴素贝叶斯NB、支持向量机SVM。 一、项目代码描述 1.数据准备和分析可视化 加载鸢尾花数据集&…...

vr中风--数据处理模型搭建与训练2
位置http://localhost:8888/notebooks/Untitled1-Copy1.ipynb # -*- coding: utf-8 -*- """ MUSED-I康复评估系统(增强版) 包含:多通道sEMG数据增强、混合模型架构、标准化处理 """ import numpy as np impor…...

鸿蒙next系统以后会取代安卓吗?
点击上方关注 “终端研发部” 设为“星标”,和你一起掌握更多数据库知识 官方可没说过取代谁谁,三足鼎立不好吗?三分天下,并立共存。 鸿蒙基于Linux,有人说套壳;ios/macos基于Unix,说它ios开源了…...

PolyGen:一个用于 3D 网格的自回归生成模型 论文阅读
[2002.10880] PolyGen:一个用于 3D 网格的自回归生成模型 --- [2002.10880] PolyGen: An Autoregressive Generative Model of 3D Meshes 图 2:PolyGen 首先生成网格顶点(左侧),然后基于这些顶点生成网格面࿰…...
约瑟夫问题 洛谷 - P1996
Description n个人围成一圈,从第一个人开始报数,数到 m 的人出列,再由下一个人重新从 1 开始报数,数到 m 的人再出圈,依次类推,直到所有的人都出圈,请输出依次出圈人的编号。 注意:本题和《深…...

系统思考:成长与投资不足
最近认识了一位95后年轻创业者,短短2年时间,他的公司从十几个人发展到几百人,规模迅速扩大。随着团队壮大,用户池也在持续扩大,但令人困惑的是,业绩增长却没有明显提升,甚至人效持续下滑。尽管公…...

快手可灵视频V1.6模型API如何接入免费AI开源项目工具
全球领先的视频生成大模型:可灵是首个效果对标 Sora 、面向用户开放的视频生成大模型,目前在国内及国际上均处于领先地位。快手视频生成大模型“可灵”(Kling),是全球首个真正用户可用的视频生成大模型,自面…...

数学建模期末速成 最短路径
关键词:Dijkstra算法 Floyd算法 例题 已知有6个村庄,各村的小学生人数如表所列,各村庄间的距离如图所示。现在计划建造一所医院和一所小学,问医院应建在哪个村庄才能使最远村庄的人到医院看病所走的路最短?又问小学建…...
【Netty系列】实现HTTP文件服务器
目录 一、完整代码实现 1. Maven依赖 (pom.xml) 2. 主启动类 (FileServer.java) 3. 通道初始化类 (FileServerInitializer.java) 4. 核心业务处理器 (FileServerHandler.java) 二、代码关键解释 1. 架构分层 2. 安全防护机制 3. 文件传输优化 4. 目录列表生成 三、运…...

Java开发经验——阿里巴巴编码规范实践解析7
摘要 本文主要解析了阿里巴巴 Java 开发中的 SQL 编码规范,涉及 SQL 查询优化、索引建立、字符集选择、分页查询处理、外键与存储过程的使用等多个方面,旨在帮助开发者提高代码质量和数据库操作性能,避免常见错误和性能陷阱。 1. 【强制】业…...

权威认证与质量保障:第三方检测在科技成果鉴定测试中的核心作用
科技成果鉴定测试是衡量科研成果技术价值与应用潜力的关键环节,其核心目标在于通过科学验证确保成果的可靠性、创新性和市场适配性。第三方检测机构凭借其独立性、专业性和权威性,成为科技成果鉴定测试的核心支撑主体。本文从测试流程、第三方检测的价值…...
混和效应模型在医学分析中的应用
混合效应模型(Mixed Effects Model),又称多层模型或随机效应模型,因其能同时分析固定效应(群体平均趋势)和随机效应(个体或组间差异),在医学研究中广泛应用于处理具有层次…...
架构分享|三层存储架构加速云端大模型推理
作者简介 Nilesh Agarwal,Inferless 联合创始人&CTO 关于Inferless Inferless :无服务器 GPU 推理无需管理服务器即可扩展机器学习推理,轻松部署复杂的自定义模型。获得Sequoia、Antler 和 Blume Ventures 的支持。 大语言模型(LLM&a…...

Perforce P4产品简介:无限扩展+全球协作+安全管控+工具集成(附下载)
本产品简介由Perforce中国授权合作伙伴——龙智编辑整理,旨在带您快速了解Perforce P4版本控制系统的强大之处。 世界级无限可扩展的版本控制系统 Perforce P4(原Helix Core)是业界领先的版本控制平台,备受19家全球Top20 AAA级游…...

网络协议入门:TCP/IP五层模型如何实现全球数据传输?
🔍 开发者资源导航 🔍🏷️ 博客主页: 个人主页📚 专栏订阅: JavaEE全栈专栏 内容: 网络初识什么是网络?关键概念认识协议五元组 协议分层OSI七层模型TCP/IP五层(四层&…...

Docker安装Redis集群(3主3从+动态扩容、缩容)保姆级教程含踩坑及安装中遇到的问题解决
前言 部署集群前,我们需要先掌握Redis分布式存储的核心算法。了解这些算法能帮助我们在实际工作中做出合理选择,同时清晰认识各方案的优缺点。 一、分布式存储算法 我们通过一道大厂面试题来进行阐述。 如下:1-2亿条数据需要缓存ÿ…...

企业级 AI 开发新范式:Spring AI 深度解析与实践
一、Spring AI 的核心架构与设计哲学 1.1 技术定位与价值主张 Spring AI 作为 Spring 生态系统的重要组成部分,其核心使命是将人工智能能力无缝注入企业级 Java 应用。它通过标准化的 API 抽象和 Spring Boot 的自动装配机制,让开发者能够以熟悉的 Spr…...

如何用docker部署ELK?
环境: ELK 8.8.0 Ubuntu20.04 问题描述: 如何用docker部署ELK? 解决方案: 一、环境准备 (一)主机设置 安装 Docker Engine :版本需为 18.06.0 或更新。可通过命令 docker --version 检查…...

Redis最佳实践——安全与稳定性保障之高可用架构详解
全面详解 Java 中 Redis 在电商应用的高可用架构设计 一、高可用架构核心模型 1. 多层级高可用体系 #mermaid-svg-Ffzq72Onkv7wgNKQ {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-Ffzq72Onkv7wgNKQ .error-icon{f…...

【Python 算法零基础 4.排序 ⑥ 快速排序】
既有锦绣前程可奔赴,亦有往日岁月可回首 —— 25.5.25 选择排序回顾 ① 遍历数组:从索引 0 到 n-1(n 为数组长度)。 ② 每轮确定最小值:假设当前索引 i 为最小值索引 min_index。从 i1 到 n-1 遍历,若找到…...
Java面试实战:从Spring Boot到微服务与AI的全栈挑战
场景一:初步了解和基本技术问题 面试官:我们先从基础开始,谢先生,你能简单介绍一下你在Java SE上的经验吗? 谢飞机:当然!Java就像是我的老朋友,尤其是8和11版本。我用它们做过很多…...

Go 即时通讯系统:日志模块重构,并从main函数开始
重构logger 上次写的logger.go过于繁琐,有很多没用到的功能;重构后只提供了简洁的日志接口,支持日志轮转、多级别日志记录等功能,并采用单例模式确保全局只有一个日志实例 全局变量 var (once sync.Once // 用于实现…...
CppCon 2014 学习:Exception-Safe Coding
以下是你提到的内容(例如 “Exception-Safe Coding 理解” 和 “Easier to Read!” 等)翻译成中文并进一步解释: 承诺:理解异常安全(Exception-Safe Coding) 什么是异常安全? 异常安全是指&a…...

MYSQL MGR高可用
1,MYSQL MGR高可用是什么 简单来说,MySQL MGR 的核心目标就是:确保数据库服务在部分节点(服务器)发生故障时,整个数据库集群依然能够继续提供读写服务,最大限度地减少停机时间。 2. 核心优势 v…...

阿里通义实验室突破空间音频新纪元!OmniAudio让360°全景视频“声”临其境
在虚拟现实和沉浸式娱乐快速发展的今天,视觉体验已经远远不够,声音的沉浸感成为打动用户的关键。然而,传统的视频配音技术往往停留在“平面”的音频层面,难以提供真正的空间感。阿里巴巴通义实验室(Qwen Lab࿰…...

异步上传石墨文件进度条前端展示记录(采用Redis中String数据结构实现-苏东坡版本)
昔者,有客临门,亟需自石墨文库中撷取卷帙若干。此等文册,非止一卷,乃累牍连篇,亟需批量转置。然吾辈虑及用户体验,当效东坡"腹有诗书气自华"之雅意,使操作如行云流水,遂定…...

处理知识库文件_编写powershell脚本文件_批量转换其他格式文件到pdf文件---人工智能工作笔记0249
最近在做部门知识库,选用的dify,作为rag的工具,但是经过多个对比,最后发现, 比较好用的是,纳米搜索,但是可惜纳米搜索无法在内网使用,无法把知识库放到本地,导致 有信息…...