pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度?
pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度?
- PyTorch交叉熵损失详解:为什么logits比targets多一个维度?
- 一、前言:新手常见困惑
- 二、核心概念:从考试得分到概率分布
- 1. logits:原始得分矩阵
- 2. targets:正确答案索引
- 三、维度差异的本质原因
- 1. 分类任务的数学需求
- 2. 维度对照表
- 3. 错误用法解析
- 四、手把手计算交叉熵损失
- 1. 输入数据
- 2. 计算步骤
- 步骤1:Softmax归一化
- 步骤2:提取正确类别的概率
- 步骤3:计算交叉熵
- 五、设计哲学深度解析
- 1. 为何不直接使用概率?
- 2. 多任务场景对照表
- 六、常见问题解答
- Q1:二分类能否用形状[N]的logits?
- Q2:如何处理多标签分类?
- Q3:为什么我的loss计算很慢?
- 七、总结
PyTorch交叉熵损失详解:为什么logits比targets多一个维度?
关键词:PyTorch交叉熵损失、logits维度、分类任务原理、深度学习基础
一、前言:新手常见困惑
许多初学PyTorch的朋友在使用交叉熵损失函数时,都会对logits和targets的维度关系感到困惑。典型的报错场景如下:
# 正确用法
logits = torch.tensor([[1.2, -0.5], [0.3, 2.1]]) # 形状 [2, 2]
targets = torch.tensor([0, 1]) # 形状 [2]# 错误用法(触发维度错误)
logits_error = torch.tensor([0.5, 1.2]) # 形状 [2]
targets_error = torch.tensor([0, 1]) # 形状 [2]
loss = F.cross_entropy(logits_error, targets_error) # 报错!
本文将用生活实例+手把手计算的方式,带你彻底理解交叉熵损失的维度设计逻辑。
二、核心概念:从考试得分到概率分布
1. logits:原始得分矩阵
想象你正在参加一场有2道选择题的考试,每道题有A、B两个选项。模型对每个选项给出原始得分:
logits = torch.tensor([[-1.0, 1.0], # 第1题:A得-1分,B得1分[-0.5, 1.5], # 第2题:A得-0.5分,B得1.5分[-0.5, 1.5] # 第3题(新增):同上
])
- 形状[3, 2]:3个样本(题目),每个样本2个类别(选项)
- 物理意义:未经归一化的"信心分数",数值越大表示模型越倾向该选项
2. targets:正确答案索引
targets = torch.tensor([0, 1, 1])
# 含义:第1题正确答案是A(索引0),第2、3题是B(索引1)
- 形状[3]:3个样本各对应一个正确答案位置
三、维度差异的本质原因
1. 分类任务的数学需求
- 模型需要为每个可能的类别提供判断依据
- 即使正确答案只有一个,也必须比较所有选项的"证据强度"
2. 维度对照表
| 张量 | 形状 | 物理意义 |
|---|---|---|
logits | [N, C] | N个样本,每个样本C个类别的得分 |
targets | [N] | N个样本的正确类别索引(n在0~c-1之间) |
3. 错误用法解析
若logits与targets同维度:
logits_error = torch.tensor([0.2, 0.7, 0.5]) # 形状[3]
targets = torch.tensor([0, 1, 1]) # 形状[3]
此时模型无法判断:
- 每个数值对应哪个类别?
- 如何进行多类别比较?
四、手把手计算交叉熵损失
以具体例子演示计算全过程:
1. 输入数据
logits = torch.tensor([[-1.0, 1.0], [-0.5, 1.5],[-0.5, 1.5]
]) # 形状[3,2]
targets = torch.tensor([0, 1, 1]) # 形状[3]
2. 计算步骤
步骤1:Softmax归一化
将原始得分转换为概率分布(每行和为1):
第1个样本([-1.0, 1.0]):
exp(-1.0) = 0.3679
exp(1.0) = 2.7183
总合 = 0.3679 + 2.7183 = 3.0862
概率 = [0.3679/3.0863 ≈ 0.1192, 2.7183/3.0863 ≈ 0.8808]
第2个样本([-0.5, 1.5]):
exp(-0.5) ≈ 0.6065
exp(1.5) ≈ 4.4817
总合 = 0.6065 + 4.4817 ≈ 5.0882
概率 = [0.6065/5.0882 ≈ 0.1192, 4.4817/5.0882 ≈ 0.8808]
步骤2:提取正确类别的概率
根据targets索引:
样本1:取索引0 → 0.1192
样本2:取索引1 → 0.8808
样本3:取索引1 → 0.8808
步骤3:计算交叉熵
公式:loss = -平均(ln(正确概率))
loss = -(ln(0.1192) + ln(0.8808) + ln(0.8808)) / 3= -[(-2.127) + (-0.127) + (-0.127)] / 3≈ 0.7937
验证PyTorch计算结果:
print(loss.item()) # 输出 0.7937
五、设计哲学深度解析
1. 为何不直接使用概率?
- 数值稳定性:直接处理指数运算易导致溢出
- 梯度优化:logits的线性特性更利于反向传播
2. 多任务场景对照表
| 任务类型 | logits形状 | targets形状 | 损失函数 |
|---|---|---|---|
| 二分类(2个选项) | [N,2] | [N] | CrossEntropyLoss |
| 多标签分类 | [N,C] | [N,C] | BCEWithLogitsLoss |
| 回归任务 | [N] | [N] | MSELoss |
六、常见问题解答
Q1:二分类能否用形状[N]的logits?
可以,但需配合sigmoid:
# 二分类特例
logits = torch.tensor([0.8, -0.3]) # 形状[2]
prob = torch.sigmoid(logits) # 转换为概率
loss = F.binary_cross_entropy(prob, targets)
Q2:如何处理多标签分类?
当每个样本可能有多个正确标签时:
logits = torch.tensor([[1.2, -0.5], [0.3, 2.1]]) # 形状[2,2]
targets = torch.tensor([[1, 0], [0, 1]]) # 形状[2,2] (one-hot)
loss = F.binary_cross_entropy_with_logits(logits, targets)
Q3:为什么我的loss计算很慢?
- 检查是否误用了for循环逐个样本计算
- 正确的向量化计算可加速百倍以上
七、总结
理解logits与targets的维度差异,关键在于把握分类任务的本质需求:
- logits提供全类别的判断依据 → 需要二维结构
- targets只需指出正确位置 → 一维索引足矣
掌握这一设计哲学后,你就能:
✅ 正确构建分类模型的输出层
✅ 快速调试维度相关的错误
✅ 深入理解损失函数的工作原理
练习建议:在Jupyter Notebook中复现本文的计算示例,尝试修改logits值观察loss变化。
相关阅读:
- PyTorch官方文档:CrossEntropyLoss
如有疑问欢迎留言讨论!
相关文章:
pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度?
pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度? PyTorch交叉熵损失详解:为什么logits比targets多一个维度?一、前言:新手常见困惑二、核心概念&…...
利用zabbix自带key获取数据
获取数据的三种方法 1、链接模版 服务器系统自身的监控 CPU CPU使用率、CPU负载 内存 内存剩余量 硬盘 关键性硬盘的剩余量、IO 网卡 流量/IO(流入流量、流出流量、总流量、错误数据包流量) 进程数 用户数 2、利用zabbix自带的键值key 1)监…...
无人机数据处理系统设计要点与难点!
一、系统设计要点 无人机数据处理系统需要高效、可靠、低延迟地处理多源异构数据(如影像、传感器数据、位置信息等),同时支持实时分析和长期存储。以下是核心设计要点: 1.数据采集与预处理 多传感器融合:集成摄像头…...
最大异或对 The XOR Largest Pair
题目来自洛谷网站: 思路: 两个循环时间复杂度太高了,会超时。 我们可以先将读入的数字,插入到字典树中,从高位到低位。对每个数查询的时候,题目要求是最大的异或对,所以我们选择相反的路径&am…...
基于SpringBoot + Vue 的汽车租赁管理系统
技术介绍: ①:架构: B/S、MVC ②:系统环境:Windows/Mac ③:开发环境:IDEA、JDK1.8、Maven、Mysql ④:技术栈:Java、Mysql、SpringBoot、Mybatis、Vue 项目功能: 角色&am…...
基于DrissionPage的TB商品信息采集与可视化分析
一、项目背景 随着电子商务的快速发展,淘宝作为中国最大的电商平台之一,拥有海量的商品信息。这些数据对于市场分析、用户行为研究以及竞争情报收集具有重要意义。然而,由于淘宝的反爬虫机制和复杂的页面结构,直接获取商品信息并不容易。尤其是在电商行业高速发展的今天,商…...
电气、电子信息与通信工程的探索与应用
从传统定义来看,电气工程是现代科技领域的核心学科和关键学科。它涵盖了创造产生电气与电子系统的有关学科的总和。然而,随着科学技术的飞速发展,电气工程的概念已经远超出这一范畴。 电子信息工程则是将电子技术、通信技术、计算机技术等应…...
Python备赛笔记2
1.区间求和 题目描述 给定a1……an一共N个整数,有M次查询,每次需要查询区间【L,R】的和。 输入描述: 第一行包含两个数:N,M 第二行输入N个整数 接下来的M行,每行有两个整数,L R,中间用空格隔开&…...
HTML5 拖放(Drag and Drop)学习笔记
一、HTML5 拖放简介 HTML5 拖放(Drag and Drop)是HTML5标准的一部分,允许用户抓取一个对象并将其拖动到另一个位置。拖放功能在现代网页中非常常见,例如文件上传、任务管理、布局调整等场景。 HTML5 拖放功能支持以下浏览器&…...
Sass (Scss) 与 Less 的区别与选择
Sass 与 Less 的区别与选择 1. 语法差异2. 特性与支持3. 兼容性4. 选择建议 在前端开发中,CSS预处理器如Sass(Syntactically Awesome Stylesheets)和Less被广泛使用,它们通过引入变量、嵌套规则、混合、函数等特性,使C…...
Unity2022发布Webgl2微信小游戏部分真机黑屏
复现规律: Unity PlayerSetting中取消勾选ShowSplashScreen 分析: 在Unity中,Splash Screen(启动画面) 不仅是视觉上的加载动画,还承担了关键的引擎初始化、资源预加载和渲染环境准备等底层逻辑。禁用后导…...
记一次线上SQL死锁事故
一、 引言 SQL死锁是一个常见且复杂的并发控制问题。当多个事务在数据库中互相等待对方释放锁时,就会形成死锁,从而导致事务无法继续执行,影响系统的性能和可用性。死锁不仅会导致数据库操作的阻塞,增加延迟,还可能对…...
Java并发编程 什么是分布式锁 跟其他的锁有什么区别 底层原理 实战讲解
目录 一、分布式锁的定义与核心作用 二、分布式锁与普通锁的核心区别 三、分布式锁的底层原理与实现方式 1. 核心实现原理 2. 主流实现方案对比 3. 关键技术细节 四、典型问题与解决方案 五、总结 六、具体代码实现 一、分布式锁的定义与核心作用 分布式锁是一种在分布…...
【react】在react中async/await一般用来实现什么功能
目录 基本概念 工作原理 优点 注意事项 底层原理 实际应用场景 1. 数据获取 (API 请求) 2. 表单提交 3. 异步状态管理 4. 异步路由切换 5. 异步数据预加载 6. 第三方 API 调用 7. 文件上传/下载 8. 路由导航拦截 关键注意事项 基本概念 async 函数:用…...
Axure项目实战:智慧城市APP(六)市民互动(动态面板、显示与隐藏)
亲爱的小伙伴,在您浏览之前,烦请关注一下,在此深表感谢! 课程主题:市民互动 主要内容:动态面板、显示与隐藏交互应用 应用场景:AI产品交互、互动类应用 案例展示: 案例视频&am…...
为何服务器监听异常?
报错: 执行./RCF后出现监听异常--在切换网络后,由于前面没有退出./RCF执行状态;重新连接后,会出现服务器监听异常 原因如下: 由于刚开始登录内网,切换之后再重新登录内网,并且切换网络的过程中…...
1.认识Excel
一 Excel 可以用来做什么 二 提升技巧 1.数据太多 2.计算太累 3.提升数据的价值和意义 4.团队协作 三 学习目标 学习目标不是为了掌握所有的技能,追逐新功能。而是学知识来解决需求,如果之前的技能和新出的技能都可以解决问题,那不学新技能也…...
目标跟踪——deepsort算法详细阐述
deepsort 算法详解 Unmatched Tracks(未匹配的轨迹) 本质角色: 是已存在的轨迹在当前帧中“失联”的状态,即预测位置与检测结果不匹配。 生命周期阶段: 已初始化: 轨迹已存在多帧,可能携带历史信息(如外观特征、运动模型)。 未被观测到: 当前帧中未找到对应的检测框…...
AI Agent 是什么?从 Chatbot 到自动化 Agent(LangChain、AutoGPT、BabyAGI)
1. 引言:AI Agent 的演进 AI Agent(人工智能智能体)是 AI 发展的重要方向之一。早期的 AI 主要以 Chatbot 形式存在,如客服机器人、智能助手等,主要基于 NLP 技术进行任务处理。而随着大模型(LLM)能力的提升,AI Agent 逐步演进为能够自主执行任务的智能体,如 AutoGPT…...
ngx_http_core_root
定义在 src\http\ngx_http_core_module.c static char * ngx_http_core_root(ngx_conf_t *cf, ngx_command_t *cmd, void *conf) {ngx_http_core_loc_conf_t *clcf conf;ngx_str_t *value;ngx_int_t alias;ngx_uint_t …...
python康复日记-request库的使用,爬虫自动化测试
一,request的简单应用 #1请求地址 URLhttps://example.com/login #2参数表单 form_data {username: admin,password: secret } #3返回的响应对象response response requests.post(URL,dataform_data,timeout5 ) #4处理返回结果,这里直接打印返回网页的…...
光谱范围与颜色感知的关系
光谱范围与颜色感知是光学、生理学及技术应用交叉的核心课题,两者通过波长分布、人眼响应及技术处理共同决定人类对色彩的认知。以下是其关系的系统解析: 1.基础原理:光谱范围与可见光 光谱范围定义: 电磁波谱中能被特定…...
OpenCV vs MediaPipe:哪种方案更适合实时手势识别?
引言 手势识别是计算机视觉的重要应用,在人机交互(HCI)、增强现实(AR)、虚拟现实(VR)、智能家居控制、游戏等领域有广泛的应用。实现实时手势识别的技术方案主要有基于传统计算机视觉的方法&am…...
el-select下拉框,搜索时,若是匹配后的数据有且只有一条,则当失去焦点时,默认选中该条数据
1、使用指令 当所需功能只能通过直接的 DOM 操作来实现时,才应该使用自定义指令。可使用方法2封装成共用函数,但用指令他人复用时比较便捷。 <el-tablev-loading"tableLoading"border:data"tableList"default-expand-allrow-key…...
网络地址转换技术(2)
NAT的配置方法: (一)静态NAT的配置方法 进入接口视图配置NAT转换规则 Nat static global 公网地址 inside 私网地址 内网终端PC2(192.168.20.2/24)与公网路由器AR1的G0/0/1(11.22.33.1/24)做…...
Python正则表达式(一)
目录 一、正则表达式的基本概念 1、基本概念 2、正则表达式的特殊字符 二、范围符号和量词 1、范围符号 2、匹配汉字 3、量词 三、正则表达式函数 1、使用正则表达式: 2、re.match()函数 3、re.search()函数 4、findall()函数 5、re.finditer()函数 6…...
【TI MSPM0】PWM学习
一、样例展示 #include "ti_msp_dl_config.h"int main(void) {SYSCFG_DL_init();DL_TimerG_startCounter(PWM_0_INST);while (1) {__WFI();} } TimerG0输出一对边缘对齐的PWM信号 TimerG0会输出一对62.5Hz的边缘对齐的PWM信号在PA12和PA13引脚上,PA12被…...
MySQL: 创建两个关联的表,用联表sql创建一个新表
MySQL: 创建两个关联的表 建表思路 USERS 表:包含用户的基本信息,像 ID、NAME、EMAIL 等。v_card 表:存有虚拟卡的相关信息,如 type 和 amount。关联字段:USERS 表的 V_CARD 字段和 v_card 表的 v_card 字段用于建立…...
更改 vscode ! + table 默认生成的 html 初始化模板
vscode ! 快速成的 html 代码默认为: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>D…...
使用LVS的 NAT 模式实现 3 台RS的轮询访问
节点规划 1、配置RS RS的网络配置为NAT模式,三台RS的网关配置为192.168.10.8 1.1配置RS1 1.1.1修改主机名和IP地址 [rootlocalhost ~]# hostnamectl hostname rs1 [rootlocalhost ~]# nmcli c modify ens160 ipv4.method manual ipv4.addresses 192.168.10.7/24…...
