当前位置: 首页 > news >正文

CAM类激活映射 |神经网络可视化 | 热力图

文章目录

    • 前言:
    • 安装库:
    • 分类案例--ResNet50
    • 分割案例
      • AttributeError: ‘tuple‘ object has no attribute ‘cpu‘
      • RuntimeError: grad can be implicitly created only for scalar outputs
      • TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
      • 完整代码

前言:

本篇文章只是教程,不涉及原理,感兴趣可以自行搜索
如图,热力图可以很好的反映出网络究竟注意图片的哪一部分
在这里插入图片描述
github官方教程:
https://github.com/jacobgil/pytorch-grad-cam
参考博客:
https://blog.csdn.net/u014264373/article/details/85415921
https://blog.csdn.net/u014264373/article/details/116302678
但还是遇到了很多报错,解决过程记录如下:

安装库:

pip install grad-cam

分类案例–ResNet50

案例图片:
在这里插入图片描述

案例代码:
这个代码是可以跑通的,将图片保存到你本地,然后设置好路径即可。
(需要下载ResNet预训练模型)

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as npdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.detach().numpy()  # [0]# 对数据进行归一化if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrpath = "../examples/both.png"
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1; squeeze
input_tensor = torchvision.transforms.functional.resize(img, [224, 224])model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
with GradCAM(model=model, target_layers=target_layers, use_cuda=False) as cam:# targets = [ClassifierOutputTarget(386), ClassifierOutputTarget(386)]  # 指定查看class_num为386的热力图targets = None  # 选定目标类别,如果不设置,则默认为分数最高的那一类# aug_smooth=True, eigen_smooth=True 使用图像增强是热力图变得更加平滑grayscale_cams = cam(input_tensor=input_tensor, targets=targets)  # targets=None 自动调用概率最大的类别显示for grayscale_cam, tensor in zip(grayscale_cams, input_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

最后出来的结果应该就是这样一张图。
在这里插入图片描述

分割案例

如果上面的代码你跑通了
那么如何为自己的网络生成热力图呢?
有几个需要注意的点:(最后会附上完整代码)

首先,切换成你的网络了模型加载就不说了,这个自己搞好。
然后,你的网络是否是在gpu上跑的,如果是
输入数据要放gpu上

path = './test_img/yu.jpg'
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeeze
img_tensor = torchvision.transforms.functional.resize(img, [352, 352])
img_tensor = img_tensor.cuda()   # 加一句这个

然后按照上面的代码,修改这一句,改成你要查看的层:

target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值

把这个改成你要的层,然后运行一下,可能会遇到报错:

AttributeError: ‘tuple‘ object has no attribute ‘cpu‘

如果出现这个报错,可以看下你的网络最终输出是几个特征。因为是自己写的网络,有的因为训练需要,最终返回的是多个结果。

print(model(x))

如果有多个结果,会被变成一个元组。后面需要转cpu,元组tuple没有.cpu的方法,所以报错。
解决方法:
先把你的网络包装一下,你返回了多个值,选择有用的那一个就行
我这里选择了多个输出的第一个,自己视情况而定

class SegmentationModelOutputWrapper(torch.nn.Module):def __init__(self, model):super(SegmentationModelOutputWrapper, self).__init__()self.model = modeldef forward(self, x):return self.model(x)[0]  # 我这里选择了多个输出的第一个,自己视情况而定model = NetWork()
model.load_state_dict(torch.load(opt.snap_path))
# 网络加载后先包装下  修改输出
model = SegmentationModelOutputWrapper(model)

然后再运行,可能会出现报错:

RuntimeError: grad can be implicitly created only for scalar outputs

这个问题的解决办法是:
你需要去到base_cam.py这个库文件里面去
第85行有一句loss.backward(retain_graph = True)
将其修改为loss.backward(torch.ones_like(loss),retain_graph=True)

在这里插入图片描述

参考链接:https://blog.csdn.net/weixin_44390884/article/details/127893163

还有一个报错:

TypeError: can’t convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

如果你的模型最后返回的特征是tensor的特征,那么需要对tensor2img做改动
在这里插入图片描述

np_arr = tensor.detach().numpy()  # [0]

修改为:

np_arr = tensor.cpu().detach().numpy()  # [0]

完整代码

import os
import torch
import argparse
import numpy as np
import imageio
import torchvision
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from matplotlib import pyplot as pltdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.cpu().detach().numpy()  # [0]if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrif __name__ == '__main__':model = NetWork()model.load_state_dict(torch.load(opt.snap_path))# torchinfo.summary(model=model,input_size=(8, 3, 352, 352))# 包装下  修改输出model = SegmentationModelOutputWrapper(model)model.eval()path = './test_img/yu.jpg'bin_data = torchvision.io.read_file(path)  # 加载二进制数据img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeezeimg_tensor = torchvision.transforms.functional.resize(img, [352, 352])img_tensor = img_tensor.cuda()target_layers = [model.model.ncd]targets = Nonewith GradCAM(model=model,target_layers=target_layers,use_cuda=True) as cam:grayscale_cams = cam(input_tensor=img_tensor,targets=targets,aug_smooth=True)# cam_image = show_cam_on_image(img_rgb, grayscale_cam, use_rgb=True)for grayscale_cam, tensor in zip(grayscale_cams, img_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

相关文章:

CAM类激活映射 |神经网络可视化 | 热力图

文章目录前言&#xff1a;安装库&#xff1a;分类案例--ResNet50分割案例AttributeError: ‘tuple‘ object has no attribute ‘cpu‘RuntimeError: grad can be implicitly created only for scalar outputsTypeError: cant convert cuda:0 device type tensor to numpy. Use…...

RecyclerView+BaseRecyclerViewAdapterHelper显示不全只显示第一行item的解决问题

RecyclerViewBaseRecyclerViewAdapterHelper显示不全只显示第一行item&#xff0c;我懵了…&#xff0c;我不说多&#xff0c;直接说吧 先看一下适配器代码中的convert()方法&#xff1a; class MineRadioAdapter(layoutResId: Int R.layout.item_my_live) :BaseQuickAdapte…...

解决后端无法对前端的ajax请求重定向

本章目录&#xff1a; 问题描述 AJAX请求后端直接重定向失败解决方案 后端拦截请为响应头添加重定向标志后端拦截器为响应头添加重定向路径前端响应拦截器获取响应头数据&#xff0c;并通过location.href url 完成页面跳转一、问题描述 本来想在拦截器里设置未登录用户访问指…...

【Python】1分钟就能制作精美的框架图?太棒啦

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、准备二、基本使用与例子1.初始化与导出2.节点类型3.集群块4.自定义线的颜色与属性总结前言 Diagrams 是一个基于Python绘制云系统架构的模块&#xff0c;它能…...

淘宝必备的补单技巧及注意事项!

补单&#xff0c;是优化善后的s单。单只是模拟用户的购物习惯&#xff0c;而补单同时还要模拟整个店铺的综合数据&#xff0c;包括点击率、转化率等等&#xff0c;补到略高于同行、竞品的平均数据时&#xff0c;淘宝会判断为买家比较喜欢你的商品&#xff0c;从而给你更多推荐机…...

【实用篇】SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式,系统详解springcloud分布式

文章目录一、服务拆分1.1 服务拆分Demo1.2 微服务远程调用二、Eureka2.1 Eureka原理2.2 Eureka-server服务搭建2.3 eureka-client服务注册2.4 eureka-client服务复制2.5 eureka服务发现三、Ribbon负载均衡3.1 负载均衡原理3.2 负载均衡策略3.3 自定义负载均衡策略3.4 饥饿加载与…...

私人飞机、公务机包机会成为富豪圈的主流出行方式吗?

从炫耀性消费到按需使用&#xff0c;私人飞机的消费群体正在被拓宽&#xff0c;但离“成为主流”还有一段距离。“时间就是金钱”为有钱人消费私人飞机提供合理动机&#xff0c;而这群高净值人群的数量增长则成为撑起市场基本面。据相关数据显示&#xff0c;2018年全球超级富豪…...

Oracle组织架构

组织架构 &#xff08;一&#xff09;业务组&#xff08;BG&#xff09; &#xff08;二&#xff09;法律实体&#xff08;LE&#xff09; &#xff08;三&#xff09;业务实体&#xff08;OU&#xff09; &#xff08;四&#xff09;库存组织&#xff08;INV&#xff09; …...

最小公倍数

目录 最小公倍数 程序设计 程序分析 最小公倍数 【问题描述】给定两个正整数,计算这两个数的最小公倍数。 【输入形式】输入包含多组测试数据,每组只有一行,包括两个不大于1000的正整数. 【输出形式】 对于每个测试用例,给出这两个数的最小公倍数,每个实例输出一行。…...

二叉树的后序遍历(力扣145)

目录 题目描述&#xff1a; 解法一&#xff1a;递归法 解法二&#xff1a;迭代法 解法三&#xff1a;Morris遍历 二叉树的后序遍历 题目描述&#xff1a; 给你一棵二叉树的根节点 root &#xff0c;返回其节点值的 后序遍历 。 示例 1&#xff1a; 输入&#xff1a;root …...

《Effective C++》读书纪实 -- 诸君同享

文章目录《Effective C》是一本经典的C编程指南&#xff0c;共包含50条C编程的最佳实践。 确定你的构造函数的行为 在构造函数中&#xff0c;应该尽可能地避免调用虚函数、非静态成员函数和虚基类的函数。 尽量使用const、enum、inline替换#define 使用const、enum、inline可以…...

【云原生】K8S-ConfigMap 实现应用和配置分离

文章目录前言ConfigMap 背景ConfigMap 创建方式ConfigMap 的使用使用 ConfigMap 的注意事项总结前言 Kubernetes 是目前最流行的容器编排系统之一&#xff0c;它提供了丰富的功能来支持容器化应用程序的管理和部署。 ConfigMap 是 Kubernetes 中重要的资源对象&#xff0c;用…...

java -测距工具(经纬度)

代码 /*** 测距工具* author qb*/ public class DistanceUtils {/*** 赤道半径*/private static final double EARTH_RADIUS 6378.137;private static double rad(double d) {return d * Math.PI / 180.0;}/*** Description : 通过经纬度获取距离(单位&#xff1a;米)* Group…...

postgres分区表的创建-基于继承

参考文档&#xff1a; http://postgres.cn/docs/12/ddl-partitioning.html 创建基于继承的分区表的步骤 1 创建父表 2 创建子表&#xff0c;从父表继承过来 3 创建函数及触发器&#xff0c;使插入的数据根据规则&#xff0c;插入到对应的子表中 -- 创建父表 CREATE TABLE a…...

Docker应用部署

文章目录Docker 应用部署一、部署MySQL二、部署Tomcat三、部署Nginx四、部署RedisDocker 应用部署 一、部署MySQL 搜索mysql镜像 docker search mysql拉取mysql镜像 docker pull mysql:5.6创建容器&#xff0c;设置端口映射、目录映射 # 在/root目录下创建mysql目录用于存…...

使用golang实现日志收集系统的logagent

整体架构 参考 七米老师的日志收集项目 主要用go实现logagent的部分&#xff0c;logagent的作用主要是实时监控日志追加的变化&#xff0c;并将变化发送到kafka中。 之前我们已经实现了 用go连接kafka并向其中发送数据&#xff0c;也实现了使用tail库监控日志追加操作。 我们…...

小红书点赞不显示怎么回事?小红书笔记评论被吞怎么办

小红书作为一个互联网产品&#xff0c;是一个软件。既然是软件就会有一定的程序漏洞&#xff0c;这是无法避免的。但是很多时候其实并不一定是漏洞的问题。今天就来和大家谈谈小红书点赞不显示怎么回事&#xff0c;小红书评论被吞又是怎么一回事&#xff0c;这些难道都是程序性…...

地址变换和缺页置换习题

1.设某进程页面的访问序列为4,3,2,1,4,3,5,4,3&#xff0c;2,1,5&#xff0c;当分配给该进程的内存页框数分别为3和4时&#xff0c;对于先进先出&#xff0c;最近最少使用&#xff0c;最佳页面置换算法&#xff0c;分别发生多少次缺页中断&#xff1f; 答&#xff1a; 分配的…...

PAT 乙级 1010 一元多项式求导(解题思路+AC代码)

题目&#xff1a; 设计函数求一元多项式的导数。&#xff08;注&#xff1a;xn&#xff08;n为整数&#xff09;的一阶导数为nxn−1。&#xff09; 输入格式: 以指数递降方式输入多项式非零项系数和指数&#xff08;绝对值均为不超过 1000 的整数&#xff09;。数字间以空格分…...

一维河流污染持续排放模拟(水污染扩散)

一、处理河道转换为geojson数据 以淮河为例处理示例数据&#xff1a; {"type": "FeatureCollection","features": [{"geometry": {"coordinates": [[[115.5803,34.4982],[115.5922,34.498],[115.6061,34.4994],[115.6203,…...

深度解析:Performance-Fish如何通过四级缓存架构实现《环世界》400%性能优化

深度解析&#xff1a;Performance-Fish如何通过四级缓存架构实现《环世界》400%性能优化 【免费下载链接】Performance-Fish Performance Mod for RimWorld 项目地址: https://gitcode.com/gh_mirrors/pe/Performance-Fish Performance-Fish是《环世界》&#xff08;Rim…...

Applite:告别命令行!macOS软件管理的图形化终极解决方案

Applite&#xff1a;告别命令行&#xff01;macOS软件管理的图形化终极解决方案 【免费下载链接】Applite User-friendly GUI macOS application for Homebrew Casks 项目地址: https://gitcode.com/gh_mirrors/ap/Applite 还在为Homebrew复杂的命令行操作而头疼吗&…...

猫抓扩展完整指南:三步掌握浏览器视频嗅探与下载技巧

猫抓扩展完整指南&#xff1a;三步掌握浏览器视频嗅探与下载技巧 【免费下载链接】cat-catch 猫抓 浏览器资源嗅探扩展 / cat-catch Browser Resource Sniffing Extension 项目地址: https://gitcode.com/GitHub_Trending/ca/cat-catch 猫抓&#xff08;Cat-Catch&#…...

XHS-Downloader:小红书内容采集与管理的全栈解决方案

XHS-Downloader&#xff1a;小红书内容采集与管理的全栈解决方案 【免费下载链接】XHS-Downloader 小红书&#xff08;XiaoHongShu、RedNote&#xff09;链接提取/作品采集工具&#xff1a;提取账号发布、收藏、点赞、专辑作品链接&#xff1b;提取搜索结果作品、用户链接&…...

从零构建专属大语言模型:Self-LLM开源项目全流程实践指南

1. 项目概述与核心价值最近在开源社区里&#xff0c;一个名为datawhalechina/self-llm的项目引起了我的注意。乍一看&#xff0c;这像是一个关于大语言模型&#xff08;LLM&#xff09;的仓库&#xff0c;但“self”这个前缀又让人浮想联翩。经过一段时间的深入研究和实践&…...

天学网口碑好不好?2026年最新用户实测反馈给你答案

作为深耕教育数字化落地领域5年的从业者&#xff0c;最近后台收到不少公立校电教组老师、学生家长的提问&#xff1a;主打AI英语教学的天学网口碑到底怎么样&#xff1f;刚好我们团队刚做完2026年第一季度的英语教育数字化工具落地效果调研&#xff0c;结合一手实测数据给大家客…...

火灾动力学模拟实战:如何用FDS构建精准的火灾预测系统

火灾动力学模拟实战&#xff1a;如何用FDS构建精准的火灾预测系统 【免费下载链接】fds Fire Dynamics Simulator 项目地址: https://gitcode.com/gh_mirrors/fd/fds 你是否曾面临这样的困境&#xff1a;当设计一栋大型商业建筑时&#xff0c;如何科学评估火灾时的人员疏…...

ComfyUI-Manager终极指南:3步掌握AI绘画插件管理技巧

ComfyUI-Manager终极指南&#xff1a;3步掌握AI绘画插件管理技巧 【免费下载链接】ComfyUI-Manager ComfyUI-Manager is an extension designed to enhance the usability of ComfyUI. It offers management functions to install, remove, disable, and enable various custom…...

Docker Compose编排微服务

Docker Compose编排微服务 引言 Docker Compose是Docker官方提供的容器编排工具&#xff0c;用于定义和运行多容器Docker应用。通过Compose&#xff0c;可以使用YAML文件定义服务、网络、数据卷等资源&#xff0c;然后通过简单的命令启动和停止整个应用。Docker Compose特别适合…...

知乎API完全指南:用Python轻松获取知乎数据的5个核心技巧

知乎API完全指南&#xff1a;用Python轻松获取知乎数据的5个核心技巧 【免费下载链接】zhihu-api Zhihu API for Humans 项目地址: https://gitcode.com/gh_mirrors/zh/zhihu-api 在当今数据驱动的时代&#xff0c;知乎数据采集和Python API开发已成为获取高质量中文知识…...