【Pytorch】Visualization of Feature Maps(4)——Saliency Maps
学习参考来自
- Saliency Maps的原理与简单实现(使用Pytorch实现)
- https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps
Saliency Maps 原理
《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》(arXiv-2013)
A saliency map tells us the degree to which each pixel in the image affects the classification score for that image.
To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar)
with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W);
for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel
changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are non-negative.
Saliency Maps相当于是计算图像的每一个pixel是如何影响一个分类器的, 或者说分类器对图像中每一个pixel哪些认为是重要的.
会计算图像每一个像素点的梯度。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,
这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。
计算 saliency map 的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;
因此最后的 saliency map的形状是(H, W)为一个通道的灰度图。
直接来代码,先载入些数据,用的是 cs231n 作业里面的 imagenet_val_25.npz
,含有 imagenet 数据中验证集的 25 张图片
import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import ImageSQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)def load_imagenet_val(num=None):"""Load a handful of validation images from ImageNet.Inputs:- num: Number of images to load (max of 25)Returns:- X: numpy array with shape [num, 224, 224, 3]- y: numpy array of integer image labels, shape [num]- class_names: dict mapping integer label to class name"""imagenet_fn = 'imagenet_val_25.npz'if not os.path.isfile(imagenet_fn):print('file %s not found' % imagenet_fn)print('Run the following:')print('cd cs231n/datasets')print('bash get_imagenet_val.sh')assert False, 'Need to download imagenet_val_25.npz'f = np.load(imagenet_fn, allow_pickle=True)X = f['X'] # (25, 224, 224, 3)y = f['y'] # (25, )class_names = f['label_map'].item() # 999if num is not None:X = X[:num]y = y[:num]return X, y, class_names
图像的前处理,resize,变成向量,减均值除以方差
# 辅助函数
def preprocess(img, size=224):transform = T.Compose([T.Resize(size),T.ToTensor(),T.Normalize(mean=SQUEEZENET_MEAN.tolist(),std=SQUEEZENET_STD.tolist()),T.Lambda(lambda x: x[None]),])return transform(img)
数据集和实验的模型
链接:https://pan.baidu.com/s/1vb2Y0IiHdH_Fb9wibTta4Q?pwd=zuvw
提取码:zuvw
核心代码,计算 saliency maps
def compute_saliency_maps(X, y, model):"""X表示图片, y表示分类结果, model表示使用的分类模型Input : - X : Input images : Tensor of shape (N, 3, H, W)- y : Label for X : LongTensor of shape (N,)- model : A pretrained CNN that will be used to computer the saliency mapReturn :- saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images"""# 确保model是test模式model.eval()# 确保X是需要gradientX.requires_grad_() # 仅开启了输入图片的梯度saliency = Nonelogits = model.forward(X) # torch.Size([5, 1000]), 前向获取 logitslogits = logits.gather(1, y.view(-1, 1)).squeeze() # torch.Size([5]) 得到正确分类 logits (5张图片标签相应类别的 logits)logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.])) # 只计算正确分类部分的loss(正确类别梯度为 1 回传)saliency = abs(X.grad.data) # 返回X的梯度绝对值大小, torch.Size([5, 3, 224, 224])saliency, _ = torch.max(saliency, dim=1) # torch.Size([5, 224, 224]),取 rgb 3通道的最大值return saliency.squeeze()
显示 saliency maps
def show_saliency_maps(X, y):# Convert X and y from numpy arrays to Torch TensorsX_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # torch.Size([5, 3, 224, 224])y_tensor = torch.LongTensor(y)# Compute saliency maps for images in Xsaliency = compute_saliency_maps(X_tensor, y_tensor, model)# Convert the saliency map from Torch Tensor to numpy array and show images# and saliency maps together.saliency = saliency.numpy()N = X.shape[0] # 5for i in range(N):plt.subplot(2, N, i + 1)plt.imshow(X[i])plt.axis('off')plt.title(class_names[y[i]])plt.subplot(2, N, N + i + 1)plt.imshow(saliency[i], cmap=plt.cm.hot)plt.axis('off')plt.gcf().set_size_inches(12, 5)plt.show()
下面开始调用,首先载入模型,使其梯度冻结,仅打开输入图片的梯度,这样反向传播的时候会更新图片,得到我们想要的 saliency maps
# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():param.requires_grad = False
加载一些图片看看,25 张中抽出来 5 张
X, y, class_names = load_imagenet_val(num=5) # X: (5, 224, 224, 3) | y: (5,) | class_names: 999"show images"plt.figure(figsize=(12, 6))
for i in range(5):plt.subplot(1, 5, i + 1)plt.imshow(X[i])plt.title(class_names[y[i]])plt.axis('off')
plt.gcf().tight_layout()
plt.show()
显示图片
把五张图片的 saliency maps 画出来
show_saliency_maps(X, y)
我把 25 张都画出来了
核心代码中涉及到了 gather 函数,下面来个简单的例子就明白了
# Example of using gather to select one entry from each row in PyTorch
# 用来返回matrix指定行某个位置的值
import torchdef gather_example():N, C = 4, 5s = torch.randn(N, C) # 随机生成 4 行 5 列的 tensory = torch.LongTensor([1, 2, 1, 3])print(s)print(y)print(torch.LongTensor(y).view(-1, 1))print(s.gather(1, y.view(-1, 1)).squeeze()) # 抽取每行相应的列数位置上的数值gather_example()"""
tensor([[ 0.8119, 0.2664, -1.4168, -0.1490, -0.0675],[ 0.5335, 0.6304, -0.7200, -0.0974, -0.9934],[-0.8305, 0.5189, 0.7359, 1.5875, 0.0505],[ 0.4335, -1.1389, -0.7771, 0.5779, 0.3515]])
tensor([1, 2, 1, 3])
tensor([[1],[2],[1],[3]])
tensor([ 0.2664, -0.7200, 0.5189, 0.5779])
"""
相关文章:

【Pytorch】Visualization of Feature Maps(4)——Saliency Maps
学习参考来自 Saliency Maps的原理与简单实现(使用Pytorch实现)https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps Saliency Maps 原理 《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》&…...

java第三十课
电商项目(前台): 登录接口 注册接口后台: 注册审核:建一个线程类 注意程序中的一个问题。 这里是 5 条记录,2 条记录显示应该是 3 页,实际操作过程 有审核机制,出现了数据记录动态变…...
Scala--2
package scala02object Scala07_typeCast {def main(args: Array[String]): Unit {// TODO 隐式转换// 自动转换val b: Byte 10var i: Int b 10val l: Long b 10 100Lval fl: Float b 10 100L 10.5fval d: Double b 10 100L 10.5f 20.00println(d.getClass…...

【SQL SERVER】定时任务
oracle是定时JOB,sqlserver是创建作业,通过sqlserver代理实现 先看SQL SERVER代理得服务有没有开 选择计算机右键——>管理——>服务与应用程序——>服务——>SQL server 代理 然后把SQL server 代理(MSSQLSERVER)启…...

MyBatis-Plus学习笔记(无脑cv即可)
1.MyBatis-Plus 1.1特性 无侵入:只做增强不做改变,引入它不会对现有工程产生影响,如丝般顺滑损耗小:启动即会自动注入基本 CURD,性能基本无损耗,直接面向对象操作强大的 CRUD 操作:内置通用 M…...
【VUE】watch 监听失效
如果你遇见了这个问题,那么尝试在 watch 函数中设置 { deep: true } 选项。这告诉 Vue 监听对象或数组内部的变化,就像下面这样: watch(()>chatStore.dataSources,(oldValue, newValue)>{// 监听执行逻辑 }, { deep: true })嗯&#x…...

python的异常处理批量执行网络设备的巡检命令
前言 在网络设备数量超过千台甚至上万台的大型企业网中,难免会遇到某些设备的管理IP地址不通,SSH连接失败的情况,设备数量越多,这种情况发生的概率越高。 这个时候如果你想用python批量配置所有的设备,就一定要注意这…...

react native 环境准备
一、必备安装 1、安装node 注意 Node 的版本应大于等于 16,安装完 Node 后建议设置 npm 镜像(淘宝源)以加速后面的过程(或使用科学上网工具)。 node下载地址:Download | Node.js设置淘宝源 npm config s…...

PGSQL(PostgreSQL)数据库安装教程
安装包下载 下载地址 下载后点击exe安装包 设置的data存储路径 设置密码 设置端口 安装完毕,配置PGSQL的ip远程连接,pg_hba.conf,postgresql.conf,需要更改这两个文件 pg_hba.conf 最后增加一行 host all all …...

识别和修复网站上损坏链接的最佳实践
如果您有一个网站,我们知道您花了很多时间在它上面,以使其成为最好的资源。如果你的链接不起作用,你的努力可能是徒劳的。您网站上的断开链接可能会以两种方式损害您的业务: 它们对企业来说是可怕的,因为当消费者点击…...

使用Navicat连接MySQL出现的一些错误
目录 一、错误一:防火墙未关闭 二、错误二:安全组问题 三、错误三:MySQL密码的加密方式 四、错误四:修改my.cnf配置文件 一、错误一:防火墙未关闭 #查看防火墙状态 firewall-cmd --state#关闭防…...

4G基站BBU、RRU、核心网设备
目录 前言 基站 核心网 信号传输 前言 移动运营商在建设4G基站的时候,除了建设一座铁塔之外,更重要的是建设搭载铁塔之上的移动通信设备,这篇博客主要介绍BBU,RRU以及机房的核心网等设备。 基站 一个基站有BBU,…...

iphone/安卓手机如何使用burp抓包
iphone 1. 电脑 ipconfig /all 获取电脑网卡ip: 192.168.31.10 2. 电脑burp上面打开设置,proxy,增加一条 192.168.31.10:8080 3. 4. 手机进入设置 -> Wi-Fi -> 找到HTTP代理选项,选择手动,192.168.31.10:8080 …...

springboot云HIS医院信息综合管理平台源码
满足基层医院机构各类业务需要的健康云HIS系统。该系统能帮助基层医院机构完成日常各类业务,提供病患挂号支持、病患问诊、电子病历、开药发药、会员管理、统计查询、医生站和护士站等一系列常规功能,能与公卫、PACS等各类外部系统融合,实现多…...

【视觉SLAM十四讲学习笔记】第三讲——四元数
专栏系列文章如下: 【视觉SLAM十四讲学习笔记】第一讲——SLAM介绍 【视觉SLAM十四讲学习笔记】第二讲——初识SLAM 【视觉SLAM十四讲学习笔记】第三讲——旋转矩阵 【视觉SLAM十四讲学习笔记】第三讲——旋转向量和欧拉角 本章将介绍视觉SLAM的基本问题之一&#x…...

Linux系统之部署Plik临时文件上传系统
Linux系统之部署Plik临时文件上传系统 一、Plik介绍1.1 Plik简介1.2 Plik特点 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本 四、下载Plik软件包4.1 创建下载目录4.2 下载Plik软件包4.3 查看下载的Plik软件…...

【EI征稿中#先投稿,先送审#】第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024)
第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024) 2024 3rd International Conference on Cyber Security, Artificial Intelligence and Digital Economy 第二届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2023&…...

『亚马逊云科技产品测评』活动征文|基于亚马逊云EC2搭建OA系统
授权声明:本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 Developer Centre, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技官方渠道 亚马逊EC2云服务器(Elastic Compute Cloud)是亚马…...

Mysql更新varchar存储的Josn数据
Mysql更新varchar存储的Josn数据 记录一次mysql操作varchar格式存储的json字符串数据 1、检查版本 -- 版本5.7以上才可以能执行json操作 select version(); 2、创建测试数据 -- 创建测试表及测试数据 CREATE TABLE test_json_table AS SELECT UUID(), {"test1": …...
JSON.stringify与JSON.parse详解与实践
目录 JSON.stringify 简介 主要用途: API 实践1: 实践2: JSON.parse 简介 API 实践1 实践2 JSON.stringify 简介 用于把JavaScript对象、数组、值、布尔值等序列化成字符串形式。 主要用途: 得到的数据通常有以下主…...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...

K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果
HarmonyOS运动开发:如何用mpchart绘制运动配速图表
##鸿蒙核心技术##运动开发##Sensor Service Kit(传感器服务)# 前言 在运动类应用中,运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据,如配速、距离、卡路里消耗等,用户可以更清晰…...

R语言速释制剂QBD解决方案之三
本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

MacOS下Homebrew国内镜像加速指南(2025最新国内镜像加速)
macos brew国内镜像加速方法 brew install 加速formula.jws.json下载慢加速 🍺 最新版brew安装慢到怀疑人生?别怕,教你轻松起飞! 最近Homebrew更新至最新版,每次执行 brew 命令时都会自动从官方地址 https://formulae.…...
BLEU评分:机器翻译质量评估的黄金标准
BLEU评分:机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域,衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标,自2002年由IBM的Kishore Papineni等人提出以来,…...
使用python进行图像处理—图像滤波(5)
图像滤波是图像处理中最基本和最重要的操作之一。它的目的是在空间域上修改图像的像素值,以达到平滑(去噪)、锐化、边缘检测等效果。滤波通常通过卷积操作实现。 5.1卷积(Convolution)原理 卷积是滤波的核心。它是一种数学运算,…...