12.12 深度学习-卷积的注意力机制-通道注意力SENet
# 告诉模型训练的时候 对某个东西 给予额外的注意 额外的权重参数 分配注意力
# 不重要的就抑制 降低权重参数 比如有些项目颜色重要 有些是形状重要
# 通道注意力 一般都要比较多的通道加注意力
# SENet
# 把上层的特征图 自动卷积为 1X1的通道数不变的特征图 然后给每一个通道乘一个权重 就分配了各个通道的注意力 把这个与原图残差回去 与原图融合 这样对比原图来说 形状 CHW都没变
# 注意力机制 可以即插即用 CHW都没变
import torch
import os
import torch.nn as nn
from torchvision.models import resnet18,ResNet18_Weights
from torchvision.models.resnet import _resnet,BasicBlock
path=os.path.dirname(__file__)
onnxpath=os.path.join(path,"assets/resnet_SE-Identity.onnx")
onnxpath=os.path.relpath(onnxpath)
class SENet1(nn.Module):
def __init__(self,inchannel,r=16):
super().__init__()
# 全局平均池化 把所以通道 整个通道进行平均池化
self.inchannel=inchannel
self.pool1=nn.AdaptiveAvgPool2d(1)
# 对全局平均池化后的结果 赋予每个通道的权重 不选择最大池化因为不是在突出最大的特征
# 这里不是直接一个全连接生成 权重 而是用两个全连接来生成 权重 第一个relu激活 第二个Sigmoid 为每一个通道生成一个0-1的权重
# 第一个全连接输出的通道数数量要缩小一下,不能直接传入多少就输出多少,不然参数量太多,第二个通道再输出回去就行
# 缩放因子
self.fc1=nn.Sequential(nn.Linear(self.inchannel,self.inchannel//r),nn.ReLU())
self.fc2=nn.Sequential(nn.Linear(self.inchannel//r,self.inchannel),nn.Sigmoid())
# fc1 用relu会信息丢失 保证inchannel//r 至少要32
# 用两层全连接可以增加注意力层的健壮性
def forward(self,x):
x1=self.pool1(x)
x1=x1.view(x1.shape[0],-1)
x1=self.fc1(x1)
x1=self.fc2(x1)
# 得到了每一个通道的权重
x1=x1.unsqueeze(2).unsqueeze(3)
# 与原来的相乘
return x*x1
def demo1():
torch.manual_seed(666)
img1=torch.rand(1,128,224,224)
senet1=SENet1(img1.shape[1],2)
res=senet1.forward(img1)
print(res.shape)
# 可以把SE模块加入到经典的CNN模型里面 有残差模块的在残差模块后面加入SE 残差模块的输出 当SE模块的输入
# 在卷积后的数据与原数据相加之前 把卷积的数据和 依靠卷积后的数据产生的SE模块的数据 相乘 然后再与原数据相加
# 这个要看源码 进行操作
# 也可以不在 残差后面 进行 有很多种插入SE的方式
# 要找到 网络的残差模块
def demo2():
# 把SE模块加入到ResNet18
# 继承一个BasicBlock类 对resnet18的残差模块进行一些重写
class BasicBlock_SE(BasicBlock):
def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1, base_width = 64, dilation = 1, norm_layer = None):
super().__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
self.se=SENet1(inplanes)# SE-Identity 加法 在 数据传进来的时候备份两份数据 一份卷积 一份加注意力SE模块 然后两个结果相加输出
def forward(self, x):
identity = x
identity=self.se(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
# self.se=SENet1(planes)# SE-POST 加法 在 残差模块彻底完成了后加注意力SE模块 然后结果输出
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out += identity
# out = self.relu(out)
# out=self.se(out)
# return out
# self.se=SENet1(inplanes)# SE-PRE 加法 在 残差模块卷积之前加注意力SE模块 然后结果输出
# def forward(self, x):
# identity = x
# out=self.se(x)
# out = self.conv1(out)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out += identity
# out = self.relu(out)
# return out
# self.se=SENet1(planes)# Standard_SE 加法 在 残差模块卷积h后加注意力SE模块 然后与原数据项加结果输出
# def forward(self, x):
# identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
# out = self.conv2(out)
# out = self.bn2(out)
# if self.downsample is not None:
# identity = self.downsample(x)
# out=self.se(out)
# out += identity
# out = self.relu(out)
# return out
def resnet18_SE(*, weights= None, progress: bool = True, **kwargs):
weights = ResNet18_Weights.verify(weights)
return _resnet(BasicBlock_SE, [2, 2, 2, 2], weights, progress, **kwargs)
model1=resnet18_SE()
x = torch.randn(1, 3, 224, 224)
# 导出onnx
torch.onnx.export(
model1,
x,
onnxpath,
verbose=True, # 输出转换过程
input_names=["input"],
output_names=["output"],
)
print("onnx导出成功")
# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少
# SE在模型的早期层并没有 起多大的作用 在后期层中加 SE机制效果明显 且参数更少
# 改模型不仅需要 加 一个网络结构 而且也需要注意前向传播 有没有问题
def demo3(): # 在resnet18中的后期 层里面加 SE 前期层不加
class ResNet_SE_laye(ResNet):
def __init__(self, block, layers, num_classes = 1000, zero_init_residual = False, groups = 1, width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
super().__init__(block, layers, num_classes, zero_init_residual, groups, width_per_group, replace_stride_with_dilation, norm_layer)
def _layer_update_SE(self):
self.se=SENet1(self.layer3[1].conv2.out_channels,8)
self.layer3[1].conv2=nn.Sequential(self.layer3[1].conv2,self.se)
print(self.layer3)
pass
return self.layer3
def _resnet_SE_layer(
block,
layers,
weights,
progress: bool,
**kwargs,
):
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ResNet_SE_laye(block, layers, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
def resnet18_SE_layer(*, weights= None, progress: bool = True, **kwargs):
weights = ResNet18_Weights.verify(weights)
return _resnet_SE_layer(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
model=resnet18_SE_layer()
# print(model)
layer=model._layer_update_SE()
torch.onnx.export(layer,torch.rand(1,128,224,224),"layer.onnx")
pass
if __name__=="__main__":
# demo1()
# demo2()
pass
相关文章:

12.12 深度学习-卷积的注意力机制-通道注意力SENet
# 告诉模型训练的时候 对某个东西 给予额外的注意 额外的权重参数 分配注意力 # 不重要的就抑制 降低权重参数 比如有些项目颜色重要 有些是形状重要 # 通道注意力 一般都要比较多的通道加注意力 # SENet # 把上层的特征图 自动卷积为 1X1的通道数不变的特征图 然后给每一个…...

H5 scss 移动端的样式适配
在移动端样式的scss文件中,出现了这些变量 env() 与 constant() 设置安全区域,是css里IOS11新增的属性,webkit的css函数,用于设定安全区域与边界的距离,有4个预定义变量: safe-area-inset-left: 安全区域距…...

【JAVA】Java项目实战—移动端项目:天气查询APP
在移动互联网时代,天气查询应用程序(APP)是日常生活中不可或缺的一部分。无论是出门旅行、上班通勤,还是安排户外活动,获取实时天气信息都至关重要。Java作为一种强大且广泛使用的编程语言,特别适合用于开发…...

SpringBoot - 动态端口切换黑魔法
文章目录 关键技术点核心原理Code 关键技术点 利用 Spring Boot 内嵌 Servlet 容器 和 动态端口切换 的方式实现平滑更新的方案,关键技术点如下: Servlet 容器重新绑定端口:Spring Boot 使用 ServletWebServerFactory 动态设置新端口。零停…...

Java爬虫技术:挖掘淘宝数据的利器
在当今大数据时代,网络爬虫技术已经成为获取网络数据的重要手段。Java作为一种强大且灵活的编程语言,非常适合开发复杂的网络爬虫系统。本文将详细介绍Java爬虫能够爬取的淘宝数据类型,并提供具体的代码示例,帮助您快速入门并掌握…...

Chromium for Android 浏览器的编译和安装
Chromium for Android 浏览器的编译和安装 Chromium for Android 浏览器的编译和安装环境要求和配置Chromium for Android源码下载安装 depot_tools获取代码转换现有的Linux检出安装额外的构建依赖运行钩子 Chromium for Android源码编译设置编译环境 编译 ChromiumChromium fo…...

实景视频与模型叠加融合?
[视频GIS系列]无人机视频与与实景模型进行实时融合_无人机视频融合-CSDN博客文章浏览阅读1.5k次,点赞28次,收藏14次。将无人机视频与实景模型进行实时融合是一个涉及多个技术领域的复杂过程,主要包括无人机视频采集、实景模型构建、视频与模型…...

Scala的隐式类
package hfd //隐式类 //任务:给之前的BaseUser添加新的功能,但是不要直接去改代码 //思路:把BaseUser通过隐式转换,改成一个新类型,而这个新类型中有这新的方法 //implicit class一个隐式转换函数类 //作用࿱…...

常见软件设计模式介绍:三层架构、MVC、SSM、EDD、DDD
三层架构(View Service Dao) 三层架构是指:视图层 view(表现层),服务层 service(业务逻辑层),持久层 Dao(数据访问层) 表现层:直接跟前…...

Springboot技术栈常见问题及搭建步骤
一. SpringBoot介绍 1.1. 引言 为了使用SSM框架去开发, 准备SSM框架的模板配置 为了使Spring整合第三方框架, 单独的去编写xml文件 导致SSM项目后期xml文件特别多, 维护xml文件的成本是很高的 SSM工程部署也是很麻烦, 依赖第三方的容器 SSM开发方式很是笨重 1.2 SpringBoot …...

session 共享服务器
1.安装 kryo-3.0.3.jar asm-5.2.jar objenesis-2.6.jar reflectasm-1.11.9.jar minlog-1.3.1.jar kryo-serializers-0.45.jar msm-kryo-serializer-2.3.2.jar memcached-session-manager-tc9-2.3.2.jar spymemcached-2.12.3.jar memcached-session-manager-2.3.2.jar …...

vue2:v-for实现的el-radio-group选中时显示角标,并自定义选中按钮的字体颜色和背景色
项目中需要实现一组预定义查询,每一个查询按钮在选中时右上角显示一个角标,展示当前查询返回的数据条目。 1、text-color="#3785FF" fill="#E6EAF1" 处理选中时的字体颜色和背景色,如上图,分别为蓝色和浅灰色。 2、badge中:value="selectedRadio…...

【Linux】-学习笔记10
第八章、Linux下的火墙管理及优化 1.什么是防火墙 从功能角度来讲 防火墙是位于内部网和外部网之间的屏障,它按照系统管理员预先定义好的规则来控制数据包的进出 从功能实现角度来讲 火墙是系统内核上的一个模块netfilter(数据包过滤机制) …...

鸿蒙NEXT开发案例:九宫格随机
【引言】 在鸿蒙NEXT开发中,九宫格抽奖是一个常见且有趣的应用场景。通过九宫格抽奖,用户可以随机获得不同奖品,增加互动性和趣味性。本文将介绍如何使用鸿蒙开发框架实现九宫格抽奖功能,并通过代码解析展示实现细节。 【环境准…...

深度解析:RTC电路上的32.768KHz时钟的频偏及测试
1、什么是RTC RTC是Real-Time Clock(实时时钟)的缩写,通常在电子产品中,是用时钟电路(外部采用时钟芯片,比如AiP8563)或时钟模块(SOC内部包含了时钟模块,只需要外接32.768KHz晶振)来…...

Scala的泛型
需求:定义一个名为getMiddleEle 的方法用它来获取当前的列表的中间位置的值中间位置的下标 长度/2目标:getMiddleEle(List(1,2,3,4,5)) > 5/2 2 > 下标为2的元素是:3 getMiddleEle(List(1,2,3,4)) > 4/2 2 > 下标为2的元素是:3格式如下: 定义一个函数的格式:def…...

OpenGL ES详解——glUniform1i方法是否能用于设置纹理单元
glUniform1i 方法确实可以用于设置纹理单元(texture unit)。在OpenGL中,纹理单元是图形硬件的一部分,它允许你同时绑定多个纹理,并在着色器程序中通过uniform变量来选择使用哪个纹理。 通常,纹理单元通过整…...

探索 Janus-1.3B:一个统一的 Any-to-Any 多模态理解与生成模型
随着多模态技术的不断发展,越来越多的模型被提出以解决跨文本与图像等多种数据类型的任务。Janus-1.3B 是由 DeepSeek 推出的一个革命性的模型,它通过解耦视觉编码并采用统一的 Transformer 架构,带来了一个高度灵活的 any-to-any 多模态框架…...

论文信息搜集
系列博客目录 文章目录 系列博客目录1.秩典型相关分析及其在视觉搜索重排序中的应用《Rank canonical correlation analysis and its application in visual search reranking》2.利用边信息的规范秩估计在多维谐波恢复中的应用《Canonical Rank Estimation Using Side Informa…...

实操给自助触摸一体机接入大模型语音交互
本文以CSK6 大模型开发板串口触摸屏为例,实操讲解触摸一体机怎样快速增加大模型语音交互功能,使用户能够通过语音在一体机上查询信息、获取智能回答及实现更多互动功能等。 在本文方案中通过CSK6大模型语音开发板采集用户语音,将语音数据传输…...

图表的放大和刷新功能
正常图表渲染显示: // 漏斗ading动画 let myChartone; // 获取配置项 let optionone; // 获取漏斗的数据 let order; let pay_order; let pay_order_num; let pay_order_num_num; let optiones; // 漏斗渲染 function polt(data) {// 从名为data的对象中获取ordata属…...

SQLServer利用QQ邮箱做SMTP服务器发邮件
环境 Microsoft SQL Server 2019 (RTM) - 15.0.2000.5 (X64) SQL Server Management Studio 15.0.18384.0 SQL Server 管理对象 (SMO) 16.100.46367.54 Microsoft .NET Framework 4.0.30319.42000 操作系统 Windows Server2019 ———————————————— 前言…...

flutter 多文本,其中文本下划线往下移动
变态需求 flutter中再满足多行文本,文本内有多个样式,并且多个样式可触发事件的情况,将其中的一部分文本的下划线往下移 方式一: 实现 使用RichText组件,主要是看中里面的WidgetSpan可以穿child为一个widget 实现源…...

7.OPEN SQL
总学习目录请点击下面连接 SAP ABAP开发从0到入职,冷冬备战-CSDN博客 目录 编辑 1.OPEN-SQL 简单回顾 R3体系 OEPN-SQL 2.OPEN-SQL 读取数据 2.1Select 语句 select 1条数据 多条数据与into AS别名 2.2INTO 结构体 内表 例子 2.3FROM 选择动态表…...

Python轻松获取抖音视频播放量
现在在gpt的加持下写一些简单的代码还是很容易的,效率高,但是要有一点基础,不然有时候发现不了问题,这些都需要经验积累和实战,最好能和工作结合起来,不然很快一段时间就忘的干干净净了,下面就是…...

YOLOv8目标检测(三*)_最佳超参数训练
YOLOv8目标检测(一)_检测流程梳理:YOLOv8目标检测(一)_检测流程梳理_yolo检测流程-CSDN博客 YOLOv8目标检测(二)_准备数据集:YOLOv8目标检测(二)_准备数据集_yolov8 数据集准备-CSDN博客 YOLOv8目标检测(三)_训练模型:YOLOv8目标检测(三)_训…...

SpringBoot SPI
参考 https://blog.csdn.net/Peelarmy/article/details/106872570 https://javaguide.cn/java/basis/spi.html#%E4%BD%95%E8%B0%93-spi SPI SPI(service provider interface)是JDK提供的服务发现机制。以JDBC为例,JDK提供JDBC接口,在包java.sql.*。MY…...

uniappp配置导航栏自定义按钮(解决首次加载图标失败问题)
1.引入iconfont的图标,只保留这两个文件 2.App.vue引入到全局中 import "./static/fonts/iconfont.css"3.pages.json中配置text为图标对应的unicode {"path": "pages/invite/invite","style": {"h5": {"…...

【Apache paimon】-- 集成 hive3.1.3 异常
目录 1、场景再现 Step1:在 hive cli beeline 执行创建 hive paimon 表 Step2:使用 insert into 写入数据 Step3:抛出异常 2、原因分析 Step1:在 yarn resource manager 作业界面查询 hive sql mr job 的 yarn log Step2:搜索job 使用的 zstd jar 版本 Step3:定…...

基于docker部署Nacos最新版本-国内稳定镜像
介绍 当前微服务架构常用的配置中心,本文推荐的是阿里云开源的nacos,截止发布本文为止,最新的nacos稳定版本为2.4.3 拉取镜像 //这个是国内目前可以下载的成熟的nacos镜像仓库,默认的docker hub需要不断的翻墙才可以下载 docke…...