理解 logits_to_keep = logits_to_keep + 1 在 _get_per_token_logps 中的作用
理解 logits_to_keep = logits_to_keep + 1 在 _get_per_token_logps 中的作用
source: anaconda3/envs/xxx/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):# We add 1 to `logits_to_keep` because the last logits of the sequence is later excludedlogits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits # (B, L, V)logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token predinput_ids = input_ids[:, -logits_to_keep:]# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.# See https://github.com/huggingface/trl/issues/2770logits = logits[:, -logits_to_keep:]# Compute the log probabilities for the input tokens.token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)# use a loop to reduce memory peaklogsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])token_log_probs = token_logits - logsumexp_values # log_softmax = logits - log(sum(exp(logits)))return token_log_probs
在 _get_per_token_logps 这个函数中,logits_to_keep 控制了要保留的 logits 数量,用于计算每个 token 的对数概率。
但这里有一个关键点:
logits_to_keep = logits_to_keep + 1
为什么需要加 1?
因为在 Transformer 语言模型(如 GPT)中,模型的 logits 预测的是下一个 token,所以如果我们只保留 logits_to_keep 个 logits,数量是不够的。
为了确保对齐,我们先多取一个 logits,然后再手动丢弃最后一个 logits,这样 logits 和 input_ids 就能正确对齐。
1. 为什么需要 logits_to_keep + 1?
1.1 自回归模型的 logits 预测的是下一个 token
在 Transformer 语言模型中,模型的 logits 形状通常是:
logits.shape = (B, L, V)
其中:
B:batch_sizeL:序列长度V:词表大小(vocab size)
模型在生成 logits 时,每个 logits[i] 实际上是用于预测下一个 token,而不是当前 token:
logits[:, 0, :] -> 用于预测 input_ids[:, 1]
logits[:, 1, :] -> 用于预测 input_ids[:, 2]
...
logits[:, L-1, :] -> 用于预测 input_ids[:, L](即下一个 token)
但 input_ids 只包含当前 token,并不包含 “下一个 token” 的真实值,因此我们需要手动去掉最后一个 logits,让它和 input_ids 对齐。
2. 代码执行步骤
2.1 假设 input_ids.shape = (1, 5)
假设 logits_to_keep = 3,那么:
logits_to_keep + 1 = 4,即多取一个logits。- 模型返回的
logits.shape = (1, 6, V),因为logits_to_keep+1=4,再加上可能的 padding,会得到 6 个logits。
2.2 关键代码
步骤 1:调用模型
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits
此时 logits.shape = (B, L, Vocab),其中 L = logits_to_keep + 1。
步骤 2:删除最后一个 logits
logits = logits[:, :-1, :]
这样 logits 的形状就变成 (B, L-1, V),让它正确对应 input_ids[:, -logits_to_keep:]。
步骤 3:对齐 input_ids
input_ids = input_ids[:, -logits_to_keep:]
这里 input_ids[:, -logits_to_keep:] 取的是最后 logits_to_keep 个 token,确保 logits 和 input_ids 一一对应。
3. 示例代码
3.1 假设 input_ids = [5, 8, 2, 3, 9],logits_to_keep = 3
① logits_to_keep + 1 让模型生成 4 个 logits
logits.shape = (1, 5, V) # 5 个 token,分别预测下一个 token
| Token | 真实 input_ids | logits 预测 |
|---|---|---|
| 1 | 5 | 用于预测 8 |
| 2 | 8 | 用于预测 2 |
| 3 | 2 | 用于预测 3 |
| 4 | 3 | 用于预测 9 |
| 5 | 9 | (无用,预测下一个 token) |
② 手动删除最后一个 logits
logits = logits[:, :-1, :] # 丢弃最后一个预测
最终 logits 形状:
logits.shape = (1, 4, V) # 只保留前 4 个 logits
这样 logits 和 input_ids[:, -logits_to_keep:] 对齐:
logits 对应 input_ids = [8, 2, 3]
4. 如果不加 +1 会发生什么?
如果 logits_to_keep 不加 1,那么:
logits数量比input_ids少 1 个,导致维度对不上。- 计算
log_probs时logits.gather(dim=-1, index=input_ids.unsqueeze(-1))会报错,或者索引到错误的 logits。
5. 结论
| 步骤 | 目的 |
|---|---|
logits_to_keep + 1 | 获取一个额外的 logits,避免数据对不齐 |
logits[:, :-1, :] | 删除最后一个 logits,确保与 input_ids 对齐 |
input_ids[:, -logits_to_keep:] | 选取最后 logits_to_keep 个 token 计算 log_probs |
核心逻辑
✅ 因为 logits 预测的是下一个 token,所以要多取 1 个,然后手动删除最后一个。
✅ 这样 logits 和 input_ids 维度对齐,确保计算正确的 log_probs。
🚀 理解这个逻辑对于实现 Transformer 语言模型的 loss 计算至关重要! 🚀
如果 logits_to_keep 不加 +1 会发生什么?
假设:
input_ids = [5, 8, 2, 3, 9]logits_to_keep = 3logits.shape = (B, L, V), 其中L=5,表示 5 个 token,每个 token 的logits是一个Vocab大小的概率分布。
1. 正确做法(logits_to_keep + 1)
如果 logits_to_keep + 1:
logits_to_keep = 3 + 1 = 4- 让模型输出 4 个
logits,即:logits.shape = (1, 4, V) - 然后 删除最后一个
logits(logits[:, :-1, :]),得到:logits.shape = (1, 3, V) # 3 个 logits,对应 input_ids 的最后 3 个 token - 此时
logits和input_ids[:, -3:] = [8, 2, 3]维度匹配,可以正确计算log_probs。
2. 错误示例(如果不加 +1)
如果不加 +1,直接 logits_to_keep = 3,那么:
- 模型只会返回
3个logits:logits.shape = (1, 3, V) # 只保留 3 个 logits - 然后
logits[:, :-1, :]会让logits变成:logits.shape = (1, 2, V) # 只有 2 个 logits - 但
input_ids[:, -logits_to_keep:]仍然是:input_ids[:, -3:] = [8, 2, 3] # 3 个 token - 这样,
gather操作:
将会报错,因为token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)logits.shape = (1, 2, V),但input_ids.shape = (1, 3),维度不匹配!
错误示例代码
import torch# 模拟 logits (batch_size=1, sequence_length=2, vocab_size=5)
logits = torch.tensor([[[2.0, 1.0, 0.5, -1.0, 0.2], # logit for token 8[0.1, -0.5, 2.2, 1.5, 0.0]] # logit for token 2
]) # shape = (1, 2, 5)# input_ids 仍然有 3 个 token
input_ids = torch.tensor([[8, 2, 3]]) # shape = (1, 3)# 错误的 gather 操作
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
错误信息:
RuntimeError: Expected index with dimension 3, but got dimension 4 for input tensor.
这个错误表明 logits 只有 2 个 token,而 input_ids 仍然有 3 个 token,导致 gather 操作失败!
3. 错误情况总结
| 情况 | logits.shape (B, L, V) | input_ids.shape (B, L) | 是否匹配? |
|---|---|---|---|
正确:logits_to_keep + 1 后删掉最后一个 logits | (1, 3, V) | (1, 3) | ✅ 匹配 |
错误:不加 +1 | (1, 2, V) | (1, 3) | ❌ 不匹配,报错 |
🔴 结论:
如果不加 +1,最终 logits 会比 input_ids 少 1 个 token,导致 gather 操作失败,无法正确计算 log_probs。
4. 关键结论
logits_to_keep + 1确保logits先比input_ids多一个,然后删掉最后一个logits,使两者对齐。- 不加
+1,最终logits比input_ids少 1 个,导致gather维度错误,代码会报错。 - 在自回归模型中,
logits预测的是下一个 token,所以要手动调整,以确保logits和input_ids一一对应。
🚀 正确理解 logits_to_keep + 1 是构建 Transformer 语言模型损失计算的关键! 🚀
如果不加 +1,可以不执行 logits = logits[:, :-1, :] 吗?
不可以!如果不加 +1,并且 不执行 logits[:, :-1, :] 这个操作,最终 logits 和 input_ids 的对齐仍然会出问题,导致错误的 token 对数概率计算。
1. 代码逻辑分析
1.1 logits_to_keep + 1 的作用
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits # (B, L, V)
logits_to_keep + 1让模型输出比logits_to_keep多 1 个logits。- 这样,
logits.shape = (B, L+1, V)(多 1 个 token 预测的logits)。
1.2 logits[:, :-1, :] 的作用
logits = logits[:, :-1, :] # (B, L-1, V)
- 这一步 删除最后一个
logits,确保logits只用于计算input_ids对应 token 的概率。 - 如果不执行这一步,则
logits.shape = (B, L, V),这就会导致logits比input_ids多 1 个 token,维度不匹配。
1.3 input_ids[:, -logits_to_keep:] 作用
input_ids = input_ids[:, -logits_to_keep:]
- 这一步 只保留
logits_to_keep个 token 的input_ids,确保input_ids和logits维度匹配。
2. 如果不加 +1,但仍然执行 logits[:, :-1, :],会发生什么?
如果 logits_to_keep 没有 +1,但仍然执行:
logits = logits[:, :-1, :]
logits数量会比input_ids少 1 个。logits.shape = (B, logits_to_keep - 1, V)。input_ids.shape = (B, logits_to_keep)。
这会导致:
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
报错,因为 logits.shape[1] 和 input_ids.shape[1] 不匹配!
3. 如果不加 +1,并且不执行 logits[:, :-1, :],会发生什么?
假设 logits_to_keep = 3,并且不加 +1,那么:
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep).logits
logits.shape = (B, logits_to_keep, V)。input_ids[:, -logits_to_keep:]仍然是(B, logits_to_keep)。logits和input_ids维度看似匹配,但实际上错位了!
错位的原因:
logits[:, i, :]对应的是input_ids[:, i+1](预测的是下一个 token),而不是input_ids[:, i]!- 这会导致
gather取到错误的 logits,计算的log_probs也是错的。
示例
假设:
input_ids = [[5, 8, 2, 3, 9]] # 长度 5
logits_to_keep = 3
如果 不加 +1,且不 logits[:, :-1, :]:
logits[:, 0, :] # 实际预测 input_ids[:, 1] (8)
logits[:, 1, :] # 实际预测 input_ids[:, 2] (2)
logits[:, 2, :] # 实际预测 input_ids[:, 3] (3) ❌ 但被错误匹配到 input_ids[:, 2]
最终 gather 取到的是错位的 logits!
4. 结论
| 情况 | logits.shape | input_ids.shape | 结果 |
|---|---|---|---|
正确:加 +1 并执行 logits[:, :-1, :] | (B, logits_to_keep, V) | (B, logits_to_keep) | ✅ 匹配正确 |
错误:不加 +1,但仍然执行 logits[:, :-1, :] | (B, logits_to_keep - 1, V) | (B, logits_to_keep) | ❌ 维度不匹配,gather 报错 |
错误:不加 +1,且不执行 logits[:, :-1, :] | (B, logits_to_keep, V) | (B, logits_to_keep) | ❌ 错位,计算错误的 log_probs |
核心总结
logits_to_keep + 1让logits先多 1 个,再删掉最后 1 个,以正确对齐input_ids。- 如果不
+1,但仍然logits[:, :-1, :],最终logits比input_ids少 1 个,导致gather失败。 - 如果不
+1,且不logits[:, :-1, :],最终logits和input_ids看似匹配,但会错位,计算错误的log_probs。
🚀 正确理解 logits_to_keep + 1 是确保 Transformer 语言模型 log_prob 计算正确的关键! 🚀
后记
2025年2月21日19点32分于上海。在GPT4o大模型辅助下完成。
相关文章:
理解 logits_to_keep = logits_to_keep + 1 在 _get_per_token_logps 中的作用
理解 logits_to_keep logits_to_keep 1 在 _get_per_token_logps 中的作用 source: anaconda3/envs/xxx/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):# We add 1 to logi…...
论文笔记-WSDM2025-ColdLLM
论文笔记-WSDM2025-Large Language Model Simulator for Cold-Start Recommendation ColdLLM:用于冷启动推荐的大语言模型模拟器摘要1.引言2.前言3.方法3.1整体框架3.1.1行为模拟3.1.2嵌入优化 3.2耦合漏斗ColdLLM3.2.1过滤模拟3.2.2精炼模拟 3.3模拟器训练3.3.1LLM…...
DeepSeek与AI幻觉
AI幻觉(AI Hallucination) 是指人工智能系统(尤其是生成式模型,如大型语言模型或图像生成模型)在输出内容时,生成与事实不符、逻辑混乱或完全虚构的信息的现象。这种现象类似于人类的“幻觉”,即…...
Linux 命令大全完整版(09)
4. 压缩与解压缩命令 ar 功能说明:建立或修改备存文件,或是从备存文件中抽取文件。语法:ar[-dmpqrtx][cfosSuvV][a<成员文件>][b<成员文件>][i<成员文件>][备存文件][成员文件]补充说明:可让您集合许多文件&a…...
deepseek_清华大学指导手册_pdf_1-5
deepseek_清华大学指导手册_pdf_1-5 无套路,无需关注,无需登录,无需app,直接下载: 下载地址 文件列表: 001_清华大学_DeepSeek从入门到精通.pdf 002_清华大学_DeepSeek如何赋能职场应用.pdf 003_清华大学…...
深度学习-127-LangGraph之基础知识(四)自定义状态添加额外字段的聊天机器人
文章目录 1 自定义状态2 自定义工具2.1 完善工具human_assistance2.2 浏览器工具baidu_search3 聊天机器人3.1 绑定工具的聊天模型3.2 聊天机器人(带记忆)4 调用图4.1 调用工具时中断4.2 人工提供信息恢复4.3 查询存储的状态4.4 手动更新状态5 参考附录使用LangGraph,在状态中…...
自定义实现简版状态机
状态机(State Machine)是一种用于描述系统行为的数学模型,广泛应用于计算机科学、工程和自动化等领域。它通过定义系统的状态、事件和转移来模拟系统的动态行为。 基本概念 状态(State):系统在某一时刻的特…...
基于 Python Django 的校园互助平台(附源码,文档)
博主介绍:✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇🏻 不…...
Python pip 缓存清理:全面方法与操作指南
在使用 Python 的 pip 进行包安装时,pip 会将下载的包缓存起来,以加快后续相同包的安装速度。不过,随着时间推移,缓存会占用大量磁盘空间,这时你可以对其进行清理。下面为你介绍不同操作系统下清理 pip 缓存的方法。 …...
Windows系统第一次运行C语言程序,环境配置,软件安装等遇到的坑及解决方法
明确需要编辑器和编译器,并选择自己要用什么(我选的编辑器是VSCode:Visual Studio Code;编译器是gcc)下载VSCode并配置环境变量(这里没啥问题),安装C/C的拓展安装Cygwin,…...
Python开发Django面试题及参考答案
目录 Django 的请求生命周期是怎样的? Django 的 MTV 架构中的各个组件分别是什么? Django 的 URL 路由是如何工作的? Django 的视图函数和视图类有什么区别? Django 的模板系统是如何渲染 HTML 的? Django 的 ORM 是如何工作的? Django 的中间件是什么?它的作用是…...
PyTorch v2.6 Overview
PyTorch v2.6 Overview Python APILibraries PyTorch 是一个优化的张量库,用于使用 GPU 和 CPU 进行深度学习。 Python API 序号API名称解释1torchPyTorch 核心库(中文:火炬)PyTorch 的核心库,提供了张量操作、自动求导等基础功能。2torch.nn神经网络模…...
智慧废品回收小程序php+uniapp
废品回收小程序:数字化赋能环保,开启资源循环新时代 城市垃圾治理难题,废品回收小程序成破局关键 随着城市化进程加速与消费水平提升,我国生活垃圾总量逐年攀升,年均增速达5%-8%,其中超30%为可回收物。然…...
【p-camera-h5】 一款开箱即用的H5相机插件,支持拍照、录像、动态水印与样式高度定制化。
【开源推荐】p-camera-h5:一款轻量级H5相机插件开发实践 一、插件背景 在Web开发中,原生摄像头功能的集成往往面临以下痛点: 浏览器兼容性问题视频流与水印叠加实现复杂移动端适配困难功能定制成本高 为此,p-camera-h5 —— 一…...
python~http的请求参数中携带map
背景 调试 http GET请求的 map 参数,链路携带参数一直有问题,最终采用如下方式携带map 解决 user{"demo":"true","info":"王者"}url encode之后的效果如下所示 user%7B%22demo%22:%22true%22,%22info%22:%22…...
网页版的俄罗斯方块
1、新建一个txt文件 2、打开后将代码复制进去保存 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>俄…...
创建虚拟环境以及配置对应的项目依赖
文章目录 首先创建一个虚拟环境,创建一个名字为myenv,并且版本为xxx的虚拟环境 conda create --name myenv pythonxxx激活虚拟环境 conda activate myenv下载所需的依赖,如果有requirements.txt文件 pip install -r requirements.txt容易出现的错误&a…...
网络安全第三次练习
一、实验拓扑 二、实验要求 配置真实DNS服务信息,创建虚拟服务,配置DNS透明代理功能 三、需求分析 1.创建用户并配置认证策略 2.安全策略划分接口 3.ip与策略配置 四、实验步骤 1.划分安全策略接口 2.创建用户并进行策略认证 3.配置安全策略 4.NAT配…...
写大论文的word版本格式整理,实现自动生成目录、参考文献序号、公式序号、图表序号
前情提要:最近开始写大论文,发现由于内容很多导致用老方法一个一个改的话超级麻烦,需要批量自动化处理,尤其是序号,在不断有增添删减的情况时序号手动调整很慢也容易出错,所以搞一个格式总结,记…...
STM32——HAL库开发笔记22(定时器3—呼吸灯实验)(参考来源:b站铁头山羊)
本文利用前几节所学知识来实现一个呼吸灯实验:两颗led灯交替呼吸。 一、STM32CubeMX配置 step1:配置调试接口 step2:配置定时器 定时器1位于APB2总线上,如上图所示。 step3:配置时基单元 按照下图配置 时钟来源配置…...
GPU和FPGA的区别
GPU(Graphics Processing Unit,图形处理器)和 FPGA(Field-Programmable Gate Array,现场可编程门阵列)不是同一种硬件。 我的理解是,虽然都可以用于并行计算,但是GPU是纯计算的硬件…...
vue3页面显示tiff图片
浏览器网页一般不直接支持tiff图片的显示,需要用到tiff.js这个库,首先安装tiff.js,使用命令 npm install tiff.js安装。 首先,引入相关库 import axios from axios; import { ref } from vue; import {TIFF } from tiff.js 在v…...
玩转 Java 与 Python 交互,JEP 库来助力
文章目录 玩转 Java 与 Python 交互,JEP 库来助力一、背景介绍二、JEP 库是什么?三、如何安装 JEP 库?四、JEP 库的简单使用方法五、JEP 库的实际应用场景场景 1:数据处理场景 2:机器学习场景 3:科学计算场…...
【单片机毕业设计14-基于stm32c8t6的智能宠物养护舱系统设计】
【单片机毕业设计14-基于stm32c8t6的智能宠物养护舱系统设计】 前言一、功能介绍二、硬件部分三、软件部分总结 前言 🔥这里是小殷学长,单片机毕业设计篇14-基于stm32c8t6的智能宠物养护舱系统设计 🧿创作不易,拒绝白嫖可私 一、功…...
ASUS/华硕天选4 Plus 锐龙版 FA507X FA707X 原厂Win11 22H2专业版系统 工厂文件 带ASUS Recovery恢复
华硕工厂文件恢复系统 ,安装结束后带隐藏分区,带一键恢复,以及机器所有的驱动和软件。 支持型号:FA507XU FA507XV FA507XQ FA507XJ FA507XI, FA707XV, FA707XU, FA707XQ, FA707XJ, FA707XI, FA707XIN 系统版本:Windo…...
从头再来!社招找工作——算法题复习九:动态规划
从头再来!社招找工作——算法题复习九:动态规划 动态规划斐波那数列跳台阶跳台阶/爬楼梯最小花费跳台阶 最长公共子序列矩阵矩阵路线总数矩阵路线总数有障碍物矩阵的最小路径和三角形的最小路径和 买卖股票的最佳时机(T1天 / 当日不可卖&…...
检测服务端口是否开放的常用方法
检测服务端口是否开放的常用方法 文章目录 检测服务端口是否开放的常用方法背景使用nc命令使用 telnet 命令使用 curl 命令使用 openssl 命令使用 Python 脚本,socket连接使用 bash 内建命令:使用 nmap:总结 背景 有时候需要测试网络是否连通,端口是否开放…...
23贪心算法
分发饼干 class Solution { public:int findContentChildren(vector<int>& g, vector<int>& s) {int i0,j0;int count0;sort(s.begin(),s.end());sort(g.begin(),g.end());while(i<g.size()&&j<s.size()){if(g[i]<s[j]){i;j;count;}else…...
网站快速收录:如何优化网站404页面?
优化网站404页面是提升用户体验和SEO效果的重要一环。以下是一些优化404页面的建议: 一、设计友好的404页面 简洁明了的提示信息:使用清晰的语言告诉用户该页面不存在或已被删除,避免使用过于技术化的术语。 提供导航链接:在40…...
DevEco Studio常用快捷键以及如何跟AndroidStudio的保持同步
DevEco Studio快捷键 DevEco Studio是华为推出的用于开发HarmonyOS应用的集成开发环境,它提供了丰富的快捷键以提高开发效率,以下为你详细介绍不同操作场景下的常用快捷键: 通用操作快捷键 操作描述Windows/Linux 快捷键Mac 快捷键打开设置窗…...
