当前位置: 首页 > article >正文

别再混用了!PyTorch实战:CrossEntropyLoss和BCEWithLogitsLoss到底怎么选?(附MNIST与多标签分类代码)

PyTorch损失函数实战指南CrossEntropyLoss与BCEWithLogitsLoss的精准选择当你面对一个分类问题时选择正确的损失函数往往决定了模型的成败。PyTorch提供了多种损失函数但CrossEntropyLoss和BCEWithLogitsLoss是最容易混淆的两个。本文将带你深入理解它们的差异并通过实际代码演示如何在不同场景下做出明智选择。1. 理解分类任务的基本类型在深度学习中分类任务主要分为两种基本类型单标签分类Multi-class Classification每个样本只能属于一个类别。例如MNIST手写数字识别一张图片只能是0-9中的一个数字。多标签分类Multi-label Classification每个样本可以同时属于多个类别。例如图像中可能同时包含猫、狗和树等多个标签。这两种任务在数据处理和模型设计上有本质区别而损失函数的选择正是基于这些差异。2. CrossEntropyLoss深度解析nn.CrossEntropyLoss是PyTorch中处理单标签分类任务的首选损失函数。它实际上是Softmax激活和负对数似然损失(NLLLoss)的组合。2.1 数学原理CrossEntropyLoss的计算公式为$$ \text{loss}(x, class) -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) -x[class] \log\left(\sum_j \exp(x[j])\right) $$其中x是模型的原始输出logitsclass是目标类别索引import torch import torch.nn as nn # 创建模型输出和目标 logits torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.0, 0.3]]) targets torch.tensor([0, 2]) # 每个样本一个类别索引 # 计算损失 loss_fn nn.CrossEntropyLoss() loss loss_fn(logits, targets) print(fCrossEntropyLoss: {loss.item():.4f})2.2 输入输出要求模型输出不需要经过Softmax处理直接使用原始logits形状为[batch_size, num_classes]目标标签类别的索引形状为[batch_size]每个值是0到num_classes-1之间的整数注意CrossEntropyLoss内部会自动应用Softmax因此不要在模型最后一层添加Softmax激活这会导致数值不稳定。2.3 MNIST实战示例让我们用经典的MNIST数据集演示CrossEntropyLoss的实际应用import torchvision from torchvision import transforms from torch.utils.data import DataLoader # 数据准备 transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset torchvision.datasets.MNIST( root./data, trainTrue, downloadTrue, transformtransform) test_dataset torchvision.datasets.MNIST( root./data, trainFalse, downloadTrue, transformtransform) train_loader DataLoader(train_dataset, batch_size64, shuffleTrue) test_loader DataLoader(test_dataset, batch_size1000, shuffleFalse) # 定义简单CNN模型 class MNISTModel(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(1, 32, 3, 1) self.conv2 nn.Conv2d(32, 64, 3, 1) self.dropout nn.Dropout(0.25) self.fc1 nn.Linear(9216, 128) self.fc2 nn.Linear(128, 10) def forward(self, x): x self.conv1(x) x nn.ReLU()(x) x self.conv2(x) x nn.ReLU()(x) x nn.MaxPool2d(2)(x) x self.dropout(x) x torch.flatten(x, 1) x self.fc1(x) x nn.ReLU()(x) x self.fc2(x) return x model MNISTModel() criterion nn.CrossEntropyLoss() optimizer torch.optim.Adam(model.parameters()) # 训练循环 for epoch in range(5): model.train() for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output model(data) loss criterion(output, target) loss.backward() optimizer.step() # 测试 model.eval() correct 0 with torch.no_grad(): for data, target in test_loader: output model(data) pred output.argmax(dim1, keepdimTrue) correct pred.eq(target.view_as(pred)).sum().item() print(fEpoch {epoch}: Accuracy {100. * correct / len(test_loader.dataset):.2f}%)3. BCEWithLogitsLoss全面剖析nn.BCEWithLogitsLoss是处理多标签分类任务的利器它结合了Sigmoid激活和二元交叉熵损失提供了数值稳定性。3.1 数学原理BCEWithLogitsLoss的计算公式为$$ \text{loss}(x, y) -\frac{1}{n}\sum_i [y_i \cdot \log\sigma(x_i) (1-y_i)\cdot \log(1-\sigma(x_i))] $$其中x是模型的原始输出logitsy是目标概率0或1σ是Sigmoid函数# BCEWithLogitsLoss示例 logits torch.tensor([[0.8, -0.5], [1.2, -1.0]]) targets torch.tensor([[1.0, 0.0], [1.0, 1.0]]) # 多标签 loss_fn nn.BCEWithLogitsLoss() loss loss_fn(logits, targets) print(fBCEWithLogitsLoss: {loss.item():.4f})3.2 输入输出要求模型输出不需要经过Sigmoid处理直接使用原始logits形状为[batch_size, num_classes]目标标签每个类别独立的概率形状与输出相同[batch_size, num_classes]值为0.0或1.0提示虽然称为二元交叉熵但它可以完美处理多标签问题只需为每个类别独立计算损失。3.3 多标签分类实战让我们创建一个模拟的多标签分类任务import numpy as np # 创建模拟多标签数据集 class MultiLabelDataset(torch.utils.data.Dataset): def __init__(self, num_samples1000, num_features20, num_classes5): self.data torch.randn(num_samples, num_features) # 随机生成多标签每个样本可能有多个1 self.labels torch.randint(0, 2, (num_samples, num_classes)).float() def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx], self.labels[idx] dataset MultiLabelDataset() dataloader DataLoader(dataset, batch_size32, shuffleTrue) # 定义简单模型 class MultiLabelModel(nn.Module): def __init__(self, input_size, num_classes): super().__init__() self.linear1 nn.Linear(input_size, 64) self.linear2 nn.Linear(64, num_classes) def forward(self, x): x nn.ReLU()(self.linear1(x)) x self.linear2(x) # 注意没有最后的Sigmoid return x model MultiLabelModel(20, 5) criterion nn.BCEWithLogitsLoss() optimizer torch.optim.Adam(model.parameters()) # 训练循环 for epoch in range(10): model.train() for data, target in dataloader: optimizer.zero_grad() output model(data) loss criterion(output, target) loss.backward() optimizer.step() print(fEpoch {epoch}: Loss {loss.item():.4f})4. 决策流程图如何选择正确的损失函数在实际项目中你可以按照以下流程图做出选择确定问题类型每个样本只能属于一个类别 →CrossEntropyLoss每个样本可以属于多个类别 →BCEWithLogitsLoss检查标签格式单标签形状[batch_size]的类别索引多标签形状[batch_size, num_classes]的0/1矩阵模型输出处理CrossEntropyLoss最后一层无激活函数BCEWithLogitsLoss最后一层无Sigmoid特殊情况处理二分类问题两种损失函数都可以使用但BCEWithLogitsLoss通常更直接类别不平衡考虑添加weight参数或使用pos_weight下表总结了两种损失函数的关键区别特性CrossEntropyLossBCEWithLogitsLoss适用任务单标签分类多标签分类目标标签形状[batch_size][batch_size, num_classes]模型输出处理无需Softmax无需Sigmoid内部激活函数SoftmaxSigmoid数学计算多类交叉熵二元交叉熵求和典型应用场景MNIST、CIFAR分类多标签图像分类、推荐系统5. 高级技巧与常见陷阱5.1 处理类别不平衡在实际数据中我们经常会遇到类别不平衡问题。两种损失函数都提供了解决方案# CrossEntropyLoss处理类别不平衡 class_weights torch.tensor([1.0, 2.0, 1.5]) # 为每个类别设置权重 criterion nn.CrossEntropyLoss(weightclass_weights) # BCEWithLogitsLoss处理正样本稀少 pos_weight torch.tensor([5.0]) # 正样本权重 criterion nn.BCEWithLogitsLoss(pos_weightpos_weight)5.2 数值稳定性技巧虽然PyTorch的这两种损失函数已经优化了数值稳定性但在极端情况下仍需注意避免在模型最后一层手动添加Softmax或Sigmoid对于BCEWithLogitsLoss可以使用torch.sigmoid将输出转换为概率时添加eps防止数值溢出probs torch.sigmoid(output).clamp(min1e-7, max1-1e-7)5.3 混合任务处理有时我们会遇到同时包含单标签和多标签的任务。这种情况下可以将单标签转换为多标签的one-hot形式统一使用BCEWithLogitsLoss对单标签部分添加约束如确保每行只有一个1# 将单标签转换为多标签形式 single_labels torch.tensor([0, 2, 1]) # 3个样本3个类别 multi_labels torch.zeros(3, 3) multi_labels[torch.arange(3), single_labels] 15.4 自定义损失函数在某些特殊场景下你可能需要自定义损失函数。例如实现带掩码的多标签损失def masked_bce_with_logits(output, target, mask): loss nn.BCEWithLogitsLoss(reductionnone)(output, target) loss loss * mask # 应用掩码 return loss.mean() # 只计算掩码部分的均值6. 性能优化建议在实际项目中损失函数的选择和实现会显著影响训练效率和模型性能批量处理优化确保数据加载时已经正确批量化使用pin_memoryTrue加速GPU数据传输混合精度训练两种损失函数都支持AMP自动混合精度可以显著减少显存使用并加速训练scaler torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast(): output model(data) loss criterion(output, target) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()分布式训练兼容性两种损失函数都完全支持DDP分布式数据并行无需特殊处理即可在多GPU环境下工作ONNX/TensorRT导出CrossEntropyLoss在导出时通常被移除推理时只需要logitsBCEWithLogitsLoss同样不需要在推理图中保留# 导出模型示例 dummy_input torch.randn(1, 1, 28, 28) torch.onnx.export(model, dummy_input, model.onnx, input_names[input], output_names[output])7. 实际项目经验分享在真实项目中使用这两种损失函数时有几个容易踩的坑值得注意标签编码错误是最常见的问题。曾经在一个多标签项目中误将0/1标签编码为1/2导致模型完全无法收敛。忘记调整输出层。有次在修改单标签为多标签模型时保留了最后的Softmax层结果损失值出现NaN。评估指标不匹配。多标签分类不能直接使用准确率应该考虑精确率、召回率或F1分数。学习率设置差异。实践中发现BCEWithLogitsLoss通常需要比CrossEntropyLoss更小的学习率。

相关文章:

别再混用了!PyTorch实战:CrossEntropyLoss和BCEWithLogitsLoss到底怎么选?(附MNIST与多标签分类代码)

PyTorch损失函数实战指南:CrossEntropyLoss与BCEWithLogitsLoss的精准选择 当你面对一个分类问题时,选择正确的损失函数往往决定了模型的成败。PyTorch提供了多种损失函数,但CrossEntropyLoss和BCEWithLogitsLoss是最容易混淆的两个。本文将带…...

Pyenv vs Miniconda vs Anaconda:Python环境管理实战对比

1. Python环境管理工具全景概览 刚接触Python开发时,最让我头疼的就是环境配置问题。同一个项目在不同电脑上跑出不同结果,安装包时各种依赖报错,这些经历相信很多开发者都遇到过。Python环境管理工具就是为解决这些问题而生的,它…...

Fluent Python Console实战指南:解锁PyFluent-Core的GUI交互新体验

1. Fluent Python Console初探:当仿真遇上交互式编程 第一次在Fluent里敲下Python命令时,那种感觉就像在汽车方向盘旁边发现了隐藏的飞行模式按钮。作为从2023 R1版本开始引入的Beta功能,Fluent Python Console彻底改变了我们与仿真软件交互的…...

Python 快速上手 Telegram Bot:从零到一的实战指南

1. 为什么选择Python开发Telegram Bot? Telegram Bot就像是你安插在Telegram里的一个24小时待命的智能助手。它能自动回复消息、处理订单、推送新闻,甚至陪你玩文字游戏。而Python凭借其简洁的语法和丰富的库生态,成为了开发Telegram Bot的首…...

RMBG-2.0部署案例:跨境电商独立站商品图自动化处理流水线

RMBG-2.0部署案例:跨境电商独立站商品图自动化处理流水线 1. 项目背景与需求 跨境电商独立站每天需要处理大量商品图片,其中背景移除是最基础也是最耗时的环节。传统的人工抠图方式存在几个明显问题: 时间成本高:一张商品图手动…...

SUNFLOWER MATCH LAB植物匹配实验室Python入门实战:从零开始部署与调用

SUNFLOWER MATCH LAB植物匹配实验室Python入门实战:从零开始部署与调用 你是不是也对那些能识别花草树木的AI应用感到好奇?看到别人用几行代码就能让电脑认出图片里的植物,自己也想试试,但又担心Python基础不够,环境配…...

前端开发趋势分析

前端开发趋势分析:探索未来技术方向 在数字化浪潮的推动下,前端开发作为连接用户与产品的桥梁,正经历着前所未有的变革。从静态页面到动态交互,再到如今的全栈化与智能化,前端技术不断突破边界。本文将分析当前前端开…...

AI绘画神器FLUX.1-dev:Docker快速部署指南,开箱即用体验惊艳画质

AI绘画神器FLUX.1-dev:Docker快速部署指南,开箱即用体验惊艳画质 1. 引言:为什么选择FLUX.1-dev旗舰版? 如果你正在寻找一款能够生成影院级画质的AI绘画工具,FLUX.1-dev旗舰版绝对值得尝试。这个基于Docker的解决方案…...

Youtu-Parsing快速开始:单图片模式、批量处理模式、输出格式详解

Youtu-Parsing快速开始:单图片模式、批量处理模式、输出格式详解 1. 项目概述 Youtu-Parsing是腾讯优图实验室推出的专业文档解析模型,基于Youtu-LLM-2B构建,能够智能识别文档中的多种元素: 文本内容:精准OCR文字识…...

3分钟搞定智慧树自动刷课:解放双手的学习加速器终极指南

3分钟搞定智慧树自动刷课:解放双手的学习加速器终极指南 【免费下载链接】zhihuishu 智慧树刷课插件,自动播放下一集、1.5倍速度、无声 项目地址: https://gitcode.com/gh_mirrors/zh/zhihuishu 还在为智慧树平台繁琐的网课学习而烦恼吗&#xff…...

2025届学术党必备的十大AI辅助写作神器推荐榜单

Ai论文网站排名(开题报告、文献综述、降aigc率、降重综合对比) TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 对于知网AI检测系统的降重策略,要从文本特征着手。其一,把短句合并成…...

AI让Verilog入门不再劝退,但芯片工程师真的轻松了吗?

还记得第一次写Verilog的感觉吗&#xff1f;明明只是想让一个LED灯闪烁&#xff0c;却要先声明一堆wire、reg&#xff0c;搞清楚阻塞赋值和非阻塞赋值的区别&#xff0c;再纠结always块里该用还是<。现在的情况完全不同了。新入行的工程师可以直接对AI说&#xff1a;"帮…...

2025最权威的AI写作平台推荐榜单

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 想要降低被检测出是AIGC&#xff08;也就是人工智能生成内容&#xff09;的概率&#xff0c;…...

Hunyuan-MT-7B翻译模型体验分享:简单易用的多语言翻译工具

Hunyuan-MT-7B翻译模型体验分享&#xff1a;简单易用的多语言翻译工具 1. 模型概览与核心优势 Hunyuan-MT-7B是腾讯混元团队推出的开源多语言翻译模型&#xff0c;凭借70亿参数的紧凑架构实现了专业级的翻译质量。这个模型最吸引人的特点是它能在消费级显卡上流畅运行&#x…...

使用VSCode远程开发并调试Qwen3.5-4B模型调用代码

使用VSCode远程开发并调试Qwen3.5-4B模型调用代码 1. 前言&#xff1a;为什么需要远程开发&#xff1f; 当你开始接触大模型开发时&#xff0c;可能会遇到一个常见问题&#xff1a;本地电脑性能不足&#xff0c;无法流畅运行像Qwen3.5-4B这样的模型。这时候&#xff0c;远程开…...

为什么你的INT4模型崩了?:SITS2026实测17个开源大模型量化表现,独家发布「量化鲁棒性评分卡」(含Qwen2、Phi-3、DeepSeek-V2全量数据)

第一章&#xff1a;SITS2026分享&#xff1a;大模型量化压缩技术 2026奇点智能技术大会(https://ml-summit.org) 大模型量化压缩已成为部署百亿参数级语言模型至边缘设备与推理服务集群的关键路径。在SITS2026现场&#xff0c;多家研究团队展示了基于混合精度、通道感知与校准…...

Qwen3Guard-Gen-WEB快速体验:网页界面一键审核内容安全

Qwen3Guard-Gen-WEB快速体验&#xff1a;网页界面一键审核内容安全 1. 为什么选择Qwen3Guard-Gen-WEB&#xff1f; 1.1 内容安全审核的痛点 在AI应用开发过程中&#xff0c;内容安全审核往往成为项目落地的最后一道障碍。传统方案面临三大挑战&#xff1a; 技术门槛高&…...

S2-Pro YOLOv11目标检测结果分析与报告生成

S2-Pro YOLOv11目标检测结果分析与报告生成 1. 计算机视觉项目的后期处理痛点 在完成目标检测模型的训练和部署后&#xff0c;很多开发者都会遇到一个共同的问题&#xff1a;如何高效处理和分析模型输出的检测结果。传统的做法是手动查看每张图片的检测框&#xff0c;统计各类…...

C++集成指南:高性能调用LongCat-Image-Edit核心算法

C集成指南&#xff1a;高性能调用LongCat-Image-Edit核心算法 最近在折腾一个图像处理项目&#xff0c;需要把动物图片编辑功能集成到C后端服务里。一开始用Python接口调用LongCat-Image-Edit&#xff0c;效果确实不错&#xff0c;但性能瓶颈很快就出现了——批量处理时速度跟…...

别再死记硬背了!用一张图+实战命令,彻底搞懂STP/RSTP/MSTP的选举过程

一张拓扑图五条命令&#xff1a;动态拆解生成树协议选举全流程 刚接触生成树协议时&#xff0c;我总被各种选举规则绕得头晕——桥ID、路径开销、端口优先级这些概念像天书一样。直到导师在白板上画了个简单的三角形拓扑&#xff0c;用不同颜色标注出阻塞端口&#xff0c;突然一…...

文脉定序系统效果对比评测:与传统BM25算法的性能较量

文脉定序系统效果对比评测&#xff1a;与传统BM25算法的性能较量 最近在折腾一个技术文档的智能检索项目&#xff0c;发现一个挺有意思的现象&#xff1a;很多朋友一提到搜索排序&#xff0c;脑子里蹦出来的第一个词还是“BM25”。这算法确实经典&#xff0c;像信息检索领域的…...

Ollama本地大模型新玩法:PasteMD剪贴板美化工具深度体验

Ollama本地大模型新玩法&#xff1a;PasteMD剪贴板美化工具深度体验 1. 为什么PasteMD是文本处理的革命性工具 在日常工作中&#xff0c;我们经常遇到这样的困扰&#xff1a; 从会议录音转写的文字稿杂乱无章&#xff0c;关键信息淹没在大量口语化表达中复制粘贴的代码片段丢失…...

MTools优化升级:开启GPU加速,让AI编程和文档生成更快更稳

MTools优化升级&#xff1a;开启GPU加速&#xff0c;让AI编程和文档生成更快更稳 1. 工具升级亮点&#xff1a;GPU加速全面支持 MTools最新版本带来了革命性的性能提升&#xff0c;通过全面支持GPU加速&#xff0c;让AI编程和文档生成的速度和稳定性都达到了新高度。这次升级…...

434649494

4546465484...

Phi-3-mini-128k-instruct在WSL2中的部署详解:Windows开发者的福音

Phi-3-mini-128k-instruct在WSL2中的部署详解&#xff1a;Windows开发者的福音 如果你是一名Windows开发者&#xff0c;想体验最新的AI模型&#xff0c;但又不想折腾双系统或者虚拟机&#xff0c;那今天这篇文章就是为你准备的。我们一起来聊聊怎么在Windows自带的WSL2里&…...

Harmonyos在语文教学中应用-6. 口令指令执行器(对应:口语交际:我说你做)

6. 口令指令执行器(对应:口语交际:我说你做) 功能介绍: 辅助《我说你做》口语交际的工具。应用内置语音识别功能,当教师或同学发出指令(如“举起右手”、“摸摸耳朵”)时,系统识别语音并在屏幕上显示对应的动作图标或文字。这帮助学生听懂指令并做出反应,锻炼听力和…...

丹青幻境效果展示:‘一袭青衣,倚楼听雨’12轮不同机缘下的意境变化

丹青幻境效果展示&#xff1a;‘一袭青衣&#xff0c;倚楼听雨’12轮不同机缘下的意境变化 你有没有想过&#xff0c;一句诗、一个画面&#xff0c;能变幻出多少种不同的模样&#xff1f; “一袭青衣&#xff0c;倚楼听雨”&#xff0c;这八个字在我脑海里盘旋了很久。它像一…...

Chandra OCR科研复现教程:olmOCR基准测试环境搭建与83.1分结果验证

Chandra OCR科研复现教程&#xff1a;olmOCR基准测试环境搭建与83.1分结果验证 4 GB显存即可运行&#xff0c;83分OCR精度&#xff0c;表格/手写/公式一次搞定&#xff0c;输出直接是Markdown 1. 项目背景与核心价值 Chandra是Datalab.to在2025年10月开源的"布局感知&quo…...

手把手教程:基于Qwen2.5-VL的Chord视觉定位模型,快速部署与实战体验

手把手教程&#xff1a;基于Qwen2.5-VL的Chord视觉定位模型&#xff0c;快速部署与实战体验 1. 项目概述 Chord视觉定位模型是基于Qwen2.5-VL多模态大模型构建的智能视觉定位服务。它能理解自然语言描述&#xff0c;在图像中精确定位目标对象并返回边界框坐标&#xff0c;无需…...

Qwen3-ASR-1.7B实战:智能客服语音转文字方案落地解析

Qwen3-ASR-1.7B实战&#xff1a;智能客服语音转文字方案落地解析 1. 引言&#xff1a;智能客服的语音识别挑战 在智能客服系统中&#xff0c;语音识别(ASR)技术承担着将客户语音转化为可处理文本的关键任务。然而传统ASR方案在实际落地时常常面临三大挑战&#xff1a; 多语言…...