当前位置: 首页 > news >正文

YOLOv6-4.0部分代码阅读笔记-effidehead_lite.py

effidehead_lite.py

yolov6\models\heads\effidehead_lite.py

目录

effidehead_lite.py

1.所需的库和模块

2.class Detect(nn.Module): 

3.def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers): 


1.所需的库和模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from yolov6.layers.common import DPBlock
from yolov6.assigners.anchor_generator import generate_anchors
from yolov6.utils.general import dist2bbox

2.class Detect(nn.Module): 

class Detect(nn.Module):# 高效分离头。# 利用硬件感知设计,使用混合通道方法对解耦头进行优化。'''Efficient Decoupled HeadWith hardware-aware degisn, the decoupled head is optimized withhybridchannels methods.'''def __init__(self, num_classes=80, num_layers=3, inplace=True, head_layers=None):  # detection layer    检测层super().__init__()assert head_layers is not Noneself.nc = num_classes  # number of classes    类别数量self.no = num_classes + 5  # number of outputs per anchor    每个锚点的输出数量self.nl = num_layers  # number of detection layers    检测层数self.grid = [torch.zeros(1)] * num_layersself.prior_prob = 1e-2self.inplace = inplacestride = [8, 16, 32] if num_layers == 3 else [8, 16, 32, 64] # strides computed during build    构建期间计算的步长self.stride = torch.tensor(stride)self.grid_cell_offset = 0.5self.grid_cell_size = 5.0# Init decouple head    初始化解耦头self.stems = nn.ModuleList()self.cls_convs = nn.ModuleList()self.reg_convs = nn.ModuleList()self.cls_preds = nn.ModuleList()self.reg_preds = nn.ModuleList()# Efficient decoupled head layers    高效解耦的头部层for i in range(num_layers):idx = i*5self.stems.append(head_layers[idx])self.cls_convs.append(head_layers[idx+1])self.reg_convs.append(head_layers[idx+2])self.cls_preds.append(head_layers[idx+3])self.reg_preds.append(head_layers[idx+4])# 它用于初始化神经网络中特定层的偏置(biases)。这个方法特别针对于类别预测层( self.cls_preds )和边界框回归预测层( self.reg_preds )的偏置和权重进行初始化。# 接受 self 作为参数,表示类的实例。def initialize_biases(self):# 遍历所有类别预测层, self.cls_preds 是一个包含卷积层的列表。for conv in self.cls_preds:# 获取当前卷积层的偏置,并将其展平为一维张量。b = conv.bias.view(-1, )# 使用逻辑斯谛分布的公式来初始化偏置值。这里的 self.prior_prob 是一个先验概率,通常用于目标检测中表示目标存在的概率。这个公式确保了在开始训练时,模型对目标的存在与否持中立态度。b.data.fill_(-math.log((1 - self.prior_prob) / self.prior_prob))# 将初始化后的偏置值重新设置为卷积层的偏置,并确保它们是可训练的( requires_grad=True )。conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)# 获取当前卷积层的权重。w = conv.weight# 将权重初始化为0。w.data.fill_(0.)# 将权重重新设置为卷积层的权重,并确保它们是可训练的。conv.weight = torch.nn.Parameter(w, requires_grad=True)# 遍历所有边界框回归预测层, self.reg_preds 是一个包含卷积层的列表。for conv in self.reg_preds:# 获取当前卷积层的偏置,并将其展平为一维张量。b = conv.bias.view(-1, )# 将偏置值初始化为1.0,这是因为在边界框回归中,我们通常希望预测的边界框与真实边界框的中心点对齐。b.data.fill_(1.0)# 将初始化后的偏置值重新设置为卷积层的偏置,并确保它们是可训练的。conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)# 获取当前卷积层的权重。w = conv.weight# 将权重初始化为0。w.data.fill_(0.)# 将权重重新设置为卷积层的权重,并确保它们是可训练的。conv.weight = torch.nn.Parameter(w, requires_grad=True)def forward(self, x):if self.training:cls_score_list = []reg_distri_list = []for i in range(self.nl):x[i] = self.stems[i](x[i])cls_x = x[i]reg_x = x[i]cls_feat = self.cls_convs[i](cls_x)cls_output = self.cls_preds[i](cls_feat)reg_feat = self.reg_convs[i](reg_x)reg_output = self.reg_preds[i](reg_feat)cls_output = torch.sigmoid(cls_output)cls_score_list.append(cls_output.flatten(2).permute((0, 2, 1)))reg_distri_list.append(reg_output.flatten(2).permute((0, 2, 1)))cls_score_list = torch.cat(cls_score_list, axis=1)reg_distri_list = torch.cat(reg_distri_list, axis=1)return x, cls_score_list, reg_distri_listelse:cls_score_list = []reg_dist_list = []# def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5,  device='cpu', is_eval=False, mode='af'):# -> 根据特征生成锚点。# -> return anchor_points, stride_tensor  /  return anchors, anchor_points, num_anchors_list, stride_tensoranchor_points, stride_tensor = generate_anchors(x, self.stride, self.grid_cell_size, self.grid_cell_offset, device=x[0].device, is_eval=True, mode='af')for i in range(self.nl):b, _, h, w = x[i].shapel = h * wx[i] = self.stems[i](x[i])cls_x = x[i]reg_x = x[i]cls_feat = self.cls_convs[i](cls_x)cls_output = self.cls_preds[i](cls_feat)reg_feat = self.reg_convs[i](reg_x)reg_output = self.reg_preds[i](reg_feat)cls_output = torch.sigmoid(cls_output)cls_score_list.append(cls_output.reshape([b, self.nc, l]))reg_dist_list.append(reg_output.reshape([b, 4, l]))cls_score_list = torch.cat(cls_score_list, axis=-1).permute(0, 2, 1)reg_dist_list = torch.cat(reg_dist_list, axis=-1).permute(0, 2, 1)# def dist2bbox(distance, anchor_points, box_format='xyxy'): -> 将距离(ltrb)转换为盒子(xywh或xyxy)。 -> return bboxpred_bboxes = dist2bbox(reg_dist_list, anchor_points, box_format='xywh')pred_bboxes *= stride_tensorreturn torch.cat([pred_bboxes,torch.ones((b, pred_bboxes.shape[1], 1), device=pred_bboxes.device, dtype=pred_bboxes.dtype),cls_score_list],axis=-1)

3.def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers): 

def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers):head_layers = nn.Sequential(# stem0DPBlock(in_channel=channels_list[0],out_channel=channels_list[0],kernel_size=5,stride=1),# cls_conv0DPBlock(in_channel=channels_list[0],out_channel=channels_list[0],kernel_size=5,stride=1),# reg_conv0DPBlock(in_channel=channels_list[0],out_channel=channels_list[0],kernel_size=5,stride=1),# cls_pred0nn.Conv2d(in_channels=channels_list[0],out_channels=num_classes * num_anchors,kernel_size=1),# reg_pred0nn.Conv2d(in_channels=channels_list[0],out_channels=4 * num_anchors,kernel_size=1),# stem1DPBlock(in_channel=channels_list[1],out_channel=channels_list[1],kernel_size=5,stride=1),# cls_conv1DPBlock(in_channel=channels_list[1],out_channel=channels_list[1],kernel_size=5,stride=1),# reg_conv1DPBlock(in_channel=channels_list[1],out_channel=channels_list[1],kernel_size=5,stride=1),# cls_pred1nn.Conv2d(in_channels=channels_list[1],out_channels=num_classes * num_anchors,kernel_size=1),# reg_pred1nn.Conv2d(in_channels=channels_list[1],out_channels=4 * num_anchors,kernel_size=1),# stem2DPBlock(in_channel=channels_list[2],out_channel=channels_list[2],kernel_size=5,stride=1),# cls_conv2DPBlock(in_channel=channels_list[2],out_channel=channels_list[2],kernel_size=5,stride=1),# reg_conv2DPBlock(in_channel=channels_list[2],out_channel=channels_list[2],kernel_size=5,stride=1),# cls_pred2nn.Conv2d(in_channels=channels_list[2],out_channels=num_classes * num_anchors,kernel_size=1),# reg_pred2nn.Conv2d(in_channels=channels_list[2],out_channels=4 * num_anchors,kernel_size=1))if num_layers == 4:head_layers.add_module('stem3',# stem3DPBlock(in_channel=channels_list[3],out_channel=channels_list[3],kernel_size=5,stride=1))head_layers.add_module('cls_conv3',# cls_conv3DPBlock(in_channel=channels_list[3],out_channel=channels_list[3],kernel_size=5,stride=1))head_layers.add_module('reg_conv3',# reg_conv3DPBlock(in_channel=channels_list[3],out_channel=channels_list[3],kernel_size=5,stride=1))head_layers.add_module('cls_pred3',# cls_pred3nn.Conv2d(in_channels=channels_list[3],out_channels=num_classes * num_anchors,kernel_size=1))head_layers.add_module('reg_pred3',# reg_pred3nn.Conv2d(in_channels=channels_list[3],out_channels=4 * num_anchors,kernel_size=1))return head_layers

相关文章:

YOLOv6-4.0部分代码阅读笔记-effidehead_lite.py

effidehead_lite.py yolov6\models\heads\effidehead_lite.py 目录 effidehead_lite.py 1.所需的库和模块 2.class Detect(nn.Module): 3.def build_effidehead_layer(channels_list, num_anchors, num_classes, num_layers): 1.所需的库和模块 import torch import t…...

重学SpringBoot3-整合 Elasticsearch 8.x (一)客户端方式

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 这里写目录标题 1. 为什么选择 Elasticsearch?2. Spring Boot 3 和 Elasticsearch 8.x 的集成概述2.1 准备工作2.2 添加依赖 3. Elasticsearch 客户端配置方式…...

极简实现酷炫动效:Flutter隐式动画指南第三篇自定义Flutter隐式动画

目录 前言 一、TweenAnimationBuilder 二、使用TweenAnimationBuilder实现的一些动画效果 1.调整透明度的动画 2.稍微复杂点的组合动画 3.数字跳动的动画效果 前言 上两节博客分别介绍了Flutter中的隐式动画的基础知识以及使用隐式动画实现的一些动画效果。当系统提供的隐…...

无人机维护保养、部件修理更换技术详解

无人机作为一种精密的航空设备,其维护保养和部件修理更换是确保飞行安全、延长使用寿命的重要环节。以下是对无人机维护保养、部件修理更换技术的详细解析: 一、无人机维护保养技术 1. 基础构造理解: 熟悉无人机的基本构造,包括…...

xilinx vitis 更换硬件平台——ZYNQ学习笔记5

1、重新生成硬件信息 2、选择带有bit信息 3、设施路径和名字 4、打开更新硬件选项 5、选择新的硬件信息 6、打开系统工程界面 7、复位硬件信息 更新完毕...

vscode makfile编译c程序

编译工具安装 为了在 Windows 上安装 GCC,您需要安装 MinGW-w64。 MinGW-w64 是一个开源项目,它为 Windows 系统提供了一个完整的 GCC 工具链,支持编译生成 32 位和 64 位的 Windows 应用程序。 1. 下载MinGW-w64源代码,如图点…...

【学术论文投稿】探索嵌入式硬件设计:揭秘智能设备的心脏

【IEEE出版】第六届国际科技创新学术交流大会暨通信、信息系统与软件工程学术会议(CISSE 2024)_艾思科蓝_学术一站式服务平台 更多学术会议论文投稿请看:https://ais.cn/u/nuyAF3 目录 引言 嵌入式系统简介 嵌入式硬件设计的组成部分 设…...

JavaScript 概述

### JavaScript 概述 JavaScript 是一种广泛使用的编程语言,它最初由 Netscape 公司的 Brendan Eich 在1995年创建,目的是为网页添加交互性。随着时间的发展,JavaScript 已经从一个简单的脚本语言演变成了一种功能强大的编程语言,…...

2024年10月个人工作生活总结

本文为 2024年10月工作生活总结。 研发编码 一个证书过期问题记录 某天,现场反馈某服务无法使用问题,经同事排查,是因为服务证书过期导致的。原来,证书的有效期设置为5年,这个月刚好到期。 虽然这个问题与自己无直接…...

uniapp ,微信小程序,滚动(下滑,上拉)到底部加载下一页内容

前言 小程序的内容基本都是滑动到底部加载下一页,这个一般都没有什么好用的组件来用,我看vant和uniapp的插件里最多只有个分页,没有滚动到底部加载下一页。再次做个记录。 效果预览 下滑到底部若是有下一页,则会自动加载下一页&…...

MySQL中的日志类型有哪些?binlog、redolog和undolog的作用和区别是什么?

简介: MySQL中有六种日志文件,分别是:重做日志(redo log)、回滚日志(undo log)、二进制日志(binlog)、错误日志(errorlog)、慢查询日志&#xff0…...

【uni-app】创建自定义模板

1. 步骤 打开自定义模板文件夹 在此文件夹下创建模板文件(注意后缀名) 重新点击“新建页面” 即可看到新建的模板 2. 注意事项 创建的模板必须文件类型对应(vue模板就创建*.vue文件, uvue模板就创建*.uvue文件)...

Cesium移动Primitive位置

与传统的Entity实体不同,Primitive作为一种自定义基本图元,几何形状、材质和其他属性均由使用者定义,在需要绘制大量静态几何图形的高效渲染场景中更为适用。 Primitive的移动涉及到矩阵变换,并不像Entity那样给它替换一个新的坐…...

安卓13默认连接wifi热点 android13默认连接wifi

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.编译6.彩蛋1.前言 有时候我们需要让固件里面内置好,相关的wifi的ssid和密码,让固件起来就可以连接wifi,不用在手动操作。 2.问题分析 这个功能,使用普通的安卓代码就可以实现了。 3.代…...

parted 磁盘分区

目录 磁盘格式磁盘分区文件系统挂载使用扩展 - parted、fdisk、gdisk 区别 磁盘格式 parted /dev/vdcmklabel gpt # 设置磁盘格式为GPT p # 打印磁盘信息此时磁盘格式设置完成! 磁盘分区 开始分区: mkpart data_mysql # 分区名&…...

第三百零八节 Log4j教程 - Log4j日志到数据库

Log4j教程 - Log4j日志到数据库 我们可以使用log4j API通过使用org.apache.log4j.jdbc.JDBCAppender对象将信息记录到数据库中。 下表列出了JDBCAppender的配置属性。 属性描述bufferSize设置缓冲区大小。默认大小为1。driverJDBC驱动程序类。默认为sun.jdbc.odbc.JdbcOdbcDr…...

ai智能语音电销机器人可以做哪些事情?

AI智能语音电销机器人是结合人工智能技术进行自动化电话销售和客户互动的工具,能够完成一系列任务,有助于提升销售效果、优化客户体验和提高工作效率。以下是AI智能语音电销机器人可以做的一些主要事情: 1. 自动拨号 AI语音电销机器人可以自…...

CleanShot X - Mac(苹果电脑)专业截图录屏软件

CleanShot X 不仅提供了基础的截图功能,更内置了强大的图片编辑器,让你能轻松添加标注、形状、文本……以及将多个截图进行合并。 无论是为社交媒体制作图文,还是制作专业的产品 / 教程演示,CleanShot X 都能满足你的需求。 软件…...

Kafka 客户端工具使用分享【offsetexplorer】

前言: 前面我们使用 Spring Boot 继承 Kafka 完成了消息发送,有朋友会问 Kafka 有没有好用的客户端工具,RabbitMQ、RocketMQ 都有自己的管理端,那 Kafka 如何去查看发送出去的消息呢? 本篇我们就来分享一个好用的工具…...

uni-app 下拉刷新、 上拉触底(列表信息)、 上滑加载(短视频) 一键搞定

一、下拉刷新 1. 首先找到pages.json中 给需要进行下拉刷新的页面设置可以下拉刷新 2. 然后在需要实现下拉刷新的script标签内添加 导入onPullDownRefresh import {onPullDownRefresh} from dcloudio/uni-app 下拉刷新触发的事件 onPullDownRefresh(()> {console.log(正…...

基于FPGA的PID算法学习———实现PID比例控制算法

基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中,新增了一个本地验证码接口 /code,使用函数式路由(RouterFunction)和 Hutool 的 Circle…...

蓝桥杯 冶炼金属

原题目链接 🔧 冶炼金属转换率推测题解 📜 原题描述 小蓝有一个神奇的炉子用于将普通金属 O O O 冶炼成为一种特殊金属 X X X。这个炉子有一个属性叫转换率 V V V,是一个正整数,表示每 V V V 个普通金属 O O O 可以冶炼出 …...

springboot整合VUE之在线教育管理系统简介

可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生,小白用户,想学习知识的 有点基础,想要通过项…...

Razor编程中@Html的方法使用大全

文章目录 1. 基础HTML辅助方法1.1 Html.ActionLink()1.2 Html.RouteLink()1.3 Html.Display() / Html.DisplayFor()1.4 Html.Editor() / Html.EditorFor()1.5 Html.Label() / Html.LabelFor()1.6 Html.TextBox() / Html.TextBoxFor() 2. 表单相关辅助方法2.1 Html.BeginForm() …...

uniapp 字符包含的相关方法

在uniapp中,如果你想检查一个字符串是否包含另一个子字符串,你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的,但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...

什么是VR全景技术

VR全景技术,全称为虚拟现实全景技术,是通过计算机图像模拟生成三维空间中的虚拟世界,使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验,结合图文、3D、音视频等多媒体元素…...