当前位置: 首页 > news >正文

大模型入门 ch 03:注意力机制

本文是github上的大模型教程LLMs-from-scratch的学习笔记,教程地址:教程链接

Chapter 3: Attention Mechanism

本文首先从固定参数的注意力机制说起,然后拓展到可以训练的注意力机制,然后加入掩码mask,最后拓展到多头注意力机制。


1. 注意力机制

一个句子中的每一个token,都会受到其他token的影响(这里先不考虑忽略未来的单词,掩码的问题后面再说),注意力机制可以让一个token收到其他token的影响,生成一个最终我们想要的embedding。即每个token有一个原始的embedding,通过注意力机制后,得到了一个新的embedding,这个embedding是结合了上下文语义得到的。

举个简单的例子,我们直接使用tokens的embedding之间两两点乘,得到互相之间的点乘结果,然后将点乘结果归一化,得到embeddings之间的注意力得分。

归一化一般使用softmax函数,通过取指数,除以求和得到归一化结果
torch.exp(x) / torch.exp(x).sum(dim=0)

得到token之间的相关权重后,我们就可以加权求和,得到每一个token的最终embedding。


2. 可以训练的注意力头

在上面的例子中,我们直接使用token对应的embedding来计算相关系数,以及最终的加权求和,这显然是不合理的,如果这样的话,那么我们只能训练token对应的词嵌入来学习模型,或者是一些全连接层,因此我们需要引入新的矩阵,来学习到更多的参数,这就是transformer的QKV矩阵。
QKV都是对原始的embedding做线性变换,得到新的向量,但是模型就可以通过训练QKV,学习海量知识。

在这里插入图片描述

QKV的维度不固定,可以与原始嵌入相同,也可以不同。总之,通过QKV三个矩阵,我们将原始token的embedding转换成了3个新的向量。

  • Query vector: q ( i ) = W q x ( i ) q^{(i)} = W_q \,x^{(i)} q(i)=Wqx(i)
  • Key vector: k ( i ) = W k x ( i ) k^{(i)} = W_k \,x^{(i)} k(i)=Wkx(i)
  • Value vector: v ( i ) = W v x ( i ) v^{(i)} = W_v \,x^{(i)} v(i)=Wvx(i)

可以使用矩阵乘法实现:

keys = inputs @ W_key 
values = inputs @ W_value

然后我们计算KQ之间的点积,作为两两token之间的关联度。为什么要用两个不一样的矩阵,我的猜测是,如果使用的是一个矩阵计算相似度,那么关于对角线对称的元素就会完全相同,但是使用两个不同的矩阵计算,就不会存在这样的情况,可以学习到的内容更多。

我们使用K和Q的点积得到了两两之间的注意力得分,同样使用softmax进行归一化,得到最终的注意力权重。

注意到没有直接对注意力得分softmax,而是除以维度的方根后再softmax,这是因为在计算注意力权重时,如果直接将Query和Key的点积结果用于softmax函数,当Key的维度较高时,点积的结果会变得非常大。这可能导致softmax函数在梯度下降过程中学习困难,因为大的数值会使softmax的梯度变得非常小(接近于0),这在数值稳定性上是一个问题,称为“梯度消失”。

最后一步,不再使用原始的embedding加权,我们使用V矩阵变换后的向量进行加权求和,得到结果向量。

代码如下

class SelfAttention_v2(nn.Module):def __init__(self, d_in, d_out, qkv_bias=False):super().__init__()self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)def forward(self, x):keys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.Tattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)context_vec = attn_weights @ valuesreturn context_vec

3. 隐藏未来的单词

对语言任务来说,在训练模型的时候不能使用未来的文本来预测之前的文本。因此我们需要屏蔽未见文本对先前文本的影响。在我们计算得到注意力权重后,我们人为地将上三角矩阵的权重置为0。
有一种naive的方法,就是将上注意力权重都置为0后,重新对剩下的元素归一化。但是我们要介绍的是一般使用的方法:

我们在计算出注意力得分后,对右上角元素都赋值为负无穷大,负无穷大在经过softmax后就变为了0。

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)

最后为了防止过拟合,一般会使用dropout,对注意力权重矩阵进行随机丢弃,加强模型泛化性能。

总结以上的所有内容,我们现在就可以写出一个单头的注意力机制了,并且加入了对batch输入的处理:

class CausalAttention(nn.Module):def __init__(self, d_in, d_out, context_length,dropout, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout) # Newself.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Newdef forward(self, x):b, num_tokens, d_in = x.shape # New batch dimension bkeys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.transpose(1, 2) # Changed transposeattn_scores.masked_fill_(  # New, _ ops are in-placeself.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_sizeattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights) # Newcontext_vec = attn_weights @ valuesreturn context_vec

4. 多头注意力机制

我们已经实现了单个头的注意力机制,那么要实现多个头,就是使用多个不同的注意力头,各自对输入进行处理,然后将各自得到的输出$z_i$拼接起来,非常显而易见,我们有第一个最直白的写法:
class MultiHeadAttentionWrapper(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])def forward(self, x):return torch.cat([head(x) for head in self.heads], dim=-1)

这是一种最简单直白的写法,直接声明num_heads个注意力单元,然后在前向传播的时候,依次调用这num_heads个注意力头,然后将输出拼接起来。(dim=-1代表最后一维拼接)


问题是,这样的话需要循环num_heads次得到结果,并且需要声明num_heads个注意力头,相信熟悉线性代数的朋友已经想到了,可以通过曾广矩阵来拓展注意力头。比如单个的注意力头是(d_in, d_out),那么有n个头的注意力机制就是(d_in, n*d_out)
假设输入是(tokens, d_in),那么(tokens, d_in) @ (d_in, n*d_out) --> (tokens, n*d_out),输出的结果完美得到了n个头对应的输出,我们只需要按照每d_out列拆开,得到n(tokens, d_out)的矩阵,就能还原出n个头对应的结果,进行后续的attention score计算。这样写起来虽然麻烦一些,但是效率更高。

class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()assert (d_out % num_heads == 0), \"d_out must be divisible by num_heads"self.d_out = d_outself.num_heads = num_headsself.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dimself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputsself.dropout = nn.Dropout(dropout)self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length),diagonal=1))def forward(self, x):b, num_tokens, d_in = x.shapekeys = self.W_key(x) # Shape: (b, num_tokens, d_out)queries = self.W_query(x)values = self.W_value(x)# We implicitly split the matrix by adding a `num_heads` dimension# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim)queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)keys = keys.transpose(1, 2)queries = queries.transpose(1, 2)values = values.transpose(1, 2)# Compute scaled dot-product attention (aka self-attention) with a causal maskattn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head# Original mask truncated to the number of tokens and converted to booleanmask_bool = self.mask.bool()[:num_tokens, :num_tokens]# Use the mask to fill attention scoresattn_scores.masked_fill_(mask_bool, -torch.inf)attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)# Shape: (b, num_tokens, num_heads, head_dim)context_vec = (attn_weights @ values).transpose(1, 2) # Combine heads, where self.d_out = self.num_heads * self.head_dimcontext_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)context_vec = self.out_proj(context_vec) # optional projectionreturn context_vec

不同于前者,将不同的注意力头分开计算,第二种方法直接扩展query,key和value矩阵的列数,将多个矩阵运算简化为一个矩阵运算,计算完再更改维度还原成一个个注意力头,效率更高。这样,我们就完成了一个完整的注意力机制。

相关文章:

大模型入门 ch 03:注意力机制

本文是github上的大模型教程LLMs-from-scratch的学习笔记,教程地址:教程链接 Chapter 3: Attention Mechanism 本文首先从固定参数的注意力机制说起,然后拓展到可以训练的注意力机制,然后加入掩码mask,最后…...

STM32点亮第一个LED

还有第二个,并轮换。 准备入门STM32,于是拿出了买到手至少2年的洋桃M1板子,STM32F103C8T6 配置有3个LED,3个按钮,RS232,RS485,CAN,有JTAG,有RTC电池,IO口引…...

[Linux]:动静态库

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:Linux学习 贝蒂的主页:Betty’s blog 1. 动静态库的介绍 一般而言,库分为动态库和静态库。 在Linux当中…...

windows 显示进程地址空间

windows 显示进程地址空间 windows 显示进程地址空间 文章目录 windows 显示进程地址空间显示进程地址空间 显示进程地址空间 /* 3-ProcessInfo.cpp 显示进程地址空间 */#include "..\\CommonFiles\\CmnHdr.h" #include "..\\CommonFiles\\Toolhelp.h"#i…...

Android 12 SystemUI下拉状态栏禁止QuickQSPanel展开

1.概述 遇到需求,QuickQSPanel首次下拉后展示快捷功能模块以后就是显示QuickQSPanel,而不展开QSPanel,接下来要从下滑手势下拉出状态栏分析功能实现。也就是直接是展开状态。 2、涉及核心类 frameworks\base\packages\SystemUI\src\com\and…...

二分思想与相关问题(下)

接下来详细讲解几道比较难的例题,仔细体会二分和其他概念混合在一起的趣味。 下面这道题涉及了“碎片拼接”的概念,很妙,也很难想。 P r o b l e m 5 Problem5 Problem5 同时运行N台电脑的最长时间 LeetCode2141 你有 n 台电脑。给你整数 n…...

【算法专题】搜索算法

二叉树剪枝 LCR 047. 二叉树剪枝 - 力扣(LeetCode) 本题要求我们将全部为0的二叉树去掉,也就是剪枝,当我们举一个具体的例子进行模拟时,会发现,只关注于对其中一个子树的根节点进行剪枝,由于我…...

B2064 斐波那契数列

题目描述 斐波那契数列是指这样的数列:数列的第一个和第二个数都为 11,接下来每个数都等于前面 22 个数之和。 给出一个正整数 aa,要求斐波那契数列中第 aa 个数是多少。 输入格式 第 11 行是测试数据的组数 nn,后面跟着 nn 行…...

Spark的介绍

一、分布式的思想 不管是数据也好,计算也好,都没有最大的电脑,而是多个小电脑组合而成。 存储:将3T的文件拆分成若干个小文件,例如每500M一个小文件,将这些小文件存储在不同的机器上 。 -- HDFS 计算&#…...

SpringBoot项目是如何启动

启动步骤 概念 运行main方法,初始化SpringApplication 从spring.factories读取listener ApplicationContentInitializer运行run方法读取环境变量,配置信息创建SpringApplication上下文预初始化上下文,将启动类作为配置类进行读取调用 refres…...

科技之光,照亮未来之路“2024南京国际人工智能展会”

全球科技产业的版图正以前所未有的速度重构,而位于中国东部沿海经济带的江浙沪地区,作为科技创新与产业升级的高地,始终站在这一浪潮的最前沿。2024年,这一区域的科技盛宴——“2024南京人工智能展会”即将在南京国际博览中心盛大…...

在深度学习计算机视觉的语义分割中,Boundary和Edge的区别是?

在深度学习中的计算机视觉任务中,语义分割中的 Boundary 和 Edge 其实有一些相似之处,但它们的定义和使用场景略有不同。下面是两者的区别: 1. Boundary(边界) 定义:Boundary 是指一个对象或区域的边界&a…...

【JAVA入门】Day41 - 字节缓冲流和字符缓冲流

【JAVA入门】Day41 - 字节缓冲流和字符缓冲流 文章目录 【JAVA入门】Day41 - 字节缓冲流和字符缓冲流一、缓冲流的体系结构二、字节缓冲流2.1 字节缓冲流提高效率的底层原理 三、字符缓冲流 在IO流体系中,FileInputStream,FileOutputStream,F…...

collocate join,bucket join,broadcast join,shuffle join对比分析

在分布式计算和大数据处理中,尤其是在使用像 Apache Spark、Hive 等大数据处理框架时,Join 操作是非常常见的。根据数据分布方式和执行机制,Join 操作可以分为不同的类型,如 Collocate Join、Bucket Join、Broadcast Join 和 Shuffle Join。以下是它们的详细对比分析: 1.…...

微信自动通过好友和自动拉人进群,微加机器人这个功能太好用了

又发现一个好用的功能,之前就想找一个这种工具,现在发现可以利用微加机器人的两个功能来实现,分别是加好友和关键词拉群 首先 微加机器人的专业版 > 功能 > 加好友设置 可以设置一个关键词通过,这样别人加好友的时候只需要输入制定内…...

R语言统计分析——功效分析3(相关、线性模型)

参考资料:R语言实战【第2版】 1、相关性 pwr.r.test()函数可以对相关性分析进行功效分析。格式如下: pwr.r.test(n, r, sig.level, power, alternative) 其中,n是观测数目,r是效应值(通过线性相关系数衡量&#xff0…...

Django创建模型

1、根据创建好应用模块 python manage.py startapp tests 2、在models文件里创建模型 from django.db import modelsfrom book.models import User# Create your models here. class Tests(models.Model):STATUS_CHOICES ((0, 启用),(1, 停用),# 更多状态...)add_time mode…...

盘点2024年大家都在用的短视频剪辑工具

你现在休息的时间是不是都靠短视频来消遣?看着看着你就会发现短视频制作好像我也可以了吧?这次我就介绍一些简单好操作的短视频剪辑工具。 1.FOXIT视频剪辑 连接直达>>https://www.pdf365.cn/foxitclip/ 短视频剪辑其实也不难,只需…...

“左侧文字横向”的QTabWidget

左侧用 QToolButton 组, 右侧用 QStackedWidget,信号槽绑定切换页面 可定制化高 QButtonGroup* btnGp new QButtonGroup(this);btnGp->addButton(ui->btn1, 0);btnGp->addButton(ui->btn2, 1);btnGp->addButton(ui->btn3, 2);connect…...

python学习之字符串操作

str python # 定义一个字符串变量 print(id(str))print(str) # 打印整个字符串 print(str[0:-1]) # 打印字符串第一个到倒数第二个字符(不包含倒数第一个字符) print(str[0]) # 打印字符串的第一个字符 print(str[2:5]) # 打印字符串第三到第…...

第7篇:【系统分析师】计算机网络

考点汇总 考点详情 1网络模型和协议:OSI/RM七层模型,网络标准和协议,TCP/IP协议族,端口 七层:应用层,表示层,会话层,传输层,网络层,数据链路层,…...

无人机培训机构组装调试技术详解

一、基础知识学习 在进入无人机组装调试领域之前,扎实的基础知识是不可或缺的。学员需掌握以下内容: 1. 无人机基本原理:了解无人机的飞行原理,包括升力、推力、重力和阻力等基本物理概念,以及无人机的飞行控制系统&…...

‌汽车的舒适进入功能是什么意思?

移动管家汽车的舒适进入系统是指无钥匙进入功能,它允许驾驶者在距离车辆一定范围内自动感应解锁车辆,并具备无钥匙启动功能‌。舒适进入系统的核心优势包括: ‌智能化操作‌:无需传统钥匙,通过智能感应实现车门解锁和…...

杂七杂八-系统环境安装

杂七杂八-系统&环境安装 1. 系统安装2. 环境安装 仅个人笔记使用,后续会根据自己遇到问题记录,感谢点赞关注 1. 系统安装 Windows安装linux子系统WSL2:使用windows系统跑linux程序(大模型)WSL VSCode:VSCode连接WSL实现高效…...

Redis高可用,Redis性能管理

文章目录 一,Redis高可用,Redis性能管理二,Redis持久化1.RDB持久化1.1触发条件(1)手动触发(2)自动触发 1.2 Redis 的 RDB 持久化配置1.3 RDB执行流程(1) 判断是否有其他持久化操作在执行(2) 父进…...

React项目中使用发布订阅模式

React项目中使用发布订阅模式 1.创建发布订阅器2.在组件中使用发布订阅器3. 订阅数据 发布订阅模式(也称观察者模式)是一种管理跨组件通信的有效方式,尤其是在不希望直接依赖于特定组件的情况下。这种模式允许一个对象(发布者&…...

buck boost Ldo 经典模型的默写

BUCK: BOOST: LDO: BUCK-BOOST:...

velero v1.14.1迁移kubernetes集群

1 概述 velero是vmware开源的一个备份和恢复工具,可作用于kubernetes集群下的任意对象和应用数据(PV上的数据)。github地址是https://github.com/vmware-tanzu/velero。 对于应用数据,可分文件级别的复制和块级别的复制。文件级…...

Qt Model/View之Model

在检查如何处理选择之前,您可能会发现检查模型/视图框架中使用的概念很有用。 基本概念 在模型/视图架构中,模型提供了一个标准接口,用于视图和委托访问数据。在Qt中,标准接口由QAbstractItemModel类定义。无论数据项如何存储在…...

如何在 Vue 3 中使用 Element Plus

在 Vue 3 中使用 Element Plus 是一个相对直接的过程,因为 Element Plus 是为 Vue 3 设计的 UI 组件库。以下是在 Vue 3 项目中集成和使用 Element Plus 的基本步骤: 1. 安装 Element Plus 首先,你需要在你的 Vue 3 项目中安装 Element Plu…...