当前位置: 首页 > news >正文

【深度学习】注意力机制(二)

本文介绍一些注意力机制的实现,包括EA/MHSA/SK/DA/EPSA。

【深度学习】注意力机制(一)

【深度学习】注意力机制(三)

目录

一、EA(External Attention)

二、Multi Head Self Attention

三、SK(Selective Kernel Networks)

四、DA(Dual Attention)

五、EPSA(Efficient Pyramid Squeeze Attention)


一、EA(External Attention)

EA可以关注全局的空间信息,论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass External_attention(nn.Module):'''Arguments:c (int): The input and output channel number.'''def __init__(self, c):super(External_attention, self).__init__()self.conv1 = nn.Conv2d(c, c, 1)self.k = 64self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        self.conv2 = nn.Sequential(nn.Conv2d(c, c, 1, bias=False),norm_layer(c))        for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.Conv1d):n = m.kernel_size[0] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, _BatchNorm):m.weight.data.fill_(1)if m.bias is not None:m.bias.data.zero_()def forward(self, x):idn = xx = self.conv1(x)b, c, h, w = x.size()n = h*wx = x.view(b, c, h*w)   # b * c * n attn = self.linear_0(x) # b, k, nattn = F.softmax(attn, dim=-1) # b, k, nattn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, nx = self.linear_1(attn) # b, c, nx = x.view(b, c, h, w)x = self.conv2(x)x = x + idnx = F.relu(x)return x

二、Multi Head Self Attention

注意力机制的经典,Transformer的基石。论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ScaledDotProductAttention(nn.Module):'''Scaled dot-product attention'''def __init__(self, d_model, d_k, d_v, h,dropout=.1):''':param d_model: Output dimensionality of the model:param d_k: Dimensionality of queries and keys:param d_v: Dimensionality of values:param h: Number of heads'''super(ScaledDotProductAttention, self).__init__()self.fc_q = nn.Linear(d_model, h * d_k)self.fc_k = nn.Linear(d_model, h * d_k)self.fc_v = nn.Linear(d_model, h * d_v)self.fc_o = nn.Linear(h * d_v, d_model)self.dropout=nn.Dropout(dropout)self.d_model = d_modelself.d_k = d_kself.d_v = d_vself.h = hself.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):'''Computes:param queries: Queries (b_s, nq, d_model):param keys: Keys (b_s, nk, d_model):param values: Values (b_s, nk, d_model):param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).:return:'''b_s, nq = queries.shape[:2]nk = keys.shape[1]q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)if attention_weights is not None:att = att * attention_weightsif attention_mask is not None:att = att.masked_fill(attention_mask, -np.inf)att = torch.softmax(att, -1)att=self.dropout(att)out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)out = self.fc_o(out)  # (b_s, nq, d_model)return out

三、SK(Selective Kernel Networks)

SK是通道注意力机制。论文地址:论文连接

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDictclass SKAttention(nn.Module):def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):super().__init__()self.d=max(L,channel//reduction)self.convs=nn.ModuleList([])for k in kernels:self.convs.append(nn.Sequential(OrderedDict([('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),('bn',nn.BatchNorm2d(channel)),('relu',nn.ReLU())])))self.fc=nn.Linear(channel,self.d)self.fcs=nn.ModuleList([])for i in range(len(kernels)):self.fcs.append(nn.Linear(self.d,channel))self.softmax=nn.Softmax(dim=0)def forward(self, x):bs, c, _, _ = x.size()conv_outs=[]### splitfor conv in self.convs:conv_outs.append(conv(x))feats=torch.stack(conv_outs,0)#k,bs,channel,h,w### fuseU=sum(conv_outs) #bs,c,h,w### reduction channelS=U.mean(-1).mean(-1) #bs,cZ=self.fc(S) #bs,d### calculate attention weightweights=[]for fc in self.fcs:weight=fc(Z)weights.append(weight.view(bs,c,1,1)) #bs,channelattention_weughts=torch.stack(weights,0)#k,bs,channel,1,1attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1### fuseV=(attention_weughts*feats).sum(0)return V

四、DA(Dual Attention)

DA融合了通道注意力和空间注意力机制。论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttentionclass PositionAttentionModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)def forward(self,x):bs,c,h,w=x.shapey=self.cnn(x)y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,cy=self.pa(y,y,y) #bs,h*w,creturn yclass ChannelAttentionModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)def forward(self,x):bs,c,h,w=x.shapey=self.cnn(x)y=y.view(bs,c,-1) #bs,c,h*wy=self.pa(y,y,y) #bs,c,h*wreturn yclass DAModule(nn.Module):def __init__(self,d_model=512,kernel_size=3,H=7,W=7):super().__init__()self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)def forward(self,input):bs,c,h,w=input.shapep_out=self.position_attention_module(input)c_out=self.channel_attention_module(input)p_out=p_out.permute(0,2,1).view(bs,c,h,w)c_out=c_out.view(bs,c,h,w)return p_out+c_out

五、EPSA(Efficient Pyramid Squeeze Attention)

论文:论文地址

如下图:

代码如下(代码连接):

import torch.nn as nnclass SEWeightModule(nn.Module):def __init__(self, channels, reduction=16):super(SEWeightModule, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)self.relu = nn.ReLU(inplace=True)self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)self.sigmoid = nn.Sigmoid()def forward(self, x):out = self.avg_pool(x)out = self.fc1(out)out = self.relu(out)out = self.fc2(out)weight = self.sigmoid(out)return weightdef conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):"""standard convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class PSAModule(nn.Module):def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):super(PSAModule, self).__init__()self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,stride=stride, groups=conv_groups[0])self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,stride=stride, groups=conv_groups[1])self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,stride=stride, groups=conv_groups[2])self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,stride=stride, groups=conv_groups[3])self.se = SEWeightModule(planes // 4)self.split_channel = planes // 4self.softmax = nn.Softmax(dim=1)def forward(self, x):batch_size = x.shape[0]x1 = self.conv_1(x)x2 = self.conv_2(x)x3 = self.conv_3(x)x4 = self.conv_4(x)feats = torch.cat((x1, x2, x3, x4), dim=1)feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])x1_se = self.se(x1)x2_se = self.se(x2)x3_se = self.se(x3)x4_se = self.se(x4)x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)attention_vectors = self.softmax(attention_vectors)feats_weight = feats * attention_vectorsfor i in range(4):x_se_weight_fp = feats_weight[:, i, :, :]if i == 0:out = x_se_weight_fpelse:out = torch.cat((x_se_weight_fp, out), 1)return out

相关文章:

【深度学习】注意力机制(二)

本文介绍一些注意力机制的实现,包括EA/MHSA/SK/DA/EPSA。 【深度学习】注意力机制(一) 【深度学习】注意力机制(三) 目录 一、EA(External Attention) 二、Multi Head Self Attention 三、…...

学习黑马vue

项目分析 项目下载地址:vue-admin-template-master: 学习黑马vue 项目下载后没有环境可参考我的篇文章,算是比较详细:vue安装与配置-CSDN博客 安装这两个插件可格式化代码,vscode这个软件是免费的,官网:…...

gdb本地调试版本移植至ARM-Linux系统

移植ncurses库 本文使用的ncurses版本为ncurses-5.9.tar.gz 下载地址:https://ftp.gnu.org/gnu/ncurses/ncurses-5.9.tar.gz 1. 将ncurses压缩包拷贝至Linux主机或使用wget命令下载并解压 tar-zxvf ncurses-5.9.tar.gz 2. 解压后进入到ncurses-5.9目录…...

《Linux C编程实战》笔记:实现自己的ls命令

关键函数的功能及说明 1.void display_attribute(struct stat buf,char *name) 函数功能:打印文件名为name的文件信息,如 含义分别为:文件的类型和访问权限,文件的链接数,文件的所有者,文件所有者所属的组…...

Python个人代码随笔(观看无益,请跳过)

异常抛错:一般来说,在程序中,遇到异常时,会从这一层逐层往外抛错,一直抛到最外层,由最外层把错误显示在用户终端。 try:raise ValueError("A value error...") except ValueError:print("V…...

Unity中实现ShaderToy卡通火(总结篇)

文章目录 前言一、把卡通火修改为后处理效果1、在Shader属性面板定义属性接收帧缓存纹理2、在片元着色器对其纹理采样后,与卡通火相加输出请添加图片描述 二、我们自定义卡通火1、修改 _CUTOFF 使卡通火显示在屏幕两侧2、使火附近屏幕偏红色 前言 在之前的文章中&a…...

等保2.0的变化

1法律地位得到确认 《中华人民共和国网络安全法》第21条规定“国家实行网络安全等级保护制度”,要求“网络运营者应当按照网络安全等级保护制度要求,履行安全保护义务”;第31条规定“对于国家关键信息基础设施,在网络安全等级保护…...

漏洞复现-网神SecGate3600防火墙敏感信息泄露漏洞(附漏洞检测脚本)

免责声明 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直接或者间接的…...

ArkTS入门

代码结构分析 struct Index{ } 「自定义组件:可复用的UI单元」 xxx 「装饰器:用来装饰类结构、方法、变量」 Entry 标记当前组件是入口组件(该组件可被独立访问,通俗来讲:它自己就是一个页面)Component 用…...

JS中for循环之退出循环

我为大家介绍一下退出循环的两种方法 1.continue 退出本次循环&#xff0c;一般用于排除或者跳过某一个选项的时候&#xff0c;可以使用continue for(let i 0;i<5;i){if(i 3){continue}// 跳过了3console.log(i) //0 1 2 4}2.break 退出整个for循环&#xff0c;一般用于…...

《Global illumination with radiance regression functions》

总结一下最近看的这篇结合神经网络的全局光照论文。 论文的主要思想是利用了神经网络的非线性特性去拟合全局光照中的间接光照部分&#xff0c;采用了基础的2层MLP去训练&#xff0c;最终能实现一些点光源、glossy材质的光照渲染。为了更好的理解、其输入输出表示如下。 首先…...

华南理工C++试卷

诚信应考 , 考试作弊将带来严重后果&#xff01; 《C程序设计试卷》 注意事项&#xff1a;1. 考前请将密封线内填写清楚&#xff1b; 2. 所有答案请答在试卷的答案栏上&#xff1b; 3&#xff0e;考试形式&#xff1a;闭卷 4. 本试卷共 五 大题&#xff0c;满分100分&#xff…...

0001.WIN7(64位)安装ADS1.2出现L6218错误

用了十多年的笔记本电脑系统出现问题&#xff0c;硬件升级重装以后安装ADS1.2。在编译代码的时候出现L6218错误。如下&#xff1a; 图片是从网上找的&#xff0c;我编译出错的界面没有保留下来。 首先&#xff0c;代码本身没有任何问题 &#xff0c;代码在win7(32位)下编译没有…...

HBuilderX 配置 夜神模拟器 详细图文教程

在电脑端查看App的效果&#xff0c;不用真机调试&#xff0c;下载一个模拟器就可以了 --- Nox Player&#xff0c;夜神模拟器&#xff0c;是一款 Android 模拟器。他的使用非常安全&#xff0c;最重要的是完全免费。 一. 安装模拟器 官网地址&#xff1a; (yeshen.com) 二.配…...

10、神秘的“位移主题”

神秘的“位移主题” 1、什么是位移主题2、位移主题的消息格式3、位移主题是怎么被创建的4、什么地方会用到位移主题5、位移主题的删除机制 本章主题是&#xff1a;Kafka 中的内部主题&#xff08;Internal Topic&#xff09;__consumer_offsets。 __consumer_offsets 在 Kafka …...

【Linux】dump命令使用

dump命令 dump命令用于备份文件系统。使用dump命令可以检查ext2/3/4文件系统上的文件&#xff0c;并确定哪些文件需要备份。这些文件复制到指定的磁盘、磁带或其他存储介质保管。 语法 dump [选项] [目录|文件系统] bash: dump: 未找到命令... 安装dump yum -y install …...

使用 TensorFlow 创建生产级机器学习模型(基于数据流编程的符号数学系统)——学习笔记

资源出处&#xff1a;初学者的 TensorFlow 2.0 教程 | TensorFlow Core (google.cn) 前言 对于新框架的学习&#xff0c;阅读官方文档是一种非常有效的方法。官方文档通常提供了关于框架的详细信息、使用方法和示例代码&#xff0c;可以帮助你快速了解和掌握框架的使用。 如…...

vue实现悬浮窗拖动的自定义指令

首先在自己的项目根目录下建一个 src --> config --> drag.js 然后在main.js中全局引入 //鼠标拖动 import drag from /config/drag; Vue.use(drag); drag.js文件相关代码 import Vue from vue; //使用Vue.directive()定义一个全局指令 //1.参数一&#xff1a;指令的…...

gitee(ssh)同步本地

一、什么是码云 gitee Git的”廉价平替” > 服务器在国内&#xff0c;运行不费劲 在国内也形成了一定的规模 git上的一些项目插件等在码云上也可以找得到 二、创建仓库 三、删除仓库 四、仓库与本地同步 > 建立公钥 五、把仓库同步到本地 六、在本地仓库中创建vue项目…...

Redis新数据类型-Bitmaps

目录 Bitmaps 简介 命令 1. setbit (1) 格式 (2) 实例 2. getbit (1) 格式 (2) 实例 3. bitcount (1) 格式 (2) 实例 4. bitop (1) 格式 (2) 实例 我的其他博客 Bitmaps 简介 Bitmaps 是 Redis 的一种新数据类型&#xff0c;它是一种用于存储位信息的数据结构&…...

基于算法竞赛的c++编程(28)结构体的进阶应用

结构体的嵌套与复杂数据组织 在C中&#xff0c;结构体可以嵌套使用&#xff0c;形成更复杂的数据结构。例如&#xff0c;可以通过嵌套结构体描述多层级数据关系&#xff1a; struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

centos 7 部署awstats 网站访问检测

一、基础环境准备&#xff08;两种安装方式都要做&#xff09; bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats&#xff0…...

聊聊 Pulsar:Producer 源码解析

一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台&#xff0c;以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中&#xff0c;Producer&#xff08;生产者&#xff09; 是连接客户端应用与消息队列的第一步。生产者…...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

服务器硬防的应用场景都有哪些?

服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式&#xff0c;避免服务器受到各种恶意攻击和网络威胁&#xff0c;那么&#xff0c;服务器硬防通常都会应用在哪些场景当中呢&#xff1f; 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...

CSS设置元素的宽度根据其内容自动调整

width: fit-content 是 CSS 中的一个属性值&#xff0c;用于设置元素的宽度根据其内容自动调整&#xff0c;确保宽度刚好容纳内容而不会超出。 效果对比 默认情况&#xff08;width: auto&#xff09;&#xff1a; 块级元素&#xff08;如 <div>&#xff09;会占满父容器…...