当前位置: 首页 > news >正文

带你从入门到精通——自然语言处理(五. Transformer中的自注意力机制和输入部分)

建议先阅读我之前的博客,掌握一定的自然语言处理前置知识后再阅读本文,链接如下:

带你从入门到精通——自然语言处理(一. 文本的基本预处理方法和张量表示)-CSDN博客

带你从入门到精通——自然语言处理(二. 文本数据分析、特征处理和数据增强)-CSDN博客

带你从入门到精通——自然语言处理(三. RNN扩展和LSTM)-CSDN博客

带你从入门到精通——自然语言处理(四. GRU和seq2seq模型)-CSDN博客

目录

五. Transformer中的自注意力机制和输入部分

5.1 自注意力机制

5.2 Transformer整体架构

5.3 输入部分

5.3.1 输入部分整体架构

5.3.2 嵌入层

5.3.2 位置编码器


五. Transformer中的自注意力机制和输入部分

        Transformer模型于2017年在Google的论文《Attention is All You Need》中首次被提出,transformer是一种基于自注意力机制(Self-Attention)seq2seq架构的深度学习模型。

5.1 自注意力机制

        传统的注意力机制中的Q、K、V向量三者是不同源的,通常Q向量来自解码器,而K、V向量来自编码器,这种注意力机制被称为一般注意力机制或者交叉注意力机制,而自注意力机制要求Q、K、V向量三者同源,即三者都来自编码器或者解码器。

        最早的自注意力机制的引入是应用到LSTM模型中的,LSTM模型没有编码器和解码器的概念,因此Q、K、V向量三者默认是同源的,为了方便这里使用RNN模型代替LSTM模型进行描述,其基本思想是一致的。

        首先初始化RNN模型的隐藏状态h0以及上下文向量c0(通常使用全0张量来进行初始化),传统的RNN模型使用隐藏状态h0和当前时间步的输入x来更新隐藏状态,但带有自注意力的RNN模型则使用上下文向量c0和当前时间步的输入x来更新隐藏状态,此后,使用上一个时间步的隐藏状态作为Q向量,此前所有时间步的上下文向量作为K向量,依次计算注意力分数(通常忽略初始的全0上下文向量c0,注意力分数的计算可以使用加性注意力、点积注意力等等),随后对所有注意力分数使用softmax函数进行归一化,并使用归一化后的注意力分数对所有V向量(V向量也为所有时间步的上下文向量,即K向量=V向量)做加权平均得到新的上下文向量,RNN模型使用这一新的上下文向量以及当前时间步的输入继续更新隐藏状态,依次往复。

        带有自注意力的RNN模型的架构如下:

5.2 Transformer整体架构

        Transformer整体架构图如下:

        Transformer模型可以分为四个部分:输入部分、编码器部分、解码器部分、输出部分,后文会详细介绍各个部分。

        Transformer模型主要有如下两个优势:

        并行计算:与传统的RNN及其变体不同,transformer模型使用自注意力机制并摒弃了序列化的计算过程,允许模型并行处理整个输入序列,有着更高的计算效率和更强的性能。

        捕捉长距离依赖:自注意力机制能够直接计算输入序列中任意两个元素之间的关系,从而更好地捕捉长距离依赖,缓解长程依赖问题。

5.3 输入部分

5.3.1 输入部分整体架构

        Transfomer输入部分包含:编码器源文本的嵌入层以及位置编码器、解码器目标文本的嵌入层以及位置编码器,即下图部分:

        Transformer模型的最终输入为:

        上述公式中的input_embedding是指输入文本每个token经过Embedding层后得到的低维稠密词向量,而positional_encoding则是输入文本中每个token的位置编码向量,两个向量有着相同的长度(在原论文中向量长度为512)。

5.3.2 嵌入层

        嵌入层(Embedding Layer)的作用是将输入文本中的每个token转换为一个固定长度的低维稠密词向量,便于模型更好地捕捉到词汇的语义信息和语法信息。

        嵌入层的代码实现如下:

class MyEmbedding(nn.Module):def __init__(self, vocab_size, embedding_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_size = embedding_sizeself.ebd = nn.Embedding(vocab_size, embedding_size)def forward(self, x):# 扩大embedding后的词向量值return self.ebd(x) * math.sqrt(self.embedding_size)if __name__ == '__main__':ebd = MyEmbedding(5, 3)t = torch.randint(0, 5, (4,))print(ebd(t))'''
tensor([[-0.4648, -0.7602,  1.1441],[ 2.1027,  0.5997,  0.6691],[-0.6455,  0.0878,  2.3561],[-1.0119,  0.5721, -0.9876]], grad_fn=<MulBackward0>)'''

5.3.2 位置编码器

        RNN模型是依次输入各个token并进行编码,因此RNN模型能够直接感知输入序列中各个token之间的位置关系,而在transformer模型中,对于输入序列是并行进行编码的,因此它无法直接感知输入序列中各个token的位置关系,所以transformer中引入了位置编码器(Positional Encoding),位置编码器能够为embedding后的词向量引入该词在输入序列中位置信息。

        位置编码器能够将各个token在输入序列中的位置信息转换为一组向量,这些向量会与embedding后的词向量相加,在transformer中,位置编码的公式如下:

        上式中pos是token在输入序列中的实际位置(例如第1个token为0,第2个token为1,以此类推),i是词向量长度的下标索引,是词向量的长度,transformer中的位置编码方式属于绝对位置编码。

        因此pos=t时,该token的位置编码向量可以表示为:

        上述表达式中角频率w的取值为:,位置编码向量中的不同下标索引都对应了了一个不同的正余弦波。

        Transformer中的位置编码方法有以下三个特点:

        1. 每个token的位置编码向量的下标索引越大,其编码值所对应的sin和cos函数的角频率越小,这一特点保证了每个token的位置编码向量唯一。

        2. 位置编码向量的值是有界且连续的,这也是正余弦函数的特性,这一特点提高了模型的泛化能力,使模型能够更好处理长度和训练数据不一致的序列。

        3. 不同的位置编码向量可以通过线性变换得到,即有:,这里的T表示一个线性变换矩阵,具体的表达式如下:

        基于矩阵乘法和如下的三角函数的两角和公式,可以即可推导出上述表达式。

        上述表达式中的也被称为旋转矩阵这一特点使得位置编码向量不仅能表示一个token的绝对位置,还可以表示该token与其他token的相对位置。

        位置编码器的代码实现如下

class PositionalEncoding(nn.Module):def __init__(self, embedding_size, dropout_p=0.1, max_len=5000):super().__init__()self.dropout = nn.Dropout(dropout_p)# pe.shape = (max_len, embedding_size)pe = torch.zeros(max_len, embedding_size)# pos,shape = (max_len, 1)pos = torch.arange(0, max_len).unsqueeze(1)# idx.shape = (embedding_size // 2,)idx = torch.arange(0, embedding_size, 2, dtype=torch.float32)# 利用广播机制进行计算pe[:, ::2] = torch.sin(pos / (10000 ** (idx / embedding_size)))pe[:, 1::2] = torch.cos(pos / (10000 ** (idx / embedding_size)))# self.register_buffer用于将一个张量注册为模型的缓冲区(buffer)# 缓冲区中的数据和模型的参数类似,都会被保存到模型的状态字典中# 缓冲区中的数据不被视为可训练的参数,即不会在优化器更新模型参数时被更新。self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(1)]return self.dropout(x)if __name__ == '__main__':# embedding_size必须为偶数ebd = MyEmbedding(5, 8)pe = PositionalEncoding(8)t = torch.randint(0, 5, (2, 4))print(pe(ebd(t)).shape)# torch.Size([2, 4, 8])

相关文章:

带你从入门到精通——自然语言处理(五. Transformer中的自注意力机制和输入部分)

建议先阅读我之前的博客&#xff0c;掌握一定的自然语言处理前置知识后再阅读本文&#xff0c;链接如下&#xff1a; 带你从入门到精通——自然语言处理&#xff08;一. 文本的基本预处理方法和张量表示&#xff09;-CSDN博客 带你从入门到精通——自然语言处理&#xff08;二…...

ubuntu挂载固态硬盘

Ubuntu 中挂载位于 /dev/sdc1 的固态硬盘&#xff0c;可以按照以下步骤操作&#xff1a; 步骤 1&#xff1a;确认分区信息 首先&#xff0c;确保设备 /dev/sdc1 存在且已正确分区&#xff1a; sudo fdisk -l /dev/sdc # 查看分区表 lsblk # 确认分区路…...

WPF+WebView 基础

1、基于.NET8&#xff0c;通过NuGet添加Microsoft.Web.WebView2。 2、MainWindow.xaml代码如下。 <Window x:Class"Demo.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/win…...

国内光子AI智能引擎:OptoChat AI在南京江北新区亮相

3月3日&#xff0c;从南京市投资促进局传来振奋人心的消息&#xff0c;南京江北新区的一家高科技企业——南京南智先进光电集成技术研究院有限公司&#xff08;简称“南智光电”&#xff09;&#xff0c;携手南京知满科技等合作伙伴&#xff0c;成功研发出国内首个光子AI智能引…...

vscode离线配置远程服务器

目录 一、前提 二、方法 2.1 查看vscode的commit_id 2.2 下载linux服务器安装包 2.3 安装包上传到远程服务器&#xff0c;并进行文件解压缩 三、常见错误 Failed to set up socket for dynamic port forward to remote port&#xff08;vscode报错解决方法&#xff09;-C…...

【安装】SQL Server 2005 安装及安装包

安装包 SQLEXPR.EXE&#xff1a;SQL Server 服务SQLServer2005_SSMSEE.msi&#xff1a;数据库管理工具&#xff0c;可以创建数据库&#xff0c;执行脚本等。SQLServer2005_SSMSEE_x64.msi&#xff1a;同上。这个是 64 位操作系统。 下载地址 https://www.microsoft.com/zh-c…...

使用Maven搭建Spring Boot框架

文章目录 前言1.环境准备2.创建SpringBoot项目3.配置Maven3.1 pom.xml文件3.2 添加其他依赖 4. 编写代码4.1 启动类4.2 控制器4.3 配置文件 5.运行项目6.打包与部署6.1 打包6.2 运行JAR文件 7.总结 前言 Spring Boot 是一个用于快速构建 Spring 应用程序的框架&#xff0c;它简…...

将docker容器打包为.tar包

1. 创建打包脚本 #!/bin/bash # 设置 -e 使得脚本在遇到错误时停止执行 set -e# 必要的参数 exported_container_name"needed_export_container_name_or_id" # 需要被导出的容器的名称或id image_save_name"my_custom_image_name:v25.03.03" # 镜像需…...

SYSTEM文件夹下的文件

sys文件夹下的.c和.h文件里的函数 最重要的倒数第二个 deley文件夹下的.c和.h文件 Systick工作原理 系统滴答定时器是在内核里的 每来一个时钟信号&#xff0c;计数器减一 F1系列时钟源是HCLK&#xff08;就是AHB总线上的时钟信号&#xff09; Systick控制寄存器 Systick重装…...

GPPT: Graph Pre-training and Prompt Tuning to Generalize Graph Neural Networks

GPPT: Graph Pre-training and Prompt Tuning to Generalize Graph Neural Networks KDD22 推荐指数&#xff1a;#paper/⭐⭐#​ 动机 本文探讨了图神经网络&#xff08;GNN&#xff09;在迁移学习中“预训练-微调”框架的局限性及改进方向。现有方法通过预训练&#xff08…...

【SegRNN 源码理解】PMF的多步并行预测

位置编码 elif self.dec_way "pmf":if self.channel_id:# m,d//2 -> 1,m,d//2 -> c,m,d//2# c,d//2 -> c,1,d//2 -> c,m,d//2# c,m,d -> cm,1,d -> bcm, 1, dpos_emb torch.cat([self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),self.cha…...

构建自己的AI客服【根据用户输入生成EL表达式】

要实现一个基于对话形式的AI客服系统&#xff0c;该系统能够提示用户输入必要的信息&#xff0c;并根据用户的输入生成相应的EL&#xff08;Expression Language&#xff09;表达式编排规则&#xff0c;您可以按照以下步骤进行设计和开发。本文将涵盖系统架构设计、关键技术选型…...

(50)[HGAME 2023 week2]before_main

[HGAME 2023 week2]before_main nss:3501 我们进入那个sub_12EB然后我们发现这个就是base64加密 我们取得qword_4020: 0CxWsOemvJq4zdk2V6QlArj9wnHbt1NfEX/3DhyPoBRLY8pK5FciZau7UMIgTSG 很显然这个是自定义映射base64.然后我们代入我们之前写的base64自定义映射代码 enc:A…...

机器学习数学基础:39.样本和隐含和残差协方差矩阵

假设我们研究学生的数学成绩、英语成绩和学习时间之间的关系。收集了100名学生这三项数据作为样本。 样本协方差矩阵 计算得到的样本协方差矩阵如下&#xff08;假设数据简化&#xff09;&#xff1a; [ V a r ( 数学 ) C o v ( 数学 , 英语 ) C o v ( 数学 , 学习时间 ) C …...

java之http传MultipartFile文件

【需求】前端请求后端做文件上传或者excel上传&#xff0c;后端不解析直接把MultipartFile传给第三方平台&#xff0c;通过http的方式该怎么写 import org.springframework.web.multipart.MultipartFile;import java.io.*; import java.net.HttpURLConnection; import java.ne…...

深入解析SpringMVC中Http响应的实现机制

在Web应用开发中&#xff0c;处理HTTP请求并返回相应的HTTP响应是核心任务之一。SpringMVC作为Java生态中广泛使用的Web框架&#xff0c;提供了灵活且强大的机制来处理HTTP请求和生成HTTP响应。本文将深入探讨SpringMVC中如何实现HTTP响应的返回&#xff0c;涵盖从控制器方法的…...

构建一个支持精度、范围和负数的-Vue-数字输入框

分析并实现一个支持精度、范围和负数控制的数字输入框。 背景 在很多业务中&#xff0c;我们经常需要使用数字输入框&#xff0c;通常这些输入框会涉及到数字校验&#xff0c;比如限制输入范围、设置小数精度、是否允许负数等。每次写表单时&#xff0c;都需要重复定义这些校…...

尚硅谷爬虫note14

一、scrapy scrapy&#xff1a;为爬取网站数据是&#xff0c;提取结构性数据而编写的应用框架 1. 安装 pip install scrapy 或者&#xff0c;国内源安装 pip install scrapy -i https&#xff1a;//pypi.douban.com/simple 2. 报错 报错1&#xff09;building ‘twisted.te…...

1438. 绝对差不超过限制的最长连续子数组

目录 一、题目二、思路2.1 解题思路2.2 代码尝试2.3 疑难问题2.4 代码复盘 三、解法四、收获4.1 心得4.2 举一反三 一、题目 二、思路 2.1 解题思路 滑动窗口 2.2 代码尝试 class Solution { public:int longestSubarray(vector<int>& nums, int limit) {int cou…...

ZCC5090EA适用于TYPE-C接口,集成30V OVP功能, 最大1.5A充电电流,带NTC及使能功能,双节锂电升压充电芯片替代CS5090EA

概要&#xff1a; ZCC5090EA是一款5V输入&#xff0c;最大1.5A充电电流&#xff0c;支 持双 节 锂 电 池 串 联 应 用 的 升 压 充 电 管 理 I C 。ZCC5090EA集成功率MOS&#xff0c;采用异步开关架构&#xff0c; 使其在应用时仅需极少的外围器件&#xff0c;可有效减少整体 …...

python打卡day49

知识点回顾&#xff1a; 通道注意力模块复习空间注意力模块CBAM的定义 作业&#xff1a;尝试对今天的模型检查参数数目&#xff0c;并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...

Debian系统简介

目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版&#xff…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址&#xff1a;pdf 英文是纯手打的&#xff01;论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误&#xff0c;若有发现欢迎评论指正&#xff01;文章偏向于笔记&#xff0c;谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库&#xff0c;专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性&#xff0c;并提供了一个通用的框架&…...

Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信

文章目录 Linux C语言网络编程详细入门教程&#xff1a;如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket&#xff08;服务端和客户端都要&#xff09;2. 绑定本地地址和端口&#x…...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)

Aspose.PDF 限制绕过方案&#xff1a;Java 字节码技术实战分享&#xff08;仅供学习&#xff09; 一、Aspose.PDF 简介二、说明&#xff08;⚠️仅供学习与研究使用&#xff09;三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...