物体检测-系列教程20:YOLOV5 源码解析10 (Model类前向传播、forward_once函数、_initialize_biases函数)
😎😎😎物体检测-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码
14、Model类
14.2 前向传播
def forward(self, x, augment=False, profile=False):if augment:img_size = x.shape[-2:] # height, widths = [1, 0.83, 0.67] # scalesf = [None, 3, None] # flips (2-ud, 3-lr)y = [] # outputsfor si, fi in zip(s, f):xi = scale_img(x.flip(fi) if fi else x, si)yi = self.forward_once(xi)[0] # forwardyi[..., :4] /= si # de-scaleif fi == 2:yi[..., 1] = img_size[0] - yi[..., 1] # de-flip udelif fi == 3:yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lry.append(yi)return torch.cat(y, 1), None # augmented inference, trainelse:return self.forward_once(x, profile) # single-scale inference, train
这段代码是forward
方法的实现,它定义了模型的前向传播过程,支持正常和增强两种推理模式:
- 前向传播函数,输入
x
,是否进行数据增强augment
,是否分析性能profile
- 是否使用数据增强
- img_size ,获取输入图像的长宽
- s,定义缩放尺度
- f,定义翻转模式,这里
None
表示不翻转,3
表示左右翻转 - y,初始化输出列表
- 使用zip函数将尺度因子列表s和翻转指示列表f组合起来,然后遍历每一对尺度因子和翻转指示
- xi,如果fi不为None,先根据fi的值对图像进行翻转,然后调用scale_img函数根据si的值缩放处理图像;否则直接调用scale_img函数根据si的值缩放处理图像
- yi,将xi进行一次前向传播,取第一个输出
- 对输出yi的前四个维度进行缩放调整,以恢复到原始的尺度。这通常是对边界框坐标的调整
- 如果使用了上下翻转
- 则调整y的坐标
- 如果使用了左右翻转
- 则调整x坐标
- 将处理后的输出添加到列表
- 将list y的所有输出按照第一个维度进行拼接
- 如果在当前循环中没有使用数据增强
- 直接进行一次正常的前向传播
前向传播方法,包括了一个可选的图像增强步骤。在增强模式下,通过对输入图像应用不同的尺度和翻转,生成多个变体,对每个变体单独进行前向传播,并对输出进行调整以适应原始图像的尺寸和方向,最后将所有变体的输出合并。这种方法可以增加模型的泛化能力,因为它让模型在训练时见到更多的数据变化。如果不进行图像增强,它将执行一次标准的前向传播。通过这种设计,模型可以更灵活地应对不同的输入和训练需求
14.3 forward_once函数
def forward_once(self, x, profile=False):y, dt = [], [] # outputsfor m in self.model:if m.f != -1: # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]if profile:try:import thopo = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPSexcept:o = 0t = time_synchronized()for _ in range(10):_ = m(x)dt.append((time_synchronized() - t) * 100)print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))x = m(x) # runy.append(x if m.i in self.save else None) # save outputif profile:print('%.1fms total' % sum(dt))return x
- forward_once函数,输入和forward函一样
- y, dt ,初始化两个空列表,y用于存储每一层的输出,dt用于在性能分析模式下存储每一层的执行时间
- 遍历模型的每一层
- 如果当前层的输入不是来自上一层的输出
- 如果m.f是整数,则直接从y中获取对应的层输出作为输入。如果m.f是一个列表,则根据列表中的索引从y中选择输入,如果索引为-1,则使用原始输入x
- 是否开启性能分析模式
- try
- 导入thop库,用于计算浮点运算数(FLOPS)
- o,使用thop.profile计算当前层m的FLOPS,结果除以1E9转换为GigaFLOPS,并乘以2。这里假设thop.profile返回的是一个元组,其第一个元素是所需的FLOPS
- 如果尝试执行失败
- 则将o(FLOPS)设置为0
- t,调用time_synchronized函数,获取当前精确的时间
- 循环10次
- 为了稳定测量时间,通过多次执行减少偶然误差
- 调用time_synchronized函数计算执行当前层操作的总时间,并将其添加到dt列表中
- 打印当前层的FLOPS、参数数量、执行时间和层类型。为性能分析提供详细信息
- 执行当前层的前向传播,并更新x为该层的输出
- 如果当前层的索引m.i在保存列表self.save中,则将输出x保存到y列表中;否则,保存
None
. 这样做可以减少内存占用,只保存那些后续步骤中需要的层的输出 - 再次检查是否开启了性能分析模式。这个检查是为了在性能分析完成后打印总的执行时间
- 如果开启了性能分析,计算所有层执行时间的总和并打印。这提供了整个前向传播过程的总执行时间,帮助了解模型的性能瓶颈
- 返回最后一层的输出
14.4 _initialize_biases函数
def _initialize_biases(self, cf=None):m = self.model[-1] # Detect() modulefor mi, s in zip(m.m, m.stride): # fromb = mi.bias.data.view(m.na, -1).clone()obj_add = math.log(8 / (640 / s) ** 2) # 计算obj层需要增加的值cls_add = math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())b[:, 4] = b[:, 4] + obj_addb[:, 5:] = b[:, 5:] + cls_addmi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
- 初始化偏执的函数,接受一个可选的参数,这个参数用于根据数据集中各类别出现的频率来调整分类(cls)层的偏置
- m,获取模型中的最后一个模块,检测层(Detect模块),用于目标检测
- 遍历检测层中的每个子模块mi及其对应的步长stride,这里的步长是指输入图像被缩减的尺度,对目标尺寸预测非常关键
- b,获取子模块mi的偏置项,并将其重塑(reshape)成(m.na, -1)的形状,其中m.na是每个特征图位置预测的锚框数量。.clone()确保在修改b时不会影响原始的偏置值
- obj_add ,计算对象(obj)层偏置需要增加的值。这个公式基于假设每640像素的图像中有8个对象,并根据特征图的尺度(通过步长s计算)来调整。目的是调整检测层对于不同尺寸特征图上对象数量预测的偏置
- cls_add ,计算分类(cls)层偏置需要增加的值。如果没有提供类频率(cf为None),则使用一个基于类数量m.nc的固定公式。如果提供了类频率,那么使用类频率来计算每个类的偏置调整值,以此反映数据集中类别的分布
- 将计算出的对象层偏置调整值加到b的第4列上,这是因为在目标检测中,偏置项通常包括4个坐标偏置和一个对象存在的偏置,后者位于第5个位置(索引为4)
- 将计算出的分类层偏置调整值加到b的第5列及之后的所有列上,对应于每个类别的偏置
- 将调整后的偏置b重塑回原始形状并设置为mi的偏置,确保这些偏置在训练过程中可以被进一步调整(requires_grad=True)
14.5 其他辅助函数
def _print_biases(self):m = self.model[-1] # Detect() modulefor mi in m.m: # fromb = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
- 获取模型的最后一个模块,这里假设是一个目标检测模块(Detect模块)
- 遍历检测模块中的每个子模块mi
- 取得当前子模块mi的偏置,通过.detach()确保不会影响梯度计算,.view(m.na, -1)调整形状以匹配锚点数量m.na和偏置的其它维度,最后进行转置以便于处理
- 打印当前子模块卷积层的输入通道数和偏置的统计信息,包括前五个偏置的平均值和之后所有偏置的平均值
fuse函数,用于融合模型中的卷积层(Conv2d)和批归一化层(BatchNorm2d)
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layersprint('Fusing layers... ')for m in self.model.modules():if type(m) is Conv:m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatabilitym.conv = fuse_conv_and_bn(m.conv, m.bn) # update convm.bn = None # remove batchnormm.forward = m.fuseforward # update forwardself.info()return self
- 遍历模型中的所有模块
- 检查当前模块是否为卷积层
- 为了兼容PyTorch 1.6.0,清空非持久性缓冲区集合
- 使用fuse_conv_and_bn函数来融合当前卷积层和其后的批归一化层
- 将批归一化层设为None,表示移除批归一化层
- 更新模块的前向传播函数为融合后的版本
- 在完成融合后,调用info方法打印模型信息
- 返回更新后的模型实例
def info(self): # print model informationmodel_info(self)
调用一个model_info函数,传入当前模型实例,用于收集和打印模型的详细信息,如参数数量、层的类型等
相关文章:
物体检测-系列教程20:YOLOV5 源码解析10 (Model类前向传播、forward_once函数、_initialize_biases函数)
😎😎😎物体检测-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 点我下载源码 14、Model类 14.2 前向传播 def forward(self, x, augmentFalse, profileFalse):if augm…...

贪吃蛇(C语言)步骤讲解
一:文章大概 使用C语言在windows环境的控制台中模拟实现经典小游戏 实现基本功能: 1.贪吃蛇地图绘制 2.蛇吃食物的功能(上,下,左,右方向控制蛇的动作) 3.蛇撞墙死亡 4.计算得分 5.蛇身加…...

MySQL 数据库表设计和优化
一、数据结构设计 正确的数据结构设计对数据库的性能是非常重要的。 在设计数据表时,尽量遵循一下几点: 将数据分解为合适的表,每个表都应该有清晰定义的目的,避免将过多的数据存储在单个表中。使用适当的数据类型来存储数据&…...

JavaScript进阶-高阶技巧
文章目录 高阶技巧深浅拷贝浅拷贝深拷贝 异常处理throw抛异常try/caych捕获异常debugger 处理thisthis指向改变this 性能优化防抖节流 高阶技巧 深浅拷贝 只针对引用类型 浅拷贝 拷贝对象后,里面的属性值是简单数据类型直接拷贝值,如果属性值是引用数…...
C语言中“#“和“##“的用法
1. 前言 # :把宏参数变为一个字符串, ##:把两个宏参数贴合在一起. 2. 一般用法 #include<stdio.h> #define toString(str) #str //转字符串 #define conStr(a,b) (a##b)//连接 int main() { printf(toString(12345)): //输出字符串&q…...
Linux命令-clock命令(用于调整 RTC 时间)
说明 clock命令用于调整 RTC 时间。 RTC 是电脑内建的硬件时间,执行这项指令可以显示现在时刻,调整硬件时钟的时间,将系统时间设成与硬件时钟之时间一致,或是把系统时间回存到硬件时钟。 语法 clock [--adjust][--debug][--dir…...
编程笔记 Golang基础 045 math包
编程笔记 Golang基础 045 math包 一、math包主要功能常量:函数:数值运算:三角函数:对数函数:随机数相关: 二、示例代码一三、示例代码二小结 Go 语言的标准库 math 提供了一系列基础数学函数和常量…...

[Java 探索者之路] 一个大厂都在用的分布式任务调度平台
分布式任务调度平台是一种能够在分布式计算环境中调度和管理任务的系统,在此环境下,各个任务可以在独立的节点上运行。它有助于提升资源利用率,增强系统扩展性以及提高系统对错误的容忍度。 文章目录 1. 分布式任务调度平台1. 基本概念1.1 任…...

基于JAVA springboot+mybatis智慧生活分享平台设计和实现
基于JAVA springbootmybatis智慧生活分享平台设计和实现 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末…...

详细了解C++中的namespace命名空间
键盘敲烂,月薪过万,同学们,加油呀! 目录 键盘敲烂,月薪过万,同学们,加油呀! 一、命名空间的理解 二、::作用域运算符 三、命名空间(namespace&…...

#WEB前端(HTML属性)
1.实验:a,img 2.IDE:VSCODE 3.记录: a: href插入超链接 默认情况下在本窗口打开链接, target可以设置打开的窗口,parent在父窗口打开,blank新开串口打开,top在顶层串口打开,self为默认在本窗口打开 img: 插入图片 可以插…...
LeetCode---【和的操作】
目录 两数之和我的答案在b站up那里学到的【然后自己复写】 和为 K 的子数组在b站up那里学到的【然后自己复写】 三数之和在b站up那里学到的【然后自己复写】 两数相加【链表】我的半路答案:没有看到是链表在b站up那里学到的【复写失败后整理】 两数之和 我的答案 …...

Docker容器与虚拟化技术:OpenEuler 使用 docker-compose 部署 LNMP
目录 一、实验 1.环境 2.OpenEuler 部署 docker-compose 3.docker-compose 部署 LNMP 二、问题 1.ntpdate未找到命令 2.timedatectl 如何设置时区与时间同步 3.php网页显示时区不对 一、实验 1.环境 (1)主机 表1 主机 系统架构版本IP备注Lin…...

13-微服务初探-自研微服务框架
微服务初探 1. 架构变迁之路 1.1 单体架构 互联网早期,一般的网站应用流量较小,只需要一个应用,将所有的功能代码都部署在一起就可以,这样可以减少开发,部署和维护的成本。 比如说一个电商系统,里面包含…...

LeetCode——二叉树(Java)
二叉树 简介[简单] 144. 二叉树的前序遍历、94. 二叉树的中序遍历、145. 二叉树的后序遍历二叉树层序遍历[中等] 102. 二叉树的层序遍历[中等] 107. 二叉树的层序遍历 II[中等] 199. 二叉树的右视图[简单] 637. 二叉树的层平均值[中等] 429. N 叉树的层序遍历[中等] 515. 在每个…...

LDR6328芯片:智能家居时代的小家电充电革新者
在当今的智能家居时代,小家电的供电方式正变得越来越智能化和高效化。 利用PD(Power Delivery)芯片进行诱骗取电,为后端小家电提供稳定电压的技术,正逐渐成为行业的新宠。在这一领域,LDR6328芯片以其出色的…...

用node写后端环境运行时报错Port 3000 is already in use
解决方法:关闭之前运行的3000端口,操作如下 1.WindowR输入cmd确定,打开命令面板 2.查看本机端口详情 netstat -ano|findstr "3000" 3.清除3000端口 taskkill -pid 41640 -f 最后再重新npm start即可,这里要看你自己项目中package.joson的启动命令是什…...

Git 如何上传本地的所有分支
Git 如何上传本地的所有分支 比如一个本地 git 仓库里定义了两个远程分支,一个名为 origin, 一个名为 web 现在本地有一些分支是 web 远程仓库没有的分支,如何将本地所有分支都推送到 web 这个远程仓库上呢 git push web --all...

【airtest】自动化入门教程(一)AirtestIDE
目录 一、下载与安装 1、下载 2、安装 3、打开软件 二、web自动化配置 1、配置chrome浏览器 2、窗口勾选selenium window 三、新建项目(web) 1、新建一个Airtest项目 2、初始化代码 3、打开一个网页 四、恢复默认布局 五、新建项目…...

ChatGPT支持下的PyTorch机器学习与深度学习技术应用
近年来,随着AlphaGo、无人驾驶汽车、医学影像智慧辅助诊疗、ImageNet竞赛等热点事件的发生,人工智能迎来了新一轮的发展浪潮。尤其是深度学习技术,在许多行业都取得了颠覆性的成果。另外,近年来,Pytorch深度学习框架受…...
<6>-MySQL表的增删查改
目录 一,create(创建表) 二,retrieve(查询表) 1,select列 2,where条件 三,update(更新表) 四,delete(删除表…...

【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...

基于ASP.NET+ SQL Server实现(Web)医院信息管理系统
医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上,开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识,在 vs 2017 平台上,进行 ASP.NET 应用程序和简易网站的开发;初步熟悉开发一…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

微服务商城-商品微服务
数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
Mobile ALOHA全身模仿学习
一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...

华硕a豆14 Air香氛版,美学与科技的馨香融合
在快节奏的现代生活中,我们渴望一个能激发创想、愉悦感官的工作与生活伙伴,它不仅是冰冷的科技工具,更能触动我们内心深处的细腻情感。正是在这样的期许下,华硕a豆14 Air香氛版翩然而至,它以一种前所未有的方式&#x…...
【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)
LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 题目描述解题思路Java代码 题目描述 题目链接:LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...
tomcat指定使用的jdk版本
说明 有时候需要对tomcat配置指定的jdk版本号,此时,我们可以通过以下方式进行配置 设置方式 找到tomcat的bin目录中的setclasspath.bat。如果是linux系统则是setclasspath.sh set JAVA_HOMEC:\Program Files\Java\jdk8 set JRE_HOMEC:\Program Files…...