当前位置: 首页 > article >正文

用PyTorch手把手实现BoTNet:把ResNet50的3x3卷积换成MHSA到底有多简单?

用PyTorch手把手实现BoTNet把ResNet50的3x3卷积换成MHSA到底有多简单如果你正在寻找一种既能保留CNN局部特征提取能力又能引入全局注意力机制的方法BoTNet可能是最优雅的解决方案之一。这个将ResNet中3x3卷积替换为多头自注意力(MHSA)的改动看似简单却效果显著。本文将用可运行的代码展示这一转换过程让你在10分钟内掌握核心实现技巧。1. 环境准备与基础理解在开始编码前我们需要明确几个关键概念。BoTNet的核心思想是在ResNet的Bottleneck块中用MHSA替代传统的3x3卷积。这种设计保留了CNN的层次结构同时在深层网络引入全局注意力机制。准备环境只需标准的PyTorch环境import torch import torch.nn as nn import torch.nn.functional as F为什么选择Bottleneck进行改造ResNet的Bottleneck结构天然适合插入注意力机制先通过1x1卷积降维减少计算量中间层处理核心特征这里是替换点最后1x1卷积恢复维度这种结构恰好与Transformer中的扩展-注意力-压缩流程相似。2. 实现核心MHSA模块让我们先构建最关键的MHSA层。与标准Transformer不同这里的实现需要处理2D特征图class MHSA(nn.Module): def __init__(self, n_dims, width14, height14, heads4): super().__init__() self.heads heads # 使用1x1卷积实现QKV投影 self.query nn.Conv2d(n_dims, n_dims, kernel_size1) self.key nn.Conv2d(n_dims, n_dims, kernel_size1) self.value nn.Conv2d(n_dims, n_dims, kernel_size1) # 相对位置编码参数 self.rel_h nn.Parameter(torch.randn([1, heads, n_dims//heads, 1, height])) self.rel_w nn.Parameter(torch.randn([1, heads, n_dims//heads, width, 1])) self.softmax nn.Softmax(dim-1) def forward(self, x): n_batch, C, width, height x.size() # 投影到QKV空间 q self.query(x).view(n_batch, self.heads, C//self.heads, -1) k self.key(x).view(n_batch, self.heads, C//self.heads, -1) v self.value(x).view(n_batch, self.heads, C//self.heads, -1) # 内容注意力 content_content torch.matmul(q.permute(0,1,3,2), k) # 位置注意力 content_position (self.rel_h self.rel_w).view(1, self.heads, C//self.heads, -1) content_position torch.matmul(content_position, q) # 合并注意力 energy content_content content_position attention self.softmax(energy) # 输出重构 out torch.matmul(v, attention.permute(0,1,3,2)) out out.view(n_batch, C, width, height) return out这段代码有几个关键设计点使用1x1卷积而非线性层实现QKV投影保留空间结构采用分解的相对位置编码将H×W的编码简化为(HW)形式注意力计算同时考虑内容相关性和位置相关性3. 改造Bottleneck模块现在我们可以改造标准的ResNet Bottleneck将中间的3x3卷积替换为MHSAclass Bottleneck(nn.Module): expansion 4 def __init__(self, in_planes, planes, stride1, heads4, mhsaFalse, resolutionNone): super().__init__() # 第一个1x1卷积降维 self.conv1 nn.Conv2d(in_planes, planes, kernel_size1, biasFalse) self.bn1 nn.BatchNorm2d(planes) # 核心修改点3x3卷积或MHSA if not mhsa: self.conv2 nn.Conv2d(planes, planes, kernel_size3, padding1, stridestride, biasFalse) else: self.conv2 nn.ModuleList([ MHSA(planes, widthint(resolution[0]), heightint(resolution[1]), headsheads) ]) if stride 2: # 处理下采样 self.conv2.append(nn.AvgPool2d(2, 2)) self.conv2 nn.Sequential(*self.conv2) self.bn2 nn.BatchNorm2d(planes) # 第三个1x1卷积升维 self.conv3 nn.Conv2d(planes, self.expansion*planes, kernel_size1, biasFalse) self.bn3 nn.BatchNorm2d(self.expansion*planes) # 捷径连接 self.shortcut nn.Sequential() if stride ! 1 or in_planes ! self.expansion*planes: self.shortcut nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size1, stridestride), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out F.relu(self.bn2(self.conv2(out))) out self.bn3(self.conv3(out)) out self.shortcut(x) out F.relu(out) return out改造时需特别注意MHSA不支持下采样需额外添加平均池化层保持原有的残差连接结构不变维持BatchNorm和ReLU的配置位置4. 构建完整BoTNet模型现在我们可以组装完整的BoTNet架构。通常只在最后几个阶段使用MHSA以平衡计算量和性能class BoTNet(nn.Module): def __init__(self, block, num_blocks, num_classes1000, resolution(224,224), heads4): super().__init__() self.in_planes 64 self.resolution list(resolution) # 初始卷积层 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.relu nn.ReLU(inplaceTrue) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 更新分辨率信息 for op in [self.conv1, self.maxpool]: if op.stride[0] 2: self.resolution[0] / 2 if len(op.stride) 1 and op.stride[1] 2: self.resolution[1] / 2 # 构建四个阶段 self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2, headsheads, mhsaTrue) # 分类头 self.avgpool nn.AdaptiveAvgPool2d((1,1)) self.fc nn.Sequential( nn.Dropout(0.3), nn.Linear(512*block.expansion, num_classes) ) def _make_layer(self, block, planes, num_blocks, stride1, heads4, mhsaFalse): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(block(self.in_planes, planes, stride, heads, mhsa, self.resolution)) if stride 2: self.resolution [r//2 for r in self.resolution] self.in_planes planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out self.relu(self.bn1(self.conv1(x))) out self.maxpool(out) out self.layer1(out) out self.layer2(out) out self.layer3(out) out self.layer4(out) out self.avgpool(out) out torch.flatten(out, 1) out self.fc(out) return out def BoTNet50(num_classes1000, resolution(224,224), heads4): return BoTNet(Bottleneck, [3,4,6,3], num_classesnum_classes, resolutionresolution, headsheads)关键设计选择仅在layer4最后阶段使用MHSA模块动态跟踪特征图分辨率变化保持与标准ResNet相同的宏观结构5. 训练技巧与性能对比将ResNet改造为BoTNet后训练策略需要相应调整学习率调整初始学习率可以比标准ResNet稍小约小2-5倍使用带warmup的学习率调度optimizer torch.optim.SGD(model.parameters(), lr0.01, momentum0.9) scheduler torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr0.01, steps_per_epochlen(train_loader), epochs100 )性能对比表格指标ResNet50BoTNet50变化幅度参数量(M)25.524.7↓3.1%FLOPs(G)4.15.8↑41.5%ImageNet Top-176.1%77.3%↑1.2%COCO mAP38.040.2↑2.2注意实际性能提升取决于具体任务和数据。在小规模数据集上可能需要减少MHSA的使用比例以避免过拟合。实际部署建议从最后阶段开始逐步替换先替换1个block观察效果对于小分辨率输入224x224可能不需要MHSA可以使用混合精度训练加速scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): outputs model(inputs) loss criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()通过以上步骤我们完成了从ResNet到BoTNet的改造。实际项目中这种改造通常能在目标检测和语义分割任务中获得更显著的提升因为这类任务更需要全局上下文信息。

相关文章:

用PyTorch手把手实现BoTNet:把ResNet50的3x3卷积换成MHSA到底有多简单?

用PyTorch手把手实现BoTNet:把ResNet50的3x3卷积换成MHSA到底有多简单? 如果你正在寻找一种既能保留CNN局部特征提取能力,又能引入全局注意力机制的方法,BoTNet可能是最优雅的解决方案之一。这个将ResNet中3x3卷积替换为多头自注意…...

FPGA时序分析避坑指南:从TimeQuest报错到正确添加SDC约束的完整流程

FPGA时序分析避坑指南:从TimeQuest报错到正确添加SDC约束的完整流程 第一次打开TimeQuest看到满屏红色警告时,那种手足无措的感觉我至今记忆犹新。时钟约束不生效、SDC文件加载失败、默认1GHz约束冲突——这些看似简单的问题背后,往往隐藏着F…...

Simulink数据导入导出全攻略:从MATLAB工作区交互到信号日志分析,提升仿真效率的5个技巧

Simulink数据流高效管理:构建闭环仿真工作流的5个核心策略 在工程仿真领域,数据就像血液一样贯穿整个系统建模的生命周期。每次打开Simulink模型时,我们都在与数据打交道——可能是来自实验室的实测数据需要导入作为激励源,也可能…...

告别ROS安装噩梦:用小鱼的一键脚本在Ubuntu 22.04上5分钟搞定ROS2 Humble

5分钟征服ROS2 Humble:小鱼一键脚本的极简安装哲学 第一次接触ROS时,我盯着官方文档里密密麻麻的依赖项和时不时报错的rosdep,差点以为自己在破解某种加密系统。直到发现小鱼的那个绿色终端界面——原来安装ROS可以像喝咖啡一样简单。这不是又…...

WeChatMsg:你的微信聊天记录永久保存与智能分析终极指南

WeChatMsg:你的微信聊天记录永久保存与智能分析终极指南 【免费下载链接】WeChatMsg 提取微信聊天记录,将其导出成HTML、Word、CSV文档永久保存,对聊天记录进行分析生成年度聊天报告 项目地址: https://gitcode.com/GitHub_Trending/we/WeC…...

玄机靶场:供应链安全-供应链应急-Part2 通关笔记

供应链安全-供应链应急-Part2 通关笔记 题目背景 本题是供应链安全应急响应的第二部分,主要考察对Gitea代码仓库和Jenkins持续集成环境的综合分析能力。黑客通过某种手段获取了开发者的Gitea Token,进而对多个代码仓库进行了恶意篡改,并在J…...

玄机靶场-2025数字中国 数据安全-溯源与取证 WP

玄机靶场-2025数字中国 数据安全-溯源与取证 WP 这道题是 2025 数字中国创新大赛数据安全赛道的原题,搬到玄机靶场上来了。主要考察磁盘数据恢复、加密驱动器解密和 Web 日志分析三块,题目一共 3 个步骤,难度中等,下面是完整解题过…...

三步解决Windows系统无法识别iPhone的终极方案:Apple-Mobile-Drivers-Installer深度指南

三步解决Windows系统无法识别iPhone的终极方案:Apple-Mobile-Drivers-Installer深度指南 【免费下载链接】Apple-Mobile-Drivers-Installer Powershell script to easily install Apple USB and Mobile Device Ethernet (USB Tethering) drivers on Windows! 项目…...

深入解析WeChatFerry:打造企业级微信机器人的5个核心技术要点

深入解析WeChatFerry:打造企业级微信机器人的5个核心技术要点 【免费下载链接】WeChatFerry 微信机器人,可接入DeepSeek、Gemini、ChatGPT、ChatGLM、讯飞星火、Tigerbot等大模型。微信 hook WeChat Robot Hook. 项目地址: https://gitcode.com/GitHub…...

基于DH参数的UR5机械臂PID轨迹跟踪控制及Simscape物理仿真:角度、速度、加速度与力...

UR5机械臂PID轨迹跟踪控制控制,六自由度机械臂simscape物理仿真,需要可以提供DH参数表,坐标系表示,三维模型,可以导出角度,角速度,角加速度以及力矩,误差曲线图机械臂轨迹跟踪这事儿…...

[1]锁相环 PLL 几个版本的matlab相位噪声拟合仿真代码,质量杠杠的,都是好东西

[1]锁相环 PLL 几个版本的matlab相位噪声拟合仿真代码,质量杠杠的,都是好东西 [2]锁相环matlab建模稳定性仿真,好几个版本 [3]锁相环2.4G小数分频 simulink建模仿真最近在折腾锁相环设计,发现手头这几个版本的Matlab相位噪声拟合…...

如何快速掌握ModTheSpire:杀戮尖塔模组加载器的终极配置指南

如何快速掌握ModTheSpire:杀戮尖塔模组加载器的终极配置指南 【免费下载链接】ModTheSpire External mod loader for Slay The Spire 项目地址: https://gitcode.com/gh_mirrors/mo/ModTheSpire 你是否厌倦了《杀戮尖塔》原版游戏内容?想要体验更…...

3步搭建NAS媒体库自动化管理系统:MoviePilot完整指南

3步搭建NAS媒体库自动化管理系统:MoviePilot完整指南 【免费下载链接】MoviePilot NAS媒体库自动化管理工具 项目地址: https://gitcode.com/gh_mirrors/mo/MoviePilot 在数字媒体时代,如何高效管理海量的电影和电视剧资源成为许多NAS用户的痛点。…...

保姆级教程:给你的YOLOv8模型“开天眼”,手把手集成CBAM/CA注意力模块(附完整代码)

YOLOv8模型增强实战:深度集成CBAM与CA注意力机制 在目标检测领域,YOLOv8以其卓越的平衡性——兼顾速度与精度——成为众多开发者的首选框架。然而,面对复杂场景时,原始模型可能对微小目标或遮挡物体表现不佳。这时,注意…...

m4s转MP4终极指南:3分钟学会B站缓存视频无损转换

m4s转MP4终极指南:3分钟学会B站缓存视频无损转换 【免费下载链接】m4s-converter 一个跨平台小工具,将bilibili缓存的m4s格式音视频文件合并成mp4 项目地址: https://gitcode.com/gh_mirrors/m4/m4s-converter 你是否曾经遇到过这样的情况&#x…...

从Modem到DTU:一个老电工的工业物联网设备选型避坑实录

从Modem到DTU:一个老电工的工业物联网设备选型避坑实录 记得去年夏天,厂里那条老生产线突然闹起了"罢工"。PLC控制柜里那台服役十年的无线Modem开始频繁掉线,每次故障都得爬上三米高的钢架桥检查设备。作为干了二十年的老电工&…...

小米电视去广告后,米家APP失灵了?教你一招两全其美(路由器Hosts规则详解)

小米电视去广告与米家APP兼容方案:路由器Hosts规则精细化管理指南 每次打开小米电视都要忍受漫长的开机广告?不少用户会选择通过修改路由器Hosts规则来屏蔽广告,但随之而来的往往是米家APP无法正常使用的尴尬。这种"拆东墙补西墙"的…...

BLE蓝牙模块型号,BLE蓝牙串口芯片应用

一、BLE蓝牙模块概述 传统串口设备升级无线通信功能时,往往需要重写底层驱动或修改上位机软件。而采用虚拟化串口技术的BLE蓝牙模块,通过将蓝牙连接模拟为本地COM口,使原有基于串口的上位机软件无需任何改动即可收发数据。这种“无感替换”能…...

别再死磕毕业论文!PaperXie 一键打通 “选题 - 定稿” 全流程,效率翻倍

paperxie-免费查重复率aigc检测/开题报告/毕业论文/智能排版/文献综述/期刊论文https://www.paperxie.cn/ai/dissertationhttps://www.paperxie.cn/ai/dissertation 毕业季的图书馆里,永远不缺对着空白文档发呆的大学生:选题改了八遍还被导师打回&#x…...

实测性能反超15%!C#工业上位机统信UOS+鲲鹏全栈移植指南(踩坑+优化+源码)

摘要 2026年是工业领域国产化替代的爆发年,统信UOS鲲鹏架构已成为政府、军工、能源等关键行业的标配。但90%的C#工业开发者都面临同一个难题:写了十几年的Windows上位机,怎么移植到Linux ARM64平台? 网上的教程要么碎片化&#xf…...

别再死磕毕业论文了!Paperxie 这波操作,把本科写作的 “坑” 全填上了

paperxie-免费查重复率aigc检测/开题报告/毕业论文/智能排版/文献综述/期刊论文https://www.paperxie.cn/ai/dissertationhttps://www.paperxie.cn/ai/dissertation 打开论文文档,盯着空白页面发呆;选题被导师打回 N 次,改到怀疑人生&#xf…...

别再死磕毕业论文!Paperxie 智能写作:大四生的「论文通关秘籍」

paperxie-免费查重复率aigc检测/开题报告/毕业论文/智能排版/文献综述/期刊论文https://www.paperxie.cn/ai/dissertationhttps://www.paperxie.cn/ai/dissertation 大四下学期的关键词,一半是毕业旅行、散伙饭,另一半却是改到崩溃的论文初稿、导师的红色…...

xrdp实战:构建企业级Linux远程桌面服务的3个关键决策

xrdp实战:构建企业级Linux远程桌面服务的3个关键决策 【免费下载链接】xrdp xrdp: an open source RDP server 项目地址: https://gitcode.com/gh_mirrors/xrd/xrdp xrdp作为开源RDP服务器,为Linux系统提供了Windows远程桌面协议的原生支持&#…...

Visual C++ Redistributable AIO:一站式解决Windows运行库依赖问题的架构设计与实施指南

Visual C Redistributable AIO:一站式解决Windows运行库依赖问题的架构设计与实施指南 【免费下载链接】vcredist AIO Repack for latest Microsoft Visual C Redistributable Runtimes 项目地址: https://gitcode.com/gh_mirrors/vc/vcredist Visual C Redi…...

STM32F302K8U6 + L6205D 驱动板实战:手把手教你搞定微型伺服电机FOC单电阻采样

STM32F302K8U6 L6205D 驱动板实战:微型伺服电机FOC单电阻采样全解析 在嵌入式电机控制领域,FOC(磁场定向控制)技术因其高效、精准的特性,正逐渐成为伺服电机控制的主流方案。本文将深入探讨基于STM32F302K8U6和L6205…...

Jimeng AI Studio新手指南:极简白色美学界面下的高效影像创作入门路径

Jimeng AI Studio新手指南:极简白色美学界面下的高效影像创作入门路径 1. 认识Jimeng AI Studio:你的极简影像创作终端 想象一下,你有一个想法,比如“一只戴着宇航员头盔的猫,在月球上喝咖啡”,你想立刻把…...

2025届最火的五大AI科研平台横评

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 当借助人工智能来生成文本之际,指令残留常常致使内容显得生硬,使逻辑…...

终极指南:如何用ModTheSpire为杀戮尖塔安装和管理模组

终极指南:如何用ModTheSpire为杀戮尖塔安装和管理模组 【免费下载链接】ModTheSpire External mod loader for Slay The Spire 项目地址: https://gitcode.com/gh_mirrors/mo/ModTheSpire ModTheSpire是专为《杀戮尖塔》设计的开源模组加载器,它能…...

2026 年 4 月 16 日 DataEase 发布 v2.10.21 LTS 版:新增技能体系、修复漏洞并优化多项功能

版本更新内容2026 年 4 月 16 日,人人可用的开源 BI 工具 DataEase 正式发布 v2.10.21 LTS 版本。在这一版本中,DataEase 推出了 Skills 技能体系,并进行了安全漏洞修复。在智能体方面,引入 DataEase Skills 技能体系;…...

Python实战:手把手教你解密并下载AES-128加密的M3U8视频流(附完整代码)

Python实战:手把手教你解密并下载AES-128加密的M3U8视频流(附完整代码) 最近在帮朋友处理一个在线教育平台的视频下载需求时,遇到了AES-128加密的M3U8视频流。这种加密方式在各大视频平台都很常见,但完整实现解密下载…...