第58步 深度学习图像识别:Transformer可视化(Pytorch)
一、写在前面
(1)pytorch_grad_cam库
这一期补上基于基于Transformer框架可视化的教程和代码,使用的是pytorch_grad_cam库,以Bottleneck Transformer模型为例。
(2)算法分类
pytorch_grad_cam库中包含的主要方法有以下几种:
GradCAM: 这是最基本的方法。GradCAM(Gradient-weighted Class Activation Mapping)通过取网络最后一个卷积层的特征图,然后对这些特征图进行加权求和,得到类别激活图。加权的系数是网络最后一个卷积层特征图对应类别的梯度的全局平均池化值。
GradCAMPlusPlus: 这是在GradCAM的基础上的改进。GradCAM++不仅计算了类别相对于特征图的梯度,还计算了二阶和三阶导数。这使得GradCAM++在某些情况下可以获得更细粒度的解释。
ScoreCAM: ScoreCAM采用了不同的策略。它对于每个特征图都生成一个类似的激活图,并将所有这些激活图加权求和。权重是每个特征图对应的类别分数。
AblationCAM: AblationCAM是基于Ablation-based的方法。它首先对每个特征图进行遮挡(或移除),然后看类别得分如何改变。这些改变被用来生成类别激活图。
XGradCAM: 这是GradCAM的另一个扩展。XGradCAM考虑了激活和梯度之间的空间关系,以生成更详细的类别激活图。
EigenCAM: 它基于主成分分析 (PCA) 的方法,利用协方差矩阵的特征向量和特征值来表示激活图。
FullGrad: FullGrad是一个对输入,权重和偏差的特征重要性进行全局分解的方法。
以上方法都在解释深度学习模型的决策,可以帮助理解模型关注的区域和特征。在选择使用哪种方法时,可以根据需求和实验效果进行选择。
二、Transformer可视化实战
继续使用胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。
(a)Bottleneck Transformer建模
######################################导入包###################################
# 导入必要的包
import copy
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from torch import optim, nn
from torch.optim import lr_scheduler
import os
import matplotlib.pyplot as plt
import warnings
import numpy as npwarnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")################################导入数据集#####################################
import torch
from torchvision import datasets, transforms
import os# 数据集路径
data_dir = "./MTB"# 图像的大小
img_height = 256
img_width = 256# 数据预处理
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(img_height),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((img_height, img_width)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 加载数据集
full_dataset = datasets.ImageFolder(data_dir)# 获取数据集的大小
full_size = len(full_dataset)
train_size = int(0.7 * full_size) # 假设训练集占80%
val_size = full_size - train_size # 验证集的大小# 随机分割数据集
torch.manual_seed(0) # 设置随机种子以确保结果可重复
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])# 将数据增强应用到训练集
train_dataset.dataset.transform = data_transforms['train']# 创建数据加载器
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)dataloaders = {'train': train_dataloader, 'val': val_dataloader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = full_dataset.classes###############################定义模型################################
# 导入必要的库
import torch.nn as nn
import timm# 定义Bottleneck Transformer模型
model = timm.create_model('botnet26t_256', pretrained=True) # 你可以选择适合你需求的BotNet版本
num_ftrs = model.feature_info[-1]['num_chs']# 根据分类任务修改最后一层
model.head.fc = nn.Linear(num_ftrs, len(class_names))# 将模型移至指定设备
model = model.to(device)# 打印模型摘要
print(model)#############################编译模型#########################################
# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.Adam(model.parameters())# 定义学习率调度器
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)# 开始训练模型
num_epochs = 2# 初始化记录器
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':model.train() # 设置模型为训练模式else:model.eval() # 设置模型为评估模式running_loss = 0.0running_corrects = 0# 遍历数据for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 只在训练模式下进行反向和优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()# 记录每个epoch的loss和accuracyif phase == 'train':train_loss_history.append(epoch_loss)train_acc_history.append(epoch_acc)else:val_loss_history.append(epoch_loss)val_acc_history.append(epoch_acc)print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))print()# 保存模型
torch.save(model.state_dict(), 'botnet_dit_model.pth')
(b)使用GradCAM可视化
在跑之前,得先安装git;然后用git安装pytorch_grad_cam:
安装git容易,无脑输入:
conda install git
安装pytorch_grad_cam也不难:
git clone https://github.com/jacobgil/pytorch-grad-cam.git
cd pytorch-grad-cam
pip install .
然后码代码:
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import timm# 代码1中的函数
def myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens,size))if titles == False:titles="0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor,heatmap=False,shape=(256,256)):np_arr=tensor.detach().numpy()#[0]#对数据进行归一化if np_arr.max()>1 or np_arr.min()<0:np_arr=np_arr-np_arr.min()np_arr=np_arr/np_arr.max()#np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0]==1:# 如果是灰度图像,复制三个通道以创建一个RGB图像np_arr=np.concatenate([np_arr,np_arr,np_arr],axis=0)np_arr=np_arr.transpose((1,2,0))return np_arr# 加载模型
model = timm.create_model('botnet26t_256', pretrained=False)# 更改全连接层以匹配你的类别数
num_ftrs = model.head.fc.in_features
model.head.fc = nn.Linear(num_ftrs, 2) # 假设你的类别数为2model.load_state_dict(torch.load('botnet_dit_model.pth', map_location=device))# 模型转移到相应设备
model = model.to(device)# 你的图像路径
image_path = './MTB/Tuberculosis/Tuberculosis-203.png'# 加载图像
image = Image.open(image_path).convert("RGB")# 使用代码1中定义的图像转换
input_image = data_transforms['val'](image).unsqueeze(0).to(device)# 使用GradCAM
target_layer = model.stages[2][0].conv3_1x1.bn.drop
with GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) as cam:target = [ClassifierOutputTarget(1)] # 修改为你的目标类别grayscale_cam = cam(input_tensor=input_image, targets=target)#将热力图结果与原图进行融合rgb_img=tensor2img(input_image.cpu().squeeze())visualization = show_cam_on_image(rgb_img, grayscale_cam[0], use_rgb=True)
myimshows([rgb_img, grayscale_cam[0], visualization],["image","cam","image + cam"])
结果输出如下:

红色区域就是模型认为的“可疑区域”,也就是说模型根据这些区域判断它是Tuberculosis的主要依据。
几个注意事项:
(a)问:代码:‘target = [ClassifierOutputTarget(0)] # 修改为你的目标类别’,这个怎么解释?此外,0和1分别代表什么呢?
答:第一小问:一般来说,ClassifierOutputTarget(0)中的0代表的是你希望将注意力图(CAM)生成针对的类别标签。例如,如果你的两个类别是猫和狗,且在训练数据集中猫的标签是0,狗的标签是1,那么ClassifierOutputTarget(0)将生成猫的注意力图,而ClassifierOutputTarget(1)将生成狗的注意力图。
第二小问:在 PyTorch 中,使用 ImageFolder 函数或类似的数据加载器加载数据时,类别名称列表(class_names)的顺序将决定了类别标签的分配。这意味着类别名称列表的索引将作为类别的标签。在我们的例子中,class_names = ['Normal', 'Tuberculosis'],"Normal" 的索引是 0,所以它的标签是 0;"Tuberculosis" 的索引是 1,所以它的标签是 1。所以ClassifierOutputTarget(0) 将生成"Normal"类别的注意力图,ClassifierOutputTarget(1) 将生成"Tuberculosis"类别的注意力图。
(b)问:代码:‘target_layer = model.stages[2][0].conv3_1x1.conv’,如何选择输出的层?怎么知道模型中有哪些层?
答:第一小问:一般来说,卷积层或者重复结构的最后一层(如 ResNet 中的每个残差块的最后一层)是可行的目标层,因为这些层能保留空间信息,而全连接层则不行,因为它们不再保留空间信息。
第二小问:通过下面代码打印出模型中所有层次的名称:
#打印出模型中所有层次的名称
for name, module in model.named_modules():
print(name)
输出如下:

或者打印出模型的顶层子模块:
#打印模型的顶层子模块
for name, module in model.named_children():print(name)
输出就四个:
stem
stages
final_conv
head
接下来,展示几个层的写法,大家自行体会:
stem.conv2.conv :target_layer = model.stem.conv2.conv
stages.3.1.conv1_1x1:target_layer = model.stages[3][1].conv1_1x1
final_conv:target_layer = model.final_conv
应该找到规律了吧,不详细解释了。每一层输出是不一样的,例如上面三层输出依次如下:



(c)问:如何改用其他7种方法来替代GradCAM?
答:很简单,来到这个代码段:
with GradCAM(model=model, target_layers=[target_layer], use_cuda=torch.cuda.is_available()) as cam:target = [ClassifierOutputTarget(0)] # 修改为你的目标类别grayscale_cam = cam(input_tensor=input_image, targets=target)#将热力图结果与原图进行融合rgb_img=tensor2img(input_image.cpu().squeeze())visualization = show_cam_on_image(rgb_img, grayscale_cam[0], use_rgb=True)
myimshows([rgb_img, grayscale_cam[0], visualization],["image","cam","image + cam"])
只需要把GradCAM分别换成GradCAMPlusPlus、ScoreCAM、AblationCAM、XGradCAM、EigenCAM以及FullGrad即可,简单粗暴。
三、写在后面
除了Transformer,pytorch_grad_cam库也可以用在之前提到的CNN的模型上,大家可自行探索哈。
四、数据
链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf
提取码:x3jf
相关文章:
第58步 深度学习图像识别:Transformer可视化(Pytorch)
一、写在前面 (1)pytorch_grad_cam库 这一期补上基于基于Transformer框架可视化的教程和代码,使用的是pytorch_grad_cam库,以Bottleneck Transformer模型为例。 (2)算法分类 pytorch_grad_cam库中包含的…...
angular实现全局组件
之前我们实现全局组件的第一种方式。我们是在定义了组件的时候通过在declares:[component],然后exports出该组件。最后在页面中每次导入该组件,而这次我们将采用另一种方式来实现 1 新建公用组件: navbreadcrumbnavbreadcrumb.component.htmlnavbreadc…...
Spring编程模型(范式)
面向对象编程 契约接口:Aware aware:意识到的 契约接口(Aware)是Spring框架中的一个特性,它允许Bean对象意识到它们所在的环境并与之进行交互,用于提供特定的功能或信息给Bean对象。这些接口通常作为回调接口,在Bean初始化过程…...
Golang GORM 单表删除
删除只有一个操作,delete。也是先找到再去删除。 可以删除单条记录,也可以删除多条记录。 var s Studentdb.Debug().Delete(&s, "age ?", 100)fmt.Println(s)[15.878ms] [rows:1] DELETE FROM student WHERE age 100var s Studentdb.De…...
Windows 下 MySQL 源码学习环境搭建步骤【建议收藏】
【建议收藏】Windows 下如何安装最新版 MySQL 源码学习的调试环境步骤。 作者:芬达 《芬达的数据库学习笔记》公众号作者,开源爱好者,擅长 MySQL、ansible。 本文来源:原创投稿 爱可生开源社区出品,原创内容未经授权不…...
redis总复习
springboot基于redisson实现看门狗锁:Springboot基于Redisson实现Redis分布式可重入锁【案例到源码分析】_springboot redission lock_AP0906424的博客-CSDN博客 springboot基于redis实现设置缓存和过期时间的代码?包括key的设计 https://mbd.baidu.com/ug_share…...
[LeetCode - Python]844. 比较;含退格的字符串(Easy);415. 字符串相加(Easy)
1.题目 844. 比较含退格的字符串(Easy) 1.代码: class Solution:def backspaceCompare(self, s: str, t: str) -> bool:# 暴力法s list(s)t list(t)M 0N 0for i in range(len(s)):i -M if s[i] # :if i > 0 :s.pop(i)s.pop(i-…...
机器学习深度学习——NLP实战(自然语言推断——注意力机制实现)
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——NLP实战(自然语言推断——数据集) 📚订阅专栏:机器学习&…...
mac垃圾清理软件有哪些
随着使用时间的增加,mac系统会产生一些垃圾文件,影响系统的性能和稳定性。为了保持mac系统的高效,用户需要定期使用mac垃圾清理软件来清理系统缓存、日志、语言包等无用文件。CleanMyMac是一款功能强大的mac垃圾清理软件,它可以帮…...
8.18 校招 内推 面经
绿泡泡: neituijunsir 交流裙,内推/实习/校招汇总表格 1、校招 | 小米集团2024届全球校园招聘正式启动(内推) 校招 | 小米集团2024届全球校园招聘正式启动(内推) 2、2023校招总结--软件测试岗位 - 2 2…...
docker的web管理平台docker.ui
docker.ui安装 docker run --name docker.ui \ -p 8999:8999 \ --restartalways \ -v /var/run/docker.sock:/var/run/docker.sock \ -d joinsunsoft/docker.ui参数说明: docker run:启动container–name:容器命名–restartalwaysÿ…...
20230822 Windows上使用find_package引入OpenCV报错
报错信息 打开Cmake项目时,find_package 报错: Found OpenCV Windows Pack but it has no binaries compatible with yourconfiguration.You should manually point CMake variable OpenCV_DIR to your build of OpenCVlibrary.原因 大概率原项目是在 …...
MySQL下载安装配置
天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…...
3D WEB轻量化引擎HOOPS产品助力NAPA打造船舶设计软件平台
NAPA(Naval Architectural PAckage,船舶建筑包),来自芬兰的船舶设计软件供应商,致力于提供世界领先的船舶设计、安全及运营的解决方案和数据分析服务。NAPA拥有超过30年的船舶设计经验,年营业额超过2560万欧…...
lesson9: C++多线程
1.线程库 1.1 thread类的简单介绍 C11 中引入了对 线程的支持 了,使得 C 在 并行编程时 不需要依赖第三方库 而且在原子操作中还引入了 原子类 的概念。要使用标准库中的线程,必须包含 < thread > 头文件 函数名 功能 thread() 构造一个线程对象…...
安卓修改SwitchCompat色值
SwitchCompat控件色值跟系统设置的主题有关,但是主题效果不是能轻易就能改的,因为涉及到整个APP的样式。网上方案基本都是通过修改style文件来改变色值,经过多次尝试修改最终觉得单独修改控件色值比较好。 一、控件属性 //修改开关色值就是最…...
pytorch内存泄漏
问题描述: 内存泄漏积累过多最终会导致内存溢出,当内存占用过大,进程会被killed掉。 解决过程: 在代码的运行阶段输出内存占用量,观察在哪一块存在内存剧烈增加或者显存异常变化的情况。但是在这个过程中要分级确认…...
20230821-字符串相乘-给树命名(unordered_map)
字符串相乘 有两个非负整数字符串num1,num2,计算num1和num2所表达整数的乘积,结果以字符串形式存储。注意:不能通过强制转换方法解题。 示例1: 输入: "4", "3" 输出: "12" …...
[Go版]算法通关村第十二关黄金——字符串冲刺题
目录 题目:最长公共前缀解法1:纵向对比-循环内套循环写法复杂度:时间复杂度 O ( n ∗ m ) O(n*m) O(n∗m)、空间复杂度 O ( 1 ) O(1) O(1)Go代码 解法2:横向对比-两两对比(类似合并K个数组、合并K个链表)复…...
neovim为工作区添加本地clangd配置
1 背景 尝试使用neovim开发stm32,使用clangd作为LSP提供代码补全等功能。 2 思路 使用stm32cubeMX生成一个基于makefile的stm32工程。 使用bear或compiledb基于makefile生成compile_commands.json文件。 为clangd配置--query-driver选项,使其使用arm…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...
大型活动交通拥堵治理的视觉算法应用
大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...
el-switch文字内置
el-switch文字内置 效果 vue <div style"color:#ffffff;font-size:14px;float:left;margin-bottom:5px;margin-right:5px;">自动加载</div> <el-switch v-model"value" active-color"#3E99FB" inactive-color"#DCDFE6"…...
【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】
1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件(System Property Definition File),用于声明和管理 Bluetooth 模块相…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...
【论文阅读28】-CNN-BiLSTM-Attention-(2024)
本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...
AspectJ 在 Android 中的完整使用指南
一、环境配置(Gradle 7.0 适配) 1. 项目级 build.gradle // 注意:沪江插件已停更,推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...
Selenium常用函数介绍
目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...
NPOI Excel用OLE对象的形式插入文件附件以及插入图片
static void Main(string[] args) {XlsWithObjData();Console.WriteLine("输出完成"); }static void XlsWithObjData() {// 创建工作簿和单元格,只有HSSFWorkbook,XSSFWorkbook不可以HSSFWorkbook workbook new HSSFWorkbook();HSSFSheet sheet (HSSFSheet)workboo…...
