当前位置: 首页 > news >正文

【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围:

1. nn.CrossEntropyLoss

这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 softmax 函数和交叉熵损失的计算。它假设最后一层的输出没有经过归一化处理(不是概率形式),而是直接给出了各个类别的得分。该函数会自动计算每一样本对各类别的得分,应用softmax函数,然后计算交叉熵损失。

import torch
import torch.nn as nn# 假设 outputs 是模型的最后一层输出,shape 为 (batch_size, num_classes),targets 是 ground truth labels
outputs = torch.randn(100, 10)  # 对于10分类问题的100个样本的不归一化的预测值
targets = torch.randint(0, 10, (100,))  # 对应的真实类别loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, targets)
print(loss.item())

2. F.cross_entropy 函数

torch.nn.functional.cross_entropy 函数也是为了多分类问题设计的,但它接受的是 logits 或者已经经过 softmax 的概率。如果你的输出已经是经过 softmax 的概率,可以直接使用;否则,它会默认内部先执行 log_softmax

import torch.nn.functional as F# 假设 outputs 是未经 softmax 的 logits
outputs = torch.randn(100, 10)# 使用 F.cross_entropy 直接计算损失,无需单独进行 softmax
loss = F.cross_entropy(outputs, targets)
print(loss.item())

3. nn.BCEWithLogitsLoss 类(二分类问题)

对于二分类问题,尤其是sigmoid激活函数之后的结果,可以使用带Sigmoid的二元交叉熵损失函数,它同时完成 sigmoid 和 二元交叉熵损失的计算。

# 二分类问题,输出维度为 (batch_size, 1)
outputs = torch.randn(100, 1) # targets 是介于 [0, 1] 或 {-1, 1} 的值,表示正负样本
targets = torch.rand(100, 1) > 0.5  # 或者其他的二进制标签bce_loss = nn.BCEWithLogitsLoss()
loss = bce_loss(outputs, targets.float())
print(loss.item())

4. 手动计算交叉熵损失

当然,也可以手动组合 log_softmaxnll_loss 函数来计算交叉熵损失,这在特殊情况下可能会有用,比如需要对损失函数进行修改或者自定义的时候:

# 多分类问题,手动组合 log_softmax 和 nll_loss
output_logits = torch.randn(100, 10)
softmax_outputs = F.log_softmax(output_logits, dim=1)  # 计算 log_softmax
loss_manual = -torch.mean(torch.gather(softmax_outputs, 1, targets.unsqueeze(1)).squeeze())  # 使用 gather 和 mean 计算 NLL
assert torch.allclose(loss_manual, F.nll_loss(softmax_outputs, targets, reduction='mean'))  # 应该与 nll_loss 结果一致

在上述代码中,gather 函数用于从预测概率矩阵中按照目标标签索引出相应的对数概率,然后求平均得到最终的交叉熵损失。在多分类任务中,直接使用 F.nll_loss(log_softmax_outputs, targets) 是更加简洁的做法,等价于手动计算。而在二分类问题中,对应的手动计算方式则会涉及 sigmoidbinary_cross_entropy_with_logits 函数。

5. 补充说明

在交叉熵损失计算函数中:
L = − ∑ i = 1 n y i l o g ( S ( f θ ( x i ) ) ) L = -\sum_{i=1}^{n}{y_i}log(S(f_\theta(x_i))) L=i=1nyilog(S(fθ(xi)))
真实值 y i y_i yi可以是热编码后的结果,也可以不进行热编码。
虽然在Pytorch架构中,神经网络内流动的数据类型必须是float类型,但是Pytorch也提供了自动处理整数(int类型)标签的交叉熵损失函数(这里的“整数标签”指的是每个样本所属的真实类别,通常是一个从0开始的整数索引,对应着类别数量中的一个),这些函数会自动将整数标签转换为内部使用的one-hot编码格式,并计算交叉熵损失。
nn.CrossEntropyLoss为例,当输入给定的output是未经归一化的类别得分(logits),而target是整数标签时,这个损失函数会自动将整数标签转换为one-hot格式,然后再进行交叉熵损失的计算。这意味着用户不需要预先将目标标签转换为one-hot编码,损失函数内部会处理这样的转换过程。

import torch
import torch.nn as nn# 假设我们有一个批次的输出和对应的类别标签
outputs = torch.randn(64, 10)  # 这是一个批次的输出,共64个样本,10个类别
labels = torch.tensor([2, 7, 0, ..., 4], dtype=torch.long)  # 这是对应的整数类别标签loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, labels)print(f'Cross-entropy loss: {loss.item()}')

相关文章:

【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围: 1. nn.CrossEntropyLoss 类 这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 sof…...

机器学习:处理jira工单的分类问题

如何根据jira工单的category、reporter自动找到处理它的组呢?这是一个利用机器学习中knn算法的小实践. 目录 Knn算法 数据 示例 分割数据 选择Neighbors knn的优缺点 机器学习是一种技术,它的目的是给机器学习能力,让它们可以根据数据自己做决定,所以对于训练…...

后端常问面经之操作系统

请简要描述线程与进程的关系,区别及优缺点? 本质区别:进程是操作系统资源分配的基本单位,而线程是任务调度和执行的基本单位 在开销方面:每个进程都有独立的代码和数据空间(程序上下文),程序之…...

RK3568平台 iperf3测试网络性能

一.iperf3简介 iperf是一款开源的网络性能测试工具,主要用于测量TCP和UDP带宽性能。它可以在不同的操作系统上运行,包括Windows、Linux、macOS等。iperf具有简单易用、功能强大、高度可配置等特点,广泛应用于网络性能测试、网络故障诊断和网…...

Spring Boot中实现对特定URL的权限验证:拦截器、切面和安全框架的比较

引言: 在开发Web应用程序时,对特定URL进行权限验证是一项常见的需求。在Spring Boot中,我们有多种选择来实现这一目标,其中包括使用拦截器、切面和专门的安全框架(如Spring Security)。本文将比较这三种方式…...

【能源数据分析-00】能源领域数据集集锦(动态更新)

一、前言 大数据科学在能源领域的深度应用,已经深刻改变了这一行业的垂直格局。它为我们提供了宝贵的见解,帮助降低下游市场的成本,使石油生产商能够更好地应对市场繁荣期的需求。近期,石油价格的剧烈下跌给全球经济带来了沉重打…...

数据挖掘与机器学习 1. 绪论

于高山之巅,方见大河奔涌;于群峰之上,便觉长风浩荡 —— 24.3.24 一、数据挖掘和机器学习的定义 1.数据挖掘的狭义定义 背景:大数据时代——知识贫乏 数据挖掘的狭义定义: 数据挖掘就是从大量的、不完全的、有噪声的、…...

Matlab实现序贯变分模态分解(SVMD)

大家好,我是带我去滑雪! 序贯变分模态分解(SVMD) 是一种信号处理和数据分析方法。它可以将复杂信号分解为一系列模态函数,每个模态函数代表信号中的特定频率分量。 SVMD 的主要目标是提取信号中的不同频率分量并将其重构为原始信号。SVMD的基…...

云安全与云计算的关系

云计算又被称为网格计算,是分布式计算的一种,能够将大量的数据计算处理程序通过网络“云”分解成多个小程序,然后将这些小程序的结果反馈给用户。云计算主要就是能够解决任务分发,并进行计算结果的合并。 云安全则是我国企业创造的…...

WPF 界面变量绑定(通知界面变化)

1、继承属性变化接口 public partial class MainWindow : Window, INotifyPropertyChanged {// 通知界面属性发生变化public event PropertyChangedEventHandler PropertyChanged;private void RaisePropertyChanged(string propertyName){PropertyChangedEventHandler handle…...

eclipse导入svn项目

1、配置maven 2、用svn引入项目 3一直点击next,到最后选完成。...

Prompt提示工程上手指南:基础原理及实践(四)-检索增强生成(RAG)策略下的Prompt

前言 此篇文章已经是本系列的第四篇文章,意味着我们已经进入了Prompt工程的深水区,掌握的知识和技术都在不断提高,对于Prompt的技巧策略也不能只局限于局部运用而要适应LLM大模型的整体框架去进行改进休整。较为主流的LLM模型框架设计可以基…...

阿里云倚天云服务器怎么样?如何收费?

阿里云倚天云服务器CPU采用倚天710处理器,租用倚天服务器c8y、g8y和r8y可以享受优惠价格,阿里云服务器网aliyunfuwuqi.com整理倚天云服务器详细介绍、倚天710处理器性能测评、CIPU架构优势、倚天服务器使用场景及生态支持: 阿里云倚天云服务…...

海外社交营销为什么用云手机?不用普通手机?

海外社交营销作为企业拓展海外市场的重要手段,正日益受到企业的青睐。云手机以其成本效益和全球性特征,成为海外社交营销领域的得力助手。那么,究竟是什么特性使得越来越多的企业选择利用云手机进行海外社交营销呢?下文将对此进行…...

【Mysql数据库基础05】子查询 where、from、exists子查询、分页查询

where、from、exists子查询、分页查询 1 where子查询1.1 where后面的标量子查询1.1.1 having后的标量子查询 1.2 where后面的列子查询1.3 where后面的行子查询(了解即可) 2 from子查询3 exists子查询(相关子查询)4 分页查询5 联合…...

在Linux/Debian/Ubuntu上通过 Azure Data Studio 管理 SQL Server 2019

Microsoft 提供 Azure Data Studio,这是一种可在 Linux、macOS 和 Windows 上运行的跨平台数据库工具。 它提供与 SSMS 类似的功能,包括查询、脚本编写和可视化数据。 要在 Ubuntu 上安装 Azure Data Studio,可以按照以下步骤操作&#xff1…...

Java代码基础算法练习-搬砖问题-2024.03.25

任务描述: m块砖,n人搬,男搬4,女搬3,两个小孩抬一砖,要求一次全搬完,问男、 女、小孩各若干? 任务要求: 代码示例: package M0317_0331;import java.util.S…...

Tomcat调优

1、调整线程数 <Connector port"8080" maxHttpHeaderSize"8192"maxThreads"1900" minSpareThreads"250" maxSpareThreads"750"enableLookups"false" redirectPort"8443" acceptCount"100"…...

每日OJ题_栈①_力扣1047. 删除字符串中的所有相邻重复项

目录 力扣1047. 删除字符串中的所有相邻重复项 解析代码 力扣1047. 删除字符串中的所有相邻重复项 1047. 删除字符串中的所有相邻重复项 难度 简单 给出由小写字母组成的字符串 S&#xff0c;重复项删除操作会选择两个相邻且相同的字母&#xff0c;并删除它们。 在 S 上反…...

SQLServer SEQUENCE用法

SEQUENCE&#xff1a;数据库中的序列生成器 在数据库管理中&#xff0c;经常需要生成唯一且递增的数值序列&#xff0c;用于作为主键或其他需要唯一标识的列的值。为了实现这一功能&#xff0c;SQL Server 引入了 SEQUENCE 对象。SEQUENCE 是一个独立的数据库对象&#xff0c;用…...

2025政务服务便民热线创新发展会议顺利召开,张晨博士受邀分享

5月28日&#xff0c;由新华社中国经济信息社、新华社广东分社联合主办的2025政务服务便民热线创新发展暨“人工智能热线”会议在广州举行。会议围绕“人工智能与新质热线”主题&#xff0c;邀请全国的12345政务服务便民热线主管部门负责人、省市热线负责人和专家学者&#xff0…...

GC1808:高性能24位立体声音频ADC芯片解析

1. 芯片简介 GC1808 是一款24位立体声音频模数转换器&#xff08;ADC&#xff09;&#xff0c;支持96kHz采样率&#xff0c;集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器&#xff0c;适用于家庭影院、蓝牙音箱等场景。 核心特性 高精度&#xff1a;24位分辨率&#xff0c;…...

什么是DevOps智能平台的核心功能?

在数字化转型的浪潮中&#xff0c;DevOps智能平台已成为企业提升研发效能、加速产品迭代的核心工具。然而&#xff0c;许多人对“DevOps智能平台”的理解仍停留在“自动化工具链”的表层概念。今天&#xff0c;我们从一个真实场景切入&#xff1a;假设你是某互联网公司的技术负…...

无字母数字webshell的命令执行

在Web安全领域&#xff0c;WebShell是一种常见的攻击手段&#xff0c;通过它攻击者可以远程执行服务器上的命令&#xff0c;获取敏感信息或控制系统。而无字母数字WebShell则是其中一种特殊形式&#xff0c;通过避免使用字母和数字字符&#xff0c;来绕过某些安全机制的检测。 …...

分词算法BBPE详解和Qwen的应用

一、TL&#xff1b;DR BPE有什么问题&#xff1a;依旧会遇到OOV问题&#xff0c;并且中文、日文这些大词汇表模型容易出现训练中未出现过的字符Byte-level BPE怎么解决&#xff1a;与BPE一样是高频字节进行合并&#xff0c;但BBPE是以UTF-8编码UTF-8编码字节序列而非字符序列B…...

11.MySQL事务管理详解

MySQL事务管理详解 文章目录 MySQL事务管理 事务的概念 事务的版本支持 事务的提交方式 事务的相关演示 事务的隔离级别 查看与设置隔离级别 读未提交&#xff08;Read Uncommitted&#xff09; 读提交&#xff08;Read Committed&#xff09; 可重复读&#xff08;Repeatabl…...

高效绘制业务流程图!专业模板免费下载

在复杂的业务流程管理中&#xff0c;可视化工具已成为提升效能的核心基础设施。为助力开发者、项目经理及业务架构师高效落地流程标准化&#xff0c;本文将为你精选5套开箱即用的专业流程图模板。这些模板覆盖跨部门协作、电商订单、客户服务等高频场景&#xff0c;具备以下核心…...

uv管理spaCy语言模型

本文记录如何在使用uv管理python项目dependencies时&#xff0c;把spaCy的模型也纳入其中. spaCy 一、spaCy简介 spaCy是一个开源的自然语言处理&#xff08;NLP&#xff09;库&#xff0c;它主要用于处理文本数据。它支持多种语言&#xff0c;包括英语、中文等。它是由Expl…...

c++ Base58编码解码

Base58 字符集 Base58 使用 58 个字符进行编码&#xff0c;字符集为&#xff1a;123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz。注意&#xff1a;0&#xff08;零&#xff09;、O&#xff08;大写字母O&#xff09;、I&#xff08;大写字母I&#xff09;和 l&a…...

前端项目eslint配置选项详细解析

文章目录 1. 前言2、错误级别3、常用规则4、目前项目使用的.eslintrc.js 1. 前言 ‌ESLint‌ 是一个可配置的 JavaScript 代码检查工具&#xff0c;旨在帮助开发者发现并修复代码中的潜在问题&#xff0c;包括语法错误、逻辑错误以及风格不一致等问题。以下是其核心功能和特点…...