当前位置: 首页 > news >正文

深度学习——常见注意力机制

1.SENet

SENet属于通道注意力机制。2017年提出,是imageNet最后的冠军

SENet采用的方法是对于特征层赋予权值。

重点在于如何赋权

1.将输入信息的所有通道平均池化。
2.平均池化后进行两次全连接,第一次全连接链接的神经元较少,第二次全连接神经元数和通道数一致
3.将Sigmoid的值固定为0-1之间
4.将权值和特征层相乘。

在这里插入图片描述

import torch
import torch.nn as nn
import mathclass se_block(nn.Module):def __init__(self, channel, ratio=16):super(se_block, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // ratio, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // ratio, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y

2.ECANet

细心的人会发现,全连接其实是一个非常耗费算力的东西,对于边缘设备的压力非常大,所以ECANet觉得SENet并不需要那么多的全连接,我们直接在GAP后做一维卷积,而后取sigmoid为0-1来获取权值即可。

ECANet认为SE的全通道信息捕获是多此一举,而卷积就有很好的跨通道信息获取能力。
在这里插入图片描述

class eca_block(nn.Module):def __init__(self, channel, b=1, gamma=2):super(eca_block, self).__init__()kernel_size = int(abs((math.log(channel, 2) + b) / gamma))kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid()def forward(self, x):y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)

4.GCNet

GCNet是我们项目的模型中使用的一种注意力机制

GCNet主要借鉴了SENet和NLNet的优点,主要基于NLNet,把NLNet的计算量削减了数倍

先看他是怎么用NLNet的

NLNet原公式
在这里插入图片描述
改进后的NLNet公式
在这里插入图片描述

改进的区别就是去掉了Wz系数。

Wz系数的削减主要是对图像中的观察得出的创意。
在这里插入图片描述
作者说,attention map在不同位置上计算的结果几乎一致,那么我们只需要计算一次然后共享attention map应该也可以获得很好的效果,并且计算量可以下降到1/(W*H)。

Simple NL Block和NL Block的结构对比如图所示,并且经过文章的实验表明,简化后的性能与原本的性能相当。
在这里插入图片描述

接着,作者基于S-NLNet和SENet的有点提出了GCNet

(1) 相比于SNL,SNL中的transform的1x1卷积在res5中是2048x1x1x2048,其计算量较大,所以借鉴SE的方法,加入压缩因子,为了更好的优化,还加入了layernorm。
(2)相比于SE,一方面是提取的全局信息更加充分(其实在后续的实验中说服力不是很强,单独avg pooling+add,只掉了0.3个点,但是更加简洁),另一方面则是加号和乘号的区别,而且在实验结果上,加号比乘号有显著的优势。

import torch
import torch.nn as nn
import torchvisionclass GlobalContextBlock(nn.Module):def __init__(self,inplanes,ratio,pooling_type='att',fusion_types=('channel_add', )):super(GlobalContextBlock, self).__init__()assert pooling_type in ['avg', 'att']assert isinstance(fusion_types, (list, tuple))valid_fusion_types = ['channel_add', 'channel_mul']assert all([f in valid_fusion_types for f in fusion_types])assert len(fusion_types) > 0, 'at least one fusion should be used'self.inplanes = inplanesself.ratio = ratioself.planes = int(inplanes * ratio)self.pooling_type = pooling_typeself.fusion_types = fusion_typesif pooling_type == 'att':self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)else:self.avg_pool = nn.AdaptiveAvgPool2d(1)if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_add_conv = Noneif 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_mul_conv = Nonedef spatial_pool(self, x):batch, channel, height, width = x.size()if self.pooling_type == 'att':input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]context_mask = self.conv_mask(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = self.softmax(context_mask)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)else:# [N, C, 1, 1]context = self.avg_pool(x)return contextdef forward(self, x):# [N, C, 1, 1]context = self.spatial_pool(x)out = xif self.channel_mul_conv is not None:# [N, C, 1, 1]channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))out = out * channel_mul_termif self.channel_add_conv is not None:# [N, C, 1, 1]channel_add_term = self.channel_add_conv(context)out = out + channel_add_termreturn outif __name__=='__main__':model = GlobalContextBlock(inplanes=16, ratio=0.25)print(model)input = torch.randn(1, 16, 64, 64)out = model(input)print(out.shape)

4.CA注意力机制

CA机制也是和之前的GCNet一样对两个已有注意力(SENet和CBAM)进行了改进。

CA提出

1.SENet作为通道注意力机制,侧重通道之前的依赖关系,忽略了空间特征的作用。
2.CBAM可以一定程度弥补,但是CBAM对于长程依赖有待改进。

经过融合改进后,CA机制有以下优点

1、不仅考虑了通道信息,还考虑了方向相关的位置信息。
2、足够的灵活和轻量,能够简单的插入到轻量级网络的核心模块中。

CA机制的算法流程图如下

在这里插入图片描述
1.CA机制为了避免将空间特征全都压缩到通道中,放弃了全局平均池化,转为分别对x和y方向进行
别生成尺寸为C ∗ H ∗ 1 和C ∗ 1 ∗ W 的attention map
在这里插入图片描述
2.将生成的两个attention map进行池化,然后concat,然后进行F1操作(利用1*1卷积核进行降维,如SE注意力中操作)和激活操作,生成特征图f

在这里插入图片描述
这图怎么这么大?
3.沿着空间维度,再将f进行split操作,分别得到h和w的特征图后再用1 × 1卷积进行升维度操作,结合sigmoid激活函数得到最后的注意力向量gh和gw

代码

class CoordAtt(nn.Module):def __init__(self, inp, oup, groups=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // groups)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.relu = h_swish()def forward(self, x):identity = xn,c,h,w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.relu(y) x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)x_h = self.conv2(x_h).sigmoid()x_w = self.conv3(x_w).sigmoid()x_h = x_h.expand(-1, -1, h, w)x_w = x_w.expand(-1, -1, h, w)y = identity * x_w * x_hreturn y

明日:ODConv,数据结构复习,套磁老师

相关文章:

深度学习——常见注意力机制

1.SENet SENet属于通道注意力机制。2017年提出,是imageNet最后的冠军 SENet采用的方法是对于特征层赋予权值。 重点在于如何赋权 1.将输入信息的所有通道平均池化。 2.平均池化后进行两次全连接,第一次全连接链接的神经元较少,第二次全连…...

Python 进阶(七):高级文件操作(shutil 模块)

❤️ 博客主页:水滴技术 🌸 订阅专栏:Python 入门核心技术 🚀 支持水滴:点赞👍 收藏⭐ 留言💬 文章目录 1. 简介2. 常用函数2.1 复制文件2.2 复制目录2.3 移动文件或目录2.4 删除文件或目录2.…...

保留网络:大型语言模型的Transformer继任者

原文信息 原文题目:《Retentive Network: A Successor to Transformer for Large Language Models》 原文引用:Sun Y, Dong L, Huang S, et al. Retentive Network: A Successor to Transformer for Large Language Models[J]. arXiv preprint arXiv:2…...

算法通关村第二关——反转链表青铜笔记

LeetCode 206.反转链表 建立虚拟结点辅助翻转 public ListNode reverseList(ListNode head) {ListNode ans new ListNode(-1);ListNode cur head;while(cur!null){ListNode curNext cur.next;cur.next ans.next;ans.next cur;cur curNext;}return ans.next; }不带虚拟头…...

【Linux】——线程安全

目录 关于线程进程的问题 可重入与线程安全 常见的线程安全的情况 常见的不可重入的情况 常见的可重入的情况 可重入与线程安全区别 可重入与线程安全联系 Linux线程互斥 进程线程间的互斥相关概念 互斥量mutex 互斥量mutex常用接口 互斥量改造抢票系统 互斥量的原…...

[React]生命周期

前言 学习React,生命周期很重要,我们了解完生命周期的各个组件,对写高性能组件会有很大的帮助. Ract生命周期 React 生命周期分为三种状态 1. 初始化 2.更新 3.销毁 初始化 1、getDefaultProps() 设置默认的props,也可以用duf…...

【2023】Redis实现消息队列的方式汇总以及代码实现

Redis实现消息队列的方式汇总以及代码实现 前言开始前准备1、添加依赖2、添加配置的Bean 具体实现一、从最简单的开始:List 队列代码实现 二、发布订阅模式:Pub/Sub1、使用RedisMessageListenerContainer实现订阅2、还可以使用redisTemplate实现订阅 三、…...

ARM裸机-10

1、X210开发板和光盘资料 1.1、配置信息 CPU:三星S5PV210 内存:512M DDR2 SDRAM Flash:4GB iBand LCD:7寸,分辨率800x480 触摸屏:电容触摸屏 2、X210开发板硬件手册 3、X210开发板刷系统 3.1、什么是刷…...

「C/C++」C/C++指针详解

✨博客主页何曾参静谧的博客📌文章专栏「C/C」C/C程序设计📚全部专栏「UG/NX」NX二次开发「UG/NX」BlockUI集合「VS」Visual Studio「QT」QT5程序设计「C/C」C/C程序设计「Win」Windows程序设计「算法」数据结构与算法「File」数据文件格式 目录 一、术语…...

提高电脑寿命的维护技巧与方法分享

在维护电脑运行方面,我有一些自己觉得非常有用的技巧和方法。下面我将分享一些我常用的维护技巧,并解释为什么我会选择这样做以及这样做的好处。 首先,我经常清理我的电脑内部的灰尘。电脑内部的灰尘会影响散热效果,导致电脑发热…...

React常见面试题

React常见面试题 一、React中的样式管理有哪些方法 内联样式:对象,作用于当前组件普通样式表: 作用于全局,文件名是:xxx.scssCSS模块:类似Vue的scoped, 文件名需是:xxx.module.scs…...

C++中数据的输入输出介绍

C中数据的输入输出介绍 C中数据的输入输出涉及到的文件 <iostream>&#xff1a;这是C标准库中最常用的头文件之一&#xff0c;包含了进行标准输入输出操作的类和对象&#xff0c;如std::cin、std::cout、std::endl等。 <iomanip>&#xff1a;该头文件提供了一些用…...

0101日志-运维-mysql

1 错误日志 错误日志&#xff08;Error Log&#xff09;&#xff1a;错误日志记录了MySQL引擎在运行过程中出现的错误和异常情况。这些错误可能包括启动和关闭问题、数据库崩溃、权限问题等。错误日志对于排查和解决MySQL引擎问题非常有帮助。 改日志默认开启&#xff0c;默认存…...

LabVIEW使用灰度和边缘检测进行视频滤波

LabVIEW使用灰度和边缘检测进行视频滤波 数字图像处理&#xff08;DIP&#xff09;是真实和连续世界的离散表示。除此之外&#xff0c;这种数字图像在通信、医学、遥感、地震学、工业自动化、机器人、航空航天和教育等领域变得非常重要。计算机技术越来越需要视频图像的数字图…...

SpringBoot整合WebService

SpringBoot整合WebService WebService是一个比较旧的远程调用通信框架&#xff0c;现在企业项目中用的比较少&#xff0c;因为它逐步被SpringCloud所取代&#xff0c;它的优势就是能够跨语言平台通信&#xff0c;所以还有点价值&#xff0c;下面来看看如何在SpringBoot项目中使…...

【LangChain】向量存储之FAISS

LangChain学习文档 【LangChain】向量存储(Vector stores)【LangChain】向量存储之FAISS 概要 Facebook AI 相似性搜索&#xff08;Faiss&#xff09;是一个用于高效相似性搜索和密集向量聚类的库。它包含的算法可以搜索任意大小的向量集&#xff0c;甚至可能无法容纳在 RAM 中…...

小研究 - 主动式微服务细粒度弹性缩放算法研究(三)

微服务架构已成为云数据中心的基本服务架构。但目前关于微服务系统弹性缩放的研究大多是基于服务或实例级别的水平缩放&#xff0c;忽略了能够充分利用单台服务器资源的细粒度垂直缩放&#xff0c;从而导致资源浪费。为此&#xff0c;本文设计了主动式微服务细粒度弹性缩放算法…...

驱动开发相关内容复盘

并发与竞争 并发 ​ 多个“用户”同时访问同一个共享资源。 竞争 并发和竞争的处理方法 处理并发和竞争的机制&#xff1a;原子操作、自旋锁、信号量和互斥体。 1、原子操作 ​ 原子操作就是指不能再进一步分割的操作&#xff0c;一般原子操作用于变量或者位操作。 ​ …...

2.2 身份鉴别与访问控制

数据参考&#xff1a;CISP官方 目录 身份鉴别基础基于实体所知的鉴别基于实体所有的鉴别基于实体特征的鉴别访问控制基础访问控制模型 一、身份鉴别基础 1、身份鉴别的概念 标识 实体身份的一种计算机表达每个实体与计算机内部的一个身份表达绑定信息系统在执行操作时&a…...

C++ 注释

程序的注释是解释性语句&#xff0c;您可以在 C 代码中包含注释&#xff0c;这将提高源代码的可读性。所有的编程语言都允许某种形式的注释。 C 支持单行注释和多行注释。注释中的所有字符会被 C 编译器忽略。 C 注释一般有两种&#xff1a; // - 一般用于单行注释。 /* … …...

支付宝秘钥模式说明

1 python服务器需要使用 PKCS1格式2 秘钥格式是不带头尾的&#xff0c;中间的纯字符串...

InvokeAI工具函数库:10个核心工具方法与实用辅助函数详解

InvokeAI工具函数库&#xff1a;10个核心工具方法与实用辅助函数详解 【免费下载链接】InvokeAI Invoke is a leading creative engine for Stable Diffusion models, empowering professionals, artists, and enthusiasts to generate and create visual media using the late…...

ChatGPT_JCM深色模式实现:保护眼睛的界面显示方案

ChatGPT_JCM深色模式实现&#xff1a;保护眼睛的界面显示方案 【免费下载链接】ChatGPT_JCM 项目地址: https://gitcode.com/gh_mirrors/ch/ChatGPT_JCM ChatGPT_JCM是一款功能强大的AI交互工具&#xff0c;其深色模式实现为用户提供了舒适的夜间使用体验&#xff0c;有…...

AI辅助开发:描述需求,快马AI自动生成旅行商问题算法与可视化

最近在做一个旅行商问题(TSP)的算法项目&#xff0c;想试试用AI辅助开发能有多高效。结果发现InsCode(快马)平台的AI功能真的帮了大忙&#xff0c;整个过程特别顺畅。这里分享一下我的体验。 需求分析阶段 刚开始我只是简单描述了需求&#xff1a;"需要一个用模拟退火算…...

SEO_10个提升网站排名的实用SEO技巧分享(370 )

SEO:10个提升网站排名的实用SEO技巧分享 在当今的互联网时代&#xff0c;一个网站的成功离不开搜索引擎优化&#xff08;SEO&#xff09;。SEO不仅仅是一套技术&#xff0c;更是一种思维方式。本文将详细分享十个实用的SEO技巧&#xff0c;帮助你提升网站的排名&#xff0c;吸…...

linux系统中简单统计java项目代码行数信息

新建脚本文件&#xff08;最好在项目根目录下&#xff09;&#xff1a;count_java.shvi count_java.sh编辑内容&#xff1a;按一下键盘上的i键&#xff0c;屏幕左下角会出现 -- INSERT --&#xff0c;输入一下内容&#xff1a; #!/bin/bash find . -name "*.java" -p…...

3步快速上手!终极缠论量化工具:基于TradingView本地SDK的几何交易可视化完整指南

3步快速上手&#xff01;终极缠论量化工具&#xff1a;基于TradingView本地SDK的几何交易可视化完整指南 【免费下载链接】chanvis 基于TradingView本地SDK的可视化前后端代码&#xff0c;适用于缠论量化研究&#xff0c;和其他的基于几何交易的量化研究。 缠论量化 摩尔缠论 缠…...

告别教材下载烦恼:国家中小学智慧教育平台电子课本解析工具如何实现3分钟高效获取

告别教材下载烦恼&#xff1a;国家中小学智慧教育平台电子课本解析工具如何实现3分钟高效获取 【免费下载链接】tchMaterial-parser 国家中小学智慧教育平台 电子课本下载工具&#xff0c;帮助您从智慧教育平台中获取电子课本的 PDF 文件网址并进行下载&#xff0c;让您更方便地…...

源代码之下的硅基启示录——Claude Code“核泄漏”事件的深度剖析与时代回响

引言 公元2026年3月30日&#xff0c;一个看似平常的春日&#xff0c;硅基世界却迎来了一场史无前例的地震。 一家以“安全”为最高信条的AI公司&#xff0c;以一种最荒诞的方式&#xff0c;亲手打开了潘多拉的魔盒。Anthropic&#xff0c;这家估值高达3800亿美元的AI新贵&#…...

嵌入式开发必备:三大代码对比工具深度评测

1. 代码对比工具概述作为一名嵌入式开发工程师&#xff0c;我每天都要处理大量的代码修改和版本对比工作。在多年的开发实践中&#xff0c;我发现选择合适的代码对比工具能极大提升工作效率。虽然Beyond Compare是业内公认的标杆产品&#xff0c;但实际工作中我们还有更多选择&…...