当前位置: 首页 > news >正文

【Pytorch】Visualization of Feature Maps(1)

在这里插入图片描述

学习参考来自

  • CNN可视化Convolutional Features
  • https://github.com/wmn7/ML_Practice/blob/master/2019_05_27/filter_visualizer.ipynb

文章目录

  • filter 的激活值


filter 的激活值

原理:找一张图片,使得某个 layer 的 filter 的激活值最大,这张图片就是能被这个 filter 所检测的对象。

来个案例,流程:

  1. 初始化一张图片, 56X56
  2. 使用预训练好的 VGG16 网络,固定网络参数;
  3. 若想可视化第 40 层 layer 的第 k 个 filter 的 conv, 我们设置 loss 函数为 (-1*神经元激活值);
  4. 梯度下降, 对初始图片进行更新;
  5. 对得到的图片X1.2, 得到新的图片,重复上面的步骤;

其中第五步比较关键,我们可以看到初始化的图片不是很大,只有56X56. 这是因为原文作者在实际做的时候发现,若初始图片较大,得到的特征的频率会较高,即没有现在这么好的显示效果。

import torch
from torch.autograd import Variable
from PIL import Image, ImageOps
import torchvision.transforms as transforms
import torchvision.models as modelsimport numpy as np
import cv2
from cv2 import resize
from matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")"initialize input image"
sz = 56
img = np.uint(np.random.uniform(150, 180, (3, sz, sz))) / 255  # (3, 56, 56)
img = torch.from_numpy(img[None]).float().to(device)  # (1, 3, 56, 56)"pretrained model"
model_vgg16 = models.vgg16_bn(pretrained=True).features.to(device).eval()
# downloading /home/xxx/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth, 500M+
# print(model_vgg16)
# print(len(list(model_vgg16.children())))  # 44
# print(list(model_vgg16.children()))"get the filter's output of one layer"
# 使用hook来得到网络中间层的输出
class SaveFeatures():def __init__(self, module):self.hook = module.register_forward_hook(self.hook_fn)def hook_fn(self, module, input, output):self.features = output.clone()def close(self):self.hook.remove()layer = 42
activations = SaveFeatures(list(model_vgg16.children())[layer])"backpropagation, setting hyper-parameters"
lr = 0.1
opt_steps = 25 # 迭代次数
filters = 265 # layer 42 的第 265 个 filter,使其激活值最大
upscaling_steps = 13 # 图像放大次数
blur = 3
upscaling_factor = 1.2 # 放大倍率"preprocessing of datasets"
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device)
cnn_normalization_std = torch.tensor([0.299, 0.224, 0.225]).view(-1, 1, 1).to(device)"gradient descent"
for epoch in range(upscaling_steps):  # scale the image up up_scaling_steps timesimg = (img - cnn_normalization_mean) / cnn_normalization_stdimg[img > 1] = 1img[img < 0] = 0print("Image Shape1:", img.shape)img_var = Variable(img, requires_grad=True)  # convert image to Variable that requires grad"optimizer"optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)for n in range(opt_steps):optimizer.zero_grad()model_vgg16(img_var)  # forwardloss = -activations.features[0, filters].mean()  # max the activationsloss.backward()optimizer.step()"restore the image"print("Loss:", loss.cpu().detach().numpy())img = img_var * cnn_normalization_std + cnn_normalization_meanimg[img>1] = 1img[img<0] = 0img = img.data.cpu().numpy()[0].transpose(1,2,0)sz = int(upscaling_factor * sz)  # calculate new image sizeimg = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_CUBIC)  # scale image upif blur is not None:img = cv2.blur(img, (blur, blur))  # blur image to reduce high frequency patternsprint("Image Shape2:", img.shape)img = torch.from_numpy(img.transpose(2, 0, 1)[None]).to(device)print("Image Shape3:", img.shape)print(str(epoch), ", Finished")print("="*10)activations.close()  # remove the hookimage = img.cpu().clone()
image = image.squeeze(0)
unloader = transforms.ToPILImage()image = unloader(image)
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
cv2.imwrite("res1.jpg", image)
torch.cuda.empty_cache()"""
Image Shape1: torch.Size([1, 3, 56, 56])
Loss: -6.0634975
Image Shape2: (67, 67, 3)
Image Shape3: torch.Size([1, 3, 67, 67])
0 , Finished
==========
Image Shape1: torch.Size([1, 3, 67, 67])
Loss: -7.8898916
Image Shape2: (80, 80, 3)
Image Shape3: torch.Size([1, 3, 80, 80])
1 , Finished
==========
Image Shape1: torch.Size([1, 3, 80, 80])
Loss: -8.730318
Image Shape2: (96, 96, 3)
Image Shape3: torch.Size([1, 3, 96, 96])
2 , Finished
==========
Image Shape1: torch.Size([1, 3, 96, 96])
Loss: -9.697872
Image Shape2: (115, 115, 3)
Image Shape3: torch.Size([1, 3, 115, 115])
3 , Finished
==========
Image Shape1: torch.Size([1, 3, 115, 115])
Loss: -10.190881
Image Shape2: (138, 138, 3)
Image Shape3: torch.Size([1, 3, 138, 138])
4 , Finished
==========
Image Shape1: torch.Size([1, 3, 138, 138])
Loss: -10.315895
Image Shape2: (165, 165, 3)
Image Shape3: torch.Size([1, 3, 165, 165])
5 , Finished
==========
Image Shape1: torch.Size([1, 3, 165, 165])
Loss: -9.73861
Image Shape2: (198, 198, 3)
Image Shape3: torch.Size([1, 3, 198, 198])
6 , Finished
==========
Image Shape1: torch.Size([1, 3, 198, 198])
Loss: -9.503629
Image Shape2: (237, 237, 3)
Image Shape3: torch.Size([1, 3, 237, 237])
7 , Finished
==========
Image Shape1: torch.Size([1, 3, 237, 237])
Loss: -9.488493
Image Shape2: (284, 284, 3)
Image Shape3: torch.Size([1, 3, 284, 284])
8 , Finished
==========
Image Shape1: torch.Size([1, 3, 284, 284])
Loss: -9.100454
Image Shape2: (340, 340, 3)
Image Shape3: torch.Size([1, 3, 340, 340])
9 , Finished
==========
Image Shape1: torch.Size([1, 3, 340, 340])
Loss: -8.699549
Image Shape2: (408, 408, 3)
Image Shape3: torch.Size([1, 3, 408, 408])
10 , Finished
==========
Image Shape1: torch.Size([1, 3, 408, 408])
Loss: -8.90135
Image Shape2: (489, 489, 3)
Image Shape3: torch.Size([1, 3, 489, 489])
11 , Finished
==========
Image Shape1: torch.Size([1, 3, 489, 489])
Loss: -8.838546
Image Shape2: (586, 586, 3)
Image Shape3: torch.Size([1, 3, 586, 586])
12 , Finished
==========Process finished with exit code 0
"""

得到特征图

请添加图片描述
网上找个图片测试下,看响应是不是最大

测试图片

请添加图片描述

import torch
from torch.autograd import Variable
from PIL import Image, ImageOps
import torchvision.transforms as transforms
import torchvision.models as modelsimport numpy as np
import cv2
from cv2 import resize
from matplotlib import pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class SaveFeatures():def __init__(self, module):self.hook = module.register_forward_hook(self.hook_fn)def hook_fn(self, module, input, output):self.features = output.clone()def close(self):self.hook.remove()size = (224, 224)
picture = Image.open("./bird.jpg").convert("RGB")
picture = ImageOps.fit(picture, size, Image.ANTIALIAS)loader = transforms.ToTensor()
picture = loader(picture).to(device)
print(picture.shape)cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(device)picture = (picture-cnn_normalization_mean) / cnn_normalization_stdmodel_vgg16 = models.vgg16_bn(pretrained=True).features.to(device).eval()
print(list(model_vgg16.children())[40])  # Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
print(list(model_vgg16.children())[41])  # BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
print(list(model_vgg16.children())[42])  # ReLU(inplace=True)layer = 42
filters = 265
activations = SaveFeatures(list(model_vgg16.children())[layer])with torch.no_grad():picture_var = Variable(picture[None])model_vgg16(picture_var)
activations.close()print(activations.features.shape)  # torch.Size([1, 512, 14, 14])# 画出每个 filter 的平均值
mean_act = [activations.features[0, i].mean().item() for i in range(activations.features.shape[1])]
plt.figure(figsize=(7,5))
act = plt.plot(mean_act, linewidth=2.)
extraticks = [filters]
ax = act[0].axes
ax.set_xlim(0, 500)
plt.axvline(x=filters, color="gray", linestyle="--")
ax.set_xlabel("feature map")
ax.set_ylabel("mane activation")
ax.set_xticks([0, 200, 400] + extraticks)
plt.show()"""
torch.Size([3, 224, 224])
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
ReLU(inplace=True)
torch.Size([1, 512, 14, 14])
"""

请添加图片描述

可以看到,265 特征图对该输入的相应最高

总结:实测了其他 layer 和 filter,画出来的直方图中,对应的 filter 相应未必是最高的,不过也很高,可能找的待测图片并不是最贴合设定 layer 的某个 filter 的特征。

相关文章:

【Pytorch】Visualization of Feature Maps(1)

学习参考来自 CNN可视化Convolutional Featureshttps://github.com/wmn7/ML_Practice/blob/master/2019_05_27/filter_visualizer.ipynb 文章目录 filter 的激活值 filter 的激活值 原理&#xff1a;找一张图片&#xff0c;使得某个 layer 的 filter 的激活值最大&#xff0c…...

js修改浏览器地址栏里url的方法

1、更新url某一参数的值 function updateQueryStringParameter(uri, key, value) {if (!value) { return uri }var re new RegExp("([?&])" key ".*?(&|$)", "i");var separator uri.indexOf(?) ! -1 ? "&" : &q…...

正则表达式(Java)(韩顺平笔记)

正则表达式&#xff08;Java&#xff09; 底层实现 package com.hspedu.RegExp;import java.util.regex.Matcher; import java.util.regex.Pattern;public class RegExp00 {public static void main(String[] args) {String content "1998年12月8日&#xff0c;第二代J…...

LLVM学习笔记(62)

4.4.3.3.2. 指令处理的设置 4.4.3.3.2.1. 目标机器相关设置 除了基类以外&#xff0c;X86TargetLowering构造函数本身也是一个庞然大物&#xff0c;我们必须要分段来看。V7.0做了不小的改动&#xff0c;改进了代码的结构&#xff0c;修改了一些指令的设置。 100 X86Targ…...

解决Spring Boot应用在Kubernetes上健康检查接口返回OUT_OF_SERVICE的问题

现象 在将Spring Boot应用部署到Kubernetes上时&#xff0c;健康检查接口/healthcheck返回的状态为{"status":"OUT_OF_SERVICE","groups":["liveness","readiness"]}&#xff0c;而期望的是返回正常的健康状态。值得注意的…...

Java对象逃逸

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 未经允许不得转载 目录 一、导读二、概览三、相关知识3.1 逃逸…...

Greenplum的数据库年龄检查处理

概述 Greenplum是基于Postgresql数据库的分布式数据库&#xff0c;而PG数据库在事务及多版本并发控制的实现方式上很特别&#xff0c;采用的是递增事务id的方法&#xff0c;事务id大的事务&#xff0c;认为比较新&#xff0c;反之事务id小&#xff0c;认为比较旧。 事务id的上…...

[HCIE] IPSec-VPN (IKE自动模式)

概念&#xff1a; IKE&#xff1a;因特网密钥交换 实验目标&#xff1a;pc1与pc2互通 步骤1&#xff1a;R1与R3配置默认路由 R1&#xff1a; ip route-static 0.0.0.0 0.0.0.0 12.1.1.2 R2&#xff1a; ip route-static 0.0.0.0 0.0.0.0 23.1.1.2 步骤2&#xff1a;配ACL…...

Qt/QML编程学习之心得:一个Qt工程的学习笔记(九)

这里是关于如何使用Qt Widget开发,而Qt Quick/QML的开发是另一种方式。 1、.pro文件 加CONFIG += c++11,才可以使用Lamda表达式(一般用于connect的内嵌槽函数) 2、QWidget 这是Qt新增加的一个类,基类,窗口类,QMainWindow和QDialog都继承与它。 3、Main函数 QApplicati…...

c++ 课程笔记

105课: cpp文件分为 .h .cpp .cpp 文件 110课:124课 深拷贝 浅拷贝 自建拷贝构造解决浅拷贝释放new后堆区析构函数的问题 (浅拷贝 拷贝内存地址, 释放堆区时 导致源数据 释放时,该地址无数据?而报错) 浅拷贝: 拷贝了对方的值和 堆区内存地址(删除 影响原数据堆区) 深拷贝…...

ELK企业级日志分析平台——ES集群监控

启用xpack认证 官网&#xff1a;https://www.elastic.co/guide/en/elasticsearch/reference/7.6/configuring-tls.html#node-certificates 在elk1上生成证书 [rootelk1 ~]# cd /usr/share/elasticsearch/[rootelk1 elasticsearch]# bin/elasticsearch-certutil ca[rootelk1 ela…...

Twincat使用:EtherCAT通信扫描硬件设备链接PLC变量

EtherCAT通信采用主从架构&#xff0c;其中一个主站设备负责整个EtherCAT网络的管理和控制&#xff0c;而从站设备则负责在数据环网上传递数据。 主站设备可以是计算机、工控机、PLC等&#xff0c; 而从站设备可以是传感器、执行器、驱动器等。 EL3102:MDP5001_300_CF8D1684;…...

手机APP-MCP走蓝牙无线遥控智能安全帽~执法记录仪~拍照录像,并可做基础的配置,例如修改服务器IP以及配置WiFi等

手机APP-MCP走蓝牙无线遥控智能安全帽~执法记录仪~拍照录像,并可做基础的配置,例如修改服务器IP以及配置WiFi等 手机APP-MCP走蓝牙无线遥控智能安全帽~执法记录仪~拍照录像,并可做基础的配置,例如修改服务器IP以及配置WiFi等&#xff0c; AIoT万物智联&#xff0c;智能安全帽…...

网络互联与IP地址

目录 网络互联概述网络的定义与分类网络的定义网络的分类 OSI模型和DoD模型网络拓扑结构总线型拓扑结构星型拓扑结构环型拓扑结构 传输介质同轴电缆双绞线光纤 介质访问控制方式CSMA/CD令牌 网络设备网卡集线器交换机路由器总结 IP地址A、B、C类IP地址特殊地址形式 子网与子网掩…...

Android设计模式--模板方法模式

一&#xff0c;定义 定义一个操作中的算法的框架&#xff0c;而将一些步骤延迟到子类中&#xff0c;使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤。 在面向对象的开发过程中&#xff0c;通常会遇到这样一个问题&#xff0c;我们知道一个算法所需的关键步…...

大语言模型——BERT和GPT的那些事儿

前言 自然语言处理是人工智能的一个分支。在自然语言处理领域&#xff0c;有两个相当著名的大语言模型——BERT和GPT。两个模型是同一年提出的&#xff0c;那一年BERT以不可抵挡之势&#xff0c;让整个人工智能届为之震动。据说当年BERT的影响力是GPT的十倍以上。而现在&#…...

Docker 命令详解

1. 容器生命周期管理 命令说明文档run创建一个新的容器并运行一个命令Docker run 命令start/stop/restart启动、停止、重启容器Docker start/stop/restart 命令kill杀掉一个运行中的容器Docker kill 命令rm删除一个或多个容器Docker rm 命令pause/unpause暂停 恢复容器中所有的…...

ios打包,证书获取

HBuilderX 打包ios界面&#xff1a; Bundle ID(AppID)&#xff1a; 又称应用ID&#xff0c;是每一个ios应用的唯一标识&#xff0c;就像一个人的身份证号码&#xff1b; 每开发一个新应用&#xff0c;首先都需要先去创建一个Bundle ID Bundle ID 格式&#xff1a; 一般为&…...

linux(nginx安装配置,tomcat服务命令操作)

首先进系统文件夹 /usr/lib/systemd/systemLs | grep mysql 查看带有命名有MySQL的文件夹修改tomcat.service文件复制jdk目录替换成我们的路径替换成我们的路径进入这个目录&#xff0c;把修改好的文件拖到我们的工具里面重新刷新系统 systemctl daemon-reload查看tomcat状态…...

jQuery_03 dom对象和jQuery对象的互相转换

dom对象和jQuery对象 dom对象 jQuery对象 在一个文件中同时存在两种对象 dom对象: 通过js中的document对象获取的对象 或者创建的对象 jQuery对象: 通过jQuery中的函数获取的对象。 为什么使用dom或jQuery对象呢&#xff1f; 目的是 要使用dom对象的函数或者属性 以及呢 要…...

网络编程(Modbus进阶)

思维导图 Modbus RTU&#xff08;先学一点理论&#xff09; 概念 Modbus RTU 是工业自动化领域 最广泛应用的串行通信协议&#xff0c;由 Modicon 公司&#xff08;现施耐德电气&#xff09;于 1979 年推出。它以 高效率、强健性、易实现的特点成为工业控制系统的通信标准。 包…...

RestClient

什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端&#xff0c;它允许HTTP与Elasticsearch 集群通信&#xff0c;而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级&#xff…...

shell脚本--常见案例

1、自动备份文件或目录 2、批量重命名文件 3、查找并删除指定名称的文件&#xff1a; 4、批量删除文件 5、查找并替换文件内容 6、批量创建文件 7、创建文件夹并移动文件 8、在文件夹中查找文件...

Admin.Net中的消息通信SignalR解释

定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…...

Java多线程实现之Callable接口深度解析

Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...