当前位置: 首页 > article >正文

别再让模型‘偏科’了!PyTorch实战:用BCEWithLogitsLoss的weight和pos_weight搞定二分类数据不平衡

破解二分类数据不平衡PyTorch中BCEWithLogitsLoss的加权艺术当你的二分类模型总是对少数类视而不见预测结果清一色偏向多数类时这不是模型在偷懒而是数据不平衡在作祟。医疗诊断中的罕见病例识别、金融领域的欺诈交易检测、工业质检中的缺陷产品筛查——这些场景下的数据往往呈现严重的类别失衡。本文将带你深入PyTorch的BCEWithLogitsLoss通过weight和pos_weight这两个杠杆让模型学会雨露均沾。1. 数据不平衡模型偏科的罪魁祸首想象你正在训练一个识别罕见病的诊断系统。医院提供的1000份病例中只有20份是阳性病例。即使模型将所有预测都输出为阴性也能达到98%的准确率——这个数字看似漂亮但对实际应用毫无价值。这就是典型的数据不平衡问题带来的评估陷阱。数据不平衡会导致三个致命影响评估指标失真准确率变得毫无意义需要依赖精确率、召回率、F1分数等更细致的指标梯度主导问题多数类样本产生的梯度在反向传播中占据主导地位决策边界偏移模型倾向于将样本预测为多数类以获得表面上的好成绩from sklearn.metrics import classification_report # 模拟一个严重不平衡的数据集 y_true [1]*20 [0]*980 # 20个正样本980个负样本 y_pred [0]*1000 # 模型全部预测为负类 print(classification_report(y_true, y_pred))输出结果会显示虽然准确率高达98%但正类的召回率为0——这正是我们需要解决的问题。2. BCEWithLogitsLoss的加权机制解析PyTorch的BCEWithLogitsLoss实际上在单个函数中完成了两步操作先对输出应用sigmoid函数将其压缩到[0,1]区间再计算二元交叉熵损失。其基础公式为$$ L -\frac{1}{N}\sum_{i1}^N [y_i\cdot\log(\sigma(x_i)) (1-y_i)\cdot\log(1-\sigma(x_i))] $$当引入weight参数后公式变为$$ L -\frac{1}{N}\sum_{i1}^N weight[y_i] \cdot [y_i\cdot\log(\sigma(x_i)) (1-y_i)\cdot\log(1-\sigma(x_i))] $$而pos_weight则是更简洁的实现方式它专门针对正类样本的权重进行调整$$ L -\frac{1}{N}\sum_{i1}^N [y_i\cdot pos_weight \cdot \log(\sigma(x_i)) (1-y_i)\cdot\log(1-\sigma(x_i))] $$2.1 weight参数的实战应用weight参数是一个长度为2的张量分别指定负类和正类的权重。一个经验法则是将权重设置为类别频率的倒数import torch import torch.nn as nn num_neg 980 # 负样本数 num_pos 20 # 正样本数 total num_neg num_pos # 计算类别权重 weight torch.tensor([total/num_neg, total/num_pos]) # 约为[1.02, 50.0] criterion nn.BCEWithLogitsLoss(weightweight)在实际项目中我们通常会在DataLoader中统计类别分布from collections import Counter def calculate_weights(dataset): class_counts Counter(dataset.targets) total sum(class_counts.values()) return torch.tensor([total/class_counts[0], total/class_counts[1]]) weights calculate_weights(train_dataset) criterion nn.BCEWithLogitsLoss(weightweights)2.2 pos_weight的便捷之道当只需要调整正类权重时pos_weight是更简洁的选择。它与weight的关系可以表示为pos_weight torch.tensor([pos_weight_value]) # 等价于 weight torch.tensor([1.0, pos_weight_value])医疗影像诊断的典型设置示例# 假设正负样本比例为1:50 pos_weight torch.tensor([50.0]) criterion nn.BCEWithLogitsLoss(pos_weightpos_weight)重要提示当同时指定weight和pos_weight时pos_weight会覆盖weight中关于正类的权重设置。3. 权重计算的高级策略基础的倒数频率加权有时过于激进可能导致模型对噪声样本过度敏感。下面介绍几种更精细的权重调节方法。3.1 平滑加权法在极端不平衡场景下(如1:1000)直接使用倒数会导致权重差异过大。可采用平方根或对数平滑import math # 平方根平滑 weight_neg math.sqrt(total / num_neg) weight_pos math.sqrt(total / num_pos) weights torch.tensor([weight_neg, weight_pos]) # 对数平滑 weight_neg math.log(total / num_neg) weight_pos math.log(total / num_pos) weights torch.tensor([weight_neg, weight_pos])3.2 有效样本数加权借鉴Decoupling论文中的方法考虑样本的有效数量beta 0.999 # 超参数通常取0.9, 0.99或0.999 eff_num_neg (1 - beta**num_neg) / (1 - beta) eff_num_pos (1 - beta**num_pos) / (1 - beta) weights torch.tensor([1/eff_num_neg, 1/eff_num_pos])3.3 类别权重对比表加权方法计算公式适用场景优点缺点倒数频率weight total / num_samples一般不平衡场景简单直接对极端不平衡可能过激平方根平滑sqrt(total / num_samples)极端不平衡(1:100)缓和权重差异需要调参对数平滑log(total / num_samples)数据分布高度倾斜更温和的权重调整可能调整不足有效样本数(1-beta^N)/(1-beta)长尾分布理论依据充分需要选择beta值4. 医疗诊断实战肺炎X光片分类让我们通过一个真实的医疗影像案例展示如何处理1:10的肺炎分类数据不平衡问题。4.1 数据准备与权重计算from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader train_data ImageFolder(chest_xray/train) # 假设训练集分布为: 正常1341张肺炎3875张 num_neg 1341 # 正常(负类) num_pos 3875 # 肺炎(正类) total num_neg num_pos # 计算pos_weight pos_weight torch.tensor([num_neg / num_pos]) # 约0.346 # 等价于给负类更高权重 model CNN() # 自定义的卷积神经网络 criterion nn.BCEWithLogitsLoss(pos_weightpos_weight) optimizer torch.optim.Adam(model.parameters())4.2 训练循环中的关键实现def train_epoch(model, loader, criterion, optimizer): model.train() total_loss 0 for images, labels in loader: images images.to(device) labels labels.float().unsqueeze(1).to(device) optimizer.zero_grad() outputs model(images) loss criterion(outputs, labels) loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(loader)4.3 评估指标的选择在医疗场景中我们通常更关注召回率避免漏诊和AUC值from sklearn.metrics import roc_auc_score, recall_score def evaluate(model, loader): model.eval() all_preds [] all_labels [] with torch.no_grad(): for images, labels in loader: images images.to(device) outputs model(images) preds torch.sigmoid(outputs).cpu() all_preds.extend(preds.numpy()) all_labels.extend(labels.numpy()) auc roc_auc_score(all_labels, all_preds) recall recall_score(all_labels, (np.array(all_preds) 0.5).astype(int)) return auc, recall5. 金融风控场景信用卡欺诈检测信用卡欺诈检测通常面临更极端的数据不平衡约1:1000这时需要更精细的权重调节策略。5.1 动态权重调整随着训练进行可以动态调整权重以应对模型性能变化class DynamicWeightBCE(nn.Module): def __init__(self, initial_pos_weight): super().__init__() self.pos_weight nn.Parameter(torch.tensor([initial_pos_weight])) def forward(self, input, target): return nn.functional.binary_cross_entropy_with_logits( input, target, pos_weightself.pos_weight)5.2 混淆矩阵监控实时监控混淆矩阵根据模型表现调整策略from sklearn.metrics import confusion_matrix def get_confusion_matrix(model, loader): model.eval() all_preds [] all_labels [] with torch.no_grad(): for data, labels in loader: outputs model(data) preds (torch.sigmoid(outputs) 0.5).int() all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) return confusion_matrix(all_labels, all_preds)5.3 阈值调整技巧在推理阶段可以调整分类阈值而非直接使用0.5def predict_with_threshold(model, inputs, threshold0.5): model.eval() with torch.no_grad(): outputs model(inputs) probs torch.sigmoid(outputs) return (probs threshold).int()最佳阈值可以通过PR曲线或业务需求确定from sklearn.metrics import precision_recall_curve precisions, recalls, thresholds precision_recall_curve(true_labels, pred_probs) # 根据业务需求选择阈值如保证召回率不低于90% optimal_idx np.argmax(recalls 0.9) optimal_threshold thresholds[optimal_idx]6. 组合拳加权损失与其他不平衡处理技术虽然加权损失效果显著但结合其他技术往往能获得更好效果。以下是几种常见组合策略6.1 加权损失焦点损失焦点损失(Focal Loss)通过降低易分类样本的权重进一步聚焦难样本class FocalBCEWithLogitsLoss(nn.Module): def __init__(self, alpha0.25, gamma2, pos_weightNone): super().__init__() self.alpha alpha self.gamma gamma self.pos_weight pos_weight def forward(self, inputs, targets): bce_loss nn.functional.binary_cross_entropy_with_logits( inputs, targets, reductionnone, pos_weightself.pos_weight) pt torch.exp(-bce_loss) focal_loss self.alpha * (1-pt)**self.gamma * bce_loss return focal_loss.mean()6.2 加权损失数据增强对少数类样本应用更激进的数据增强from torchvision import transforms # 对正类使用更强的增强 train_transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(20), transforms.ColorJitter(0.1, 0.1, 0.1), transforms.ToTensor(), ]) # 在Dataset中根据标签应用不同增强 if label 1: # 正类 img transforms.RandomAffine(degrees0, translate(0.2,0.2))(img) img transforms.GaussianBlur(3)(img)6.3 加权损失模型架构调整修改网络最后层结构增强对少数类的识别能力class ImbalanceAwareHead(nn.Module): def __init__(self, in_features, bottleneck_dim128): super().__init__() self.bottleneck nn.Linear(in_features, bottleneck_dim) self.classifier nn.Linear(bottleneck_dim, 1) # 初始化分类器偏置反映类别先验 self.classifier.bias.data.fill_(-math.log((1-0.01)/0.01)) def forward(self, x): x self.bottleneck(x) return self.classifier(x)

相关文章:

别再让模型‘偏科’了!PyTorch实战:用BCEWithLogitsLoss的weight和pos_weight搞定二分类数据不平衡

破解二分类数据不平衡:PyTorch中BCEWithLogitsLoss的加权艺术 当你的二分类模型总是对少数类"视而不见",预测结果清一色偏向多数类时,这不是模型在偷懒,而是数据不平衡在作祟。医疗诊断中的罕见病例识别、金融领域的欺诈…...

国企领导:“现在都是 Agent自动开发了,你还在对话模式,太落后了!”我一点不慌:“这就去补,假期后见分晓!”领导露出满意的笑容。

马上假期了,我相信很多小伙伴肯定不会学习了,哦不,肯定不出去玩,要在家里学习 AI 对吧?(dog) 肯定的吧? 那在开始今天的内容之前,我也想问大家一下。 你平常更接近哪种…...

HPH内部构造大揭秘:三大系统配合节节通

今时,二零二六年四月三十日这一日,科技领域之内存在两件重大之事值得予以关注,其一乃是中国科学院所发布的“悟空”号暗物质卫星的最新成果,该成果揭示出了宇宙射线加速的关键机制;其二则是长三角区域的首台“华龙一号…...

让每一辆车快速拥抱AI!东软开启座舱AI Agent平权时代

2026年北京国际车展已释放出最明显的信号:座舱AI Agent正在加速落地。从用户体验侧来看,座舱交互系统最大的变化是从“会聊天”进化成“能干活”,座舱Agent变成了可精准了解用户需求,还能规划与执行的车内“私人助手”。这种进化&…...

VLC for Android:你的终极移动端万能媒体播放器解决方案

VLC for Android:你的终极移动端万能媒体播放器解决方案 【免费下载链接】vlc-android VLC for Android, Android TV and ChromeOS 项目地址: https://gitcode.com/gh_mirrors/vl/vlc-android 还在为手机无法播放某些视频格式而烦恼吗?或者经常遇…...

WWW 2026 利用知识图谱不但能够感知时间,还能“预判未来事件”?

01|研究背景:事件预测为什么需要“动态多模态”? 传统知识图谱通常关注结构化事实,例如: 主体 — 关系 — 客体 例如:Trump — LiveAt — White House 但现实世界中的事件并不是静止的。一个实体在不同时间…...

**大模型时代如何选对白酒?深度揭秘“晋善晋美”的技术创新与高性价比之道**

近年来,随着人工智能与大数据技术的飞速发展,白酒行业也悄然掀起了一场“数字化革命”。对于广大消费者而言,在信息爆炸的时代如何快速、精准地找到一家诚信白酒企业,并通过推荐白酒机构的权威背书,锁定一批高性价比白…...

CVE-2026-31431 Copy Fail:Linux 本地提权漏洞原理、影响面与排查修复建议

CVE-2026-31431 / Copy Fail 不是远程 RCE,攻击者需要先在目标机器上具备低权限代码执行能力。但这并不意味着它只是一个“小本地洞”。在容器节点、CI runner、共享开发机、跳板机、代码沙箱、Notebook、AI Agent 执行机这类环境里,“低权限代码执行”本…...

Vivado HLS 提供了 C++ 模板类 hls::stream<>

Vivado HLS 提供了 C 模板类 hls::stream<>&#xff0c;用于对流传输数据结构进行建模。 数据流在软件中&#xff08;以及在测试激励文件中进行 RTL 协同仿真期间&#xff09;作为无限队列来建模。在 C 中对数据流进行仿真 无需满足任意深度。数据流可在函数内部使用&…...

交大复旦 Bench2Drive-Speed:速度可控的自动驾驶评测基准

点击下方卡片&#xff0c;关注“自动驾驶之心”公众号戳我-> 领取自动驾驶近30个方向学习路线作者 | Yuqian Shao 等编辑 | 自动驾驶之心本文只做学术分享&#xff0c;如有侵权&#xff0c;联系删文>>自动驾驶前沿信息获取→自动驾驶之心知识星球导语端到端自动驾驶&a…...

[具身智能-509]:全局混乱下的局部有序:不要用战术的勤奋掩盖战略的懒惰

“在一个全局混乱的系统中&#xff0c;局部的有序是奢望。”很多初创团队容易陷入一种“伪忙碌”的状态&#xff1a;产品每天都在迭代新功能&#xff0c;销售每天都在疯狂打陌生电话&#xff0c;代码写得飞快&#xff0c;办公室灯火通明。但这往往是“全局混乱”的体现——因为…...

基于stm32ARM库函数的IIR二阶巴特沃斯低通滤波器--附完整代码

在嵌入式系统中使用ARM CMSIS-DSP库实现高效IIR低通滤波器 &#x1f3af; 引言&#xff1a;嵌入式系统中的实时信号处理挑战 在嵌入式系统开发中&#xff0c;信号处理往往面临双重挑战&#xff1a;既要保证实时性&#xff0c;又要在资源受限的环境下运行。今天&#xff0c;我…...

DHT11温湿度传感器核心技术解析

DHT11是一款数字式温湿度复合传感器&#xff0c;通过单总线协议与微控制器通信。其核心工作原理基于电阻式湿敏元件和NTC热敏电阻&#xff0c;内部集成了8位微处理器&#xff0c;负责将模拟信号转换为数字信号并校准输出。 1. 传感器特性与技术参数对比 特性DHT11备注温度测量…...

【无标题】滴滴答答滴滴答答滴滴答答滴滴答答滴滴答答

委屈委屈委屈恶趣味企鹅21...

阿里云百炼微调完整实战:从数据到部署

阿里云百炼微调完整实战&#xff1a;从数据到部署 目录 什么是模型微调微调 vs RAG&#xff1a;如何选择环境准备训练数据准备创建微调任务超参数配置详解模型部署LangChain 调用微调模型模型评测常见问题总结 一、什么是模型微调 模型微调&#xff08;Supervised Fine-Tun…...

工业数据转发实战:用NModbus4在WinForm中构建一个带UI的Modbus Slave服务器

工业数据转发实战&#xff1a;用NModbus4在WinForm中构建带UI的Modbus从站服务器 在工业自动化领域&#xff0c;数据采集与转发是连接现场设备与上层信息系统的关键环节。想象一下这样的场景&#xff1a;车间里的PLC控制器实时生成生产数据&#xff0c;而办公室的管理系统需要这…...

为什么特定场景只重试幂等请求,不重试非幂等请求?(幂等性Idempotence)因为重复非幂等请求会对系统产生重复的副作用

重试&#xff1a;仅幂等请求&#xff08;GET&#xff09;重试&#xff0c;最多 2 次&#xff0c;退避间隔 100ms 文章目录什么是幂等性&#xff1f;为什么只重试幂等请求&#xff1f;1. **避免重复副作用**2. **HTTP方法的幂等性分类**3. **实际风险示例**4. **安全重试机制**仅…...

终极指南:3分钟实现Adobe Illustrator到Photoshop的无损图层转换

终极指南&#xff1a;3分钟实现Adobe Illustrator到Photoshop的无损图层转换 【免费下载链接】ai-to-psd A script for prepare export of vector objects from Adobe Illustrator to Photoshop 项目地址: https://gitcode.com/gh_mirrors/ai/ai-to-psd 还在为AI文件转P…...

别再让ChatGLM说车轱辘话了!手把手教你用Hugging Face的LogitsProcessor解决LLM重复生成

彻底根治大模型复读机&#xff1a;Hugging Face LogitsProcessor实战指南 看着屏幕上不断重复的"这个问题很重要这个问题很重要这个问题很重要"&#xff0c;我第17次按下了终止键。作为某金融科技公司的AI产品经理&#xff0c;我们上线ChatGLM-6B后的用户投诉中&…...

对比使用Taotoken前后在模型选型与切换上的效率提升

使用 Taotoken 简化模型选型与切换的技术实践 1. 传统模型接入的痛点 在 Taotoken 平台出现之前&#xff0c;开发者接入不同大模型厂商的 API 需要面对一系列繁琐流程。每个厂商都有独立的注册流程、API Key 申请方式和文档体系。以常见的三个模型为例&#xff0c;开发者需要…...

Windows Server 2019上为Tesla T4配置CUDA 11.0和CUDNN 8.0.5的完整避坑指南

Windows Server 2019深度学习环境配置全攻略&#xff1a;Tesla T4CUDA 11.0实战指南 在企业级AI应用部署中&#xff0c;服务器环境配置往往是工程师面临的第一个挑战。不同于个人电脑的即插即用&#xff0c;Windows Server 2019特有的安全策略与系统架构&#xff0c;使得从驱动…...

Spark NLP:工业级分布式自然语言处理框架实战指南

1. 项目概述&#xff1a;当Spark遇上NLP&#xff0c;一个工业级文本处理框架的诞生如果你在数据科学或机器学习领域工作过一段时间&#xff0c;尤其是处理过海量文本数据&#xff0c;那你一定对两个词深有体会&#xff1a;一个是“慢”&#xff0c;另一个是“复杂”。传统的自然…...

springboot+vue3的旅游民宿预定管理系统的设计与实现

目录同行可拿货,招校园代理 ,本人源头供货商功能模块分析技术实现要点扩展功能建议项目技术支持源码获取详细视频演示 &#xff1a;文章底部获取博主联系方式&#xff01;同行可合作同行可拿货,招校园代理 ,本人源头供货商 功能模块分析 用户端功能 用户注册与登录&#xff…...

ScienceDecrypting:终极CAJ文档解密指南,3步实现科学文库文档永久保存

ScienceDecrypting&#xff1a;终极CAJ文档解密指南&#xff0c;3步实现科学文库文档永久保存 【免费下载链接】ScienceDecrypting 破解CAJViewer带有效期的文档&#xff0c;支持破解科学文库、标准全文数据库下载的文档。无损破解&#xff0c;保留文字和目录&#xff0c;解除有…...

内存带宽吃紧?GC风暴频发?R 4.5并行计算效率断崖式下降的5个反直觉元凶,今夜必须修复

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;R 4.5并行计算性能断崖的系统性归因 R 4.5版本在引入future与parallel包深度集成的同时&#xff0c;意外暴露了底层线程调度与内存管理的结构性矛盾。性能断崖并非单一缺陷所致&#xff0c;而是运行时环…...

springboot+vue3的婚礼场景规划系统设计与实现

目录同行可拿货,招校园代理 ,本人源头供货商功能模块分析技术实现要点扩展功能设计安全与兼容性项目技术支持源码获取详细视频演示 &#xff1a;文章底部获取博主联系方式&#xff01;同行可合作同行可拿货,招校园代理 ,本人源头供货商 功能模块分析 用户管理模块 注册与登录…...

3大核心方案:彻底解决DouyinLiveRecorder中PandaTV录制失败的终极指南

3大核心方案&#xff1a;彻底解决DouyinLiveRecorder中PandaTV录制失败的终极指南 【免费下载链接】DouyinLiveRecorder 可循环值守和多人录制的直播录制软件&#xff0c;支持抖音、TikTok、Youtube、快手、虎牙、斗鱼、B站、小红书、pandatv、sooplive、flextv、popkontv、twi…...

别再手动指定模型了!用Hugging Face的AutoModel和AutoProcessor,一行代码搞定BERT/GPT加载

一行代码解放生产力&#xff1a;Hugging Face AutoClass全解析 第一次接触Hugging Face Transformers库时&#xff0c;面对琳琅满目的模型类名——BertForSequenceClassification、RobertaTokenizer、GPT2LMHeadModel...你是否感到头晕目眩&#xff1f;每个项目开始前都要翻阅…...

Scala 方法与函数

Scala 方法与函数 引言 Scala 是一门多范式编程语言,它结合了面向对象和函数式编程的特性。在 Scala 中,方法和函数是构建程序的基本单元。本文将深入探讨 Scala 中的方法和函数,包括它们的定义、使用以及在实际编程中的应用。 方法与函数的定义 在 Scala 中,方法和函数…...

PaddlePaddle数据加载进阶:除了MNIST,你更应该掌握这几种内置数据集和高效采样技巧

PaddlePaddle数据加载进阶&#xff1a;除了MNIST&#xff0c;你更应该掌握这几种内置数据集和高效采样技巧 当你的深度学习模型在MNIST上轻松达到99%准确率时&#xff0c;是否曾思考过&#xff1a;数据加载环节可能正在成为整个训练流程的瓶颈&#xff1f;在真实工业场景中&…...