当前位置: 首页 > article >正文

GPT - 因果掩码(Causal Mask)

本节代码定义了一个函数 causal_mask,用于生成因果掩码(Causal Mask)。因果掩码通常用于自注意力机制中,以确保模型在解码时只能看到当前及之前的位置,而不能看到未来的信息。这种掩码在自然语言处理任务(如语言生成)中非常重要,因为它模拟了人类阅读或写作时的顺序性。

一、因果掩码(Causal Mask)代码实现

def causal_mask(x):mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0return mask
1. 输入参数
  • x:输入张量,通常是一个序列,形状为 (seq_len, d_model)(batch_size, seq_len, d_model)。这里的 seq_len 是序列的长度。

2. 生成掩码
mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
  • torch.ones(x.shape[0], x.shape[0]):生成一个形状为 (seq_len, seq_len) 的全1矩阵。

  • torch.triu(..., diagonal=1):取该矩阵的上三角部分(包括对角线),其余部分设置为0。diagonal=1 表示从对角线的下一个位置开始取上三角部分。

  • == 0:将上三角部分(包括对角线)的值设置为 False,其余部分设置为 True。这样生成的掩码矩阵中,True 表示需要保留的注意力位置,False 表示需要被忽略的注意力位置。

3. 返回值
  • mask:生成的因果掩码,形状为 (seq_len, seq_len),是一个布尔张量。

示例

假设输入张量 x 的形状为 (5, d_model),即序列长度为5。那么:

x = torch.randn(5, d_model)  # 示例输入
mask = causal_mask(x)
print(mask)

输出的掩码矩阵 mask 将是:

tensor([[ True, False, False, False, False],[ True,  True, False, False, False],[ True,  True,  True, False, False],[ True,  True,  True,  True, False],[ True,  True,  True,  True,  True]])

作用

在自注意力机制中,因果掩码用于确保模型在计算注意力分数时,只能看到当前及之前的位置,而不能看到未来的信息。具体来说:

  • True:表示可以计算注意力分数。

  • False:表示需要被忽略,注意力分数会被设置为一个非常小的值(如 -1e9),从而在 softmax 归一化后,其权重趋近于0。

二、因果掩码如何使用?

1. 因果掩码的生成

因果掩码的生成函数如下:

def causal_mask(x):mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0return mask
  • 输入x 是一个张量,通常是一个序列的嵌入表示,形状为 (seq_len, d_model)(batch_size, seq_len, d_model)

  • 输出:生成一个布尔张量 mask,形状为 (seq_len, seq_len),其中上三角部分(包括对角线)为 True,其余部分为 False

2. 因果掩码的应用

因果掩码在 Poetry 数据集类中被应用,具体如下:

class Poetry(Dataset):def __init__(self, poetries, tokenizer: Tokenizer):self.poetries = poetriesself.tokenizer = tokenizerself.pad_id = self.tokenizer.vocab["[PAD]"]self.bos_id = self.tokenizer.vocab["[BOS]"]self.eos_id = self.tokenizer.vocab["[EOS]"]def __len__(self):return len(self.poetries)def __getitem__(self, idx):poetry = self.poetries[idx]poetry_ids = self.tokenizer.encode(poetry)input_ids = torch.tensor([self.bos_id] + poetry_ids)input_msk = causal_mask(input_ids)label_ids = torch.tensor(poetry_ids + [self.eos_id])return {"input_ids": input_ids,"input_msk": input_msk,"label_ids": label_ids}
  • __getitem__ 方法

    • 对于每首诗 poetry,将其编码为 poetry_ids

    • 在输入序列的开头添加 [BOS](开始标记符),生成 input_ids

    • 使用 causal_mask 函数生成因果掩码 input_msk

    • 在标签序列的末尾添加 [EOS](结束标记符),生成 label_ids

3. 因果掩码的传递

在训练过程中,因果掩码 input_msk 会被传递给模型的自注意力层。具体如下:

for epoch in range(epochs):for batch in tqdm(trainloader, desc="Training"):batch_input_ids = batch["input_ids"]batch_input_msk = batch["input_msk"]batch_label_ids = batch["label_ids"]output = model(batch_input_ids, batch_input_msk)loss = loss_fn(output.view(-1, len(vocab)), batch_label_ids.view(-1))loss.backward()optim.step()optim.zero_grad()
  • model(batch_input_ids, batch_input_msk)

    • batch_input_ids 是输入序列的嵌入表示。

    • batch_input_msk 是对应的因果掩码。

    • 模型在计算自注意力时,会使用 batch_input_msk 来确保解码器只能看到当前及之前的位置。

4. 因果掩码的作用

MultiHeadAttention 类中,因果掩码被应用到注意力分数矩阵中:

if attn_mask is not None:attn_mask = attn_mask.unsqueeze(1)atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)
  • attn_mask.unsqueeze(1)

    • 将掩码的形状从 (batch_size, seq_len, seq_len) 扩展为 (batch_size, 1, seq_len, seq_len)

  • masked_fill

    • 将掩码中为 False 的位置的注意力分数设置为 -1e9,确保这些位置的注意力权重趋近于0。

5. 生成诗歌时的因果掩码

在生成诗歌时,因果掩码同样被应用:

def generate_poetry(method="greedy", top_k=5):model.eval()with torch.no_grad():input_ids = torch.tensor(vocab["[BOS]"]).view(1, -1)while input_ids.shape[1] < seq_len:output = model(input_ids, None)probabilities = torch.softmax(output[:, -1, :], dim=-1)if method == "greedy":next_token_id = torch.argmax(probabilities, dim=-1)elif method == "top_k":top_k_probs, top_k_indices = torch.topk(probabilities[0], top_k)next_token_id = top_k_indices[torch.multinomial(top_k_probs, 1)]if next_token_id == vocab["[EOS]"]:breakinput_ids = torch.cat([input_ids, next_token_id.view(1, 1)], dim=1)return input_ids.squeeze()
  • model(input_ids, None)

    • 在生成诗歌时,输入序列 input_ids 会逐渐增长,但因果掩码是隐含的,因为模型的自注意力层会自动处理序列的顺序性。

    • 生成过程中,模型只能看到当前及之前的位置,这与训练时使用因果掩码的目的相同。



 

相关文章:

GPT - 因果掩码(Causal Mask)

本节代码定义了一个函数 causal_mask&#xff0c;用于生成因果掩码&#xff08;Causal Mask&#xff09;。因果掩码通常用于自注意力机制中&#xff0c;以确保模型在解码时只能看到当前及之前的位置&#xff0c;而不能看到未来的信息。这种掩码在自然语言处理任务&#xff08;如…...

SpringBoot 数据库MySql的读写分离 多数据源 Shardingsphere高并发优化

介绍 传统的 MySQL 架构中&#xff0c;所有的数据库操作&#xff08;包括读操作和写操作&#xff09;都在同一个数据库实例上进行。随着应用程序的规模增长&#xff0c;单一数据库实例可能会成为瓶颈&#xff0c;无法满足高并发的需求。为了优化性能&#xff0c;可以将数据库的…...

适合工程建筑行业的OA系统有什么推荐?

工程行业具有项目周期长、协作链条复杂等特性&#xff0c;传统管理模式下的 “人治”“纸质化” 弊端日益凸显。OA 系统作为数字化管理的核心载体&#xff0c;通过流程标准化、数据可视化&#xff0c;精准解决工程行业项目管理核心痛点。 泛微 e-office 深度聚焦工程场景&#…...

如何使用 DeepSeek 帮助自己的工作?

1. 信息检索 信息检索是获取特定信息的过程&#xff0c;尤其是在大量数据或文本中查找相关内容。这个过程应用广泛&#xff0c;从网页搜索引擎到数据库查询&#xff0c;再到企业内部信息系统。在使用 DeepSeek 或其它类似工具进行信息检索时&#xff0c;可以考虑以下几个重要方…...

python对mysql数据库的操作

现在遇到一个问题如何将数据批量的插入mysql数据库中 基础操作 import asyncio from config import config from mysql_pool import MysqlPoolclass MysqlLoop(object):def __init__(self):self.logger config.loggerself.pool MysqlPool()def loop_query(self, queries):lo…...

MFC案例:利用CFileDialog类选择多个文件的实验

在MFC项目中使用CFileDialog打开文件时&#xff0c;一般的使用场景是选择一个文件&#xff0c;今天我们做一个选择多个文件的实验&#xff0c;运行环境是VS2022。 实验目标&#xff1a;在基于对话框的MFC项目中&#xff0c;通过调用CFileDialog类对象&#xff0c;将选择…...

深入解析栈回溯技术:如何通过异常处理精准定位程序崩溃点

一、栈回溯 1.1 栈回溯的原理 调试程序时&#xff0c;经常发生这类错误&#xff1a; 1.读写某个地址&#xff0c;导致程序崩溃 2.调用某个空函数&#xff0c;导致程序崩溃在异常处理函数中&#xff0c;可以打印出”发生错误瞬间”的所有寄存器。 我们调试时&#xff0c;可以…...

封装uniapp request promise化

uniapp request 封装 一、 封装方法1. 使用 promis 封装 request2. 封装 api 在 api.js3.在要请求的页面 调用 api 一、 封装方法 1. 使用 promis 封装 request const BASE_URL 你的url接口 //比如 http://198.12.3.3/pzexport function request(config {}){let {url,dat…...

重构居家养老安全网:从 “被动响应” 到 “主动守护”

随着全球老龄化加剧&#xff0c;居家养老安全成为社会关注的核心议题。 传统养老模式依赖人工巡检或单一传感器&#xff0c;存在响应滞后、隐私泄露、场景覆盖不足等问题。 由此智绅科技应运而生&#xff0c;七彩喜智慧养老系统构筑居家养老安全网。 而物联网&#xff08;Io…...

深入理解 GLOG_minloglevel 与 GLOG_v:原理与使用示例

文章目录 深入理解 GLOG_minloglevel 与 GLOG_v&#xff1a;原理与使用示例1. GLOG_minloglevel&#xff1a;最低日志等级控制2. GLOG_v&#xff1a;控制 VLOG() 的详细输出等级3. GLOG_minloglevel 与 GLOG_v 的优先级关系4. 使用示例4.1 基础示例&#xff1a;不同日志等级4.2…...

Unity6下架中国区,团结引擎接棒:这是分裂,还是本地化的开始?

就在近日&#xff0c;一则消息在国内游戏开发圈内迅速传播开来&#xff1a;Unity 6 及其后续版本已在中国大陆及港澳地区下架。这意味着&#xff0c;未来中国用户将无法直接使用 Unity 最新的主线版本。而取而代之的&#xff0c;是由 Unity 中国主导推出的本地化产品 —— 团结…...

ESP8266水位监测以及温湿度数据采集

上面就是ESP8266的引脚图&#xff0c;水温检测使用的是水位监测传感器&#xff0c;温湿度测量使用的是DHT11&#xff0c;DHT11的反应时间是2秒&#xff0c;这里要注意。开发采用Arduino程序 1. 传感器初始化 功能&#xff1a;初始化DHT11温湿度传感器和串口通信。 代码实现&…...

国产信创数据库:PolarDB 分布式版 V2.0,支持集中分布式一体化

阿里云PolarDB数据库管理软件&#xff08;分布式版&#xff09;V2.0 &#xff0c;安全可靠的集中分布式一体化数据库管理软件。点此查看详情https://www.aliyun.com/activity/database/polardbx-v2?spma2c6h.13046898.publish-article.8.44146ffaE0lEWT 立即咨询专家&#xf…...

iOS按键精灵辅助工具在游戏开发中的创新应用

一、iOS自动化测试辅助工具 在移动游戏开发领域&#xff0c;iOS按键精灵类辅助工具不同于传统的安卓自动化方案&#xff0c;iOS环境下的自动化测试面临更严峻的技术挑战&#xff0c;但通过创新方法仍可实现精准控制。 # 基于图像识别的智能定位算法示例 def find_button(butt…...

淘宝 API 与 AI 图像识别整合:开启商品主图智能解析新时代

在电商蓬勃发展的当下&#xff0c;淘宝作为行业巨头&#xff0c;承载着海量的商品信息。如何让买家更高效地筛选心仪好物&#xff0c;让卖家精准把握商品展示要点&#xff1f;淘宝 API 与 AI 图像识别技术的整合为这一难题提供了创新性解法&#xff0c;实现对商品主图实时解析&…...

Axure PR 9 中继器 09 删除行

大家好&#xff0c;我是大明同学。 接着上期的内容&#xff0c;这期内容&#xff0c;我们来了解一下Axure中继器数据表删除行交互设计。 预览地址&#xff1a;https://vvlmqu.axshare.com 删除行 1.打开上期RP 文件&#xff0c;设计一个删除弹窗元件&#xff0c; 创建为动态面…...

HDCP(五)

HDCP 2.2 测试用例设计详解 基于HDCP 2.2 CTS v1.1规范及协议核心机制&#xff0c;以下从正常流程与异常场景两大方向拆解测试用例设计要点&#xff0c;覆盖认证、密钥管理、拓扑验证等关键环节&#xff1a; 1. 正常流程测试 1.1 单设备认证 • 测试目标&#xff1a;验证源设…...

商城APP打包教程

下载 HBuilderX 工具 HBuilderX支持插件拓展功能。App开发版已集成相关插件、开箱即用 根据自身电脑系统选择对应软件下载&#xff0c;建议选择APP开发版 2. 下载好软件安装后打开 建议直接在uniapp插件页面一键导入&#xff0c;正常情况下uniapp插件都是最新的&#xff0c;大家…...

Spring 框架的核心基础:IoC 和 AOP

一、IoC&#xff08;Inversion of Control&#xff0c;控制反转&#xff09; 定义&#xff1a; IoC&#xff08;Inversion of Control&#xff0c;控制反转&#xff09;&#xff0c;就是把对象创建和依赖关系的管理交给 Spring 容器&#xff0c;而不是由程序员手动去创建对象…...

SpringBoot 基础知识,HTTP 概述

1. 概述 1.1 Spring Spring 提供若干个子项目&#xff0c;每个项目用于完成特定功能 Spring 的若干个子项目都基于一个基础的框架&#xff1a;Spring Framework 框架类似于 房屋的地基 但 Spring Framework 配置繁琐&#xff0c;入门难度大 1.2 Spring Boot 于是&#xf…...

《网络管理》实践环节04:SNMP监控数据采集流程及SNMP协议详细分析

兰生幽谷&#xff0c;不为莫服而不芳&#xff1b; 君子行义&#xff0c;不为莫知而止休。 1 实验目标 1. 理解SNMP网络管理原理 2. 掌握SNMP服务器采集SNMP Agent数据的方法 3. 掌握SNMP报文发送和应答流程 4. 掌握典型GetResponsePDU数据结构分析的方法 4. 具备SNMP通信…...

RTX30系显卡运行Tensorflow 1.15 GPU版本

​ 30系显卡只支持cuda11.0及以上版本&#xff0c;但很多tensorflow项目用的仍然是1.1x版本&#xff0c;这些版本需要cuda10或者以下版本&#xff0c;这就导致在30系显卡上无法正常运1.1x版本的tensorflow&#xff0c;最近几天我也因为这个问题头疼不已&#xff0c;网上一番搜索…...

debian系统中文输入法失效解决

在 Debian 9.6 上无法切换中文输入法的问题通常与输入法框架&#xff08;如 Fcitx 或 IBus&#xff09;的配置或依赖缺失有关。以下是详细的解决步骤&#xff1a; 1. 安装中文语言包 确保系统已安装中文语言支持&#xff1a; sudo apt update sudo apt install locales sudo…...

《Uniapp-Vue 3-TS 实战开发》构建HTTP请求拦截器

引言 在 UniApp 结合 TypeScript 和 Vue3 的项目开发中&#xff0c;请求拦截器起着至关重要的作用。它能够在请求发送前和响应接收后对数据进行统一处理&#xff0c;极大地提高了代码的可维护性和功能性。本文将详细解析上述代码中请求拦截器的实现及其在 UniApp-Ts-Vue3 项目中…...

C#基础类型系统-接口

接口 - 定义多种类型的行为 接口包含非抽象 class 或 struct 必须实现的一组相关功能的定义。接口可以定义 static 方法&#xff0c;此类方法必须具有实现。接口可为成员定义默认实现。接口不能声明实例数据&#xff0c;如字段、自动实现的属性或类似属性的事件。C#不支持类的…...

StringTemplate修仙指南:字符串处理的“言出法随“大法

各位在字符串处理苦海中挣扎的道友们&#xff01;今天要解锁的是StringTemplate这门"言出法随"的绝学——用模板语法让字符串替换变得优雅如诗&#xff01;无论是代码生成、邮件模板还是动态SQL&#xff0c;都能一键搞定&#xff01;准备好告别String.format()的混沌…...

从PDF中提取表格:以GB/T2260—2007为例

文章目录 先说结论前因后果思路1、PDF2CSV2、PDF2MD → MD2CSV3、针对不同表格的两种思路1&#xff09; 竖形三线表2&#xff09;五元素为一组 还没结束批量处理1、分割markdown文档2、跳过另一种格式的文档 总结一下 先说结论 结论就是&#xff0c;博主用了一天的时间去研究如…...

初识MySQL · 复合查询(内外连接)

目录 前言&#xff1a; 基本查询回顾 笛卡尔积和子查询 笛卡尔积 内外连接 子查询 单行子查询 多行子查询 多列子查询 from中使用子查询 合并查询 前言&#xff1a; 在前文我们学习了MySQL的基本查询&#xff0c;就是简单的套用了select语句&#xff0c;最多不过是…...

电视剧角色扮演AI Agent中的大模型操作流程

电视剧角色扮演AI Agent中的大模型操作流程 在您描述的 “电视剧角色扮演 + 挑战任务” 的AI Agent场景中,大模型(如GPT-4、Claude等)需要完成多个关键操作,以下是详细的技术流程分解: 1. 用户输入处理阶段 操作:文本向量化(Embedding) 技术实现: 使用 文本嵌入模型…...

Android学习总结之数据结构篇

Java 的集合体系 Java 的集合框架主要分为两大接口体系&#xff1a;Collection 和 Map。以下是对这两大体系下常见集合类的介绍&#xff1a; Collection 体系 Collection 是单列集合的根接口&#xff0c;它有三个主要的子接口&#xff1a;List、Set 和 Queue。 List 接口&a…...