当前位置: 首页 > news >正文

【Pytorch】大语言模型中的CrossEntropyLoss

文章目录

  • 前言
  • 什么是CrossEntropyLoss
  • 语言模型中的CrossEntropyLoss
    • 计算loss的前期准备
    • CrossEntropyLoss的输入
    • CrossEntropyLoss的输出
  • 额外说明

前言

在大语言模型时代,我们常常使用交叉熵损失函数来计算loss,因此,理解该loss的计算流程有助于帮助我们对训练过程有更清晰的认知。本文从以下几个角度介绍nn.CrossEntropyLoss()

  • 使用该函数的前期准备:如何组织函数的输入(logits & labels)
  • 该函数流程
  • 常用参数
  • 该文章内容仅为个人理解,如有误解,欢迎讨论

什么是CrossEntropyLoss

这部分并不是本文的重点,我们仅介绍在语言模型的训练过程中,如何利用该loss

  • 相关信息可见:本人博客
  • 以及官网:CrossEntropyLoss官网

语言模型中的CrossEntropyLoss

计算loss的前期准备

huggingface-transformers源码中,我们在语言模型的forward中总是能看到这样一段函数。我们以LlamaForCausalLM为例:Llama源码

if labels is not None:# Shift so that tokens < n predict nshift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss_fct = CrossEntropyLoss()shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)# Enable model parallelismshift_labels = shift_labels.to(shift_logits.device)loss = loss_fct(shift_logits, shift_labels)if not return_dict:output = (logits,) + outputs[1:]return (loss,) + output if loss is not None else output

对于Decoder-only模型,在训练时,我们的目标是next token prediction,任务流程如下

  • 假定我们是常规的问答任务,问题是“where is the capital of China“,label为“The capital is Beijing”。该任务的目标为,当输入为“where is the capital of China“时,

  • 我们对question和label进行拼接和tokenize化,一般转化结果 (tokenize忽略) 为:< bos > where is the capital of China < sep > The capital is Beijing < eos >

    • < bos>为句子开头的标志
    • < sep>用于分隔question和label,本质作用是,当模型看到时就知道:问题结束了,下一个token要输出答案了
    • < eos>为生成结束的标志
    • 假定每个词算一个token (忽略空格),那么输入一共有13个token
  • 这时我们将整个序列输入到模型中,模型在每个token的位置都生成一个向量,我们利用lm_head将最后一层的hidden state转化成词表大小的向量logits,用于后续利用Softmax确定每个token的概率

  • 现在模型有了输出logits,怎么计算loss?

    • 对比labels和logits之间的差异来计算loss

    • 现在一共有13个token,生成了13个logits,每个logits都是用于生成next token的。那么很直接的,我们来对比该logits生成的next token准不准就好了

      • 输入:< bos> where is the capital of China < sep> The capital is Beijing < eos>

      • 对比情况为:< sep>->The, The->capital, …, is->Beijing, Beijing->< eos>

        • < sep>对应位置要生成The,…, Beijing对应位置要输出< eos>
      • 我们可以将输入右移一位作为labels: where is the capital of China < sep> The capital is Beijing

        • 可以看到,对于输入来说, < eos>位置没有对应的需要生成的token,因此我们去掉该token
        • 对于labels,< bos>不需要生成,因此我们去掉该token
      • 因此,我们在计算loss时,对logits去尾,labels是输入掐头且右移一位

      • 在代码中对应

          shift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()
        

CrossEntropyLoss的输入

此时还不能直接将shift_logitsshift_labels进行对比,来计算loss。因为我们上面的操作只是为了<sep> The capital is BeijingThe capital is Beijing <eos>中的token能一一对应起来,对于其他部分生成的token,我们并没有要求(因为不是answer,不需要生成)

  • CrossEntropyLoss函数中有一个参数为ignore_idx默认值为-100。labels值设置为-100的位置不会计算loss
  • 因此我们将除了需要计算loss的位置 (最后5个位置)的labels都设置为-100
  • 最终,需要输入到CrossEntropyLoss中的inputs和labels为
    • inputs为: [, where, is, the, capital, of, China, < sep>, The, capital, is, Beijing]对应的logits
      • 注意:不需要进行Softmax,直接传logits即可,函数内部有更稳定的Softmax计算方式
    • labels为: [-100, -100, -100, -100, -100, -100, -100, The, capital, is, Beijing, < eos>]
    • 我们在训练时,构造输入和labels要注意构造为这种形式

CrossEntropyLoss的输出

默认情况下,输出为mean,即各个token计算得到loss的平均值(在token-level上平均,分母是token的个数)

import torch
import torch.nn as nn# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])# 标签,其中第二个样本的标签为 ignore_index (-100)
labels = torch.tensor([0, -100, 2])# 定义 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()# 计算损失
loss = criterion(logits, labels)print(f"Loss: {loss}")
>>> Loss: 0.51058030128479
  • 常用参数:

    • reduction:控制loss的输出形式,共三种'none', 'mean', 'sum',默认为'mean'

      • mean: 每个token计算得到的loss的平均值

      • none: 直接返回每个token计算得到的loss

        • 例子:

          import torch
          import torch.nn as nn# 假设有 3 个类,logits 形状为 (batch_size=3, num_classes=3)
          logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3], [1.5, 0.5, 2.0]])# 标签,其中第二个样本的标签为 ignore_index (-100)
          labels = torch.tensor([0, -100, 2])# 定义 CrossEntropyLoss
          criterion = nn.CrossEntropyLoss(reduction='none')# 计算损失
          loss = criterion(logits, labels)print(f"Loss: {loss}")
          >>> Loss: tensor([0.4170, 0.0000, 0.6041])
          
      • sum: 所有token对应loss求和

额外说明

对最上面的代码补充说明

  shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)
  • 训练数据往往是按batch组织的,shape为(batch_size, seq_len, vocab_size)
  • 我们将所有batch的token压缩为一个序列,计算整个序列的loss,这样比较方便

相关文章:

【Pytorch】大语言模型中的CrossEntropyLoss

文章目录 前言什么是CrossEntropyLoss语言模型中的CrossEntropyLoss计算loss的前期准备CrossEntropyLoss的输入CrossEntropyLoss的输出 额外说明 前言 在大语言模型时代&#xff0c;我们常常使用交叉熵损失函数来计算loss&#xff0c;因此&#xff0c;理解该loss的计算流程有助…...

安全热点问题

安全热点问题 1.DDOS2.补丁管理3.堡垒机管理4.加密机管理 1.DDOS 分布式拒绝服务攻击&#xff0c;是指黑客通过控制由多个肉鸡或服务器组成的僵尸网络&#xff0c;向目标发送大量看似合法的请求&#xff0c;从而占用大量网络资源使网络瘫痪&#xff0c;阻止用户对网络资源的正…...

C++——用选择法对10个数值进行排序。

没注释的源代码 #include <iostream> using namespace std; int main() { int i,j,min,a[11],temp; cout<<"请输入数组a的十个值&#xff1a;"<<endl; for(i1;i<10;i) { cin>>a[i]; } for(i1;i<9;…...

CSP-CCF★★★201909-2小明种苹果(续)★★★

一、问题描述 二、解答 关键&#xff1a;判断是否发生苹果掉落&#xff0c;使用flag[]数组来标记&#xff0c;1为掉落&#xff0c;0为没有掉落&#xff0c;这样也是为了后续比较连续三棵树是否掉落 误区&#xff1a;用最后一次正数&#xff08;即最后一次统计苹果个数&#x…...

硬件工程师笔试面试——变压器

目录 9、变压器 9.1 基础 变压器原理图 变压器实物图 9.1.1 概念 9.1.2 变压器组成结构 9.1.3 变压器原理 9.1.4 变压器的类型 9.1.5 应用领域 9.2 相关问题 9.2.1 变压器的工作原理是什么? 9.2.2 如何选择合适的变压器类型? 9.2.3 变压器在实际应用中,如何进行…...

Visual Studio Code( VS Code)倍速提高编程工作效率的免费的源代码编辑器

耕耘于编程二十多年&#xff0c;后端、前端、操作系统、数据库、脚本都做过&#xff0c;各种各样的编程工具&#xff0c;IDE开发环境都用过&#xff0c;但是让我感觉比较好用、容易上手、能够提高工作效率的开发工具还是VS Code&#xff0c;下面我就简单的介绍一下这个广泛使用…...

华为SMU02B1智能通信电源监控单元模块简介

华为SMU02B1是一款智能通信电源监控单元模块&#xff0c;专为5G嵌入式机框设计&#xff0c;它在通信电源管理领域扮演着重要角色。以下是对该产品的详细介绍&#xff1a; 一、产品概述 主要功能&#xff1a;华为SMU02B1能够监控和管理通信电源系统&#xff0c;提供站点监控功能…...

【刷题日记】15. 三数之和

15. 三数之和 两数之和可以用巧思也可以用map 三数之和会更加复杂一点&#xff0c;且这道题还需要考虑避免重复答案&#xff01; 思路&#xff1a; 特判&#xff1a;检如果nums 为 null 或长度小于 3直接返回空数组。排序&#xff1a;使用 sort对数组进行升序排序。就变成了…...

低级编程语言和高级编程语言

一.区分低级编程语言和高级编程语言的方法 1.低级编程语言 低级编程语言,并不是简单的编程语言,而是写起来很费事的编程语言,如所有编程语言的"祖宗":汇编语言,写起来极其麻烦,说不定一个 int a1; 它就得写好几行,甚至十几行 这样麻烦的编程语言为什么还没消失那,因…...

Spring Boot-API网关问题

****### Spring Boot API 网关问题分析与解决方案 在微服务架构中&#xff0c;API 网关扮演着非常重要的角色。它位于客户端和微服务之间&#xff0c;充当所有外部请求的入口&#xff0c;负责请求的路由、聚合、鉴权、限流等功能。Spring Boot 提供了多种方式实现 API 网关&am…...

三 auto占位符

3.1 重新定义的auto关键字 1.当用一个auto关键字声明多个变量的时候&#xff0c;编译器遵从由左往右的推导规则&#xff0c;以最左边的表达式推断auto的具体类型 int n 5; auto *pn &n, m 10;// 这里auto被推导为 int 所以int m 10;合理 auto *pns &n, m 10.0;/…...

tail: inotify 资源耗尽

解决方法&#xff1a; 增加可用的 inotify 监视器数量。可以通过修改系统配置文件来增加监视器数量限制。 临时增加&#xff08;直到下次重启&#xff09;&#xff1a;执行 echo 1048576 | sudo tee -a /proc/sys/fs/inotify/max_user_instances 和 echo 65536 | sudo tee -a /…...

什么是损失函数?常见的损失函数有哪些?

损失函数 什么是损失函数&#xff1f;损失函数作用如何设计损失函数常见的损失函数有哪些&#xff1f; 什么是损失函数&#xff1f; 损失函数&#xff08;Loss Function&#xff09;&#xff0c;也称为误差函数&#xff0c;是机器学习和深度学习中的一个重要概念。它用于衡量模…...

Python Web 开发中的国际化与本地化处理

Python Web 开发中的国际化与本地化处理 目录 &#x1f30d; Flask中的国际化与本地化处理&#x1f310; Django中的国际化与本地化处理&#x1f5e3;️ 多语言支持与翻译系统实现&#x1f552; 时区和日期的本地化处理 1. &#x1f30d; Flask中的国际化与本地化处理 Flask…...

android API、SDK与android版本

随着 Android 系统的不断更新&#xff0c;API Level 也会随之增加。每个新的 API Level 都引入了新的功能、改进旧的功能&#xff0c;或者弃用了旧的 API。开发者在开发应用时&#xff0c;需要指定目标 API Level&#xff0c;也就是应用最低支持的 Android 版本。 API Level 与…...

OpenHarmony(鸿蒙南向开发)——小型系统内核(LiteOS-A)【内核通信机制】下

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 子系统开发内核 轻量系统内核&#xff08;LiteOS-M&#xff09; 轻量系统内核&#…...

如何联系真正的开发者而非公司??

&#x1f3c6;本文收录于《全栈Bug调优(实战版)》专栏&#xff0c;主要记录项目实战过程中所遇到的Bug或因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&am…...

OpenCV运动分析和目标跟踪(1)累积操作函数accumulate()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 将一个图像添加到累积图像中。 该函数将 src 或其部分元素添加到 dst 中&#xff1a; dst ( x , y ) ← dst ( x , y ) src ( x , y ) if mask…...

source ~/.bash_profile有什么用

source ~/.bash_profile 是在 Unix/Linux 系统上用来重新加载用户的 Bash 配置文件 ~/.bash_profile 的命令。这条命令的作用是使得当前的 Bash 环境重新读取并应用 ~/.bash_profile 中的设置和变量定义。 作用&#xff1a; 1. 更新环境变量&#xff1a; ~/.bash_profile 是用户…...

【C++笔记】类和对象的深入理解(三)

【C笔记】类和对象的深入理解(三) &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;C笔记 文章目录 【C笔记】类和对象的深入理解(三)前言一.日期类的实现1.1声明和定义分离1.2日期类整数1.3日期类整数1.4日期类-整数1.5日期类-日期1.6复用对…...

万象视界灵坛实操手册:使用Prometheus+Grafana监控CLIP推理延迟、GPU利用率、QPS指标

万象视界灵坛实操手册&#xff1a;使用PrometheusGrafana监控CLIP推理延迟、GPU利用率、QPS指标 1. 监控系统概述 在现代AI应用部署中&#xff0c;实时监控系统性能指标是确保服务稳定运行的关键。对于万象视界灵坛这样的多模态智能感知平台&#xff0c;我们需要重点关注三个…...

M5Stamp C3 Mate LED驱动库:基于RMT的WS2812B精简控制方案

1. 项目概述M5StampC3LED 是专为 M5Stamp C3 Mate 模块设计的 LED 控制库&#xff0c;其本质是一个轻量级封装层&#xff0c;用于驱动板载的 Adafruit NeoPixel&#xff08;WS2812B 兼容&#xff09;RGB LED。该库不直接实现底层时序协议&#xff0c;而是基于 ESP-IDF 或 Ardui…...

大数据标注工具对比:2023年最值得推荐的5款工具

大数据标注工具对比&#xff1a;2023年最值得推荐的5款工具关键词&#xff1a;大数据标注工具、2023年推荐、工具对比、标注效率、标注质量摘要&#xff1a;本文聚焦于2023年大数据标注领域&#xff0c;详细对比了五款极具代表性的大数据标注工具。通过对它们的核心概念、算法原…...

TX12 + ExpressLRS 915MHz RC链路优化与EdgeTX固件升级实战

1. 为什么选择TX12搭配ExpressLRS 915MHz系统 玩无人机的朋友都知道&#xff0c;遥控链路就像风筝线&#xff0c;距离和稳定性直接决定飞行体验。我之前用2.4GHz的RadioLink套装&#xff0c;飞到500米就开始心跳加速——信号时断时续&#xff0c;每次返航都像在赌运气。换成TX1…...

TouchGal终极指南:一站式Galgame社区如何让玩家找到纯净交流空间

TouchGal终极指南&#xff1a;一站式Galgame社区如何让玩家找到纯净交流空间 【免费下载链接】kun-touchgal-next TouchGAL是立足于分享快乐的一站式Galgame文化社区, 为Gal爱好者提供一片净土! 项目地址: https://gitcode.com/gh_mirrors/ku/kun-touchgal-next 你是否曾…...

LinkSwift:重新定义网盘下载体验的八大平台直链解析工具

LinkSwift&#xff1a;重新定义网盘下载体验的八大平台直链解析工具 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 &#xff0c;支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天…...

Radiant Player媒体键集成:揭秘硬件控制背后的技术

Radiant Player媒体键集成&#xff1a;揭秘硬件控制背后的技术 【免费下载链接】radiant-player-mac :notes: Turn Google Play Music into a separate, beautiful application that integrates with your Mac. 项目地址: https://gitcode.com/gh_mirrors/ra/radiant-player-…...

安防弱电智能化VISIO图例实战指南:从入门到精通的设计技巧

1. VISIO在安防弱电设计中的核心价值 第一次接触安防弱电智能化设计时&#xff0c;我被各种复杂的系统连接关系搞得头晕眼花。直到发现VISIO这个神器&#xff0c;才真正体会到什么叫"一图胜千言"。不同于普通CAD软件&#xff0c;VISIO最大的优势在于它专为系统图设计…...

Minecraft启动器与游戏配置工具全攻略:从新手到大师的进阶指南

Minecraft启动器与游戏配置工具全攻略&#xff1a;从新手到大师的进阶指南 Minecraft启动器是每一位玩家进入方块世界的第一道门&#xff0c;而一款优秀的游戏配置工具则能让你的冒险之旅更加顺畅。本文将以玩家视角&#xff0c;带你深入了解如何利用PCL2-CE这款强大的开源工具…...

告别重复造轮子,用快马平台一键生成OpenClaw高效工具模块

最近在做一个机器人控制项目&#xff0c;需要集成OpenClaw机械爪模块。传统开发方式需要从零开始写大量重复代码&#xff0c;效率很低。后来尝试用InsCode(快马)平台生成核心模块&#xff0c;效果出乎意料的好。这里分享下具体实现思路和优化点&#xff1a; 安全初始化模块设计…...