当前位置: 首页 > news >正文

【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于YOLOv8深度学习的行人跌倒检测系统】
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】
11.【基于YOLOv8深度学习的安全帽目标检测系统】12.【基于YOLOv8深度学习的120种犬类检测与识别系统】
13.【基于YOLOv8深度学习的路面坑洞检测系统】14.【基于YOLOv8深度学习的火焰烟雾检测系统】
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】16.【基于YOLOv8深度学习的舰船目标分类检测系统】
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】18.【基于YOLOv8深度学习的血细胞检测与计数系统】
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】22.【基于YOLOv8深度学习的路面标志线检测与识别系统】
22.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】23.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】
24.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】25.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】
26.【基于YOLOv8深度学习的人脸面部表情识别系统】27.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】

《------正文------》

## 搜索C2f源码位置并新建C2f类

在项目目录中全局搜索class c2f即可找到c2f的源码位置。然后打开源码位置,进行相应修改。源码路径为:ultralytics/nn/modules/block.py

在原文件中直接copy一份c2f类的源码,然后命名为c2f_Attention,如下所示:

在这里插入图片描述

在不同文件导入新建的C2f类

ultralytics/nn/modules/block.py顶部,all中添加刚才创建的类的名称:c2f_Attention,如下图所示:

在这里插入图片描述

同样需要在ultralytics/nn/modules/__init__.py文件,相应位置导入刚出创建的c2f_Attention类。如下图:

在这里插入图片描述

还需要在ultralytics/nn/tasks.py中导入创建的c2f_Attention类,,如下图:

在这里插入图片描述

parse_model解析函数中添加C2f类

ultralytics/nn/tasks.pyparse_model解析网络结构的函数中,加入c2f_Attention类,如下图:
在这里插入图片描述

创建新的配置文件c2f_att_yolov8.yaml

ultralytics/cfg/models/v8目录下新建c2f_att_yolov8.yaml配置文件,内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f_Attention, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f_Attention, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f_Attention, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

新的c2f_att_yolov8.yaml配置文件与原yolov8.yaml文件的对比如下:

在这里插入图片描述

在C2f中添加注意力:ShuffleAttention

注意:对于有通道数参数的注意力机制,其输入通道数为其上层的输出通道数。这个注意力添加的位置有关。

在路径ultralytics/nn下新建注意力模块,ShuffleAttention.py文件。内容如下:

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameterclass ShuffleAttention(nn.Module):def __init__(self, channel=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()# group into subfeaturesx = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w# channel_splitx_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w# channel attentionx_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1x_channel = x_0 * self.sigmoid(x_channel)# spatial attentionx_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,wx_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,wx_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,wout = out.contiguous().view(b, -1, h, w)# channel shuffleout = self.channel_shuffle(out, 2)return out

ultralytics/nn/tasks.py中导入,并修改在parse_model解析网络结构的函数中,添加解析代码:

在这里插入图片描述

在这里插入图片描述

注意力不同位置添加方法

ultralytics/nn/modules/block.py中的c2f_Attention类中代码相应位置添加注意力机制:

1 . 方式一:在self.cv1后面添加注意力机制

在这里插入图片描述

2.方式二:在self.cv2后面添加注意力机制

在这里插入图片描述

3.方式三:在c2fbottleneck中添加注意力机制,将Bottleneck类,复制一份,并命名为Bottleneck_Attention,然后,在Bottleneck_Attention的cv2后面添加注意力机制,同时修改C2f_Attention类别中的BottleneckBottleneck_Attention。如下图所示:

在这里插入图片描述

加载配置文件并训练

加载c2f_att_yolov8.yaml配置文件,并运行train.py训练代码:

#coding:utf-8
from ultralytics import YOLOif __name__ == '__main__':model = YOLO('ultralytics/cfg/models/v8/c2f_att_yolov8.yaml')model.load('yolov8n.pt') # loading pretrain weightsmodel.train(data='datasets/TomatoData/data.yaml', epochs=150, batch=2)

注意观察,打印出的网络结构是否正常修改,如下图所示:
在这里插入图片描述

【源码免费获取】

为了小伙伴们能够,更好的学习实践,本文已将所有代码、示例数据集、论文等相关内容打包上传,供小伙伴们学习。获取方式如下:

关注下方名片G-Z-H:【阿旭算法与机器学习】,发送【yolov8改进】即可免费获取

在这里插入图片描述


结束语

关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

觉得不错的小伙伴,感谢点赞、关注加收藏哦!

相关文章:

【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…...

Hive聚合函数详细讲解

Hive中的聚合函数用于在数据上进行计算并返回单个值,这些值通常是基于一组行或列的汇总。以下是您提到的聚合函数的详细讲解,包括案例和使用注意事项: SUM() 功能:计算某列的总和。语法:SUM(column)案例:SELECT SUM(salary) FROM employees;注意事项:通常用于数值型列。…...

方案:如何列出 Jira 中授予用户的所有权限

文章目录 概述解决方案REST API数据库 概述 为了进行故障排除或某些管理任务,我们可能想知道给定用户拥有的所有权限。 Jira 通过其 UI 提供权限助手和类似工具,但对于所有权限的列表,我们只能通过作为用户本身进行身份验证的 REST API 请求…...

Flutter-Web从0到部署上线(实践+埋坑)

本文字数:7743字 预计阅读时间:60分钟 01 前言 首先说明一下,这篇文章是给具备Flutter开发经验的客户端同学看的。Flutter 的诞生虽然来自 Google 的 Chrome 团队,但大家都知道 Flutter 最先支持的平台是 Android 和 iOS&#xff…...

Redis键值设计

文章目录 1.优雅的key2.拒绝BigKey2.1.什么是BigKey2.2.BigKey的危害2.3.如何发现BigKey2.4.如何删除BigKey 3.恰当的数据类型 1.优雅的key 2.拒绝BigKey 2.1.什么是BigKey 2.2.BigKey的危害 2.3.如何发现BigKey scan扫描示例代码 final static int STR_MAX_LEN 10 * 1024;fi…...

CSS 下载进度条

<template><view class=btn>下载中</view></template><script></script><style>/* 设置整个页面的样式 */body {width: 100vw; /* 页面宽度为视口宽度 */background: #000000; /* 背景颜色为白色 */display: flex; /* 使用 flex…...

风力发电防雷监测浪涌保护器的应用解决方案

风力发电是一种利用风能转化为电能的可再生能源技术&#xff0c;具有清洁、环保、低碳的优点&#xff0c;是应对全球气候变化和能源危机的重要途径之一。然而&#xff0c;风力发电也面临着一些技术和经济的挑战&#xff0c;其中之一就是雷电的威胁。由于风力发电机组通常位于高…...

开发安全之:Database access control

Overview 如果没有适当的 access control&#xff0c;就会执行一个包含用户控制主键的 SQL 指令&#xff0c;从而允许攻击者访问未经授权的记录。 Details Database access control 错误在以下情况下发生&#xff1a; 1. 数据从一个不可信赖的数据源进入程序。 2. 这个数据用…...

VMware Vsphere 日志:用户 dcui@127.0.01已以vMware-client/6.5.0 的身份登录

一、事件截图&#xff1a; 二、解决办法 原因&#xff1a; 三、解决办法 1.开启锁定模式 2.操作 1、从清单中选择您的 ESXi 主机&#xff0c;然后转至管理 > 设置 > 安全配置文件&#xff0c;然后单击锁定模式的编辑按钮 2、在打开的锁定模式窗口中&#xff0c;选中启…...

全面了解网络性能监测:从哪些方面进行监测?

目录 摘要 引言 CPU内存监控 磁盘监控 网络监控 GPU监控 帧率监控 总结 摘要 本文介绍了网络性能监测的重要性&#xff0c;并详细介绍了一款名为克魔助手的应用开发工具&#xff0c;该工具提供了丰富的性能监控功能&#xff0c;包括CPU、内存、磁盘、网络等指标的实时监…...

智能分析网关V4基于AI视频智能分析技术的周界安全防范方案

一、背景分析 随着科技的不断进步&#xff0c;AI视频智能检测技术已经成为周界安全防范的一种重要手段。A智能分析网关V4基于深度学习和计算机视觉技术&#xff0c;可以通过多种AI周界防范算法&#xff0c;实时、精准地监测人员入侵行为&#xff0c;及时发现异常情况并发出警报…...

小红书再不赚钱就晚了

2023年12月&#xff0c;小红书COO柯南公开表态&#xff0c;五年前&#xff0c;自己当时还非常坚定地表态小红书不要做电商。“那时候&#xff0c;我是站在社区的视角&#xff0c;但现在&#xff0c;我开始负责电商业务了。”此前&#xff0c;小红书已经整合了电商业务与直播业务…...

思腾云计算三大业务:算力租赁、服务器托管、思腾公有云

1、算力租赁业务 裸金属服务器就是传统物理服务器的升级版&#xff0c;也可以说是介于物理服务器和云主机之间的一种形态。既具备传统物理服务器卓越性能&#xff0c;又具备云主机一样的便捷管理平台&#xff0c;兼具了双方的优点&#xff0c;在满足核心应用场景对高性能及稳定…...

HarmonyOS之sqlite数据库的使用

从API Version 9开始&#xff0c;鸿蒙开发中sqlite使用新接口ohos.data.relationalStore 但是 relationalStore在 getRdbStore操作时&#xff0c;在预览模式运行或者远程模拟器运行都会报错&#xff0c;导致无法使用。查了一圈说只有在真机上可以正常使用&#xff0c;因此这里…...

网络抓包命令tcpdump

网络抓包命令tcpdump "tcpdump -i any -nn -vv tcp port 9095 -s 0 -w dump.cap"命令是一个网络抓包命令&#xff0c;用于捕获流经指定网络接口的TCP协议、端口号为9095的网络数据包&#xff0c;并将这些数据包写入到名为"dump.cap"的文件中。 具体参数解…...

iTMSTransporter上传ipa文件

背景 uni-app云打包之后生成的ipa包需要上传到app store上&#xff0c;applicationloader和香蕉云编都收费&#xff0c;转用iTMSTransporter上传 环境&#xff1a;mac 第1步 选择安装目录 xcode-select --switch /Applications/Xcode.app/Contents/Developer第2步 下载 xcru…...

2024华数杯国际数学建模A题思路模型详解

2024华数杯国际数学建模A题思路论文&#xff1a;1.17上午第一时间持续更新&#xff0c;详细内容见文末名片 建立一个模型来描述放射性废水在海水中的扩散速率和方向&#xff0c;考虑到涉及的物理过程和环境因素的复杂性&#xff0c;我们通常会使用一个简化的扩散模型作为起点…...

JS-定时器-间歇函数(一)

• 定时器函数介绍 定时器函数在开发中的使用场景 网页中经常会需要一种功能&#xff1a;每隔一段时间需要自动执行一段代码&#xff0c;不需要我们手动去触发例如&#xff1a;网页中的倒计时要实现这种需求&#xff0c;需要定时器函数定时器函数有两种&#xff0c;今天我先讲…...

AttributeError: module ‘openai‘ has no attribute ‘error‘解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…...

每日一记:一个windows的bat脚本工具集

最近在工作上遇到要校验文件的问题&#xff0c;例如&#xff0c;下载了一个文件之后&#xff0c;通过查看文件的md5来校验文件是否完整&#xff0c;这个动作在linux上很简单&#xff0c;但在windows上也不难&#xff0c;可以通过 certutil 命令实现&#xff0c;该命令通常可用于…...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

51c自动驾驶~合集58

我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留&#xff0c;CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制&#xff08;CCA-Attention&#xff09;&#xff0c;…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

ABAP设计模式之---“简单设计原则(Simple Design)”

“Simple Design”&#xff08;简单设计&#xff09;是软件开发中的一个重要理念&#xff0c;倡导以最简单的方式实现软件功能&#xff0c;以确保代码清晰易懂、易维护&#xff0c;并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计&#xff0c;遵循“让事情保…...

CSS设置元素的宽度根据其内容自动调整

width: fit-content 是 CSS 中的一个属性值&#xff0c;用于设置元素的宽度根据其内容自动调整&#xff0c;确保宽度刚好容纳内容而不会超出。 效果对比 默认情况&#xff08;width: auto&#xff09;&#xff1a; 块级元素&#xff08;如 <div>&#xff09;会占满父容器…...

七、数据库的完整性

七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...

第八部分:阶段项目 6:构建 React 前端应用

现在&#xff0c;是时候将你学到的 React 基础知识付诸实践&#xff0c;构建一个简单的前端应用来模拟与后端 API 的交互了。在这个阶段&#xff0c;你可以先使用模拟数据&#xff0c;或者如果你的后端 API&#xff08;阶段项目 5&#xff09;已经搭建好&#xff0c;可以直接连…...