当前位置: 首页 > article >正文

day43 python Grad-CAM

目录

一、为什么需要 Grad-CAM?

二、Grad-CAM 的原理

三、Grad-CAM 的实现

1. 模块钩子(Module Hooks)

2. Grad-CAM 的实现代码

四、学习总结


在深度学习领域,神经网络模型常常被视为“黑盒”,因为其复杂的内部结构和难以理解的决策过程。然而,随着模型可解释性研究的不断深入,Grad-CAM(Gradient-weighted Class Activation Mapping)作为一种强大的可视化工具,为我们打开了一扇窥探模型决策机制的窗口。

一、为什么需要 Grad-CAM?

在实际的深度学习项目中,我们常常面临这样的问题:模型的预测结果虽然准确,但其背后的决策依据却难以捉摸。例如,在图像分类任务中,模型是如何从一张复杂的图片中识别出特定的类别?它关注了图片的哪些区域?这些问题的答案对于理解模型的行为、优化模型性能以及发现潜在的偏差至关重要。Grad-CAM 正是为了解决这些问题而诞生的。它通过可视化模型对输入图像的关注区域,帮助我们直观地理解模型的决策过程。这种可视化的热力图不仅能够增强我们对模型的信任,还能在模型出现偏差时,提供线索以便我们进行调整和优化。

二、Grad-CAM 的原理

Grad-CAM 的核心思想是利用卷积神经网络(CNN)中卷积层的特征图(Feature Map)和对应的梯度信息,生成类激活映射(Class Activation Mapping)。具体来说,它通过以下步骤实现:

  1. 选择目标层:通常选择最后一个卷积层作为目标层,因为这一层的特征图包含了丰富的语义信息。

  2. 前向传播:将输入图像通过模型进行前向传播,获取目标层的特征图。

  3. 反向传播:对目标类别进行反向传播,计算目标层的梯度。

  4. 生成热力图:将梯度信息与特征图结合,生成热力图。热力图中的高亮区域表示模型在预测目标类别时关注的区域。

Grad-CAM 的关键在于,它利用梯度信息来衡量每个特征图通道对目标类别的贡献程度,并通过对特征图进行加权求和,生成最终的热力图。

三、Grad-CAM 的实现

为了实现 Grad-CAM,我们需要借助 PyTorch 的 hook 机制。hook 是一种强大的工具,允许我们在不修改模型结构的情况下,动态地获取或修改中间层的输出或梯度。

1. 模块钩子(Module Hooks)

模块钩子分为前向钩子(register_forward_hook)和反向钩子(register_backward_hook)。前向钩子用于获取模块的输入和输出,而反向钩子用于获取模块的梯度信息。

以下是一个简单的示例,展示如何使用模块钩子获取卷积层的输出和梯度:

import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.conv = nn.Conv2d(1, 2, kernel_size=3, padding=1)self.relu = nn.ReLU()self.fc = nn.Linear(2 * 4 * 4, 10)def forward(self, x):x = self.conv(x)x = self.relu(x)x = x.view(-1, 2 * 4 * 4)x = self.fc(x)return xmodel = SimpleModel()# 定义前向钩子
def forward_hook(module, input, output):print("前向钩子被调用!")print(f"输入形状: {input[0].shape}")print(f"输出形状: {output.shape}")# 注册前向钩子
hook_handle = model.conv.register_forward_hook(forward_hook)# 创建输入并执行前向传播
x = torch.randn(1, 1, 4, 4)
output = model(x)# 移除钩子
hook_handle.remove()

通过上述代码,我们可以在卷积层的前向传播过程中获取其输入和输出。类似地,我们可以通过反向钩子获取梯度信息。

2. Grad-CAM 的实现代码

接下来,我们将实现 Grad-CAM 的完整代码。我们将使用 CIFAR-10 数据集,并基于一个简单的 CNN 模型进行实验。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image# 定义一个简单的 CNN 模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = x.view(-1, 128 * 4 * 4)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型并加载预训练权重
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()# Grad-CAM 类
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = Noneself.register_hooks()def register_hooks(self):def forward_hook(module, input, output):self.activations = output.detach()def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):model_output = self.model(input_image)if target_class is None:target_class = torch.argmax(model_output, dim=1).item()self.model.zero_grad()one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)gradients = self.gradientsactivations = self.activationsweights = torch.mean(gradients, dim=(2, 3), keepdim=True)cam = torch.sum(weights * activations, dim=1, keepdim=True)cam = F.relu(cam)cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_class# 选择一张测试图像并生成 Grad-CAM 热力图
image, label = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())[102]
input_tensor = image.unsqueeze(0)grad_cam = GradCAM(model, model.conv3)
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化结果
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(image.permute(1, 2, 0).numpy())
plt.title(f"原始图像: {label}")
plt.axis('off')plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM 热力图: {pred_class}")
plt.axis('off')plt.subplot(1, 3, 3)
img = image.permute(1, 2, 0).numpy()
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.show()

四、学习总结

通过本次实验,我对 Grad-CAM 的原理和实现有了更深入的理解。Grad-CAM 不仅能够帮助我们可视化模型的决策过程,还能在模型出现偏差时提供线索。例如,在实验中,我们发现模型在识别“青蛙”类别时,主要关注了图像的腿部和头部区域。这表明模型确实能够捕捉到关键的语义特征,但也提醒我们在数据标注和模型训练过程中需要注意潜在的偏差。

@浙大疏锦行

相关文章:

day43 python Grad-CAM

目录 一、为什么需要 Grad-CAM? 二、Grad-CAM 的原理 三、Grad-CAM 的实现 1. 模块钩子(Module Hooks) 2. Grad-CAM 的实现代码 四、学习总结 在深度学习领域,神经网络模型常常被视为“黑盒”,因为其复杂的内部结…...

在 Ubuntu 上挂载其他硬盘的步骤

一、查看当前磁盘信息 打开终端,执行: lsblk 这个命令会列出所有的块设备(包括硬盘和分区)。比如输出可能如下: NAME MAJ:MIN RM SIZE RO TYPE MOUNTPOINT sda 8:0 0 1.8T 0 disk └─sda1 8:1 0 …...

SQL的查询优化

1. 查询优化器 1.1. SQL语句执行需要经历的环节 解析阶段:语法分析和语义检查,确保语句正确;优化阶段:通过优化器生成查询计划;执行阶段:由执行器根据查询计划实际执行操作。 1.2. 查询优化器 查询优化器…...

MCU如何从向量表到中断服务

目录 1、中断向量表 2、编写中断服务例程 中断处理的核心是中断向量表(IVT),它是一个存储中断服务例程(ISR)地址的内存结构。当中断发生时,MCU通过IVT找到对应的ISR地址并跳转执行。本文将深入探讨MCU&am…...

物联网基础概念

入行物联网两年半,想写点东西记录踩过的坑,能让自己反省的同时,也希望能帮到其他小伙伴。 本人仍是小白,请看客朋友谨慎参考。另,本人主要从事智慧用电和智慧医疗行业,其他行业不一定有参考性。 以下所有内…...

Linux线程同步实战:多线程程序的同步与调度

个人主页:chian-ocean 文章专栏-Linux Linux线程同步实战:多线程程序的同步与调度 个人主页:chian-ocean文章专栏-Linux 前言:为什么要实现线程同步线程饥饿(Thread Starvation)示例:抢票问题 …...

【MySQL】事务及隔离性

目录 一、什么是事务 (一)概念 (二)事务的四大属性 (三)事务的作用 (四)事务的提交方式 二、事务的启动、回滚与提交 (一)事务的启动、回滚与提交 &am…...

Leetcode 3566. Partition Array into Two Equal Product Subsets

Leetcode 3566. Partition Array into Two Equal Product Subsets 1. 解题思路2. 代码实现 题目链接:3566. Partition Array into Two Equal Product Subsets 1. 解题思路 这一题我的实现还是比较暴力的,首先显而易见的,若要满足题目要求&…...

yolo目标检测助手:具有模型预测、图像标注功能

在人工智能浪潮席卷各行各业的今天,计算机视觉模型(如 YOLO)已成为目标检测领域的标杆。然而,模型的强大能力需要直观的界面和便捷的工具才能充分发挥其演示、验证与迭代优化的价值。为此,我开发了一款基于 WPF 的桌面…...

传统数据表设计与Prompt驱动设计的范式对比:以NBA投篮数据表为例

引言:数据表设计方法的演进 在数据库设计领域,传统的数据表设计方法与新兴的Prompt驱动设计方法代表了两种截然不同的思维方式。本文将以NBA赛季投篮数据表(shots)的设计为例,深入探讨这两种方法的差异、优劣及适用场景。随着AI技术在数据领…...

2022 RoboCom 世界机器人开发者大赛(睿抗 caip) -高职组(国赛)解题报告 | 科学家

前言 题解 2022 RoboCom 世界机器人开发者大赛(睿抗 caip) -高职组&#xff08;国赛&#xff09;。 最后一题还考验能力&#xff0c;需要找到合适的剪枝。 RC-v1 智能管家 分值: 20分 签到题&#xff0c;map的简单实用 #include <bits/stdc.h>using namespace std;int…...

WIN11 Docker Desktop 安装问题解决

windows version 打开windows 命令行&#xff0c;执行 ver显示 Microsoft Windows [版本 10.0.26100.4061]安装docker desktop 后&#xff0c;启动出问题&#xff0c;可以按下面步骤解决 安装 virtual machine plateform 开始 —》 控制面板 ----》程序 ----》启动或关闭w…...

网站服务器出现异常的原因是什么?

网站时企业和个人用户进行提供信息和服务的重要平台&#xff0c;随着时间的推移&#xff0c;网站服务器出现异常情况也是常见的问题之一&#xff0c;这可能会导致网站无法正常访问或者是运行缓慢&#xff0c;会严重影响到用户的体验感&#xff0c;本文就来介绍一下网站服务器出…...

Python实例题:Python3实现图片转彩色字符

目录 Python实例题 题目 代码实现 实现原理 图像预处理&#xff1a; 灰度值计算&#xff1a; 字符映射&#xff1a; 彩色输出&#xff1a; 关键代码解析 1. 字符映射和灰度计算 2. 图像模式输出 3. 命令行参数处理 使用说明 基本用法&#xff08;终端输出&#x…...

同一机器下通过HTTP域名访问其他服务器进程返回504问题记录

我这边项目的服务器有好几个类型节点&#xff0c;每个节点为一个进程&#xff0c;不同节点间通过HTTP来通讯&#xff0c;当前这几个类型的节点都部署在同一台机器上&#xff0c;然后我再测试某个节点到另一个节点的http通讯时&#xff0c;发现一个奇怪的现象&#xff1a; 1. 我…...

基于物联网(IoT)的电动汽车(EVs)智能诊断

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 做到欲望极简&#xff0c;了解自己的真实欲望&#xff0c;不受外在潮流的影响&#xff0c;不盲从&#x…...

JDBC+HTML+AJAX实现登陆和单表的CRUD

JDBCHTMLAJAX实现登陆和单表的CRUD 导入maven依赖 <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocatio…...

Leetcode 3568. Minimum Moves to Clean the Classroom

Leetcode 3568. Minimum Moves to Clean the Classroom 1. 解题思路2. 代码实现 题目链接&#xff1a;3568. Minimum Moves to Clean the Classroom 1. 解题思路 这一题我的核心思路就是广度优先遍历遍历剪枝。 显然&#xff0c;我们可以给出一个广度优先遍历来给出所有可能…...

Kafka多线程Consumer

Apache Kafka作为一款分布式流处理平台&#xff0c;以其高吞吐量和可扩展性在大数据处理领域占据了重要地位。在实际应用中&#xff0c;为了提升数据处理的效率和灵活性&#xff0c;我们常常需要采用多线程的方式来消费Kafka中的数据。本文将通过一个案例分析&#xff0c;详细探…...

从零开始的git学习

基本概念&#xff1a;修改记录 1、每个修改记录都有对应的id 2、当发现修改有问题时&#xff0c;可以进行回滚操作。 3、回滚的本质是一次新的更新以复原修改。但是如果不是针对最新记录进行回滚&#xff0c;会出现冲突。 这里需要举例说明 基本概念&#xff1a;分支 1、分支…...

【C++】位图详解(一文彻底搞懂位图的使用方法与底层原理)

&#x1f308; 个人主页&#xff1a;谁在夜里看海. &#x1f525; 个人专栏&#xff1a;《C系列》《Linux系列》 ⛰️ 天高地阔&#xff0c;欲往观之。 目录 1.位图的概念 2.位图的使用方法 定义与创建 设置和清除 位访问和检查 转换为其他格式 3.位图的使用场景 1.快速…...

Spring Boot 整合 JdbcTemplate,JdbcTemplate 与 MyBatis 的区别

DAY29.1 Java核心基础 Spring Boot 整合 JdbcTemplate JdbcTemplate是一个轻量级JDBC封装的组件 JdbcTemplate 是 Spring 自带的JDBC的封装&#xff0c;和Mybatis类似&#xff0c;需要自己封装sql语句 JdbcTemplate 帮助我们来连接数据库&#xff0c;SQL的执行&#xff0c;…...

sass基础语法

Sass&#xff08;Syntactically Awesome Style Sheets&#xff09;是一种 CSS 预处理器&#xff0c;提供了比原生 CSS 更强大、更灵活的语法功能。它有两种语法格式&#xff1a; Sass&#xff08;缩进语法&#xff0c;.sass 文件&#xff09;SCSS&#xff08;CSS-like 语法&am…...

【EF Core】 EF Core 批量操作的进化之路——从传统变更跟踪到无跟踪更新

文章目录 前言一、批量操作&#xff08;Rang&#xff09;1.1 AddRange()1.2 UpdateRange()1.3 AttachRange()1.4 RemoveRange() 二、Range操作的底层优化2.1 EF Core 7 前举步维艰2.2 EF Core 7后焕然一新 三、无跟踪的批量更新与删除3.1 ExecuteUpdate3.2 ExecuteDelete3.3 状…...

[Go] Option选项设计模式 — — 编程方式基础入门

[Go] Option选项设计模式 — — 编程方式基础入门 全部代码地址&#xff0c;欢迎⭐️ Github&#xff1a;https://github.com/ziyifast/ziyifast-code_instruction/tree/main/go-demo/go-option 1 介绍 在 Go 开发中&#xff0c;我们经常遇到需要处理多参数配置的场景。传统方…...

Vue 项目命名规范指南

&#x1f4da; Vue 项目命名规范指南&#xff08;适用于 Vue 3 Pinia Vue Router&#xff09; 目的&#xff1a;统一命名风格&#xff0c;提升可读性、可维护性和团队协作效率。 一、通用原则 类型命名风格示例变量camelCaseuserName, isLoading常量UPPER_SNAKE_CASEMAX_RET…...

【笔记】开源通用人工智能代理 Suna 部署全流程准备清单(Windows 系统)

#工作记录 一、基础工具与环境 开发工具 Git 或 GitHub Desktop&#xff08;代码管理&#xff09;Docker Desktop&#xff08;需启用 WSL2&#xff0c;容器化部署&#xff09;Python 3.11&#xff08;推荐版本&#xff0c;需添加到系统环境变量&#xff09;Node.js LTS&#xf…...

海康工业相机SDK二次开发(VS+QT+海康SDK+C++)

前言 工业相机在现代制造和工业自动化中扮演了至关重要的角色&#xff0c;尤其是在高精度、高速度检测中。海康威视工业相机以其性能稳定、图像质量高、兼容性强而受到广泛青睐。特别是搞机器视觉的小伙伴们跟海康打交道肯定不在少数&#xff0c;笔者在平常项目中跟海康相关人…...

前端面试准备-5

1.Node.js中的process.nectTick()有什么作用 将一个回调函数插入到当前执行栈的尾部&#xff0c;在下一次事件轮询之前调用这个回调函数 2.什么是Node.js中的事件发射器&#xff0c;作用是什么&#xff0c;如何使用 提供一种机制&#xff0c;可以创建、触发和监听自定义事件…...

Spring Boot 启动流程深度解析:从源码到实践

Spring Boot 启动流程深度解析&#xff1a;从源码到实践 Spring Boot 作为 Java 开发的主流框架&#xff0c;其 “约定大于配置” 的理念极大提升了开发效率。本文将从源码层面深入解析 Spring Boot 的启动流程&#xff0c;并通过代码示例展示其工作机制。 一、Spring Boot 启…...