当前位置: 首页 > news >正文

YOLOv5算法改进(4)— 添加CA注意力机制

前言:Hello大家好,我是小哥谈。注意力机制是近年来深度学习领域内的研究热点,可以帮助模型更好地关注重要的特征,从而提高模型的性能。在许多视觉任务中,输入数据通常由多个通道组成,例如图像中的RGB通道或视频中的时间序列帧。传统的卷积神经网络(CNN)在处理这些通道时通常是独立地对每个通道进行操作,忽略了通道之间的相互作用。CA注意力机制通过引入通道注意力来解决这个问题。它能够自动学习到不同通道之间的关联性和重要性,从而增强模型对输入数据的建模能力。具体来说,CA注意力机制通过计算每个通道的权重,使得模型能够更加关注重要的通道,并抑制不重要的通道。这样可以提高模型在处理多通道输入数据时的表达能力和性能。🌈 

前期回顾:

          YOLOv5算法改进(1)— 如何去改进YOLOv5算法

          YOLOv5算法改进(2)— 添加SE注意力机制

          YOLOv5算法改进(3)— 添加CBAM注意力机制

          目录

🚀1.论文

🚀2.CA注意力机制的原理及实现

🚀3.添加CA注意力机制的好处 

🚀4.添加CA注意力机制的方法

💥💥步骤1:在common.py中添加CA模块

💥💥步骤2:在yolo.py文件中加入类名

💥💥步骤3:创建自定义yaml文件

💥💥步骤4:修改yolov5s_CA.yaml文件 

💥💥步骤5:验证是否加入成功

💥💥步骤6:修改train.py中的'--cfg'默认参数

🚀5.添加C3_CA注意力机制的方法(在C3模块中添加)

💥💥步骤1:在common.py中添加CABottleneck和C3_CA模块

💥💥步骤2:在yolo.py文件里parse_model函数中加入类名

​💥💥步骤3:创建自定义yaml文件

​💥💥步骤4:验证是否加入成功

​💥💥步骤5:修改train.py中的'--cfg'默认参数 

🚀1.论文

目前,轻量级网络的注意力机制大都采用 SE 模块,仅考虑了通道间的信息,忽略了位置信息。尽管后来的 BAM 和 CBAM 尝试在降低通道数后通过卷积来提取位置注意力信息,但卷积只能提取局部关系,缺乏长距离关系提取的能力。为此,论文提出了新的高效注意力机制CA(coordinate attention),能够将横向和纵向的位置信息编码到 channel attention 中,使得移动网络能够关注大范围的位置信息又不会带来过多的计算量。🌴

论文题目:Coordinate Attention for Efficient Mobile Network Design

论文地址:https://arxiv.org/abs/2103.02907

代码实现:GitHub - houqb/CoordAttention: Code for our CVPR2021 paper coordinate attention 


🚀2.CA注意力机制的原理及实现

CA(Channel Attention)注意力机制是一种在深度学习中常用的注意力机制之一,用于增强模型对于不同通道(channel)之间的特征关联性。📚

其原理如下:👇

(1)输入特征经过卷积等操作得到中间特征表示。

(2)中间特征表示经过两个并行的操作:全局平均池化和全局最大池化,得到全局特征描述。

(3)全局特征描述通过两个全连接层生成注意力权重。

(4)注意力权重与中间特征表示相乘,得到加权后的特征表示。

(5)加权后的特征表示经过适当的调整(如残差连接)后,作为下一层的输入。

CA注意力的实现如图所示,可以认为分为两个并行阶段

将输入特征图分别在为宽度高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W],在经过宽方向的平均池化后,获得的特征层shape为[C, H, 1],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[C, 1, W],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C, 1, H+W],利用卷积+标准化+激活函数获得特征。

之后再次分开为两个并行阶段,再将宽高分开成为:[C, 1, H][C, 1, W],之后进行转置。获得两个特征层[C, H, 1][C, 1, W]

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况,乘上原有的特征就是CA注意力机制


🚀3.添加CA注意力机制的好处 

作者通过将位置信息嵌入到通道注意力中提出了一种新颖的移动网络注意力机制,将其称为“Coordinate Attention”。其为即插即用的注意力模块,能插入任何经典网络🍉

加入CA注意力机制的好处包括:

 (1)增强特征表达:CA注意力机制能够自适应地选择和调整不同通道的特征权重,从而更好地表达输入数据。它可以帮助模型发现和利用输入数据中重要的通道信息,提高特征的判别能力和区分性。

 (2)减少冗余信息:通过抑制不重要的通道,CA注意力机制可以减少输入数据中的冗余信息,提高模型对关键特征的关注度。这有助于降低模型的计算复杂度,并提高模型的泛化能力。

 (3)提升模型性能:加入CA注意力机制可以显著提高模型在多通道输入数据上的性能。它能够帮助模型更好地捕捉到通道之间的相关性和依赖关系,从而提高模型对输入数据的理解能力。

综上所述,加入CA注意力机制可以有效地增强模型对多通道输入数据的建模能力,提高模型性能和泛化能力。它在图像处理、视频分析等任务中具有重要的应用价值。🌿


🚀4.添加CA注意力机制的方法

💥💥步骤1:在common.py中添加CA模块

将下面的CA模块的代码复制粘贴到common.py文件的末尾。

# CA
class h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6
class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CoordAtt(nn.Module):def __init__(self, inp, oup, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()#c*1*Wx_h = self.pool_h(x)#c*H*1#C*1*hx_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)#C*1*(h+w)y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn out

具体如下图所示:

💥💥步骤2:在yolo.py文件中加入类名

首先在yolo.py文件中找到parse_model函数,然后将 CoordAtt 添加到这个注册表里。

💥💥步骤3:创建自定义yaml文件

models文件夹中复制yolov5s.yaml粘贴并命名为yolov5s_CA.yaml

💥💥步骤4:修改yolov5s_CA.yaml文件 

本步骤是修改yolov5s_CA.yaml,将CA模块添加到我们想添加的位置。在这里,我将[-1,1,CoordAtt,[1024]]添加到SPPF的上一层,即下图中所示位置。

说明:♨️♨️♨️

注意力机制可以加在Backbone、Neck、Head等部分,常见的有两种:一种是在主干的SPPF前面添加一层;二是将Backbone中的C3全部替换。不同的位置效果可能不同,需要我们去反复测试。

这里需要注意一个问题,当在网络中添加新的层之后,那么该层网络后面的层的编号会发生变化。原本Detect指定的是[17,20,23]层,所以,我们在添加了CA模块之后,也要对这里进行修改,即原来的17层,变成18层,原来的20层,变成21层,原来的23层,变成24层;所以这里需要改为[18,21,24]。同样的,Concat的系数也要修改,这样才能保持原来的网络结构不会发生特别大的改变,我们刚才把CA模块加到了第9层,所以第9层之后的编号都需要加1,这里我们把后面两个Concat的系数分别由[-1,14][-1,10]改为[-1,15][-1,11]。🌻

具体如下图所示:

💥💥步骤5:验证是否加入成功

yolo.py文件里,将配置改为我们刚才自定义的yolov5s_CA.yaml

 然后运行yolo.py,得到结果。

找到了CA模块,说明我们添加成功了。🎉🎉🎉

💥💥步骤6:修改train.py中的'--cfg'默认参数

train.py文件中找到 parse_opt函数,然后将第二行'--cfg'的default改为 'models/yolov5s_CA.yaml',然后就可以开始进行训练了。🎈🎈🎈


🚀5.添加C3_CA注意力机制的方法(在C3模块中添加)

上面是单独添加注意力层,接下来的方法是在C3模块中加入注意力层。这个策略是将CA注意力机制添加到Bottleneck,替换Backbone中所有的C3模块。🌳

💥💥步骤1:在common.py中添加CABottleneck和C3_CA模块

将下面的代码复制粘贴到common.py文件的末尾。

# CA
class h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CABottleneck(nn.Module):# Standard bottleneckdef __init__(self, c1, c2, shortcut=True, g=1, e=0.5, ratio=32):  # ch_in, ch_out, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c_, c2, 3, 1, g=g)self.add = shortcut and c1 == c2# self.ca=CoordAtt(c1,c2,ratio)self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, c1 // ratio)self.conv1 = nn.Conv2d(c1, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, c2, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, c2, kernel_size=1, stride=1, padding=0)def forward(self, x):x1 = self.cv2(self.cv1(x))n, c, h, w = x.size()# c*1*Wx_h = self.pool_h(x1)# c*H*1# C*1*hx_w = self.pool_w(x1).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)# C*1*(h+w)y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = x1 * a_w * a_h# out=self.ca(x1)*x1return x + out if self.add else outclass C3_CA(C3):# C3 module with CABottleneck()def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):super().__init__(c1, c2, n, shortcut, g, e)c_ = int(c2 * e)  # hidden channelsself.m = nn.Sequential(*(CABottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

💥💥步骤2:在yolo.py文件里parse_model函数中加入类名

yolo.py文件parse_model函数中,加入CABottleneckC3_CA这两个模块。

​💥💥步骤3:创建自定义yaml文件

按照上面的步骤创建yolov5s_C3_CA.yaml文件,替换4个C3模块。

​💥💥步骤4:验证是否加入成功

yolo.py文件里配置刚才我们自定义的yolov5s_C3_CA.yaml,然后运行。 

​💥💥步骤5:修改train.py中的'--cfg'默认参数 

train.py文件中找到parse_opt函数,然后将第二行'--cfg'的default改为 'models/yolov5s_C3_CA.yaml',然后就可以开始进行训练了。🎈🎈🎈


相关文章:

YOLOv5算法改进(4)— 添加CA注意力机制

前言:Hello大家好,我是小哥谈。注意力机制是近年来深度学习领域内的研究热点,可以帮助模型更好地关注重要的特征,从而提高模型的性能。在许多视觉任务中,输入数据通常由多个通道组成,例如图像中的RGB通道或…...

无涯教程-PHP - XML GET

XML Get已用于从xml文件获取节点值。以下示例显示了如何从xml获取数据。 Note.xml 是xml文件&#xff0c;可以通过php文件访问。 <SUBJECT><COURSE>Android</COURSE><COUNTRY>India</COUNTRY><COMPANY>LearnFk</COMPANY><PRICE…...

Spark Standalone环境搭建及测试

&#x1f947;&#x1f947;【大数据学习记录篇】-持续更新中~&#x1f947;&#x1f947; 篇一&#xff1a;Linux系统下配置java环境 篇二&#xff1a;hadoop伪分布式搭建&#xff08;超详细&#xff09; 篇三&#xff1a;hadoop完全分布式集群搭建&#xff08;超详细&#xf…...

【PHP】流程控制-ifswitchforwhiledo-whilecontinuebreak

文章目录 流程控制顺序结构分支结构if分支switch分支 循环结构for循环while循环do-while循环continue和break 流程控制 顺序结构&#xff1a;代码从上往下&#xff0c;顺序执行。&#xff08;代码执行的最基本结构&#xff09; 分支结构&#xff1a;给定一个条件&#xff0c;…...

Pytorch-day04-模型构建-checkpoint

PyTorch 模型构建 1、GPU配置2、数据预处理3、划分训练集、验证集、测试集4、选择模型5、设定损失函数&优化方法6、模型效果评估 #导入常用包 import os import numpy as np import torch from torch.utils.data import Dataset, DataLoader from torchvision.transfor…...

使用Xshell7控制多台服务同时安装ZK最新版集群服务

一: 环境准备: 主机名称 主机IP 节点 (集群内通讯端口|选举leader|cline端提供服务)端口 docker0 192.168.1.100 node-0 2888 | 3888 | 2181 docker1 192.168.1.101 node-1 2888 | 388…...

python numpy array dtype和astype类型转换的区别

Python3 本身对整数的支持做了提升&#xff0c;可以支持无限长度的整数&#xff1a;比如&#xff1a; b 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffPython的模块numpy array定义的数组在windows和MACOS上默认长度是…...

浮动属性样式

&#x1f353;浮动属性 属性名称中文注释备注float设置盒子浮动left左浮动&#xff0c;right右浮动&#xff0c;none不浮动clear清除浮动left清除左浮动&#xff0c;right清除右浮动&#xff0c;both左右浮动都清除&#xff08;注意&#xff1a;clear清除浮动一般只有作用在块…...

keepalived双机热备 (四十五)

一、概述 Keepalived 是一个基于 VRRP 协议来实现的 LVS 服务高可用方案&#xff0c;可以解决静态路由出现的单点故障问题。 原理 在一个 LVS 服务集群中通常有主服务器&#xff08;MASTER&#xff09;和备份服务器&#xff08;BACKUP&#xff09;两种角色的服务器…...

SpringBoot整合阿里云OSS,实现图片上传

在项目中&#xff0c;将图片等文件资源上传到阿里云的OSS&#xff0c;减少服务器压力。 项目中导入阿里云的SDK <dependency><groupId>com.aliyun.oss</groupId><artifactId>aliyun-sdk-oss</artifactId><version>3.10.2</version>…...

Dynaminc Programming相关

目录 3.1 最长回文子串&#xff08;中等&#xff09;&#xff1a;标志位 3.2 最大子数组和&#xff08;中等&#xff09;&#xff1a;动态规划 3.3 爬楼梯&#xff08;简单&#xff09;&#xff1a;动态规划 3.4 买卖股票的最佳时机&#xff08;简单&#xff09;&#xff1…...

使用 Elasticsearch 轻松进行中文文本分类

本文记录下使用 Elasticsearch 进行文本分类&#xff0c;当我第一次偶然发现 Elasticsearch 时&#xff0c;就被它的易用性、速度和配置选项所吸引。每次使用 Elasticsearch&#xff0c;我都能找到一种更为简单的方法来解决我一贯通过传统的自然语言处理 (NLP) 工具和技术来解决…...

MNN学习笔记(八):使用MNN推理Mediapipe模型

1.项目说明 最近需要用到一些mediapipe中的模型功能&#xff0c;于是尝试对mediapipe中的一些模型进行转换&#xff0c;并使用MNN进行推理&#xff1b;主要模型包括&#xff1a;图像分类、人脸检测及人脸关键点mesh、手掌检测及手势关键点、人体检测及人体关键点、图像嵌入特征…...

主力吸筹指标及其分析和使用说明

文章目录 主力吸筹指标指标代码分析使用说明使用配图主力吸筹指标 VAR1:=REF(LOW,1); VAR2:=SMA(MAX(LOW-VAR1,0),3,1)/SMA(ABS(LOW-VAR1),3,1)*100; VAR3:=EMA(VAR2,3); VAR4:=LLV(LOW,34); VAR5:=HHV(VAR3,34); VAR7:=EMA(IF(LOW<=VAR4,(VAR3+VAR5*2)/2,0),3); /*底线:0,…...

Python高光谱遥感数据处理与高光谱遥感机器学习方法教程

详情点击链接&#xff1a;Python高光谱遥感数据处理与高光谱遥感机器学习方法教程 第一&#xff1a;高光谱基础 一&#xff1a;高光谱遥感基本 01)高光谱遥感 02)光的波长 03)光谱分辨率 04)高光谱遥感的历史和发展 二&#xff1a;高光谱传感器与数据获取 01)高光谱遥感…...

【洛谷】P1678 烦恼的高考志愿

原题链接&#xff1a;https://www.luogu.com.cn/problem/P1678 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 将每个学校的分数线用sort()升序排序&#xff0c;再二分查找每个学校的分数线&#xff0c;通过二分找到每个同学估分附近的分数线。 最后…...

开机自启CPU设置定频

sudo apt-get install expect sudo apt-get install cpufrequtils具体步骤如下&#xff1a; 安装 cpufrequtils 工具 ⚫ sudo apt-get install cpufrequtils ⚫ 需要联网下载修改配置文件 ⚫ sudo vi /etc/init.d/cpufrequtils ⚫ 将 GOVERNOR“ondemand” 改为&#xff1a; &g…...

嵌入式Linux开发实操(十二):PWM接口开发

# 前言 使用pwm实现LED点灯,可以说是嵌入式系统的一个基本案例。那么嵌入式linux系统下又如何实现pwm点led灯呢? # PWM在嵌入式linux下的操作指令 实际使用效果如下,可以通过shell指令将开发板对应的LED灯点亮。 点亮3个LED,则分别使用pwm1、pwm2和pwm3。 # PWM引脚的硬…...

消息中间件介绍

消息队列已经逐渐成为企业IT系统内部通信的核心手段。它具有低耦合、可靠投递、广播、流量控制、最终一致性等一系列功能&#xff0c;成为异步RPC的主要手段之一。当今市面上有很多主流的消息中间件&#xff0c;如ActiveMQ、RabbitMQ&#xff0c;Kafka&#xff0c;还有阿里巴巴…...

[Unity] 基础的编程思想, 组件式开发

熟悉 C# 开发的朋友, 在刚进入 Unity 开发时, 不可避免的会有一些迷惑, 例如不清楚 Unity 自己的思想, 如何设计与架构一个应用程序之类的. 本篇文章简要的介绍一下 Unity 的基础编程思想. 独立 Unity 很少使用 C# 的标准库, 例如 C# 的网络, 事件驱动, 对象模型, 这些概念在 …...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

Linux链表操作全解析

Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表&#xff1f;1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始&#xff0c;我们会探讨数据链路层的差错控制功能&#xff0c;差错控制功能的主要目标是要发现并且解决一个帧内部的位错误&#xff0c;我们需要使用特殊的编码技术去发现帧内部的位错误&#xff0c;当我们发现位错误之后&#xff0c;通常来说有两种解决方案。第一…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

《基于Apache Flink的流处理》笔记

思维导图 1-3 章 4-7章 8-11 章 参考资料 源码&#xff1a; https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马&#xff08;服务器方面的&#xff09;的原理&#xff0c;连接&#xff0c;以及各种木马及连接工具的分享 文件木马&#xff1a;https://w…...