【深度学习】Pytorch中实现交叉熵损失计算的方式总结
在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围:
1. nn.CrossEntropyLoss 类
这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 softmax 函数和交叉熵损失的计算。它假设最后一层的输出没有经过归一化处理(不是概率形式),而是直接给出了各个类别的得分。该函数会自动计算每一样本对各类别的得分,应用softmax函数,然后计算交叉熵损失。
import torch
import torch.nn as nn# 假设 outputs 是模型的最后一层输出,shape 为 (batch_size, num_classes),targets 是 ground truth labels
outputs = torch.randn(100, 10) # 对于10分类问题的100个样本的不归一化的预测值
targets = torch.randint(0, 10, (100,)) # 对应的真实类别loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, targets)
print(loss.item())
2. F.cross_entropy 函数
torch.nn.functional.cross_entropy 函数也是为了多分类问题设计的,但它接受的是 logits 或者已经经过 softmax 的概率。如果你的输出已经是经过 softmax 的概率,可以直接使用;否则,它会默认内部先执行 log_softmax。
import torch.nn.functional as F# 假设 outputs 是未经 softmax 的 logits
outputs = torch.randn(100, 10)# 使用 F.cross_entropy 直接计算损失,无需单独进行 softmax
loss = F.cross_entropy(outputs, targets)
print(loss.item())
3. nn.BCEWithLogitsLoss 类(二分类问题)
对于二分类问题,尤其是sigmoid激活函数之后的结果,可以使用带Sigmoid的二元交叉熵损失函数,它同时完成 sigmoid 和 二元交叉熵损失的计算。
# 二分类问题,输出维度为 (batch_size, 1)
outputs = torch.randn(100, 1) # targets 是介于 [0, 1] 或 {-1, 1} 的值,表示正负样本
targets = torch.rand(100, 1) > 0.5 # 或者其他的二进制标签bce_loss = nn.BCEWithLogitsLoss()
loss = bce_loss(outputs, targets.float())
print(loss.item())
4. 手动计算交叉熵损失
当然,也可以手动组合 log_softmax 和 nll_loss 函数来计算交叉熵损失,这在特殊情况下可能会有用,比如需要对损失函数进行修改或者自定义的时候:
# 多分类问题,手动组合 log_softmax 和 nll_loss
output_logits = torch.randn(100, 10)
softmax_outputs = F.log_softmax(output_logits, dim=1) # 计算 log_softmax
loss_manual = -torch.mean(torch.gather(softmax_outputs, 1, targets.unsqueeze(1)).squeeze()) # 使用 gather 和 mean 计算 NLL
assert torch.allclose(loss_manual, F.nll_loss(softmax_outputs, targets, reduction='mean')) # 应该与 nll_loss 结果一致
在上述代码中,gather 函数用于从预测概率矩阵中按照目标标签索引出相应的对数概率,然后求平均得到最终的交叉熵损失。在多分类任务中,直接使用 F.nll_loss(log_softmax_outputs, targets) 是更加简洁的做法,等价于手动计算。而在二分类问题中,对应的手动计算方式则会涉及 sigmoid 和 binary_cross_entropy_with_logits 函数。
5. 补充说明
在交叉熵损失计算函数中:
L = − ∑ i = 1 n y i l o g ( S ( f θ ( x i ) ) ) L = -\sum_{i=1}^{n}{y_i}log(S(f_\theta(x_i))) L=−i=1∑nyilog(S(fθ(xi)))
真实值 y i y_i yi可以是热编码后的结果,也可以不进行热编码。
虽然在Pytorch架构中,神经网络内流动的数据类型必须是float类型,但是Pytorch也提供了自动处理整数(int类型)标签的交叉熵损失函数(这里的“整数标签”指的是每个样本所属的真实类别,通常是一个从0开始的整数索引,对应着类别数量中的一个),这些函数会自动将整数标签转换为内部使用的one-hot编码格式,并计算交叉熵损失。
以nn.CrossEntropyLoss为例,当输入给定的output是未经归一化的类别得分(logits),而target是整数标签时,这个损失函数会自动将整数标签转换为one-hot格式,然后再进行交叉熵损失的计算。这意味着用户不需要预先将目标标签转换为one-hot编码,损失函数内部会处理这样的转换过程。
import torch
import torch.nn as nn# 假设我们有一个批次的输出和对应的类别标签
outputs = torch.randn(64, 10) # 这是一个批次的输出,共64个样本,10个类别
labels = torch.tensor([2, 7, 0, ..., 4], dtype=torch.long) # 这是对应的整数类别标签loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(outputs, labels)print(f'Cross-entropy loss: {loss.item()}')
相关文章:
【深度学习】Pytorch中实现交叉熵损失计算的方式总结
在PyTorch中,计算交叉熵损失主要有以下几种方式,它们针对不同的场景和需求有不同的实现方式和适用范围: 1. nn.CrossEntropyLoss 类 这是最常用且方便的方法,特别适用于多分类任务。nn.CrossEntropyLoss 实际上是同时完成了 sof…...
机器学习:处理jira工单的分类问题
如何根据jira工单的category、reporter自动找到处理它的组呢?这是一个利用机器学习中knn算法的小实践. 目录 Knn算法 数据 示例 分割数据 选择Neighbors knn的优缺点 机器学习是一种技术,它的目的是给机器学习能力,让它们可以根据数据自己做决定,所以对于训练…...
后端常问面经之操作系统
请简要描述线程与进程的关系,区别及优缺点? 本质区别:进程是操作系统资源分配的基本单位,而线程是任务调度和执行的基本单位 在开销方面:每个进程都有独立的代码和数据空间(程序上下文),程序之…...
RK3568平台 iperf3测试网络性能
一.iperf3简介 iperf是一款开源的网络性能测试工具,主要用于测量TCP和UDP带宽性能。它可以在不同的操作系统上运行,包括Windows、Linux、macOS等。iperf具有简单易用、功能强大、高度可配置等特点,广泛应用于网络性能测试、网络故障诊断和网…...
Spring Boot中实现对特定URL的权限验证:拦截器、切面和安全框架的比较
引言: 在开发Web应用程序时,对特定URL进行权限验证是一项常见的需求。在Spring Boot中,我们有多种选择来实现这一目标,其中包括使用拦截器、切面和专门的安全框架(如Spring Security)。本文将比较这三种方式…...
【能源数据分析-00】能源领域数据集集锦(动态更新)
一、前言 大数据科学在能源领域的深度应用,已经深刻改变了这一行业的垂直格局。它为我们提供了宝贵的见解,帮助降低下游市场的成本,使石油生产商能够更好地应对市场繁荣期的需求。近期,石油价格的剧烈下跌给全球经济带来了沉重打…...
数据挖掘与机器学习 1. 绪论
于高山之巅,方见大河奔涌;于群峰之上,便觉长风浩荡 —— 24.3.24 一、数据挖掘和机器学习的定义 1.数据挖掘的狭义定义 背景:大数据时代——知识贫乏 数据挖掘的狭义定义: 数据挖掘就是从大量的、不完全的、有噪声的、…...
Matlab实现序贯变分模态分解(SVMD)
大家好,我是带我去滑雪! 序贯变分模态分解(SVMD) 是一种信号处理和数据分析方法。它可以将复杂信号分解为一系列模态函数,每个模态函数代表信号中的特定频率分量。 SVMD 的主要目标是提取信号中的不同频率分量并将其重构为原始信号。SVMD的基…...
云安全与云计算的关系
云计算又被称为网格计算,是分布式计算的一种,能够将大量的数据计算处理程序通过网络“云”分解成多个小程序,然后将这些小程序的结果反馈给用户。云计算主要就是能够解决任务分发,并进行计算结果的合并。 云安全则是我国企业创造的…...
WPF 界面变量绑定(通知界面变化)
1、继承属性变化接口 public partial class MainWindow : Window, INotifyPropertyChanged {// 通知界面属性发生变化public event PropertyChangedEventHandler PropertyChanged;private void RaisePropertyChanged(string propertyName){PropertyChangedEventHandler handle…...
eclipse导入svn项目
1、配置maven 2、用svn引入项目 3一直点击next,到最后选完成。...
Prompt提示工程上手指南:基础原理及实践(四)-检索增强生成(RAG)策略下的Prompt
前言 此篇文章已经是本系列的第四篇文章,意味着我们已经进入了Prompt工程的深水区,掌握的知识和技术都在不断提高,对于Prompt的技巧策略也不能只局限于局部运用而要适应LLM大模型的整体框架去进行改进休整。较为主流的LLM模型框架设计可以基…...
阿里云倚天云服务器怎么样?如何收费?
阿里云倚天云服务器CPU采用倚天710处理器,租用倚天服务器c8y、g8y和r8y可以享受优惠价格,阿里云服务器网aliyunfuwuqi.com整理倚天云服务器详细介绍、倚天710处理器性能测评、CIPU架构优势、倚天服务器使用场景及生态支持: 阿里云倚天云服务…...
海外社交营销为什么用云手机?不用普通手机?
海外社交营销作为企业拓展海外市场的重要手段,正日益受到企业的青睐。云手机以其成本效益和全球性特征,成为海外社交营销领域的得力助手。那么,究竟是什么特性使得越来越多的企业选择利用云手机进行海外社交营销呢?下文将对此进行…...
【Mysql数据库基础05】子查询 where、from、exists子查询、分页查询
where、from、exists子查询、分页查询 1 where子查询1.1 where后面的标量子查询1.1.1 having后的标量子查询 1.2 where后面的列子查询1.3 where后面的行子查询(了解即可) 2 from子查询3 exists子查询(相关子查询)4 分页查询5 联合…...
在Linux/Debian/Ubuntu上通过 Azure Data Studio 管理 SQL Server 2019
Microsoft 提供 Azure Data Studio,这是一种可在 Linux、macOS 和 Windows 上运行的跨平台数据库工具。 它提供与 SSMS 类似的功能,包括查询、脚本编写和可视化数据。 要在 Ubuntu 上安装 Azure Data Studio,可以按照以下步骤操作࿱…...
Java代码基础算法练习-搬砖问题-2024.03.25
任务描述: m块砖,n人搬,男搬4,女搬3,两个小孩抬一砖,要求一次全搬完,问男、 女、小孩各若干? 任务要求: 代码示例: package M0317_0331;import java.util.S…...
Tomcat调优
1、调整线程数 <Connector port"8080" maxHttpHeaderSize"8192"maxThreads"1900" minSpareThreads"250" maxSpareThreads"750"enableLookups"false" redirectPort"8443" acceptCount"100"…...
每日OJ题_栈①_力扣1047. 删除字符串中的所有相邻重复项
目录 力扣1047. 删除字符串中的所有相邻重复项 解析代码 力扣1047. 删除字符串中的所有相邻重复项 1047. 删除字符串中的所有相邻重复项 难度 简单 给出由小写字母组成的字符串 S,重复项删除操作会选择两个相邻且相同的字母,并删除它们。 在 S 上反…...
SQLServer SEQUENCE用法
SEQUENCE:数据库中的序列生成器 在数据库管理中,经常需要生成唯一且递增的数值序列,用于作为主键或其他需要唯一标识的列的值。为了实现这一功能,SQL Server 引入了 SEQUENCE 对象。SEQUENCE 是一个独立的数据库对象,用…...
C++初阶-list的底层
目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...
MFC内存泄露
1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...
【入坑系列】TiDB 强制索引在不同库下不生效问题
文章目录 背景SQL 优化情况线上SQL运行情况分析怀疑1:执行计划绑定问题?尝试:SHOW WARNINGS 查看警告探索 TiDB 的 USE_INDEX 写法Hint 不生效问题排查解决参考背景 项目中使用 TiDB 数据库,并对 SQL 进行优化了,添加了强制索引。 UAT 环境已经生效,但 PROD 环境强制索…...
【android bluetooth 框架分析 04】【bt-framework 层详解 1】【BluetoothProperties介绍】
1. BluetoothProperties介绍 libsysprop/srcs/android/sysprop/BluetoothProperties.sysprop BluetoothProperties.sysprop 是 Android AOSP 中的一种 系统属性定义文件(System Property Definition File),用于声明和管理 Bluetooth 模块相…...
华为OD机考-机房布局
import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...
为什么要创建 Vue 实例
核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...
Python 训练营打卡 Day 47
注意力热力图可视化 在day 46代码的基础上,对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...
uniapp 集成腾讯云 IM 富媒体消息(地理位置/文件)
UniApp 集成腾讯云 IM 富媒体消息全攻略(地理位置/文件) 一、功能实现原理 腾讯云 IM 通过 消息扩展机制 支持富媒体类型,核心实现方式: 标准消息类型:直接使用 SDK 内置类型(文件、图片等)自…...
命令行关闭Windows防火墙
命令行关闭Windows防火墙 引言一、防火墙:被低估的"智能安检员"二、优先尝试!90%问题无需关闭防火墙方案1:程序白名单(解决软件误拦截)方案2:开放特定端口(解决网游/开发端口不通)三、命令行极速关闭方案方法一:PowerShell(推荐Win10/11)方法二:CMD命令…...
