当前位置: 首页 > news >正文

LLM - Make Causal Mask 构造因果关系掩码

目录

一.引言

二.make_causal_mask

1.完整代码

2.Torch.full

3.torch.view

4.torch.masked_fill_

5.past_key_values_length

6.Test Main

三.总结


一.引言

Causal Mask 主要用于限定模型的可视范围,防止模型看到未来的数据。在具体应用中,Causal Mask 可将所有未来的 token 设置为零,从注意力机制中屏蔽掉这些令牌,使得模型在进行预测时只能关注过去和当前的 token,并确保模型仅基于每个时间步骤可用的信息进行预测。

在 Transformer 模型中,Multihead Attention 中的 Causal Mask 就是用于解决这个问题,以实现模型对于输入序列的正确处理。下面是 Causal 的可视化示例,在实践中其呈现倒三角形状:

全文为 'I love eating lunch.' ,对于 'love' 而言其只能看到 'I',不能看到未来的 'eating'、'lunch'。

二.make_causal_mask

为了方便后续示例的展示,这里选择较小的参数,batch_size = 2,target_length = 4。

1.完整代码

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):"""Make causal mask used for bi-directional self-attention."""bsz, tgt_len = input_ids_shapemask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)mask_cond = torch.arange(mask.size(-1), device=device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(dtype)if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

代码的 Input 主要就是一个二维的 input_ids_shape,分别为 batch_size 和 target_length,dtype 和 device 在这里比较好理解,还有就是最后的 past_key_values_length,用于补齐,这个也比较简单。Output 则是 (batch_size, 1, target_length, target_length) 的 Causal Mask,其中 Msak 的矩阵 target_length x target_length 就是上面所示的倒三角形状。

2.Torch.full

函数介绍

该函数用于创建一个具有指定填充值的新张量。该函数的语法如下:

torch.full(size, fill_value, *, dtype=None, device=None, requires_grad=False)

参数说明:

  • size:张量的形状,可以是一个整数或者一个元组,例如:(3, 3) 或 3。
  • fill_value:张量的填充值。
  • dtype:张量的数据类型,默认为None,即根据输入的数据类型推断。
  • device:张量所在的设备,默认为None,即根据输入的设备推断。
  • requires_grad:是否需要计算梯度,默认为False。

该函数返回一个与指定形状相同且所有元素都被设置为指定填充值的新张量。

函数使用

mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))

这里 torch.finfo(dtype).min 为对应 torch.dtype 类型的最小值,以 bfloat16 为例:

print(torch.tensor(torch.finfo(dtype).min))
=> tensor(-3.3895e+38)

而这一步 mask 的操作就是生成一个 tgt_len x tgt_len 的充满 min 元素的方阵:

3.torch.view

函数介绍

在 PyTorch 库中,view 函数用于改变一个张量(Tensor)的形状(shape)。它返回一个新的张量,其元素与原始张量相同,但形状(shape)已被改变。view 函数的行为非常类似于 NumPy的 reshape 函数。它会返回一个与原始张量共享数据但具有不同形状的新的张量。如果给定的形状与原始张量的元素总数不匹配,则会引发错误。

import torch  x = torch.randn(4, 5)  # 创建一个4x5的随机张量  
y = x.view(20)  # 改变形状为20的一维张量  
z = x.view(-1, 10)  # 改变形状为10的一维张量,第一维度由其他维度决定

函数使用

mask_cond = torch.arange(mask.size(-1))

mask_cond 是一个1维向量:

(mask_cond + 1).view(mask.size(-1), 1)

这一步相当于在 mask_cond 基础上先加常量再 reshape:

4.torch.masked_fill_

函数介绍

在 PyTorch 库中,masked_fill_() 函数是一个张量(Tensor)方法,用于将张量中的指定区域填充为特定值。此函数需要一个掩码(mask)作为输入,该掩码应与原张量具有相同的形状。掩码中的 True 值表示需要填充的区域,False 值表示需要保留的原始值。

torch.Tensor.masked_fill_(mask, value)

参数说明: 

  • mask (Bool tensor) - 掩码张量,用于指定需要填充的区域。
  • value (float) - 填充的值。

示例

假设我们有一个 3x3 的张量,我们想要将所有大于 5 的元素替换为 -1。我们可以使用该函数来实现这个目标。

import torch  # 创建一个3x3的张量  
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # 创建一个掩码,其中大于5的元素为True,其余为False  
mask = x > 5  # 使用masked_fill_函数将大于5的元素替换为-1  
x.masked_fill_(mask, -1)  print(x)

输出:

tensor([[ 1,  2,  3],  [ 4,  5,  6],  [-1, -1, -1]])

函数使用

mask_cond < (mask_cond + 1).view(mask.size(-1), 1)

传入函数的 mask 如下,呈倒三角形态,其中 True 的部分填充新值,False 部分保持不变: 

mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

 根据 mask 对 target x target 的方阵进行填充 0 得到我们上面提到的倒三角:

5.past_key_values_length

在 PyTorch 中,past_key_values_length 是一个参数,用于指定在使用 Transformer 模型时,过去键值缓存(past key-value cache)的长度。该参数通常与 Transformer 模型中的自注意力机制(self-attention mechanism)一起使用。在过去键值缓存中,模型保存了过去的键和值向量,以便在生成序列时重复使用它们。这些过去的键和值向量可以用于计算自注意力分数,从而提高生成序列的效率。较大的past_key_values_length可以增加模型的表现力,但也会增加计算量和内存消耗。因此,需要根据具体任务和资源限制来选择合适的值。

if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)

这里定义 past_key_values_length = 1,代码逻辑就是在原有的 tgt x tgt 方阵前补 past_key_values_length 个 0:

6.Test Main

mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
if __name__ == '__main__':batch_size = 2target_length = 4input_shape = (batch_size, target_length)data_type = torch.bfloat16causal_mask = _make_causal_mask(input_shape, data_type, 1)print(causal_mask)print(causal_mask.shape)

=>        pask_key_length = 1填充后,tgt x tgt 变为 tgt x (1 + tgt)

=>        通过 None + expand 的组合,tgt x (1 + tgt) 变为 bsz x 1 x tgt x (1 + tgt)

三.总结

新系列博文一方面是阅读 HF 上 LLM 模型实现的源码,了解对应知识的实现过程。另一方面是之前很多同学主要接触 TF 1.x、TF 2.x 以及 Estimator 和 Keras 这一类深度学习工具,趁此机会也能熟悉 Torch 的使用方法。

相关文章:

LLM - Make Causal Mask 构造因果关系掩码

目录 一.引言 二.make_causal_mask 1.完整代码 2.Torch.full 3.torch.view 4.torch.masked_fill_ 5.past_key_values_length 6.Test Main 三.总结 一.引言 Causal Mask 主要用于限定模型的可视范围&#xff0c;防止模型看到未来的数据。在具体应用中&#xff0c;Caus…...

Python函数式编程(一)概念和itertools

Python函数式编程是一种编程范式&#xff0c;它强调使用纯函数来处理数据。函数是程序的基本构建块&#xff0c;并且尽可能避免或最小化可变状态和副作用。在函数式编程中&#xff0c;函数被视为一等公民&#xff0c;可以像值一样传递和存储。 函数式编程概念 编程语言支持通…...

Guava限流器原理浅析

文章目录 基本知识限流器的类图使用示例 原理解析限流整体流程问题驱动1、限流器创建的时候会初始化令牌吗&#xff1f;2、令牌是如何放到桶里的&#xff1f;3、如果要获取的令牌数大于桶里的令牌数会怎么样4、令牌数量的更新会有并发问题吗 总结 实际工作中难免有限流的场景。…...

第四十二章 持久对象和SQL - 用于创建持久类和表的选项

文章目录 第四十二章 持久对象和SQL - 用于创建持久类和表的选项用于创建持久类和表的选项访问数据 第四十二章 持久对象和SQL - 用于创建持久类和表的选项 用于创建持久类和表的选项 要创建持久类及其对应的 SQL 表&#xff0c;可以执行以下任一操作&#xff1a; 使用 IDE …...

集合-ArrayList源码分析(面试)

系列文章目录 1.集合-Collection-CSDN博客​​​​​​ 2.集合-List集合-CSDN博客 3.集合-ArrayList源码分析(面试)_喜欢吃animal milk的博客-CSDN博客 目录 系列文章目录 前言 一 . 什么是ArrayList? 二 . ArrayList集合底层原理 总结 前言 大家好,今天给大家讲一下Arra…...

跨类型文本文件,反序列化与类型转换的思考

文章目录 应用场景序列化 - 对象替换原内容&#xff0c;方便使用编写程序取得结果数组 序列化 - JSON 应用场景 在编写热更新的时候&#xff0c;我发现了一个古早的 ini 文件&#xff0c;记录了许多有用的数据 由于使用的语言年份较新&#xff0c;没有办法较好地对 ini 文件的…...

ubuntu20安装nvidia驱动

1. 查看显卡型号 lspci | grep -i nvidia 我的输出&#xff1a; 01:00.0 VGA compatible controller: NVIDIA Corporation GP104 [GeForce GTX 1080] (rev a1) 01:00.1 Audio device: NVIDIA Corporation GP104 High Definition Audio Controller (rev a1) 07:00.0 VGA comp…...

gma 2 成书计划

随着 gma 2 整体构建完成。下一步计划针对库内所有功能完成一个用户指南&#xff08;非网站&#xff09;。 封皮 主要章节 章节完成度相关链接第 1 章 GMA 概述已完成第 2 章 地理空间数据操作已完成第 3 章 坐标参考系统已完成第 4 章 地理空间制图已完成第 5 章 数学运算模…...

从零手搓一个【消息队列】项目设计、需求分析、模块划分、目录结构

文章目录 一、需求分析1, 项目简介2, BrokerServer 核心概念3, BrokerServer 提供的核心 API4, 交换机类型5, 持久化存储6, 网络通信7, TCP 连接的复用8, 需求分析小结 二、模块划分三、目录结构 提示&#xff1a;是正在努力进步的小菜鸟一只&#xff0c;如有大佬发现文章欠佳之…...

【Spring Cloud】深入探索 Nacos 注册中心的原理,服务的注册与发现,服务分层模型,负载均衡策略,微服务的权重设置,环境隔离

文章目录 前言一、初识 Nacos 注册中心1.1 什么是 Nacos1.2 Nacos 的安装&#xff0c;配置&#xff0c;启动 二、服务的注册与发现三、Nacos 服务分层模型3.1 Nacos 的服务分级存储模型3.2 服务跨集群调用问题3.3 服务集群属性设置3.4 修改负载均衡策略为集群策略 四、根据服务…...

No156.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…...

如何在PIL图像和PyTorch Tensor之间进行相互转换,使用pytorch进行PIL和tensor之间的数据转换

目录 引言PIL简介PyTorch和Torchvision简介PIL转换为TensorTensor转换为PIL实例代码和解释结论参考文献 &#x1f4dd; 引言 在计算机视觉领域&#xff0c;使用图像处理库对图像进行预处理是非常常见的。其中&#xff0c;Python Imaging Library&#xff08;PIL&#xff09;以…...

STM32F4X UCOSIII任务消息队列

STM32F4X UCOSIII任务消息队列 任务消息队列和内核消息队列对比内核消息队列内核消息队列 UCOSIII任务消息队列API任务消息队列发送函数任务消息队列接收函数 UCOSIII任务消息队列例程 之前的章节中讲解过消息队列这个机制&#xff0c;UCOSIII除了有内核消息队列之外&#xff0…...

8个居家兼职,帮助自己在家搞副业

越来越多的人开始追求居家工作的机会&#xff0c;无论是为了获得更多收入以改善生活质量&#xff0c;还是为了更好地平衡工作和家庭的关系&#xff0c;居家兼职已成为一种趋势。而在家中从事副业不仅能够为我们带来额外的收入&#xff0c;更重要的是&#xff0c;它可以让我们在…...

管理与系统思维

技术管理者不仅仅需要做事情&#xff0c;还需要以系统思维的方式推动组织变革&#xff0c;从而帮助团队和个人做到更好。原文: Management and Systems Thinking 图片来源: Dall-E "除非管理者考虑到组织的系统性&#xff0c;否则大多数提高绩效的努力都将注定失败。"…...

电死人的是电流还是电压?

先说答案&#xff0c;是电流。 这个有两个派别&#xff0c;一个是电流派&#xff0c;一个是电压派。 举个例子&#xff0c;拿我们的头发或者指甲之类的高电阻物质去接触高压&#xff0c;你会发现基本没有什么作用&#xff1b;还有就是冬天我们脱毛衣的时候&#xff0c;噼里啪啦…...

mac 编译问题记录

1、mac 编译提示 Unsupported option ‘--no-pie‘ Linux 上用 --no-pie mac 上用 -no-pie 2、mac 找不到 malloc.h 使用 #include <sys/malloc.h> Mac上使用malloc函数报错_mac malloc.h-CSDN博客...

centos 7.9同时安装JDK1.8和openjdk11两个版本

1.使用的原因 在服务器上&#xff0c;有些情况因为有一些系统比较老&#xff0c;所以需要使用JDK8版本&#xff0c;但随着时间的发展&#xff0c;新的软件出来&#xff0c;一般都会使用比较新的JDK版本。所以就出现了我们标题的需求&#xff0c;一个系统内同时安装两个不同的版…...

【JavaEE】HTML

JavaWeb HTML 超文本标记语言 超文本&#xff1a;文本、声音、图片、视频、表格、连接标记&#xff1a;有许许多多的标签组成 vscode开发工具搭建 因为我使用的IDEA是社区版&#xff0c;代码高亮补全缩进都有些问题&#xff0c;使用vscode是最好的选择~ 安装 Visual Stu…...

【数据结构--八大排序】之堆排序

&#x1f490; &#x1f338; &#x1f337; &#x1f340; &#x1f339; &#x1f33b; &#x1f33a; &#x1f341; &#x1f343; &#x1f342; &#x1f33f; &#x1f344;&#x1f35d; &#x1f35b; &#x1f364; &#x1f4c3;个人主页 &#xff1a;阿然成长日记 …...

MVC 数据库

MVC 数据库 引言 在软件开发领域,Model-View-Controller(MVC)是一种流行的软件架构模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。这种模式有助于提高代码的可维护性和可扩展性。本文将深入探讨MVC架构与数据库之间的关系,以…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

剑指offer20_链表中环的入口节点

链表中环的入口节点 给定一个链表&#xff0c;若其中包含环&#xff0c;则输出环的入口节点。 若其中不包含环&#xff0c;则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

腾讯云V3签名

想要接入腾讯云的Api&#xff0c;必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口&#xff0c;但总是卡在签名这一步&#xff0c;最后放弃选择SDK&#xff0c;这次终于自己代码实现。 可能腾讯云翻新了接口文档&#xff0c;现在阅读起来&#xff0c;清晰了很多&…...

CSS | transition 和 transform的用处和区别

省流总结&#xff1a; transform用于变换/变形&#xff0c;transition是动画控制器 transform 用来对元素进行变形&#xff0c;常见的操作如下&#xff0c;它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...

WebRTC从入门到实践 - 零基础教程

WebRTC从入门到实践 - 零基础教程 目录 WebRTC简介 基础概念 工作原理 开发环境搭建 基础实践 三个实战案例 常见问题解答 1. WebRTC简介 1.1 什么是WebRTC&#xff1f; WebRTC&#xff08;Web Real-Time Communication&#xff09;是一个支持网页浏览器进行实时语音…...

android RelativeLayout布局

<?xml version"1.0" encoding"utf-8"?> <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent"android:gravity&…...