当前位置: 首页 > article >正文

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化(中英双语)

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化

在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致数值溢出。torch.logsumexp 正是为了解决这一问题而设计的。

在本文中,我们将详细介绍:

  • torch.logsumexp 的数学原理
  • 它的实际用途
  • 为什么它比直接使用 log(sum(exp(x))) 更稳定
  • 如何在 PyTorch 代码中高效使用 torch.logsumexp

1. torch.logsumexp 是什么?

1.1 数学公式

torch.logsumexp(x, dim) 计算以下数学表达式:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logiexi

其中:

  • ( x i x_i xi ) 是输入张量中的元素,
  • dim 指定沿哪个维度执行计算。

1.2 为什么不直接计算 log(sum(exp(x)))

假设我们有一个很大的数值 ( x ),比如 x = 1000,如果直接计算:

import torchx = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # 结果是 inf(溢出)

问题: exp(1000) 太大,超出了浮点数表示范围,导致溢出。

torch.logsumexp 解决方案:
log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logiexi=xmax+logie(xixmax)

  • 核心思想:先减去最大值 ( x max ⁡ x_{\max} xmax )(防止指数爆炸),然后再计算指数和的对数。
  • 这样能避免溢出,提高数值稳定性。

使用 torch.logsumexp

log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # 正常输出

它不会溢出,因为先减去了最大值,再进行 log 操作。


2. torch.logsumexp 的实际应用

2.1 用于计算 softmax

Softmax 计算公式:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=jexjexi

取对数后,得到对数 softmax(log-softmax):
log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xilogjexj

PyTorch 代码:

import torchx = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)

这避免了指数溢出,比直接计算 torch.log(torch.sum(torch.exp(x))) 更稳定。


2.2 用于计算交叉熵损失

交叉熵(Cross-Entropy)计算:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=iyilogP(xi)

其中 ( P ( x i ) P(x_i) P(xi) ) 通过 softmax 计算得到,而 torch.logsumexp 让 softmax 的分母计算更稳定。


2.3 在 Transformer 模型中的应用

GPT、BERT 等 Transformer 语言模型 训练过程中,我们通常会计算 token_log_probs,如下:

import torchlogits = torch.randn(4, 5)  # 假设 batch_size=4, vocab_size=5
input_ids = torch.tensor([1, 2, 3, 4])  # 假设真实的 token 位置# 计算每个 token 的对数概率
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_valuesprint(token_log_probs)

这里 torch.logsumexp(logits, dim=-1) 用于计算 softmax 分母的对数值,确保概率计算不会溢出。


3. torch.logsumexp 的性能优化

3.1 为什么 torch.logsumexplog(sum(exp(x))) 更快?

  • 避免额外存储 exp(x):如果先 exp(x),再 sum(),会生成一个额外的大张量,而 logsumexp 直接在 C++/CUDA 内部优化了计算。
  • 减少数值溢出:减少浮点数不必要的运算,防止梯度爆炸。

3.2 实测性能

import timex = torch.randn(1000000)start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")

结果(示例):

torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp 速度更快,并且避免了 exp(x) 可能导致的溢出。


4. 总结

  • torch.logsumexp(x, dim) 计算 log(sum(exp(x))),但使用数值稳定的方法,防止溢出。
  • 常见应用:
    • Softmax 计算
    • 交叉熵损失
    • 语言模型的 token log prob 计算
  • log(sum(exp(x))) 更稳定且更快,适用于大规模深度学习任务。

建议:
🚀 在涉及 log(sum(exp(x))) 计算时,尽量使用 torch.logsumexp,可以大幅提升数值稳定性和计算效率! 🚀

Understanding torch.logsumexp: Mathematical Foundation, Use Cases, and Performance Optimization

In deep learning, especially in probability models, computing logarithmic probabilities in a numerically stable way is crucial. Directly applying log(sum(exp(x))) can lead to numerical instability due to floating-point overflow. torch.logsumexp is designed to solve this problem efficiently.

In this article, we will cover:

  • The mathematical foundation of torch.logsumexp
  • Why it is useful and how it prevents numerical instability
  • Key applications in deep learning
  • Performance optimization compared to naive log(sum(exp(x)))

1. What is torch.logsumexp?

1.1 Mathematical Formula

torch.logsumexp(x, dim) computes the following function:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logiexi

where:

  • ( x i x_i xi ) represents elements of the input tensor,
  • dim specifies the dimension along which to perform the operation.

1.2 Why Not Directly Compute log(sum(exp(x)))?

Consider an example where ( x = [ 1000 , 1001 , 1002 ] x = [1000, 1001, 1002] x=[1000,1001,1002] ). If we naively compute:

import torchx = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # Output: inf (overflow)

Problem:

  • exp(1000) is too large, exceeding the floating-point limit, causing an overflow.

Solution: Log-Sum-Exp Trick
To prevent overflow, torch.logsumexp applies the following transformation:

log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logiexi=xmax+logie(xixmax)

where ( x max ⁡ x_{\max} xmax ) is the maximum value in ( x x x ).

  • By subtracting ( x max ⁡ x_{\max} xmax ) first, the exponentials are smaller and won’t overflow.

Example using torch.logsumexp:

log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # Outputs a valid value without overflow

This is more numerically stable.


2. Key Applications of torch.logsumexp

2.1 Computing Softmax in Log Space

The Softmax function is defined as:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=jexjexi

Taking the log:

log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xilogjexj

Using PyTorch:

import torchx = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)

This avoids computing exp(x), preventing numerical instability.


2.2 Cross-Entropy Loss Computation

Cross-entropy loss:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=iyilogP(xi)

where ( P ( x i ) P(x_i) P(xi) ) is computed using Softmax.
Using torch.logsumexp, we avoid overflow in the denominator:

logits = torch.tensor([[2.0, 1.0, 0.1]])
logsumexp_values = torch.logsumexp(logits, dim=-1)
print(logsumexp_values)

This technique is used in torch.nn.CrossEntropyLoss.


2.3 Token Log Probabilities in Transformer Models

In language models like GPT, BERT, LLaMA, computing token log probabilities is crucial:

import torchlogits = torch.randn(4, 5)  # Simulated logits for 4 tokens, vocab size 5
input_ids = torch.tensor([1, 2, 3, 4])  # Token positions# Gather the logits corresponding to the actual tokens
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)# Compute log probabilities
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_valuesprint(token_log_probs)

Here, torch.logsumexp ensures stable probability computation by handling large exponentiations.


3. Performance Optimization

3.1 Why is torch.logsumexp Faster?

Instead of:

torch.log(torch.sum(torch.exp(x)))

which:

  1. Computes exp(x), creating an intermediate tensor.
  2. Sums the tensor.
  3. Computes log(sum(exp(x))).

torch.logsumexp:

  • Avoids unnecessary tensor storage.
  • Optimizes computation at the C++/CUDA level.
  • Improves numerical stability.

3.2 Performance Benchmark

import timex = torch.randn(1000000)start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")

Results:

torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp is significantly faster and more stable.


4. Summary

  • torch.logsumexp(x, dim) computes log(sum(exp(x))) safely, preventing overflow.
  • Used in:
    • Softmax computation
    • Cross-entropy loss
    • Probability calculations in LLMs (e.g., GPT, BERT)
  • More efficient than log(sum(exp(x))) due to internal optimizations.

🚀 Always prefer torch.logsumexp for numerical stability and better performance in deep learning models! 🚀

后记

2025年2月21日19点06分于上海。在GPT4o大模型辅助下完成。

相关文章:

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化(中英双语)

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化 在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致…...

如何为自己的 PDF 文件添加密码?在线加密 PDF 文件其实更简单

随着信息泄露和数据安全问题的日益突出,保护敏感信息变得尤为重要。加密 PDF 文件是一种有效的手段,可以确保只有授权用户才能访问或修改文档内容。本文将详细介绍如何使用 CleverPDF 在线工具为你的 PDF 文件添加密码保护,确保其安全性。 为…...

华为昇腾910b服务器部署DeepSeek翻车现场

最近到祸一台HUAWEI Kunpeng 920 5250,先看看配置。之前是部署的讯飞大模型,发现资源利用率太低了。把5台减少到3台,就出了他 硬件配置信息 基本硬件信息 按照惯例先来看看配置。一共3块盘,500G的系统盘, 2块3T固态…...

hive—常用的函数整理

1、size(split(...))函数用于计算分割后字符串数组的长度 实例1):由客户编号列表计算客户编号个数 --数据准备 with tmp_test01 as ( select tag074445270 tag_id,202501busi_mon , 012399931003,012399931000 index_val union all select tag07444527…...

深入浅出机器学习:概念、算法与实践

目录 引言 机器学习的基本概念 什么是机器学习 机器学习的基本要素 机器学习的主要类型 监督学习(Supervised Learning) 无监督学习(Unsupervised Learning) 强化学习(Reinforcement Learning) 机器…...

Unity Mirror 多房间匹配

文章目录 一 、一些唠叨二 、案例位置三、多房间匹配代码解析四、关于MatchInterestManagement五、总结 一 、一些唠叨 最近使用Mirror开发了一款多人同时在线的肉鸽塔防游戏,其目的是巩固一下Mirror这个插件的熟练度,另一方面是想和身边的朋友一起玩一下自己开发的游戏. 但是…...

基于flask+vue框架的的医院预约挂号系统i1616(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表 项目功能:用户,医生,科室信息,就诊信息,医院概况,挂号信息,诊断信息,取消挂号 开题报告内容 基于FlaskVue框架的医院预约挂号系统开题报告 一、研究背景与意义 随着医疗技术的不断进步和人们健康意识的日益增强,医院就诊量逐年增加。传统的现场…...

Rust编程语言入门教程(五)猜数游戏:生成、比较神秘数字并进行多次猜测

Rust 系列 🎀Rust编程语言入门教程(一)安装Rust🚪 🎀Rust编程语言入门教程(二)hello_world🚪 🎀Rust编程语言入门教程(三) Hello Cargo&#x1f…...

ubuntu部署小笔记-采坑

ubuntu部署小笔记 搭建前端控制端后端前端nginx反向代理使用ubuntu部署nextjs项目问题一 如何访问端口号配置后台运行该进程pm2 问题二 包体过大生产环境下所需文件 问题三 部署在vercel时出现的问题需要魔法访问后端api时,必须使用https协议电脑端访问正常&#xf…...

【代码审计】-Tenda AC 18 v15.03.05.05 /goform接口文档漏洞挖掘

路由器:Tenda AC 18 v15.03.05.05 固件下载地址:https://www.tenda.com.cn/material?keywordac18 1./goform/SetSpeedWan 接口文档: formSetSpeedWan函数中speed_di参数缓冲区溢出漏洞: 使用 binwalk -eM 解包固件&#xff0c…...

2025年02月21日Github流行趋势

项目名称:source-sdk-2013 项目地址url:https://github.com/ValveSoftware/source-sdk-2013项目语言:C历史star数:7343今日star数:929项目维护者:JoeLudwig, jorgenpt, narendraumate, sortie, alanedwarde…...

git 克隆及拉取github项目到本地微信开发者工具,微信开发者工具通过git commit、git push上传代码到github仓库

git 克隆及拉取github项目到本地微信开发者工具,微信开发者工具通过git commit、git push上传代码到github仓库 git 克隆及拉取github项目到本地 先在自己的用户文件夹新建一个项目文件夹,取名为项目名 例如这样 C:\Users\HP\yzj-再打开一个终端页面&…...

【算法基础】--前缀和

前缀和 一、一维前缀和示例模板[寻找数组的中心下标 ](https://leetcode.cn/problems/tvdfij/description/)除自身以外的数组乘积和可被k整除的子数组 一、一维前缀和 前缀和就是快速求出数组某一个连续区间内所有元素的和。 示例模板 已知一个数组arr,求前缀和 …...

统一的多摄像头3D感知框架!PETRv2论文精读

论文地址:PETRv2: A Unified Framework for 3D Perception from Multi-Camera Images 源代码:PETR 摘要 在本文中,我们提出了PETRv2,用于从多视角图像中进行3D感知的统一框架。基于PETR [24],PETRv2探索了时间建模的…...

【Linux】Linux 文件系统—— 探讨软链接(symbolic link)

ℹ️大家好,我是练小杰,周五又到了,明天应该就是牛马的休息日了吧!!😆 前天我们详细介绍了 硬链接的特点,现在继续探讨 软链接的特点,并且后续将添加更多相关知识噢,谢谢…...

快速排序_912. 排序数组(10中排序算法)

快速排序_912. 排序数组(10中排序算法) 1 快速排序(重点)报错代码超时代码修改官方题解快速排序 1:基本快速排序快速排序 2:双指针(指针对撞)快速排序快速排序 3:三指针快…...

DEMF模型赋能多模态图像融合,助力肺癌高效分类

目录 论文创新点 实验设计 1. 可视化的研究设计 2. 样本选取和数据处理 3. 集成分类模型 4. 实验结果 5. 可视化结果 图表总结 可视化知识图谱 在肺癌早期筛查中,计算机断层扫描(CT)和正电子发射断层扫描(PET)作为两种关键的影像学手段,分别提供了丰富的解剖结构…...

Linux-CentOS 7安装

Centos 7镜像:https://pan.baidu.com/s/1fkQHYT64RMFRGLZy1xnSWw 提取码: q2w2 VMware Workstation:https://pan.baidu.com/s/1JnRcDBIIOWGf6FnGY_0LgA 提取码: w2e2 1、打开vmware workstation 2、选择主界面的"创建新的虚拟机"或者点击左上…...

Android14(13)添加墨水屏手写API

软件平台:Android14 硬件平台:QCS6115 需求:特殊品类的产品墨水屏实现手写的功能,本来Android自带的Input这一套可以实现实时展示笔迹,但是由于墨水屏特性,达不到正常的彩屏刷新的帧率,因此使用…...

AI助力下的PPT革命:DeepSeek 与Kimi的高效创作实践

清华大学出品《DeepSeek:从入门到精通》分享 在忙碌的职场中,制作一份高质量的PPT往往需要投入大量时间和精力,尤其是在临近截止日期时。今天,我们将探索如何借助 AI 工具 —— DeepSeek 和 Kimi —— 让 PPT 制作变得既快捷又高…...

【opencv】图像基本操作

一.计算机眼中的图像 1.1 图像读取 cv2.IMREAD_COLOR:彩色图像 cv2.IMREAD_GRAYSCCALE:灰色图像 ①导包 import cv2 # opencv读取的格式是BGR import matplotlib.pyplot as plt import numpy as np %matplotlib inline ②读取图像 img cv2.imread(…...

帆软报表FineReport入门:简单报表制作[扩展|左父格|上父格]

FineReport帮助文档 - 全面的报表使用教程和学习资料 数据库连接 点击号>>JDBC 选择要连接的数据库>>填写信息>>点击测试连接 数据库SQLite是帆软的内置数据库, 里面有练习数据 选择此数据库后,点击测试连接即可 数据库查询 方法一: 在左下角的模板数据集…...

云手机如何进行经纬度修改

云手机如何进行经纬度修改 云手机修改经纬度的方法因不同服务商和操作方式有所差异,以下是综合多个来源的常用方法及注意事项: 通过ADB命令注入GPS数据(适用于技术用户) 1.连接云手机 使用ADB工具连接云手机服务器,…...

VUE中的组件加载方式

加载方式有哪些,及如何进行选择 常规的静态引入是在组件初始化时就加载所有依赖的组件,而懒加载则是等到组件需要被渲染的时候才加载。 对于大型应用,可能会有很多组件,如果一开始都加载,可能会影响首屏加载时间。如…...

天 锐 蓝盾终端安全管理系统:办公U盘拷贝使用管控限制

天 锐 蓝盾终端安全管理系统以终端安全为基石,深度融合安全、管理与维护三大要素,通过对桌面终端系统的精准把控,助力企业用户构筑起更为安全、稳固且可靠的网络运行环境。它实现了管理的标准化,有效破解终端安全管理难题&#xf…...

计算机网络之物理层——基于《计算机网络》谢希仁第八版

(꒪ꇴ꒪ ),Hello我是祐言QAQ我的博客主页:C/C语言,数据结构,Linux基础,ARM开发板,网络编程等领域UP🌍快上🚘,一起学习,让我们成为一个强大的攻城狮&#xff0…...

区块链中的递归长度前缀(RLP)序列化详解

文章目录 1. 什么是RLP序列化?2. RLP的设计目标与优势3. RLP处理的数据类型4. RLP编码规则详解字符串的编码规则列表的编码规则 5. RLP解码原理6. RLP在以太坊中的应用场景7. 编码示例分析8. 总结 1. 什么是RLP序列化? 递归长度前缀(RLP&…...

分布式简单理解

基本概念 应⽤(Application)/系统(System) 为了完成⼀整套服务的⼀个程序或者⼀组相互配合的程序群。⽣活例⼦类⽐:为了完成⼀项任 务,⽽搭建的由⼀个⼈或者⼀群相互配的⼈组成的团队。 模块(Module)/组件…...

记录:Docker 安装记录

今天在安装 ollama 时发现无法指定安装目录,而且它的命令行反馈内容很像 docker ,而且它下载的模型也是放在 C 盘,那么如果我 C 盘空间不足,就装不了 deepseek-r1:70b ,于是想起来之前安装 Docker 的时候也遇到过类似问…...

Leetcode 二叉树展开为链表

java solution class Solution {public void flatten(TreeNode root) {//首先设置递归终止条件if(root null) return;//分别递归处理左右子树,//递归需要先处理子问题(子树的拉平),然后才能处理当前问题(当前节点的指…...