带你从入门到精通——自然语言处理(五. Transformer中的自注意力机制和输入部分)
建议先阅读我之前的博客,掌握一定的自然语言处理前置知识后再阅读本文,链接如下:
带你从入门到精通——自然语言处理(一. 文本的基本预处理方法和张量表示)-CSDN博客
带你从入门到精通——自然语言处理(二. 文本数据分析、特征处理和数据增强)-CSDN博客
带你从入门到精通——自然语言处理(三. RNN扩展和LSTM)-CSDN博客
带你从入门到精通——自然语言处理(四. GRU和seq2seq模型)-CSDN博客
目录
五. Transformer中的自注意力机制和输入部分
5.1 自注意力机制
5.2 Transformer整体架构
5.3 输入部分
5.3.1 输入部分整体架构
5.3.2 嵌入层
5.3.2 位置编码器
五. Transformer中的自注意力机制和输入部分
Transformer模型于2017年在Google的论文《Attention is All You Need》中首次被提出,transformer是一种基于自注意力机制(Self-Attention)和seq2seq架构的深度学习模型。
5.1 自注意力机制
传统的注意力机制中的Q、K、V向量三者是不同源的,通常Q向量来自解码器,而K、V向量来自编码器,这种注意力机制被称为一般注意力机制或者交叉注意力机制,而自注意力机制要求Q、K、V向量三者同源,即三者都来自编码器或者解码器。
最早的自注意力机制的引入是应用到LSTM模型中的,LSTM模型没有编码器和解码器的概念,因此Q、K、V向量三者默认是同源的,为了方便这里使用RNN模型代替LSTM模型进行描述,其基本思想是一致的。
首先初始化RNN模型的隐藏状态h0以及上下文向量c0(通常使用全0张量来进行初始化),传统的RNN模型使用隐藏状态h0和当前时间步的输入x来更新隐藏状态,但带有自注意力的RNN模型则使用上下文向量c0和当前时间步的输入x来更新隐藏状态,此后,使用上一个时间步的隐藏状态作为Q向量,此前所有时间步的上下文向量作为K向量,依次计算注意力分数(通常忽略初始的全0上下文向量c0,注意力分数的计算可以使用加性注意力、点积注意力等等),随后对所有注意力分数使用softmax函数进行归一化,并使用归一化后的注意力分数对所有V向量(V向量也为所有时间步的上下文向量,即K向量=V向量)做加权平均得到新的上下文向量,RNN模型使用这一新的上下文向量以及当前时间步的输入继续更新隐藏状态,依次往复。
带有自注意力的RNN模型的架构如下:
5.2 Transformer整体架构
Transformer整体架构图如下:
Transformer模型可以分为四个部分:输入部分、编码器部分、解码器部分、输出部分,后文会详细介绍各个部分。
Transformer模型主要有如下两个优势:
并行计算:与传统的RNN及其变体不同,transformer模型使用自注意力机制并摒弃了序列化的计算过程,允许模型并行处理整个输入序列,有着更高的计算效率和更强的性能。
捕捉长距离依赖:自注意力机制能够直接计算输入序列中任意两个元素之间的关系,从而更好地捕捉长距离依赖,缓解长程依赖问题。
5.3 输入部分
5.3.1 输入部分整体架构
Transfomer输入部分包含:编码器源文本的嵌入层以及位置编码器、解码器目标文本的嵌入层以及位置编码器,即下图部分:
Transformer模型的最终输入为:
上述公式中的input_embedding是指输入文本每个token经过Embedding层后得到的低维稠密词向量,而positional_encoding则是输入文本中每个token的位置编码向量,两个向量有着相同的长度(在原论文中向量长度为512)。
5.3.2 嵌入层
嵌入层(Embedding Layer)的作用是将输入文本中的每个token转换为一个固定长度的低维稠密词向量,便于模型更好地捕捉到词汇的语义信息和语法信息。
嵌入层的代码实现如下:
class MyEmbedding(nn.Module):def __init__(self, vocab_size, embedding_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_size = embedding_sizeself.ebd = nn.Embedding(vocab_size, embedding_size)def forward(self, x):# 扩大embedding后的词向量值return self.ebd(x) * math.sqrt(self.embedding_size)if __name__ == '__main__':ebd = MyEmbedding(5, 3)t = torch.randint(0, 5, (4,))print(ebd(t))'''
tensor([[-0.4648, -0.7602, 1.1441],[ 2.1027, 0.5997, 0.6691],[-0.6455, 0.0878, 2.3561],[-1.0119, 0.5721, -0.9876]], grad_fn=<MulBackward0>)'''
5.3.2 位置编码器
RNN模型是依次输入各个token并进行编码,因此RNN模型能够直接感知输入序列中各个token之间的位置关系,而在transformer模型中,对于输入序列是并行进行编码的,因此它无法直接感知输入序列中各个token的位置关系,所以transformer中引入了位置编码器(Positional Encoding),位置编码器能够为embedding后的词向量引入该词在输入序列中位置信息。
位置编码器能够将各个token在输入序列中的位置信息转换为一组向量,这些向量会与embedding后的词向量相加,在transformer中,位置编码的公式如下:
上式中pos是token在输入序列中的实际位置(例如第1个token为0,第2个token为1,以此类推),i是词向量长度的下标索引,是词向量的长度,transformer中的位置编码方式属于绝对位置编码。
因此pos=t时,该token的位置编码向量可以表示为:
上述表达式中角频率w的取值为:,位置编码向量中的不同下标索引都对应了了一个不同的正余弦波。
Transformer中的位置编码方法有以下三个特点:
1. 每个token的位置编码向量的下标索引越大,其编码值所对应的sin和cos函数的角频率越小,这一特点保证了每个token的位置编码向量唯一。
2. 位置编码向量的值是有界且连续的,这也是正余弦函数的特性,这一特点提高了模型的泛化能力,使模型能够更好处理长度和训练数据不一致的序列。
3. 不同的位置编码向量可以通过线性变换得到,即有:,这里的T表示一个线性变换矩阵,具体的表达式如下:
基于矩阵乘法和如下的三角函数的两角和公式,可以即可推导出上述表达式。
上述表达式中的也被称为旋转矩阵,这一特点使得位置编码向量不仅能表示一个token的绝对位置,还可以表示该token与其他token的相对位置。
位置编码器的代码实现如下
class PositionalEncoding(nn.Module):def __init__(self, embedding_size, dropout_p=0.1, max_len=5000):super().__init__()self.dropout = nn.Dropout(dropout_p)# pe.shape = (max_len, embedding_size)pe = torch.zeros(max_len, embedding_size)# pos,shape = (max_len, 1)pos = torch.arange(0, max_len).unsqueeze(1)# idx.shape = (embedding_size // 2,)idx = torch.arange(0, embedding_size, 2, dtype=torch.float32)# 利用广播机制进行计算pe[:, ::2] = torch.sin(pos / (10000 ** (idx / embedding_size)))pe[:, 1::2] = torch.cos(pos / (10000 ** (idx / embedding_size)))# self.register_buffer用于将一个张量注册为模型的缓冲区(buffer)# 缓冲区中的数据和模型的参数类似,都会被保存到模型的状态字典中# 缓冲区中的数据不被视为可训练的参数,即不会在优化器更新模型参数时被更新。self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(1)]return self.dropout(x)if __name__ == '__main__':# embedding_size必须为偶数ebd = MyEmbedding(5, 8)pe = PositionalEncoding(8)t = torch.randint(0, 5, (2, 4))print(pe(ebd(t)).shape)# torch.Size([2, 4, 8])
相关文章:

带你从入门到精通——自然语言处理(五. Transformer中的自注意力机制和输入部分)
建议先阅读我之前的博客,掌握一定的自然语言处理前置知识后再阅读本文,链接如下: 带你从入门到精通——自然语言处理(一. 文本的基本预处理方法和张量表示)-CSDN博客 带你从入门到精通——自然语言处理(二…...

ubuntu挂载固态硬盘
Ubuntu 中挂载位于 /dev/sdc1 的固态硬盘,可以按照以下步骤操作: 步骤 1:确认分区信息 首先,确保设备 /dev/sdc1 存在且已正确分区: sudo fdisk -l /dev/sdc # 查看分区表 lsblk # 确认分区路…...

WPF+WebView 基础
1、基于.NET8,通过NuGet添加Microsoft.Web.WebView2。 2、MainWindow.xaml代码如下。 <Window x:Class"Demo.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/win…...

国内光子AI智能引擎:OptoChat AI在南京江北新区亮相
3月3日,从南京市投资促进局传来振奋人心的消息,南京江北新区的一家高科技企业——南京南智先进光电集成技术研究院有限公司(简称“南智光电”),携手南京知满科技等合作伙伴,成功研发出国内首个光子AI智能引…...

vscode离线配置远程服务器
目录 一、前提 二、方法 2.1 查看vscode的commit_id 2.2 下载linux服务器安装包 2.3 安装包上传到远程服务器,并进行文件解压缩 三、常见错误 Failed to set up socket for dynamic port forward to remote port(vscode报错解决方法)-C…...
【安装】SQL Server 2005 安装及安装包
安装包 SQLEXPR.EXE:SQL Server 服务SQLServer2005_SSMSEE.msi:数据库管理工具,可以创建数据库,执行脚本等。SQLServer2005_SSMSEE_x64.msi:同上。这个是 64 位操作系统。 下载地址 https://www.microsoft.com/zh-c…...

使用Maven搭建Spring Boot框架
文章目录 前言1.环境准备2.创建SpringBoot项目3.配置Maven3.1 pom.xml文件3.2 添加其他依赖 4. 编写代码4.1 启动类4.2 控制器4.3 配置文件 5.运行项目6.打包与部署6.1 打包6.2 运行JAR文件 7.总结 前言 Spring Boot 是一个用于快速构建 Spring 应用程序的框架,它简…...

将docker容器打包为.tar包
1. 创建打包脚本 #!/bin/bash # 设置 -e 使得脚本在遇到错误时停止执行 set -e# 必要的参数 exported_container_name"needed_export_container_name_or_id" # 需要被导出的容器的名称或id image_save_name"my_custom_image_name:v25.03.03" # 镜像需…...

SYSTEM文件夹下的文件
sys文件夹下的.c和.h文件里的函数 最重要的倒数第二个 deley文件夹下的.c和.h文件 Systick工作原理 系统滴答定时器是在内核里的 每来一个时钟信号,计数器减一 F1系列时钟源是HCLK(就是AHB总线上的时钟信号) Systick控制寄存器 Systick重装…...

GPPT: Graph Pre-training and Prompt Tuning to Generalize Graph Neural Networks
GPPT: Graph Pre-training and Prompt Tuning to Generalize Graph Neural Networks KDD22 推荐指数:#paper/⭐⭐# 动机 本文探讨了图神经网络(GNN)在迁移学习中“预训练-微调”框架的局限性及改进方向。现有方法通过预训练(…...

【SegRNN 源码理解】PMF的多步并行预测
位置编码 elif self.dec_way "pmf":if self.channel_id:# m,d//2 -> 1,m,d//2 -> c,m,d//2# c,d//2 -> c,1,d//2 -> c,m,d//2# c,m,d -> cm,1,d -> bcm, 1, dpos_emb torch.cat([self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),self.cha…...
构建自己的AI客服【根据用户输入生成EL表达式】
要实现一个基于对话形式的AI客服系统,该系统能够提示用户输入必要的信息,并根据用户的输入生成相应的EL(Expression Language)表达式编排规则,您可以按照以下步骤进行设计和开发。本文将涵盖系统架构设计、关键技术选型…...

(50)[HGAME 2023 week2]before_main
[HGAME 2023 week2]before_main nss:3501 我们进入那个sub_12EB然后我们发现这个就是base64加密 我们取得qword_4020: 0CxWsOemvJq4zdk2V6QlArj9wnHbt1NfEX/3DhyPoBRLY8pK5FciZau7UMIgTSG 很显然这个是自定义映射base64.然后我们代入我们之前写的base64自定义映射代码 enc:A…...
机器学习数学基础:39.样本和隐含和残差协方差矩阵
假设我们研究学生的数学成绩、英语成绩和学习时间之间的关系。收集了100名学生这三项数据作为样本。 样本协方差矩阵 计算得到的样本协方差矩阵如下(假设数据简化): [ V a r ( 数学 ) C o v ( 数学 , 英语 ) C o v ( 数学 , 学习时间 ) C …...
java之http传MultipartFile文件
【需求】前端请求后端做文件上传或者excel上传,后端不解析直接把MultipartFile传给第三方平台,通过http的方式该怎么写 import org.springframework.web.multipart.MultipartFile;import java.io.*; import java.net.HttpURLConnection; import java.ne…...
深入解析SpringMVC中Http响应的实现机制
在Web应用开发中,处理HTTP请求并返回相应的HTTP响应是核心任务之一。SpringMVC作为Java生态中广泛使用的Web框架,提供了灵活且强大的机制来处理HTTP请求和生成HTTP响应。本文将深入探讨SpringMVC中如何实现HTTP响应的返回,涵盖从控制器方法的…...
构建一个支持精度、范围和负数的-Vue-数字输入框
分析并实现一个支持精度、范围和负数控制的数字输入框。 背景 在很多业务中,我们经常需要使用数字输入框,通常这些输入框会涉及到数字校验,比如限制输入范围、设置小数精度、是否允许负数等。每次写表单时,都需要重复定义这些校…...

尚硅谷爬虫note14
一、scrapy scrapy:为爬取网站数据是,提取结构性数据而编写的应用框架 1. 安装 pip install scrapy 或者,国内源安装 pip install scrapy -i https://pypi.douban.com/simple 2. 报错 报错1)building ‘twisted.te…...

1438. 绝对差不超过限制的最长连续子数组
目录 一、题目二、思路2.1 解题思路2.2 代码尝试2.3 疑难问题2.4 代码复盘 三、解法四、收获4.1 心得4.2 举一反三 一、题目 二、思路 2.1 解题思路 滑动窗口 2.2 代码尝试 class Solution { public:int longestSubarray(vector<int>& nums, int limit) {int cou…...
ZCC5090EA适用于TYPE-C接口,集成30V OVP功能, 最大1.5A充电电流,带NTC及使能功能,双节锂电升压充电芯片替代CS5090EA
概要: ZCC5090EA是一款5V输入,最大1.5A充电电流,支 持双 节 锂 电 池 串 联 应 用 的 升 压 充 电 管 理 I C 。ZCC5090EA集成功率MOS,采用异步开关架构, 使其在应用时仅需极少的外围器件,可有效减少整体 …...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...

K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...

使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...

什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...
MySQL用户和授权
开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务: test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

中医有效性探讨
文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...

MySQL 知识小结(一)
一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库,分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷,但是文件存放起来数据比较冗余,用二进制能够更好管理咱们M…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)
RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发,后来由Pivotal Software Inc.(现为VMware子公司)接管。RabbitMQ 是一个开源的消息代理和队列服务器,用 Erlang 语言编写。广泛应用于各种分布…...
提升移动端网页调试效率:WebDebugX 与常见工具组合实践
在日常移动端开发中,网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时,开发者迫切需要一套高效、可靠且跨平台的调试方案。过去,我们或多或少使用过 Chrome DevTools、Remote Debug…...