yolov5模型构建源码详细解读(yaml、parse_model等内容)
文章目录
- 前言
- 一、yolov5文件说明
- 二、yolov5调用模型构建位置
- 三、模型yaml文件解析
- 1、 yaml的backbone解读
- Conv模块参数解读
- C3模块参数解读
- 2、yaml的head解读
- Concat模块参数解读
- Detect模块参数解读
- 四、模型构建整体解读
- 五、构建模型parse_model源码解读
前言
本文章记录yolov5如何通过模型文件yaml搭建模型,从解析yaml参数用途,到parse_model模型构建,最后到yolov5如何使用搭建模型实现模型训练过程。
`
一、yolov5文件说明
model/yolo.py文件:为模型构建文件,主要为模型集成类class Model(nn.Module),模型yaml参数(如:yolov5s.yaml)构建parse_model(d, ch)
model/common.py文件:为模型模块(或叫模型组装网络模块)
二、yolov5调用模型构建位置
在train.py约113行位置,如下代码:
if pretrained:with torch_distributed_zero_first(LOCAL_RANK):weights = attempt_download(weights) # download if not found locallyckpt = torch.load(weights, map_location=device) # load checkpointmodel = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # createexclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
三、模型yaml文件解析
以yolov5s.yaml文件作为参考,作为解析。
1、 yaml的backbone解读
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2[-1, 1, Conv, [128, 3, 2]], # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]], # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]], # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]], # 9]
Conv模块参数解读
backbone的[-1, 1, Conv, [128, 3, 2]]行作为解读参考,在parse_model(d, ch)中表示,f, n, m, args=[-1, 1, Conv, [128, 3, 2]]。
f为取ch[f]通道(ch保存通道,-1取上次通道数);
m为调用模块函数,通常在common.py中;
n为网络深度depth,使用max[1,int(n*depth_multiple)]赋值,即m结构循环次数;
args对应[128, 3, 2],表示通道数args[0],该值会根据math.ceil(args[0]/8)*8调整,args[1]表示kernel大小,args[2]表示stride,
args[-2:]后2位为m模块传递参数;
C3模块参数解读
backbone的[-1, 3, C3, [128]]行作为解读参考,在parse_model(d, ch)中表示,f, n, m, args=[-1, 1, Conv, [128, 3, 2]]。
f为取ch[f]通道(ch保存通道,-1取上次通道数);
m为调用模块函数,通常在common.py中;
n为网络深度depth,使用max[1,int(n*depth_multiple)]赋值,即m结构循环次数;
args对应[128],表示通道数args[0]为c2,该值会根据math.ceil(args[0]/8)*8调整,决定当前层输出通道数量,而后在parse_model中被下面代码直接忽略,会被插值n=1,在C3代码中表示循环次数,
顺便说下args对应[128,False],在在C3中的False表示是否需要shotcut。
c1, c2 = ch[f], args[0]
if c2 != no: # if not outputc2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
# 通过模块,更换n值
if m in [BottleneckCSP, C3, C3TR, C3Ghost]:args.insert(2, n) # number of repeatsn = 1
C3模块代码如下:
class C3(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e) # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])def forward(self, x):return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
2、yaml的head解读
head参数
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]], # cat backbone P4[-1, 3, C3, [512, False]], # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]], # cat backbone P3[-1, 3, C3, [256, False]], # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]], # cat head P4[-1, 3, C3, [512, False]], # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]], # cat head P5[-1, 3, C3, [1024, False]], # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)]
Concat模块参数解读
head的[[-1, 6], 1, Concat, [1]]行作为解读参考,在parse_model(d, ch)中表示,f, n, m, args=[[-1, 6], 1, Concat, [1]]。
f为取ch[-1]与ch[6]通道数和,且6会被保存到save列表中,在forward中该列表对应层模块输出会被保存
;
m为调用模块函数,通常在common.py中;
n为网络深度depth,使用max[1,int(n*depth_multiple)]赋值,即m结构循环次数,但这里必然为1;
args对应[1],表示通cat维度,这里为1,表示通道叠加;
Detect模块参数解读
head的[[17, 20, 23], 1, Detect, [nc, anchors]]行作为解读参考,在parse_model(d, ch)中表示,f, n, m, args=[[17, 20, 23], 1, Detect, [nc, anchors]]。
f表示需要使用的层,并分别在17,20,23层获取对应通道,可通过yaml从backbone开始从0开始数的那一行,如17对应[-1, 3, C3, [256, False]], # 17 (P3/8-small),
同时,17、20、23也会被保存到save列表中
;
m为调用模块函数,通常在common.py中;
n为网络深度depth,使用max[1,int(n*depth_multiple)]赋值,即m结构循环次数,但这里必然为1;
args对应[nc, anchors],表示去nc数量与anchor三个列表,同时会将f找到的通道作为列表添加到args中,如下代码示意,
最终args大致为[80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]],
80为类别nc,[128, 256, 512]为f对应的通道,其它为anchor值;
elif m is Detect:args.append([ch[x] for x in f])if isinstance(args[1], int): # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)
四、模型构建整体解读
yolov5模型集成网络代码如下,其重要解读已在代码中注释。
class Model(nn.Module):def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classessuper().__init__()if isinstance(cfg, dict):self.yaml = cfg # model dictelse: # is *.yamlimport yaml # for torch hubself.yaml_file = Path(cfg).namewith open(cfg, errors='ignore') as f:self.yaml = yaml.safe_load(f) # model dict# Define modelch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channelsif nc and nc != self.yaml['nc']:LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")self.yaml['nc'] = nc # override yaml valueif anchors:LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')self.yaml['anchors'] = round(anchors) # override yaml valueself.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelistself.names = [str(i) for i in range(self.yaml['nc'])] # default names,将其对应数字转为字符串self.inplace = self.yaml.get('inplace', True)# Build strides, anchors 以下为detect模块设置参数m = self.model[-1] # Detect()if isinstance(m, Detect):s = 256 # 2x min stridem.inplace = self.inplace# 通过给定假设输入为torch.zeros(1, ch, s, s)获得stridem.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forwardm.anchors /= m.stride.view(-1, 1, 1) # 变换获得每一特征层的anchorcheck_anchor_order(m)self.stride = m.stride # [8,16,32]self._initialize_biases() # only run once,为检测detect设置bias初始化# Init weights, biasesinitialize_weights(self)self.info()LOGGER.info('')def forward(self, x, augment=False, profile=False, visualize=False):if augment:return self._forward_augment(x) # augmented inference, Nonereturn self._forward_once(x, profile, visualize) # single-scale inference, train
以上代码最重要网络搭建模块,如下代码调用,我将在下一节解读。
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
以上代码最重要网络运行forward,如下代码调用,我将重点解读。
模型在训练时候,是调用下面模块,如下:
return self._forward_once(x, profile, visualize) # single-scale inference, train
m.f和m.i已在上面yaml中介绍,实际需重点关注y保存save列表对应的特征层输出,若没有则保留为none占位,其代码解读已在有注释,详情如下:
def _forward_once(self, x, profile=False, visualize=False):y, dt = [], [] # outputsfor m in self.model:# 通过m.f确定改变m模块输入变量值,若为列表如[-1,6]一般为cat或detect,一般需要给定输入什么特征if m.f != -1: # if not from previous layer# 若m.f为[-1,6]这种情况,则[x if j == -1 else y[j] for j in m.f]运行此块,该块将-1变成了上一层输出x与对应6的输出x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layersif profile:self._profile_one_layer(m, x, dt)x = m(x) # run# 通过之前parse_model获得save列表(已赋值给self.save),将其m模块输出结果保存到y列表中,否则使用none代替位置# 这里m.i是索引,是yaml每行的模块索引y.append(x if m.i in self.save else None) # save outputif visualize:feature_visualization(x, m.type, m.i, save_dir=visualize)return x
五、构建模型parse_model源码解读
该部分是yolov5根据对应yaml模型文件搭建的网络,需结合对应yaml文件一起解读,我已在上面介绍了yaml文件,可自行查看。
同时,也需要重点关注 m_.i, m_.f, m_.type, m_.np = i, f, t, np,会在上面_forward_once函数中用到。
本模块代码解读,也已在代码注释中,请查看源码理解,代码如下:
def parse_model(d, ch): # model_dict, input_channels(3)LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors,获得每个特征点anchor数量,为3no = na * (nc + 5) # 最终预测输出数量, number of outputs = anchors * (classes + 5)# layers保存yaml每一行处理作为一层,使用列表保存,最后输出使用nn.Sequential(*layers)处理作为模型层连接# c2为yaml每一行通道输出预定义数量,需与width_multiple参数共同决定layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out # ch为channel数量,初始值为[3]for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args# eval这个函数会把里面的字符串参数的引号去掉,把中间的内容当成Python的代码# i为每一层附带索引,相当于对yaml每一行的模块设置编号m = eval(m) if isinstance(m, str) else m # eval stringsfor j, a in enumerate(args):try:args[j] = eval(a) if isinstance(a, str) else a # eval stringsexcept NameError:passn = n_ = max(round(n * gd), 1) if n > 1 else n # 获得最终深度,循环次数,depth gain# 不同网络结构模块处理,同时会改变对应c2通道if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]: # 是否在设定模块内c1, c2 = ch[f], args[0]if c2 != no: # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]# 通过模块,更换n值if m in [BottleneckCSP, C3, C3TR, C3Ghost]:args.insert(2, n) # number of repeatsn = 1elif m is nn.BatchNorm2d:args = [ch[f]]elif m is Concat:c2 = sum([ch[x] for x in f]) # 将最后一层通道数与cancat通道叠加求和,如[[-1, 6], 1, Concat, [1]]将-1与第6通道求和elif m is Detect:args.append([ch[x] for x in f])if isinstance(args[1], int): # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2else:c2 = ch[f]m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # modulet = str(m)[8:-2].replace('__main__.', '') # module typenp = sum([x.numel() for x in m_.parameters()]) # number params,计算参数量m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params ,将其赋给模型,后面forward会使用到LOGGER.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n_, np, t, args)) # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelistlayers.append(m_) #if i == 0:ch = [] # 删除 输入的3 通道ch.append(c2) # 保存每个模块的通道,即yaml的每行均保存,包含concat啥都保存return nn.Sequential(*layers), sorted(save)
相关文章:
yolov5模型构建源码详细解读(yaml、parse_model等内容)
文章目录 前言一、yolov5文件说明二、yolov5调用模型构建位置三、模型yaml文件解析1、 yaml的backbone解读Conv模块参数解读C3模块参数解读 2、yaml的head解读Concat模块参数解读Detect模块参数解读 四、模型构建整体解读五、构建模型parse_model源码解读 前言 本文章记录yolo…...
Monodepth2和Lite-Mono准备数据集
以KITTI为例下载解压后放在/home/lwd/tmp/2011_09_26 cd /home/lwd/tmp/2011_09_26 ls输出 2011_09_26_drive_0001_sync 2011_09_26_drive_0002_sync 2011_09_26_drive_0005_sync python txt.py txt.py import os, sysalos.listdir(.) al.sort() fopen(train.txt, w) for a in…...

ML-fairness-gym入门教学
1、ML-fairness-gym简介 ML-fairness-gym是一个探索机器学习系统长期影响的工具。可以用于评估机器学习系统的公平性和评估静态数据集上针对各种输入的误差度量的差异。开源网站:GitHub - google/ml-fairness-gym 2、安装ML-fairness-gym(Windows&…...

结构体指针变量的使用
1、结构体指针的引用 #include<iostream> using namespace std;struct Student {int num;char name[32]; }; int main() {struct Student stu {1,"张三"};struct Student* p &stu;system("pause"); return 0; } 2、通过结构体指针访问结构体…...
解决oracle的em访问提示“使用不受支持的协议。”的bug
1. 设置oracle唯一名称 执行emctl时需要设置一个唯一的名称 否则提示 “Environment variable ORACLE_UNQNAME not defined. Please set ORACLE_UNQNAME to database unique name. ”中文意思为“未定义环境变量ORACLE_UNQNAME。 请将ORACLE_UNQNAME设置为数据库唯一名称/服务…...

编译工具:CMake(三)| 最简单的实例升级
编译工具:CMake(三)| 最简单的实例升级 前言过程语法解释ADD_SUBDIRECTORY 指令 如何安装目标文件的安装普通文件的安装:非目标文件的可执行程序安装(比如脚本之类)目录的安装 修改 Helloworld 支持安装测试 前言 本篇博客的任务…...
20天学会rust(四)常见系统库的使用
前面已经学习了rust的基础知识,今天我们来学习rust强大的系统库,从此coding事半功倍。 集合 数组&可变长数组 在 Rust 中,有两种主要的数组类型:固定长度数组(Fixed-size Arrays)和可变长度数组&…...

drawio----输出pdf为图片大小无空白(图片插入论文)
自己在写论文插入图片时为了让论文图片放大不模糊,啥方法都试了,最后摸索出来这个。 自己手动画图的时候导出pdf总会出现自己的图片很小,pdf的白边很大如下如所示,插入论文的时候后虽然放大不会模糊,但是白边很大会显…...

2021年09月 C/C++(二级)真题解析#中国电子学会#全国青少年软件编程等级考试
第1题:字符统计 给定一个由a-z这26个字符组成的字符串,统计其中哪个字符出现的次数最多。 输入 输入包含一行,一个字符串,长度不超过1000。 输出 输出一行,包括出现次数最多的字符和该字符出现的次数,中间以…...

HCIP VRRP技术
一、VRRP概述 VRRP(Virtual Router Pedundancy Protocol)虚拟路由器冗余协议,既能够实现网关的备份,又能够解决多个网关之间互相冲突的问题,从而提高网络可靠性。 局域网中的用户的终端通常采用配置一个默认网关的形…...
JAVA AES ECB/CBC 加解密
JAVA AES ECB/CBC 加解密 1. AES ECB2. AES CBC 1. AES ECB package org.apache.jmeter.functions;/*** author yuyang*/import org.apache.commons.lang3.StringUtils; import java.util.Base64; import javax.crypto.Cipher; import javax.crypto.spec.SecretKeySpec;/*** a…...

Android FrameWork 层 Handler源码解析
Handler生产者-消费者模型 在android开发中,经常会在子线程中进行一些耗时操作,当操作完毕后会通过handler发送一些数据给主线程,通知主线程做相应的操作。 其中:子线程、handler、主线程,其实构成了线程模型中经典的…...

list
目录 迭代器 介绍 种类 本质 介绍 模拟实现 注意点 代码 迭代器 介绍 在C中,迭代器(Iterators)是一种用于遍历容器(如数组、vector、list等)中元素的工具 无论容器的具体实现细节如何,访问容器中的元素的方…...

ABeam×Startup丨德硕管理咨询(深圳)创新研究团队前往灵境至维·既明科技进行拜访交流
近日,德硕管理咨询(深圳)(以下简称“ABeam-SZ”)创新研究团队一行前往灵境至维既明科技有限公司(以下简称“灵境至维”)进行拜访交流,探讨线上虚拟空间的商业模式。 现场合影 &…...
TCP的相关性质
文章目录 流量控制拥塞控制拥塞窗口 延迟应答捎带应答面向字节流粘包问题TCP的异常 流量控制 由于接收端处理数据的速度是有限的,如果发送端发的太快,那么接收端的缓冲区就可能会满。此时如果发送端还发数据,就会出现丢包现象,并…...
pointpillars在2D CNN引入自适应注意力机制
在给定的代码中,您想要引入自适应注意力机制。自适应注意力机制通常用于增强模型的感受野,从而帮助模型更好地捕捉特征之间的关系。在这里,我将展示如何在您的代码中引入自适应注意力机制,并提供详细的解释。 首先,让…...

【每日一题】1572. 矩阵对角线元素的和
【每日一题】1572. 矩阵对角线元素的和 1572. 矩阵对角线元素的和题目描述解题思路 1572. 矩阵对角线元素的和 题目描述 给你一个正方形矩阵 mat,请你返回矩阵对角线元素的和。 请你返回在矩阵主对角线上的元素和副对角线上且不在主对角线上元素的和。 示例 1&a…...
leetcode原题:检查子树
题目: 检查子树。你有两棵非常大的二叉树:T1,有几万个节点;T2,有几万个节点。设计一个算法,判断 T2 是否为 T1 的子树。 如果 T1 有这么一个节点 n,其子树与 T2 一模一样,则 T2 为…...

2023年国赛数学建模思路 - 案例:ID3-决策树分类算法
文章目录 0 赛题思路1 算法介绍2 FP树表示法3 构建FP树4 实现代码 建模资料 0 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 1 算法介绍 FP-Tree算法全称是FrequentPattern Tree算法,就是频繁模…...
可视化绘图技巧100篇进阶篇(七)-三维堆积柱形图(3D Stacked Bar Chart)
目录 前言 适用场景 图例 绘图工具及代码实现 HighCharts echarts MATLAB...

DLL动态库实现文件遍历功能(Windows编程)
源文件: 文件遍历功能的动态库,并支持用户注册回调函数处理遍历到的文件 a8f80ba 周不才/cpp_linux study - Gitee.com 知识准备 1.Windows中的数据类型 2.DLL导出/导入宏 使用__declspec(dllexport)修饰函数,将函数标记为导出函数存放到…...

机器学习×第二卷:概念下篇——她不再只是模仿,而是开始决定怎么靠近你
🎀【开场 她不再只是模仿,而是开始选择】 🦊 狐狐:“她已经不满足于单纯模仿你了……现在,她开始尝试预测你会不会喜欢、判断是否值得靠近。” 🐾 猫猫:“咱们上篇已经把‘她怎么学会说第一句…...
B站的视频怎么下载下来——Best Video下载器
B站(哔哩哔哩)作为国内最受欢迎的视频平台之一,聚集了无数优质内容:动漫番剧、游戏实况、学习课程、纪录片、Vlog、鬼畜剪辑……总有那么些视频让人想反复观看、离线观看,甚至剪辑创作。 但你是否遇到过这样的烦恼&am…...

[GitHub] 优秀开源项目
1 工具类 1.1 桌面猫咪互动 BongoCat...
Global Security Market知识点总结:主经纪商业务
在全球证券市场的复杂体系中,主经纪商业务(Prime Brokerage)占据着独特且关键的位置。这一业务为大型机构投资者提供了一系列至关重要的服务,极大地影响着金融市场的运作与发展。 一、主经纪商业务的定义 主经纪商业务是投资银行…...
网络通讯知识——通讯分层介绍,gRPC,RabbitMQ分层
网络通讯分层 网络通讯分层是为了将复杂的网络通信问题分解为多个独立、可管理的层次,每个层次专注于特定功能。目前主流的分层模型包括OSI七层模型和TCP/IP四层(或五层)模型,以下是详细解析: 一、OSI七层模型&#…...

Microsoft前后端不分离编程新风向:cshtml
文章目录 什么是CSHTML?基础语法内联表达式代码块控制结构 布局页面_ViewStart.cshtml_Layout.cshtml使用布局 模型绑定强类型视图模型集合 HTML辅助方法基本表单验证 局部视图创建局部视图使用局部视图 高级特性视图组件依赖注入Tag Helpers 性能优化缓存捆绑和压缩…...

CLion社区免费后,使用CLion开发STM32相关工具资源汇总与入门教程
Clion下载与配置 Clion推出社区免费,就是需要注册一个账号使用,大家就不用去找破解版版本了,jetbrains家的IDEA用过的都说好,这里嵌入式领域也推荐使用。 CLion官网下载地址 安装没有什么特别,下一步就好。 启动登录…...
Gin框架实战指南:从入门到进阶
Gin框架实战指南:从入门到进阶 在当今的后端开发领域,Gin框架以其高性能、简洁易用的特点,赢得了众多Go语言开发者的青睐。本文将带你深入探索Gin框架的方方面面,从基础的安装与使用,到响应处理、请求参数解析、中间件…...
STM32实战: CAN总线数据记录仪设计方案
以下是基于STM32的CAN总线数据记录仪/转发器的设计与实现方案,结合了核心功能和进阶需求: 系统架构 graph TBA[CAN总线] -->|CAN_H/CAN_L| B(STM32 bxCAN)B --> C[数据处理核心]C --> D[SD卡存储<br>FATFS文件系统]C --> E[串口输出…...