当前位置: 首页 > article >正文

手把手教你用M-CBAM提升遥感图像分类精度(附Python代码)

手把手教你用M-CBAM提升遥感图像分类精度附Python代码遥感图像分类一直是计算机视觉领域的重要研究方向尤其在土地利用规划、环境监测和灾害评估等应用中发挥着关键作用。然而由于遥感图像通常包含复杂的场景和多样化的地物目标传统分类方法往往难以达到理想的精度。本文将详细介绍如何利用改进的通道-空间注意力模块M-CBAM来显著提升遥感图像分类性能并提供完整的Python实现代码。1. M-CBAM模块原理与优势M-CBAMModified Convolutional Block Attention Module是在经典CBAM注意力机制基础上的改进版本专门针对遥感图像特点进行了优化。其核心思想是通过同时关注通道和空间两个维度的关键信息让模型能够更有效地聚焦于图像中的判别性区域。1.1 通道注意力机制通道注意力模块通过学习不同特征通道的重要性权重实现对关键特征的增强和非关键特征的抑制。具体实现流程如下class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio16): super(ChannelAttention, self).__init__() self.avg_pool nn.AdaptiveAvgPool2d(1) self.max_pool nn.AdaptiveMaxPool2d(1) self.fc1 nn.Conv2d(in_planes, in_planes // ratio, 1, biasFalse) self.relu1 nn.ReLU() self.fc2 nn.Conv2d(in_planes // ratio, in_planes, 1, biasFalse) self.sigmoid nn.Sigmoid() def forward(self, x): avg_out self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out avg_out max_out return self.sigmoid(out)1.2 空间注意力机制空间注意力模块则关注图像中的空间位置重要性能够有效突出场景中的关键区域class SpatialAttention(nn.Module): def __init__(self, kernel_size7): super(SpatialAttention, self).__init__() self.conv1 nn.Conv2d(2, 1, kernel_size, paddingkernel_size//2, biasFalse) self.sigmoid nn.Sigmoid() def forward(self, x): avg_out torch.mean(x, dim1, keepdimTrue) max_out, _ torch.max(x, dim1, keepdimTrue) x torch.cat([avg_out, max_out], dim1) x self.conv1(x) return self.sigmoid(x)1.3 M-CBAM的创新点相比原始CBAMM-CBAM主要做了以下改进多尺度特征融合在空间注意力前加入金字塔池化模块捕获不同尺度的上下文信息动态权重调整根据特征重要性动态调整通道和空间注意力的融合比例残差连接设计保留原始特征信息避免注意力机制导致的信息丢失2. 在遥感图像分类中的集成方法将M-CBAM模块集成到现有分类网络中可以显著提升模型对复杂遥感场景的理解能力。下面以ResNet为例展示具体的集成方式。2.1 基础网络改造首先需要在ResNet的每个残差块后添加M-CBAM模块class M_CBAM_ResNet(nn.Module): def __init__(self, block, layers, num_classes21): super(M_CBAM_ResNet, self).__init__() self.inplanes 64 self.conv1 nn.Conv2d(3, 64, kernel_size7, stride2, padding3, biasFalse) self.bn1 nn.BatchNorm2d(64) self.relu nn.ReLU(inplaceTrue) self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) # 在各阶段添加M-CBAM模块 self.layer1 self._make_layer(block, 64, layers[0]) self.cbam1 M_CBAM(64 * block.expansion) self.layer2 self._make_layer(block, 128, layers[1], stride2) self.cbam2 M_CBAM(128 * block.expansion) self.layer3 self._make_layer(block, 256, layers[2], stride2) self.cbam3 M_CBAM(256 * block.expansion) self.layer4 self._make_layer(block, 512, layers[3], stride2) self.cbam4 M_CBAM(512 * block.expansion) self.avgpool nn.AdaptiveAvgPool2d((1, 1)) self.fc nn.Linear(512 * block.expansion, num_classes) def forward(self, x): x self.conv1(x) x self.bn1(x) x self.relu(x) x self.maxpool(x) x self.layer1(x) x self.cbam1(x) x self.layer2(x) x self.cbam2(x) x self.layer3(x) x self.cbam3(x) x self.layer4(x) x self.cbam4(x) x self.avgpool(x) x torch.flatten(x, 1) x self.fc(x) return x2.2 训练策略优化使用M-CBAM后模型的训练策略也需要相应调整学习率设置初始学习率设为0.01每30个epoch衰减为原来的1/10损失函数采用Label Smoothing Cross Entropy缓解遥感数据中的类别不平衡问题数据增强特别针对遥感图像特点添加随机旋转、色彩抖动等增强方式# 优化器设置 optimizer torch.optim.SGD(model.parameters(), lr0.01, momentum0.9, weight_decay1e-4) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size30, gamma0.1) # 损失函数 criterion nn.CrossEntropyLoss(label_smoothing0.1) # 数据增强 train_transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomRotation(30), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])3. UC Merced数据集上的实战应用UC Merced土地利用数据集是遥感图像分类的基准数据集之一包含21类场景每类有100张256×256像素的图像。我们以此为例展示M-CBAM的实际效果。3.1 数据准备与加载首先需要下载并组织UC Merced数据集UC_Merced/ ├── agricultural/ ├── airplane/ ├── ... └── parkinglot/然后使用PyTorch的Dataset类加载数据class UCMercedDataset(Dataset): def __init__(self, root_dir, transformNone): self.root_dir root_dir self.transform transform self.classes sorted(os.listdir(root_dir)) self.class_to_idx {cls_name: i for i, cls_name in enumerate(self.classes)} self.images [] for cls_name in self.classes: cls_dir os.path.join(root_dir, cls_name) for img_name in os.listdir(cls_dir): self.images.append((os.path.join(cls_dir, img_name), self.class_to_idx[cls_name])) def __len__(self): return len(self.images) def __getitem__(self, idx): img_path, label self.images[idx] image Image.open(img_path).convert(RGB) if self.transform: image self.transform(image) return image, label # 创建数据集实例 train_dataset UCMercedDataset(UC_Merced/train, transformtrain_transform) val_dataset UCMercedDataset(UC_Merced/val, transformval_transform) # 数据加载器 train_loader DataLoader(train_dataset, batch_size32, shuffleTrue, num_workers4) val_loader DataLoader(val_dataset, batch_size32, shuffleFalse, num_workers4)3.2 模型训练与验证完整的训练循环实现如下def train_model(model, criterion, optimizer, scheduler, num_epochs100): best_acc 0.0 for epoch in range(num_epochs): print(fEpoch {epoch}/{num_epochs - 1}) print(- * 10) # 训练阶段 model.train() running_loss 0.0 running_corrects 0 for inputs, labels in train_loader: inputs inputs.to(device) labels labels.to(device) optimizer.zero_grad() outputs model(inputs) _, preds torch.max(outputs, 1) loss criterion(outputs, labels) loss.backward() optimizer.step() running_loss loss.item() * inputs.size(0) running_corrects torch.sum(preds labels.data) epoch_loss running_loss / len(train_dataset) epoch_acc running_corrects.double() / len(train_dataset) print(fTrain Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}) # 验证阶段 model.eval() val_loss 0.0 val_corrects 0 with torch.no_grad(): for inputs, labels in val_loader: inputs inputs.to(device) labels labels.to(device) outputs model(inputs) _, preds torch.max(outputs, 1) loss criterion(outputs, labels) val_loss loss.item() * inputs.size(0) val_corrects torch.sum(preds labels.data) val_loss val_loss / len(val_dataset) val_acc val_corrects.double() / len(val_dataset) print(fVal Loss: {val_loss:.4f} Acc: {val_acc:.4f}) # 更新学习率 scheduler.step() # 保存最佳模型 if val_acc best_acc: best_acc val_acc torch.save(model.state_dict(), best_model.pth) print(fBest val Acc: {best_acc:.4f})3.3 性能对比与分析我们在UC Merced数据集上对比了不同方法的分类准确率模型准确率(%)参数量(M)推理时间(ms)ResNet5087.325.515.2ResNet50CBAM89.125.616.8ResNet50M-CBAM91.726.118.3EfficientNet-B490.219.322.7EfficientNet-B4M-CBAM92.519.824.1从结果可以看出M-CBAM模块在不同骨干网络上都能带来显著的性能提升且增加的参数量和计算开销相对有限。4. 高级调优技巧与实战建议在实际应用中为了充分发挥M-CBAM的潜力还需要注意以下调优技巧4.1 注意力位置选择不是所有网络层都同样适合添加注意力模块。通过实验我们发现浅层网络更适合空间注意力帮助定位关键区域深层网络通道注意力效果更明显有助于语义特征选择中间层同时使用两种注意力效果最佳4.2 超参数优化M-CBAM有几个关键超参数需要仔细调整通道缩减比例(ratio)控制通道注意力的计算复杂度通常设为16-32空间注意力卷积核大小影响感受野遥感图像建议使用7×7或9×9注意力融合权重可以设为可学习参数让网络自动平衡两种注意力class M_CBAM(nn.Module): def __init__(self, channels, ratio16, kernel_size7): super(M_CBAM, self).__init__() self.channel_attention ChannelAttention(channels, ratio) self.spatial_attention SpatialAttention(kernel_size) # 可学习的注意力融合权重 self.alpha nn.Parameter(torch.tensor(0.5)) self.beta nn.Parameter(torch.tensor(0.5)) def forward(self, x): # 通道注意力 ca self.channel_attention(x) x_ca x * ca # 空间注意力 sa self.spatial_attention(x_ca) x_sa x_ca * sa # 自适应融合 out self.alpha * x_ca self.beta * x_sa (1 - self.alpha - self.beta) * x return out4.3 类别不平衡处理遥感数据集中常存在严重的类别不平衡问题可以通过以下方式缓解样本重加权根据类别频率调整损失权重焦点损失(Focal Loss)降低易分类样本的权重过采样/欠采样平衡各类别样本数量# 计算类别权重 class_counts [100] * 21 # UC Merced每类100个样本实际中各类数量可能不同 class_weights 1. / torch.tensor(class_counts, dtypetorch.float) class_weights class_weights / class_weights.sum() # 加权交叉熵损失 criterion nn.CrossEntropyLoss(weightclass_weights.to(device))4.4 可视化分析理解模型关注哪些区域对改进模型非常重要。我们可以使用Grad-CAM等方法可视化注意力def generate_gradcam(model, img_tensor, target_layer): # 前向传播 model.eval() output model(img_tensor.unsqueeze(0)) pred_idx torch.argmax(output).item() # 获取目标层的梯度 target output[0, pred_idx] target.backward() gradients model.get_activations_gradient() pooled_gradients torch.mean(gradients, dim[0, 2, 3]) # 获取目标层的激活 activations model.get_activations(img_tensor.unsqueeze(0)).detach() # 加权融合通道 for i in range(activations.shape[1]): activations[:, i, :, :] * pooled_gradients[i] heatmap torch.mean(activations, dim1).squeeze() heatmap np.maximum(heatmap, 0) heatmap / torch.max(heatmap) return heatmap.numpy(), pred_idx在实际项目中我们发现M-CBAM特别擅长处理以下场景区分外观相似但尺度不同的目标如小型飞机与大型飞机在复杂背景中定位小型人造目标处理部分遮挡或光照条件变化的场景

相关文章:

手把手教你用M-CBAM提升遥感图像分类精度(附Python代码)

手把手教你用M-CBAM提升遥感图像分类精度(附Python代码) 遥感图像分类一直是计算机视觉领域的重要研究方向,尤其在土地利用规划、环境监测和灾害评估等应用中发挥着关键作用。然而,由于遥感图像通常包含复杂的场景和多样化的地物目…...

JDK版本不兼容导致HTTPS握手失败?手把手教你解决TLS协议冲突问题

JDK版本不兼容导致HTTPS握手失败的深度解决方案 当Java开发者使用JDK1.8与旧系统(如JDK7)进行HTTPS交互时,经常会遇到javax.net.ssl.SSLHandshakeException: Received fatal alert: handshake_failure这样的错误。这通常是由于TLS协议版本不匹…...

从零开始:用openEuler 22.09搭建openGauss开发环境全记录(含Data Studio连接配置)

从零构建openGauss开发环境:基于openEuler 22.09的完整实践指南 在数据库技术快速迭代的今天,国产开源数据库openGauss凭借其高性能、高安全特性正获得越来越多开发者的青睐。本文将带您完成从操作系统部署到数据库连接的全流程实践,特别针对…...

openclaw赋能Nunchaku FLUX.1-dev:低成本GPU显存优化部署教程

openclaw赋能Nunchaku FLUX.1-dev:低成本GPU显存优化部署教程 想体验FLUX.1-dev强大的文生图能力,却被动辄30GB的显存要求劝退?别担心,今天就来分享一个“平民友好”的部署方案。通过openclaw平台和Nunchaku的量化技术&#xff0…...

SketchUp STL插件:3D模型与打印格式的双向转换解决方案

SketchUp STL插件:3D模型与打印格式的双向转换解决方案 【免费下载链接】sketchup-stl A SketchUp Ruby Extension that adds STL (STereoLithography) file format import and export. 项目地址: https://gitcode.com/gh_mirrors/sk/sketchup-stl 1. 功能解…...

Python环境管理不求人:Miniconda-Python3.10镜像新手入门全攻略

Python环境管理不求人:Miniconda-Python3.10镜像新手入门全攻略 1. 为什么需要Python环境管理 在日常开发中,我们经常会遇到这样的问题: 项目A需要Python 3.7和TensorFlow 1.15项目B需要Python 3.10和TensorFlow 2.8系统自带的Python版本又…...

模拟信号调制技术:深入解析幅度调制的核心原理与应用场景

1. 幅度调制技术的前世今生 第一次接触幅度调制是在大学实验室里,那台老旧的示波器上跳动的波形让我着迷。当时教授用了一个特别形象的比喻:幅度调制就像给快递包裹贴标签——高频载波是运输车辆,低频信号是包裹内容,而调制过程就…...

Local AI MusicGen进阶技巧:组合Prompt生成复杂编曲结构

Local AI MusicGen进阶技巧:组合Prompt生成复杂编曲结构 1. 从单旋律到复杂编曲的挑战 刚开始使用Local AI MusicGen时,你可能已经尝试过一些简单的提示词,比如"钢琴独奏"或"轻快的吉他旋律"。这些简单的提示确实能生成…...

SolidWorks设计师助手:为3D模型角色快速生成参考人脸贴图

SolidWorks设计师助手:为3D模型角色快速生成参考人脸贴图 你是不是也遇到过这种情况?在SolidWorks里好不容易把一个人物角色的身体结构、盔甲装备都建模好了,到了最后一步——给角色“画脸”的时候,却卡住了。对着空白的脸部曲面…...

Phi-3-vision-128k-instruct基础教程:如何用WebShell验证vLLM服务状态

Phi-3-vision-128k-instruct基础教程:如何用WebShell验证vLLM服务状态 1. 模型简介 Phi-3-Vision-128K-Instruct是一个轻量级的多模态模型,它能够同时处理文本和图像信息。这个模型特别适合需要结合视觉和语言理解的任务,比如看图回答问题、…...

chandra人力资源应用:简历批量解析与人才库构建

Chandra人力资源应用:简历批量解析与人才库构建 你是不是也遇到过这样的场景?HR部门每天收到上百份简历,有Word、PDF,甚至还有扫描件。手动打开、阅读、提取关键信息,不仅效率低下,还容易看走眼&#xff0…...

Docker 27日志审计能力跃迁(审计日志零丢失实测报告)

第一章:Docker 27日志审计能力跃迁全景概览Docker 27 引入了原生、可插拔的日志审计框架,标志着容器运行时日志可观测性从“事后排查”迈向“实时合规驱动”的关键转折。该版本不再依赖外部代理或侵入式日志重定向,而是通过内核级日志钩子&am…...

OFA-VE镜像免配置价值:对比手动部署节省4.2小时/人·次实测数据

OFA-VE镜像免配置价值:对比手动部署节省4.2小时/人次实测数据 1. 引言:从“部署地狱”到“一键即用” 如果你尝试过手动部署一个多模态AI模型,大概率经历过这样的场景:花半天时间配环境,结果因为CUDA版本不对报错&am…...

TI电赛开发板(TMS320F28P550)驱动5V光耦隔离继电器模块实战

TI电赛开发板(TMS320F28P550)驱动5V光耦隔离继电器模块实战 很多刚开始接触TI C2000系列DSP的朋友,在做电赛或者项目时,经常会遇到需要控制大功率设备的情况,比如电机、加热管或者照明灯。这时候,继电器就是…...

CMake 多层级项目构建实战指南

1. 为什么需要多层级CMake项目构建 第一次接触CMake时,你可能只写过一个简单的CMakeLists.txt文件来编译单个源文件。但随着项目规模扩大,把所有代码都堆在一个目录下会变得难以管理。想象一下你的衣柜——如果所有衣服都胡乱塞在一起,找件T恤…...

Autoformer核心机制解析:从时序拆解到自相关注意力

1. Autoformer的革新之处:当Transformer遇见时间序列 时间序列预测一直是机器学习领域的经典难题。从早期的ARIMA、Prophet到后来的LSTM、GRU,再到如今基于Transformer的各类模型,我们不断追求更精准的预测能力。Autoformer正是在这个背景下诞…...

MogFace模型Claude Code协作编程:利用AI助手完成模型调用代码重构与优化

MogFace模型Claude Code协作编程:利用AI助手完成模型调用代码重构与优化 最近在做一个项目,需要调用MogFace模型进行人脸检测。我吭哧吭哧写了个初版代码,跑是能跑,但回头一看,结构混乱,错误处理基本靠“随…...

软件工程学习必备:如何高效利用课后习题提升理解(附第四版答案)

软件工程学习必备:如何高效利用课后习题提升理解 作为一名软件工程教育从业者,我经常看到学生在面对课后习题时陷入两种极端:要么机械地抄写答案,要么完全跳过不做。实际上,课后习题是连接理论与实践的黄金桥梁。本文将…...

RK3576开发板ROS部署避坑指南:解决Ubuntu下5个最常见编译错误

RK3576开发板ROS部署避坑指南:解决Ubuntu下5个最常见编译错误 当你在RK3576开发板上部署ROS时,可能会遇到各种棘手的编译问题。这些问题往往与Arm架构的交叉编译环境、库版本兼容性或工具链配置相关。本文将深入分析五个最常遇到的编译错误,并…...

从李雅普诺夫函数到双曲正切:深入理解滑模控制的稳定性设计

滑模控制中的双曲正切函数:从数学本质到工程实践 在非线性控制领域,滑模控制因其对参数不确定性和外部干扰的强鲁棒性而备受推崇。然而,传统滑模控制中固有的抖振问题一直是制约其工程应用的瓶颈。本文将深入探讨双曲正切函数在滑模控制中的应…...

DASD-4B-Thinking与vLLM集成实战:5步完成AI问答系统部署

DASD-4B-Thinking与vLLM集成实战:5步完成AI问答系统部署 1. 为什么选择DASD-4B-Thinking vLLM组合 最近在星图GPU平台上试了几次DASD-4B-Thinking模型,说实话,第一感觉是它不像很多40亿参数的模型那样“凑数”。这个模型在多步推理任务上表…...

WeKnora产品文档系统:基于Vue3的前端界面开发指南

WeKnora产品文档系统:基于Vue3的前端界面开发指南 1. 开发环境准备 在开始WeKnora前端开发之前,我们需要先搭建好开发环境。Vue3作为当前最流行的前端框架之一,提供了更好的性能和开发体验。 首先确保你的系统已经安装Node.js(…...

RimSort:开源环世界MOD管理效率提升解决方案

RimSort:开源环世界MOD管理效率提升解决方案 【免费下载链接】RimSort 项目地址: https://gitcode.com/gh_mirrors/ri/RimSort 问题诊断:环世界MOD管理的三大核心挑战 当环世界玩家安装超过20个MOD后,普遍会遭遇三类技术问题&#x…...

apiSQL+GoView:从零到一构建高效数据大屏的实战指南

1. 为什么需要apiSQLGoView组合? 最近几年数据可视化需求爆发式增长,但传统开发模式存在明显瓶颈。我去年参与过一个智慧园区项目,大屏需要展示20多个图表,结果光是前后端联调就花了整整两周时间。每个图表都要单独开发接口&#…...

从零定制:基于STM32F401CCU开发板的INAV飞控移植实战

1. 为什么选择STM32F401CCU开发板做INAV飞控移植 玩航模的朋友都知道,飞控是飞行器的"大脑"。我当初选择STM32F401CCU开发板来做INAV飞控移植,主要是被它的性价比打动了。这块开发板在某宝上20块钱就能拿下,比专门的飞控板便宜不少…...

GLM-OCR赋能Agent智能体:让AI能“看懂”图片指令

GLM-OCR赋能Agent智能体:让AI能“看懂”图片指令 你有没有想过,未来的AI助手可能不再需要你打字输入指令?想象一下这样的场景:你随手拍下一张产品照片,圈出你想了解的商品,然后AI就能自动识别图片中的文字…...

驱动清理工具技术指南:从问题诊断到风险规避

驱动清理工具技术指南:从问题诊断到风险规避 【免费下载链接】display-drivers-uninstaller Display Driver Uninstaller (DDU) a driver removal utility / cleaner utility 项目地址: https://gitcode.com/gh_mirrors/di/display-drivers-uninstaller 驱动…...

手把手教你用Python实现11种视频质量诊断算法(附代码)

Python实战:11种视频质量诊断算法的工程化实现指南 引言:视频质量诊断的技术价值与应用场景 在安防监控、视频会议、流媒体服务等领域,视频质量直接影响着信息传递的有效性。一个专业的视频质量诊断系统(VQD)能够自动检…...

Neo4j批量导入实战:从CSV到图数据库的5种高效方法对比

Neo4j批量导入实战:从CSV到图数据库的5种高效方法对比 当数据规模突破百万级时,传统的单条插入方式会让Neo4j变得像老式打字机一样缓慢。我曾亲历一个社交网络项目,最初用常规方法导入800万用户关系花费了26小时,而优化后的批量导…...

Zemax非序列转序列避坑指南:从光源设置到惠更斯衍射分析

Zemax非序列转序列避坑指南:从光源设置到惠更斯衍射分析 在光学设计领域,Zemax作为行业标杆软件,其非序列模式(Non-Sequential Mode)与序列模式(Sequential Mode)的转换是许多工程师必须掌握的技…...