当前位置: 首页 > news >正文

《python深度学习》笔记(二十):神经网络的解释方法之CAM、Grad-CAM、Grad-CAM++、LayerCAM

原理优点缺点
GAP将多维特征映射降维为一个固定长度的特征向量①减少了模型的参数量;②保留更多的空间位置信息;③可并行计算,计算效率高;④具有一定程度的不变性①可能导致信息的损失;②忽略不同尺度的空间信息
CAM利用最后一个卷积层的特征图×权重(用GAP代替全连接层,重新训练,经过GAP分类后概率最大的神经元的权重效果已经很不错需要修改原模型的结构,导致需要重新训练该模型,大大限制了使用场景,如果模型已经上线了,或着训练的成本非常高,我们几乎是不可能为了它重新训练的。
Grad-CAM最后一个卷积层的特征图×权重(通过对特征图梯度的全局平均来计算权重①解决了CAM的缺点,适用于任何卷积神经网络;②利用特征图的梯度,可视化结果更准确和精细
Grad-CAM++1. 定位更准确
2. 更适合同类多目标的情况

目录

GAP全局平均池化

CAM

Grad-CAM 

Grad-CAM++


GAP全局平均池化

论文:Network In Network

GAP (Global Average Pooling,全局平均池化),在上述论文中提出,用于避免全连接层的过拟合问题。全局平均池化就是对整个特征映射应用平均池化。

图1:将原本h × w × d的三维特征图,具体大小为6 × 6 × 3,经过GAP池化为1 × 1 × 3 输出值。也就是每一个channel的h × w 平均池化为一个值。特征图经过 GAP 处理后每一个特征图包含了不同类别的信息。 

GAP平均池化的操作步骤如下:

  1. 经过卷积操作和激活函数后,得到最后一个卷积层的特征图。
  2. 对每个通道的特征图进行平均池化,即计算每个通道上所有元素的平均值。这将每个通道的特征图转化为一个标量值。
  3. 将每个通道的标量值组合成一个特征向量。这些标量值的顺序与通道的顺序相同。
  4. 最终得到的特征向量可以作为分类器的输入,用于进行图像分类。

CAM

论文:Learning Deep Features for Discriminative Localization

原理:利用最后一个卷积层的特征图与经过GAP分类后概率最大的神经元权重进行叠加。

图2:解释了在CNN中使用全局平均池化(GAP)生成类激活映射(CAM)的过程:

经过最后一层卷积操作之后,得到的特征图包含多个channel,如图1中的不同颜色的3个channel,也就是在GAP之前所对应的不同的channel特征图,f_{k}就表示第k个channel的特征图。然后经过GAP处理后每个channel的特征图包含了不同类别的信息,w_{k}就表示分类概率最大的神经元(图2黑色神经元)所对应连接的第k个神经元的权重。

Grad-CAM 

Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization (arxiv.org)

Grad-CAM的前身是 CAM,CAM 的基本的思想是求分类网络某一类别得分对高维特征图 (卷积层的输出) 的偏导数,从而可以得到该高维特征图每个通道对该类别得分的权值;而高维特征图的激活信息 (正值) 又代表了卷积神经网络的所感兴趣的信息,加权后使用热力图呈现得到 CAM。

原理:Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。

具体步骤如下:

  1. 选择网络的最后一个卷积层,因为它既包含了高级特征,也保留了空间信息。
  2. 前向传播图像到网络,得到你想解释的类别的得分。
  3. 计算此得分相对于我们选择的卷积层输出的梯度。
  4. 对于该卷积层的每个通道,使用上述梯度的全局平均值对该通道进行加权。
  5. 结果是一个与卷积层的空间维度相同的加权热力图。

因为热力图关心的是对分类有正面影响的特征,所以在线性组合的技术上加上了ReLU,以移除负值 。

w_{k}^{c}第 k 个特征图对应于类别 c 的权重,
A^{k}表示:第 k 个特征图,
Z表示特征图的像素个数,
y^{c}表示: 第c类得分的梯度,
A_{ij}^{k}表示: 第 k个特征图中坐标( i , j )位置处的像素值;

Grad-CAM代码:

import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Imageclass GradCAM:def __init__(self, model, target_layer):self.model = model  # 要进行Grad-CAM处理的模型self.target_layer = target_layer  # 要进行特征可视化的目标层self.feature_maps = None  # 存储特征图self.gradients = None  # 存储梯度# 为目标层添加钩子,以保存输出和梯度target_layer.register_forward_hook(self.save_feature_maps)target_layer.register_backward_hook(self.save_gradients)def save_feature_maps(self, module, input, output):"""保存特征图"""self.feature_maps = output.detach()def save_gradients(self, module, grad_input, grad_output):"""保存梯度"""self.gradients = grad_output[0].detach()def generate_cam(self, image, class_idx=None):"""生成CAM热力图"""# 将模型设置为评估模式self.model.eval()# 正向传播output = self.model(image)if class_idx is None:class_idx = torch.argmax(output).item()# 清空所有梯度self.model.zero_grad()# 对目标类进行反向传播one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)one_hot[0][class_idx] = 1output.backward(gradient=one_hot.cuda(), retain_graph=True)# 获取平均梯度和特征图pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])activation = self.feature_maps.squeeze(0)for i in range(activation.size(0)):activation[i, :, :] *= pooled_gradients[i]# 创建热力图heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()heatmap = np.maximum(heatmap, 0)heatmap /= torch.max(heatmap)heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))heatmap = np.uint8(255 * heatmap)heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)# 将热力图叠加到原始图像上original_image = self.unprocess_image(image.squeeze().cpu().numpy())superimposed_img = heatmap * 0.4 + original_imagesuperimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)return heatmap, superimposed_imgdef unprocess_image(self, image):"""反预处理图像,将其转回原始图像"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)return imagedef visualize_gradcam(model, input_image_path, target_layer):"""可视化Grad-CAM热力图"""# 加载图像img = Image.open(input_image_path)preprocess = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])input_tensor = preprocess(img).unsqueeze(0).cuda()# 创建GradCAMgradcam = GradCAM(model, target_layer)heatmap, result = gradcam.generate_cam(input_tensor)# 显示图像和热力图plt.figure(figsize=(10,10))plt.subplot(1,2,1)plt.imshow(heatmap)plt.title('热力图')plt.axis('off')plt.subplot(1,2,2)plt.imshow(result)plt.title('叠加后的图像')plt.axis('off')plt.show()# 以下是示例代码,显示如何使用上述代码。
# 首先,你需要加载你的模型和权重。
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')# 然后,调用`visualize_gradcam`函数来查看结果。
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

 Grad-CAM++

Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks (arxiv.org) 

相关文章:

《python深度学习》笔记(二十):神经网络的解释方法之CAM、Grad-CAM、Grad-CAM++、LayerCAM

原理优点缺点GAP将多维特征映射降维为一个固定长度的特征向量①减少了模型的参数量;②保留更多的空间位置信息;③可并行计算,计算效率高;④具有一定程度的不变性①可能导致信息的损失;②忽略不同尺度的空间信息CAM利用…...

Python中文件copy模块shutil

高级的 文件、文件夹、压缩包 处理模块 shutil.copyfileobj(fsrc, fdst[, length])将文件内容拷贝到另一个文件中 import shutil shutil.copyfileobj(open(old.xml,r), open(new.xml, w)) shutil.copyfile(src, dst)拷贝文件 shutil.copyfile(f1.log, f2.log) #目标文件无需…...

机器学习快速入门教程 Scikit-Learn实现

机器学习是什么? 机器学习是一帮计算机科学家想让计算机像人一样思考所研发出来的计算机理论。他们曾经说过,人和计算机其实本没有差别,同样都是一大批互相连接的信息传递和存储元素所组成的系统。所以有了这样的想法,加上他们得天独厚的数学功底,机器学习的前身也就孕育而生…...

【向生活低头】win7打印机共享给win11使用,win11无法连接问题的解决

打印机是跟win7的电脑连接的,然后试了很多方法,win11都没法添加该打印机去使用。 网上的方法乱七八糟啥都有,但试了以后,发现基本没什么用。 刚刚发现知乎上的一个回答是有用的,这里做记录以备后用。 1.打开控制面板的…...

HarmonyOS鸿蒙原生应用开发设计- 元服务(原子化服务)图标

HarmonyOS设计文档中,为大家提供了独特的元服务图标,开发者可以根据需要直接引用。 开发者直接使用官方提供的元服务图标内容,既可以符合HarmonyOS原生应用的开发上架运营规范,又可以防止使用别人的元服务图标侵权意外情况等&…...

rhcsa-vim

命令行的三种模式 将ets下的passwd文件复制到普通用户下面 编辑模式的快捷方式 a--光标后插入 A--行尾插入 o--光标所在上一行插入 O--光标所在上一行插入 i--光标前插入 I--行首插入 s--删除光标所在位然后进行插入模式 S--删除光标所在行然后进行插入 命令模式的快捷…...

Rocky9 上安装 redis-dump 和redis-load 命令

一、安装依赖环境 1、依赖包 dnf -y install perl gcc gcc-c zlib-devel2、编译openssl 1.X ### 下载编译 wget https://www.openssl.org/source/openssl-1.1.1t.tar.gz tar xf openssl-1.1.1t.tar.gz cd openssl-1.1.1t ./config --prefix/usr/local/openssl make make ins…...

Azure机器学习 - 使用与Azure集成的Visual Studio Code实战教程

本文介绍如何启动远程连接到 Azure 机器学习计算实例的 Visual Studio Code。 借助 Azure 机器学习资源的强大功能,使用 VS Code 作为集成开发环境 (IDE)。 在VS Code中将计算实例设置为远程 Jupyter Notebook 服务器。 关注TechLead,分享AI全维度知识。…...

内网渗透-域信息收集

域环境 虚拟机应用:vmware17 域控主机:win2008 2r 域成员主机:win2008 2r win7 一.域用户和本地用户区别 使用本地用户安装程序时,可以直接安装 使用域用户安装程序时,需要输入域控管理员的账号密码才能安装。总结…...

三国志14信息查询小程序(历史武将信息一览)制作更新过程02-基本架构

0,前期准备 (1)一台有公网IP的云服务器,服务器上安装MySQL数据库,启用IIS服务。出入端口号配置运行(服务器和平台都要配置),IIS服务器上安装SSL证书 (2)域名…...

【51单片机】LED与独立按键(学习笔记)

一、点亮一个LED 1、LED介绍 LED:发光二极管 补:电阻读数 102 > 10 00 1k 473 > 47 000 2、Keil的使用 1、新建工程:Project > New Project Ctrl Shift N :新建文件夹 2、选型号:Atmel-AT89C52 3、xxx…...

package.json(2)

发布配置 和npm 项目包发布相关的配置。 private private 字段可以防止我们意外地将私有库发布到 npm 服务器。只需要将该字段设置为 true: "private": true preferGlobal preferGlobal 字段表示当用户不把该模块安装为全局模块时,如果设…...

Docker(2)——Docker镜像的基本命令

目录 一、简介 二、基本命令 1. Docker命令官方文档 2. 展示镜像 3. 搜索镜像 4. 下载镜像 5. 删除镜像 一、简介 本篇文章是Docker专栏的第二章,主要用于介绍Docker镜像的一些基本命令 二、基本命令 1. Docker命令官方文档 本篇博客仅记录常用的Docker镜…...

IT技术发展背景下的就业趋势:哪个领域最受欢迎?

IT技术发展背景下的就业趋势:哪个领域最受欢迎? 随着科技的不断进步和互联网的普及,IT行业正以惊人的速度蓬勃发展。在这个数字化时代,IT技术已经渗透到各个行业和领域中,为人们带来了巨大的便利和机遇。那么&#xf…...

日本移动支付Merpay QA团队的自动化现状

Merpay是日本最大的网购平台之一Mercari的无现金支付系统。Merpay 的主要功能是让用户在 Mercari的网站上购物,也可以在日本的许多实体店和餐厅使用它,也可以理解为日本的“支付宝”。以下为Merpay QA 团队在自动化方面的一些思考: 这几年&am…...

EasyExcel复杂表头数据导入

目录 表头示例导入代码数据导出 表头示例 导入代码 Overridepublic void importExcel(InputStream inputStream) {ItemExcelListener itemExcelListener new ItemExcelListener();EasyExcel.read(inputStream, ImportItem.class, itemExcelListener).headRowNumber(2).sheet()…...

【Redis】Redis安装教程基本操作语法

【Redis】Redis安装教程&基本操作语法 一、Redis简介1.1.什么是Redis1.2.Redis与传统数据库的区别主要 二、Linux安装Redis2.1.安装Redis2.2.解压安装包2.3.解压后执行安装gcc2.4.编译Redis2.5.修改Redis为守护进程2.6.启动Redis服务2.7.配置密码且外部连接2.8.重启服务器2…...

spring-boot-autoconfigure.jar/META-INF/spring.factories介绍

spring-boot-autoconfigure.jar/META-INF/spring.factories是Spring Boot自动配置的核心文件,它包含了各种自动配置类的注册信息。这个文件是Spring Boot根据应用程序的依赖关系和配置文件中的条件注解,自动加载和配置所需的Bean的依据。 在spring.fact…...

vue3视频大小适配浏览器窗口大小

目标:按浏览器窗口的大小,平铺视频,来适配屏幕的大小。 考虑使用 DPlayer.js、video.js、vue-video-player等视频插件,但报了各种各样的错;试过使用 js 对视频进行同比例放大,再判断其与窗口的大小取最小值…...

Nignx安装负载均衡动静分离以及Linux前端项目部署将域名映射到特定IP地址

目录 一、nginx简介 1.1 定义 1.2 背景 1.3 作用 二、nginx搭载负载均衡提供前后分离后台接口数据 2.1 nginx安装 2.1.1 下载依赖 2.1.2 下载并解压安装包 2.1.3 安装nginx 2.1.4 启动nginx服务 2.2 tomcat负载均衡 2.2.1 负载均衡所需服务器准备 2.2.2 配置修改 …...

Plist编辑软件 PlistEdit Pro mac中文版功能介绍

PlistEdit Pro mac是一款功能强大的Plist文件编辑软件。Plist文件是苹果公司开发的一种XML文件格式,用于存储应用程序的配置信息和数据。PlistEdit Pro可以帮助用户轻松地编辑和管理Plist文件。 PlistEdit Pro具有直观的用户界面和丰富的功能。用户可以使用该软件打…...

CSS3网页布局基础

CSS布局始于第2个版本,CSS 2.1把布局分为3种模型:常规流、浮动、绝对定位。CSS 3推出更多布局方案:多列布局、弹性盒、模板层、网格定位、网格层、浮动盒等。本章重点介绍CSS 2.1标准的3种布局模型,它们获得所有浏览器的全面、一致…...

【npm run dev 报错:error:0308010C:digital envelope routines::unsupported】

问题原因: nodejs版本太高(nodejs v17版本发布了openSSL3.0对短发和密钥大小增加了更为严格的限制,nodejs v17之前版本没有影响,但之后的版本会出现这个错误,物品的node版本是20.9.0) 解决方式&#xff1…...

Vue3.0 this,ref , $parent,$root组件通信 :VCA

1...

天猫商品评论API接口(评论内容|日期|买家昵称|追评内容|评论图片|评论视频..)

要获取天猫商品评论接口,您需要使用天猫开放平台提供的API接口。以下是一些可能有用的步骤: 注册并登录天猫开放平台,获取开发者账号。在开发者中心创建一个应用,获取应用的App Key和App Secret。使用天猫开放平台的API接口&…...

redis数据库简介

Redis是什么 Redis是现在最受欢迎的NoSQL数据库之一,Redis是一个使用ANSI C编写的开源、包含多种数据结构、支持网络、基于内存、可选持久性的键值对存储数据库,其具备如下特性: 基于内存运行,性能高效支持分布式,理…...

数据结构 - ArrayList - 动态修改的数组

目录 实现一个通用的顺序表 总结 包装类 装箱 / 装包 和 拆箱 / 拆包 ArrayList 与 顺序表 ArrayList基础功能演示 add 和 addAll ,添加元素功能 ArrayList的扩容机制 来看一下,下面的代码是否存在缺陷 模拟实现 ArrayList add 功能 add ind…...

python爬虫实战——今日头条新闻数据获取

大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 如果有什么疑惑/资料需要的可以点击文章末尾名片领取源码 第三方库: requests >>> pip install requests 第三方模块安装: win R 输入cmd 输入安装命令 pip install 模块名 (如果你觉得安装速度比较慢, 你…...

ardupilot开发 --- gdb 篇

环境 win11 vscode 1.81.0 wsl2 ardupilot 利用gdb工具在vsCode中实现 Ardupilot SITL的断点调试 优点:可在vsCode中实现断点调试。 参考文献:https://ardupilot.org/dev/docs/debugging-with-gdb-using-vscode.html 安装gdb工具 打开wsl&#xff0…...

在Vue项目中定义全局变量

在Vue项目中我们需要使用许多的变量来维护数据的流向和状态,这些变量可以是本地变量、组件变量、父子组件变量等,但这些变量都是有局限性的。在一些场景中,可能需要在多个组件中共享某个变量,此时全局变量就派上了用场。 定义全局…...