(ECCV2018)CBAM改进思路
论文链接:https://arxiv.org/abs/1807.06521
论文题目:CBAM: Convolutional Block Attention Module
会议:ECCV2018
论文方法
利用特征的通道间关系生成了一个通道注意图。 由于特征映射的每个通道被认为是一个特征检测器,通道注意力集中在给定输入图像的“什么”是有意义的。 为了有效地计算通道注意力,我们压缩了输入特征映射的空间维度。 对于空间信息的聚合,目前普遍采用平均池化方法。 除了之前的工作,我们认为最大池化收集了另一个关于不同对象特征的重要线索,以推断更精细的通道明智的注意力。 因此,作者同时使用平均池化和最大池化特征。
利用特征的空间间关系生成空间注意图。 与通道注意不同的是,空间注意关注的“在哪里”是信息部分,与通道注意是互补的。 为了计算空间注意力,首先沿着通道轴应用平均池化和最大池化操作,并将它们连接起来以生成有效的特征描述符。 沿着通道轴应用池操作可以有效地突出显示信息区域。 在连接的特征描述符上,应用卷积层生成空间注意映射Ms(F)∈RH×W,该映射编码强调或抑制的位置。
论文源代码
import torch
import torch.nn.functional as F
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False)) self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outout = self.sigmoid(out)return out * xclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avg_out, max_out], dim=1)out = self.sigmoid(self.conv1(out))return out * xclass CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=3):super(CBAM, self).__init__()self.channelattention = ChannelAttention(in_channels, ratio=ratio)self.spatialattention = SpatialAttention(kernel_size=kernel_size)def forward(self, x):x = self.channelattention(x)x = self.spatialattention(x)return x
改进思路
1.通道注意力独立分支与批归一化
使用独立的FC层处理平均池化和最大池化,增强表达能力。
在FC层之间加入批归一化,加速训练收敛。
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)# 独立的全连接层分支self.fc_avg = nn.Sequential(nn.Conv2d(in_channels, in_channels//ratio, 1, bias=False),nn.BatchNorm2d(in_channels//ratio), # 添加BNnn.ReLU(inplace=True),nn.Conv2d(in_channels//ratio, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels) # 输出层也可以考虑BN)self.fc_max = nn.Sequential(nn.Conv2d(in_channels, in_channels//ratio, 1, bias=False),nn.BatchNorm2d(in_channels//ratio),nn.ReLU(inplace=True),nn.Conv2d(in_channels//ratio, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc_avg(self.avg_pool(x))max_out = self.fc_max(self.max_pool(x))out = self.sigmoid(avg_out + max_out)return x * out
2.空间注意力深度增强
使用多层卷积增加非线性。
引入残差连接提升梯度流动。
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()padding = kernel_size // 2self.conv = nn.Sequential(nn.Conv2d(2, 32, kernel_size, padding=padding, bias=False),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.Conv2d(32, 1, kernel_size, padding=padding, bias=False), # 深层卷积nn.BatchNorm2d(1))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)cat = torch.cat([avg_out, max_out], dim=1)out = self.conv(cat) + cat.mean(dim=1, keepdim=True) # 残差连接return x * self.sigmoid(out)
3.动态比例调整、参数初始化优化、并行注意力融合
import torch
import torch.nn as nn
import torch.nn.functional as F# --------------------------
# 改进3:动态比例调整
# --------------------------
def get_ratio(in_channels, min_ratio=16):"""动态计算压缩比例,防止通道数过小时出现除零错误"""return max(in_channels // min_ratio, 4) # 保证最小分割比例为4# --------------------------
# 改进4:参数初始化优化
# --------------------------
def init_weights(m):"""He初始化 + 零偏置初始化"""if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# --------------------------
# 改进1/3:通道注意力(包含动态比例调整)
# --------------------------
class ChannelAttention(nn.Module):def __init__(self, in_channels):super().__init__()ratio = get_ratio(in_channels) # 动态计算ratioself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels, ratio, 1, bias=False),nn.BatchNorm2d(ratio),nn.ReLU(),nn.Conv2d(ratio, in_channels, 1, bias=False),nn.BatchNorm2d(in_channels))self.sigmoid = nn.Sigmoid()self.apply(init_weights) # 应用参数初始化def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))weight = self.sigmoid(avg_out + max_out)return x * weight# --------------------------
# 改进1:空间注意力
# --------------------------
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()padding = kernel_size // 2self.conv = nn.Sequential(nn.Conv2d(2, 32, kernel_size, padding=padding, bias=False),nn.BatchNorm2d(32),nn.ReLU(),nn.Conv2d(32, 1, kernel_size, padding=padding, bias=False),nn.BatchNorm2d(1))self.sigmoid = nn.Sigmoid()self.apply(init_weights) # 应用参数初始化def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)cat = torch.cat([avg_out, max_out], dim=1)weight = self.sigmoid(self.conv(cat))return x * weight# --------------------------
# 改进5:并行注意力融合
# --------------------------
class CBAM(nn.Module):def __init__(self, in_channels, kernel_size=7):super().__init__()self.ca = ChannelAttention(in_channels)self.sa = SpatialAttention(kernel_size)self.apply(init_weights) # 整个模块应用初始化def forward(self, x):# 并行计算通道注意力和空间注意力ca_out = self.ca(x) # 通道注意力分支sa_out = self.sa(x) # 空间注意力分支# 残差连接融合 (原始特征 + 通道特征 + 空间特征)return x + ca_out + sa_out
相关文章:

(ECCV2018)CBAM改进思路
论文链接:https://arxiv.org/abs/1807.06521 论文题目:CBAM: Convolutional Block Attention Module 会议:ECCV2018 论文方法 利用特征的通道间关系生成了一个通道注意图。 由于特征映射的每个通道被认为是一个特征检测器,通道…...
Python脚本,音频格式转换 和 视频格式转换
一、音频格式转换完整代码 from pydub import AudioSegment import osdef convert_audio(input_dir, output_dir, target_format):if not os.path.exists(output_dir):os.makedirs(output_dir)for filename in os.listdir(input_dir):if filename.endswith((.mp3, .wav, .ogg)…...

基于Spring Boot的高校就业招聘系统的设计与实现(LW+源码+讲解)
专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...

强化学习(赵世钰版)-学习笔记(4.值迭代与策略迭代)
本章是整个课程中,算法与方法的第一章,应该是最简单的入门方法。 上一章讲到了贝尔曼最优方程,其目的是计算出最优状态值,从而确定对应的最优策略。 而压缩映射理论推出了迭代算法 对初始值V0赋一个随机的初始值,算法最…...

Cursor安装配置
1.安装 通过网盘分享的文件:Cursor Setup 0.45.11 - x64.exe 链接: 百度网盘 请输入提取码 提取码: 6juv 2. 配置 选择AI工具的语言 输入AI工具的语言为 "中文" ,输入完语言之后,直接点击 "Continue" 下一步&#x…...

相机几何:从三维世界到二维图像的映射
本系列课程将带领读者开启一场独特的三维视觉工程之旅。我们不再止步于教科书式的公式推导,而是聚焦于如何将抽象的数学原理转化为可落地的工程实践。通过解剖相机的光学特性、构建成像数学模型、解析坐标系转换链条,直至亲手实现参数标定代码࿰…...

【GoTeams】-5:引入Docker
本文目录 1. Dokcer-compose回顾下Docker知识编写docker-compose.yaml运行docker 2. 部署go服务编写dockerfile 1. Dokcer-compose 这里简单先用一下win版本的Docker,后期开发好了部署的时候再移植到服务器下进行docker部署。 输入命令docker-compose version 就可…...
基金股票期权期货投资方式对比
以下是基金、股票、期权和期货的详细对比分析,涵盖定义、核心特点、优势、劣势、适用场景及相互区别: 一、基金 定义 基金是通过集合投资者的资金,由专业管理人(基金经理)进行多元化投资的金融工具。根据投资标的可分…...

大模型AI平台DeepSeek 眼中的SQL2API平台:QuickAPI、dbapi 和 Magic API 介绍与对比
目录 1 QuickAPI 介绍 2 dbapi 介绍 3 Magic API 介绍 4 简单对比 5 总结 统一数据服务平台是一种低代码的方式,实现一般是通过SQL能直接生成数据API,同时能对产生的数据API进行全生命周期的管理,典型的SQL2API的实现模式。 以下是针对…...

K8S学习之基础十九:k8s的四层代理Service
K8S四层代理Service 四层负载均衡Service 在k8s中,访问pod可以通过ip端口的方式,但是pod是由生命 周期的,pod在重启的时候ip地址往往会发生变化,访问pod就需要新的ip地址,这样就会很麻烦,每次pod地址改变就…...

揭开AI-OPS 的神秘面纱 第六讲 AI 模型服务层 - 开源模型选型与应用 (时间序列场景|图神经网络场景)
时间序列场景 AI 模型服务层 - 开源模型选型与应用 (时间序列场景) 在 AI-Ops 中,时间序列数据分析主要应用于以下场景: 指标预测: 预测 Metrics 指标 (例如 CPU 使用率、内存使用率、网络流量、请求延迟等) 的未来趋势,用于容量规划、资源调度、异常检测等。异常检测: 检…...
在Dify中访问Gemini等模型代理设置指南
问题背景 Google Gemini模型可纯免费使用,且性能也相当不错,一般个人使用或研究足够。但在在国内访问,需设置代理。在Docker部署Dify时,虽然按官方文档介绍设置代理环境变量,但实测发现并不生效。我们通过研究试验解决…...

MySQL的安装以及数据库的基本配置
MySQL的安装及配置 MySQL的下载 选择想要安装的版本,点击Download下载 Mysql官网下载地址: https://downloads.mysql.com/archives/installer/ MySQL的安装 选择是自定义安装,所以直接选择“Custom”,点击“Next” …...
设备树的组成
根节点下含有 compatile 属性的子节点 含有特定 compatile 属性的节点的子节点 如果一个节点的 compatile 属性,它的值是这 4 者之一:"simple-bus","simple-mfd","isa","arm,amba-bus", 那 么 它 的 子结点 (…...

C++入门——输入输出、缺省参数
C入门——输入输出、缺省参数 一、C标准库——命名空间 std C标准库std是一个命名空间,全称为"standard",其中包括标准模板库(STL),输入输出系统,文件系统库,智能指针与内存管理&am…...

deepseek 本地部署
deepseek 本地部署 纯新手教学,手把手5分钟带你在本地部署一个私有的deepseek,再也不用受网络影响。流畅使用deepseek!!! 如果不想看文章,指路:Deep seek R1本地部署 小白超详细教程 ࿰…...

[网络爬虫] 动态网页抓取 — Selenium 入门操作
🌟想系统化学习爬虫技术?看看这个:[数据抓取] Python 网络爬虫 - 学习手册-CSDN博客 0x01:WebDriver 类基础属性 & 方法 为模仿用户真实操作浏览器的基本过程,Selenium 的 WebDriver 模块提供了一个 WebDriver 类…...
HTML 超链接(简单易懂较详细)
在 HTML 中,超链接是通过 <a> 标签(anchor tag)创建的。超链接允许用户通过点击文本、图像或其他元素跳转到另一个网页、文件或页面的特定部分。本文将详细介绍 HTML 超链接的语法、属性和应用场景。 一、基本语法 <a href"U…...
rpc和proto
rpc全称远程过程控制,说白了是一种对信息发送和接收的规则编写方法,来自google,这些规则会以protobuf代码存到proto文件里。我以autoGen中agent_worker.proto为例,大概长这样 syntax "proto3";package agents;option …...

OPENGLPG第九版学习 -颜色、像素和片元 PART1
文章目录 4.1 基本颜色理论4.2 缓存及其用途颜色缓存深度缓存 / z缓存 / z-buffer模板缓存 4.2.1 缓存的清除4.2.2 缓存的掩码 4.3 颜色与OpenGL4.3.1 颜色的表达与OpenGL4.3.2 平滑数据插值 4.4 片元的测试与操作4.4.1 剪切测试4.4.2 多重采样的片元操作4.4.3 模板测试模板查询…...

wordpress后台更新后 前端没变化的解决方法
使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...

7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...

【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

Python:操作 Excel 折叠
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...
python如何将word的doc另存为docx
将 DOCX 文件另存为 DOCX 格式(Python 实现) 在 Python 中,你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是,.doc 是旧的 Word 格式,而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...

Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...