当前位置: 首页 > news >正文

CAM类激活映射 |神经网络可视化 | 热力图

文章目录

    • 前言:
    • 安装库:
    • 分类案例--ResNet50
    • 分割案例
      • AttributeError: ‘tuple‘ object has no attribute ‘cpu‘
      • RuntimeError: grad can be implicitly created only for scalar outputs
      • TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
      • 完整代码

前言:

本篇文章只是教程,不涉及原理,感兴趣可以自行搜索
如图,热力图可以很好的反映出网络究竟注意图片的哪一部分
在这里插入图片描述
github官方教程:
https://github.com/jacobgil/pytorch-grad-cam
参考博客:
https://blog.csdn.net/u014264373/article/details/85415921
https://blog.csdn.net/u014264373/article/details/116302678
但还是遇到了很多报错,解决过程记录如下:

安装库:

pip install grad-cam

分类案例–ResNet50

案例图片:
在这里插入图片描述

案例代码:
这个代码是可以跑通的,将图片保存到你本地,然后设置好路径即可。
(需要下载ResNet预训练模型)

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as npdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.detach().numpy()  # [0]# 对数据进行归一化if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrpath = "../examples/both.png"
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1; squeeze
input_tensor = torchvision.transforms.functional.resize(img, [224, 224])model = resnet50(pretrained=True)
target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
with GradCAM(model=model, target_layers=target_layers, use_cuda=False) as cam:# targets = [ClassifierOutputTarget(386), ClassifierOutputTarget(386)]  # 指定查看class_num为386的热力图targets = None  # 选定目标类别,如果不设置,则默认为分数最高的那一类# aug_smooth=True, eigen_smooth=True 使用图像增强是热力图变得更加平滑grayscale_cams = cam(input_tensor=input_tensor, targets=targets)  # targets=None 自动调用概率最大的类别显示for grayscale_cam, tensor in zip(grayscale_cams, input_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

最后出来的结果应该就是这样一张图。
在这里插入图片描述

分割案例

如果上面的代码你跑通了
那么如何为自己的网络生成热力图呢?
有几个需要注意的点:(最后会附上完整代码)

首先,切换成你的网络了模型加载就不说了,这个自己搞好。
然后,你的网络是否是在gpu上跑的,如果是
输入数据要放gpu上

path = './test_img/yu.jpg'
bin_data = torchvision.io.read_file(path)  # 加载二进制数据
img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片
img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeeze
img_tensor = torchvision.transforms.functional.resize(img, [352, 352])
img_tensor = img_tensor.cuda()   # 加一句这个

然后按照上面的代码,修改这一句,改成你要查看的层:

target_layers = [model.layer4[-1]]  # 如果传入多个layer,cam输出结果将会取均值

把这个改成你要的层,然后运行一下,可能会遇到报错:

AttributeError: ‘tuple‘ object has no attribute ‘cpu‘

如果出现这个报错,可以看下你的网络最终输出是几个特征。因为是自己写的网络,有的因为训练需要,最终返回的是多个结果。

print(model(x))

如果有多个结果,会被变成一个元组。后面需要转cpu,元组tuple没有.cpu的方法,所以报错。
解决方法:
先把你的网络包装一下,你返回了多个值,选择有用的那一个就行
我这里选择了多个输出的第一个,自己视情况而定

class SegmentationModelOutputWrapper(torch.nn.Module):def __init__(self, model):super(SegmentationModelOutputWrapper, self).__init__()self.model = modeldef forward(self, x):return self.model(x)[0]  # 我这里选择了多个输出的第一个,自己视情况而定model = NetWork()
model.load_state_dict(torch.load(opt.snap_path))
# 网络加载后先包装下  修改输出
model = SegmentationModelOutputWrapper(model)

然后再运行,可能会出现报错:

RuntimeError: grad can be implicitly created only for scalar outputs

这个问题的解决办法是:
你需要去到base_cam.py这个库文件里面去
第85行有一句loss.backward(retain_graph = True)
将其修改为loss.backward(torch.ones_like(loss),retain_graph=True)

在这里插入图片描述

参考链接:https://blog.csdn.net/weixin_44390884/article/details/127893163

还有一个报错:

TypeError: can’t convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

如果你的模型最后返回的特征是tensor的特征,那么需要对tensor2img做改动
在这里插入图片描述

np_arr = tensor.detach().numpy()  # [0]

修改为:

np_arr = tensor.cpu().detach().numpy()  # [0]

完整代码

import os
import torch
import argparse
import numpy as np
import imageio
import torchvision
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from matplotlib import pyplot as pltdef myimshows(imgs, titles=False, fname="test.jpg", size=6):lens = len(imgs)fig = plt.figure(figsize=(size * lens, size))if titles == False:titles = "0123456789"for i in range(1, lens + 1):cols = 100 + lens * 10 + iplt.xticks(())plt.yticks(())plt.subplot(cols)if len(imgs[i - 1].shape) == 2:plt.imshow(imgs[i - 1], cmap='Reds')else:plt.imshow(imgs[i - 1])plt.title(titles[i - 1])plt.xticks(())plt.yticks(())plt.savefig(fname, bbox_inches='tight')plt.show()def tensor2img(tensor, heatmap=False, shape=(224, 224)):np_arr = tensor.cpu().detach().numpy()  # [0]if np_arr.max() > 1 or np_arr.min() < 0:np_arr = np_arr - np_arr.min()np_arr = np_arr / np_arr.max()# np_arr=(np_arr*255).astype(np.uint8)if np_arr.shape[0] == 1:np_arr = np.concatenate([np_arr, np_arr, np_arr], axis=0)np_arr = np_arr.transpose((1, 2, 0))return np_arrif __name__ == '__main__':model = NetWork()model.load_state_dict(torch.load(opt.snap_path))# torchinfo.summary(model=model,input_size=(8, 3, 352, 352))# 包装下  修改输出model = SegmentationModelOutputWrapper(model)model.eval()path = './test_img/yu.jpg'bin_data = torchvision.io.read_file(path)  # 加载二进制数据img = torchvision.io.decode_image(bin_data) / 255  # 解码成CHW的图片img = img.unsqueeze(0)  # 变成BCHW的数据,B==1 squeezeimg_tensor = torchvision.transforms.functional.resize(img, [352, 352])img_tensor = img_tensor.cuda()target_layers = [model.model.ncd]targets = Nonewith GradCAM(model=model,target_layers=target_layers,use_cuda=True) as cam:grayscale_cams = cam(input_tensor=img_tensor,targets=targets,aug_smooth=True)# cam_image = show_cam_on_image(img_rgb, grayscale_cam, use_rgb=True)for grayscale_cam, tensor in zip(grayscale_cams, img_tensor):# 将热力图结果与原图进行融合rgb_img = tensor2img(tensor)visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)myimshows([rgb_img, grayscale_cam, visualization], ["image", "cam", "image + cam"])

相关文章:

CAM类激活映射 |神经网络可视化 | 热力图

文章目录前言&#xff1a;安装库&#xff1a;分类案例--ResNet50分割案例AttributeError: ‘tuple‘ object has no attribute ‘cpu‘RuntimeError: grad can be implicitly created only for scalar outputsTypeError: cant convert cuda:0 device type tensor to numpy. Use…...

RecyclerView+BaseRecyclerViewAdapterHelper显示不全只显示第一行item的解决问题

RecyclerViewBaseRecyclerViewAdapterHelper显示不全只显示第一行item&#xff0c;我懵了…&#xff0c;我不说多&#xff0c;直接说吧 先看一下适配器代码中的convert()方法&#xff1a; class MineRadioAdapter(layoutResId: Int R.layout.item_my_live) :BaseQuickAdapte…...

解决后端无法对前端的ajax请求重定向

本章目录&#xff1a; 问题描述 AJAX请求后端直接重定向失败解决方案 后端拦截请为响应头添加重定向标志后端拦截器为响应头添加重定向路径前端响应拦截器获取响应头数据&#xff0c;并通过location.href url 完成页面跳转一、问题描述 本来想在拦截器里设置未登录用户访问指…...

【Python】1分钟就能制作精美的框架图?太棒啦

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录前言一、准备二、基本使用与例子1.初始化与导出2.节点类型3.集群块4.自定义线的颜色与属性总结前言 Diagrams 是一个基于Python绘制云系统架构的模块&#xff0c;它能…...

淘宝必备的补单技巧及注意事项!

补单&#xff0c;是优化善后的s单。单只是模拟用户的购物习惯&#xff0c;而补单同时还要模拟整个店铺的综合数据&#xff0c;包括点击率、转化率等等&#xff0c;补到略高于同行、竞品的平均数据时&#xff0c;淘宝会判断为买家比较喜欢你的商品&#xff0c;从而给你更多推荐机…...

【实用篇】SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式,系统详解springcloud分布式

文章目录一、服务拆分1.1 服务拆分Demo1.2 微服务远程调用二、Eureka2.1 Eureka原理2.2 Eureka-server服务搭建2.3 eureka-client服务注册2.4 eureka-client服务复制2.5 eureka服务发现三、Ribbon负载均衡3.1 负载均衡原理3.2 负载均衡策略3.3 自定义负载均衡策略3.4 饥饿加载与…...

私人飞机、公务机包机会成为富豪圈的主流出行方式吗?

从炫耀性消费到按需使用&#xff0c;私人飞机的消费群体正在被拓宽&#xff0c;但离“成为主流”还有一段距离。“时间就是金钱”为有钱人消费私人飞机提供合理动机&#xff0c;而这群高净值人群的数量增长则成为撑起市场基本面。据相关数据显示&#xff0c;2018年全球超级富豪…...

Oracle组织架构

组织架构 &#xff08;一&#xff09;业务组&#xff08;BG&#xff09; &#xff08;二&#xff09;法律实体&#xff08;LE&#xff09; &#xff08;三&#xff09;业务实体&#xff08;OU&#xff09; &#xff08;四&#xff09;库存组织&#xff08;INV&#xff09; …...

最小公倍数

目录 最小公倍数 程序设计 程序分析 最小公倍数 【问题描述】给定两个正整数,计算这两个数的最小公倍数。 【输入形式】输入包含多组测试数据,每组只有一行,包括两个不大于1000的正整数. 【输出形式】 对于每个测试用例,给出这两个数的最小公倍数,每个实例输出一行。…...

二叉树的后序遍历(力扣145)

目录 题目描述&#xff1a; 解法一&#xff1a;递归法 解法二&#xff1a;迭代法 解法三&#xff1a;Morris遍历 二叉树的后序遍历 题目描述&#xff1a; 给你一棵二叉树的根节点 root &#xff0c;返回其节点值的 后序遍历 。 示例 1&#xff1a; 输入&#xff1a;root …...

《Effective C++》读书纪实 -- 诸君同享

文章目录《Effective C》是一本经典的C编程指南&#xff0c;共包含50条C编程的最佳实践。 确定你的构造函数的行为 在构造函数中&#xff0c;应该尽可能地避免调用虚函数、非静态成员函数和虚基类的函数。 尽量使用const、enum、inline替换#define 使用const、enum、inline可以…...

【云原生】K8S-ConfigMap 实现应用和配置分离

文章目录前言ConfigMap 背景ConfigMap 创建方式ConfigMap 的使用使用 ConfigMap 的注意事项总结前言 Kubernetes 是目前最流行的容器编排系统之一&#xff0c;它提供了丰富的功能来支持容器化应用程序的管理和部署。 ConfigMap 是 Kubernetes 中重要的资源对象&#xff0c;用…...

java -测距工具(经纬度)

代码 /*** 测距工具* author qb*/ public class DistanceUtils {/*** 赤道半径*/private static final double EARTH_RADIUS 6378.137;private static double rad(double d) {return d * Math.PI / 180.0;}/*** Description : 通过经纬度获取距离(单位&#xff1a;米)* Group…...

postgres分区表的创建-基于继承

参考文档&#xff1a; http://postgres.cn/docs/12/ddl-partitioning.html 创建基于继承的分区表的步骤 1 创建父表 2 创建子表&#xff0c;从父表继承过来 3 创建函数及触发器&#xff0c;使插入的数据根据规则&#xff0c;插入到对应的子表中 -- 创建父表 CREATE TABLE a…...

Docker应用部署

文章目录Docker 应用部署一、部署MySQL二、部署Tomcat三、部署Nginx四、部署RedisDocker 应用部署 一、部署MySQL 搜索mysql镜像 docker search mysql拉取mysql镜像 docker pull mysql:5.6创建容器&#xff0c;设置端口映射、目录映射 # 在/root目录下创建mysql目录用于存…...

使用golang实现日志收集系统的logagent

整体架构 参考 七米老师的日志收集项目 主要用go实现logagent的部分&#xff0c;logagent的作用主要是实时监控日志追加的变化&#xff0c;并将变化发送到kafka中。 之前我们已经实现了 用go连接kafka并向其中发送数据&#xff0c;也实现了使用tail库监控日志追加操作。 我们…...

小红书点赞不显示怎么回事?小红书笔记评论被吞怎么办

小红书作为一个互联网产品&#xff0c;是一个软件。既然是软件就会有一定的程序漏洞&#xff0c;这是无法避免的。但是很多时候其实并不一定是漏洞的问题。今天就来和大家谈谈小红书点赞不显示怎么回事&#xff0c;小红书评论被吞又是怎么一回事&#xff0c;这些难道都是程序性…...

地址变换和缺页置换习题

1.设某进程页面的访问序列为4,3,2,1,4,3,5,4,3&#xff0c;2,1,5&#xff0c;当分配给该进程的内存页框数分别为3和4时&#xff0c;对于先进先出&#xff0c;最近最少使用&#xff0c;最佳页面置换算法&#xff0c;分别发生多少次缺页中断&#xff1f; 答&#xff1a; 分配的…...

PAT 乙级 1010 一元多项式求导(解题思路+AC代码)

题目&#xff1a; 设计函数求一元多项式的导数。&#xff08;注&#xff1a;xn&#xff08;n为整数&#xff09;的一阶导数为nxn−1。&#xff09; 输入格式: 以指数递降方式输入多项式非零项系数和指数&#xff08;绝对值均为不超过 1000 的整数&#xff09;。数字间以空格分…...

一维河流污染持续排放模拟(水污染扩散)

一、处理河道转换为geojson数据 以淮河为例处理示例数据&#xff1a; {"type": "FeatureCollection","features": [{"geometry": {"coordinates": [[[115.5803,34.4982],[115.5922,34.498],[115.6061,34.4994],[115.6203,…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板&#xff0c;载入页面后&#xff0c;会显示引导弹窗&#xff0c;适用于引导用户使用页面&#xff0c;点击完成后&#xff0c;会显示下一个引导弹窗&#xff0c;直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

【杂谈】-递归进化:人工智能的自我改进与监管挑战

递归进化&#xff1a;人工智能的自我改进与监管挑战 文章目录 递归进化&#xff1a;人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管&#xff1f;3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

51c自动驾驶~合集58

我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留&#xff0c;CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制&#xff08;CCA-Attention&#xff09;&#xff0c;…...

Leetcode 3576. Transform Array to All Equal Elements

Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接&#xff1a;3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到&#xf…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得&#xff0c;如果用户端访问量比较大&#xff0c;数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据&#xff0c;减少数据库查询操作。 缓存逻辑分析&#xff1a; ①每个分类下的菜品保持一份缓存数据…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...

MySQL 8.0 事务全面讲解

以下是一个结合两次回答的 MySQL 8.0 事务全面讲解&#xff0c;涵盖了事务的核心概念、操作示例、失败回滚、隔离级别、事务性 DDL 和 XA 事务等内容&#xff0c;并修正了查看隔离级别的命令。 MySQL 8.0 事务全面讲解 一、事务的核心概念&#xff08;ACID&#xff09; 事务是…...