当前位置: 首页 > news >正文

Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的dec_inputs: [batch_size, tgt_len]enc_inpus: [batch_size, src_len]enc_outputs: [batch_size, src_len, d_model]'''dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵# 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为Falsedec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +dec_self_attn_subsequence_mask), 0).cuda()# 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]for layer in self.layers:dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]],[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]],[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的attn_mask: [batch_size, n_heads, seq_len, seq_len]'''# 1) 计算注意力分数QK^T/sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]# 2)  进行 mask 和 softmax# mask为True的位置会被设为-1e9scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]# 3) 乘V得到最终的加权和context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]'''得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来'''return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

相关文章:

Transformer的PyTorch实现之若干问题探讨(二)

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。 1.Transformer中decoder的流程 在论文《Attention is all you need》中&#xff0…...

解释Python中的GIL(全局解释器锁)及其影响。描述Python中的垃圾回收机制。Python中的类变量和实例变量有什么区别

解释Python中的GIL(全局解释器锁)及其影响 Python中的GIL(全局解释器锁)是CPython解释器中的一个机制,用于同步线程的执行。GIL确保任何时候只有一个线程在执行Python字节码。这意味着,即使在多核或多处理器…...

Appium使用初体验之参数配置,简单能够运行起来

一、服务器配置 Appium Server配置与Appium Server GUI(可视化客户端)中的配置对应,尤其是二者如果不在同一台机器上,那么就需要配置Appium Server GUI所在机器的IP(Appium Server GUI的HOST也需要配置本机IP&#xf…...

Java:JDK8新特性(Stream流)、File类、递归 --黑马笔记

一、JDK8新特性(Stream流) 接下来我们学习一个全新的知识,叫做Stream流(也叫Stream API)。它是从JDK8以后才有的一个新特性,是专业用于对集合或者数组进行便捷操作的。有多方便呢?我们用一个案…...

【Unity ShaderGraph】| 物体靠近时局部溶解,根据坐标控制溶解的位置【文末送书】

前言 【Unity ShaderGraph】| 物体靠近时局部溶解,根据坐标控制溶解的位置一、效果展示二、根据坐标控制溶解的位置,物体靠近局部溶解三、应用实例👑评论区抽奖送书 前言 本文将使用ShaderGraph制作一个根据坐标控制溶解的位置,物…...

测试OpenSIPS3.4.3的lua模块

这几天测试OpenSIPS3.4.3的lua模块,记录如下: 有bug,但能用 但现实世界就是这样,总是不完美的,发现之后马上提了issue 下面这段代码运行报错: function func1(msg) xlog("ERR","…...

【机器学习】数据清洗之处理缺失点

🎈个人主页:甜美的江 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步…...

Linux 命令行的世界 :2.文件系统中跳转

我们需要学习的第一件事(除了打字之外)是如何在 Linux 文件系统中跳转。在这一章节中,我们将介绍以下命令:pwd 打印出当前工作目录名 cd 更改目录 ls 列出目录内容 Linux以分层目录结构来组织所有文件。这就意味着所有文件…...

R语言:箱线图绘制(添加平均值趋势线)

箱线图绘制 1. 写在前面2.箱线图绘制2.1 相关R包导入2.2 数据导入及格式转换2.3 ggplot绘图 1. 写在前面 今天有时间把之前使用过的一些代码和大家分享,其中箱线图绘制我认为是非常有用的一个部分。之前我是比较喜欢使用origin进行绘图,但是绘制的图不太…...

Open3D 模型切片

目录 一、算法原理1、算法过程2、主要函数二、代码实现三、结果展示1、原始数据2、切片结果本文由CSDN点云侠原创,原文链接。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT。 一、算法原理...

KtConnect 本地连接连接K8S工具

KT Connect简介 Kt Connect (Kubernetes Developer Tool)是一个阿里开源、轻量级的面向 Kubernetes 用户的开发测试环境治理辅助工具。其核心是通过建立本地到集群以及集群到本地的双向通道。 1.阿里开源,轻量级, 2. 安装快捷简单&#xf…...

【Java万花筒】数据的安全钥匙:Java的加密与保护方法

编码的盾牌:Java开发人员的安全性武器库 前言 在当今数字化时代,保护用户数据和信息的安全已成为开发人员的首要任务。无论是在Web应用程序开发还是安全测试中,加密和安全性都是至关重要的。本文将介绍六个Java库和工具,它们为开…...

【Java多线程案例】实现阻塞队列

1. 阻塞队列简介 1.1 阻塞队列概念 阻塞队列:是一种特殊的队列,具有队列"先进先出"的特性,同时相较于普通队列,阻塞队列是线程安全的,并且带有阻塞功能,表现形式如下: 当队列满时&…...

【制作100个unity游戏之24】unity制作一个3D动物AI生态系统游戏3(附项目源码)

最终效果 文章目录 最终效果系列目录前言随着地面法线旋转在地形上随机生成动物不同部位颜色不同最终效果源码完结系列目录 前言 欢迎来到【制作100个Unity游戏】系列!本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第24篇中,我们将探索如何用unity制作一…...

home work day5

第四章 堆与拷贝构造函数 一 、程序阅读题 1、给出下面程序输出结果。 #include <iostream.h> class example {int a; public: example(int b5){ab;} void print(){aa1;cout <<a<<"";} void print()const {cout<<a<<endl;} …...

c#安全-nativeAOT

文章目录 前记AOT测试反序列化Emit 前记 JIT\AOT JIT编译器&#xff08;Just-in-Time Complier&#xff09;,AOT编译器&#xff08;Ahead-of-Time Complier&#xff09;。 AOT测试 首先编译一段普通代码 using System; using System.Runtime.InteropServices; namespace co…...

【Java】案例:检测MySQL是否存在某数据库,没有则创建

1.代码 package hello; import java.sql.*;public class CeShi {//定义基本数据static final String JDBC_DRIVER "com.mysql.cj.jdbc.Driver";static final String DB_URL "jdbc:mysql://localhost/";static final String USER "your_username&q…...

内网渗透靶场02----Weblogic反序列化+域渗透

网络拓扑&#xff1a; 攻击机&#xff1a; Kali: 192.168.111.129 Win10: 192.168.111.128 靶场基本配置&#xff1a;web服务器双网卡机器&#xff1a; 192.168.111.80&#xff08;模拟外网&#xff09;10.10.10.80&#xff08;模拟内网&#xff09;域成员机器 WIN7PC192.168.…...

[嵌入式系统-9]:C语言程序调用汇编语言程序的三种方式

目录 1. 使用函数声明和函数调用&#xff1a; 2. 使用汇编内联&#xff08;Inline Assembly&#xff09;&#xff1a; 3. 使用汇编代码文件和链接器&#xff1a; C语言程序可以调用汇编程序的方式有多种&#xff0c;下面列举了几种常见的方式&#xff1a; 1. 使用函数声明和…...

备战蓝桥杯---搜索(完结篇)

再看一道不完全是搜索的题&#xff1a; 解法1&#xff1a;贪心并查集&#xff1a; 把冲突事件从大到小排&#xff0c;判断是否两个在同一集合&#xff0c;在的话就返回&#xff0c;不在的话就合并。 下面是AC代码&#xff1a; #include<bits/stdc.h> using namespace …...

告别Xshell!Mac上这款免费串口工具CoolTerm,固件调试日志记录真香了

告别Xshell&#xff01;Mac上这款免费串口工具CoolTerm&#xff0c;固件调试日志记录真香了 从Windows切换到Mac平台的嵌入式开发者&#xff0c;最头疼的莫过于找不到趁手的串口调试工具。Xshell和SecureCRT在Windows上堪称神器&#xff0c;但它们的Mac版本要么收费高昂&#…...

SmolVLA模型服务监控与告警体系搭建

SmolVLA模型服务监控与告警体系搭建 你刚把SmolVLA模型部署上线&#xff0c;看着它流畅地处理着第一批请求&#xff0c;心里总算踏实了点。但没过多久&#xff0c;问题就来了&#xff1a;半夜突然收到用户反馈说服务变慢了&#xff0c;你赶紧爬起来查&#xff0c;发现是GPU显存…...

5分钟搞定KEPserver V6配置:Java读取西门子PLC数据的保姆级教程

5分钟极速配置KEPserver V6与Java通信&#xff1a;西门子S7-1500数据采集实战指南 当工业现场的PLC数据需要与IT系统集成时&#xff0c;OPC技术栈往往是最直接的选择。但传统OPC配置过程繁琐的文档和复杂的依赖管理&#xff0c;常让工程师在项目初期耗费大量时间在环境搭建上。…...

RPCS3完全攻略:从零开始打造你的PC端PS3游戏中心

RPCS3完全攻略&#xff1a;从零开始打造你的PC端PS3游戏中心 【免费下载链接】rpcs3 PS3 emulator/debugger 项目地址: https://gitcode.com/GitHub_Trending/rp/rpcs3 还在为无法重温经典PS3游戏而烦恼吗&#xff1f;想要在电脑上体验《最后生还者》、《神秘海域》等索…...

PDF文本高效提取:用pdftotext实现秒级文档内容解析

PDF文本高效提取&#xff1a;用pdftotext实现秒级文档内容解析 【免费下载链接】pdftotext Simple PDF text extraction 项目地址: https://gitcode.com/gh_mirrors/pd/pdftotext 破解PDF提取痛点&#xff1a;为什么你需要专业工具&#xff1f; 每天面对数十份PDF文档却…...

计算机毕业设计springboot高校实验室安全巡检系统 基于SpringBoot的高校实验室智能安防监管平台 SpringBoot框架下高校实验楼安全隐患排查与预警系统

计算机毕业设计springboot高校实验室安全巡检系统4p1y5wo9 &#xff08;配套有源码 程序 mysql数据库 论文&#xff09; 本套源码可以在文本联xi,先看具体系统功能演示视频领取&#xff0c;可分享源码参考。随着高等教育规模的持续扩张&#xff0c;高校实验室数量与类型日益增多…...

为什么你的Java车载模块在-40℃冷启动失败?温度敏感型JIT编译失效分析与AOT预编译加固方案(ISO 26262 Part 6实证)

第一章&#xff1a;Java车载系统实时性优化技巧在车载嵌入式环境中&#xff0c;Java虚拟机&#xff08;JVM&#xff09;的默认行为往往难以满足毫秒级响应、确定性调度与低抖动等硬实时需求。尽管Java并非传统实时语言&#xff0c;但通过深度配置与架构约束&#xff0c;可显著提…...

如何7天免费使用Cursor Pro:无限制AI编程助手完整指南

如何7天免费使用Cursor Pro&#xff1a;无限制AI编程助手完整指南 【免费下载链接】cursor-free-vip [Support 0.45]&#xff08;Multi Language 多语言&#xff09;自动注册 Cursor Ai &#xff0c;自动重置机器ID &#xff0c; 免费升级使用Pro 功能: Youve reached your tri…...

OpenClaw多模态扩展:结合百川2-13B-4bits与OCR的图像信息处理流程

OpenClaw多模态扩展&#xff1a;结合百川2-13B-4bits与OCR的图像信息处理流程 1. 为什么需要多模态能力扩展&#xff1f; 上周我需要整理一批技术文档的截图&#xff0c;包含代码片段、错误日志和流程图。手动转录不仅耗时&#xff0c;还容易出错。这让我开始思考&#xff1a…...

Qwen3-Reranker-0.6B快速体验:搭建个人语义排序服务的简单方法

Qwen3-Reranker-0.6B快速体验&#xff1a;搭建个人语义排序服务的简单方法 1. 为什么你需要一个轻量级语义排序服务 在信息检索和问答系统中&#xff0c;语义排序&#xff08;Reranking&#xff09;是一个关键环节。想象一下&#xff0c;当用户输入一个问题后&#xff0c;系统…...