【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练
《博主简介》
小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!
《------往期经典推荐------》
一、AI应用软件开发实战专栏【链接】
| 项目名称 | 项目名称 |
|---|---|
| 1.【人脸识别与管理系统开发】 | 2.【车牌识别与自动收费管理系统开发】 |
| 3.【手势识别系统开发】 | 4.【人脸面部活体检测系统开发】 |
| 5.【图片风格快速迁移软件开发】 | 6.【人脸表表情识别系统】 |
| 7.【YOLOv8多目标识别与自动标注软件开发】 | 8.【基于YOLOv8深度学习的行人跌倒检测系统】 |
| 9.【基于YOLOv8深度学习的PCB板缺陷检测系统】 | 10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】 |
| 11.【基于YOLOv8深度学习的安全帽目标检测系统】 | 12.【基于YOLOv8深度学习的120种犬类检测与识别系统】 |
| 13.【基于YOLOv8深度学习的路面坑洞检测系统】 | 14.【基于YOLOv8深度学习的火焰烟雾检测系统】 |
| 15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】 | 16.【基于YOLOv8深度学习的舰船目标分类检测系统】 |
| 17.【基于YOLOv8深度学习的西红柿成熟度检测系统】 | 18.【基于YOLOv8深度学习的血细胞检测与计数系统】 |
| 19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】 | 20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】 |
| 21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】 | 22.【基于YOLOv8深度学习的路面标志线检测与识别系统】 |
| 22.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】 | 23.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】 |
| 24.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】 | 25.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】 |
| 26.【基于YOLOv8深度学习的人脸面部表情识别系统】 | 27.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】 |
二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】
《------正文------》
## 搜索C2f源码位置并新建C2f类在项目目录中全局搜索class c2f即可找到c2f的源码位置。然后打开源码位置,进行相应修改。源码路径为:ultralytics/nn/modules/block.py
在原文件中直接copy一份c2f类的源码,然后命名为c2f_Attention,如下所示:

在不同文件导入新建的C2f类
在ultralytics/nn/modules/block.py顶部,all中添加刚才创建的类的名称:c2f_Attention,如下图所示:

同样需要在ultralytics/nn/modules/__init__.py文件,相应位置导入刚出创建的c2f_Attention类。如下图:

还需要在ultralytics/nn/tasks.py中导入创建的c2f_Attention类,,如下图:

在parse_model解析函数中添加C2f类
在ultralytics/nn/tasks.py的parse_model解析网络结构的函数中,加入c2f_Attention类,如下图:

创建新的配置文件c2f_att_yolov8.yaml
在ultralytics/cfg/models/v8目录下新建c2f_att_yolov8.yaml配置文件,内容如下:
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPss: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPsm: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPsl: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8- [-1, 6, C2f_Attention, [256, True]]- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16- [-1, 6, C2f_Attention, [512, True]]- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32- [-1, 3, C2f_Attention, [1024, True]]- [-1, 1, SPPF, [1024, 5]] # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]] # cat backbone P4- [-1, 3, C2f, [512]] # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]] # cat backbone P3- [-1, 3, C2f, [256]] # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]] # cat head P4- [-1, 3, C2f, [512]] # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]] # cat head P5- [-1, 3, C2f, [1024]] # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
新的c2f_att_yolov8.yaml配置文件与原yolov8.yaml文件的对比如下:

在C2f中添加注意力:ShuffleAttention
注意:对于有通道数参数的注意力机制,其输入通道数为其上层的输出通道数。这个注意力添加的位置有关。
在路径ultralytics/nn下新建注意力模块,ShuffleAttention.py文件。内容如下:
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameterclass ShuffleAttention(nn.Module):def __init__(self, channel=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()# group into subfeaturesx = x.view(b * self.G, -1, h, w) # bs*G,c//G,h,w# channel_splitx_0, x_1 = x.chunk(2, dim=1) # bs*G,c//(2*G),h,w# channel attentionx_channel = self.avg_pool(x_0) # bs*G,c//(2*G),1,1x_channel = self.cweight * x_channel + self.cbias # bs*G,c//(2*G),1,1x_channel = x_0 * self.sigmoid(x_channel)# spatial attentionx_spatial = self.gn(x_1) # bs*G,c//(2*G),h,wx_spatial = self.sweight * x_spatial + self.sbias # bs*G,c//(2*G),h,wx_spatial = x_1 * self.sigmoid(x_spatial) # bs*G,c//(2*G),h,w# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1) # bs*G,c//G,h,wout = out.contiguous().view(b, -1, h, w)# channel shuffleout = self.channel_shuffle(out, 2)return out
在ultralytics/nn/tasks.py中导入,并修改在parse_model解析网络结构的函数中,添加解析代码:


注意力不同位置添加方法
在ultralytics/nn/modules/block.py中的c2f_Attention类中代码相应位置添加注意力机制:
1 . 方式一:在self.cv1后面添加注意力机制

2.方式二:在self.cv2后面添加注意力机制

3.方式三:在c2f的bottleneck中添加注意力机制,将Bottleneck类,复制一份,并命名为Bottleneck_Attention,然后,在Bottleneck_Attention的cv2后面添加注意力机制,同时修改C2f_Attention类别中的Bottleneck为Bottleneck_Attention。如下图所示:

加载配置文件并训练
加载c2f_att_yolov8.yaml配置文件,并运行train.py训练代码:
#coding:utf-8
from ultralytics import YOLOif __name__ == '__main__':model = YOLO('ultralytics/cfg/models/v8/c2f_att_yolov8.yaml')model.load('yolov8n.pt') # loading pretrain weightsmodel.train(data='datasets/TomatoData/data.yaml', epochs=150, batch=2)
注意观察,打印出的网络结构是否正常修改,如下图所示:

【源码免费获取】
为了小伙伴们能够,更好的学习实践,本文已将所有代码、示例数据集、论文等相关内容打包上传,供小伙伴们学习。获取方式如下:
关注下方名片G-Z-H:【阿旭算法与机器学习】,发送【yolov8改进】即可免费获取

结束语
关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!
觉得不错的小伙伴,感谢点赞、关注加收藏哦!
相关文章:
【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练
《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…...
Hive聚合函数详细讲解
Hive中的聚合函数用于在数据上进行计算并返回单个值,这些值通常是基于一组行或列的汇总。以下是您提到的聚合函数的详细讲解,包括案例和使用注意事项: SUM() 功能:计算某列的总和。语法:SUM(column)案例:SELECT SUM(salary) FROM employees;注意事项:通常用于数值型列。…...
方案:如何列出 Jira 中授予用户的所有权限
文章目录 概述解决方案REST API数据库 概述 为了进行故障排除或某些管理任务,我们可能想知道给定用户拥有的所有权限。 Jira 通过其 UI 提供权限助手和类似工具,但对于所有权限的列表,我们只能通过作为用户本身进行身份验证的 REST API 请求…...
Flutter-Web从0到部署上线(实践+埋坑)
本文字数:7743字 预计阅读时间:60分钟 01 前言 首先说明一下,这篇文章是给具备Flutter开发经验的客户端同学看的。Flutter 的诞生虽然来自 Google 的 Chrome 团队,但大家都知道 Flutter 最先支持的平台是 Android 和 iOSÿ…...
Redis键值设计
文章目录 1.优雅的key2.拒绝BigKey2.1.什么是BigKey2.2.BigKey的危害2.3.如何发现BigKey2.4.如何删除BigKey 3.恰当的数据类型 1.优雅的key 2.拒绝BigKey 2.1.什么是BigKey 2.2.BigKey的危害 2.3.如何发现BigKey scan扫描示例代码 final static int STR_MAX_LEN 10 * 1024;fi…...
CSS 下载进度条
<template><view class=btn>下载中</view></template><script></script><style>/* 设置整个页面的样式 */body {width: 100vw; /* 页面宽度为视口宽度 */background: #000000; /* 背景颜色为白色 */display: flex; /* 使用 flex…...
风力发电防雷监测浪涌保护器的应用解决方案
风力发电是一种利用风能转化为电能的可再生能源技术,具有清洁、环保、低碳的优点,是应对全球气候变化和能源危机的重要途径之一。然而,风力发电也面临着一些技术和经济的挑战,其中之一就是雷电的威胁。由于风力发电机组通常位于高…...
开发安全之:Database access control
Overview 如果没有适当的 access control,就会执行一个包含用户控制主键的 SQL 指令,从而允许攻击者访问未经授权的记录。 Details Database access control 错误在以下情况下发生: 1. 数据从一个不可信赖的数据源进入程序。 2. 这个数据用…...
VMware Vsphere 日志:用户 dcui@127.0.01已以vMware-client/6.5.0 的身份登录
一、事件截图: 二、解决办法 原因: 三、解决办法 1.开启锁定模式 2.操作 1、从清单中选择您的 ESXi 主机,然后转至管理 > 设置 > 安全配置文件,然后单击锁定模式的编辑按钮 2、在打开的锁定模式窗口中,选中启…...
全面了解网络性能监测:从哪些方面进行监测?
目录 摘要 引言 CPU内存监控 磁盘监控 网络监控 GPU监控 帧率监控 总结 摘要 本文介绍了网络性能监测的重要性,并详细介绍了一款名为克魔助手的应用开发工具,该工具提供了丰富的性能监控功能,包括CPU、内存、磁盘、网络等指标的实时监…...
智能分析网关V4基于AI视频智能分析技术的周界安全防范方案
一、背景分析 随着科技的不断进步,AI视频智能检测技术已经成为周界安全防范的一种重要手段。A智能分析网关V4基于深度学习和计算机视觉技术,可以通过多种AI周界防范算法,实时、精准地监测人员入侵行为,及时发现异常情况并发出警报…...
小红书再不赚钱就晚了
2023年12月,小红书COO柯南公开表态,五年前,自己当时还非常坚定地表态小红书不要做电商。“那时候,我是站在社区的视角,但现在,我开始负责电商业务了。”此前,小红书已经整合了电商业务与直播业务…...
思腾云计算三大业务:算力租赁、服务器托管、思腾公有云
1、算力租赁业务 裸金属服务器就是传统物理服务器的升级版,也可以说是介于物理服务器和云主机之间的一种形态。既具备传统物理服务器卓越性能,又具备云主机一样的便捷管理平台,兼具了双方的优点,在满足核心应用场景对高性能及稳定…...
HarmonyOS之sqlite数据库的使用
从API Version 9开始,鸿蒙开发中sqlite使用新接口ohos.data.relationalStore 但是 relationalStore在 getRdbStore操作时,在预览模式运行或者远程模拟器运行都会报错,导致无法使用。查了一圈说只有在真机上可以正常使用,因此这里…...
网络抓包命令tcpdump
网络抓包命令tcpdump "tcpdump -i any -nn -vv tcp port 9095 -s 0 -w dump.cap"命令是一个网络抓包命令,用于捕获流经指定网络接口的TCP协议、端口号为9095的网络数据包,并将这些数据包写入到名为"dump.cap"的文件中。 具体参数解…...
iTMSTransporter上传ipa文件
背景 uni-app云打包之后生成的ipa包需要上传到app store上,applicationloader和香蕉云编都收费,转用iTMSTransporter上传 环境:mac 第1步 选择安装目录 xcode-select --switch /Applications/Xcode.app/Contents/Developer第2步 下载 xcru…...
2024华数杯国际数学建模A题思路模型详解
2024华数杯国际数学建模A题思路论文:1.17上午第一时间持续更新,详细内容见文末名片 建立一个模型来描述放射性废水在海水中的扩散速率和方向,考虑到涉及的物理过程和环境因素的复杂性,我们通常会使用一个简化的扩散模型作为起点…...
JS-定时器-间歇函数(一)
• 定时器函数介绍 定时器函数在开发中的使用场景 网页中经常会需要一种功能:每隔一段时间需要自动执行一段代码,不需要我们手动去触发例如:网页中的倒计时要实现这种需求,需要定时器函数定时器函数有两种,今天我先讲…...
AttributeError: module ‘openai‘ has no attribute ‘error‘解决方案
大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…...
每日一记:一个windows的bat脚本工具集
最近在工作上遇到要校验文件的问题,例如,下载了一个文件之后,通过查看文件的md5来校验文件是否完整,这个动作在linux上很简单,但在windows上也不难,可以通过 certutil 命令实现,该命令通常可用于…...
微信小程序逆向工程终极指南:wxappUnpacker完整实战解析
微信小程序逆向工程终极指南:wxappUnpacker完整实战解析 【免费下载链接】wxappUnpacker forked from https://github.com/qwerty472123/wxappUnpacker 项目地址: https://gitcode.com/gh_mirrors/wxappu/wxappUnpacker 微信小程序逆向工程是安全研究人员和技…...
JMeter接口测试实战:从鉴权验证到故障注入的工程化落地
1. 为什么接口测试不能只靠“点点点”——JMeter不是高级版Postman,而是工程化验证的起点很多人第一次接触JMeter,是在开发甩来一个接口文档后,下意识打开Postman填URL、选Method、点Send,看到返回200就松一口气:“通了…...
捡垃圾实战:让ESXi 7.0 U3识别老古董Mellanox ConnectX-2 10G网卡(附驱动修改全流程)
老硬件焕新:ESXi 7.0 U3下Mellanox ConnectX-2网卡驱动改造指南 在二手市场以几十元价格淘到的Mellanox ConnectX-2 10G双口网卡,性能依然强劲,却因为官方停止支持而无法在现代虚拟化平台上使用。本文将带你深入探索如何通过驱动改造…...
告别手动分割!用Python脚本一键生成VOC数据集所需的train.txt和val.txt
告别手动分割!用Python脚本一键生成VOC数据集所需的train.txt和val.txt 在计算机视觉项目中,数据集的准备往往是耗时最长的环节之一。特别是当我们需要按照VOC格式整理数据集时,手动分割训练集、验证集不仅效率低下,还容易引入人为…...
前端html字体包体积压缩,网站工程下字体压缩裁剪工具
整个网站项目如果字体包体积太大就会影响其加载速度,字体加载完会让页面字体突然变换。做一个工具他会自动检索网站上所有展现给用户的字符,然后原地裁剪字体。来解决这个问题。实现效果如下: 执行py文件以后,在网站字体文件所在目…...
内连接,左连接,右连接怎么区别开来?
区分这三种连接其实非常简单,核心就在于看**“谁的数据必须全部保留,谁的数据没有匹配就要被过滤掉”**。 为了让你彻底搞懂,我们可以把 user 表(用户)和 orders 表(订单)想象成两个班级&#x…...
Android多媒体开发避坑:深入理解DMABUF机制与RK3588上的常见泄漏点
Android多媒体开发中的DMABUF机制解析与RK3588内存泄漏实战指南 在RK3588这类高性能芯片上开发视频编解码、相机等多媒体应用时,追求零拷贝性能优化往往会引入DMABUF的使用。然而,这种看似完美的解决方案背后隐藏着复杂的内存管理陷阱。本文将带您深入理…...
AI学习 Newsletter 的手工感设计:从断点驱动到可追溯实践
1. 项目概述:这不是一份 newsletter,而是一份 AI 社区共建的实践手记 “Learn AI Together — Towards AI Community Newsletter #14”——看到这个标题,你第一反应可能是:又一份 AI 领域的资讯汇总?点开看看最新论文…...
别再被‘一亿像素’忽悠了!聊聊手机CMOS尺寸、像素和Remosaic那些事儿
手机CMOS尺寸、像素与成像质量的真相:别再被数字游戏迷惑 每次打开手机厂商的发布会,总能看到各种令人眼花缭乱的参数轰炸——"一亿像素"、"超大底传感器"、"超清画质"。这些营销术语让普通消费者一头雾水,甚至…...
