基于transformer的解码decode目标检测框架(修改DETR源码)
提示:transformer结构的目标检测解码器,包含loss计算,附有源码
文章目录
- 前言
- 一、main函数代码解读
- 1、整体结构认识
- 2、main函数代码解读
- 3、源码链接
- 二、decode模块代码解读
- 1、decoded的TransformerDec模块代码解读
- 2、decoded的TransformerDecoder模块代码解读
- 3、decoded的DecoderLayer模块代码解读
- 三、decode模块训练demo代码解读
- 1、解码数据输入格式
- 2、解码训练demo代码解读
- 四、decode模块预测demo代码解读
- 1、预测数据输入格式
- 2、解码预测demo代码解读
- 五、losses模块代码解读
- 1、matcher初始化
- 2、二分匹配matcher代码解读
- 3、num_classes参数解读
- 4、losses的demo代码解读
前言
最近重温DETR模型,越发感觉detr模型结构精妙之处,不同于anchor base 与anchor free设计,直接利用100框给出预测结果,使用可学习learn query深度查找,使用二分匹配方式训练模型。为此,我基于detr源码提取解码decode、loss计算等系列模块,并重构、修改、整合一套解码与loss实现的框架,该框架可适用任何backbone特征提取接我框架,实现完整训练与预测,我也有相应demo指导使用我的框架。那么,接下来,我将完整介绍该框架源码。同时,我将此源码进行开源,并上传github中,供读者参考。
一、main函数代码解读
1、整体结构认识
在介绍main函数代码前,我先说下整体框架结构,该框架包含2个文件夹,一个losses文件夹,用于处理loss计算,一个是obj_det文件,用于transformer解码模块,该模块源码修改于detr模型,也包含main.py,该文件是整体解码与loss计算demo示意代码,如下图。

2、main函数代码解读
该代码实际是我随机创造了标签target数据与backbone特征提取数据及位置编码数据,使其能正常运行的demo,其代码如下:
import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterionif __name__ == '__main__':Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)num_classes = 4 # 类别+1matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2) # 二分匹配不同任务分配的权重losses = ['labels', 'boxes', 'cardinality'] # 计算loss的任务weight_dict = {'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2} # 为dert最后一个设置权重criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)# 下面使用iter,我构造了虚拟模型编码数据与数据加载标签数据src = torch.rand((391, 2, 256))pos_embed = torch.ones((391, 1, 256))# 创造真实target数据target1 = {'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}target2 = {'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}target = [target1, target2]res = Model(src, pos_embed)losses = criterion(res, target)print(losses)
如下图:

3、源码链接
源码链接:点击这里
二、decode模块代码解读
该模块主要是使用transform方式对backbone提取特征的解码,主要使用learn query等相关trike与transform解码方式内容。
我主要介绍TransformerDec、TransformerDecoder、DecoderLayer模块,为依次被包含关系,或说成后者是前者组成部分。
1、decoded的TransformerDec模块代码解读
该类大意是包含了learn query嵌入、解码transform模块调用、head头预测logit与boxes等内容,是实现解码与预测内容,该模块参数或解释已有注释,读者可自行查看,其代码如下:
class TransformerDec(nn.Module):'''d_model=512, 使用多少维度表示,实际为编码输出表达维度nhead=8, 有多少个头num_queries=100, 目标查询数量,可学习querynum_decoder_layers=6, 解码循环层数dim_feedforward=2048, 类似FFN的2个nn.Linear变化dropout=0.1,activation="relu",normalize_before=False,解码结构使用2种方式,默认False使用post解码结构output_intermediate_dec=False, 若为True保存中间层解码结果(即:每个解码层结果保存),若False只保存最后一次结果,训练为True,推理为Falsenum_classes: num_classes数量与数据格式有关,若类别id=1表示第一类,则num_classes=实际类别数+1,若id=0表示第一个,则num_classes=实际类别数额外说明,coco类别id是1开始的,假如有三个类,名称为[dog,cat,pig],batch=2,那么参数num_classes=4,表示3个类+1个背景,模型输出src_logits=[2,100,5]会多出一个预测,target_classes设置为[2,100],其值为4(该值就是背景,而有类别值为1、2、3),那么target_classes中没有值为0,我理解模型不对0类做任何操作,是个无效值,模型只对1、2、3、4进行loss计算,然4为背景会比较多,作者使用权重0.1避免其背景过度影响。forward return: 返回字典,包含{'pred_logits':[], # 为列表,格式为[b,100,num_classes+2]'pred_boxes':[], # 为列表,格式为[b,100,4]'aux_outputs'[{},...] # 为列表,元素为字典,每个字典为{'pred_logits':[],'pred_boxes':[]},格式与上相同}'''def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):super().__init__()self.num_queries = num_queriesself.query_embed = nn.Embedding(num_queries, d_model) # 与编码输出表达维度一致self.output_intermediate_dec = output_intermediate_decdecoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers相关文章:
基于transformer的解码decode目标检测框架(修改DETR源码)
提示:transformer结构的目标检测解码器,包含loss计算,附有源码 文章目录 前言一、main函数代码解读1、整体结构认识2、main函数代码解读3、源码链接二、decode模块代码解读1、decoded的TransformerDec模块代码解读2、decoded的TransformerDecoder模块代码解读3、decoded的De…...
Java SE 学习笔记(十七)—— 单元测试、反射
目录 1 单元测试1.1 单元测试概述1.2 单元测试快速入门1.3 JUnit 常用注解 2 反射2.1 反射概述2.2 获取类对象2.3 获取构造器对象2.4 获取成员变量对象2.5 获取常用方法对象2.6 反射的作用2.6.1 绕过编译阶段为集合添加数据2.6.2 通用框架的底层原理 1 单元测试 1.1 单元测试概…...
HNU-计算机网络-实验1-应用协议与数据包分析实验(Wireshark)
计算机网络 课程基础实验一 应用协议与数据包分析实验(Wireshark) 计科210X 甘晴void 202108010XXX 一、实验目的: 通过本实验,熟练掌握Wireshark的操作和使用,学习对HTTP协议进行分析。 二、实验内容 2.1 HTTP 协议简介 HTTP 是超文本…...
【深度学习】快速制作图像标签数据集以及训练
快速制作图像标签数据集以及训练 制作DataSet 先从网络收集十张图片 每种十张 定义dataSet和dataloader import glob import torch from torch.utils import data from PIL import Image import numpy as np from torchvision import transforms import matplotlib.pyplot…...
Spring Boot Web MVC
文章目录 一、Spring Boot Web MVC 概念二、状态码三、其他注解四、响应操作 一、Spring Boot Web MVC 概念 Spring Web MVC 是⼀个 Web 框架,一开始就包含在Spring 框架里。 1. MVC 定义 软件⼯程中的⼀种软件架构设计模式,它把软件系统分为模型、视…...
设置防火墙
1.RHEL7中的防火墙类型 防火墙只能同时使用一张,firewall底层调用的还是lptables的服务: firewalld:默认 ,基于不同的区域做规则 iptables: RHEL6使用,基于链表 Ip6tables Ebtables 2.防火墙的配置方式 查看防火墙状态: rootlinuxidc -]#systemct…...
3.Docker的客户端指令学习与实战
1.Docker的命令 1.1 启动Docker(systemctl start docker) systemctl start docker1.2 查看docker的版本信息(docker version) docker version1.3 显示docker系统范围的信息(docker info) docker info1.4…...
【微服务开篇-RestTemplate服务调用、Eureka注册中心、Nacos注册中心】
本篇用到的资料:https://gitee.com/Allengan/cloud-demo.githttps://gitee.com/Allengan/cloud-demo.git 目录 1.认识微服务 1.1.单体架构 1.2.分布式架构 1.3.微服务 1.4.SpringCloud 1.5.总结 2.服务拆分和远程调用 2.1.服务拆分原则 2.2.服务拆分示例 …...
python if和while的区别有哪些
python if和while的区别有哪些?下面给大家具体介绍: 1、用法 while和if本身就用法不同,一个是循环语句,一个是判断语句。 2、运行模式 if 只做判断,判断一次之后,便不会再回来了。 while 的话…...
Unity计时器
using UnityEngine; using System.Collections;public class Timer : MonoBehaviour {public float duration 1.0f; // 定时器持续时间public bool isLooping false; // 是否循环public bool isPaused false; // 是否暂停计时器private float currentDuration 0.0f; // 当前…...
Unity热更新介绍
打包函数 BuildPipeline.BuildAssetBundles("AssetBundles", BuildAssetBundleOptions.ChunkBasedCompression, BuildTarget.Android);打包策略和方案 按文件夹打包:Bundle数量少,首次下载块,但是后期更新补丁大按文件打包&#…...
在虚拟机centos7中部署docker+jenkins最新稳定版
在虚拟机centos7中部署dockerjenkins最新稳定版 查看端口是否被占用 lsof -i:80 查看运行中容器 docker ps 查看所有容器 docker ps -a 删除容器 docker rm 镜像/容器名称 强制删除 docker rmi -f 镜像名 查看当前目录 pwd 查看当前目录下所有文件名称 ls 赋予权限 chown 777 …...
nodejs express vue 点餐外卖系统源码
开发环境及工具: nodejs,vscode(webstorm),大于mysql5.5 技术说明: nodejs express vue elementui 功能介绍: 用户端: 登录注册 首页显示搜索菜品,轮播图…...
微信小程序导入js使用时候报错
我是引入weapp库时候,导入js会报错。 需要在小程序开发工具里面配置 就可以了。...
相机存储卡被格式化了怎么恢复?数据恢复办法分享!
随着时代的发展,相机被越来越多的用户所使用,这也意味着更多的用户面临着相机数据丢失的问题,很多用户在使用相机的过程中,都出现过不小心格式化相机存储卡的情况,里面的数据也将一并消失,相机存储卡被格式…...
Firefox修改缓存目录的方法
打开Firefox,在地址栏输入“about:config” 查找是否有 browser.cache.disk.parent_directory,如果没有就新建一个同名的字符串,然后修改值为你要存放Firefox浏览器缓存的目录地址(E:\FirefoxCacheFiles) 然后重新…...
maven子模块无法导入jar包问题
明明本地仓库有jar包 maven子模块无法导入jar包,然后放到父项目的pom.xml则可以导入 可以试试更新仓库后,引入成功...
ardupilot开发 --- 代码解析 篇
0. 前言 根据SITL的断点调试和自己阅读代码的一些理解,写一点自己的注释,有什么不恰当的地方请各位读者不吝赐教。 1. GCS::update_send 线程 主动向MavLink system发送消息包。 1.1 不断向地面站发送飞机状态数据 msg_attitude: msg_location: n…...
C++引用概述
变量名实质上是一段连续存储空间的别名,是一个标号(门牌号),程序中通过变量来申请并命 名内存空间,通过变量的名字可以使用存储空间。引用是 C中新增加的概念,引用可以看作 一个已定义变量的别名。 引用的语法: Type&…...
精准努力,提升自己的核心竞争力——中国人民大学与加拿大女王大学金融硕士
步入职场,相信大家都想成为职场的宠儿。经过一番摸爬滚打后,在职场稳固了地位。但想叱咤职场,还需要精准努力,提升自己的核心竞争力。中国人民大学与加拿大女王大学金融硕士项目为你补给能量。 任何资产都有贬值的风险࿰…...
未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?
编辑:陈萍萍的公主一点人工一点智能 未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战,在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...
解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八
现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...
12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
让AI看见世界:MCP协议与服务器的工作原理
让AI看见世界:MCP协议与服务器的工作原理 MCP(Model Context Protocol)是一种创新的通信协议,旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天,MCP正成为连接AI与现实世界的重要桥梁。…...
QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...
HarmonyOS运动开发:如何用mpchart绘制运动配速图表
##鸿蒙核心技术##运动开发##Sensor Service Kit(传感器服务)# 前言 在运动类应用中,运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据,如配速、距离、卡路里消耗等,用户可以更清晰…...
Python网页自动化Selenium中文文档
1. 安装 1.1. 安装 Selenium Python bindings 提供了一个简单的API,让你使用Selenium WebDriver来编写功能/校验测试。 通过Selenium Python的API,你可以非常直观的使用Selenium WebDriver的所有功能。 Selenium Python bindings 使用非常简洁方便的A…...
高考志愿填报管理系统---开发介绍
高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...
【Linux】Linux安装并配置RabbitMQ
目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的,需要先安…...
【若依】框架项目部署笔记
参考【SpringBoot】【Vue】项目部署_no main manifest attribute, in springboot-0.0.1-sn-CSDN博客 多一个redis安装 准备工作: 压缩包下载:http://download.redis.io/releases 1. 上传压缩包,并进入压缩包所在目录,解压到目标…...
