当前位置: 首页 > article >正文

PyTorch实战:用CrossEntropyLoss的weight和label_smoothing解决类别不平衡与过拟合

PyTorch实战用CrossEntropyLoss的weight和label_smoothing解决类别不平衡与过拟合当你面对医学影像分类任务时数据集中正常样本占比90%而病变样本仅占10%。训练后的模型对所有样本都预测为正常类别准确率看似很高却完全无法识别关键病例——这是类别不平衡问题的典型表现。另一种情况是模型在训练集上准确率达到99%但在验证集上暴跌至60%这是过拟合在作祟。本文将手把手教你用PyTorch的nn.CrossEntropyLoss中的weight和label_smoothing参数解决这两大难题。1. 理解核心问题类别不平衡与过拟合1.1 类别不平衡的数学本质假设我们有一个三分类任务类别分布为类别样本数占比A90090%B909%C101%传统交叉熵损失函数会平等对待每个样本导致模型倾向于优化主导类别的预测准确率。从数学上看标准交叉熵损失为$$ L -\frac{1}{N}\sum_{i1}^N \log(p_{i,y_i}) $$其中$p_{i,y_i}$是样本i在其真实类别$y_i$上的预测概率。对于上述分布即使模型完全忽略B、C类损失值也能保持很低。1.2 过拟合的表现形式过拟合模型通常会出现以下特征训练损失持续下降而验证损失开始上升模型对训练样本的预测置信度极高softmax输出接近1.0在对抗样本或噪声数据上表现脆弱# 过拟合模型的典型输出示例 output model(train_data) print(torch.softmax(output, dim1)[:5]) # tensor([[0.9999, 0.0001], # [0.9997, 0.0003], # [0.9998, 0.0002], # [0.0001, 0.9999], # [0.0002, 0.9998]])2. 类别加权weight参数实战2.1 计算类别权重的三种方法PyTorch的weight参数需要传入一个长度为C类别数的张量。以下是常用计算方法逆频率加权class_counts torch.tensor([900, 90, 10]) weights 1.0 / class_counts weights weights / weights.sum() * len(weights) # 归一化 # tensor([0.0111, 0.1111, 1.0000])有效样本数加权适用于极端不平衡beta 0.999 effective_num 1.0 - torch.pow(beta, class_counts) weights (1.0 - beta) / effective_num # tensor([0.0011, 0.0110, 0.1054])平方根逆频率更平滑的加权weights 1.0 / torch.sqrt(class_counts) # tensor([0.0333, 0.1054, 0.3162])2.2 完整训练代码示例import torch import torch.nn as nn from torch.utils.data import DataLoader, WeightedRandomSampler # 假设我们有一个极度不平衡的数据集 dataset ... # 你的数据集 class_counts torch.tensor([900, 90, 10]) weights 1.0 / class_counts sample_weights weights[dataset.targets] # 使用加权采样器 sampler WeightedRandomSampler( weightssample_weights, num_sampleslen(dataset), replacementTrue ) # 定义带权重的损失函数 criterion nn.CrossEntropyLoss( weightweights.to(device), label_smoothing0.0 ) # 训练循环 for epoch in range(epochs): for inputs, targets in DataLoader(dataset, samplersampler): outputs model(inputs) loss criterion(outputs, targets) ...注意使用weight参数时建议同时配合WeightedRandomSampler进行样本重采样从数据加载层面进一步缓解不平衡问题。3. 标签平滑label_smoothing参数详解3.1 标签平滑的数学原理传统one-hot标签会强制模型对正确类别的预测概率接近1.0这容易导致模型过度自信泛化能力下降对对抗样本敏感标签平滑将原始标签$y$转换为$$ y (1 - \alpha) \cdot y \frac{\alpha}{C} $$其中$\alpha$是平滑系数通常0.1-0.2$C$是类别数。例如对于二分类原始标签[1, 0]→ 平滑后[0.95, 0.05]当α0.13.2 不同平滑系数的影响我们通过实验比较不同α值的效果α值训练准确率验证准确率测试集熵0.099.2%85.3%0.080.197.8%88.6%0.350.296.5%89.1%0.520.395.1%88.3%0.68实验表明α0.2时模型在验证集上表现最佳且预测分布更柔软熵值适中。3.3 实现代码对比# 传统硬标签训练 criterion nn.CrossEntropyLoss() output model(input) loss criterion(output, target) # target是类别索引 # 标签平滑训练PyTorch 1.10 criterion nn.CrossEntropyLoss(label_smoothing0.1) loss criterion(output, target) # 手动实现标签平滑 def smooth_one_hot(target, n_classes, smoothing0.0): assert 0 smoothing 1 with torch.no_grad(): target torch.empty_like(output).fill_( smoothing / (n_classes - 1) ).scatter_(1, target.unsqueeze(1), 1.0 - smoothing) return target smoothed_target smooth_one_hot(target, n_classes10, smoothing0.1) loss criterion(output, smoothed_target)4. 综合解决方案weight与label_smoothing联合使用4.1 参数组合策略当同时面对类别不平衡和过拟合问题时建议采用以下策略先确定最佳weight计算初始类别分布用逆频率或有效样本数方法得到基础权重在验证集上微调权重缩放因子通常0.5-2.0倍再调整label_smoothing从α0.1开始尝试观察验证集准确率和损失曲线以0.05为步长调整最终联合训练# 最优参数组合示例 best_weights torch.tensor([0.5, 1.8, 3.0]) # 对稀有类别赋予更高权重 best_alpha 0.15 criterion nn.CrossEntropyLoss( weightbest_weights, label_smoothingbest_alpha )4.2 完整训练案例以下是在COVID-19胸部X光分类中的应用实例import torch import torch.nn as nn import torch.optim as optim from torchvision import models # 数据准备 train_loader ... # 不平衡的COVID数据集 class_counts torch.tensor([1000, 200, 50]) # 正常/肺炎/COVID-19 # 模型与优化器 model models.resnet50(pretrainedTrue) model.fc nn.Linear(2048, 3) # 损失函数配置 weights 1.0 / torch.sqrt(class_counts) # 平方根逆频率加权 criterion nn.CrossEntropyLoss( weightweights.to(device), label_smoothing0.1 ) optimizer optim.AdamW(model.parameters(), lr1e-4) # 训练循环 for epoch in range(30): model.train() for inputs, targets in train_loader: optimizer.zero_grad() outputs model(inputs.to(device)) loss criterion(outputs, targets.to(device)) loss.backward() optimizer.step() # 验证逻辑 model.eval() with torch.no_grad(): # 计算各类别准确率...4.3 效果评估指标不要只看整体准确率要关注混淆矩阵from sklearn.metrics import confusion_matrix cm confusion_matrix(true_labels, preds) print(cm) # [[980 15 5] # [ 10 180 10] # [ 2 8 40]]类别特异性指标召回率对稀有类别最关键F1分数PR曲线下面积AUPRC模型校准度from sklearn.calibration import calibration_curve prob_true, prob_pred calibration_curve(true_labels, pred_probs, n_bins10) plt.plot(prob_pred, prob_true, s-)在实际医疗影像项目中这种组合策略将COVID-19类别的召回率从35%提升到了68%同时保持了其他类别的性能。模型对噪声和对抗攻击的鲁棒性也有显著改善——当向测试图像添加高斯噪声(σ0.1)时传统模型的准确率下降42%而使用weightlabel_smoothing的模型仅下降17%。

相关文章:

PyTorch实战:用CrossEntropyLoss的weight和label_smoothing解决类别不平衡与过拟合

PyTorch实战:用CrossEntropyLoss的weight和label_smoothing解决类别不平衡与过拟合 当你面对医学影像分类任务时,数据集中正常样本占比90%,而病变样本仅占10%。训练后的模型对所有样本都预测为正常类别,准确率看似很高却完全无法识…...

Display Driver Uninstaller (DDU):显卡驱动问题的终极解决方案

Display Driver Uninstaller (DDU):显卡驱动问题的终极解决方案 【免费下载链接】display-drivers-uninstaller Display Driver Uninstaller (DDU) a driver removal utility / cleaner utility 项目地址: https://gitcode.com/gh_mirrors/di/display-drivers-uni…...

C#怎么获取多显示器屏幕尺寸_C#如何适应不同分辨率【解析】

Screen.AllScreens 可获取所有显示器的 Bounds(含位置和宽高)及 WorkingArea,需配合 per-monitor V2 manifest 实现准确 DPI 感知,否则 Bounds 返回逻辑像素而非物理分辨率。怎么用 Screen.AllScreens 拿到所有显示器的尺寸直接遍…...

保姆级教程:在AutoDL上用vLLM一键部署GLM-4.1V-Thinking多模态大模型

云平台极速部署GLM-4.1V多模态模型实战指南 当我们需要快速验证一个视觉语言模型的实际表现时,云GPU平台往往是最便捷的选择。不同于本地部署需要折腾驱动和环境,云服务提供了开箱即用的计算资源,特别适合需要快速迭代的实验场景。今天我们就…...

临床医生也能懂的蛋白质组学:疾病标志物发现全流程解析

临床医生也能懂的蛋白质组学:疾病标志物发现全流程解析 在肝癌诊疗中,我们常遇到这样的困境:当超声发现肝脏占位时,患者往往已进入中晚期。而甲胎蛋白(AFP)作为传统标志物,其敏感性和特异性仅约…...

RC定时电路

RC定时电路 什么是RC定时电路 RC 定时电路(RC Timing Circuit)是利用电阻 R 和电容 C 的充放电特性来实现时间控制的基础电路. 核心原理是: 电容的电压不能突变, 通过电阻给电容充电 / 放电时, 电压会按指数规律变化, 这个过程的时间由时间常数 τ RC 决定. 电阻控制电流速…...

如何在 Divi 主题中禁用锚点链接的平滑滚动动画

本文介绍一种无需修改 Divi 核心文件的安全方式,通过重写 et_pb_smooth_scroll 全局函数,将锚点跳转强制设为瞬时定位(0ms 动画),彻底禁用默认的平滑滚动效果。 本文介绍一种无需修改 divi 核心文件的安全方式&am…...

若依框架集成百度地图组件的实战指南

1. 环境准备与基础配置 在开始集成百度地图组件之前,确保你已经完成以下准备工作。我遇到过不少开发者因为基础环境没配好,导致后续步骤频频报错的情况,所以这部分特别重要。 首先,你需要一个有效的百度地图开发者账号。登录百度地…...

mysql如何通过防火墙保护MySQL权限_MySQL网络层安全配置

MySQL 默认监听0.0.0.0:3306,必须通过bind-address限制监听地址、系统防火墙(ufw/firewalld)设置IP白名单、云平台安全组精确放行,并与MySQL用户host字段协同配置,四层防护缺一不可。MySQL 默认监听所有网卡&#xff0…...

【大模型版权保护实战指南】:20年AI工程专家亲授3大不可绕过的法律+技术双轨防护体系

第一章:大模型版权保护的工程化挑战与战略定位 2026奇点智能技术大会(https://ml-summit.org) 大模型版权保护已远超法律文本层面的权属界定,演变为融合数据溯源、训练过程审计、模型水印嵌入与推理行为可验证性的系统性工程问题。当千亿参数模型在跨机…...

AI基础设施运维黑盒曝光:实时监控127个关键指标、自动定位集群间token吞吐偏差>15%的根因分析流程

第一章:AI基础设施运维黑盒曝光:实时监控127个关键指标、自动定位集群间token吞吐偏差>15%的根因分析流程 2026奇点智能技术大会(https://ml-summit.org) 现代大模型推理集群已演变为多租户、跨地域、异构加速卡混合部署的复杂系统,传统…...

2026届毕业生推荐的AI写作神器横评

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 人工智能论文工具正渐渐变成学术写作方面极为重要的辅助办法,这般工具一般会整合…...

Apriltag tag36H11:视觉系统标定的高效解决方案

1. Apriltag tag36H11是什么? 如果你玩过机器人或者做过计算机视觉项目,大概率听说过Apriltag。简单来说,Apriltag就是一种特殊的二维码,但它的设计更适用于机器视觉系统。tag36H11是Apriltag家族中最常用的一个变种,…...

如何快速掌握OCAuxiliaryTools:黑苹果配置的终极图形化指南

如何快速掌握OCAuxiliaryTools:黑苹果配置的终极图形化指南 【免费下载链接】OCAuxiliaryTools Cross-platform GUI management tools for OpenCore(OCAT) 项目地址: https://gitcode.com/gh_mirrors/oc/OCAuxiliaryTools 你是否在为黑…...

【源码深度】Android 图片加载框架全解析|Glide、Picasso、Fresco、Coil 原理与优化|Android全栈体系150讲-18

...

零基础部署NaViL-9B:手把手教你搭建图文理解AI助手

零基础部署NaViL-9B:手把手教你搭建图文理解AI助手 1. 认识NaViL-9B多模态模型 NaViL-9B是由专业研究机构开发的原生多模态大语言模型,它不仅能像普通AI助手一样处理文本问答,还具备理解图片内容的独特能力。这意味着你可以上传一张照片&am…...

AI热力图赋能商场运营:实时监控与智能决策的技术实践

1. 为什么商场需要AI热力图技术? 每次逛商场时,你可能注意过有些区域总是挤满人,而有些角落却冷冷清清。作为商场管理者,最头疼的就是不知道顾客到底在哪里聚集、为什么聚集。传统的人工巡查方式就像蒙着眼睛捉迷藏——效率低还不…...

Lite-Avatar持续集成:GitHub Actions实践指南

Lite-Avatar持续集成:GitHub Actions实践指南 1. 引言 你是不是也遇到过这样的情况:每次修改Lite-Avatar项目代码后,都要手动运行测试、构建镜像、部署到服务器?不仅耗时耗力,还容易出错。特别是当团队协作时&#x…...

深度探索ComfyUI-BrushNet:解锁图像修复与内容替换的3种创新应用范式

深度探索ComfyUI-BrushNet:解锁图像修复与内容替换的3种创新应用范式 【免费下载链接】ComfyUI-BrushNet ComfyUI BrushNet nodes 项目地址: https://gitcode.com/gh_mirrors/co/ComfyUI-BrushNet ComfyUI-BrushNet作为AI图像编辑领域的前沿技术实现&#xf…...

【大模型工程化能效优化黄金法则】:20年实战总结的7大降耗策略,省电37%实测数据首次公开

第一章:大模型工程化中的能效优化策略 2026奇点智能技术大会(https://ml-summit.org) 大模型推理与训练的能耗问题已不再仅是运维成本考量,而是关乎碳中和承诺、边缘部署可行性及长期服务SLA稳定性的核心工程约束。在千卡级集群与百亿参数模型常态化落地…...

密评实战指南—从算法验证到电子签章的全流程解析

1. 密评实战入门:为什么需要密码应用安全性评估 最近帮某政务系统做上线前的安全检测时,发现他们的登录接口居然用MD5存储密码。这让我想起三年前某大型数据泄露事件,根源就是用了不安全的加密算法。密码应用安全性评估(简称密评…...

Windows系统优化新选择:Win11Debloat让你的电脑重获新生

Windows系统优化新选择:Win11Debloat让你的电脑重获新生 【免费下载链接】Win11Debloat A simple, lightweight PowerShell script that allows you to remove pre-installed apps, disable telemetry, as well as perform various other changes to declutter and …...

Cosmos-Reason1-7B实际效果:离散数学归纳法证明过程结构化输出

Cosmos-Reason1-7B实际效果:离散数学归纳法证明过程结构化输出 提示:本文所有演示均基于本地部署的Cosmos-Reason1-7B推理工具,无需联网,保护隐私 1. 工具简介:你的本地数学推理助手 Cosmos-Reason1-7B是一个专门为逻…...

从卡比到瓦豆鲁迪:用OpenGL层次建模和贴图复刻经典游戏角色的保姆级教程

从卡比到瓦豆鲁迪:用OpenGL层次建模和贴图复刻经典游戏角色的保姆级教程 1. 前言:为什么选择卡比作为OpenGL学习案例 在计算机图形学的学习过程中,3D角色建模一直是令人着迷又颇具挑战性的领域。而《星之卡比》系列中的角色以其简洁的几何造型…...

混合Copula模型(Clayton-Frank-Gumbel)代码深度解析与实战指南

混合copula 二维数据拟合得到相关结构参数与系数 主要针对常用的Clayton Frank Gumbel三种copula函数的组合,进行混合copula构建 Matlab代码实现一、代码定位与核心价值 1.1 应用场景 这套MATLAB代码专为二维变量依赖结构分析设计,核心应用于金融工程&am…...

从ResNet到VISA-Transformer:2026奇点大会公布的视觉理解技术演进路线图(含3级技术替代时间窗口与迁移风险清单)

第一章:2026奇点智能技术大会:大模型视觉理解 2026奇点智能技术大会(https://ml-summit.org) 多模态视觉理解范式的跃迁 本届大会首次系统展示基于世界模型(World Model)驱动的视觉理解框架,其核心突破在于将图像解析…...

终极指南:如何让Mac外接鼠标获得触控板般丝滑滚动体验

终极指南:如何让Mac外接鼠标获得触控板般丝滑滚动体验 【免费下载链接】Mos 一个用于在 macOS 上平滑你的鼠标滚动效果或单独设置滚动方向的小工具, 让你的滚轮爽如触控板 | A lightweight tool used to smooth scrolling and set scroll direction independently f…...

无感FOC电机三相控制高速吹风筒方案 FU6812L+FD2504S 电压AC220V 功率80W

无感FOC电机三相控制高速吹风筒方案 FU6812LFD2504S 电压AC220V 功率80W 最高转速20万RPM 方案优势:响应快、效率高、噪声低、成本低 控制方式:三相电机无感FOC 闭环方式:功率闭环,速度闭环 调速接口:按键调试 提供原理…...

2026奇点大会闭门报告泄露(含原始benchmark数据):多轮对话SOTA模型在长记忆场景下的5项隐性衰减指标

第一章:2026奇点智能技术大会:大模型多轮对话 2026奇点智能技术大会(https://ml-summit.org) 在2026奇点智能技术大会上,大模型多轮对话能力成为核心议题之一。与会研究者展示了新一代对话系统在长程上下文建模、意图漂移检测与跨轮记忆对齐…...

PyTorch 2.8镜像惊艳效果:RTX 4090D下Llama3-8B+Phi-3-Vision多模态推理展示

PyTorch 2.8镜像惊艳效果:RTX 4090D下Llama3-8BPhi-3-Vision多模态推理展示 1. 开篇:专业级深度学习环境 当谈到高性能深度学习环境时,PyTorch 2.8与RTX 4090D的组合堪称当前最强大的配置之一。这个经过深度优化的镜像不仅提供了开箱即用的…...