yolov8源码解读Detect层
yolov8源码解读Detect层
- Detect层解读
- 网络各层解读及detect层后的处理
关于网络的backbone,head,以及detect层后处理,可以参考文章结尾博主的文章。
Detect层解读
先贴一下全部代码,下面一一解读。
class Detect(nn.Module):"""YOLOv8 Detect head for detection models."""dynamic = False # force grid reconstructionexport = False # export modeshape = Noneanchors = torch.empty(0) # initstrides = torch.empty(0) # initdef __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc # number of classesself.nl = len(ch) # number of detection layersself.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4 # number of outputs per anchorself.stride = torch.zeros(self.nl) # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()def forward(self, x):"""Concatenates and returns predicted bounding boxes and class probabilities."""shape = x[0].shape # BCHW# print(">>>>", x[0].shape)# print(">>>>", x[1].shape)# print(">>>>", x[2].shape)for i in range(self.nl):x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shapex_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.stridesif self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_size# print(cls.shape)y = torch.cat((dbox, cls.sigmoid()), 1)# print(y.shape)return y if self.export else (y, x)
dynamic = False #这个属性指示网格(通常是特征图上的锚框网格)是否需要动态地重建export = False #这个属性用于指示模型是否处于导出模式。shape = None # 用于存储输入图像或特征图的尺寸。anchors = torch.empty(0) # 创建了一个空的PyTorch张量strides = torch.empty(0)
步长(strides)是卷积神经网络中特征图相对于输入图像的缩小比例。
例如,如果步长是32,那么一个32x32像素的区域在特征图上就对应一个单元。
和anchors一样,这里的torch.empty(0)表示步长尚未初始化。
def __init__(self, nc=80, ch=()):"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""super().__init__()self.nc = nc # number of classesself.nl = len(ch) # number of detection layersself.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)self.no = nc + self.reg_max * 4 # number of outputs per anchorself.stride = torch.zeros(self.nl) # strides computed during buildc2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channelsself.cv2 = nn.ModuleList(nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
nc:类别数
nl:检测层的数量,目标检测中为3。
ch:传入的图片通道尺寸,在yolov8n,图片大小为640*640时。这里的ch为(256,128,64)
no:两个卷积再拼接后输出通道数,为4×reg_max+nc
c2,c3:计算卷积层的通道数。
cv2,cv3:定义的卷积操作,以输出有关类别和选框的特征图。
dfl:通过将分布式的概率分布转化为单一的预测值
class DFL(nn.Module):def __init__(self, c1=16):"""Initialize a convolutional layer with a given number of input channels."""super().__init__()self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)x = torch.arange(c1, dtype=torch.float)self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))self.c1 = c1def forward(self, x):"""Applies a transformer layer on input tensor 'x' and returns a tensor."""b, c, a = x.shape # batch, channels, anchorsreturn self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
self.conv:创建了一个输入通道为16,输出为1,没有偏置项,不需要进行梯度更新的卷积层。
这样的权重设置实际上模拟了一个积分过程,将卷积操作变成了加权求和的形式。
x:1到15的整数。
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)):
这里使用nn.Parameter将重塑后的张量设置为模型的参数,并且参数不会被更新。
假设前向传播中,x的形状为(1, 64, 8400),下面解释下forword中的变化。
1,x.view(b, 4, self.c1, a): 这个操作是对x的形状进行重塑。self.c1是16(因为输入通道数是64,即4*self.c1),那么a是8400(代表了所有锚点的数量)。b是批次大小,这里为1。所以x.view(b, 4, self.c1, a)将x从(1, 64, 8400)重塑为(1, 4, 16, 8400)。在这个形状中,我们得到了每个锚点的每个坐标轴(x, y, 宽度, 高度)上的16个预测值(可能代表某种概率分布)。
2,transpose(2, 1): 这个操作交换第二维和第三维。在应用transpose之后,张量的形状变为(1, 16, 4, 8400)。这样做的目的是让每组概率分布的16个预测值连续地排列在一起,为后面的softmax运算做准备。
3,softmax(1): softmax函数应用于第一维(现在是16个预测值的这一维)。softmax确保了这16个值之和为1,转换为一个有效的概率分布,表示每个预测值的可能性。
4,self.conv(…): 这个操作将配置好的卷积层应用在进行了softmax操作的张量上。由于卷积层的权重已被设置为从0到15的整数,并且不更新权重(不进行梯度下降优化),这个步骤实际上是在计算期望值。卷积层将每个离散的概率值乘以其相应的索引(也就是权重),然后对结果进行求和,得到该坐标的预测值。
5,view(b, 4, a): 最后一步是将张量的形状从卷积操作后的(1, 1, 4, 8400)转换回(1, 4, 8400)。这样确保了最终的输出张量与每个坐标轴的预测值(x, y, 宽度, 高度)和所有锚点的数量对齐。
总的来说,dfl层就是对预测的坐标求加权期望值。将(1,64,8400)先变为(1,16,4,8400),然后对这16个通道求加权期望,变为(1,4,8400)即这8400个锚点中的每一个锚点,x,y,width,hight的加权平均值。
接下来是前向传播的过程。打印传入的x形状,发现通道数是64,128,256。
原因:Detect层接受15,18,21层的输入。原本通道数是1024,512,256。但是yolov8n还需要乘0.25。
经过cv2,通道数变为64,经过cv3通道数变为nc,我这里nc为2(二分类)。在经过cat拼接,在通道维度上拼接,所以x[i]的通道数变为66。
如果处于训练模式,就直接返回x。
否则执行下面的代码,将特征图列表x(1×66×40×40,1×66×80×80,1×66×20×20)传递给make_anchors()函数。make_anchors函数用于生成锚点(anchors),它通常用在目标检测网络中。每个锚点代表了特征图上的一个点,可以用来预测相对于该点的边界框。strides是这些特征图相对于原始图像的下采样步长。简单来说,生成了8400个锚点(40×40+80×80+20×20),变量为anchors,形状为1×2×8400)。同时生成了8400个步长,变量为strides,形状为1×8400。参数0.5表示每个锚点处于每个像素块的中央。
if self.training:return xelif self.dynamic or self.shape != shape:self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))self.shape = shape
将xi按照2维度进行拼接,xi分别为1×66×40×40,1×66×80×80,1×66×20×20。拼接后的x_cat为1×66×8400
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
这段代码就是把x_cat进行拆分。box形状为1×64×8400,包含每个边界框的回归参数。cls形状为1×2×8400,会包含类别预测,2是因为我这里类别为2。
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV opsbox = x_cat[:, :self.reg_max * 4]cls = x_cat[:, self.reg_max * 4:]else:box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dfl层就是对预测的坐标求加权期望值。将(1,64,8400)先变为(1,16,4,8400),然后对这16个通道求加权期望,变为(1,4,8400)即这8400个锚点中的每一个锚点,x1,y1,x2,y2的加权平均值。dist2bbox()函数的作用是将锚点x1,y1,x2,y2转换为x,y,width,hight的形式。最后在乘以步长,还原到原图的大小比例。
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
此代码片段的作用是在模型导出为 Tensorflow Lite (tflite) 或 Edge TPU 兼容格式时,对预测框 (dbox) 进行归一化处理。
if self.export and self.format in ('tflite', 'edgetpu'):# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695img_h = shape[2] * self.stride[0]img_w = shape[3] * self.stride[0]img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)dbox /= img_size
此时y的形状为1×66×8400。代表有8400个锚点,每个锚点包含坐标框的x,y,width,hight,以及类别得分信息。
y = torch.cat((dbox, cls.sigmoid()), 1)
返回值,至此detect层结束。完整的预测,后续还需要进行一些处理。如进行非极大抑制,对这8400个锚点进行筛选
return y if self.export else (y, x)
另外,最后的bias_init()函数用于初始化一个目标检测模型中的Detect层的偏置。确保在训练开始时偏置值是基于合理假设的。这种方法的目标是为模型提供一个好的起点,并有助于加速训练过程中的收敛。
def bias_init(self):"""Initialize Detect() biases, WARNING: requires stride availability."""m = self # self.model[-1] # Detect() module# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequencyfor a, b, s in zip(m.cv2, m.cv3, m.stride): # froma[-1].bias.data[:] = 1.0 # boxb[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
网络各层解读及detect层后的处理
关于backbone,head层,以及detect层参考下面博主的文章,讲的非常好。
链接: Yolov 8源码超详细逐行解读+ 网络结构细讲(自我用的小白笔记)
链接: 最细致讲解yolov8模型推理完整代码–(前处理,后处理)
相关文章:

yolov8源码解读Detect层
yolov8源码解读Detect层 Detect层解读网络各层解读及detect层后的处理 关于网络的backbone,head,以及detect层后处理,可以参考文章结尾博主的文章。 Detect层解读 先贴一下全部代码,下面一一解读。 class Detect(nn.Module):"""YOLOv8 …...

(AtCoder Beginner Contest 341)(A - D)
比赛地址 : Tasks - Toyota Programming Contest 2024#2(AtCoder Beginner Contest 341) A . Print 341 模拟就好了 , 先放一个 1 , 然后放 n 个 01 ; #include<bits/stdc.h> #define IOS ios::sync_with_stdio(0);cin.tie(0);cout…...
python Flask与微信小程序 统计管理
common/models/stat/StatDailyMember.py DROP TABLE IF EXISTS stat_daily_member;CREATE TABLE stat_daily_member (id int(11) unsigned NOT NULL AUTO_INCREMENT,date date NOT NULL COMMENT 日期,member_id int(11) NOT NULL DEFAULT 0 COMMENT 会员id,total_shared_count …...
光伏企业助力乡村振兴
光伏是一种利用太阳能生产电能的发电技术,属于可再生能源。近年来我国的光伏企业发展迅速,已经称霸全球,同时也为乡村振兴贡献了力量。 一、光伏企业助力乡村 1.推动农业发展 光伏发电和农业种植、畜牧、渔业、水产等有机结合,…...

root MUSIC 算法补充说明
root MUSIC 算法补充说明 多项式求根root MUSIC 算法原理如何从 2 M − 2 2M-2 2M−2 个根中确定 K K K 个根从复数域上观察 2 M − 2 2M-2 2M−2 个根的分布 这篇笔记是上一篇关于 root MUSIC 笔记的补充。 多项式求根 要理解 root MUSIC 算法,需要理解多项式求…...

关于Django的中间件使用说明。
目录 1.中间件2. 为什么要中间件?3. 具体使用中间件3.1 中间件所在的位置:在django的settings.py里面的MIDDLEWARE。3.2 中间件的创建3.3 中间件的使用 4. 展示成果 1.中间件 中间件的大概解释:在浏览器在请求服务器的时候,首先要…...
Chapter 8 - 15. Congestion Management in TCP Storage Networks
User Actions After learning the states of queue utilization, the following are the actions that admins and operators can take while using TCP transport for storage traffic. 了解了队列利用率的状态后,管理员和操作员在使用 TCP 传输存储流量时可以采取以下措施。…...

前端技巧之svg精灵图svg-sprite-loader
首先说明精灵图的必要性,其可以让我们只需要向服务器请求一次图片资源,就能加载很多图片,即能够减轻http请求造成的服务器压力。 然后这里要说明的是这个插件是webpack上面的,所以在vue2中比较好用,如果在vue3中&…...

IO线程-day2
1> 使用fread和fwrite完成两个文件的拷贝 程序: #define MAXSIZE 1024 #include<myhead.h>int main(int argc, char const *argv[]) {FILE *srcfpNULL;FILE *destfpNULL;if(!(srcfpfopen("pm.bmp","r")))PRINT_ERR("");if…...

Spring Boot 笔记 024 登录页面
1.1 登录接口 //导入request.js请求工具 import request from /utils/request.js//提供调用注册接口的函数 export const userRegisterService (registerData)>{//借助于UrlSearchParams完成传递const params new URLSearchParams()for(let key in registerData){params.a…...

09_Java集合
一、Java集合框架概述 一方面, 面向对象语言对事物的体现都是以对象的形式,为了方便对多个对象的操作,就要对对象进行存储。另一方面,使用Array存储对象方面具有一些弊端,而Java 集合就像一种容器,可以动态…...

HCIA-HarmonyOS设备开发认证V2.0-3.2.轻量系统内核基础-软件定时器
目录 一、软件定时器基本概念二、软件定时器运行机制三、软件定时器状态四、软件定时器模式五、软件定时器开发流程六、软件定时器使用说明七、软件定时器接口八、代码分析(待续...)坚持就有收获 一、软件定时器基本概念 软件定时器,是基于系…...

考研证件照可以自己用手机拍吗?考研证件照p过可以通过审核吗?考研证件照有什么要求
一、考研证件照可以自己用手机拍吗 现在的智能手机相机技术先进,大多都配备了高像素摄像头,使得自拍照片的质量有了大幅提升。相较于传统的证件照拍摄,使用手机自拍考研证件照理论上是可行的。然而,考研证件照需要满足一定的规定…...

win10 环境下Python 3.8按装fastapi paddlepaddle 进行图片文字识别1
###按装 用conda 创建python 3.8的环境,可参看本人python下的其它文章。 在pycharm开发环境下按装相关的模块: pip install -i https://pypi.tuna.tsinghua.edu.cn/simple fastapi pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "uvi…...
json字符串的处理
json字符串的处理 【1】解析json字符串(1)如果json格式字符串 ,最外层 是 中括号,表示数组,就使用方法(2)如果json格式字符串,最外层是 大括号,表示对象,就是…...
Java基础String常见的编程练习
1.对字符串数组进行排序 package javalianxi;import java.util.Arrays; import java.util.Comparator;public class Test1 {public static void main(String[] args) {String[] array { "cd", "CD", "bc", "AB", "ab", &q…...
代码随想录算法训练营(回溯5)| 491.递增子序列 46.全排列 47.全排列 II
491.递增子序列 本题和大家刚做过的 90.子集II 非常像,但又很不一样,很容易掉坑里。 题目链接/文章讲解 视频讲解 46.全排列 本题重点感受一下,排列问题 与 组合问题,组合总和,子集问题的区别。 为什么排列问题不用…...

专业140+总分420+南京信息工程大学811信号与系统考研经验南信大电子信息与通信工程,真题,大纲,参考书
今年顺利被南信大电子信息录取,初试420,专业811信号与系统140(Jenny老师辅导班上140很多,真是大佬云集),今年应该是南信大电子信息最卷的一年,复试线比往年提高了很多,录取平均分380…...

一元函数微分学【高数笔记】
1. 什么是微分?什么是微商? 2. 什么是函数的微分? 3. 在函数的微分中,有什么样的关系? 4. 一元函数的微分运用在什么题型中? 5. 什么是一元函数?...
(16)Hive——企业调优经验
前言 本篇文章主要整理hive-3.1.2版本的企业调优经验,有误请指出~ 一、性能评估和优化 1.1 Explain查询计划 使用explain命令可以分析查询计划,查看计划中的资源消耗情况,定位潜在的性能问题,并进行相应的优化。 explain执行计划…...

wordpress后台更新后 前端没变化的解决方法
使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...
三维GIS开发cesium智慧地铁教程(5)Cesium相机控制
一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点: 路径验证:确保相对路径.…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...

PL0语法,分析器实现!
简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...
WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)
一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...
Spring AI 入门:Java 开发者的生成式 AI 实践之路
一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...

QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...

企业如何增强终端安全?
在数字化转型加速的今天,企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机,到工厂里的物联网设备、智能传感器,这些终端构成了企业与外部世界连接的 “神经末梢”。然而,随着远程办公的常态化和设备接入的爆炸式…...

Linux --进程控制
本文从以下五个方面来初步认识进程控制: 目录 进程创建 进程终止 进程等待 进程替换 模拟实现一个微型shell 进程创建 在Linux系统中我们可以在一个进程使用系统调用fork()来创建子进程,创建出来的进程就是子进程,原来的进程为父进程。…...