PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化(中英双语)
PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化
在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致数值溢出。torch.logsumexp 正是为了解决这一问题而设计的。
在本文中,我们将详细介绍:
torch.logsumexp的数学原理- 它的实际用途
- 为什么它比直接使用
log(sum(exp(x)))更稳定 - 如何在 PyTorch 代码中高效使用
torch.logsumexp
1. torch.logsumexp 是什么?
1.1 数学公式
torch.logsumexp(x, dim) 计算以下数学表达式:
log ∑ i e x i \log \sum_{i} e^{x_i} logi∑exi
其中:
- ( x i x_i xi ) 是输入张量中的元素,
dim指定沿哪个维度执行计算。
1.2 为什么不直接计算 log(sum(exp(x)))?
假设我们有一个很大的数值 ( x ),比如 x = 1000,如果直接计算:
import torchx = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp) # 结果是 inf(溢出)
问题: exp(1000) 太大,超出了浮点数表示范围,导致溢出。
torch.logsumexp 解决方案:
log ∑ i e x i = x max + log ∑ i e ( x i − x max ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logi∑exi=xmax+logi∑e(xi−xmax)
- 核心思想:先减去最大值 ( x max x_{\max} xmax )(防止指数爆炸),然后再计算指数和的对数。
- 这样能避免溢出,提高数值稳定性。
使用 torch.logsumexp:
log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable) # 正常输出
它不会溢出,因为先减去了最大值,再进行 log 操作。
2. torch.logsumexp 的实际应用
2.1 用于计算 softmax
Softmax 计算公式:
softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=∑jexjexi
取对数后,得到对数 softmax(log-softmax):
log P ( x i ) = x i − log ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xi−logj∑exj
PyTorch 代码:
import torchx = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)
这避免了指数溢出,比直接计算 torch.log(torch.sum(torch.exp(x))) 更稳定。
2.2 用于计算交叉熵损失
交叉熵(Cross-Entropy)计算:
L = − ∑ i y i log P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=−i∑yilogP(xi)
其中 ( P ( x i ) P(x_i) P(xi) ) 通过 softmax 计算得到,而 torch.logsumexp 让 softmax 的分母计算更稳定。
2.3 在 Transformer 模型中的应用
在 GPT、BERT 等 Transformer 语言模型 训练过程中,我们通常会计算 token_log_probs,如下:
import torchlogits = torch.randn(4, 5) # 假设 batch_size=4, vocab_size=5
input_ids = torch.tensor([1, 2, 3, 4]) # 假设真实的 token 位置# 计算每个 token 的对数概率
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_valuesprint(token_log_probs)
这里 torch.logsumexp(logits, dim=-1) 用于计算 softmax 分母的对数值,确保概率计算不会溢出。
3. torch.logsumexp 的性能优化
3.1 为什么 torch.logsumexp 比 log(sum(exp(x))) 更快?
- 避免额外存储
exp(x):如果先exp(x),再sum(),会生成一个额外的大张量,而logsumexp直接在 C++/CUDA 内部优化了计算。 - 减少数值溢出:减少浮点数不必要的运算,防止梯度爆炸。
3.2 实测性能
import timex = torch.randn(1000000)start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")
结果(示例):
torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s
torch.logsumexp 速度更快,并且避免了 exp(x) 可能导致的溢出。
4. 总结
torch.logsumexp(x, dim)计算log(sum(exp(x))),但使用数值稳定的方法,防止溢出。- 常见应用:
- Softmax 计算
- 交叉熵损失
- 语言模型的 token log prob 计算
- 比
log(sum(exp(x)))更稳定且更快,适用于大规模深度学习任务。
建议:
🚀 在涉及 log(sum(exp(x))) 计算时,尽量使用 torch.logsumexp,可以大幅提升数值稳定性和计算效率! 🚀
Understanding torch.logsumexp: Mathematical Foundation, Use Cases, and Performance Optimization
In deep learning, especially in probability models, computing logarithmic probabilities in a numerically stable way is crucial. Directly applying log(sum(exp(x))) can lead to numerical instability due to floating-point overflow. torch.logsumexp is designed to solve this problem efficiently.
In this article, we will cover:
- The mathematical foundation of
torch.logsumexp - Why it is useful and how it prevents numerical instability
- Key applications in deep learning
- Performance optimization compared to naive
log(sum(exp(x)))
1. What is torch.logsumexp?
1.1 Mathematical Formula
torch.logsumexp(x, dim) computes the following function:
log ∑ i e x i \log \sum_{i} e^{x_i} logi∑exi
where:
- ( x i x_i xi ) represents elements of the input tensor,
dimspecifies the dimension along which to perform the operation.
1.2 Why Not Directly Compute log(sum(exp(x)))?
Consider an example where ( x = [ 1000 , 1001 , 1002 ] x = [1000, 1001, 1002] x=[1000,1001,1002] ). If we naively compute:
import torchx = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp) # Output: inf (overflow)
Problem:
exp(1000)is too large, exceeding the floating-point limit, causing an overflow.
Solution: Log-Sum-Exp Trick
To prevent overflow, torch.logsumexp applies the following transformation:
log ∑ i e x i = x max + log ∑ i e ( x i − x max ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logi∑exi=xmax+logi∑e(xi−xmax)
where ( x max x_{\max} xmax ) is the maximum value in ( x x x ).
- By subtracting ( x max x_{\max} xmax ) first, the exponentials are smaller and won’t overflow.
Example using torch.logsumexp:
log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable) # Outputs a valid value without overflow
This is more numerically stable.
2. Key Applications of torch.logsumexp
2.1 Computing Softmax in Log Space
The Softmax function is defined as:
softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=∑jexjexi
Taking the log:
log P ( x i ) = x i − log ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xi−logj∑exj
Using PyTorch:
import torchx = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)
This avoids computing exp(x), preventing numerical instability.
2.2 Cross-Entropy Loss Computation
Cross-entropy loss:
L = − ∑ i y i log P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=−i∑yilogP(xi)
where ( P ( x i ) P(x_i) P(xi) ) is computed using Softmax.
Using torch.logsumexp, we avoid overflow in the denominator:
logits = torch.tensor([[2.0, 1.0, 0.1]])
logsumexp_values = torch.logsumexp(logits, dim=-1)
print(logsumexp_values)
This technique is used in torch.nn.CrossEntropyLoss.
2.3 Token Log Probabilities in Transformer Models
In language models like GPT, BERT, LLaMA, computing token log probabilities is crucial:
import torchlogits = torch.randn(4, 5) # Simulated logits for 4 tokens, vocab size 5
input_ids = torch.tensor([1, 2, 3, 4]) # Token positions# Gather the logits corresponding to the actual tokens
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)# Compute log probabilities
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_valuesprint(token_log_probs)
Here, torch.logsumexp ensures stable probability computation by handling large exponentiations.
3. Performance Optimization
3.1 Why is torch.logsumexp Faster?
Instead of:
torch.log(torch.sum(torch.exp(x)))
which:
- Computes
exp(x), creating an intermediate tensor. - Sums the tensor.
- Computes
log(sum(exp(x))).
torch.logsumexp:
- Avoids unnecessary tensor storage.
- Optimizes computation at the C++/CUDA level.
- Improves numerical stability.
3.2 Performance Benchmark
import timex = torch.randn(1000000)start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")
Results:
torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s
torch.logsumexp is significantly faster and more stable.
4. Summary
torch.logsumexp(x, dim)computeslog(sum(exp(x)))safely, preventing overflow.- Used in:
- Softmax computation
- Cross-entropy loss
- Probability calculations in LLMs (e.g., GPT, BERT)
- More efficient than
log(sum(exp(x)))due to internal optimizations.
🚀 Always prefer torch.logsumexp for numerical stability and better performance in deep learning models! 🚀
后记
2025年2月21日19点06分于上海。在GPT4o大模型辅助下完成。
相关文章:
PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化(中英双语)
PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化 在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致…...
如何为自己的 PDF 文件添加密码?在线加密 PDF 文件其实更简单
随着信息泄露和数据安全问题的日益突出,保护敏感信息变得尤为重要。加密 PDF 文件是一种有效的手段,可以确保只有授权用户才能访问或修改文档内容。本文将详细介绍如何使用 CleverPDF 在线工具为你的 PDF 文件添加密码保护,确保其安全性。 为…...
华为昇腾910b服务器部署DeepSeek翻车现场
最近到祸一台HUAWEI Kunpeng 920 5250,先看看配置。之前是部署的讯飞大模型,发现资源利用率太低了。把5台减少到3台,就出了他 硬件配置信息 基本硬件信息 按照惯例先来看看配置。一共3块盘,500G的系统盘, 2块3T固态…...
hive—常用的函数整理
1、size(split(...))函数用于计算分割后字符串数组的长度 实例1):由客户编号列表计算客户编号个数 --数据准备 with tmp_test01 as ( select tag074445270 tag_id,202501busi_mon , 012399931003,012399931000 index_val union all select tag07444527…...
深入浅出机器学习:概念、算法与实践
目录 引言 机器学习的基本概念 什么是机器学习 机器学习的基本要素 机器学习的主要类型 监督学习(Supervised Learning) 无监督学习(Unsupervised Learning) 强化学习(Reinforcement Learning) 机器…...
Unity Mirror 多房间匹配
文章目录 一 、一些唠叨二 、案例位置三、多房间匹配代码解析四、关于MatchInterestManagement五、总结 一 、一些唠叨 最近使用Mirror开发了一款多人同时在线的肉鸽塔防游戏,其目的是巩固一下Mirror这个插件的熟练度,另一方面是想和身边的朋友一起玩一下自己开发的游戏. 但是…...
基于flask+vue框架的的医院预约挂号系统i1616(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。
系统程序文件列表 项目功能:用户,医生,科室信息,就诊信息,医院概况,挂号信息,诊断信息,取消挂号 开题报告内容 基于FlaskVue框架的医院预约挂号系统开题报告 一、研究背景与意义 随着医疗技术的不断进步和人们健康意识的日益增强,医院就诊量逐年增加。传统的现场…...
Rust编程语言入门教程(五)猜数游戏:生成、比较神秘数字并进行多次猜测
Rust 系列 🎀Rust编程语言入门教程(一)安装Rust🚪 🎀Rust编程语言入门教程(二)hello_world🚪 🎀Rust编程语言入门教程(三) Hello Cargo…...
ubuntu部署小笔记-采坑
ubuntu部署小笔记 搭建前端控制端后端前端nginx反向代理使用ubuntu部署nextjs项目问题一 如何访问端口号配置后台运行该进程pm2 问题二 包体过大生产环境下所需文件 问题三 部署在vercel时出现的问题需要魔法访问后端api时,必须使用https协议电脑端访问正常…...
【代码审计】-Tenda AC 18 v15.03.05.05 /goform接口文档漏洞挖掘
路由器:Tenda AC 18 v15.03.05.05 固件下载地址:https://www.tenda.com.cn/material?keywordac18 1./goform/SetSpeedWan 接口文档: formSetSpeedWan函数中speed_di参数缓冲区溢出漏洞: 使用 binwalk -eM 解包固件,…...
2025年02月21日Github流行趋势
项目名称:source-sdk-2013 项目地址url:https://github.com/ValveSoftware/source-sdk-2013项目语言:C历史star数:7343今日star数:929项目维护者:JoeLudwig, jorgenpt, narendraumate, sortie, alanedwarde…...
git 克隆及拉取github项目到本地微信开发者工具,微信开发者工具通过git commit、git push上传代码到github仓库
git 克隆及拉取github项目到本地微信开发者工具,微信开发者工具通过git commit、git push上传代码到github仓库 git 克隆及拉取github项目到本地 先在自己的用户文件夹新建一个项目文件夹,取名为项目名 例如这样 C:\Users\HP\yzj-再打开一个终端页面&…...
【算法基础】--前缀和
前缀和 一、一维前缀和示例模板[寻找数组的中心下标 ](https://leetcode.cn/problems/tvdfij/description/)除自身以外的数组乘积和可被k整除的子数组 一、一维前缀和 前缀和就是快速求出数组某一个连续区间内所有元素的和。 示例模板 已知一个数组arr,求前缀和 …...
统一的多摄像头3D感知框架!PETRv2论文精读
论文地址:PETRv2: A Unified Framework for 3D Perception from Multi-Camera Images 源代码:PETR 摘要 在本文中,我们提出了PETRv2,用于从多视角图像中进行3D感知的统一框架。基于PETR [24],PETRv2探索了时间建模的…...
【Linux】Linux 文件系统—— 探讨软链接(symbolic link)
ℹ️大家好,我是练小杰,周五又到了,明天应该就是牛马的休息日了吧!!😆 前天我们详细介绍了 硬链接的特点,现在继续探讨 软链接的特点,并且后续将添加更多相关知识噢,谢谢…...
快速排序_912. 排序数组(10中排序算法)
快速排序_912. 排序数组(10中排序算法) 1 快速排序(重点)报错代码超时代码修改官方题解快速排序 1:基本快速排序快速排序 2:双指针(指针对撞)快速排序快速排序 3:三指针快…...
DEMF模型赋能多模态图像融合,助力肺癌高效分类
目录 论文创新点 实验设计 1. 可视化的研究设计 2. 样本选取和数据处理 3. 集成分类模型 4. 实验结果 5. 可视化结果 图表总结 可视化知识图谱 在肺癌早期筛查中,计算机断层扫描(CT)和正电子发射断层扫描(PET)作为两种关键的影像学手段,分别提供了丰富的解剖结构…...
Linux-CentOS 7安装
Centos 7镜像:https://pan.baidu.com/s/1fkQHYT64RMFRGLZy1xnSWw 提取码: q2w2 VMware Workstation:https://pan.baidu.com/s/1JnRcDBIIOWGf6FnGY_0LgA 提取码: w2e2 1、打开vmware workstation 2、选择主界面的"创建新的虚拟机"或者点击左上…...
Android14(13)添加墨水屏手写API
软件平台:Android14 硬件平台:QCS6115 需求:特殊品类的产品墨水屏实现手写的功能,本来Android自带的Input这一套可以实现实时展示笔迹,但是由于墨水屏特性,达不到正常的彩屏刷新的帧率,因此使用…...
AI助力下的PPT革命:DeepSeek 与Kimi的高效创作实践
清华大学出品《DeepSeek:从入门到精通》分享 在忙碌的职场中,制作一份高质量的PPT往往需要投入大量时间和精力,尤其是在临近截止日期时。今天,我们将探索如何借助 AI 工具 —— DeepSeek 和 Kimi —— 让 PPT 制作变得既快捷又高…...
【opencv】图像基本操作
一.计算机眼中的图像 1.1 图像读取 cv2.IMREAD_COLOR:彩色图像 cv2.IMREAD_GRAYSCCALE:灰色图像 ①导包 import cv2 # opencv读取的格式是BGR import matplotlib.pyplot as plt import numpy as np %matplotlib inline ②读取图像 img cv2.imread(…...
帆软报表FineReport入门:简单报表制作[扩展|左父格|上父格]
FineReport帮助文档 - 全面的报表使用教程和学习资料 数据库连接 点击号>>JDBC 选择要连接的数据库>>填写信息>>点击测试连接 数据库SQLite是帆软的内置数据库, 里面有练习数据 选择此数据库后,点击测试连接即可 数据库查询 方法一: 在左下角的模板数据集…...
云手机如何进行经纬度修改
云手机如何进行经纬度修改 云手机修改经纬度的方法因不同服务商和操作方式有所差异,以下是综合多个来源的常用方法及注意事项: 通过ADB命令注入GPS数据(适用于技术用户) 1.连接云手机 使用ADB工具连接云手机服务器,…...
VUE中的组件加载方式
加载方式有哪些,及如何进行选择 常规的静态引入是在组件初始化时就加载所有依赖的组件,而懒加载则是等到组件需要被渲染的时候才加载。 对于大型应用,可能会有很多组件,如果一开始都加载,可能会影响首屏加载时间。如…...
天 锐 蓝盾终端安全管理系统:办公U盘拷贝使用管控限制
天 锐 蓝盾终端安全管理系统以终端安全为基石,深度融合安全、管理与维护三大要素,通过对桌面终端系统的精准把控,助力企业用户构筑起更为安全、稳固且可靠的网络运行环境。它实现了管理的标准化,有效破解终端安全管理难题…...
计算机网络之物理层——基于《计算机网络》谢希仁第八版
(꒪ꇴ꒪ ),Hello我是祐言QAQ我的博客主页:C/C语言,数据结构,Linux基础,ARM开发板,网络编程等领域UP🌍快上🚘,一起学习,让我们成为一个强大的攻城狮࿰…...
区块链中的递归长度前缀(RLP)序列化详解
文章目录 1. 什么是RLP序列化?2. RLP的设计目标与优势3. RLP处理的数据类型4. RLP编码规则详解字符串的编码规则列表的编码规则 5. RLP解码原理6. RLP在以太坊中的应用场景7. 编码示例分析8. 总结 1. 什么是RLP序列化? 递归长度前缀(RLP&…...
分布式简单理解
基本概念 应⽤(Application)/系统(System) 为了完成⼀整套服务的⼀个程序或者⼀组相互配合的程序群。⽣活例⼦类⽐:为了完成⼀项任 务,⽽搭建的由⼀个⼈或者⼀群相互配的⼈组成的团队。 模块(Module)/组件…...
记录:Docker 安装记录
今天在安装 ollama 时发现无法指定安装目录,而且它的命令行反馈内容很像 docker ,而且它下载的模型也是放在 C 盘,那么如果我 C 盘空间不足,就装不了 deepseek-r1:70b ,于是想起来之前安装 Docker 的时候也遇到过类似问…...
Leetcode 二叉树展开为链表
java solution class Solution {public void flatten(TreeNode root) {//首先设置递归终止条件if(root null) return;//分别递归处理左右子树,//递归需要先处理子问题(子树的拉平),然后才能处理当前问题(当前节点的指…...
