从零开始实现 MobileViT 注意力机制——轻量级Transformer Vision Model 的新思路
从零开始实现 MobileViT 注意力机制——轻量级Transformer Vision Model 的新思路
近年来,计算机视觉领域中 Transformer 模型的崛起为图像处理带来了新的活力。特别是在 ViT(Vision Transformer)模型提出之后,Transformer 在图像分类、目标检测等任务上展示了超越 CNN 的潜力。然而,标准的 ViT 模型参数量大,计算复杂度高,难以在移动设备等资源受限的环境中部署。
最近,《MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer》 这篇论文提出了一种轻量化、通用且适合移动端的视觉变换器模型。该模型通过结合局部和全局特征的创新设计,在保持良好性能的同时,大大降低了计算资源的需求,为移动应用提供了新的解决方案。
本文将从零开始解读并实现 MobileViT 的核心注意力机制模块,帮助开发者理解这一轻量级视觉变换器的工作原理,从而在实际项目中灵活运用。
1. 背景:从 ViT 到 MobileViT
1.1 Vision Transformer (ViT) 简介
标准的 ViT 模型将整个图像划分为不重叠的 patches(块),并将其转换为序列输入到基于Transformer 的编码器中。这种方法虽然在性能上表现出色,但也带来了以下问题:
- 计算复杂度高:将图像分割成大量 patches 后进行序列操作,参数量和计算量急剧上升。
- 适用性有限:直接使用 Transformer 架构处理图像分辨率较高的场景时,资源消耗(如内存、算力)难以满足移动端的需求。
1.2 MobileViT 的创新思路
MobileViT 提出了一种折中的解决方案——结合 局部表示(Local Representation) 和 全局表示(Global Representation),以降低计算复杂度同时保持性能。其核心思想是:
- 在每个位置保留原始图像的局部特征信息。
- 通过 Transformer 模块提取和增强全局特征信息。
- 将局部和全局特征进行融合,生成最终的高质量视觉表征。
2. MobileViT 注意力机制模块实现解析
MobileViT 的核心模块是 MobileViTAttention。我们需要逐步解读其实现细节,并通过代码示例帮助读者理解其工作原理。
2.1 模块设计概述
- 输入:一个张量(Tensor),形状为
[batch_size, in_channel, height, width] - 输出:经过局部和全局特征融合后的张量,保持与输入相同的尺寸
模块主要包含以下几个部分:
- 局部特征提取:通过卷积操作提取每个位置的局部信息。
- 全局特征提取:使用 Transformer 模块对图像进行分块(patch)处理,并在序列空间中捕获长距离依赖关系。
- 特征融合:将局部和全局特征拼接后,通过轻量级的卷积操作生成最终输出。
以下是完整的 MobileViTAttention 类的实现代码:
import torch
from torch import nnclass MobileViT_Attention(nn.Module):def __init__(self, in_channels=3, kernel_size=3, patch_size=2, embed_dim=144):super().__init__()# 设置 patch 的大小(默认为7x7)self.ph, self.pw = patch_size, patch_size# 局部特征提取:通过卷积操作捕获局部上下文信息self.local_conv = nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=1),nn.BatchNorm2d(in_channels),nn.ReLU(inplace=True))# 全局特征提取:将张量重排为 [batch_size, patch_height*patch_width, N_h*N_w, embed_dim]# Transformer 模块用于捕获全局上下文信息self.global_trans = Transformer(embed_dim=embed_dim,num_heads=16,num_transformer_layers=4)# 特征融合:将局部特征和全局特征拼接,并通过卷积操作生成最终输出self.fusion_conv = nn.Sequential(nn.Conv2d(in_channels*2, in_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=1),nn.BatchNorm2d(in_channels),nn.ReLU(inplace=True))def forward(self, x):# 提取局部特征local_feats = self.local_conv(x) # 局部特征if len(local_feats.shape) == 4:B, C, H, W = local_feats.shapeelse:raise ValueError("Input tensor should have rank 4.")# 分割图像为 patch,并进行重排:从 [B, C, H, W] 到 [B, (H*W), C]# 每个 patch 的大小为 (patch_size, patch_size)patches = []for i in range(0, H, self.ph):for j in range(0, W, self.pw):patch = local_feats[:, :, i:i+self.ph, j:j+self.pw]patch = torch.flatten(patch, start_dim=2) # 打平patchpatches.append(patch)# 拼接所有的 patch,形成张量 [B, num_patches, C]x_patched = torch.stack(patches, dim=1)# 传递到 Transformer 中提取全局特征global_feats = self.global_trans(x_patched) # 全局上下文特征# 特征融合:将原始输入的局部特征与 Transformer 输出的全局特征拼接x_fused = torch.cat([local_feats, global_feats.unsqueeze(2).unsqueeze(3)], dim=1)return self.fusion_conv(x_fused) # 最终的特征输出class Transformer(nn.Module):def __init__(self, embed_dim=768, num_heads=12, num_transformer_layers=4):super().__init__()self.embedding = nn.Linear(embed_dim, embed_dim)self.layers =(nn.ModuleList([TransformerBlock(d_model=embed_dim, nhead=num_heads)for _ in range(num_transformer_layers)]))def forward(self, x):x = self.embedding(x)for layer in self.layers:x = layer(x)return x
3. 实现细节解读
3.1 局部特征提取
- 卷积操作:使用
nn.Conv2d在局部区域内捕获上下文信息。 - BN 和 ReLU:通过归一化和非线性激活,提升特征表达能力。
self.local_conv = nn.Sequential(nn.Conv2d(3, 3, kernel_size=3, padding=1),nn.BatchNorm2d(3),nn.ReLU(inplace=True)
)
3.2 全局特征提取(Transformer)
- 分块:将图像分割为
patch_size x patch_size的小块,每个块展开成一维向量。 - 序列建模:通过多层 Transformer Block 捕获长距离依赖。
class TransformerBlock(nn.Module):def __init__(self, d_model=768, nhead=12):super().__init__()self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead)self.dropout = nn.Dropout(0.1)def forward(self, x):out = self.self_attn(x, x, x)[0]return F.dropout(out, p=0.1, training=self.training)
3.3 特征融合
- 拼接:将局部特征和全局特征在通道维度上进行拼接。
- 卷积操作:通过轻量级的卷积操作生成最终输出。
self.fusion_conv = nn.Sequential(nn.Conv2d(3*2, 3, kernel_size=3, padding=1),nn.BatchNorm2d(3),nn.ReLU(inplace=True)
)
4. 模块的输入输出尺寸分析
输入
- 形状:
[batch_size, in_channels, height, width] - 示例:
[ batch_size: 4, in_channels: 3 (RGB), height: 224, width: 224 ]
输出
- 相同的尺寸
[batch_size, in_channels, height, width] - 经过局部和全局特征融合后,输出高质量的视觉表征。
5. 总结与展望
通过结合局部和全局特征提取,MobileViT 成功地在轻量级计算资源的基础上实现了高效的视觉信息处理。这一模块尤其适合应用于移动设备和嵌入式系统中,同时也可以作为其他视觉任务(如目标检测、图像分割)的高效特征提取模块。
未来的工作可以尝试以下方向:
- 优化 Transformer 模块:通过减少头数或简化注意力机制降低计算复杂度。
- 自适应 patch 大小:根据输入尺寸动态调整 patch 的大小,提升模型的灵活性。
- 多尺度融合:在更细粒度的尺度上结合特征信息,进一步提升性能。
希望通过对这一模块的解读和实现,能够帮助读者更好地理解和应用 MobileViT 模型,在实际项目中发挥其优势。
相关文章:
从零开始实现 MobileViT 注意力机制——轻量级Transformer Vision Model 的新思路
从零开始实现 MobileViT 注意力机制——轻量级Transformer Vision Model 的新思路 近年来,计算机视觉领域中 Transformer 模型的崛起为图像处理带来了新的活力。特别是在 ViT(Vision Transformer)模型提出之后,Transformer 在图像…...
揭秘大数据 | 22、软件定义存储
揭秘大数据 | 19、软件定义的世界-CSDN博客 揭秘大数据 | 20、软件定义数据中心-CSDN博客 揭秘大数据 | 21、软件定义计算-CSDN博客 老规矩,先把这个小系列的前三篇奉上。今天书接上文,接着叙软件定义存储的那些事儿。 软件定义存储源于VMware公司于…...
OpenCV 图形API(37)图像滤波-----分离过滤器函数sepFilter()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 应用一个可分离的线性滤波器到一个矩阵(图像)。 该函数对矩阵应用一个可分离的线性滤波器。也就是说,首先&a…...
flutter下载SDK环境配置步骤详解
目录 1.Flutter官网地址、SDK下载地址? 1.1 选择你电脑的系统 2.配置环境 3.解决环境报错 zsh:command not found:flutter 1.Flutter官网地址、SDK下载地址? flutter官网地址: URL 1.1 选择你电脑的系统 下载解压动目录就OK了 2.配置环境 1、打开命令行…...
数据结构与算法入门 Day 0:程序世界的基石与密码
🌟数据结构与算法入门 Day 0:程序世界的基石与密码🔑 ps:接受到了不少的私信反馈,说应该先把前置的知识内容做一个梳理,所以把昨天的文章删除了,重新开启今天的博文写作 Hey 小伙伴们ÿ…...
vscode终端运行windows服务器的conda出错
远程windows服务器可以运行,本地vscode不能。 打开vscode settings.json文件 添加conda所在路径...
Elasticsearch 查询排序报错总结
Elasticsearch 查询sort报错总结 文章目录 Elasticsearch 查询`sort`报错总结错误1、使用Es对 `sort` 进行排序字段类型的要求1.1、数值类型(如 `integer`、`long`、`float`、`double`)1.2、日期类型(如 `date`)1.3、字符串类型(如 `keyword`、`text`)1.4、布尔类型(`bo…...
“大湾区珠宝艺境花园”璀璨绽放第五届消博会
2025年4月13日,第五届中国国际消费品博览会(以下简称"消博会")重要主题活动——《大湾区珠宝艺境花园》启动仪式在海南国际会展中心2号馆隆重举行。由广东省金银珠宝玉器业厂商会组织带领粤港澳大湾区优秀珠宝品牌,以“…...
十、自动化函数+实战
Maven环境配置 1.设计测试用例 2.创建空项目 1)添加需要的依赖pom.xml <dependencies> <!-- 截图配置--><dependency><groupId>commons-io</groupId><artifactId>commons-io</artifactId><version>2.6</…...
Day09【基于jieba分词和RNN实现的简单中文分词】
基于jieba分词和RNN实现的中文分词 目标数据准备主程序预测效果 目标 本文基于给定的中文词表,将输入的文本基于jieba分词分割为若干个词,词的末尾对应的标签为1,中间部分对应的标签为0,同时将分词后的单词基于中文词表做初步序列…...
自动化测试——selenium
简介 Selenium 是一个广泛使用的自动化测试工具,主要用于 Web 应用程序的自动化测试。它能实现的功能是网页的自动化操作,例如自动抢票刷课等。同时你应该也见到过有些网站在打开之后并没有直接加载出网站的所有内容,比如一些图片等等&#x…...
java和python实现mqtt
说明: MQTT 异步通信系统功能文档 系统概述 本系统基于 MQTT 协议实现异步通信,包含三个核心组件: Broker(消息代理):负责消息的路由和转发。 Client(主客户端):定时发…...
5.9 《GPT-4调试+测试金字塔:构建高可靠系统的5大实战策略》
5.4 测试与调试:构建企业级质量的保障体系 关键词:测试金字塔模型、GPT-4调试助手、LangChain调试模式、异步任务验证 测试策略设计(测试金字塔实践) #mermaid-svg-RblGbJVMnCIShiCW {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill…...
Linux——进程通信
我们知道,进程具有独立性,各进程之间互不干扰,但我们为什么还要让其联系,建立通信呢?比如:数据传输,资源共享,通知某个事件,或控制某个进程。因此,让进程间建…...
学习笔记十三—— 理解 Rust 闭包:从语法到 impl Fn vs Box<dyn Fn>
🧠 理解 Rust 闭包:从语法到 impl Fn vs Box 📚 目录 闭包是什么?和普通函数有什么不同?闭包的语法长什么样?闭包“捕获变量”是什么意思?闭包和所有权的关系Fn、FnMut、FnOnce 三种闭包类型的…...
【免费参会合集】2025年生物制药行业展会会议表格整理
全文精心整理, 建议今年参会前都好好收藏着,记得点赞! 医药人非常吃资源,资源从何而来?作为一名从事医药行业的工作者,可以很负责任的告诉诸位,其中非常重要的一个渠道就是会议会展! 建议所有医…...
腾讯云开发+MCP:旅游规划攻略
1.登录注册好之后进入腾讯云开发 2.创建环境 4.创建好环境之后点击去开发 5.进入控制台后,选择AI,找到MCP 6.点击创建MCP Server 使用腾讯云开发创建MCP目前需要云开发入门版99/月,我没开通,所以没办法往下进行。...
银河麒麟系统 达梦8 安装 dlask 框架后端环境
适配的一套环境为 dmPython2.5.8 dmSQLAlchemy1.4.39 Flask2.0.3 Flask-Cors3.0.10 Flask-SQLAlchemy2.5.1 SQLAlchemy1.4.54 Werkzeug2.2.2其中 # sqlalchemy-dm1.4.39 通过dmdbms目录内文件进行源码安装 (MindSpore) [ma-user python]$pwd /home/syl/dmdbms/drivers/python…...
Cribl (实验) vpc-flow 数据抽样
先看文档: Firewall Logs: VPC Flow Logs, Cisco ASA, Etc. | Cribl Docs Firewall Logs: VPC Flow Logs, Cisco ASA, Etc. Recipe for Sampling Firewall Logs Firewall logs are another source of important operational (and security) data. Typical examples include Ama…...
Sklearn入门之数据预处理preprocessing
、 Sklearn全称:Scipy-toolkit Learn是 一个基于scipy实现的的开源机器学习库。它提供了大量的算法和工具,用于数据挖掘和数据分析,包括分类、回归、聚类等多种任务。本文我将带你了解并入门Sklearn下的preprocessing在机器学习中的基本用法。 获取方式…...
我想自己组装一台服务器,微调大模型通义千问2.5 Omni 72B,但是我是个人购买,资金非常有限,最省的方案
目录 🧠 首先我们要搞清楚几个核心点: 🎯 目标:微调 Qwen2.5-Omni-72B 🚨 现实问题:作为个人用户,72B 模型几乎无法负担全量微调 💸 全量微调硬件需求: ✅ 最省的个人方案:不组 72B,只训练 Qwen2.5-Omni-7B 或 14B 💡 推荐方案 A:个人桌面级多卡训练服…...
家用打印机性价比排名及推荐
文章目录 品牌性价比一、核心参数对比与场景适配二、技术类型深度解析三、不同场景选择 相关文章 品牌 性价比 一、核心参数对比与场景适配 兄弟T436W 优势: 微压电技术,打印头寿命长,堵头率低。 支持A4无边距和5G WiFi,适合照片…...
KWDB(Knowledge Worker Database)基础概念与原理完整指南
KWDB(Knowledge Worker Database)基础概念与原理完整指南—目录 前言一、背景1.1 知识工作者的痛点1.2 技术演进推动 二、定义与定位2.1 什么是KWDB?2.2 KWDB与传统数据库的对比与传统关系型数据库(如MySQL)的对比与分…...
数字电子技术基础(四十七)——使用Mutlisim软件来模拟74LS85芯片
目录 1 使用74LS85N芯片完成四位二进制数的比较 1.1原理介绍 1.2 器件选择 1.3 运行电路 2 使用74LS85N完成更多位的二进制比较 1 使用74LS85N芯片完成四位二进制数的比较 1.1原理介绍 对于74LS85 是一款 4 位数值比较器集成电路,用于比较两个 4 位二进制数&…...
关于STM32创建工程文件启动文件选择
注意启动文件只要选择这几个 而不是要把所有都选上...
LLC电路工作在容性区的风险
在t0时刻之前,Q6Q7导通,回路如下所示,此时A点电压是低压,B点电压是高压 在t0时刻时,谐振电流相位发生变换,在t1时刻,Q5,Q8导通,对于Q8MOS管来说,B点电压在Q6Q…...
Linux Kernel 6
clone 系统调用(The clone system call) 在 Linux 中,使用 clone() 系统调用来创建新的线程或进程。fork() 系统调用和 pthread_create() 函数都基于 clone() 的实现。 clone() 系统调用允许调用者决定哪些资源应该与父进程共享,…...
【开源项目】Excel手撕AI算法深入理解(四):AlphaFold、Autoencoder
项目源码地址:https://github.com/ImagineAILab/ai-by-hand-excel.git 一、AlphaFold AlphaFold 是 DeepMind 开发的突破性 AI 算法,用于预测蛋白质的三维结构。它的出现解决了生物学领域长达 50 年的“蛋白质折叠问题”,被《科学》杂志评为…...
第IV部分有效应用程序的设计模式
第IV部分有效应用程序的设计模式 第IV部分有效应用程序的设计模式第23章:应用程序用户界面的架构设计23.1设计考量23.2示例1:用于非分布式有界上下文的一个基于HTMLAF的、服务器端的UI23.3示例2:用于分布式有界上下文的一个基于数据API的客户端UI23.4要点第24章:CQRS:一种…...
如何编制实施项目管理章程
本文档概述了一个项目管理系统的实施计划,旨在通过统一的业务规范和技术架构,加强集团公司的业务管控,并规范业务管理。系统建设将遵循集团统一模板,确保各单位项目系统建设的标准化和一致性。 实施范围涵盖投资管理、立项管理、设计管理、进度管理等多个方面,支持项目全生…...
