当前位置: 首页 > news >正文

【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围:

1. nn.CrossEntropyLoss

这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 softmax 函数和交叉熵损失的计算。它假设最后一层的输出没有经过归一化处理(不是概率形式),而是直接给出了各个类别的得分。该函数会自动计算每一样本对各类别的得分,应用softmax函数,然后计算交叉熵损失。

import torch
import torch.nn as nn# 假设 outputs 是模型的最后一层输出,shape 为 (batch_size, num_classes),targets 是 ground truth labels
outputs = torch.randn(100, 10)  # 对于10分类问题的100个样本的不归一化的预测值
targets = torch.randint(0, 10, (100,))  # 对应的真实类别loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, targets)
print(loss.item())

2. F.cross_entropy 函数

torch.nn.functional.cross_entropy 函数也是为了多分类问题设计的,但它接受的是 logits 或者已经经过 softmax 的概率。如果你的输出已经是经过 softmax 的概率,可以直接使用;否则,它会默认内部先执行 log_softmax

import torch.nn.functional as F# 假设 outputs 是未经 softmax 的 logits
outputs = torch.randn(100, 10)# 使用 F.cross_entropy 直接计算损失,无需单独进行 softmax
loss = F.cross_entropy(outputs, targets)
print(loss.item())

3. nn.BCEWithLogitsLoss 类(二分类问题)

对于二分类问题,尤其是sigmoid激活函数之后的结果,可以使用带Sigmoid的二元交叉熵损失函数,它同时完成 sigmoid 和 二元交叉熵损失的计算。

# 二分类问题,输出维度为 (batch_size, 1)
outputs = torch.randn(100, 1) # targets 是介于 [0, 1] 或 {-1, 1} 的值,表示正负样本
targets = torch.rand(100, 1) > 0.5  # 或者其他的二进制标签bce_loss = nn.BCEWithLogitsLoss()
loss = bce_loss(outputs, targets.float())
print(loss.item())

4. 手动计算交叉熵损失

当然,也可以手动组合 log_softmaxnll_loss 函数来计算交叉熵损失,这在特殊情况下可能会有用,比如需要对损失函数进行修改或者自定义的时候:

# 多分类问题,手动组合 log_softmax 和 nll_loss
output_logits = torch.randn(100, 10)
softmax_outputs = F.log_softmax(output_logits, dim=1)  # 计算 log_softmax
loss_manual = -torch.mean(torch.gather(softmax_outputs, 1, targets.unsqueeze(1)).squeeze())  # 使用 gather 和 mean 计算 NLL
assert torch.allclose(loss_manual, F.nll_loss(softmax_outputs, targets, reduction='mean'))  # 应该与 nll_loss 结果一致

在上述代码中,gather 函数用于从预测概率矩阵中按照目标标签索引出相应的对数概率,然后求平均得到最终的交叉熵损失。在多分类任务中,直接使用 F.nll_loss(log_softmax_outputs, targets) 是更加简洁的做法,等价于手动计算。而在二分类问题中,对应的手动计算方式则会涉及 sigmoidbinary_cross_entropy_with_logits 函数。

5. 补充说明

在交叉熵损失计算函数中:
L = − ∑ i = 1 n y i l o g ( S ( f θ ( x i ) ) ) L = -\sum_{i=1}^{n}{y_i}log(S(f_\theta(x_i))) L=i=1nyilog(S(fθ(xi)))
真实值 y i y_i yi可以是热编码后的结果,也可以不进行热编码。
虽然在Pytorch架构中,神经网络内流动的数据类型必须是float类型,但是Pytorch也提供了自动处理整数(int类型)标签的交叉熵损失函数(这里的“整数标签”指的是每个样本所属的真实类别,通常是一个从0开始的整数索引,对应着类别数量中的一个),这些函数会自动将整数标签转换为内部使用的one-hot编码格式,并计算交叉熵损失。
nn.CrossEntropyLoss为例,当输入给定的output是未经归一化的类别得分(logits),而target是整数标签时,这个损失函数会自动将整数标签转换为one-hot格式,然后再进行交叉熵损失的计算。这意味着用户不需要预先将目标标签转换为one-hot编码,损失函数内部会处理这样的转换过程。

import torch
import torch.nn as nn# 假设我们有一个批次的输出和对应的类别标签
outputs = torch.randn(64, 10)  # 这是一个批次的输出,共64个样本,10个类别
labels = torch.tensor([2, 7, 0, ..., 4], dtype=torch.long)  # 这是对应的整数类别标签loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, labels)print(f'Cross-entropy loss: {loss.item()}')

相关文章:

【深度学习】Pytorch中实现交叉熵损失计算的方式总结

在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围: 1. nn.CrossEntropyLoss 类 这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 sof…...

机器学习:处理jira工单的分类问题

如何根据jira工单的category、reporter自动找到处理它的组呢?这是一个利用机器学习中knn算法的小实践. 目录 Knn算法 数据 示例 分割数据 选择Neighbors knn的优缺点 机器学习是一种技术,它的目的是给机器学习能力,让它们可以根据数据自己做决定,所以对于训练…...

后端常问面经之操作系统

请简要描述线程与进程的关系,区别及优缺点? 本质区别:进程是操作系统资源分配的基本单位,而线程是任务调度和执行的基本单位 在开销方面:每个进程都有独立的代码和数据空间(程序上下文),程序之…...

RK3568平台 iperf3测试网络性能

一.iperf3简介 iperf是一款开源的网络性能测试工具,主要用于测量TCP和UDP带宽性能。它可以在不同的操作系统上运行,包括Windows、Linux、macOS等。iperf具有简单易用、功能强大、高度可配置等特点,广泛应用于网络性能测试、网络故障诊断和网…...

Spring Boot中实现对特定URL的权限验证:拦截器、切面和安全框架的比较

引言: 在开发Web应用程序时,对特定URL进行权限验证是一项常见的需求。在Spring Boot中,我们有多种选择来实现这一目标,其中包括使用拦截器、切面和专门的安全框架(如Spring Security)。本文将比较这三种方式…...

【能源数据分析-00】能源领域数据集集锦(动态更新)

一、前言 大数据科学在能源领域的深度应用,已经深刻改变了这一行业的垂直格局。它为我们提供了宝贵的见解,帮助降低下游市场的成本,使石油生产商能够更好地应对市场繁荣期的需求。近期,石油价格的剧烈下跌给全球经济带来了沉重打…...

数据挖掘与机器学习 1. 绪论

于高山之巅,方见大河奔涌;于群峰之上,便觉长风浩荡 —— 24.3.24 一、数据挖掘和机器学习的定义 1.数据挖掘的狭义定义 背景:大数据时代——知识贫乏 数据挖掘的狭义定义: 数据挖掘就是从大量的、不完全的、有噪声的、…...

Matlab实现序贯变分模态分解(SVMD)

大家好,我是带我去滑雪! 序贯变分模态分解(SVMD) 是一种信号处理和数据分析方法。它可以将复杂信号分解为一系列模态函数,每个模态函数代表信号中的特定频率分量。 SVMD 的主要目标是提取信号中的不同频率分量并将其重构为原始信号。SVMD的基…...

云安全与云计算的关系

云计算又被称为网格计算,是分布式计算的一种,能够将大量的数据计算处理程序通过网络“云”分解成多个小程序,然后将这些小程序的结果反馈给用户。云计算主要就是能够解决任务分发,并进行计算结果的合并。 云安全则是我国企业创造的…...

WPF 界面变量绑定(通知界面变化)

1、继承属性变化接口 public partial class MainWindow : Window, INotifyPropertyChanged {// 通知界面属性发生变化public event PropertyChangedEventHandler PropertyChanged;private void RaisePropertyChanged(string propertyName){PropertyChangedEventHandler handle…...

eclipse导入svn项目

1、配置maven 2、用svn引入项目 3一直点击next,到最后选完成。...

Prompt提示工程上手指南:基础原理及实践(四)-检索增强生成(RAG)策略下的Prompt

前言 此篇文章已经是本系列的第四篇文章,意味着我们已经进入了Prompt工程的深水区,掌握的知识和技术都在不断提高,对于Prompt的技巧策略也不能只局限于局部运用而要适应LLM大模型的整体框架去进行改进休整。较为主流的LLM模型框架设计可以基…...

阿里云倚天云服务器怎么样?如何收费?

阿里云倚天云服务器CPU采用倚天710处理器,租用倚天服务器c8y、g8y和r8y可以享受优惠价格,阿里云服务器网aliyunfuwuqi.com整理倚天云服务器详细介绍、倚天710处理器性能测评、CIPU架构优势、倚天服务器使用场景及生态支持: 阿里云倚天云服务…...

海外社交营销为什么用云手机?不用普通手机?

海外社交营销作为企业拓展海外市场的重要手段,正日益受到企业的青睐。云手机以其成本效益和全球性特征,成为海外社交营销领域的得力助手。那么,究竟是什么特性使得越来越多的企业选择利用云手机进行海外社交营销呢?下文将对此进行…...

【Mysql数据库基础05】子查询 where、from、exists子查询、分页查询

where、from、exists子查询、分页查询 1 where子查询1.1 where后面的标量子查询1.1.1 having后的标量子查询 1.2 where后面的列子查询1.3 where后面的行子查询(了解即可) 2 from子查询3 exists子查询(相关子查询)4 分页查询5 联合…...

在Linux/Debian/Ubuntu上通过 Azure Data Studio 管理 SQL Server 2019

Microsoft 提供 Azure Data Studio,这是一种可在 Linux、macOS 和 Windows 上运行的跨平台数据库工具。 它提供与 SSMS 类似的功能,包括查询、脚本编写和可视化数据。 要在 Ubuntu 上安装 Azure Data Studio,可以按照以下步骤操作&#xff1…...

Java代码基础算法练习-搬砖问题-2024.03.25

任务描述: m块砖,n人搬,男搬4,女搬3,两个小孩抬一砖,要求一次全搬完,问男、 女、小孩各若干? 任务要求: 代码示例: package M0317_0331;import java.util.S…...

Tomcat调优

1、调整线程数 <Connector port"8080" maxHttpHeaderSize"8192"maxThreads"1900" minSpareThreads"250" maxSpareThreads"750"enableLookups"false" redirectPort"8443" acceptCount"100"…...

每日OJ题_栈①_力扣1047. 删除字符串中的所有相邻重复项

目录 力扣1047. 删除字符串中的所有相邻重复项 解析代码 力扣1047. 删除字符串中的所有相邻重复项 1047. 删除字符串中的所有相邻重复项 难度 简单 给出由小写字母组成的字符串 S&#xff0c;重复项删除操作会选择两个相邻且相同的字母&#xff0c;并删除它们。 在 S 上反…...

SQLServer SEQUENCE用法

SEQUENCE&#xff1a;数据库中的序列生成器 在数据库管理中&#xff0c;经常需要生成唯一且递增的数值序列&#xff0c;用于作为主键或其他需要唯一标识的列的值。为了实现这一功能&#xff0c;SQL Server 引入了 SEQUENCE 对象。SEQUENCE 是一个独立的数据库对象&#xff0c;用…...

STM32F4基本定时器使用和原理详解

STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...

基于数字孪生的水厂可视化平台建设:架构与实践

分享大纲&#xff1a; 1、数字孪生水厂可视化平台建设背景 2、数字孪生水厂可视化平台建设架构 3、数字孪生水厂可视化平台建设成效 近几年&#xff0c;数字孪生水厂的建设开展的如火如荼。作为提升水厂管理效率、优化资源的调度手段&#xff0c;基于数字孪生的水厂可视化平台的…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂度…...

Python 训练营打卡 Day 47

注意力热力图可视化 在day 46代码的基础上&#xff0c;对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...

pycharm 设置环境出错

pycharm 设置环境出错 pycharm 新建项目&#xff0c;设置虚拟环境&#xff0c;出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...

【安全篇】金刚不坏之身:整合 Spring Security + JWT 实现无状态认证与授权

摘要 本文是《Spring Boot 实战派》系列的第四篇。我们将直面所有 Web 应用都无法回避的核心问题&#xff1a;安全。文章将详细阐述认证&#xff08;Authentication) 与授权&#xff08;Authorization的核心概念&#xff0c;对比传统 Session-Cookie 与现代 JWT&#xff08;JS…...

深度解析云存储:概念、架构与应用实践

在数据爆炸式增长的时代&#xff0c;传统本地存储因容量限制、管理复杂等问题&#xff0c;已难以满足企业和个人的需求。云存储凭借灵活扩展、便捷访问等特性&#xff0c;成为数据存储领域的主流解决方案。从个人照片备份到企业核心数据管理&#xff0c;云存储正重塑数据存储与…...

今日行情明日机会——20250609

上证指数放量上涨&#xff0c;接近3400点&#xff0c;个股涨多跌少。 深证放量上涨&#xff0c;但有个小上影线&#xff0c;相对上证走势更弱。 2025年6月9日涨停股主要行业方向分析&#xff08;基于最新图片数据&#xff09; 1. 医药&#xff08;11家涨停&#xff09; 代表标…...