当前位置: 首页 > news >正文

使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题

在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著


1. 小尺寸图像如何加剧样本不均衡?

(1) 局部裁剪导致类别分布偏差
  • 问题:遥感图像中某些类别(如道路、建筑)可能稀疏分布。小尺寸裁剪后,部分训练样本可能完全不含某些类别(例如一块纯农田的补丁),导致模型对这些类别缺乏学习机会。
  • 示例
    • 原图中“道路”占比5%,若裁剪为 256x256 的小图,部分小图中可能完全无道路像素。
    • 极端情况下,某些类别可能仅在极少数小图中出现,形成“长尾分布”。
(2) 批次内类别覆盖不足
  • 问题:小尺寸图像的批训练(batch training)中,若单个批次内缺少某些类别,梯度更新会偏向多数类。
  • 示例:若一个batch中80%的补丁以“植被”为主,模型会倾向于将模糊区域预测为植被。
(3) 像素级不平衡放大
  • 问题:即使原图类别均衡,小尺寸裁剪可能导致局部像素比例失衡。
    • 例如,原图中“水体”占10%,但某个小图中水体可能占90%(河流区域)或0%(干旱区域)。

2. 样本不均衡的典型影响

  • 模型偏向多数类:对高频类别(如植被、背景)过拟合,低频类别(如车辆、道路)漏检。
  • 边界模糊:模型对类别交界处的预测置信度低,导致分割边缘不连续。
  • 评估指标失真:全局指标(如整体准确率)虚高,但关键类别(如灾害损毁区域)的IoU/F1值极低。

3. 针对小尺寸图像的解决方案

(1) 数据层面的优化
  • 定向裁剪(Guided Cropping)
    • 根据类别分布优先裁剪包含稀有类别的小图。
    • 工具:使用滑动窗口统计每个候选补丁的类别比例,筛选包含目标类别的补丁。
  • 过采样(Oversampling)
    • 对包含稀有类别的小图增加采样概率。
    • 例如:若某小图中含“道路”,则其在训练集中的出现次数增加3倍。
  • 数据增强强化
    • 对小图中稀有类别区域进行针对性增强:
      • 局部旋转、缩放、亮度调整(避免全局变换导致稀有目标失真)。
      • 复制-粘贴增强(Copy-Paste):将稀有目标粘贴到其他背景中(如将车辆粘贴到农田补丁上)。
(2) 损失函数设计
  • 加权交叉熵(Weighted Cross-Entropy)
    • 根据类别像素频率反向加权,例如权重与类别频率成反比:
      weight = 1 / (class_freq + epsilon)  # 防止除零
      
  • Focal Loss
    • 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
      loss = -α * (1 - p)^γ * log(p)  # α平衡类别,γ聚焦难样本
      
  • Dice Loss / Tversky Loss
    • 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
      Dice Loss = 1 - (2*|X∩Y|) / (|X| + |Y|)
      Tversky Loss = 1 - (|X∩Y|) / (|X∩Y| + α|X-Y| + β|Y-X|)  # 调整α,β权衡假阳/假阴
      
(3) 模型架构改进
  • 上下文感知模块
    • 使用空洞卷积(Dilated Convolution)或注意力机制(如SE Block、Non-local Networks),增强模型对稀疏目标的捕捉能力。
  • 多尺度特征融合
    • 通过金字塔池化(PSPNet)或U-Net++结构,融合不同尺度的特征,缓解因小尺寸输入丢失的上下文信息。
  • 辅助监督(Auxiliary Supervision)
    • 在中间层添加辅助损失函数,强制模型关注细粒度特征。
(4) 训练策略调整
  • 小批次大迭代
    • 使用小batch size但增加迭代次数,确保稀有类别在多个epoch中被充分学习。
  • 动态类别权重
    • 根据当前batch内的类别分布实时调整损失权重。
  • 困难样本挖掘(Hard Example Mining)
    • 在每个epoch后,筛选对稀有类别预测误差大的样本,下一轮训练中增加其采样概率。

4. 实验验证建议

  • 监控类别指标:除了整体准确率,跟踪每个类别的IoU、F1-score。
  • 可视化错误样本:检查模型在稀有类别上的失败案例,针对性优化数据或模型。
  • 消融实验:对比不同损失函数、数据增强策略的效果。

小尺寸图像训练会放大样本不均衡问题,但通过定向数据采样、损失函数优化、模型结构改进三者结合,可显著缓解影响。关键是根据任务特点(如目标大小、类别分布)选择组合策略,例如:

  • 稀疏小目标:Focal Loss + Copy-Paste增强 + 空洞卷积。
  • 长尾分布:加权交叉熵 + 过采样 + 动态类别权重。

在 PyTorch 中,虽然没有直接解决语义分割样本不均衡的“万能模块”,但可以通过组合现有模块社区成熟库高效实现解决方案。


1. 数据层面:加权采样与增强

(1) 加权随机采样(WeightedRandomSampler)

PyTorch 内置 WeightedRandomSampler,可对包含稀有类别的图像补丁过采样:

import numpy as npdef compute_weight_for_patch(patch):image, mask = patch# 假设 mask 是一个二维数组,每个像素值表示类别标签# 计算每个类别的像素数量class_counts = np.bincount(mask.flatten())# 计算总像素数量total_pixels = mask.size# 计算每个类别的比例class_ratios = class_counts / total_pixels# 计算所有类别的权重class_weights = 1.0 / (class_ratios + 1e-6)  # 避免除以零,添加一个小的常数# 应用 sigmoid 函数class_weights = 1.0 / (1.0 + np.exp(-class_weights))# 计算样本的权重sample_weight = np.sum(class_weights)print("Total samples weights:", sample_weight)return class_weights
from torch.utils.data import WeightedRandomSampler# 假设 dataset 返回 (image, mask),且每个样本有一个权重 weight
weights = [compute_weight_for_patch(patch) for patch in dataset]  # 根据补丁中稀有类别比例计算权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
(2) 数据增强库(Albumentations)

Albumentations 提供针对分割任务的增强,支持对特定类别区域增强:

import albumentations as Atransform = A.Compose([A.RandomCrop(256, 256),A.OneOf([A.RandomRotate90(),A.HorizontalFlip(),A.VerticalFlip()]),A.RandomBrightnessContrast(p=0.5),# 对特定类别区域增强(如仅增强“车辆”区域)A.RandomCropNearBBox(p=0.5, max_part_shift=0.3)
])

2. 损失函数:直接调用社区实现

(1) Focal Loss

使用 torchvision.ops 或第三方库:

# 使用 torchvision(需 0.10+ 版本)
from torchvision.ops import sigmoid_focal_lossloss = sigmoid_focal_loss(outputs, targets, alpha=0.25, gamma=2, reduction="mean")# 或自定义多类别 Focal Loss
class FocalLoss(nn.Module):def __init__(self, alpha=0.25, gamma=2):super().__init__()self.alpha = alphaself.gamma = gammadef forward(self, inputs, targets):ce_loss = F.cross_entropy(inputs, targets, reduction="none")pt = torch.exp(-ce_loss)loss = self.alpha * (1 - pt) ** self.gamma * ce_lossreturn loss.mean()
(2) Dice Loss

社区标准实现(或使用 segmentation_models_pytorch 库):

class DiceLoss(nn.Module):def __init__(self, smooth=1e-6):super().__init__()self.smooth = smoothdef forward(self, inputs, targets):inputs = F.softmax(inputs, dim=1)targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2)intersection = (inputs * targets).sum()union = inputs.sum() + targets.sum()dice = (2 * intersection + self.smooth) / (union + self.smooth)return 1 - dice
(3) 直接调用 segmentation_models_pytorch 损失函数
import segmentation_models_pytorch as smploss = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1, 2])  # 指定关注类别
loss = smp.losses.FocalLoss(mode="multiclass", normalized=True)   # 归一化版本

3. 模型层面:集成注意力与多尺度模块

(1) 使用预建模型库

segmentation_models_pytorch(SMP)提供即用的模型和模块:

import segmentation_models_pytorch as smpmodel = smp.Unet(encoder_name="resnet34",encoder_weights="imagenet",in_channels=3,classes=5,decoder_attention_type="scse",  # 添加空间-通道注意力
)
(2) 空洞卷积(Dilated Convolution)

直接使用 PyTorch 的 Conv2d 实现:

class DilatedConvBlock(nn.Module):def __init__(self, in_channels, out_channels, dilation_rate=2):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)self.norm = nn.BatchNorm2d(out_channels)self.act = nn.ReLU()def forward(self, x):return self.act(self.norm(self.conv(x)))# 在 U-Net 的 decoder 中插入空洞卷积块

4. 类别权重计算工具

(1) 自动计算类别权重
from sklearn.utils.class_weight import compute_class_weight# 统计训练集所有像素的类别分布
class_counts = np.bincount(all_pixel_labels.flatten())
class_weights = compute_class_weight(class_weight="balanced", classes=np.arange(num_classes), y=all_pixel_labels.flatten()
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)# 在损失函数中使用
criterion = nn.CrossEntropyLoss(weight=class_weights)

5. 完整 Pipeline 示例

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import segmentation_models_pytorch as smp
import albumentations as A# 1. 定义数据集和采样器
dataset = YourDataset(transform=albumentations_transform)
weights = compute_patch_weights(dataset)  # 根据补丁中目标类别比例计算
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)# 2. 定义模型和损失
model = smp.Unet(encoder_name="resnet34", classes=5, decoder_attention_type="scse")
criterion = smp.losses.DiceLoss(mode="multiclass") + smp.losses.FocalLoss(mode="multiclass")# 3. 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):for images, masks in dataloader:outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()

关键工具总结

问题类型PyTorch 原生支持推荐第三方库(直接调用)
数据采样WeightedRandomSamplerAlbumentations(定向增强)
损失函数自定义(需手写)segmentation_models_pytorch.losses
模型结构手动添加模块(空洞卷积、注意力)segmentation_models_pytorch 预建模型
类别权重计算sklearn.utils.class_weight内置自动统计工具(如 SMP 数据集类)

注意事项

  1. 灵活组合策略:例如同时使用 WeightedRandomSamplerFocal Loss 可能过度偏向少数类,需通过实验调整。
  2. 监控类别指标:使用 torchmetrics 库计算每个类别的 IoU:
    from torchmetrics import JaccardIndex
    iou = JaccardIndex(num_classes=5, task="multiclass")
    iou.update(outputs, targets)
    print(f"IoU: {iou.compute()}")
    
  3. 混合精度训练:使用 torch.cuda.amp 加速训练,缓解显存压力:
    scaler = torch.cuda.amp.GradScaler()
    with torch.cuda.amp.autocast():outputs = model(images)loss = criterion(outputs, masks)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    

相关文章:

使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题

在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著。 1. 小尺寸图像如何加剧样本不均衡? (1) 局部裁剪导致类别分布偏差 问题:遥感图像中某些类别(如道路、建…...

0.91英寸OLED显示屏一种具有小尺寸、高分辨率、低功耗特性的显示器件

0.91英寸OLED显示屏是一种具有小尺寸、高分辨率、低功耗特性的显示器件。以下是对0.91英寸OLED显示屏的详细介绍: 一、基本参数 尺寸:0.91英寸分辨率:通常为128x32像素,意味着显示屏上有128列和32行的像素点,总共409…...

读书笔记--分布式服务架构对比及优势

本篇是在上一篇的基础上,主要对共享服务平台建设所依赖的分布式服务架构进行学习,主要记录和思考如下,供大家学习参考。随着企业各业务数字化转型工作的推进,之前在传统的单一系统(或单体应用)模式中&#…...

HTML5 新的 Input 类型详解

HTML5 引入了许多新的输入类型,极大地增强了表单的功能和用户体验。这些新的输入类型不仅提供了更好的输入控制,还支持内置的验证功能,减少了开发者手动编写验证逻辑的工作量。本文将全面介绍 HTML5 中新增的输入类型,并结合代码示…...

ESP32-CAM实验集(WebServer)

WebServer 效果图 已连接 web端 platformio.ini ; PlatformIO Project Configuration File ; ; Build options: build flags, source filter ; Upload options: custom upload port, speed and extra flags ; Library options: dependencies, extra library stor…...

Case逢无意难休——深度解析JAVA中case穿透问题

Case逢无意难休——深度解析JAVA中case穿透问题~ 不作溢美之词,不作浮夸文章,此文与功名进取毫不相关也!与大家共勉!! 更多文章:个人主页 系列文章:JAVA专栏 欢迎各位大佬来访哦~互三必回&#…...

Golang笔记——常用库context和runtime

大家好,这里是Good Note,关注 公主号:Goodnote,专栏文章私信限时Free。本文详细介绍Golang的常用库context和runtime,包括库的基本概念和基本函数的使用等。 文章目录 contextcontext 包的基本概念主要类型和函数1. **…...

2000-2020年各省第二产业增加值占GDP比重数据

2000-2020年各省第二产业增加值占GDP比重数据 1、时间:2000-2020年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区名称、年份、第二产业增加值占GDP比重 4、范围:31省 5、指标解释:第二产业增加值占GDP比重…...

unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?

需要使用NavGraph.IsPointOnNavmesh(Vector3 point) 如果点位于导航网的可步行部分,则为真。 如果一个点在可步行导航网表面之上或之下,在任何距离,如果它不在更近的不可步行节点之上 / 之下,则认为它在导航网上。 使用方法 Ast…...

Day24-【13003】短文,数据结构与算法开篇,什么是数据元素?数据结构有哪些类型?什么是抽象类型?

文章目录 13003数据结构与算法全书框架考试题型的分值分布如何? 本次内容概述绪论第一节概览什么是数据、数据元素,数据项,数据项的值?什么是数据结构?分哪两种集合形式(逻辑和存储)&#xff1f…...

富文本 tinyMCE Vue2 组件使用简易教程

参考官方教程 TinyMCE Vue.js integration technical reference Vue2 项目需要使用 tinyMCE Vue2 组件(tinymce/tinymce-vue)的第 3 版 安装组件 npm install --save "tinymce/tinymce-vue^3" 编写组件调用 <template><Editorref"editor"v-m…...

强化学习在自动驾驶中的实现与挑战

强化学习在自动驾驶中的实现与挑战 自动驾驶技术作为当今人工智能领域的前沿之一,正通过各种方式改变我们的出行方式。而强化学习(Reinforcement Learning, RL),作为机器学习的一大分支,在自动驾驶的实现中扮演了至关重要的角色。它通过模仿人类驾驶员的决策过程,为车辆…...

记录 | MaxKB创建本地AI智能问答系统

目录 前言一、重建MaxKBStep1 复制路径Step2 删除MaxKBStep3 创建数据存储文件夹Step4 重建 二、创建知识库Step1 新建知识库Step2 下载测试所用的txtStep3 上传本地文档Step4 选择模型补充智谱的API Key如何获取 Step5 查看是否成功 三、创建应用Step1 新建应用Step2 配置AI助…...

特种作业操作之低压电工考试真题

1.下面&#xff08; &#xff09;属于顺磁性材料。 A. 铜 B. 水 C. 空气 答案&#xff1a;C 2.事故照明一般采用&#xff08; &#xff09;。 A. 日光灯 B. 白炽灯 C. 压汞灯 答案&#xff1a;B 3.人体同时接触带电设备或线路中的两相导体时&#xff0c;电流从一相通过人体流…...

[免费]基于Python的Django博客系统【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的基于Python的Django博客系统&#xff0c;分享下哈。 项目视频演示 【免费】基于Python的Django博客系统 Python毕业设计_哔哩哔哩_bilibili 项目介绍 随着互联网技术的飞速发展&#xff0c;信息的传播与…...

Cannot resolve symbol ‘XXX‘ Maven 依赖问题的解决过程

一、问题描述 在使用 Maven 管理项目依赖时&#xff0c;遇到了一个棘手的问题。具体表现为&#xff1a;在 pom.xml 文件中导入了所需的依赖&#xff0c;并且在 IDE 中导入语句没有显示为红色&#xff08;表示 IDE 没有提示依赖缺失&#xff09;&#xff0c;但是在实际使用这些依…...

我们需要有哪些知识体系,知识体系里面要有什么哪些内容?

01、管理知识体系的学习知识体系 主要内容&#xff1a; 1、知识管理框架的外部借鉴、和自身知识体系的搭建&#xff1b; 2、学习能力、思维逻辑能力等的塑造&#xff1b; 3、知识管理工具的使用&#xff1b; 4、学习资料的导入和查找资料的渠道&#xff1b; 5、深层关键的…...

什么是vue.js组件开发,我们需要做哪些准备工作?

Vue.js 是一个非常流行的前端框架,用于构建用户界面。组件开发是 Vue.js 的核心概念之一,通过将界面拆分为独立的组件,可以提高代码的可维护性和复用性。以下是一个详细的 Vue.js 组件开发指南,包括基础概念、开发流程和代码示例。 一、Vue.js 组件开发基础 1. 组件的基本…...

网络工程师 (3)指令系统基础

一、寻址方式 &#xff08;一&#xff09;指令寻址 顺序寻址&#xff1a;通过程序计数器&#xff08;PC&#xff09;加1&#xff0c;自动形成下一条指令的地址。这是计算机中最基本、最常用的寻址方式。 跳跃寻址&#xff1a;通过转移类指令直接或间接给出下一条指令的地址。跳…...

第4章 神经网络【1】——损失函数

4.1.从数据中学习 实际的神经网络中&#xff0c;参数的数量成千上万&#xff0c;因此&#xff0c;需要由数据自动决定权重参数的值。 4.1.1.数据驱动 数据是机器学习的核心。 我们的目标是要提取出特征量&#xff0c;特征量指的是从输入数据/图像中提取出的本质的数 …...

<6>-MySQL表的增删查改

目录 一&#xff0c;create&#xff08;创建表&#xff09; 二&#xff0c;retrieve&#xff08;查询表&#xff09; 1&#xff0c;select列 2&#xff0c;where条件 三&#xff0c;update&#xff08;更新表&#xff09; 四&#xff0c;delete&#xff08;删除表&#xf…...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

以光量子为例,详解量子获取方式

光量子技术获取量子比特可在室温下进行。该方式有望通过与名为硅光子学&#xff08;silicon photonics&#xff09;的光波导&#xff08;optical waveguide&#xff09;芯片制造技术和光纤等光通信技术相结合来实现量子计算机。量子力学中&#xff0c;光既是波又是粒子。光子本…...

使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度

文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...

02.运算符

目录 什么是运算符 算术运算符 1.基本四则运算符 2.增量运算符 3.自增/自减运算符 关系运算符 逻辑运算符 &&&#xff1a;逻辑与 ||&#xff1a;逻辑或 &#xff01;&#xff1a;逻辑非 短路求值 位运算符 按位与&&#xff1a; 按位或 | 按位取反~ …...