当前位置: 首页 > news >正文

LLM - Make Causal Mask 构造因果关系掩码

目录

一.引言

二.make_causal_mask

1.完整代码

2.Torch.full

3.torch.view

4.torch.masked_fill_

5.past_key_values_length

6.Test Main

三.总结


一.引言

Causal Mask 主要用于限定模型的可视范围,防止模型看到未来的数据。在具体应用中,Causal Mask 可将所有未来的 token 设置为零,从注意力机制中屏蔽掉这些令牌,使得模型在进行预测时只能关注过去和当前的 token,并确保模型仅基于每个时间步骤可用的信息进行预测。

在 Transformer 模型中,Multihead Attention 中的 Causal Mask 就是用于解决这个问题,以实现模型对于输入序列的正确处理。下面是 Causal 的可视化示例,在实践中其呈现倒三角形状:

全文为 'I love eating lunch.' ,对于 'love' 而言其只能看到 'I',不能看到未来的 'eating'、'lunch'。

二.make_causal_mask

为了方便后续示例的展示,这里选择较小的参数,batch_size = 2,target_length = 4。

1.完整代码

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):"""Make causal mask used for bi-directional self-attention."""bsz, tgt_len = input_ids_shapemask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)mask_cond = torch.arange(mask.size(-1), device=device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(dtype)if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

代码的 Input 主要就是一个二维的 input_ids_shape,分别为 batch_size 和 target_length,dtype 和 device 在这里比较好理解,还有就是最后的 past_key_values_length,用于补齐,这个也比较简单。Output 则是 (batch_size, 1, target_length, target_length) 的 Causal Mask,其中 Msak 的矩阵 target_length x target_length 就是上面所示的倒三角形状。

2.Torch.full

函数介绍

该函数用于创建一个具有指定填充值的新张量。该函数的语法如下:

torch.full(size, fill_value, *, dtype=None, device=None, requires_grad=False)

参数说明:

  • size:张量的形状,可以是一个整数或者一个元组,例如:(3, 3) 或 3。
  • fill_value:张量的填充值。
  • dtype:张量的数据类型,默认为None,即根据输入的数据类型推断。
  • device:张量所在的设备,默认为None,即根据输入的设备推断。
  • requires_grad:是否需要计算梯度,默认为False。

该函数返回一个与指定形状相同且所有元素都被设置为指定填充值的新张量。

函数使用

mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))

这里 torch.finfo(dtype).min 为对应 torch.dtype 类型的最小值,以 bfloat16 为例:

print(torch.tensor(torch.finfo(dtype).min))
=> tensor(-3.3895e+38)

而这一步 mask 的操作就是生成一个 tgt_len x tgt_len 的充满 min 元素的方阵:

3.torch.view

函数介绍

在 PyTorch 库中,view 函数用于改变一个张量(Tensor)的形状(shape)。它返回一个新的张量,其元素与原始张量相同,但形状(shape)已被改变。view 函数的行为非常类似于 NumPy的 reshape 函数。它会返回一个与原始张量共享数据但具有不同形状的新的张量。如果给定的形状与原始张量的元素总数不匹配,则会引发错误。

import torch  x = torch.randn(4, 5)  # 创建一个4x5的随机张量  
y = x.view(20)  # 改变形状为20的一维张量  
z = x.view(-1, 10)  # 改变形状为10的一维张量,第一维度由其他维度决定

函数使用

mask_cond = torch.arange(mask.size(-1))

mask_cond 是一个1维向量:

(mask_cond + 1).view(mask.size(-1), 1)

这一步相当于在 mask_cond 基础上先加常量再 reshape:

4.torch.masked_fill_

函数介绍

在 PyTorch 库中,masked_fill_() 函数是一个张量(Tensor)方法,用于将张量中的指定区域填充为特定值。此函数需要一个掩码(mask)作为输入,该掩码应与原张量具有相同的形状。掩码中的 True 值表示需要填充的区域,False 值表示需要保留的原始值。

torch.Tensor.masked_fill_(mask, value)

参数说明: 

  • mask (Bool tensor) - 掩码张量,用于指定需要填充的区域。
  • value (float) - 填充的值。

示例

假设我们有一个 3x3 的张量,我们想要将所有大于 5 的元素替换为 -1。我们可以使用该函数来实现这个目标。

import torch  # 创建一个3x3的张量  
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # 创建一个掩码,其中大于5的元素为True,其余为False  
mask = x > 5  # 使用masked_fill_函数将大于5的元素替换为-1  
x.masked_fill_(mask, -1)  print(x)

输出:

tensor([[ 1,  2,  3],  [ 4,  5,  6],  [-1, -1, -1]])

函数使用

mask_cond < (mask_cond + 1).view(mask.size(-1), 1)

传入函数的 mask 如下,呈倒三角形态,其中 True 的部分填充新值,False 部分保持不变: 

mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

 根据 mask 对 target x target 的方阵进行填充 0 得到我们上面提到的倒三角:

5.past_key_values_length

在 PyTorch 中,past_key_values_length 是一个参数,用于指定在使用 Transformer 模型时,过去键值缓存(past key-value cache)的长度。该参数通常与 Transformer 模型中的自注意力机制(self-attention mechanism)一起使用。在过去键值缓存中,模型保存了过去的键和值向量,以便在生成序列时重复使用它们。这些过去的键和值向量可以用于计算自注意力分数,从而提高生成序列的效率。较大的past_key_values_length可以增加模型的表现力,但也会增加计算量和内存消耗。因此,需要根据具体任务和资源限制来选择合适的值。

if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)

这里定义 past_key_values_length = 1,代码逻辑就是在原有的 tgt x tgt 方阵前补 past_key_values_length 个 0:

6.Test Main

mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
if __name__ == '__main__':batch_size = 2target_length = 4input_shape = (batch_size, target_length)data_type = torch.bfloat16causal_mask = _make_causal_mask(input_shape, data_type, 1)print(causal_mask)print(causal_mask.shape)

=>        pask_key_length = 1填充后,tgt x tgt 变为 tgt x (1 + tgt)

=>        通过 None + expand 的组合,tgt x (1 + tgt) 变为 bsz x 1 x tgt x (1 + tgt)

三.总结

新系列博文一方面是阅读 HF 上 LLM 模型实现的源码,了解对应知识的实现过程。另一方面是之前很多同学主要接触 TF 1.x、TF 2.x 以及 Estimator 和 Keras 这一类深度学习工具,趁此机会也能熟悉 Torch 的使用方法。

相关文章:

LLM - Make Causal Mask 构造因果关系掩码

目录 一.引言 二.make_causal_mask 1.完整代码 2.Torch.full 3.torch.view 4.torch.masked_fill_ 5.past_key_values_length 6.Test Main 三.总结 一.引言 Causal Mask 主要用于限定模型的可视范围&#xff0c;防止模型看到未来的数据。在具体应用中&#xff0c;Caus…...

Python函数式编程(一)概念和itertools

Python函数式编程是一种编程范式&#xff0c;它强调使用纯函数来处理数据。函数是程序的基本构建块&#xff0c;并且尽可能避免或最小化可变状态和副作用。在函数式编程中&#xff0c;函数被视为一等公民&#xff0c;可以像值一样传递和存储。 函数式编程概念 编程语言支持通…...

Guava限流器原理浅析

文章目录 基本知识限流器的类图使用示例 原理解析限流整体流程问题驱动1、限流器创建的时候会初始化令牌吗&#xff1f;2、令牌是如何放到桶里的&#xff1f;3、如果要获取的令牌数大于桶里的令牌数会怎么样4、令牌数量的更新会有并发问题吗 总结 实际工作中难免有限流的场景。…...

第四十二章 持久对象和SQL - 用于创建持久类和表的选项

文章目录 第四十二章 持久对象和SQL - 用于创建持久类和表的选项用于创建持久类和表的选项访问数据 第四十二章 持久对象和SQL - 用于创建持久类和表的选项 用于创建持久类和表的选项 要创建持久类及其对应的 SQL 表&#xff0c;可以执行以下任一操作&#xff1a; 使用 IDE …...

集合-ArrayList源码分析(面试)

系列文章目录 1.集合-Collection-CSDN博客​​​​​​ 2.集合-List集合-CSDN博客 3.集合-ArrayList源码分析(面试)_喜欢吃animal milk的博客-CSDN博客 目录 系列文章目录 前言 一 . 什么是ArrayList? 二 . ArrayList集合底层原理 总结 前言 大家好,今天给大家讲一下Arra…...

跨类型文本文件,反序列化与类型转换的思考

文章目录 应用场景序列化 - 对象替换原内容&#xff0c;方便使用编写程序取得结果数组 序列化 - JSON 应用场景 在编写热更新的时候&#xff0c;我发现了一个古早的 ini 文件&#xff0c;记录了许多有用的数据 由于使用的语言年份较新&#xff0c;没有办法较好地对 ini 文件的…...

ubuntu20安装nvidia驱动

1. 查看显卡型号 lspci | grep -i nvidia 我的输出&#xff1a; 01:00.0 VGA compatible controller: NVIDIA Corporation GP104 [GeForce GTX 1080] (rev a1) 01:00.1 Audio device: NVIDIA Corporation GP104 High Definition Audio Controller (rev a1) 07:00.0 VGA comp…...

gma 2 成书计划

随着 gma 2 整体构建完成。下一步计划针对库内所有功能完成一个用户指南&#xff08;非网站&#xff09;。 封皮 主要章节 章节完成度相关链接第 1 章 GMA 概述已完成第 2 章 地理空间数据操作已完成第 3 章 坐标参考系统已完成第 4 章 地理空间制图已完成第 5 章 数学运算模…...

从零手搓一个【消息队列】项目设计、需求分析、模块划分、目录结构

文章目录 一、需求分析1, 项目简介2, BrokerServer 核心概念3, BrokerServer 提供的核心 API4, 交换机类型5, 持久化存储6, 网络通信7, TCP 连接的复用8, 需求分析小结 二、模块划分三、目录结构 提示&#xff1a;是正在努力进步的小菜鸟一只&#xff0c;如有大佬发现文章欠佳之…...

【Spring Cloud】深入探索 Nacos 注册中心的原理,服务的注册与发现,服务分层模型,负载均衡策略,微服务的权重设置,环境隔离

文章目录 前言一、初识 Nacos 注册中心1.1 什么是 Nacos1.2 Nacos 的安装&#xff0c;配置&#xff0c;启动 二、服务的注册与发现三、Nacos 服务分层模型3.1 Nacos 的服务分级存储模型3.2 服务跨集群调用问题3.3 服务集群属性设置3.4 修改负载均衡策略为集群策略 四、根据服务…...

No156.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…...

如何在PIL图像和PyTorch Tensor之间进行相互转换,使用pytorch进行PIL和tensor之间的数据转换

目录 引言PIL简介PyTorch和Torchvision简介PIL转换为TensorTensor转换为PIL实例代码和解释结论参考文献 &#x1f4dd; 引言 在计算机视觉领域&#xff0c;使用图像处理库对图像进行预处理是非常常见的。其中&#xff0c;Python Imaging Library&#xff08;PIL&#xff09;以…...

STM32F4X UCOSIII任务消息队列

STM32F4X UCOSIII任务消息队列 任务消息队列和内核消息队列对比内核消息队列内核消息队列 UCOSIII任务消息队列API任务消息队列发送函数任务消息队列接收函数 UCOSIII任务消息队列例程 之前的章节中讲解过消息队列这个机制&#xff0c;UCOSIII除了有内核消息队列之外&#xff0…...

8个居家兼职,帮助自己在家搞副业

越来越多的人开始追求居家工作的机会&#xff0c;无论是为了获得更多收入以改善生活质量&#xff0c;还是为了更好地平衡工作和家庭的关系&#xff0c;居家兼职已成为一种趋势。而在家中从事副业不仅能够为我们带来额外的收入&#xff0c;更重要的是&#xff0c;它可以让我们在…...

管理与系统思维

技术管理者不仅仅需要做事情&#xff0c;还需要以系统思维的方式推动组织变革&#xff0c;从而帮助团队和个人做到更好。原文: Management and Systems Thinking 图片来源: Dall-E "除非管理者考虑到组织的系统性&#xff0c;否则大多数提高绩效的努力都将注定失败。"…...

电死人的是电流还是电压?

先说答案&#xff0c;是电流。 这个有两个派别&#xff0c;一个是电流派&#xff0c;一个是电压派。 举个例子&#xff0c;拿我们的头发或者指甲之类的高电阻物质去接触高压&#xff0c;你会发现基本没有什么作用&#xff1b;还有就是冬天我们脱毛衣的时候&#xff0c;噼里啪啦…...

mac 编译问题记录

1、mac 编译提示 Unsupported option ‘--no-pie‘ Linux 上用 --no-pie mac 上用 -no-pie 2、mac 找不到 malloc.h 使用 #include <sys/malloc.h> Mac上使用malloc函数报错_mac malloc.h-CSDN博客...

centos 7.9同时安装JDK1.8和openjdk11两个版本

1.使用的原因 在服务器上&#xff0c;有些情况因为有一些系统比较老&#xff0c;所以需要使用JDK8版本&#xff0c;但随着时间的发展&#xff0c;新的软件出来&#xff0c;一般都会使用比较新的JDK版本。所以就出现了我们标题的需求&#xff0c;一个系统内同时安装两个不同的版…...

【JavaEE】HTML

JavaWeb HTML 超文本标记语言 超文本&#xff1a;文本、声音、图片、视频、表格、连接标记&#xff1a;有许许多多的标签组成 vscode开发工具搭建 因为我使用的IDEA是社区版&#xff0c;代码高亮补全缩进都有些问题&#xff0c;使用vscode是最好的选择~ 安装 Visual Stu…...

【数据结构--八大排序】之堆排序

&#x1f490; &#x1f338; &#x1f337; &#x1f340; &#x1f339; &#x1f33b; &#x1f33a; &#x1f341; &#x1f343; &#x1f342; &#x1f33f; &#x1f344;&#x1f35d; &#x1f35b; &#x1f364; &#x1f4c3;个人主页 &#xff1a;阿然成长日记 …...

Antares LoRaWAN库深度解析:嵌入式LoRaWAN MAC层实现指南

1. Antares LoRaWAN 库深度技术解析&#xff1a;面向嵌入式工程师的 LoRaWAN MAC 层实现指南 1.1 库定位与工程价值 Antares LoRaWAN 是一个专为 Arduino 生态设计的轻量级 LoRaWAN MAC 层实现库&#xff0c;其核心价值不在于功能堆砌&#xff0c;而在于 可理解性、可调试性与…...

Halcon HImage转Bitmap性能大比拼:实测unsafe方案比安全方案快30倍的背后原因

Halcon HImage转Bitmap性能优化实战&#xff1a;从30倍差距到工业级解决方案 在工业视觉检测和实时图像处理领域&#xff0c;毫秒级的性能差异可能意味着生产线能否稳定运行。最近在为一个汽车零部件检测系统做性能优化时&#xff0c;我意外发现Halcon的HImage转Bitmap操作竟成…...

Windows下BERTopic安装避坑指南:解决hdbscan报错(附Python 3.8环境配置)

Windows下BERTopic安装避坑指南&#xff1a;解决hdbscan报错&#xff08;附Python 3.8环境配置&#xff09; 第一次在Windows上安装BERTopic时&#xff0c;那个红色的hdbscan报错信息让我盯着屏幕发了十分钟呆。作为一款强大的主题建模工具&#xff0c;BERTopic的安装本不该如此…...

Edge浏览器专属:B站直播实时字幕插件开发全记录(附源码下载)

Edge浏览器实现B站直播实时字幕的技术解析与实战 作为一名长期关注Web语音技术的开发者&#xff0c;我最近在Edge浏览器上成功实现了一个B站直播实时字幕插件。这个项目的核心价值在于解决了无字幕直播场景下的信息获取难题——根据用户反馈&#xff0c;超过68%的观众会在没有字…...

新手福音:利用快马平台生成你的第一个数学公式编辑器入门项目

最近在自学前端开发&#xff0c;一直想尝试做个数学公式编辑器来练手。作为一个完全的新手&#xff0c;从零开始写这种项目确实有点无从下手。不过我发现用InsCode(快马)平台可以很轻松地生成基础代码框架&#xff0c;再根据自己的需求调整完善&#xff0c;特别适合像我这样的初…...

深入解析FOC电机控制:从理论到实践的无传感器实现

1. 无传感器FOC控制的核心原理 磁场定向控制&#xff08;FOC&#xff09;本质上是在模拟直流电机的控制方式。想象一下小时候玩的四驱车——直流电机通过改变电压就能直接控制转速&#xff0c;简单粗暴。但三相交流电机就像个傲娇的艺术家&#xff0c;需要我们把三相电流"…...

P15801 [GESP202603 六级] 完全二叉树

[GESP202603 六级] 完全二叉树 https://www.bilibili.com/video/BV1jQAEz3Eir/ 1.4满二叉树与完全二叉树 https://www.bilibili.com/video/BV1T44y1P7Xx/ 数据结构合集 - 二叉树&完全二叉树(定义, 性质) https://www.bilibili.com/video/BV1eQ3RzxEoS/ 202603GESP六级C第2题…...

Mermaid在线编辑器:开源可视化工具的图表创作革命

Mermaid在线编辑器&#xff1a;开源可视化工具的图表创作革命 【免费下载链接】mermaid-live-editor Edit, preview and share mermaid charts/diagrams. New implementation of the live editor. 项目地址: https://gitcode.com/GitHub_Trending/me/mermaid-live-editor …...

开源编解码工具技术选型与实战指南:跨场景应用的H.264解决方案

开源编解码工具技术选型与实战指南&#xff1a;跨场景应用的H.264解决方案 【免费下载链接】openh264 Open Source H.264 Codec 项目地址: https://gitcode.com/gh_mirrors/op/openh264 一、价值定位&#xff1a;为什么开源编解码工具是技术选型的最优解 在视频技术快…...

Atomics探究(四)-- atomic flag

本篇将研究atomic_flag相关函数底层汇编指令,以及与其他原子操作函数进行比较,探讨其存在的意义。 1、标准描述: 2、定义 gcc 头文件中定义如下 typedef _Atomic struct { #if __GCC_ATOMIC_TEST_AND_SET_TRUEVAL == 1_Bool __val; #elseunsigned char __val; #endif } at…...