人工智能图像分割之Mask2former源码解读
环境搭建:
(1)首先本代码是下载的mmdetection-2022.9的,所以它的版本要配置好,本源码配置例如mmcv1.7,python3.7,pytorch1.13,cuda11.7。pytorch与python,cuda版本匹配可参考:https://www.jb51.net/python/3308342lx.htm。
(2)还有一个是先要安装一个vs2022版本或vs2019,其中确保工作负载下"使用C++的桌面开发"的内容基本安装上
(3)数据集就用coco集,在参数配置中指定目录,例如../configs/mask2former目录下的.py文件
一.Backbone获取多层级特征
(1)进入train.py读取各个配置参数,构建模型调用Mask2Former类,如下图:




有了model(Mask2Former),datasets(COCO数据集)和cfg后就传入下面的方法:

(2)进入Mask2Former模块,然后调用MaskFormer模块,如下图:


(3)进入MaskFormer类中,并进入def forward_train这个方法中,如下图:


然后
进入父类BaseDetector的forward_train方法:

回到maskformer.py中,如下图:







(4)进入head层,因为mask2former_head的父类是maskformer_head,所以会先调父类中的forward_train方法后再进入到了mask2former_head中的forward方法中,mask2former_head.py这个类很重要,如下图:





二.多层级采样点初始化构建



对输入的每一张特征图加上位置编码,这个FOR中的特征图是256


上图中得到的level_embed就是按每个层级取出的值,它是一个256维的向量,下面用view转成四维向量后与位置编码(pos_embed)做加法。

然后调用point_generator中的single_level_grid_priors方法,如下图:


shift_x与shift_y都是32的一维矩阵,

做成网络后就是32*32=1024的一维矩阵,然后做合并得到棋盘上的位置

返回到msdeformattn_pixel_decoder.py中,如下图:


它是维度转变后1024在最前面的


三.多层级输入特征序列创建方法




上图中的所说的是进入transformer.py中的DetrTransformerEncoder类下的forward方法。
四.偏移量与权重计算并转换
还是在transformer.py中,对每一层(例如selfattention,bn,norm,ffn,全连接)进行for,如果是self_attn

执行selfattention层时,进入到了multi_scale_deform_attn.py中,

进入到multi_scale_deform_attn.py中的forward方法:


下面做了多头注意力机制,变成8头:

重点来了,对query做sampling_offsets与attention_weights方法后并调用view变形得到相应的每个点的偏移量与权重(softmax),偏移量到时候做采样时用到,最后还要对这二个值做乘法,如下图:

其中sampling_offsets这个是全连接层,如下图:


由这4个采样点找它们的偏移量,每个偏移量都有x,y二个值组成。

为什么这里权重值是96呢?因为每个偏移量是一个点,它只是这个点有x,y组成,但权重是指这个点的权重,所以就是上面偏移量输出的192/2=96了(也等以8*(3*4)=8*12)。


上图中的levels是指层级数,points是指采样点数。


五.Encoder特征构建方法实例
偏移量有了,现在我们执行特征计算操作了,而原来特征不准,现在要把新的偏移完的准确的特征拿到手,所以这里也要做特征的偏移。还是在multi_scale_deform_attn.py中,

上图中对齐特征是要重新进行采样的,这里的特征偏移范围是[-1,1]。





做完这一步后返回到transformer.py这里,发现就是做了一个self_attn层的操作,得到query值,如下图所示:



做完ffn全连接后再做norm,就算做完一个层级采样了,其它层级采样也是一样按这样流程的,最后把每一层级执行完后就得到特征值,返回到msdeformattn_pixel_decoder.py中,存到了“(3)多层级输入特征序列创建方法”所说的memory变量(它存的是编码完后特征)中,如下图:

总结一下这里encoder是做了什么:其实就是和可变形detr是一样的,就是对展开的序列提特征,我们是希望它是多层级,多头注意力机制和加上可变形的位置偏移,这样可以得到我们序列更好的特征。

还是在msdeformattn_pixel_decoder.py中,总共3个层级,每个层级的大小(图像上的点的个数)分别是1024(它是由32*32得到),4096(它是由64*64得到)与16384(它是由128*128得到)如下图:






等下就用这个y用来预测一下是每个点是前景还是背景?



六.query要预测的任务解读
返回到mask2former_head.py类中,至此self.pixel_decoder方法调用结束(主要是transformer编码这一块),准备decoder解码了。





上图的query_feat中初始化的100是指decoder中会找100样东西,例下图:




下面开始调用forward_head方法:






这个会找到这100个的前景与背景分别是那些?
七.Decoder中的AttentionMask方法




上图中调用sigmoid()后的值是在0至1之间范围之间。




返回上面3张图的预测值




现在又跳回到transformer.py的baseTransformer当中,



八.损失模块输入参数分析
返回到mask2former_head.py中,如下图:



有了cls_pred_list,mask_pred_list这二个结果(10层)后,接下来就去计算每一层的损失函数结果了。返回到maskformer_head.py文件中,如下图:

上图中的81=80+1,其中80是实际类别,1表示背景。


gt_masks是指标注的信息。

九.标签分配策略解读
进入到mask2former_head.py中的loss_single方法中,它是对每一层进行处理(共10层),



get_targets方法目的就是找正负样本,这个方法是在maskformer_head.py中,对于每一张图像它的正负样本是不一样的。

调用上图的multi_apply会进入到mask_hungarian_assigner.py中的assign方法,这里主要进行标签分配,如下图:


上图100个-1的值当中,如果与标签匹配上就修改里面-1值。

上图中标签分配考虑的三方面其实就对应分类,mask,iou这三方面的损失。
十.正样本筛选损失计算
还是在mask_hungarian_assigner.py中,

上图中的gt_labels是10个标签。调用cls_cost方法后会进入到match_cost.py中,如下图:

上图中调用了softmax方法后即把cls_pred变成概率值了(0------1之间的值),
同时cls_cost的第2个维度是10了,即返回值变成了(100,10),cls_cost的值是负数来的,因为前面加了负号,如下图:


mask_hungarian_assigner.py中计算完类别损失后,现在计算mask损失,如下图:

上图中的12544就当作服从256*256正态分布的随机采样吧,调用上图的mask_cost方法后进入下图的类中:


上图中求pos与neg损失时,都调用了binary_cross_entropy_with_logits,首先我们知道二元交叉熵(Binary cross entropy)是二分类中常用的损失函数,它可以衡量两个概率分布的距离,二元交叉熵越小,分布越相似。相比F.binary_cross_entropy函数,F.binary_cross_entropy_with_logits函数在内部使用了sigmoid函数,也就是F.binary_cross_entropy_with_logits = sigmoid + F.binary_cross_entropy。


上图中因为它是有12544个采样点,它是累加后求平均,所以要除以12544,最后得到一个(100,10)的矩阵返回回去。为什么它是返回(100,10)?是因为它是100个query都分别与10个类别去计算得到的。
十一.标签分类匹配结果分析

这个mask算出来的cost值是正数来的。下面开始计算dice_cost(类似iou重合比例计算)

跳转到match_cost.py中,先拉长,再按dice系数公式得到被除数

Dice 系数可以计算两个字符串的相似度:Dice(s1,s2) = 2*comm(s1,s2)/(leng(s1)+leng(s2))。其中,comm(s1,s2) 是s1和s2中相同字符的个数; leng(s1)、leng(s2)是字符串s1、s2 的长度。
返回后得到如下图:

计算出三个损失值后就进行累加,如下图:


而上图中matched_col_inds的10个索引值是对应gt标签中的索引位置,matched_row_inds是100个中匹配到损失最小的对应的10个索引。这时我们就认为通过与标签匹配完成了,得到正样本了。

返回后回到了mask2former_head.py中,调用assign方法进行标签匹配就结束了,如下图:

然后调用sample主要是用来取正,负样本的索引值,如下图:






上图中mask_targets是只考虑匹配上的10个。到此时,第一张图片就处理完了,然后按上面的逻辑做第二张图片的处理了。
十二.最终损失计算流程
这时(9)中所说的调用get_targets方法就结束了。前面算的损失是关于标签分配的。而下面将算实际的损失了。


上图中labels这个标签是更新了正样本索引后的标签值,80是默认的初始化的负样本索引。上图中这三个都是200,那就可以做交叉熵损失了,调下图的方法self.loss_cls方法。




上图中二个极端值是指概率值(上图画的y轴值)要么靠近0,要么靠近1,这样就表示确定性越大(高),而不确定性越强的是概率值靠近0.5那种。
十三.汇总所有损失完成迭代
get_uncertainty方法返回就得到不确定性的点索引,如下图


在分割中也许是边界点比较模拟二可(即可能是背景也可能是前景),它的不确定性就比较高。


这里三个loss_cls,loss_dice,loss_mask做完返回值,重复迭代10次求这三个值。
相关文章:
人工智能图像分割之Mask2former源码解读
环境搭建: (1)首先本代码是下载的mmdetection-2022.9的,所以它的版本要配置好,本源码配置例如mmcv1.7,python3.7,pytorch1.13,cuda11.7。pytorch与python,cuda版本匹配可参考:https://www.jb51.net/python/3308342lx.htm。 (2)还有一个是先要安装一个vs2022版本或…...
uniapp 编译生成鸿蒙正式app步骤
1,在最新版本DevEco-Studio工具新建一个空项目并生成p12和csr文件(构建-生成私钥和证书请求文件) 2,华为开发者平台 根据上面生成的csr文件新增cer和p7b文件,分发布和测试 3,在最新版本DevEco-Studio工具 文…...
2024最新版Java面试题及答案,【来自于各大厂】
发现网上很多Java面试题都没有答案,所以花了很长时间搜集整理出来了这套Java面试题大全~ 篇幅限制就只能给大家展示小册部分内容了,需要完整版的及Java面试宝典小伙伴点赞转发,关注我后在【翻到最下方,文尾点击名片】即可免费获取…...
Excel 融合 deepseek
效果展示 代码实现 Function QhBaiDuYunAIReq(question, _Optional Authorization "Bearer ", _Optional Qhurl "https://qianfan.baidubce.com/v2/chat/completions")Dim XMLHTTP As ObjectDim url As Stringurl Qhurl 这里替换为你实际的URLDim postD…...
【填坑】新能源汽车三电设计之常用半导体器件系统性介绍
# 在新能源汽车的三电(电池、电机、电控)系统中,半导体器件扮演着至关重要的角色。它们如同系统的“大脑”和“神经末梢”,精确地控制着电能的流向与转换,确保新能源汽车高效、稳定且安全地运行。今天,就让…...
SpringCloud面试题----Nacos和Eureka的区别
功能特性 服务发现 Nacos:支持基于 DNS 和 RPC 的服务发现,提供了更为灵活的服务发现机制,能满足不同场景下的服务发现需求。Eureka:主要基于 HTTP 的 RESTful 接口进行服务发现,客户端通过向 Eureka Server 发送 HT…...
21.2.6 字体和边框
版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 通过设置Rang.Font对象的几个成员就可以修改字体,设置Range.Borders就可以修改边框样式。 【例 21.6】【项目ÿ…...
正则表达式进阶(二)——零宽断言详解:\b \B \K \z \A
在正则表达式中,零宽断言是一种非常强大的工具,能够在不消费字符的情况下对匹配位置进行约束。除了环视(lookahead 和 lookbehind)以外,还有一些常用的零宽断言,它们用于处理边界、字符串的开头和结尾等特殊…...
OpenFeign远程调用返回的是List<T>类型的数据
在使用 OpenFeign 进行远程调用时,如果接口返回的是 List 类型的数据,可以通过以下方式处理: 直接定义返回类型为List Feign 默认支持 JSON 序列化/反序列化,如果服务端返回的是 List的JSON格式数据,可以直接在 Feig…...
三维模拟-机械臂自翻车
机械仿真 前言效果图后续 前言 最近在研究Unity机械仿真,用Unity实现其运动学仿真展示的功能,发现一个好用的插件“MGS-Machinery-master”,完美的解决了Unity关节定义缺少液压缸伸缩关节功能,内置了多个场景,讲真的&…...
网络安全治理架构图 网络安全管理架构
网站安全攻防战 XSS攻击 防御手段: - 消毒。 因为恶意脚本中有一些特殊字符,可以通过转义的方式来进行防范 - HttpOnly 对cookie添加httpOnly属性则脚本不能修改cookie。就能防止恶意脚本篡改cookie 注入攻击 SQL注入攻击需要攻击者对数据库结构有所…...
调用deepseek的API接口使用,对话,json化,产品化
背景 最近没咋用chatgpt了,deepseek-r1推理模型写代码质量是很高。deepseek其输出内容的质量和效果在国产的模型里面来说确实算是最强的,并且成本低,它的API接口生态也做的非常好,和OpenAI完美兼容。所以我们这一期来学一下怎么调…...
omegaconf库使用实践
最近在重构RapidOCR仓库代码,使其更加优雅的同时,具有扩展性。无意从他人源码中发现omegaconf库。 omegaconf OmegaConf是一个用于处理配置文件和命令行参数的Python库。它支持YAML、JSON、INI等多种配置文件格式,提供了配置合并、类型安全…...
STM32 USART1 串口调试打印,映射printf函数
该代码可以在freertos中正常运行,你可以进行更多细节优化 PA9(TX) PA10(RX) #include "usart.h"// 解决串口死机问题 #pragma import(__use_no_semihosting) struct __FILE { int handle; }; // 标准库需要的支持函数 FILE __…...
DeepSeek大模型本地部署实战
1. 下载并安装Ollama 打开浏览器:使用你常用的浏览器(如Chrome、Firefox等)访问Ollama的官方网站。无需特殊网络环境,直接搜索“Ollama”即可找到。 登录与下载:进入Ollama官网后,点击右上角的“Download…...
【学习总结|DAY037】Linux 项目部署
引言 在当今的软件开发领域,Linux 以其安全、稳定、免费且开源的特性,成为项目部署的首选操作系统。无论是 Java 项目,还是各类开发、测试、生产环境中的软件安装,Linux 都占据着重要地位。本文将结合我今天所学内容,…...
Spring Boot Actuator使用
说明:本文介绍Spring Boot Actuator的使用,关于Spring Boot Actuator介绍,下面这篇博客写得很好,珠玉在前,我就不多介绍了。 Spring Boot Actuator 简单使用 项目里引入下面这个依赖 <!--Spring Boot Actuator依…...
[css] 黑白主题切换
link动态引入 类名切换 css滤镜 var 类名切换 v-bind css预处理器mixin类名切换 【前端知识分享】CSS主题切换方案...
阿里云专有云网络架构学习
阿里云专有云网络架构 叶脊(spine-leaf)网络和传统三层网络拓扑对比 阿里云网络架构V3拓扑角色介绍推荐设备设备组网举例带外管理网络带外网和带内网对比设备介绍 安全网络设备介绍 参考 后续更新流量分析叶脊(spine-leaf)网络和传…...
【AIGC】冷启动数据与多阶段训练在 DeepSeek 中的作用
博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: AIGC | ChatGPT 文章目录 💯前言💯冷启动数据的作用冷启动数据设计 💯多阶段训练的作用阶段 1:冷启动微调阶段 2:推理导向强化学习(RL࿰…...
在SIP路由中,常见的对接方式
好的,我已将应用场景和案例分为两列。修改后的表格如下: 对接方式描述应用场景案例注册 (REGISTER)用于用户注册,将用户位置(如IP地址)与其用户名进行绑定。用户通过发送REGISTER请求将自己注册到SIP服务器。注册过程…...
GenAI + 电商:从单张图片生成可动态模拟的3D服装
在当今数字化时代,电子商务和虚拟现实技术的结合正在改变人们的购物体验。特别是在服装行业,消费者越来越期待能够通过虚拟试衣来预览衣服的效果,而无需实际穿戴。Dress-1-to-3 技术框架正是为此而生,它利用生成式AI模型(GenAI)和物理模拟技术,将一张普通的穿衣照片转化…...
harmonyOS生命周期详述
harmonyOS的生命周期分为app(应用)的生命周期和页面的生命周期函数两部分 应用的生命周期-app应用 在app.js中写逻辑,具体有哪些生命周期函数呢,请看下图: onCreated()、onShow()、onHide()、onDestroy()这五部分 页面及组件生命周期 着重说下onShow和onHide,分别代表是不是…...
记一次调整磁盘分区大小的经验
背景 redhat 6 系统 根目录挂载的逻辑卷满了,系统都不能正常运行了 但是/home目录挂载的另外一个逻辑卷却占用只有4% 所以想把/home挂的逻辑卷分一部分给/ 挂的逻辑卷 备份 先把系统整盘备份一下,用clonezilla做一个磁盘镜像,免得失误了搞…...
css:怎么设置图片不变形
问: main元素中有一个img元素,这个img src‘/assets/images/tupian.png’css设置了img元素width:50% height:50%但是图片变形了,我应该怎么设置保持图片样式不变形 回答: 为了确保图片在调整大小时不变形࿰…...
软件测试就业
文章目录 2.6 初识一、软件测试理论二、软件的生产过程三、软件测试概述四、软件测试目的五、软件开发与软件测试的区别?六、学习内容 2.7 理解一、软件测试的定义二、软件测试的生命周期三、软件测试的原则四、软件测试分类五、软件的开发与测试模型1.软件开发模型…...
【Pandas】pandas Series sum
Pandas2.2 Series Computations descriptive stats 方法描述Series.abs()用于计算 Series 中每个元素的绝对值Series.all()用于检查 Series 中的所有元素是否都为 True 或非零值(对于数值型数据)Series.any()用于检查 Series 中是否至少有一个元素为 T…...
后缀表达式(蓝桥杯19I)
有减于号时 假设有n个大于0从大到小的数,加减符号数为n-1:a,b,c,d,。。。。。,e sum求最大:(max )-(min ) a - (e - ( ) -())( ( )( ) ( ) 。。。。 ) 当序列中有负数时: a -&am…...
问题大集04-浏览器阻止从 本地 发起的跨域请求,因为服务器的响应头 Access-Control-Allow-Origin 设置为通配符 *
1、问题 localhost/:1 Access to XMLHttpRequest at xxx(请求) from origin http://localhost:xxx(本地) has been blocked by CORS policy: The value of the Access-Control-Allow-Origin header in the response must not be t…...
mac环境下,ollama+deepseek+cherry studio+chatbox本地部署
春节期间,deepseek迅速火爆全网,然后回来上班,我就浅浅的学习一下,然后这里总结一下,我学习中,总结的一些知识点吧,分享给大家。具体的深度安装部署,这里不做赘述,因为网…...
