GPT - 因果掩码(Causal Mask)
本节代码定义了一个函数 causal_mask,用于生成因果掩码(Causal Mask)。因果掩码通常用于自注意力机制中,以确保模型在解码时只能看到当前及之前的位置,而不能看到未来的信息。这种掩码在自然语言处理任务(如语言生成)中非常重要,因为它模拟了人类阅读或写作时的顺序性。

一、因果掩码(Causal Mask)代码实现
def causal_mask(x):mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0return mask
1. 输入参数
-
x:输入张量,通常是一个序列,形状为(seq_len, d_model)或(batch_size, seq_len, d_model)。这里的seq_len是序列的长度。
2. 生成掩码
mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
-
torch.ones(x.shape[0], x.shape[0]):生成一个形状为(seq_len, seq_len)的全1矩阵。 -
torch.triu(..., diagonal=1):取该矩阵的上三角部分(包括对角线),其余部分设置为0。diagonal=1表示从对角线的下一个位置开始取上三角部分。 -
== 0:将上三角部分(包括对角线)的值设置为False,其余部分设置为True。这样生成的掩码矩阵中,True表示需要保留的注意力位置,False表示需要被忽略的注意力位置。
3. 返回值
-
mask:生成的因果掩码,形状为(seq_len, seq_len),是一个布尔张量。
示例
假设输入张量 x 的形状为 (5, d_model),即序列长度为5。那么:
x = torch.randn(5, d_model) # 示例输入
mask = causal_mask(x)
print(mask)
输出的掩码矩阵 mask 将是:
tensor([[ True, False, False, False, False],[ True, True, False, False, False],[ True, True, True, False, False],[ True, True, True, True, False],[ True, True, True, True, True]])
作用
在自注意力机制中,因果掩码用于确保模型在计算注意力分数时,只能看到当前及之前的位置,而不能看到未来的信息。具体来说:
-
True:表示可以计算注意力分数。 -
False:表示需要被忽略,注意力分数会被设置为一个非常小的值(如-1e9),从而在 softmax 归一化后,其权重趋近于0。
二、因果掩码如何使用?
1. 因果掩码的生成
因果掩码的生成函数如下:
def causal_mask(x):mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0return mask
-
输入:
x是一个张量,通常是一个序列的嵌入表示,形状为(seq_len, d_model)或(batch_size, seq_len, d_model)。 -
输出:生成一个布尔张量
mask,形状为(seq_len, seq_len),其中上三角部分(包括对角线)为True,其余部分为False。
2. 因果掩码的应用
因果掩码在 Poetry 数据集类中被应用,具体如下:
class Poetry(Dataset):def __init__(self, poetries, tokenizer: Tokenizer):self.poetries = poetriesself.tokenizer = tokenizerself.pad_id = self.tokenizer.vocab["[PAD]"]self.bos_id = self.tokenizer.vocab["[BOS]"]self.eos_id = self.tokenizer.vocab["[EOS]"]def __len__(self):return len(self.poetries)def __getitem__(self, idx):poetry = self.poetries[idx]poetry_ids = self.tokenizer.encode(poetry)input_ids = torch.tensor([self.bos_id] + poetry_ids)input_msk = causal_mask(input_ids)label_ids = torch.tensor(poetry_ids + [self.eos_id])return {"input_ids": input_ids,"input_msk": input_msk,"label_ids": label_ids}
-
__getitem__方法:-
对于每首诗
poetry,将其编码为poetry_ids。 -
在输入序列的开头添加
[BOS](开始标记符),生成input_ids。 -
使用
causal_mask函数生成因果掩码input_msk。 -
在标签序列的末尾添加
[EOS](结束标记符),生成label_ids。
-
3. 因果掩码的传递
在训练过程中,因果掩码 input_msk 会被传递给模型的自注意力层。具体如下:
for epoch in range(epochs):for batch in tqdm(trainloader, desc="Training"):batch_input_ids = batch["input_ids"]batch_input_msk = batch["input_msk"]batch_label_ids = batch["label_ids"]output = model(batch_input_ids, batch_input_msk)loss = loss_fn(output.view(-1, len(vocab)), batch_label_ids.view(-1))loss.backward()optim.step()optim.zero_grad()
-
model(batch_input_ids, batch_input_msk):-
batch_input_ids是输入序列的嵌入表示。 -
batch_input_msk是对应的因果掩码。 -
模型在计算自注意力时,会使用
batch_input_msk来确保解码器只能看到当前及之前的位置。
-
4. 因果掩码的作用
在 MultiHeadAttention 类中,因果掩码被应用到注意力分数矩阵中:
if attn_mask is not None:attn_mask = attn_mask.unsqueeze(1)atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)
-
attn_mask.unsqueeze(1):-
将掩码的形状从
(batch_size, seq_len, seq_len)扩展为(batch_size, 1, seq_len, seq_len)。
-
-
masked_fill:-
将掩码中为
False的位置的注意力分数设置为-1e9,确保这些位置的注意力权重趋近于0。
-
5. 生成诗歌时的因果掩码
在生成诗歌时,因果掩码同样被应用:
def generate_poetry(method="greedy", top_k=5):model.eval()with torch.no_grad():input_ids = torch.tensor(vocab["[BOS]"]).view(1, -1)while input_ids.shape[1] < seq_len:output = model(input_ids, None)probabilities = torch.softmax(output[:, -1, :], dim=-1)if method == "greedy":next_token_id = torch.argmax(probabilities, dim=-1)elif method == "top_k":top_k_probs, top_k_indices = torch.topk(probabilities[0], top_k)next_token_id = top_k_indices[torch.multinomial(top_k_probs, 1)]if next_token_id == vocab["[EOS]"]:breakinput_ids = torch.cat([input_ids, next_token_id.view(1, 1)], dim=1)return input_ids.squeeze()
-
model(input_ids, None):-
在生成诗歌时,输入序列
input_ids会逐渐增长,但因果掩码是隐含的,因为模型的自注意力层会自动处理序列的顺序性。 -
生成过程中,模型只能看到当前及之前的位置,这与训练时使用因果掩码的目的相同。
-
相关文章:
GPT - 因果掩码(Causal Mask)
本节代码定义了一个函数 causal_mask,用于生成因果掩码(Causal Mask)。因果掩码通常用于自注意力机制中,以确保模型在解码时只能看到当前及之前的位置,而不能看到未来的信息。这种掩码在自然语言处理任务(如…...
SpringBoot 数据库MySql的读写分离 多数据源 Shardingsphere高并发优化
介绍 传统的 MySQL 架构中,所有的数据库操作(包括读操作和写操作)都在同一个数据库实例上进行。随着应用程序的规模增长,单一数据库实例可能会成为瓶颈,无法满足高并发的需求。为了优化性能,可以将数据库的…...
适合工程建筑行业的OA系统有什么推荐?
工程行业具有项目周期长、协作链条复杂等特性,传统管理模式下的 “人治”“纸质化” 弊端日益凸显。OA 系统作为数字化管理的核心载体,通过流程标准化、数据可视化,精准解决工程行业项目管理核心痛点。 泛微 e-office 深度聚焦工程场景&#…...
如何使用 DeepSeek 帮助自己的工作?
1. 信息检索 信息检索是获取特定信息的过程,尤其是在大量数据或文本中查找相关内容。这个过程应用广泛,从网页搜索引擎到数据库查询,再到企业内部信息系统。在使用 DeepSeek 或其它类似工具进行信息检索时,可以考虑以下几个重要方…...
python对mysql数据库的操作
现在遇到一个问题如何将数据批量的插入mysql数据库中 基础操作 import asyncio from config import config from mysql_pool import MysqlPoolclass MysqlLoop(object):def __init__(self):self.logger config.loggerself.pool MysqlPool()def loop_query(self, queries):lo…...
MFC案例:利用CFileDialog类选择多个文件的实验
在MFC项目中使用CFileDialog打开文件时,一般的使用场景是选择一个文件,今天我们做一个选择多个文件的实验,运行环境是VS2022。 实验目标:在基于对话框的MFC项目中,通过调用CFileDialog类对象,将选择…...
深入解析栈回溯技术:如何通过异常处理精准定位程序崩溃点
一、栈回溯 1.1 栈回溯的原理 调试程序时,经常发生这类错误: 1.读写某个地址,导致程序崩溃 2.调用某个空函数,导致程序崩溃在异常处理函数中,可以打印出”发生错误瞬间”的所有寄存器。 我们调试时,可以…...
封装uniapp request promise化
uniapp request 封装 一、 封装方法1. 使用 promis 封装 request2. 封装 api 在 api.js3.在要请求的页面 调用 api 一、 封装方法 1. 使用 promis 封装 request const BASE_URL 你的url接口 //比如 http://198.12.3.3/pzexport function request(config {}){let {url,dat…...
重构居家养老安全网:从 “被动响应” 到 “主动守护”
随着全球老龄化加剧,居家养老安全成为社会关注的核心议题。 传统养老模式依赖人工巡检或单一传感器,存在响应滞后、隐私泄露、场景覆盖不足等问题。 由此智绅科技应运而生,七彩喜智慧养老系统构筑居家养老安全网。 而物联网(Io…...
深入理解 GLOG_minloglevel 与 GLOG_v:原理与使用示例
文章目录 深入理解 GLOG_minloglevel 与 GLOG_v:原理与使用示例1. GLOG_minloglevel:最低日志等级控制2. GLOG_v:控制 VLOG() 的详细输出等级3. GLOG_minloglevel 与 GLOG_v 的优先级关系4. 使用示例4.1 基础示例:不同日志等级4.2…...
Unity6下架中国区,团结引擎接棒:这是分裂,还是本地化的开始?
就在近日,一则消息在国内游戏开发圈内迅速传播开来:Unity 6 及其后续版本已在中国大陆及港澳地区下架。这意味着,未来中国用户将无法直接使用 Unity 最新的主线版本。而取而代之的,是由 Unity 中国主导推出的本地化产品 —— 团结…...
ESP8266水位监测以及温湿度数据采集
上面就是ESP8266的引脚图,水温检测使用的是水位监测传感器,温湿度测量使用的是DHT11,DHT11的反应时间是2秒,这里要注意。开发采用Arduino程序 1. 传感器初始化 功能:初始化DHT11温湿度传感器和串口通信。 代码实现&…...
国产信创数据库:PolarDB 分布式版 V2.0,支持集中分布式一体化
阿里云PolarDB数据库管理软件(分布式版)V2.0 ,安全可靠的集中分布式一体化数据库管理软件。点此查看详情https://www.aliyun.com/activity/database/polardbx-v2?spma2c6h.13046898.publish-article.8.44146ffaE0lEWT 立即咨询专家…...
iOS按键精灵辅助工具在游戏开发中的创新应用
一、iOS自动化测试辅助工具 在移动游戏开发领域,iOS按键精灵类辅助工具不同于传统的安卓自动化方案,iOS环境下的自动化测试面临更严峻的技术挑战,但通过创新方法仍可实现精准控制。 # 基于图像识别的智能定位算法示例 def find_button(butt…...
淘宝 API 与 AI 图像识别整合:开启商品主图智能解析新时代
在电商蓬勃发展的当下,淘宝作为行业巨头,承载着海量的商品信息。如何让买家更高效地筛选心仪好物,让卖家精准把握商品展示要点?淘宝 API 与 AI 图像识别技术的整合为这一难题提供了创新性解法,实现对商品主图实时解析&…...
Axure PR 9 中继器 09 删除行
大家好,我是大明同学。 接着上期的内容,这期内容,我们来了解一下Axure中继器数据表删除行交互设计。 预览地址:https://vvlmqu.axshare.com 删除行 1.打开上期RP 文件,设计一个删除弹窗元件, 创建为动态面…...
HDCP(五)
HDCP 2.2 测试用例设计详解 基于HDCP 2.2 CTS v1.1规范及协议核心机制,以下从正常流程与异常场景两大方向拆解测试用例设计要点,覆盖认证、密钥管理、拓扑验证等关键环节: 1. 正常流程测试 1.1 单设备认证 • 测试目标:验证源设…...
商城APP打包教程
下载 HBuilderX 工具 HBuilderX支持插件拓展功能。App开发版已集成相关插件、开箱即用 根据自身电脑系统选择对应软件下载,建议选择APP开发版 2. 下载好软件安装后打开 建议直接在uniapp插件页面一键导入,正常情况下uniapp插件都是最新的,大家…...
Spring 框架的核心基础:IoC 和 AOP
一、IoC(Inversion of Control,控制反转) 定义: IoC(Inversion of Control,控制反转),就是把对象创建和依赖关系的管理交给 Spring 容器,而不是由程序员手动去创建对象…...
SpringBoot 基础知识,HTTP 概述
1. 概述 1.1 Spring Spring 提供若干个子项目,每个项目用于完成特定功能 Spring 的若干个子项目都基于一个基础的框架:Spring Framework 框架类似于 房屋的地基 但 Spring Framework 配置繁琐,入门难度大 1.2 Spring Boot 于是…...
《网络管理》实践环节04:SNMP监控数据采集流程及SNMP协议详细分析
兰生幽谷,不为莫服而不芳; 君子行义,不为莫知而止休。 1 实验目标 1. 理解SNMP网络管理原理 2. 掌握SNMP服务器采集SNMP Agent数据的方法 3. 掌握SNMP报文发送和应答流程 4. 掌握典型GetResponsePDU数据结构分析的方法 4. 具备SNMP通信…...
RTX30系显卡运行Tensorflow 1.15 GPU版本
30系显卡只支持cuda11.0及以上版本,但很多tensorflow项目用的仍然是1.1x版本,这些版本需要cuda10或者以下版本,这就导致在30系显卡上无法正常运1.1x版本的tensorflow,最近几天我也因为这个问题头疼不已,网上一番搜索…...
debian系统中文输入法失效解决
在 Debian 9.6 上无法切换中文输入法的问题通常与输入法框架(如 Fcitx 或 IBus)的配置或依赖缺失有关。以下是详细的解决步骤: 1. 安装中文语言包 确保系统已安装中文语言支持: sudo apt update sudo apt install locales sudo…...
《Uniapp-Vue 3-TS 实战开发》构建HTTP请求拦截器
引言 在 UniApp 结合 TypeScript 和 Vue3 的项目开发中,请求拦截器起着至关重要的作用。它能够在请求发送前和响应接收后对数据进行统一处理,极大地提高了代码的可维护性和功能性。本文将详细解析上述代码中请求拦截器的实现及其在 UniApp-Ts-Vue3 项目中…...
C#基础类型系统-接口
接口 - 定义多种类型的行为 接口包含非抽象 class 或 struct 必须实现的一组相关功能的定义。接口可以定义 static 方法,此类方法必须具有实现。接口可为成员定义默认实现。接口不能声明实例数据,如字段、自动实现的属性或类似属性的事件。C#不支持类的…...
StringTemplate修仙指南:字符串处理的“言出法随“大法
各位在字符串处理苦海中挣扎的道友们!今天要解锁的是StringTemplate这门"言出法随"的绝学——用模板语法让字符串替换变得优雅如诗!无论是代码生成、邮件模板还是动态SQL,都能一键搞定!准备好告别String.format()的混沌…...
从PDF中提取表格:以GB/T2260—2007为例
文章目录 先说结论前因后果思路1、PDF2CSV2、PDF2MD → MD2CSV3、针对不同表格的两种思路1) 竖形三线表2)五元素为一组 还没结束批量处理1、分割markdown文档2、跳过另一种格式的文档 总结一下 先说结论 结论就是,博主用了一天的时间去研究如…...
初识MySQL · 复合查询(内外连接)
目录 前言: 基本查询回顾 笛卡尔积和子查询 笛卡尔积 内外连接 子查询 单行子查询 多行子查询 多列子查询 from中使用子查询 合并查询 前言: 在前文我们学习了MySQL的基本查询,就是简单的套用了select语句,最多不过是…...
电视剧角色扮演AI Agent中的大模型操作流程
电视剧角色扮演AI Agent中的大模型操作流程 在您描述的 “电视剧角色扮演 + 挑战任务” 的AI Agent场景中,大模型(如GPT-4、Claude等)需要完成多个关键操作,以下是详细的技术流程分解: 1. 用户输入处理阶段 操作:文本向量化(Embedding) 技术实现: 使用 文本嵌入模型…...
Android学习总结之数据结构篇
Java 的集合体系 Java 的集合框架主要分为两大接口体系:Collection 和 Map。以下是对这两大体系下常见集合类的介绍: Collection 体系 Collection 是单列集合的根接口,它有三个主要的子接口:List、Set 和 Queue。 List 接口&a…...
