当前位置: 首页 > article >正文

PyTorch 的 nn.NLLLoss:负对数似然损失全解析

PyTorch 的 nn.NLLLoss:负对数似然损失全解析

在 PyTorch 的损失函数家族中,nn.NLLLoss(Negative Log Likelihood Loss,负对数似然损失)是一个不太起眼但非常重要的成员。它经常跟 LogSoftmax 搭配出现,尤其在分类任务中扮演关键角色。今天我们就来聊聊 nn.NLLLoss 的数学原理、使用方法,以及它适用的场景,带你彻底搞懂这个损失函数。

1. 什么是负对数似然损失?

先从名字拆解:

  • 似然(Likelihood):在统计学中,似然表示“给定模型参数时,观察到数据的概率”。对数似然(Log Likelihood)是它的对数形式,常用于简化计算。
  • 负对数似然(Negative Log Likelihood, NLL):把对数似然取负数,作为损失函数,目标是最小化它。

在机器学习中,负对数似然损失通常用来衡量模型预测的概率分布与真实标签的差距,尤其是在分类任务中。

数学公式

假设我们有一个多分类任务,有 ( C C C ) 个类别。对于一个样本:

  • ( y ^ \hat{y} y^ ) 是模型输出的概率分布,比如经过 Softmax 或 LogSoftmax 处理后的结果。
  • ( y y y ) 是真实类别,用索引表示(比如 2 表示第 2 类)。

nn.NLLLoss 的公式是:

NLL = − 1 N ∑ i = 1 N log ⁡ ( y ^ i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(\hat{y}_{i, y_i}) NLL=N1i=1Nlog(y^i,yi)

  • ( N N N ):样本数量(batch size)。
  • ( y i y_i yi ):第 ( i i i ) 个样本的真实类别索引。
  • ( y ^ i , y i \hat{y}_{i, y_i} y^i,yi ):第 ( i i i ) 个样本在真实类别 ( y i y_i yi ) 上的预测概率(对数值)。

简单来说,nn.NLLLoss 取预测概率的对数(已经由 LogSoftmax 计算好),然后取负号,只关心正确类别的概率值。

2. 为什么搭配 LogSoftmax

你可能会注意到,nn.NLLLoss 的文档里总是提到“通常与 LogSoftmax 搭配使用”。这是为什么?

  • 模型输出:神经网络的最后一层通常输出未归一化的 logits(比如 [1.0, 2.0, 0.5]),而不是概率。
  • Softmax:将 logits 转为概率分布,比如 [0.2, 0.5, 0.3],满足 ( ∑ y ^ = 1 \sum \hat{y} = 1 y^=1)。公式是:
    y ^ j = e z j ∑ k = 1 C e z k \hat{y}_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} y^j=k=1Cezkezj
  • LogSoftmax:在 Softmax 基础上取对数,输出的是对数概率,比如 [-1.6, -0.7, -1.2]。公式是:
    log ⁡ ( y ^ j ) = z j − log ⁡ ( ∑ k = 1 C e z k ) \log(\hat{y}_j) = z_j - \log(\sum_{k=1}^{C} e^{z_k}) log(y^j)=zjlog(k=1Cezk)

nn.NLLLoss 要求输入是对数概率(log probabilities),而不是原始概率。所以:

  • 如果你直接给它 Softmax 后的概率,会出错,因为它期待的是 ( log ⁡ ( y ^ ) \log(\hat{y}) log(y^))。
  • LogSoftmax 处理后,输入正好符合要求,计算时直接取负号即可。
3. 代码使用示例

我们来看一个简单的例子,展示 nn.NLLLossLogSoftmax 的搭配:

import torch
import torch.nn as nn# 假设一个 3 分类任务,batch_size = 2
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]])  # 原始 logits
target = torch.tensor([1, 2])  # 真实类别索引,0~2# 定义 LogSoftmax 和 NLLLoss
log_softmax = nn.LogSoftmax(dim=1)  # dim=1 表示在类别维度上归一化
loss_fn = nn.NLLLoss()# 计算损失
log_probs = log_softmax(logits)  # 先转为对数概率
loss = loss_fn(log_probs, target)
print("NLL Loss:", loss.item())

运行过程

  1. logits[batch_size, num_classes] 的张量,表示每个样本在每个类别上的得分。
  2. nn.LogSoftmax 把 logits 转为对数概率,比如 [[-1.9, -0.9, -2.4], [-2.3, -1.9, -0.4]]
  3. nn.NLLLoss 提取每个样本在真实类别上的对数概率(比如第一个样本取 -0.9,第二个取 -0.4),取负并平均。

输出可能是 1.15,具体值取决于输入。

4. 与 nn.CrossEntropyLoss 的关系

你可能听说过 nn.CrossEntropyLoss,它也很常见。实际上:

  • nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss
    PyTorch 把这两步合二为一,直接接受 logits 作为输入,内部自动完成 LogSoftmax 和 NLL 计算。具体过程可以参考笔者的另一篇博客:Pytorch为什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss?

代码对比:

# 用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())  # 与上面结果相同
  • 区别
    • nn.NLLLoss:输入是对数概率,需手动加 LogSoftmax
    • nn.CrossEntropyLoss:输入是 logits,自动处理。
5. 使用场景

nn.NLLLoss 适用于以下场景:

  • 多分类任务:比如图像分类(CIFAR-10 的 10 类)、文本分类。
  • 需要分离 Softmax 的情况
    • 你想在模型里显式控制 LogSoftmax 的位置,而不是交给损失函数。
    • 调试时单独检查对数概率的值。
  • 概率输出的模型:如果你的模型已经输出对数概率(比如某些预训练模型),直接用 nn.NLLLoss 更高效。

典型例子

  • 一个简单的 CNN 分类器:
    class SimpleCNN(nn.Module):def __init__(self):super().__init__()self.conv = nn.Conv2d(1, 16, 3)self.fc = nn.Linear(16 * 26 * 26, 10)  # 假设 28x28 输入self.log_softmax = nn.LogSoftmax(dim=1)def forward(self, x):x = self.conv(x)x = x.view(x.size(0), -1)x = self.fc(x)return self.log_softmax(x)model = SimpleCNN()
    loss_fn = nn.NLLLoss()
    
    这里模型输出对数概率,搭配 nn.NLLLoss 计算损失。
6. 注意事项
  • 输入形状
    • 输入:[batch_size, num_classes](对数概率)。
    • 目标:[batch_size](类别索引)。
  • 目标类型:必须是整数(long 类型),不能是 one-hot 或浮点数。
  • 数值稳定性LogSoftmax 比单独的 Softmax + log 更稳定,因为它避免了溢出问题。
7. 小结:nn.NLLLoss 的核心
  • 数学原理:计算正确类别对数概率的负值,最小化它等价于最大化似然。
  • 使用方式:搭配 LogSoftmax,输入对数概率,输出标量损失。
  • 场景:多分类任务,尤其是需要显式控制概率计算时。
  • CrossEntropyLoss 的关系:前者是后者的组成部分,功能更模块化。

nn.NLLLoss 就像一个“半成品”,需要你自己搭配 LogSoftmax,但这也给了你更多灵活性。相比直接用 nn.CrossEntropyLoss,它更适合喜欢拆解步骤或调试模型的开发者。

8. 调试小技巧
  • 检查输入:打印 log_probs 确保是对数概率(负值)。
  • 验证目标:确保 target 是整数,且范围在 [0, num_classes-1]
  • 对比结果:用 nn.CrossEntropyLoss 验证是否一致。

希望这篇博客让你对 nn.NLLLoss 有了全面认识!

后记

2025年2月28日18点59分于上海,在Grok3大模型辅助下完成。

相关文章:

PyTorch 的 nn.NLLLoss:负对数似然损失全解析

PyTorch 的 nn.NLLLoss:负对数似然损失全解析 在 PyTorch 的损失函数家族中,nn.NLLLoss(Negative Log Likelihood Loss,负对数似然损失)是一个不太起眼但非常重要的成员。它经常跟 LogSoftmax 搭配出现,尤…...

《数据安全架构设计与实战》的目录

《数据安全架构设计与实战》的目录如下1: 第一部分 安全架构基础 第 1 章 架构第 2 章 安全架构 2.1 什么是安全2.2 为什么使用 “数据安全” 这个术语2.3 什么是安全架构2.4 安全架构 5A 方法论2.5 安全架构 5A 与 CIA 的关系 第二部分 产品安全架构 …...

python-leetcode-删除并获得点数

740. 删除并获得点数 - 力扣(LeetCode) 解法 1:动态规划(O(n) 时间,O(n) 空间) class Solution:def deleteAndEarn(self, nums: List[int]) -> int:if not nums:return 0# 统计每个数的贡献points Cou…...

助力DeepSeek私有化部署服务:让企业AI落地更简单、更安全

在数字化转型的浪潮中,越来越多的企业选择私有化部署AI技术,以保障数据安全、提升业务效率并实现自主可控。DeepSeek作为行业领先的AI开源技术,其技术可以支持企业私有化部署,企业需要一站式服务私有化部署,涵盖硬件采…...

Android JNI开发指南

了解Android JNI开发。首先,我需要回忆一下JNI的基本概念和作用。JNI允许Java代码与本地代码(如C/C)交互,这在需要高性能或者使用现有C/C库时非常有用。 接下来,我应该整理开发步骤。用户可能想知道如何开始&#xff…...

【每天认识一个漏洞】url重定向

🌝博客主页:菜鸟小羊 💖专栏:Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 常见应用场景 主要是业务逻辑中需要进行跳转的地方。比如登录处、注册处、访问用户信息、订单信息、加入购物车、分享、收…...

纯代码实战--用Deepseek+SQLite+Ollama搭建数据库助手

如何用Python调用本地模型实现DeepSeek提示词模板:一步步教你高效解决13种应用场景 从零到一:纯代码联合PyQt5、Ollama、Deepseek打造简易版智能聊天助手 用外接知识库武装大模型:基于Deepseek、Ollama、LangChain的RAG实战解析 纯代码实战–…...

2025 最新版鸿蒙 HarmonyOS 开发工具安装使用指南

为保证 DevEco Studio 正常运行,建议电脑配置满足如下要求: Windows 系统 操作系统:Windows10 64 位、Windows11 64 位内存:16GB 及以上硬盘:100GB 及以上分辨率:1280*800 像素及以上 macOS 系统 操作系统…...

日期时间 API

日期时间 API (java.time 包),旨在解决旧版 java.util.Date 和 java.util.Calendar 存在的一些设计缺陷,比如线程不安全、时区处理不一致等问题。新 API 基于 ISO 8601 标准,更加直观、简洁,且支持时区和区域设置。主要类有&#…...

AI数字人开发,引领科技新潮流

引言 随着人工智能技术的迅猛发展,AI 数字人在影视娱乐、客户服务、教育及医疗等多个领域展现出巨大的潜力。本文旨在为开发者提供一份详细的 AI 数字人系统开发指南,涵盖从基础架构到实现细节的各个方面,包括人物建模、动作生成、语音交互、…...

领域驱动设计:事件溯源架构简介

概述 事件溯源架构通常由3种应用设计模式组成,分别是:事件驱动(Event Driven),事件溯源(Event Source)、CQRS(读写分离)。这三种应用设计模式常见于领域驱动设计(DDD)中,但它们本身是一种应用设计的思想,不仅仅局限于DDD,每一种模式都可以单独拿出来使用。 E…...

自定义类加载器国密版本冲突

自定义类加载器国密版本冲突 对接三方接口经常使用到国密加密包(bcprov),此时系统已经引入了1.5版本,而三方提供的sdk中引用了1.6版版本,两个版本有冲突,如果系统加载到1.5版本的将会加密异常(各种奇怪的异…...

‌Debian 包版本号比较规则详解

1 版本号组成结构 Debian 版本号格式为:[epoch:]upstream_version[-debian_revision] 示例‌:2:1.18.3~betadfsg1-5b1 组件说明比较优先级‌Epoch‌冒号前的数字 (2:)最高‌Upstream‌主版本 (1.18.3~betadfsg1)中‌Debian修订号‌减号后的部分 (5)最…...

STM32之影子寄存器

预分频寄存器计数到一半的时候,改变预分频值,此时不会立即生效,会等到计数完成,再从影子寄存器即预分频缓冲器里装载修改的预分频值。 如上图,第一行是内部时钟72M,第二行是时钟使能,高电平启动…...

x64汇编下过程参数解析

简介 好久没上博客, 突然发现我的粉丝数变2700了, 真是这几个月涨的粉比我之前好几年的都多, 于是心血来潮来写一篇, 记录一下x64下的调用约定(这里的调用约定只针对windows平台) Windows下的x64程序的调用约定有别于x86下的"stdcall调用约定"以及"cdecl调用约…...

Blender调整最佳渲染清晰度

1.渲染采样调高 512 2.根据需要 开启AO ,开启辉光 , 开启 屏幕空间反射 3.调高分辨率 4096x4096 100% 分辨率是清晰度的关键 , 分辨率不高 , 你其他参数调再高都没用 4.世界环境开启体积散射 , 可以增强氛围感 5.三点打光法 放在模型和相机45夹角上 白模 白模带线条 成品...

TSMaster【第二十篇:华山论剑——知识图谱全览】

(三维思维导图「独孤九剑总诀式」技能树「经脉贯通」检测系统未来技术「武学秘境」预测) 【武侠场景导入】光明顶秘道惊变 明教光明顶密道中,张无忌面对错综复杂的甬道体系,以乾坤大挪移心法贯通九阳神功与太极拳剑,终成武林至尊。今时今日,三电工程师面对庞杂的TSMaste…...

神经性手抖是一种常见的症状

神经性手抖是一种常见的症状,表现为手部无意识或不受控制地颤抖。为了预防神经性手抖,我们可以采取以下几种方法: 1. 放松身心:压力和焦虑是导致神经性手抖的常见原因之一。因此,学会放松身心是预防手抖的关键。可以通…...

前端项目打包生成 JS 文件的核心步骤

前端项目打包生成 JS 文件的过程通常涉及以下核心步骤,以主流工具(如 Webpack、Vite、Rollup 等)为例: 一、项目准备阶段 项目结构 源代码目录(如 src/)包含 JS/TS、CSS、图片等资源配置文件(pa…...

金融支付行业技术侧重点

1. 合规问题 第三方支付系统的平稳运营,严格遵循《非银行支付机构监督管理条例》的各项条款是基础与前提,其中第十八条的规定堪称重中之重,是支付机构必须牢牢把握的关键准则。 第十八条明确指出,非银行支付机构需构建起必要且独…...

支付宝 IoT 设备入门宝典(下)设备经营篇

上篇介绍了支付宝 IoT 设备管理,但除了这些基础功能外,商户还可以利用设备进行一些运营动作,让设备更好的帮助自己,本篇就会以设备经营为中心,介绍常见的设备相关能力和问题解决方案。如果对上篇感兴趣,可以…...

mac电脑中使用无线诊断.app查看连接的Wi-Fi带宽

问题 需要检查连接到的Wi-Fi的AP硬件支持的带宽。 步骤 1.按住 Option 键,然后点击屏幕顶部的Wi-Fi图标;2.从下拉菜单中选择 “打开无线诊断”(Open Wireless Diagnostics);3.你可能会看到一个提示窗口,…...

企业微信里可以使用的企业内刊制作工具,FLBOOK

如何让员工及时了解公司动态、行业资讯、学习专业知识,并有效沉淀企业文化?一份高质量的企业内刊是不可或缺的。现在让我来教你该怎么制作企业内刊吧 1.登录与上传 访问FLBOOK官网,注册账号后上传排版好的文档 2.选择模板 FLBOOK提供了丰富的…...

网络变压器的主要电性参数与测试方法(2)

Hqst盈盛(华强盛)电子导读:网络变压器的主要电性参数与测试方法(2).. 今天我们继续来看看网络变压器的2个主要电性参数与它的测试方法: 1. 线圈间分布电容Cp:线圈间杂散静电容 测试条件:100KHz/0.1…...

深度学习笔记17-马铃薯病害识别(VGG-16复现)

目录 一、 前期准备 1. 设置GPU 2. 导入数据 二、手动搭建VGG-16模型 1. 搭建模型 三、 训练模型 1. 编写训练函数 3. 编写测试函数 4. 正式训练 四、 结果可视化 1. Loss与Accuracy图 2. 指定图片进行预测 3. 模型评估 前言 🍨 本文为🔗365天深度学习训…...

#7 Diffusion for beginners

DDPM的原理讲解视频:DDPM explain,就是口音一言难尽 还有大佬从零开始搭建模型代码的视频:DDPM implementation,相当震撼,代码我从来都是粗粗的看个大概了事,大佬直接手撕 一个很好的资源集合网站:https://diff-usion.github.io/Awesome-Diffusion-Models/ 今天学习一段…...

【计算机网络】TCP三次握手,四次挥手以及SYN,ACK,seq,以及握手次数理解

TCP三次握手图解 描述 第一次握手:客户端请求建立连接,发送同步报文(SYN1),同时随机一个seqx作为初始序列号,进入SYN_SENT状态,等待服务器确认 第二次握手:服务端收到请求报文,如果同意建立连接…...

【Java】System 类

目录 静态字段标准输入输出流相关 常用静态方法数组操作时间操作系统操作属性操作安全管理 其他方法 System 类位于 java.lang 包下,是一个 final 类,意味着它不能被继承。并且其所有构造方法都是私有的,这使得我们无法创建 System 类的实例&…...

Zookeeper(80)Zookeeper的常见问题有哪些?

Zookeeper作为分布式系统的协调服务,常见的问题主要集中在配置、性能、连接管理、数据一致性和节点故障等方面。以下是一些常见问题及其详细解决方法和代码示例。 1. 配置问题 问题描述 配置不当可能导致 Zookeeper 集群无法正常启动或运行效率低下。 解决方法 …...

docker通用技术介绍

docker通用技术介绍 1.docker介绍 1.1 基本概念 docker是一个开源的容器化平台,用于快速构建、打包、部署和运行应用程序。它通过容器化技术将应用及其依赖环境(如代码、库、系统工具等)打包成一个标准化、轻量级的独立单元,实…...