当前位置: 首页 > news >正文

【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围:

1. nn.CrossEntropyLoss

这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 softmax 函数和交叉熵损失的计算。它假设最后一层的输出没有经过归一化处理(不是概率形式),而是直接给出了各个类别的得分。该函数会自动计算每一样本对各类别的得分,应用softmax函数,然后计算交叉熵损失。

import torch
import torch.nn as nn# 假设 outputs 是模型的最后一层输出,shape 为 (batch_size, num_classes),targets 是 ground truth labels
outputs = torch.randn(100, 10)  # 对于10分类问题的100个样本的不归一化的预测值
targets = torch.randint(0, 10, (100,))  # 对应的真实类别loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, targets)
print(loss.item())

2. F.cross_entropy 函数

torch.nn.functional.cross_entropy 函数也是为了多分类问题设计的,但它接受的是 logits 或者已经经过 softmax 的概率。如果你的输出已经是经过 softmax 的概率,可以直接使用;否则,它会默认内部先执行 log_softmax

import torch.nn.functional as F# 假设 outputs 是未经 softmax 的 logits
outputs = torch.randn(100, 10)# 使用 F.cross_entropy 直接计算损失,无需单独进行 softmax
loss = F.cross_entropy(outputs, targets)
print(loss.item())

3. nn.BCEWithLogitsLoss 类(二分类问题)

对于二分类问题,尤其是sigmoid激活函数之后的结果,可以使用带Sigmoid的二元交叉熵损失函数,它同时完成 sigmoid 和 二元交叉熵损失的计算。

# 二分类问题,输出维度为 (batch_size, 1)
outputs = torch.randn(100, 1) # targets 是介于 [0, 1] 或 {-1, 1} 的值,表示正负样本
targets = torch.rand(100, 1) > 0.5  # 或者其他的二进制标签bce_loss = nn.BCEWithLogitsLoss()
loss = bce_loss(outputs, targets.float())
print(loss.item())

4. 手动计算交叉熵损失

当然,也可以手动组合 log_softmaxnll_loss 函数来计算交叉熵损失,这在特殊情况下可能会有用,比如需要对损失函数进行修改或者自定义的时候:

# 多分类问题,手动组合 log_softmax 和 nll_loss
output_logits = torch.randn(100, 10)
softmax_outputs = F.log_softmax(output_logits, dim=1)  # 计算 log_softmax
loss_manual = -torch.mean(torch.gather(softmax_outputs, 1, targets.unsqueeze(1)).squeeze())  # 使用 gather 和 mean 计算 NLL
assert torch.allclose(loss_manual, F.nll_loss(softmax_outputs, targets, reduction='mean'))  # 应该与 nll_loss 结果一致

在上述代码中,gather 函数用于从预测概率矩阵中按照目标标签索引出相应的对数概率,然后求平均得到最终的交叉熵损失。在多分类任务中,直接使用 F.nll_loss(log_softmax_outputs, targets) 是更加简洁的做法,等价于手动计算。而在二分类问题中,对应的手动计算方式则会涉及 sigmoidbinary_cross_entropy_with_logits 函数。

5. 补充说明

在交叉熵损失计算函数中:
L = − ∑ i = 1 n y i l o g ( S ( f θ ( x i ) ) ) L = -\sum_{i=1}^{n}{y_i}log(S(f_\theta(x_i))) L=i=1nyilog(S(fθ(xi)))
真实值 y i y_i yi可以是热编码后的结果,也可以不进行热编码。
虽然在Pytorch架构中,神经网络内流动的数据类型必须是float类型,但是Pytorch也提供了自动处理整数(int类型)标签的交叉熵损失函数(这里的“整数标签”指的是每个样本所属的真实类别,通常是一个从0开始的整数索引,对应着类别数量中的一个),这些函数会自动将整数标签转换为内部使用的one-hot编码格式,并计算交叉熵损失。
nn.CrossEntropyLoss为例,当输入给定的output是未经归一化的类别得分(logits),而target是整数标签时,这个损失函数会自动将整数标签转换为one-hot格式,然后再进行交叉熵损失的计算。这意味着用户不需要预先将目标标签转换为one-hot编码,损失函数内部会处理这样的转换过程。

import torch
import torch.nn as nn# 假设我们有一个批次的输出和对应的类别标签
outputs = torch.randn(64, 10)  # 这是一个批次的输出,共64个样本,10个类别
labels = torch.tensor([2, 7, 0, ..., 4], dtype=torch.long)  # 这是对应的整数类别标签loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, labels)print(f'Cross-entropy loss: {loss.item()}')

相关文章:

【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围: 1. nn.CrossEntropyLoss 类 这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 sof…...

机器学习:处理jira工单的分类问题

如何根据jira工单的category、reporter自动找到处理它的组呢?这是一个利用机器学习中knn算法的小实践. 目录 Knn算法 数据 示例 分割数据 选择Neighbors knn的优缺点 机器学习是一种技术,它的目的是给机器学习能力,让它们可以根据数据自己做决定,所以对于训练…...

后端常问面经之操作系统

请简要描述线程与进程的关系,区别及优缺点? 本质区别:进程是操作系统资源分配的基本单位,而线程是任务调度和执行的基本单位 在开销方面:每个进程都有独立的代码和数据空间(程序上下文),程序之…...

RK3568平台 iperf3测试网络性能

一.iperf3简介 iperf是一款开源的网络性能测试工具,主要用于测量TCP和UDP带宽性能。它可以在不同的操作系统上运行,包括Windows、Linux、macOS等。iperf具有简单易用、功能强大、高度可配置等特点,广泛应用于网络性能测试、网络故障诊断和网…...

Spring Boot中实现对特定URL的权限验证:拦截器、切面和安全框架的比较

引言: 在开发Web应用程序时,对特定URL进行权限验证是一项常见的需求。在Spring Boot中,我们有多种选择来实现这一目标,其中包括使用拦截器、切面和专门的安全框架(如Spring Security)。本文将比较这三种方式…...

【能源数据分析-00】能源领域数据集集锦(动态更新)

一、前言 大数据科学在能源领域的深度应用,已经深刻改变了这一行业的垂直格局。它为我们提供了宝贵的见解,帮助降低下游市场的成本,使石油生产商能够更好地应对市场繁荣期的需求。近期,石油价格的剧烈下跌给全球经济带来了沉重打…...

数据挖掘与机器学习 1. 绪论

于高山之巅,方见大河奔涌;于群峰之上,便觉长风浩荡 —— 24.3.24 一、数据挖掘和机器学习的定义 1.数据挖掘的狭义定义 背景:大数据时代——知识贫乏 数据挖掘的狭义定义: 数据挖掘就是从大量的、不完全的、有噪声的、…...

Matlab实现序贯变分模态分解(SVMD)

大家好,我是带我去滑雪! 序贯变分模态分解(SVMD) 是一种信号处理和数据分析方法。它可以将复杂信号分解为一系列模态函数,每个模态函数代表信号中的特定频率分量。 SVMD 的主要目标是提取信号中的不同频率分量并将其重构为原始信号。SVMD的基…...

云安全与云计算的关系

云计算又被称为网格计算,是分布式计算的一种,能够将大量的数据计算处理程序通过网络“云”分解成多个小程序,然后将这些小程序的结果反馈给用户。云计算主要就是能够解决任务分发,并进行计算结果的合并。 云安全则是我国企业创造的…...

WPF 界面变量绑定(通知界面变化)

1、继承属性变化接口 public partial class MainWindow : Window, INotifyPropertyChanged {// 通知界面属性发生变化public event PropertyChangedEventHandler PropertyChanged;private void RaisePropertyChanged(string propertyName){PropertyChangedEventHandler handle…...

eclipse导入svn项目

1、配置maven 2、用svn引入项目 3一直点击next,到最后选完成。...

Prompt提示工程上手指南:基础原理及实践(四)-检索增强生成(RAG)策略下的Prompt

前言 此篇文章已经是本系列的第四篇文章,意味着我们已经进入了Prompt工程的深水区,掌握的知识和技术都在不断提高,对于Prompt的技巧策略也不能只局限于局部运用而要适应LLM大模型的整体框架去进行改进休整。较为主流的LLM模型框架设计可以基…...

阿里云倚天云服务器怎么样?如何收费?

阿里云倚天云服务器CPU采用倚天710处理器,租用倚天服务器c8y、g8y和r8y可以享受优惠价格,阿里云服务器网aliyunfuwuqi.com整理倚天云服务器详细介绍、倚天710处理器性能测评、CIPU架构优势、倚天服务器使用场景及生态支持: 阿里云倚天云服务…...

海外社交营销为什么用云手机?不用普通手机?

海外社交营销作为企业拓展海外市场的重要手段,正日益受到企业的青睐。云手机以其成本效益和全球性特征,成为海外社交营销领域的得力助手。那么,究竟是什么特性使得越来越多的企业选择利用云手机进行海外社交营销呢?下文将对此进行…...

【Mysql数据库基础05】子查询 where、from、exists子查询、分页查询

where、from、exists子查询、分页查询 1 where子查询1.1 where后面的标量子查询1.1.1 having后的标量子查询 1.2 where后面的列子查询1.3 where后面的行子查询(了解即可) 2 from子查询3 exists子查询(相关子查询)4 分页查询5 联合…...

在Linux/Debian/Ubuntu上通过 Azure Data Studio 管理 SQL Server 2019

Microsoft 提供 Azure Data Studio,这是一种可在 Linux、macOS 和 Windows 上运行的跨平台数据库工具。 它提供与 SSMS 类似的功能,包括查询、脚本编写和可视化数据。 要在 Ubuntu 上安装 Azure Data Studio,可以按照以下步骤操作&#xff1…...

Java代码基础算法练习-搬砖问题-2024.03.25

任务描述: m块砖,n人搬,男搬4,女搬3,两个小孩抬一砖,要求一次全搬完,问男、 女、小孩各若干? 任务要求: 代码示例: package M0317_0331;import java.util.S…...

Tomcat调优

1、调整线程数 <Connector port"8080" maxHttpHeaderSize"8192"maxThreads"1900" minSpareThreads"250" maxSpareThreads"750"enableLookups"false" redirectPort"8443" acceptCount"100"…...

每日OJ题_栈①_力扣1047. 删除字符串中的所有相邻重复项

目录 力扣1047. 删除字符串中的所有相邻重复项 解析代码 力扣1047. 删除字符串中的所有相邻重复项 1047. 删除字符串中的所有相邻重复项 难度 简单 给出由小写字母组成的字符串 S&#xff0c;重复项删除操作会选择两个相邻且相同的字母&#xff0c;并删除它们。 在 S 上反…...

SQLServer SEQUENCE用法

SEQUENCE&#xff1a;数据库中的序列生成器 在数据库管理中&#xff0c;经常需要生成唯一且递增的数值序列&#xff0c;用于作为主键或其他需要唯一标识的列的值。为了实现这一功能&#xff0c;SQL Server 引入了 SEQUENCE 对象。SEQUENCE 是一个独立的数据库对象&#xff0c;用…...

18年产品经理生涯精华:从交付到规划,项目管理、解决方案、业务理解深度解析!

本期访谈只有1位老师&#xff0c;大海老师&#xff0c;18年工作经验&#xff0c;从干交付&#xff0c;到项目管理&#xff0c;再到资深技术专家、解决方案专家&#xff0c;目前做的更多的是业务规划、产品规划&#xff0c;是从一线实战走到真正的专家层面&#xff0c;老师分享的…...

ABAP开发必备:5种处理前导0的实战技巧(附SQL代码示例)

ABAP开发必备&#xff1a;5种处理前导0的实战技巧&#xff08;附SQL代码示例&#xff09; 在SAP ABAP开发中&#xff0c;物料号、供应商号等关键字段经常需要处理前导0的问题。这些看似简单的数字格式差异&#xff0c;却可能引发数据查询失败、报表统计错误等一系列"蝴蝶效…...

​​​​​​​巧用API接口,数据驱动提升店铺DSR评分

前言 DSR评分&#xff08;Detail Seller Rating&#xff0c;卖家服务评级系统&#xff09;是衡量电商店铺综合服务质量的核心指标&#xff0c;直接影响店铺排名、流量分配和买家信任度。传统的提升方式如加强客服培训、优化物流等固然重要&#xff0c;但在大数据时代&#xff0…...

Electron实战:将你的网页应用打包成桌面客户端

在当今数字化时代&#xff0c;网页应用已经渗透到我们工作和生活的方方面面。有时我们仍然需要一个桌面客户端来提供更稳定的运行环境、离线功能或更好的系统集成。Electron作为一个强大的跨平台框架&#xff0c;能够帮助开发者轻松将网页应用打包成桌面客户端。无论是开发效率…...

别再只会用‘Let‘s think step by step’了:DeepSeek-R1原生思维链的实战调优指南

别再只会用‘Let‘s think step by step’了&#xff1a;DeepSeek-R1原生思维链的实战调优指南 当你在深夜调试一个复杂的代码生成任务时&#xff0c;模型突然输出了一个完全不符合预期的结果。你盯着屏幕&#xff0c;反复检查自己的prompt——明明已经加上了经典的"Lets …...

嵌入式系统接口技术详解与应用实践

1. 嵌入式系统接口技术概述在嵌入式系统开发中&#xff0c;接口技术是连接处理器与外部设备的关键桥梁。作为一名嵌入式开发工程师&#xff0c;我经常需要根据项目需求选择合适的接口方案。本文将基于多年实战经验&#xff0c;深入解析各类嵌入式接口的工作原理、应用场景和选型…...

如何永久保存B站缓存视频?m4s-converter开源工具完整使用指南

如何永久保存B站缓存视频&#xff1f;m4s-converter开源工具完整使用指南 【免费下载链接】m4s-converter 一个跨平台小工具&#xff0c;将bilibili缓存的m4s格式音视频文件合并成mp4 项目地址: https://gitcode.com/gh_mirrors/m4/m4s-converter 你是否曾经遇到过这样的…...

嵌入式软件缺陷预防与设计规范实战指南

1. 嵌入式软件缺陷预防与设计规范作为一名在嵌入式领域摸爬滚打多年的工程师&#xff0c;我见过太多因为软件缺陷导致的灾难性后果。从航天器失联到医疗设备故障&#xff0c;这些事故背后往往都隐藏着本可以在设计阶段就规避的代码问题。今天我想分享的是&#xff1a;为什么一个…...

LSM6DS3TR-C驱动开发指南:寄存器配置与嵌入式IMU工程实践

1. JoyIT_LSM6DS3TR-C库深度解析&#xff1a;面向嵌入式工程师的LSM6DS3TR-C驱动开发指南LSM6DS3TR-C是意法半导体&#xff08;STMicroelectronics&#xff09;推出的超低功耗、高精度6轴惯性测量单元&#xff08;IMU&#xff09;&#xff0c;集成三轴加速度计与三轴陀螺仪&…...

并联混合动力船舶能量管理策略与SOC约束优化研究

并联混合动力船舶能量管理策略与SOC约束优化研究 摘要 本文针对并联混合动力船舶能量管理问题,基于等效燃油消耗最小化策略(ECMS),构建了包含柴油机、电动机、电池及船舶动力学系统的仿真模型。通过调整电池荷电状态(SOC)约束范围,分析其对燃油经济性、电池寿命及系统…...