当前位置: 首页 > news >正文

【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于YOLOv8深度学习的行人跌倒检测系统】
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】
11.【基于YOLOv8深度学习的安全帽目标检测系统】12.【基于YOLOv8深度学习的120种犬类检测与识别系统】
13.【基于YOLOv8深度学习的路面坑洞检测系统】14.【基于YOLOv8深度学习的火焰烟雾检测系统】
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】16.【基于YOLOv8深度学习的舰船目标分类检测系统】
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】18.【基于YOLOv8深度学习的血细胞检测与计数系统】
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】22.【基于YOLOv8深度学习的路面标志线检测与识别系统】
22.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】23.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】
24.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】25.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】
26.【基于YOLOv8深度学习的人脸面部表情识别系统】27.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】

《------正文------》

## 搜索C2f源码位置并新建C2f类

在项目目录中全局搜索class c2f即可找到c2f的源码位置。然后打开源码位置,进行相应修改。源码路径为:ultralytics/nn/modules/block.py

在原文件中直接copy一份c2f类的源码,然后命名为c2f_Attention,如下所示:

在这里插入图片描述

在不同文件导入新建的C2f类

ultralytics/nn/modules/block.py顶部,all中添加刚才创建的类的名称:c2f_Attention,如下图所示:

在这里插入图片描述

同样需要在ultralytics/nn/modules/__init__.py文件,相应位置导入刚出创建的c2f_Attention类。如下图:

在这里插入图片描述

还需要在ultralytics/nn/tasks.py中导入创建的c2f_Attention类,,如下图:

在这里插入图片描述

parse_model解析函数中添加C2f类

ultralytics/nn/tasks.pyparse_model解析网络结构的函数中,加入c2f_Attention类,如下图:
在这里插入图片描述

创建新的配置文件c2f_att_yolov8.yaml

ultralytics/cfg/models/v8目录下新建c2f_att_yolov8.yaml配置文件,内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f_Attention, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f_Attention, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f_Attention, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 9], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

新的c2f_att_yolov8.yaml配置文件与原yolov8.yaml文件的对比如下:

在这里插入图片描述

在C2f中添加注意力:ShuffleAttention

注意:对于有通道数参数的注意力机制,其输入通道数为其上层的输出通道数。这个注意力添加的位置有关。

在路径ultralytics/nn下新建注意力模块,ShuffleAttention.py文件。内容如下:

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameterclass ShuffleAttention(nn.Module):def __init__(self, channel=512, reduction=16, G=8):super().__init__()self.G = Gself.channel = channelself.avg_pool = nn.AdaptiveAvgPool2d(1)self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))self.sigmoid = nn.Sigmoid()def init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):init.constant_(m.weight, 1)init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)@staticmethoddef channel_shuffle(x, groups):b, c, h, w = x.shapex = x.reshape(b, groups, -1, h, w)x = x.permute(0, 2, 1, 3, 4)# flattenx = x.reshape(b, -1, h, w)return xdef forward(self, x):b, c, h, w = x.size()# group into subfeaturesx = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w# channel_splitx_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w# channel attentionx_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1x_channel = x_0 * self.sigmoid(x_channel)# spatial attentionx_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,wx_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,wx_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w# concatenate along channel axisout = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,wout = out.contiguous().view(b, -1, h, w)# channel shuffleout = self.channel_shuffle(out, 2)return out

ultralytics/nn/tasks.py中导入,并修改在parse_model解析网络结构的函数中,添加解析代码:

在这里插入图片描述

在这里插入图片描述

注意力不同位置添加方法

ultralytics/nn/modules/block.py中的c2f_Attention类中代码相应位置添加注意力机制:

1 . 方式一:在self.cv1后面添加注意力机制

在这里插入图片描述

2.方式二:在self.cv2后面添加注意力机制

在这里插入图片描述

3.方式三:在c2fbottleneck中添加注意力机制,将Bottleneck类,复制一份,并命名为Bottleneck_Attention,然后,在Bottleneck_Attention的cv2后面添加注意力机制,同时修改C2f_Attention类别中的BottleneckBottleneck_Attention。如下图所示:

在这里插入图片描述

加载配置文件并训练

加载c2f_att_yolov8.yaml配置文件,并运行train.py训练代码:

#coding:utf-8
from ultralytics import YOLOif __name__ == '__main__':model = YOLO('ultralytics/cfg/models/v8/c2f_att_yolov8.yaml')model.load('yolov8n.pt') # loading pretrain weightsmodel.train(data='datasets/TomatoData/data.yaml', epochs=150, batch=2)

注意观察,打印出的网络结构是否正常修改,如下图所示:
在这里插入图片描述

【源码免费获取】

为了小伙伴们能够,更好的学习实践,本文已将所有代码、示例数据集、论文等相关内容打包上传,供小伙伴们学习。获取方式如下:

关注下方名片G-Z-H:【阿旭算法与机器学习】,发送【yolov8改进】即可免费获取

在这里插入图片描述


结束语

关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

觉得不错的小伙伴,感谢点赞、关注加收藏哦!

相关文章:

【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…...

Hive聚合函数详细讲解

Hive中的聚合函数用于在数据上进行计算并返回单个值,这些值通常是基于一组行或列的汇总。以下是您提到的聚合函数的详细讲解,包括案例和使用注意事项: SUM() 功能:计算某列的总和。语法:SUM(column)案例:SELECT SUM(salary) FROM employees;注意事项:通常用于数值型列。…...

方案:如何列出 Jira 中授予用户的所有权限

文章目录 概述解决方案REST API数据库 概述 为了进行故障排除或某些管理任务,我们可能想知道给定用户拥有的所有权限。 Jira 通过其 UI 提供权限助手和类似工具,但对于所有权限的列表,我们只能通过作为用户本身进行身份验证的 REST API 请求…...

Flutter-Web从0到部署上线(实践+埋坑)

本文字数:7743字 预计阅读时间:60分钟 01 前言 首先说明一下,这篇文章是给具备Flutter开发经验的客户端同学看的。Flutter 的诞生虽然来自 Google 的 Chrome 团队,但大家都知道 Flutter 最先支持的平台是 Android 和 iOS&#xff…...

Redis键值设计

文章目录 1.优雅的key2.拒绝BigKey2.1.什么是BigKey2.2.BigKey的危害2.3.如何发现BigKey2.4.如何删除BigKey 3.恰当的数据类型 1.优雅的key 2.拒绝BigKey 2.1.什么是BigKey 2.2.BigKey的危害 2.3.如何发现BigKey scan扫描示例代码 final static int STR_MAX_LEN 10 * 1024;fi…...

CSS 下载进度条

<template><view class=btn>下载中</view></template><script></script><style>/* 设置整个页面的样式 */body {width: 100vw; /* 页面宽度为视口宽度 */background: #000000; /* 背景颜色为白色 */display: flex; /* 使用 flex…...

风力发电防雷监测浪涌保护器的应用解决方案

风力发电是一种利用风能转化为电能的可再生能源技术&#xff0c;具有清洁、环保、低碳的优点&#xff0c;是应对全球气候变化和能源危机的重要途径之一。然而&#xff0c;风力发电也面临着一些技术和经济的挑战&#xff0c;其中之一就是雷电的威胁。由于风力发电机组通常位于高…...

开发安全之:Database access control

Overview 如果没有适当的 access control&#xff0c;就会执行一个包含用户控制主键的 SQL 指令&#xff0c;从而允许攻击者访问未经授权的记录。 Details Database access control 错误在以下情况下发生&#xff1a; 1. 数据从一个不可信赖的数据源进入程序。 2. 这个数据用…...

VMware Vsphere 日志:用户 dcui@127.0.01已以vMware-client/6.5.0 的身份登录

一、事件截图&#xff1a; 二、解决办法 原因&#xff1a; 三、解决办法 1.开启锁定模式 2.操作 1、从清单中选择您的 ESXi 主机&#xff0c;然后转至管理 > 设置 > 安全配置文件&#xff0c;然后单击锁定模式的编辑按钮 2、在打开的锁定模式窗口中&#xff0c;选中启…...

全面了解网络性能监测:从哪些方面进行监测?

目录 摘要 引言 CPU内存监控 磁盘监控 网络监控 GPU监控 帧率监控 总结 摘要 本文介绍了网络性能监测的重要性&#xff0c;并详细介绍了一款名为克魔助手的应用开发工具&#xff0c;该工具提供了丰富的性能监控功能&#xff0c;包括CPU、内存、磁盘、网络等指标的实时监…...

智能分析网关V4基于AI视频智能分析技术的周界安全防范方案

一、背景分析 随着科技的不断进步&#xff0c;AI视频智能检测技术已经成为周界安全防范的一种重要手段。A智能分析网关V4基于深度学习和计算机视觉技术&#xff0c;可以通过多种AI周界防范算法&#xff0c;实时、精准地监测人员入侵行为&#xff0c;及时发现异常情况并发出警报…...

小红书再不赚钱就晚了

2023年12月&#xff0c;小红书COO柯南公开表态&#xff0c;五年前&#xff0c;自己当时还非常坚定地表态小红书不要做电商。“那时候&#xff0c;我是站在社区的视角&#xff0c;但现在&#xff0c;我开始负责电商业务了。”此前&#xff0c;小红书已经整合了电商业务与直播业务…...

思腾云计算三大业务:算力租赁、服务器托管、思腾公有云

1、算力租赁业务 裸金属服务器就是传统物理服务器的升级版&#xff0c;也可以说是介于物理服务器和云主机之间的一种形态。既具备传统物理服务器卓越性能&#xff0c;又具备云主机一样的便捷管理平台&#xff0c;兼具了双方的优点&#xff0c;在满足核心应用场景对高性能及稳定…...

HarmonyOS之sqlite数据库的使用

从API Version 9开始&#xff0c;鸿蒙开发中sqlite使用新接口ohos.data.relationalStore 但是 relationalStore在 getRdbStore操作时&#xff0c;在预览模式运行或者远程模拟器运行都会报错&#xff0c;导致无法使用。查了一圈说只有在真机上可以正常使用&#xff0c;因此这里…...

网络抓包命令tcpdump

网络抓包命令tcpdump "tcpdump -i any -nn -vv tcp port 9095 -s 0 -w dump.cap"命令是一个网络抓包命令&#xff0c;用于捕获流经指定网络接口的TCP协议、端口号为9095的网络数据包&#xff0c;并将这些数据包写入到名为"dump.cap"的文件中。 具体参数解…...

iTMSTransporter上传ipa文件

背景 uni-app云打包之后生成的ipa包需要上传到app store上&#xff0c;applicationloader和香蕉云编都收费&#xff0c;转用iTMSTransporter上传 环境&#xff1a;mac 第1步 选择安装目录 xcode-select --switch /Applications/Xcode.app/Contents/Developer第2步 下载 xcru…...

2024华数杯国际数学建模A题思路模型详解

2024华数杯国际数学建模A题思路论文&#xff1a;1.17上午第一时间持续更新&#xff0c;详细内容见文末名片 建立一个模型来描述放射性废水在海水中的扩散速率和方向&#xff0c;考虑到涉及的物理过程和环境因素的复杂性&#xff0c;我们通常会使用一个简化的扩散模型作为起点…...

JS-定时器-间歇函数(一)

• 定时器函数介绍 定时器函数在开发中的使用场景 网页中经常会需要一种功能&#xff1a;每隔一段时间需要自动执行一段代码&#xff0c;不需要我们手动去触发例如&#xff1a;网页中的倒计时要实现这种需求&#xff0c;需要定时器函数定时器函数有两种&#xff0c;今天我先讲…...

AttributeError: module ‘openai‘ has no attribute ‘error‘解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…...

每日一记:一个windows的bat脚本工具集

最近在工作上遇到要校验文件的问题&#xff0c;例如&#xff0c;下载了一个文件之后&#xff0c;通过查看文件的md5来校验文件是否完整&#xff0c;这个动作在linux上很简单&#xff0c;但在windows上也不难&#xff0c;可以通过 certutil 命令实现&#xff0c;该命令通常可用于…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应&#xff0c;这是一种非线性光学现象&#xff0c;主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场&#xff0c;对材料产生非线性响应&#xff0c;可能…...

【kafka】Golang实现分布式Masscan任务调度系统

要求&#xff1a; 输出两个程序&#xff0c;一个命令行程序&#xff08;命令行参数用flag&#xff09;和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽&#xff0c;然后将消息推送到kafka里面。 服务端程序&#xff1a; 从kafka消费者接收…...

STM32+rt-thread判断是否联网

一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

【配置 YOLOX 用于按目录分类的图片数据集】

现在的图标点选越来越多&#xff0c;如何一步解决&#xff0c;采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集&#xff08;每个目录代表一个类别&#xff0c;目录下是该类别的所有图片&#xff09;&#xff0c;你需要进行以下配置步骤&#x…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章&#xff1f;AI自动生成&#xff0c;效率提升10倍&#xff01; 支持多语言、自动配图、定时发布&#xff0c;让内容创作更轻松&#xff01; AI内容生成 → 不想每天写文章&#xff1f;AI一键生成高质量内容&#xff01;多语言支持 → 跨境电商必备&am…...

GitHub 趋势日报 (2025年06月08日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...

聊一聊接口测试的意义有哪些?

目录 一、隔离性 & 早期测试 二、保障系统集成质量 三、验证业务逻辑的核心层 四、提升测试效率与覆盖度 五、系统稳定性的守护者 六、驱动团队协作与契约管理 七、性能与扩展性的前置评估 八、持续交付的核心支撑 接口测试的意义可以从四个维度展开&#xff0c;首…...

Java 二维码

Java 二维码 **技术&#xff1a;**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库&#xff0c;专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性&#xff0c;并提供了一个通用的框架&…...