当前位置: 首页 > news >正文

CAM类激活映射 |神经网络可视化 | 热力图

文章目录

    • 前言:
    • 安装库:
    • 分类案例--ResNet50
    • 分割案例
      • AttributeError: ‘tuple‘ object has no attribute ‘cpu‘
      • RuntimeError: grad can be implicitly created only for scalar outputs
      • TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
      • 完整代码

前言:

本篇文章只是教程,不涉及原理,感兴趣可以自行搜索
如图,热力图可以很好的反映出网络究竟注意图片的哪一部分
在这里插入图片描述
github官方教程:
https://github.com/jacobgil/pytorch-grad-cam
参考博客:
https://blog.csdn.net/u014264373/article/details/85415921
https://blog.csdn.net/u014264373/article/details/116302678
但还是遇到了很多报错,解决过程记录如下:

安装库:

pip install grad-cam

分类案例–ResNet50

案例图片:
在这里插入图片描述

案例代码:
这个代码是可以跑通的,将图片保存到你本地,然后设置好路径即可。
(需要下载ResNet预训练模型)

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as npdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.detach().numpy()  # [0]# 对数据进行归一化if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrpath = "../examples/both.png"
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1; squeeze
input_tensor = torchvision.transforms.functional.resize(img, [224, 224])model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
with GradCAM(model=model, target_layers=target_layers, use_cuda=False) as cam:# targets = [ClassifierOutputTarget(386), ClassifierOutputTarget(386)]  # 指定查看class_num为386的热力图targets = None  # 选定目标类别,如果不设置,则默认为分数最高的那一类# aug_smooth=True, eigen_smooth=True 使用图像增强是热力图变得更加平滑grayscale_cams = cam(input_tensor=input_tensor, targets=targets)  # targets=None 自动调用概率最大的类别显示for grayscale_cam, tensor in zip(grayscale_cams, input_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

最后出来的结果应该就是这样一张图。
在这里插入图片描述

分割案例

如果上面的代码你跑通了
那么如何为自己的网络生成热力图呢?
有几个需要注意的点:(最后会附上完整代码)

首先,切换成你的网络了模型加载就不说了,这个自己搞好。
然后,你的网络是否是在gpu上跑的,如果是
输入数据要放gpu上

path = './test_img/yu.jpg'
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeeze
img_tensor = torchvision.transforms.functional.resize(img, [352, 352])
img_tensor = img_tensor.cuda()   # 加一句这个

然后按照上面的代码,修改这一句,改成你要查看的层:

target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值

把这个改成你要的层,然后运行一下,可能会遇到报错:

AttributeError: ‘tuple‘ object has no attribute ‘cpu‘

如果出现这个报错,可以看下你的网络最终输出是几个特征。因为是自己写的网络,有的因为训练需要,最终返回的是多个结果。

print(model(x))

如果有多个结果,会被变成一个元组。后面需要转cpu,元组tuple没有.cpu的方法,所以报错。
解决方法:
先把你的网络包装一下,你返回了多个值,选择有用的那一个就行
我这里选择了多个输出的第一个,自己视情况而定

class SegmentationModelOutputWrapper(torch.nn.Module):def __init__(self, model):super(SegmentationModelOutputWrapper, self).__init__()self.model = modeldef forward(self, x):return self.model(x)[0]  # 我这里选择了多个输出的第一个,自己视情况而定model = NetWork()
model.load_state_dict(torch.load(opt.snap_path))
# 网络加载后先包装下  修改输出
model = SegmentationModelOutputWrapper(model)

然后再运行,可能会出现报错:

RuntimeError: grad can be implicitly created only for scalar outputs

这个问题的解决办法是:
你需要去到base_cam.py这个库文件里面去
第85行有一句loss.backward(retain_graph = True)
将其修改为loss.backward(torch.ones_like(loss),retain_graph=True)

在这里插入图片描述

参考链接:https://blog.csdn.net/weixin_44390884/article/details/127893163

还有一个报错:

TypeError: can’t convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

如果你的模型最后返回的特征是tensor的特征,那么需要对tensor2img做改动
在这里插入图片描述

np_arr = tensor.detach().numpy()  # [0]

修改为:

np_arr = tensor.cpu().detach().numpy()  # [0]

完整代码

import os
import torch
import argparse
import numpy as np
import imageio
import torchvision
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from matplotlib import pyplot as pltdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.cpu().detach().numpy()  # [0]if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrif __name__ == '__main__':model = NetWork()model.load_state_dict(torch.load(opt.snap_path))# torchinfo.summary(model=model,input_size=(8, 3, 352, 352))# 包装下  修改输出model = SegmentationModelOutputWrapper(model)model.eval()path = './test_img/yu.jpg'bin_data = torchvision.io.read_file(path)  # 加载二进制数据img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeezeimg_tensor = torchvision.transforms.functional.resize(img, [352, 352])img_tensor = img_tensor.cuda()target_layers = [model.model.ncd]targets = Nonewith GradCAM(model=model,target_layers=target_layers,use_cuda=True) as cam:grayscale_cams = cam(input_tensor=img_tensor,targets=targets,aug_smooth=True)# cam_image = show_cam_on_image(img_rgb, grayscale_cam, use_rgb=True)for grayscale_cam, tensor in zip(grayscale_cams, img_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

相关文章:

CAM类激活映射 |神经网络可视化 | 热力图

文章目录前言&#xff1a;安装库&#xff1a;分类案例--ResNet50分割案例AttributeError: ‘tuple‘ object has no attribute ‘cpu‘RuntimeError: grad can be implicitly created only for scalar outputsTypeError: cant convert cuda:0 device type tensor to numpy. Use…...

RecyclerView+BaseRecyclerViewAdapterHelper显示不全只显示第一行item的解决问题

RecyclerViewBaseRecyclerViewAdapterHelper显示不全只显示第一行item&#xff0c;我懵了…&#xff0c;我不说多&#xff0c;直接说吧 先看一下适配器代码中的convert()方法&#xff1a; class MineRadioAdapter(layoutResId: Int R.layout.item_my_live) :BaseQuickAdapte…...

解决后端无法对前端的ajax请求重定向

本章目录&#xff1a; 问题描述 AJAX请求后端直接重定向失败解决方案 后端拦截请为响应头添加重定向标志后端拦截器为响应头添加重定向路径前端响应拦截器获取响应头数据&#xff0c;并通过location.href url 完成页面跳转一、问题描述 本来想在拦截器里设置未登录用户访问指…...

【Python】1分钟就能制作精美的框架图?太棒啦

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、准备二、基本使用与例子1.初始化与导出2.节点类型3.集群块4.自定义线的颜色与属性总结前言 Diagrams 是一个基于Python绘制云系统架构的模块&#xff0c;它能…...

淘宝必备的补单技巧及注意事项!

补单&#xff0c;是优化善后的s单。单只是模拟用户的购物习惯&#xff0c;而补单同时还要模拟整个店铺的综合数据&#xff0c;包括点击率、转化率等等&#xff0c;补到略高于同行、竞品的平均数据时&#xff0c;淘宝会判断为买家比较喜欢你的商品&#xff0c;从而给你更多推荐机…...

【实用篇】SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式,系统详解springcloud分布式

文章目录一、服务拆分1.1 服务拆分Demo1.2 微服务远程调用二、Eureka2.1 Eureka原理2.2 Eureka-server服务搭建2.3 eureka-client服务注册2.4 eureka-client服务复制2.5 eureka服务发现三、Ribbon负载均衡3.1 负载均衡原理3.2 负载均衡策略3.3 自定义负载均衡策略3.4 饥饿加载与…...

私人飞机、公务机包机会成为富豪圈的主流出行方式吗?

从炫耀性消费到按需使用&#xff0c;私人飞机的消费群体正在被拓宽&#xff0c;但离“成为主流”还有一段距离。“时间就是金钱”为有钱人消费私人飞机提供合理动机&#xff0c;而这群高净值人群的数量增长则成为撑起市场基本面。据相关数据显示&#xff0c;2018年全球超级富豪…...

Oracle组织架构

组织架构 &#xff08;一&#xff09;业务组&#xff08;BG&#xff09; &#xff08;二&#xff09;法律实体&#xff08;LE&#xff09; &#xff08;三&#xff09;业务实体&#xff08;OU&#xff09; &#xff08;四&#xff09;库存组织&#xff08;INV&#xff09; …...

最小公倍数

目录 最小公倍数 程序设计 程序分析 最小公倍数 【问题描述】给定两个正整数,计算这两个数的最小公倍数。 【输入形式】输入包含多组测试数据,每组只有一行,包括两个不大于1000的正整数. 【输出形式】 对于每个测试用例,给出这两个数的最小公倍数,每个实例输出一行。…...

二叉树的后序遍历(力扣145)

目录 题目描述&#xff1a; 解法一&#xff1a;递归法 解法二&#xff1a;迭代法 解法三&#xff1a;Morris遍历 二叉树的后序遍历 题目描述&#xff1a; 给你一棵二叉树的根节点 root &#xff0c;返回其节点值的 后序遍历 。 示例 1&#xff1a; 输入&#xff1a;root …...

《Effective C++》读书纪实 -- 诸君同享

文章目录《Effective C》是一本经典的C编程指南&#xff0c;共包含50条C编程的最佳实践。 确定你的构造函数的行为 在构造函数中&#xff0c;应该尽可能地避免调用虚函数、非静态成员函数和虚基类的函数。 尽量使用const、enum、inline替换#define 使用const、enum、inline可以…...

【云原生】K8S-ConfigMap 实现应用和配置分离

文章目录前言ConfigMap 背景ConfigMap 创建方式ConfigMap 的使用使用 ConfigMap 的注意事项总结前言 Kubernetes 是目前最流行的容器编排系统之一&#xff0c;它提供了丰富的功能来支持容器化应用程序的管理和部署。 ConfigMap 是 Kubernetes 中重要的资源对象&#xff0c;用…...

java -测距工具(经纬度)

代码 /*** 测距工具* author qb*/ public class DistanceUtils {/*** 赤道半径*/private static final double EARTH_RADIUS 6378.137;private static double rad(double d) {return d * Math.PI / 180.0;}/*** Description : 通过经纬度获取距离(单位&#xff1a;米)* Group…...

postgres分区表的创建-基于继承

参考文档&#xff1a; http://postgres.cn/docs/12/ddl-partitioning.html 创建基于继承的分区表的步骤 1 创建父表 2 创建子表&#xff0c;从父表继承过来 3 创建函数及触发器&#xff0c;使插入的数据根据规则&#xff0c;插入到对应的子表中 -- 创建父表 CREATE TABLE a…...

Docker应用部署

文章目录Docker 应用部署一、部署MySQL二、部署Tomcat三、部署Nginx四、部署RedisDocker 应用部署 一、部署MySQL 搜索mysql镜像 docker search mysql拉取mysql镜像 docker pull mysql:5.6创建容器&#xff0c;设置端口映射、目录映射 # 在/root目录下创建mysql目录用于存…...

使用golang实现日志收集系统的logagent

整体架构 参考 七米老师的日志收集项目 主要用go实现logagent的部分&#xff0c;logagent的作用主要是实时监控日志追加的变化&#xff0c;并将变化发送到kafka中。 之前我们已经实现了 用go连接kafka并向其中发送数据&#xff0c;也实现了使用tail库监控日志追加操作。 我们…...

小红书点赞不显示怎么回事?小红书笔记评论被吞怎么办

小红书作为一个互联网产品&#xff0c;是一个软件。既然是软件就会有一定的程序漏洞&#xff0c;这是无法避免的。但是很多时候其实并不一定是漏洞的问题。今天就来和大家谈谈小红书点赞不显示怎么回事&#xff0c;小红书评论被吞又是怎么一回事&#xff0c;这些难道都是程序性…...

地址变换和缺页置换习题

1.设某进程页面的访问序列为4,3,2,1,4,3,5,4,3&#xff0c;2,1,5&#xff0c;当分配给该进程的内存页框数分别为3和4时&#xff0c;对于先进先出&#xff0c;最近最少使用&#xff0c;最佳页面置换算法&#xff0c;分别发生多少次缺页中断&#xff1f; 答&#xff1a; 分配的…...

PAT 乙级 1010 一元多项式求导(解题思路+AC代码)

题目&#xff1a; 设计函数求一元多项式的导数。&#xff08;注&#xff1a;xn&#xff08;n为整数&#xff09;的一阶导数为nxn−1。&#xff09; 输入格式: 以指数递降方式输入多项式非零项系数和指数&#xff08;绝对值均为不超过 1000 的整数&#xff09;。数字间以空格分…...

一维河流污染持续排放模拟(水污染扩散)

一、处理河道转换为geojson数据 以淮河为例处理示例数据&#xff1a; {"type": "FeatureCollection","features": [{"geometry": {"coordinates": [[[115.5803,34.4982],[115.5922,34.498],[115.6061,34.4994],[115.6203,…...

UDP(Echoserver)

网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法&#xff1a;netstat [选项] 功能&#xff1a;查看网络状态 常用选项&#xff1a; n 拒绝显示别名&#…...

【决胜公务员考试】求职OMG——见面课测验1

2025最新版&#xff01;&#xff01;&#xff01;6.8截至答题&#xff0c;大家注意呀&#xff01; 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:&#xff08; B &#xff09; A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕&#xff0c;#AI 监考一度冲上热搜。当AI深度融入高考&#xff0c;#时间同步 不再是辅助功能&#xff0c;而是决定AI监考系统成败的“生命线”。 AI亮相2025高考&#xff0c;40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕&#xff0c;江西、…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA&#xff1a;通过低成本全身远程操作学习双手移动操作 传统模仿学习&#xff08;Imitation Learning&#xff09;缺点&#xff1a;聚焦与桌面操作&#xff0c;缺乏通用任务所需的移动性和灵活性 本论文优点&#xff1a;&#xff08;1&#xff09;在ALOHA…...

GruntJS-前端自动化任务运行器从入门到实战

Grunt 完全指南&#xff1a;从入门到实战 一、Grunt 是什么&#xff1f; Grunt是一个基于 Node.js 的前端自动化任务运行器&#xff0c;主要用于自动化执行项目开发中重复性高的任务&#xff0c;例如文件压缩、代码编译、语法检查、单元测试、文件合并等。通过配置简洁的任务…...

BLEU评分:机器翻译质量评估的黄金标准

BLEU评分&#xff1a;机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域&#xff0c;衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标&#xff0c;自2002年由IBM的Kishore Papineni等人提出以来&#xff0c;…...

Qt 事件处理中 return 的深入解析

Qt 事件处理中 return 的深入解析 在 Qt 事件处理中&#xff0c;return 语句的使用是另一个关键概念&#xff0c;它与 event->accept()/event->ignore() 密切相关但作用不同。让我们详细分析一下它们之间的关系和工作原理。 核心区别&#xff1a;不同层级的事件处理 方…...

系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文通过代码驱动的方式&#xff0c;系统讲解PyTorch核心概念和实战技巧&#xff0c;涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...