使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题
在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著。
1. 小尺寸图像如何加剧样本不均衡?
(1) 局部裁剪导致类别分布偏差
- 问题:遥感图像中某些类别(如道路、建筑)可能稀疏分布。小尺寸裁剪后,部分训练样本可能完全不含某些类别(例如一块纯农田的补丁),导致模型对这些类别缺乏学习机会。
- 示例:
- 原图中“道路”占比5%,若裁剪为
256x256的小图,部分小图中可能完全无道路像素。 - 极端情况下,某些类别可能仅在极少数小图中出现,形成“长尾分布”。
- 原图中“道路”占比5%,若裁剪为
(2) 批次内类别覆盖不足
- 问题:小尺寸图像的批训练(batch training)中,若单个批次内缺少某些类别,梯度更新会偏向多数类。
- 示例:若一个batch中80%的补丁以“植被”为主,模型会倾向于将模糊区域预测为植被。
(3) 像素级不平衡放大
- 问题:即使原图类别均衡,小尺寸裁剪可能导致局部像素比例失衡。
- 例如,原图中“水体”占10%,但某个小图中水体可能占90%(河流区域)或0%(干旱区域)。
2. 样本不均衡的典型影响
- 模型偏向多数类:对高频类别(如植被、背景)过拟合,低频类别(如车辆、道路)漏检。
- 边界模糊:模型对类别交界处的预测置信度低,导致分割边缘不连续。
- 评估指标失真:全局指标(如整体准确率)虚高,但关键类别(如灾害损毁区域)的IoU/F1值极低。
3. 针对小尺寸图像的解决方案
(1) 数据层面的优化
- 定向裁剪(Guided Cropping):
- 根据类别分布优先裁剪包含稀有类别的小图。
- 工具:使用滑动窗口统计每个候选补丁的类别比例,筛选包含目标类别的补丁。
- 过采样(Oversampling):
- 对包含稀有类别的小图增加采样概率。
- 例如:若某小图中含“道路”,则其在训练集中的出现次数增加3倍。
- 数据增强强化:
- 对小图中稀有类别区域进行针对性增强:
- 局部旋转、缩放、亮度调整(避免全局变换导致稀有目标失真)。
- 复制-粘贴增强(Copy-Paste):将稀有目标粘贴到其他背景中(如将车辆粘贴到农田补丁上)。
- 对小图中稀有类别区域进行针对性增强:
(2) 损失函数设计
- 加权交叉熵(Weighted Cross-Entropy):
- 根据类别像素频率反向加权,例如权重与类别频率成反比:
weight = 1 / (class_freq + epsilon) # 防止除零
- 根据类别像素频率反向加权,例如权重与类别频率成反比:
- Focal Loss:
- 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
loss = -α * (1 - p)^γ * log(p) # α平衡类别,γ聚焦难样本
- 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
- Dice Loss / Tversky Loss:
- 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
Dice Loss = 1 - (2*|X∩Y|) / (|X| + |Y|) Tversky Loss = 1 - (|X∩Y|) / (|X∩Y| + α|X-Y| + β|Y-X|) # 调整α,β权衡假阳/假阴
- 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
(3) 模型架构改进
- 上下文感知模块:
- 使用空洞卷积(Dilated Convolution)或注意力机制(如SE Block、Non-local Networks),增强模型对稀疏目标的捕捉能力。
- 多尺度特征融合:
- 通过金字塔池化(PSPNet)或U-Net++结构,融合不同尺度的特征,缓解因小尺寸输入丢失的上下文信息。
- 辅助监督(Auxiliary Supervision):
- 在中间层添加辅助损失函数,强制模型关注细粒度特征。
(4) 训练策略调整
- 小批次大迭代:
- 使用小batch size但增加迭代次数,确保稀有类别在多个epoch中被充分学习。
- 动态类别权重:
- 根据当前batch内的类别分布实时调整损失权重。
- 困难样本挖掘(Hard Example Mining):
- 在每个epoch后,筛选对稀有类别预测误差大的样本,下一轮训练中增加其采样概率。
4. 实验验证建议
- 监控类别指标:除了整体准确率,跟踪每个类别的IoU、F1-score。
- 可视化错误样本:检查模型在稀有类别上的失败案例,针对性优化数据或模型。
- 消融实验:对比不同损失函数、数据增强策略的效果。
小尺寸图像训练会放大样本不均衡问题,但通过定向数据采样、损失函数优化、模型结构改进三者结合,可显著缓解影响。关键是根据任务特点(如目标大小、类别分布)选择组合策略,例如:
- 稀疏小目标:Focal Loss + Copy-Paste增强 + 空洞卷积。
- 长尾分布:加权交叉熵 + 过采样 + 动态类别权重。
在 PyTorch 中,虽然没有直接解决语义分割样本不均衡的“万能模块”,但可以通过组合现有模块和社区成熟库高效实现解决方案。
1. 数据层面:加权采样与增强
(1) 加权随机采样(WeightedRandomSampler)
PyTorch 内置 WeightedRandomSampler,可对包含稀有类别的图像补丁过采样:
import numpy as npdef compute_weight_for_patch(patch):image, mask = patch# 假设 mask 是一个二维数组,每个像素值表示类别标签# 计算每个类别的像素数量class_counts = np.bincount(mask.flatten())# 计算总像素数量total_pixels = mask.size# 计算每个类别的比例class_ratios = class_counts / total_pixels# 计算所有类别的权重class_weights = 1.0 / (class_ratios + 1e-6) # 避免除以零,添加一个小的常数# 应用 sigmoid 函数class_weights = 1.0 / (1.0 + np.exp(-class_weights))# 计算样本的权重sample_weight = np.sum(class_weights)print("Total samples weights:", sample_weight)return class_weights
from torch.utils.data import WeightedRandomSampler# 假设 dataset 返回 (image, mask),且每个样本有一个权重 weight
weights = [compute_weight_for_patch(patch) for patch in dataset] # 根据补丁中稀有类别比例计算权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
(2) 数据增强库(Albumentations)
Albumentations 提供针对分割任务的增强,支持对特定类别区域增强:
import albumentations as Atransform = A.Compose([A.RandomCrop(256, 256),A.OneOf([A.RandomRotate90(),A.HorizontalFlip(),A.VerticalFlip()]),A.RandomBrightnessContrast(p=0.5),# 对特定类别区域增强(如仅增强“车辆”区域)A.RandomCropNearBBox(p=0.5, max_part_shift=0.3)
])
2. 损失函数:直接调用社区实现
(1) Focal Loss
使用 torchvision.ops 或第三方库:
# 使用 torchvision(需 0.10+ 版本)
from torchvision.ops import sigmoid_focal_lossloss = sigmoid_focal_loss(outputs, targets, alpha=0.25, gamma=2, reduction="mean")# 或自定义多类别 Focal Loss
class FocalLoss(nn.Module):def __init__(self, alpha=0.25, gamma=2):super().__init__()self.alpha = alphaself.gamma = gammadef forward(self, inputs, targets):ce_loss = F.cross_entropy(inputs, targets, reduction="none")pt = torch.exp(-ce_loss)loss = self.alpha * (1 - pt) ** self.gamma * ce_lossreturn loss.mean()
(2) Dice Loss
社区标准实现(或使用 segmentation_models_pytorch 库):
class DiceLoss(nn.Module):def __init__(self, smooth=1e-6):super().__init__()self.smooth = smoothdef forward(self, inputs, targets):inputs = F.softmax(inputs, dim=1)targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2)intersection = (inputs * targets).sum()union = inputs.sum() + targets.sum()dice = (2 * intersection + self.smooth) / (union + self.smooth)return 1 - dice
(3) 直接调用 segmentation_models_pytorch 损失函数
import segmentation_models_pytorch as smploss = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1, 2]) # 指定关注类别
loss = smp.losses.FocalLoss(mode="multiclass", normalized=True) # 归一化版本
3. 模型层面:集成注意力与多尺度模块
(1) 使用预建模型库
segmentation_models_pytorch(SMP)提供即用的模型和模块:
import segmentation_models_pytorch as smpmodel = smp.Unet(encoder_name="resnet34",encoder_weights="imagenet",in_channels=3,classes=5,decoder_attention_type="scse", # 添加空间-通道注意力
)
(2) 空洞卷积(Dilated Convolution)
直接使用 PyTorch 的 Conv2d 实现:
class DilatedConvBlock(nn.Module):def __init__(self, in_channels, out_channels, dilation_rate=2):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)self.norm = nn.BatchNorm2d(out_channels)self.act = nn.ReLU()def forward(self, x):return self.act(self.norm(self.conv(x)))# 在 U-Net 的 decoder 中插入空洞卷积块
4. 类别权重计算工具
(1) 自动计算类别权重
from sklearn.utils.class_weight import compute_class_weight# 统计训练集所有像素的类别分布
class_counts = np.bincount(all_pixel_labels.flatten())
class_weights = compute_class_weight(class_weight="balanced", classes=np.arange(num_classes), y=all_pixel_labels.flatten()
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)# 在损失函数中使用
criterion = nn.CrossEntropyLoss(weight=class_weights)
5. 完整 Pipeline 示例
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import segmentation_models_pytorch as smp
import albumentations as A# 1. 定义数据集和采样器
dataset = YourDataset(transform=albumentations_transform)
weights = compute_patch_weights(dataset) # 根据补丁中目标类别比例计算
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)# 2. 定义模型和损失
model = smp.Unet(encoder_name="resnet34", classes=5, decoder_attention_type="scse")
criterion = smp.losses.DiceLoss(mode="multiclass") + smp.losses.FocalLoss(mode="multiclass")# 3. 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):for images, masks in dataloader:outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()
关键工具总结
| 问题类型 | PyTorch 原生支持 | 推荐第三方库(直接调用) |
|---|---|---|
| 数据采样 | WeightedRandomSampler | Albumentations(定向增强) |
| 损失函数 | 自定义(需手写) | segmentation_models_pytorch.losses |
| 模型结构 | 手动添加模块(空洞卷积、注意力) | segmentation_models_pytorch 预建模型 |
| 类别权重计算 | sklearn.utils.class_weight | 内置自动统计工具(如 SMP 数据集类) |
注意事项
- 灵活组合策略:例如同时使用
WeightedRandomSampler和Focal Loss可能过度偏向少数类,需通过实验调整。 - 监控类别指标:使用
torchmetrics库计算每个类别的 IoU:from torchmetrics import JaccardIndex iou = JaccardIndex(num_classes=5, task="multiclass") iou.update(outputs, targets) print(f"IoU: {iou.compute()}") - 混合精度训练:使用
torch.cuda.amp加速训练,缓解显存压力:scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast():outputs = model(images)loss = criterion(outputs, masks) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
相关文章:
使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题
在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著。 1. 小尺寸图像如何加剧样本不均衡? (1) 局部裁剪导致类别分布偏差 问题:遥感图像中某些类别(如道路、建…...
0.91英寸OLED显示屏一种具有小尺寸、高分辨率、低功耗特性的显示器件
0.91英寸OLED显示屏是一种具有小尺寸、高分辨率、低功耗特性的显示器件。以下是对0.91英寸OLED显示屏的详细介绍: 一、基本参数 尺寸:0.91英寸分辨率:通常为128x32像素,意味着显示屏上有128列和32行的像素点,总共409…...
读书笔记--分布式服务架构对比及优势
本篇是在上一篇的基础上,主要对共享服务平台建设所依赖的分布式服务架构进行学习,主要记录和思考如下,供大家学习参考。随着企业各业务数字化转型工作的推进,之前在传统的单一系统(或单体应用)模式中&#…...
HTML5 新的 Input 类型详解
HTML5 引入了许多新的输入类型,极大地增强了表单的功能和用户体验。这些新的输入类型不仅提供了更好的输入控制,还支持内置的验证功能,减少了开发者手动编写验证逻辑的工作量。本文将全面介绍 HTML5 中新增的输入类型,并结合代码示…...
ESP32-CAM实验集(WebServer)
WebServer 效果图 已连接 web端 platformio.ini ; PlatformIO Project Configuration File ; ; Build options: build flags, source filter ; Upload options: custom upload port, speed and extra flags ; Library options: dependencies, extra library stor…...
Case逢无意难休——深度解析JAVA中case穿透问题
Case逢无意难休——深度解析JAVA中case穿透问题~ 不作溢美之词,不作浮夸文章,此文与功名进取毫不相关也!与大家共勉!! 更多文章:个人主页 系列文章:JAVA专栏 欢迎各位大佬来访哦~互三必回&#…...
Golang笔记——常用库context和runtime
大家好,这里是Good Note,关注 公主号:Goodnote,专栏文章私信限时Free。本文详细介绍Golang的常用库context和runtime,包括库的基本概念和基本函数的使用等。 文章目录 contextcontext 包的基本概念主要类型和函数1. **…...
2000-2020年各省第二产业增加值占GDP比重数据
2000-2020年各省第二产业增加值占GDP比重数据 1、时间:2000-2020年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区名称、年份、第二产业增加值占GDP比重 4、范围:31省 5、指标解释:第二产业增加值占GDP比重…...
unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?
需要使用NavGraph.IsPointOnNavmesh(Vector3 point) 如果点位于导航网的可步行部分,则为真。 如果一个点在可步行导航网表面之上或之下,在任何距离,如果它不在更近的不可步行节点之上 / 之下,则认为它在导航网上。 使用方法 Ast…...
Day24-【13003】短文,数据结构与算法开篇,什么是数据元素?数据结构有哪些类型?什么是抽象类型?
文章目录 13003数据结构与算法全书框架考试题型的分值分布如何? 本次内容概述绪论第一节概览什么是数据、数据元素,数据项,数据项的值?什么是数据结构?分哪两种集合形式(逻辑和存储)?…...
富文本 tinyMCE Vue2 组件使用简易教程
参考官方教程 TinyMCE Vue.js integration technical reference Vue2 项目需要使用 tinyMCE Vue2 组件(tinymce/tinymce-vue)的第 3 版 安装组件 npm install --save "tinymce/tinymce-vue^3" 编写组件调用 <template><Editorref"editor"v-m…...
强化学习在自动驾驶中的实现与挑战
强化学习在自动驾驶中的实现与挑战 自动驾驶技术作为当今人工智能领域的前沿之一,正通过各种方式改变我们的出行方式。而强化学习(Reinforcement Learning, RL),作为机器学习的一大分支,在自动驾驶的实现中扮演了至关重要的角色。它通过模仿人类驾驶员的决策过程,为车辆…...
记录 | MaxKB创建本地AI智能问答系统
目录 前言一、重建MaxKBStep1 复制路径Step2 删除MaxKBStep3 创建数据存储文件夹Step4 重建 二、创建知识库Step1 新建知识库Step2 下载测试所用的txtStep3 上传本地文档Step4 选择模型补充智谱的API Key如何获取 Step5 查看是否成功 三、创建应用Step1 新建应用Step2 配置AI助…...
特种作业操作之低压电工考试真题
1.下面( )属于顺磁性材料。 A. 铜 B. 水 C. 空气 答案:C 2.事故照明一般采用( )。 A. 日光灯 B. 白炽灯 C. 压汞灯 答案:B 3.人体同时接触带电设备或线路中的两相导体时,电流从一相通过人体流…...
[免费]基于Python的Django博客系统【论文+源码+SQL脚本】
大家好,我是java1234_小锋老师,看到一个不错的基于Python的Django博客系统,分享下哈。 项目视频演示 【免费】基于Python的Django博客系统 Python毕业设计_哔哩哔哩_bilibili 项目介绍 随着互联网技术的飞速发展,信息的传播与…...
Cannot resolve symbol ‘XXX‘ Maven 依赖问题的解决过程
一、问题描述 在使用 Maven 管理项目依赖时,遇到了一个棘手的问题。具体表现为:在 pom.xml 文件中导入了所需的依赖,并且在 IDE 中导入语句没有显示为红色(表示 IDE 没有提示依赖缺失),但是在实际使用这些依…...
我们需要有哪些知识体系,知识体系里面要有什么哪些内容?
01、管理知识体系的学习知识体系 主要内容: 1、知识管理框架的外部借鉴、和自身知识体系的搭建; 2、学习能力、思维逻辑能力等的塑造; 3、知识管理工具的使用; 4、学习资料的导入和查找资料的渠道; 5、深层关键的…...
什么是vue.js组件开发,我们需要做哪些准备工作?
Vue.js 是一个非常流行的前端框架,用于构建用户界面。组件开发是 Vue.js 的核心概念之一,通过将界面拆分为独立的组件,可以提高代码的可维护性和复用性。以下是一个详细的 Vue.js 组件开发指南,包括基础概念、开发流程和代码示例。 一、Vue.js 组件开发基础 1. 组件的基本…...
网络工程师 (3)指令系统基础
一、寻址方式 (一)指令寻址 顺序寻址:通过程序计数器(PC)加1,自动形成下一条指令的地址。这是计算机中最基本、最常用的寻址方式。 跳跃寻址:通过转移类指令直接或间接给出下一条指令的地址。跳…...
第4章 神经网络【1】——损失函数
4.1.从数据中学习 实际的神经网络中,参数的数量成千上万,因此,需要由数据自动决定权重参数的值。 4.1.1.数据驱动 数据是机器学习的核心。 我们的目标是要提取出特征量,特征量指的是从输入数据/图像中提取出的本质的数 …...
ES6从入门到精通:前言
ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...
CMake控制VS2022项目文件分组
我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...
中医有效性探讨
文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...
网站指纹识别
网站指纹识别 网站的最基本组成:服务器(操作系统)、中间件(web容器)、脚本语言、数据厍 为什么要了解这些?举个例子:发现了一个文件读取漏洞,我们需要读/etc/passwd,如…...
GitFlow 工作模式(详解)
今天再学项目的过程中遇到使用gitflow模式管理代码,因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存,无论是github还是gittee,都是一种基于git去保存代码的形式,这样保存代码…...
基于IDIG-GAN的小样本电机轴承故障诊断
目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) 梯度归一化(Gradient Normalization) (2) 判别器梯度间隙正则化(Discriminator Gradient Gap Regularization) (3) 自注意力机制(Self-Attention) 3. 完整损失函数 二…...
莫兰迪高级灰总结计划简约商务通用PPT模版
莫兰迪高级灰总结计划简约商务通用PPT模版,莫兰迪调色板清新简约工作汇报PPT模版,莫兰迪时尚风极简设计PPT模版,大学生毕业论文答辩PPT模版,莫兰迪配色总结计划简约商务通用PPT模版,莫兰迪商务汇报PPT模版,…...
Web后端基础(基础知识)
BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器,应用程序的逻辑和数据都存储在服务端。 优点:维护方便缺点:体验一般 CS架构:Client/Server,客户端/服务器架构模式。需要单独…...
给网站添加live2d看板娘
给网站添加live2d看板娘 参考文献: stevenjoezhang/live2d-widget: 把萌萌哒的看板娘抱回家 (ノ≧∇≦)ノ | Live2D widget for web platformEikanya/Live2d-model: Live2d model collectionzenghongtu/live2d-model-assets 前言 网站环境如下,文章也主…...
