12.12 深度学习-卷积的注意力机制-通道注意力SENet
# 告诉模型训练的时候 对某个东西 给予额外的注意 额外的权重参数 分配注意力
# 不重要的就抑制 降低权重参数 比如有些项目颜色重要 有些是形状重要
# 通道注意力 一般都要比较多的通道加注意力
# SENet
# 把上层的特征图 自动卷积为 1X1的通道数不变的特征图 然后给每一个通道乘一个权重 就分配了各个通道的注意力 把这个与原图残差回去 与原图融合 这样对比原图来说 形状 CHW都没变
# 注意力机制 可以即插即用 CHW都没变
import torch
import os
import torch.nn as nn
from torchvision.models import resnet18,ResNet18_Weights
from torchvision.models.resnet import _resnet,BasicBlock
path=os.path.dirname(__file__)
onnxpath=os.path.join(path,"assets/resnet_SE-Identity.onnx")
onnxpath=os.path.relpath(onnxpath)
class SENet1(nn.Module):
def __init__(self,inchannel,r=16):
super().__init__()
# 全局平均池化 把所以通道 整个通道进行平均池化
self.inchannel=inchannel
self.pool1=nn.AdaptiveAvgPool2d(1)
# 对全局平均池化后的结果 赋予每个通道的权重 不选择最大池化因为不是在突出最大的特征
# 这里不是直接一个全连接生成 权重 而是用两个全连接来生成 权重 第一个relu激活 第二个Sigmoid 为每一个通道生成一个0-1的权重
# 第一个全连接输出的通道数数量要缩小一下,不能直接传入多少就输出多少,不然参数量太多,第二个通道再输出回去就行
# 缩放因子
self.fc1=nn.Sequential(nn.Linear(self.inchannel,self.inchannel//r),nn.ReLU())
self.fc2=nn.Sequential(nn.Linear(self.inchannel//r,self.inchannel),nn.Sigmoid())
# fc1 用relu会信息丢失 保证inchannel//r 至少要32
# 用两层全连接可以增加注意力层的健壮性
def forward(self,x):
x1=self.pool1(x)
x1=x1.view(x1.shape[0],-1)
x1=self.fc1(x1)
x1=self.fc2(x1)
# 得到了每一个通道的权重
x1=x1.unsqueeze(2).unsqueeze(3)
# 与原来的相乘
return x*x1
def demo1():
torch.manual_seed(666)
img1=torch.rand(1,128,224,224)
senet1=SENet1(img1.shape[1],2)
res=senet1.forward(img1)
print(res.shape)
# 可以把SE模块加入到经典的CNN模型里面 有残差模块的在残差模块后面加入SE 残差模块的输出 当SE模块的输入
# 在卷积后的数据与原数据相加之前 把卷积的数据和 依靠卷积后的数据产生的SE模块的数据 相乘 然后再与原数据相加
# 这个要看源码 进行操作
# 也可以不在 残差后面 进行 有很多种插入SE的方式
# 要找到 网络的残差模块
def demo2():
# 把SE模块加入到ResNet18
# 继承一个BasicBlock类 对resnet18的残差模块进行一些重写
class BasicBlock_SE(BasicBlock):
def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1, base_width = 64, dilation = 1, norm_layer = None):
super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
self.se=SENet1(inplanes)# SE-Identity 加法 在 数据传进来的时候备份两份数据 一份卷积 一份加注意力SE模块 然后两个结果相加输出
def forward(self, x):
identity = x
identity=self.se(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
# self.se=SENet1(planes)# SE-POST 加法 在 残差模块彻底完成了后加注意力SE模块 然后结果输出
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out += identity
# out = self.relu(out)
# out=self.se(out)
# return out
# self.se=SENet1(inplanes)# SE-PRE 加法 在 残差模块卷积之前加注意力SE模块 然后结果输出
# def forward(self, x):
# identity = x
# out=self.se(x)
# out = self.conv1(out)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out += identity
# out = self.relu(out)
# return out
# self.se=SENet1(planes)# Standard_SE 加法 在 残差模块卷积h后加注意力SE模块 然后与原数据项加结果输出
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out=self.se(out)
# out += identity
# out = self.relu(out)
# return out
def resnet18_SE(*, weights= None, progress: bool = True, **kwargs):
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock_SE, [2, 2, 2, 2], weights, progress, **kwargs)
model1=resnet18_SE()
x = torch.randn(1, 3, 224, 224)
# 导出onnx
torch.onnx.export(
model1,
x,
onnxpath,
verbose=True, # 输出转换过程
input_names=["input"],
output_names=["output"],
)
print("onnx导出成功")
# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少
# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少
# 改模型不仅需要 加 一个网络结构 而且也需要注意前向传播 有没有问题
def demo3(): # 在resnet18中的后期 层里面加 SE 前期层不加
class ResNet_SE_laye(ResNet):
def __init__(self, block, layers, num_classes = 1000, zero_init_residual = False, groups = 1, width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
def _layer_update_SE(self):
self.se=SENet1(self.layer3[1].conv2.out_channels,8)
self.layer3[1].conv2=nn.Sequential(self.layer3[1].conv2,self.se)
print(self.layer3)
pass
return self.layer3
def _resnet_SE_layer(
block,
layers,
weights,
progress: bool,
**kwargs,
):
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet_SE_laye(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
def resnet18_SE_layer(*, weights= None, progress: bool = True, **kwargs):
weights = ResNet18_Weights.verify(weights)
return _resnet_SE_layer(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
model=resnet18_SE_layer()
# print(model)
layer=model._layer_update_SE()
torch.onnx.export(layer,torch.rand(1,128,224,224),"layer.onnx")
pass
if __name__=="__main__":
# demo1()
# demo2()
pass
相关文章:
12.12 深度学习-卷积的注意力机制-通道注意力SENet
# 告诉模型训练的时候 对某个东西 给予额外的注意 额外的权重参数 分配注意力 # 不重要的就抑制 降低权重参数 比如有些项目颜色重要 有些是形状重要 # 通道注意力 一般都要比较多的通道加注意力 # SENet # 把上层的特征图 自动卷积为 1X1的通道数不变的特征图 然后给每一个…...

H5 scss 移动端的样式适配
在移动端样式的scss文件中,出现了这些变量 env() 与 constant() 设置安全区域,是css里IOS11新增的属性,webkit的css函数,用于设定安全区域与边界的距离,有4个预定义变量: safe-area-inset-left: 安全区域距…...
【JAVA】Java项目实战—移动端项目:天气查询APP
在移动互联网时代,天气查询应用程序(APP)是日常生活中不可或缺的一部分。无论是出门旅行、上班通勤,还是安排户外活动,获取实时天气信息都至关重要。Java作为一种强大且广泛使用的编程语言,特别适合用于开发…...

SpringBoot - 动态端口切换黑魔法
文章目录 关键技术点核心原理Code 关键技术点 利用 Spring Boot 内嵌 Servlet 容器 和 动态端口切换 的方式实现平滑更新的方案,关键技术点如下: Servlet 容器重新绑定端口:Spring Boot 使用 ServletWebServerFactory 动态设置新端口。零停…...
Java爬虫技术:挖掘淘宝数据的利器
在当今大数据时代,网络爬虫技术已经成为获取网络数据的重要手段。Java作为一种强大且灵活的编程语言,非常适合开发复杂的网络爬虫系统。本文将详细介绍Java爬虫能够爬取的淘宝数据类型,并提供具体的代码示例,帮助您快速入门并掌握…...
Chromium for Android 浏览器的编译和安装
Chromium for Android 浏览器的编译和安装 Chromium for Android 浏览器的编译和安装环境要求和配置Chromium for Android源码下载安装 depot_tools获取代码转换现有的Linux检出安装额外的构建依赖运行钩子 Chromium for Android源码编译设置编译环境 编译 ChromiumChromium fo…...

实景视频与模型叠加融合?
[视频GIS系列]无人机视频与与实景模型进行实时融合_无人机视频融合-CSDN博客文章浏览阅读1.5k次,点赞28次,收藏14次。将无人机视频与实景模型进行实时融合是一个涉及多个技术领域的复杂过程,主要包括无人机视频采集、实景模型构建、视频与模型…...
Scala的隐式类
package hfd //隐式类 //任务:给之前的BaseUser添加新的功能,但是不要直接去改代码 //思路:把BaseUser通过隐式转换,改成一个新类型,而这个新类型中有这新的方法 //implicit class一个隐式转换函数类 //作用࿱…...

常见软件设计模式介绍:三层架构、MVC、SSM、EDD、DDD
三层架构(View Service Dao) 三层架构是指:视图层 view(表现层),服务层 service(业务逻辑层),持久层 Dao(数据访问层) 表现层:直接跟前…...
Springboot技术栈常见问题及搭建步骤
一. SpringBoot介绍 1.1. 引言 为了使用SSM框架去开发, 准备SSM框架的模板配置 为了使Spring整合第三方框架, 单独的去编写xml文件 导致SSM项目后期xml文件特别多, 维护xml文件的成本是很高的 SSM工程部署也是很麻烦, 依赖第三方的容器 SSM开发方式很是笨重 1.2 SpringBoot …...

session 共享服务器
1.安装 kryo-3.0.3.jar asm-5.2.jar objenesis-2.6.jar reflectasm-1.11.9.jar minlog-1.3.1.jar kryo-serializers-0.45.jar msm-kryo-serializer-2.3.2.jar memcached-session-manager-tc9-2.3.2.jar spymemcached-2.12.3.jar memcached-session-manager-2.3.2.jar …...

vue2:v-for实现的el-radio-group选中时显示角标,并自定义选中按钮的字体颜色和背景色
项目中需要实现一组预定义查询,每一个查询按钮在选中时右上角显示一个角标,展示当前查询返回的数据条目。 1、text-color="#3785FF" fill="#E6EAF1" 处理选中时的字体颜色和背景色,如上图,分别为蓝色和浅灰色。 2、badge中:value="selectedRadio…...

【Linux】-学习笔记10
第八章、Linux下的火墙管理及优化 1.什么是防火墙 从功能角度来讲 防火墙是位于内部网和外部网之间的屏障,它按照系统管理员预先定义好的规则来控制数据包的进出 从功能实现角度来讲 火墙是系统内核上的一个模块netfilter(数据包过滤机制) …...

鸿蒙NEXT开发案例:九宫格随机
【引言】 在鸿蒙NEXT开发中,九宫格抽奖是一个常见且有趣的应用场景。通过九宫格抽奖,用户可以随机获得不同奖品,增加互动性和趣味性。本文将介绍如何使用鸿蒙开发框架实现九宫格抽奖功能,并通过代码解析展示实现细节。 【环境准…...

深度解析:RTC电路上的32.768KHz时钟的频偏及测试
1、什么是RTC RTC是Real-Time Clock(实时时钟)的缩写,通常在电子产品中,是用时钟电路(外部采用时钟芯片,比如AiP8563)或时钟模块(SOC内部包含了时钟模块,只需要外接32.768KHz晶振)来…...
Scala的泛型
需求:定义一个名为getMiddleEle 的方法用它来获取当前的列表的中间位置的值中间位置的下标 长度/2目标:getMiddleEle(List(1,2,3,4,5)) > 5/2 2 > 下标为2的元素是:3 getMiddleEle(List(1,2,3,4)) > 4/2 2 > 下标为2的元素是:3格式如下: 定义一个函数的格式:def…...
OpenGL ES详解——glUniform1i方法是否能用于设置纹理单元
glUniform1i 方法确实可以用于设置纹理单元(texture unit)。在OpenGL中,纹理单元是图形硬件的一部分,它允许你同时绑定多个纹理,并在着色器程序中通过uniform变量来选择使用哪个纹理。 通常,纹理单元通过整…...

探索 Janus-1.3B:一个统一的 Any-to-Any 多模态理解与生成模型
随着多模态技术的不断发展,越来越多的模型被提出以解决跨文本与图像等多种数据类型的任务。Janus-1.3B 是由 DeepSeek 推出的一个革命性的模型,它通过解耦视觉编码并采用统一的 Transformer 架构,带来了一个高度灵活的 any-to-any 多模态框架…...
论文信息搜集
系列博客目录 文章目录 系列博客目录1.秩典型相关分析及其在视觉搜索重排序中的应用《Rank canonical correlation analysis and its application in visual search reranking》2.利用边信息的规范秩估计在多维谐波恢复中的应用《Canonical Rank Estimation Using Side Informa…...

实操给自助触摸一体机接入大模型语音交互
本文以CSK6 大模型开发板串口触摸屏为例,实操讲解触摸一体机怎样快速增加大模型语音交互功能,使用户能够通过语音在一体机上查询信息、获取智能回答及实现更多互动功能等。 在本文方案中通过CSK6大模型语音开发板采集用户语音,将语音数据传输…...

23-Oracle 23 ai 区块链表(Blockchain Table)
小伙伴有没有在金融强合规的领域中遇见,必须要保持数据不可变,管理员都无法修改和留痕的要求。比如医疗的电子病历中,影像检查检验结果不可篡改行的,药品追溯过程中数据只可插入无法删除的特性需求;登录日志、修改日志…...
FastAPI 教程:从入门到实践
FastAPI 是一个现代、快速(高性能)的 Web 框架,用于构建 API,支持 Python 3.6。它基于标准 Python 类型提示,易于学习且功能强大。以下是一个完整的 FastAPI 入门教程,涵盖从环境搭建到创建并运行一个简单的…...
在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module
1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...

QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...

人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...

(一)单例模式
一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...
C语言中提供的第三方库之哈希表实现
一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

Linux 下 DMA 内存映射浅析
序 系统 I/O 设备驱动程序通常调用其特定子系统的接口为 DMA 分配内存,但最终会调到 DMA 子系统的dma_alloc_coherent()/dma_alloc_attrs() 等接口。 关于 dma_alloc_coherent 接口详细的代码讲解、调用流程,可以参考这篇文章,我觉得写的非常…...