当前位置: 首页 > article >正文

别再只盯着SENet了!手把手教你用PyTorch复现SKNet和CBAM(附完整代码)

深度学习注意力机制实战从SKNet到CBAM的PyTorch实现精要在计算机视觉领域注意力机制已经成为提升模型性能的关键技术。不同于传统的卷积神经网络平等对待所有特征通道注意力机制让模型学会关注最重要的信息。本文将带您深入实践两种先进的注意力机制——SKNet和CBAM通过完整的PyTorch实现和实战技巧帮助您快速掌握这些技术的核心要点。1. 注意力机制基础与演进注意力机制的核心思想是模仿人类视觉系统的选择性关注特性。当我们观察一个复杂场景时大脑会自动聚焦于最相关的区域而忽略次要信息。这种机制在深度学习中的实现使得模型能够动态调整对不同特征的重视程度。从早期的SENet开始注意力机制经历了快速演进。SENet通过简单的全局平均池化和全连接层实现了通道注意力而后续的SKNet和CBAM则在此基础上进行了重要改进SKNet引入多尺度卷积核的动态选择机制解决了单一感受野难以适应不同尺度目标的问题CBAM同时考虑通道和空间两个维度的注意力形成了更全面的特征优化方案这些改进带来了明显的性能提升。在ImageNet数据集上ResNet-50基础模型的top-1准确率约为76%加入SENet后提升到77.3%而使用CBAM可达到77.8%SKNet更是能达到78.2%的准确率。提示注意力模块通常加在网络的后半部分因为深层特征具有更高的语义信息更需要选择性关注理解这些机制的最佳方式就是亲手实现它们。下面我们将从代码层面详细解析SKNet和CBAM的实现细节并展示如何将它们集成到现有网络中。2. SKNet实现详解SKNetSelective Kernel Networks的核心创新在于其动态选择不同大小卷积核的能力。这种设计使网络能够自适应地处理不同尺度的视觉特征。2.1 SKNet模块结构SKNet包含三个关键操作Split、Fuse和Select。让我们通过PyTorch代码来具体实现import torch import torch.nn as nn import torch.nn.functional as F class SKConv(nn.Module): def __init__(self, in_channels, out_channels, stride1, M2, G32, r16): super(SKConv, self).__init__() self.M M self.out_channels out_channels # 创建不同卷积核大小的分支 self.convs nn.ModuleList([]) for i in range(M): # 使用3x3和5x5卷积核5x5通过dilation实现 kernel_size 3 2*i dilation 1 i padding dilation self.convs.append( nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_sizekernel_size, stridestride, paddingpadding, dilationdilation, groupsG, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue) ) ) # 注意力机制的全连接层 self.fc nn.Linear(out_channels, out_channels//r) self.fcs nn.ModuleList([ nn.Linear(out_channels//r, out_channels) for _ in range(M) ]) self.softmax nn.Softmax(dim1) def forward(self, x): # Split阶段各分支处理输入 feats [conv(x) for conv in self.convs] feats torch.stack(feats, dim1) # [B, M, C, H, W] # Fuse阶段特征融合 feats_sum torch.sum(feats, dim1) # [B, C, H, W] # 通道注意力计算 gap F.adaptive_avg_pool2d(feats_sum, (1,1)) # [B, C, 1, 1] gap gap.view(gap.size(0), -1) # [B, C] fcs self.fc(gap) # [B, C/r] # Select阶段计算各分支注意力权重 attention_vectors [fc(fcs) for fc in self.fcs] attention_vectors torch.stack(attention_vectors, dim1) # [B, M, C] attention_vectors self.softmax(attention_vectors) # [B, M, C] attention_vectors attention_vectors.unsqueeze(-1).unsqueeze(-1) # [B, M, C, 1, 1] # 加权融合各分支特征 out (feats * attention_vectors).sum(dim1) return out2.2 SKNet与ResNet集成将SKConv集成到ResNet中需要替换原有的基本模块。以下是SK版本的ResNet基本块实现class SKBlock(nn.Module): expansion 1 def __init__(self, in_channels, out_channels, stride1, downsampleNone): super(SKBlock, self).__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size1, stride1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 SKConv(out_channels, out_channels, stridestride) self.bn2 nn.BatchNorm2d(out_channels) self.relu nn.ReLU(inplaceTrue) self.downsample downsample def forward(self, x): residual x out self.conv1(x) out self.bn1(out) out self.relu(out) out self.conv2(out) out self.bn2(out) if self.downsample is not None: residual self.downsample(x) out residual out self.relu(out) return out在实际应用中SKNet特别适合处理多尺度目标的任务如场景解析、目标检测等。通过调整M参数分支数量和r参数压缩比例可以在模型性能和计算成本之间取得平衡。3. CBAM实现解析CBAMConvolutional Block Attention Module通过串联通道注意力和空间注意力实现了更全面的特征优化。3.1 通道注意力模块实现CBAM的通道注意力模块在SENet基础上增加了并行最大池化路径class ChannelAttention(nn.Module): def __init__(self, in_channels, ratio16): super(ChannelAttention, self).__init__() self.avg_pool nn.AdaptiveAvgPool2d(1) self.max_pool nn.AdaptiveMaxPool2d(1) self.shared_mlp nn.Sequential( nn.Conv2d(in_channels, in_channels//ratio, 1, biasFalse), nn.ReLU(), nn.Conv2d(in_channels//ratio, in_channels, 1, biasFalse) ) self.sigmoid nn.Sigmoid() def forward(self, x): avg_out self.shared_mlp(self.avg_pool(x)) max_out self.shared_mlp(self.max_pool(x)) out avg_out max_out return self.sigmoid(out)3.2 空间注意力模块实现空间注意力模块则关注特征图的空间位置重要性class SpatialAttention(nn.Module): def __init__(self, kernel_size7): super(SpatialAttention, self).__init__() assert kernel_size % 2 1, Kernel size must be odd self.conv nn.Conv2d(2, 1, kernel_size, paddingkernel_size//2, biasFalse) self.sigmoid nn.Sigmoid() def forward(self, x): avg_out torch.mean(x, dim1, keepdimTrue) max_out, _ torch.max(x, dim1, keepdimTrue) x torch.cat([avg_out, max_out], dim1) x self.conv(x) return self.sigmoid(x)3.3 完整CBAM模块与ResNet集成将通道和空间注意力串联形成完整CBAM模块class CBAM(nn.Module): def __init__(self, in_channels, ratio16, kernel_size7): super(CBAM, self).__init__() self.channel_att ChannelAttention(in_channels, ratio) self.spatial_att SpatialAttention(kernel_size) def forward(self, x): x x * self.channel_att(x) x x * self.spatial_att(x) return x集成到ResNet中的示例class CBAMBlock(nn.Module): expansion 1 def __init__(self, in_channels, out_channels, stride1, downsampleNone): super(CBAMBlock, self).__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse) self.bn2 nn.BatchNorm2d(out_channels) self.relu nn.ReLU(inplaceTrue) self.downsample downsample self.cbam CBAM(out_channels) def forward(self, x): residual x out self.conv1(x) out self.bn1(out) out self.relu(out) out self.conv2(out) out self.bn2(out) out self.cbam(out) if self.downsample is not None: residual self.downsample(x) out residual out self.relu(out) return out4. 实战应用与性能对比在实际项目中应用这些注意力模块时有几个关键考虑因素插入位置选择浅层网络更适合空间注意力捕捉位置信息深层网络更适合通道注意力处理高级语义特征计算开销控制注意力模块会增加模型参数和计算量通过调整压缩比例(r)可以平衡性能和效率任务适配性分类任务CBAM通常表现更好检测/分割任务SKNet的多尺度特性更有优势以下是在CIFAR-100数据集上的简单对比实验设置import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader # 数据准备 transform_train transforms.Compose([ transforms.RandomCrop(32, padding4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) transform_test transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), ]) train_set datasets.CIFAR100(root./data, trainTrue, downloadTrue, transformtransform_train) test_set datasets.CIFAR100(root./data, trainFalse, downloadTrue, transformtransform_test) train_loader DataLoader(train_set, batch_size128, shuffleTrue, num_workers2) test_loader DataLoader(test_set, batch_size100, shuffleFalse, num_workers2) # 模型训练配置 def train_model(model, name, epochs200): criterion nn.CrossEntropyLoss() optimizer optim.SGD(model.parameters(), lr0.1, momentum0.9, weight_decay5e-4) scheduler optim.lr_scheduler.CosineAnnealingLR(optimizer, T_maxepochs) for epoch in range(epochs): model.train() for inputs, targets in train_loader: inputs, targets inputs.to(device), targets.to(device) optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, targets) loss.backward() optimizer.step() scheduler.step() # 测试集评估 model.eval() correct 0 total 0 with torch.no_grad(): for inputs, targets in test_loader: inputs, targets inputs.to(device), targets.to(device) outputs model(inputs) _, predicted outputs.max(1) total targets.size(0) correct predicted.eq(targets).sum().item() acc 100.*correct/total print(f{name} Epoch: {epoch} | Test Acc: {acc:.2f}%)实验结果显示在相同训练条件下基础ResNet-34的测试准确率约为72.5%加入SENet提升到74.1%CBAM达到75.3%而SKNet表现最佳达到76.8%的准确率。

相关文章:

别再只盯着SENet了!手把手教你用PyTorch复现SKNet和CBAM(附完整代码)

深度学习注意力机制实战:从SKNet到CBAM的PyTorch实现精要 在计算机视觉领域,注意力机制已经成为提升模型性能的关键技术。不同于传统的卷积神经网络平等对待所有特征通道,注意力机制让模型学会"关注"最重要的信息。本文将带您深入…...

SQL盲注技术全解析:布尔盲注、时间盲注与DNSLog带外注入

前言 在之前的学习中,我们掌握了 SQL 注入的基本原理,包括联合查询注入和报错注入技术。这些攻击方式都有一个共同点:需要页面能够显示查询结果或通过报错信息泄露数据。但在实际环境中,Web 应用通常会采取多种防护措施&#xff…...

SQL注入攻击与防御实战:手把手教你挖漏洞

三、防御方案。1.参数化查询:用Prepared Statements,用户输入当数据处理。PHP用PDO,Java用PreparedStatement。2.输入验证:白名单过滤危险字符单引号、分号等。3.使用ORM框架:Laravel、Hibernate等内置防注入。4.最小权…...

Vue3怎么起步入门?

Vue.js 是一个渐进式 JavaScript 框架,主要用于构建用户界面。 刚开始学习 Vue,我们不推荐使用 vue-cli 命令行工具来创建项目,更简单的方式是直接在页面引入 vue.global.js 文件来测试学习。 Vue3 中的应用是通过使用 createApp 函数来创建…...

从集合到点云:深入浅出图解Deep Sets的置换不变性到底在说什么

从集合到点云:深入浅出图解Deep Sets的置换不变性到底在说什么 想象一下,你面前有一堆散落的乐高积木,无论你怎么打乱它们的顺序,最终拼出来的城堡总是一样的。这就是置换不变性(Permutation Invariance)的…...

终极指南:3步解锁百度网盘SVIP高速下载功能(macOS版)

终极指南:3步解锁百度网盘SVIP高速下载功能(macOS版) 【免费下载链接】BaiduNetdiskPlugin-macOS For macOS.百度网盘 破解SVIP、下载速度限制~ 项目地址: https://gitcode.com/gh_mirrors/ba/BaiduNetdiskPlugin-macOS 还在为百度网盘…...

【Python基础】零基础入门到实战,这一篇就够了!(附详细代码)

前言 大家好,我是jifeng,今天给大家带来一篇全网最贴心的Python保姆级入门教程。 在这个AI与大数据爆发的时代,“人生苦短,我用Python” 早已不仅仅是一句口号。无论是Web开发、数据分析、人工智能还是日常办公自动化&#xff0…...

SiameseUIE模型在网络安全领域的应用:威胁情报抽取

SiameseUIE模型在网络安全领域的应用:威胁情报抽取 网络安全分析师每天都要面对海量的威胁情报报告、安全日志和漏洞公告。这些文本数据里藏着攻击者的IP地址、恶意域名、攻击手法、漏洞编号等关键信息。传统做法是人工逐篇阅读、标记、整理,不仅效率低…...

终极指南:如何用KMS_VL_ALL_AIO一键永久激活Windows和Office系统

终极指南:如何用KMS_VL_ALL_AIO一键永久激活Windows和Office系统 【免费下载链接】KMS_VL_ALL_AIO Smart Activation Script 项目地址: https://gitcode.com/gh_mirrors/km/KMS_VL_ALL_AIO 还在为Windows系统频繁弹出激活提示而烦恼吗?Office文档…...

SOCD Cleaner:终极键盘优化工具 - 5个关键优势提升游戏操作精度

SOCD Cleaner:终极键盘优化工具 - 5个关键优势提升游戏操作精度 【免费下载链接】socd Key remapper for epic gamers 项目地址: https://gitcode.com/gh_mirrors/so/socd 在竞技游戏的微秒级对决中,你是否曾因同时按下W和S键导致角色卡顿&#x…...

解锁小米EG系列机型的注意事项

springboot自动配置 自动配置了大量组件,配置信息可以在application.properties文件中修改。 当添加了特定的Starter POM后,springboot会根据类路径上的jar包来自动配置bean(比如:springboot发现类路径上的MyBatis相关类&#xff…...

如何在Windows上获得苹果触控板的原生级体验:mac-precision-touchpad完整指南

如何在Windows上获得苹果触控板的原生级体验:mac-precision-touchpad完整指南 【免费下载链接】mac-precision-touchpad Windows Precision Touchpad Driver Implementation for Apple MacBook / Magic Trackpad 项目地址: https://gitcode.com/gh_mirrors/ma/mac…...

H5GG:零门槛定制iOS应用,JavaScript引擎开启全新可能

H5GG:零门槛定制iOS应用,JavaScript引擎开启全新可能 【免费下载链接】H5GG an iOS Mod Engine with JavaScript APIs & Html5 UI 项目地址: https://gitcode.com/gh_mirrors/h5/H5GG 在iOS生态系统中,定制化一直是技术爱好者的追…...

YOLO系列算法改进 | C2PSA改进篇 | 融合UPT不确定性先验Transformer模块 | 突破模糊感知瓶颈,动态聚焦困难样本 | CVPR 2026

0. 前言 本文介绍UPT(不确定性先验Transformer模块),并将其集成到ultralytics最新发布的YOLO26目标检测算法中,构建C2PSA_UPT创新模块。UPT是一种基于不确定性感知的注意力机制,源自UCMNet图像复原架构,旨在通过估计特征图的空间不确定性来引导上下文特征的动态检索与聚…...

从VGA到8K:一文读懂HDMI协议进化史与关键版本差异(1.4/2.0/2.1对比)

从VGA到8K:HDMI协议进化史与关键版本差异全解析 2002年12月,当索尼、松下、东芝等七家电子巨头联合发布HDMI 1.0标准时,很少有人能预料到这个接口会在未来二十年彻底改变视听产业的格局。如今,从家庭影院到电竞显示器,…...

Pandas 复制 DataFrame的方法总结

Pandas 复制 DataFrame的方法总结 1.pandas.DataFrame.copy() 方法语法 DataFrame.copy(deepTrue) 它返回 DataFrame 的副本。deep 默认为 True,这意味着在副本中所作的任何更改将不会反映在原始 DataFrame 中。但是,如果我们设置 deepFalse&#xff…...

数据库分库分表方案设计

数据库分库分表方案设计:应对海量数据挑战 随着互联网业务规模不断扩大,传统单库单表的数据库架构逐渐暴露出性能瓶颈。当数据量达到千万甚至亿级时,查询延迟、写入拥堵等问题频发,分库分表成为解决这一难题的核心方案。通过将数…...

3分钟搞定专业照片批量水印:告别繁琐手动操作

3分钟搞定专业照片批量水印:告别繁琐手动操作 【免费下载链接】semi-utils 一个批量添加相机机型和拍摄参数的工具,后续「可能」添加其他功能。 项目地址: https://gitcode.com/gh_mirrors/se/semi-utils 还在为每张照片手动添加水印而烦恼吗&…...

为何要使用虚拟计算机(v0.1.0)

一、术语 【虚拟计算机】 虚拟计算机(Virtual Machine, VM),简称虚拟机,是通过软件模拟出来的、具有完整硬件系统功能的、运行在一个完全隔离环境中的计算机系统。 你可以把它理解为“电脑里的另一台电脑”。其概念图见图1。[1] …...

从‘浪费生命’到‘轻松驾驭’:我的NRF24L01/SI24L01调试心路与替代方案盘点

从‘浪费生命’到‘轻松驾驭’:NRF24L01/SI24L01调试心路与替代方案盘点 第一次点亮NRF24L01模块时,我天真地以为无线通信的大门就此敞开。直到连续三天的调试中,这个火柴盒大小的模块让我经历了从期待到崩溃的全过程——明明代码和接线都&qu…...

长沙金海中学答题:中天电子实现精准调控

课堂困境与答题需求长沙金海中学在传统教学模式中,面临着诸多答题相关的痛点。每次进行50题的答题测试,教师需要花费30分钟以上的时间进行人工批改,这不仅耗时耗力,还容易出现批改错误。同时,课堂互动参与率不足30%&am…...

3步解锁加密音频:实现全平台自由播放的终极方案

3步解锁加密音频:实现全平台自由播放的终极方案 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 你是否曾遇到过这样的困扰?在网易云音乐下载的歌曲只能在特定应用播放,无法在车载音响、智能音箱或…...

告别盲目干扰!用VH6501做车载网络测试,你必须分清Rx和Tx的触发逻辑

车载网络测试进阶:VH6501中Rx与Tx干扰逻辑的深度解析 在车载电子系统日益复杂的今天,CAN-FD总线承载着越来越多的关键数据交换。作为测试工程师,我们常常需要模拟各种异常场景来验证系统的鲁棒性。VH6501作为专业的CAN干扰接口,其…...

51单片机按键控制LED的两种C语言写法对比:数组映射 vs Switch语句,哪种更适合你?

51单片机按键控制LED的两种编程范式深度解析:数组映射与Switch语句实战对比 在嵌入式开发中,按键控制LED是最基础却最能体现编程思想的实验。当我们需要实现按键顺序控制8个LED时,数组映射和switch语句是两种典型解决方案。这两种方法看似都能…...

如何在macOS上打造完美音乐体验:LyricsX歌词神器完全指南 [特殊字符]

如何在macOS上打造完美音乐体验:LyricsX歌词神器完全指南 🎵 【免费下载链接】LyricsX 🎶 Ultimate lyrics app for macOS. 项目地址: https://gitcode.com/gh_mirrors/ly/LyricsX 想要在macOS上享受完美的音乐歌词体验吗?…...

2026届必备的降AI率网站推荐榜单

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 此刻,AI生成内容检测技术正日益走向成熟之态,这使得大量经由自动化产…...

Android14 Launcher3开发实战:用SurfaceControl实现跨进程动画的5个关键技巧

Android 14 Launcher3开发实战:SurfaceControl跨进程动画的5个核心技法 在Android系统定制开发领域,Launcher作为用户交互的第一入口,其动画流畅度直接影响用户体验。随着Android 14的发布,SurfaceControl在跨进程动画处理上展现…...

百度网盘下载加速全攻略:3步解锁满速下载的免费开源方案

百度网盘下载加速全攻略:3步解锁满速下载的免费开源方案 【免费下载链接】baidupcs-web 项目地址: https://gitcode.com/gh_mirrors/ba/baidupcs-web 还在为百度网盘下载速度慢如蜗牛而烦恼吗?每次下载大文件都需要花费数小时甚至更长时间&#…...

省级面板数据避坑指南:统计局2500指标的真实使用场景解析

省级面板数据实战解析:能源财政指标的深度验证与陷阱规避 当面对涵盖2500指标的省级面板数据时,智库研究员和政策分析师常常陷入两难:一方面欣喜于数据的丰富性,另一方面又担忧数据质量对研究结论的影响。特别是在能源转型和财政政…...

复杂表格快速解读(使用千问)

复杂表格通常包含多维度数据(如多产品、多区域、多时间段)、多层级分类,人工解读需先梳理结构,再整合数据,耗时且易遗漏关键信息。千问通过“结构解析数据聚合”的双重逻辑,可快速输出表格核心框架与关键数…...