当前位置: 首页 > news >正文

大模型入门 ch 03:注意力机制

本文是github上的大模型教程LLMs-from-scratch的学习笔记,教程地址:教程链接

Chapter 3: Attention Mechanism

本文首先从固定参数的注意力机制说起,然后拓展到可以训练的注意力机制,然后加入掩码mask,最后拓展到多头注意力机制。


1. 注意力机制

一个句子中的每一个token,都会受到其他token的影响(这里先不考虑忽略未来的单词,掩码的问题后面再说),注意力机制可以让一个token收到其他token的影响,生成一个最终我们想要的embedding。即每个token有一个原始的embedding,通过注意力机制后,得到了一个新的embedding,这个embedding是结合了上下文语义得到的。

举个简单的例子,我们直接使用tokens的embedding之间两两点乘,得到互相之间的点乘结果,然后将点乘结果归一化,得到embeddings之间的注意力得分。

归一化一般使用softmax函数,通过取指数,除以求和得到归一化结果
torch.exp(x) / torch.exp(x).sum(dim=0)

得到token之间的相关权重后,我们就可以加权求和,得到每一个token的最终embedding。


2. 可以训练的注意力头

在上面的例子中,我们直接使用token对应的embedding来计算相关系数,以及最终的加权求和,这显然是不合理的,如果这样的话,那么我们只能训练token对应的词嵌入来学习模型,或者是一些全连接层,因此我们需要引入新的矩阵,来学习到更多的参数,这就是transformer的QKV矩阵。
QKV都是对原始的embedding做线性变换,得到新的向量,但是模型就可以通过训练QKV,学习海量知识。

在这里插入图片描述

QKV的维度不固定,可以与原始嵌入相同,也可以不同。总之,通过QKV三个矩阵,我们将原始token的embedding转换成了3个新的向量。

  • Query vector: q ( i ) = W q x ( i ) q^{(i)} = W_q \,x^{(i)} q(i)=Wqx(i)
  • Key vector: k ( i ) = W k x ( i ) k^{(i)} = W_k \,x^{(i)} k(i)=Wkx(i)
  • Value vector: v ( i ) = W v x ( i ) v^{(i)} = W_v \,x^{(i)} v(i)=Wvx(i)

可以使用矩阵乘法实现:

keys = inputs @ W_key 
values = inputs @ W_value

然后我们计算KQ之间的点积,作为两两token之间的关联度。为什么要用两个不一样的矩阵,我的猜测是,如果使用的是一个矩阵计算相似度,那么关于对角线对称的元素就会完全相同,但是使用两个不同的矩阵计算,就不会存在这样的情况,可以学习到的内容更多。

我们使用K和Q的点积得到了两两之间的注意力得分,同样使用softmax进行归一化,得到最终的注意力权重。

注意到没有直接对注意力得分softmax,而是除以维度的方根后再softmax,这是因为在计算注意力权重时,如果直接将Query和Key的点积结果用于softmax函数,当Key的维度较高时,点积的结果会变得非常大。这可能导致softmax函数在梯度下降过程中学习困难,因为大的数值会使softmax的梯度变得非常小(接近于0),这在数值稳定性上是一个问题,称为“梯度消失”。

最后一步,不再使用原始的embedding加权,我们使用V矩阵变换后的向量进行加权求和,得到结果向量。

代码如下

class SelfAttention_v2(nn.Module):def __init__(self, d_in, d_out, qkv_bias=False):super().__init__()self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)def forward(self, x):keys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.Tattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)context_vec = attn_weights @ valuesreturn context_vec

3. 隐藏未来的单词

对语言任务来说,在训练模型的时候不能使用未来的文本来预测之前的文本。因此我们需要屏蔽未见文本对先前文本的影响。在我们计算得到注意力权重后,我们人为地将上三角矩阵的权重置为0。
有一种naive的方法,就是将上注意力权重都置为0后,重新对剩下的元素归一化。但是我们要介绍的是一般使用的方法:

我们在计算出注意力得分后,对右上角元素都赋值为负无穷大,负无穷大在经过softmax后就变为了0。

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)

最后为了防止过拟合,一般会使用dropout,对注意力权重矩阵进行随机丢弃,加强模型泛化性能。

总结以上的所有内容,我们现在就可以写出一个单头的注意力机制了,并且加入了对batch输入的处理:

class CausalAttention(nn.Module):def __init__(self, d_in, d_out, context_length,dropout, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout) # Newself.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Newdef forward(self, x):b, num_tokens, d_in = x.shape # New batch dimension bkeys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.transpose(1, 2) # Changed transposeattn_scores.masked_fill_(  # New, _ ops are in-placeself.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_sizeattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights) # Newcontext_vec = attn_weights @ valuesreturn context_vec

4. 多头注意力机制

我们已经实现了单个头的注意力机制,那么要实现多个头,就是使用多个不同的注意力头,各自对输入进行处理,然后将各自得到的输出$z_i$拼接起来,非常显而易见,我们有第一个最直白的写法:
class MultiHeadAttentionWrapper(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])def forward(self, x):return torch.cat([head(x) for head in self.heads], dim=-1)

这是一种最简单直白的写法,直接声明num_heads个注意力单元,然后在前向传播的时候,依次调用这num_heads个注意力头,然后将输出拼接起来。(dim=-1代表最后一维拼接)


问题是,这样的话需要循环num_heads次得到结果,并且需要声明num_heads个注意力头,相信熟悉线性代数的朋友已经想到了,可以通过曾广矩阵来拓展注意力头。比如单个的注意力头是(d_in, d_out),那么有n个头的注意力机制就是(d_in, n*d_out)
假设输入是(tokens, d_in),那么(tokens, d_in) @ (d_in, n*d_out) --> (tokens, n*d_out),输出的结果完美得到了n个头对应的输出,我们只需要按照每d_out列拆开,得到n(tokens, d_out)的矩阵,就能还原出n个头对应的结果,进行后续的attention score计算。这样写起来虽然麻烦一些,但是效率更高。

class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()assert (d_out % num_heads == 0), \"d_out must be divisible by num_heads"self.d_out = d_outself.num_heads = num_headsself.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dimself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputsself.dropout = nn.Dropout(dropout)self.register_buffer("mask",torch.triu(torch.ones(context_length, context_length),diagonal=1))def forward(self, x):b, num_tokens, d_in = x.shapekeys = self.W_key(x) # Shape: (b, num_tokens, d_out)queries = self.W_query(x)values = self.W_value(x)# We implicitly split the matrix by adding a `num_heads` dimension# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim)queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)keys = keys.transpose(1, 2)queries = queries.transpose(1, 2)values = values.transpose(1, 2)# Compute scaled dot-product attention (aka self-attention) with a causal maskattn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head# Original mask truncated to the number of tokens and converted to booleanmask_bool = self.mask.bool()[:num_tokens, :num_tokens]# Use the mask to fill attention scoresattn_scores.masked_fill_(mask_bool, -torch.inf)attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)# Shape: (b, num_tokens, num_heads, head_dim)context_vec = (attn_weights @ values).transpose(1, 2) # Combine heads, where self.d_out = self.num_heads * self.head_dimcontext_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)context_vec = self.out_proj(context_vec) # optional projectionreturn context_vec

不同于前者,将不同的注意力头分开计算,第二种方法直接扩展query,key和value矩阵的列数,将多个矩阵运算简化为一个矩阵运算,计算完再更改维度还原成一个个注意力头,效率更高。这样,我们就完成了一个完整的注意力机制。

相关文章:

大模型入门 ch 03:注意力机制

本文是github上的大模型教程LLMs-from-scratch的学习笔记,教程地址:教程链接 Chapter 3: Attention Mechanism 本文首先从固定参数的注意力机制说起,然后拓展到可以训练的注意力机制,然后加入掩码mask,最后…...

STM32点亮第一个LED

还有第二个,并轮换。 准备入门STM32,于是拿出了买到手至少2年的洋桃M1板子,STM32F103C8T6 配置有3个LED,3个按钮,RS232,RS485,CAN,有JTAG,有RTC电池,IO口引…...

[Linux]:动静态库

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:Linux学习 贝蒂的主页:Betty’s blog 1. 动静态库的介绍 一般而言,库分为动态库和静态库。 在Linux当中…...

windows 显示进程地址空间

windows 显示进程地址空间 windows 显示进程地址空间 文章目录 windows 显示进程地址空间显示进程地址空间 显示进程地址空间 /* 3-ProcessInfo.cpp 显示进程地址空间 */#include "..\\CommonFiles\\CmnHdr.h" #include "..\\CommonFiles\\Toolhelp.h"#i…...

Android 12 SystemUI下拉状态栏禁止QuickQSPanel展开

1.概述 遇到需求,QuickQSPanel首次下拉后展示快捷功能模块以后就是显示QuickQSPanel,而不展开QSPanel,接下来要从下滑手势下拉出状态栏分析功能实现。也就是直接是展开状态。 2、涉及核心类 frameworks\base\packages\SystemUI\src\com\and…...

二分思想与相关问题(下)

接下来详细讲解几道比较难的例题,仔细体会二分和其他概念混合在一起的趣味。 下面这道题涉及了“碎片拼接”的概念,很妙,也很难想。 P r o b l e m 5 Problem5 Problem5 同时运行N台电脑的最长时间 LeetCode2141 你有 n 台电脑。给你整数 n…...

【算法专题】搜索算法

二叉树剪枝 LCR 047. 二叉树剪枝 - 力扣(LeetCode) 本题要求我们将全部为0的二叉树去掉,也就是剪枝,当我们举一个具体的例子进行模拟时,会发现,只关注于对其中一个子树的根节点进行剪枝,由于我…...

B2064 斐波那契数列

题目描述 斐波那契数列是指这样的数列:数列的第一个和第二个数都为 11,接下来每个数都等于前面 22 个数之和。 给出一个正整数 aa,要求斐波那契数列中第 aa 个数是多少。 输入格式 第 11 行是测试数据的组数 nn,后面跟着 nn 行…...

Spark的介绍

一、分布式的思想 不管是数据也好,计算也好,都没有最大的电脑,而是多个小电脑组合而成。 存储:将3T的文件拆分成若干个小文件,例如每500M一个小文件,将这些小文件存储在不同的机器上 。 -- HDFS 计算&#…...

SpringBoot项目是如何启动

启动步骤 概念 运行main方法,初始化SpringApplication 从spring.factories读取listener ApplicationContentInitializer运行run方法读取环境变量,配置信息创建SpringApplication上下文预初始化上下文,将启动类作为配置类进行读取调用 refres…...

科技之光,照亮未来之路“2024南京国际人工智能展会”

全球科技产业的版图正以前所未有的速度重构,而位于中国东部沿海经济带的江浙沪地区,作为科技创新与产业升级的高地,始终站在这一浪潮的最前沿。2024年,这一区域的科技盛宴——“2024南京人工智能展会”即将在南京国际博览中心盛大…...

在深度学习计算机视觉的语义分割中,Boundary和Edge的区别是?

在深度学习中的计算机视觉任务中,语义分割中的 Boundary 和 Edge 其实有一些相似之处,但它们的定义和使用场景略有不同。下面是两者的区别: 1. Boundary(边界) 定义:Boundary 是指一个对象或区域的边界&a…...

【JAVA入门】Day41 - 字节缓冲流和字符缓冲流

【JAVA入门】Day41 - 字节缓冲流和字符缓冲流 文章目录 【JAVA入门】Day41 - 字节缓冲流和字符缓冲流一、缓冲流的体系结构二、字节缓冲流2.1 字节缓冲流提高效率的底层原理 三、字符缓冲流 在IO流体系中,FileInputStream,FileOutputStream,F…...

collocate join,bucket join,broadcast join,shuffle join对比分析

在分布式计算和大数据处理中,尤其是在使用像 Apache Spark、Hive 等大数据处理框架时,Join 操作是非常常见的。根据数据分布方式和执行机制,Join 操作可以分为不同的类型,如 Collocate Join、Bucket Join、Broadcast Join 和 Shuffle Join。以下是它们的详细对比分析: 1.…...

微信自动通过好友和自动拉人进群,微加机器人这个功能太好用了

又发现一个好用的功能,之前就想找一个这种工具,现在发现可以利用微加机器人的两个功能来实现,分别是加好友和关键词拉群 首先 微加机器人的专业版 > 功能 > 加好友设置 可以设置一个关键词通过,这样别人加好友的时候只需要输入制定内…...

R语言统计分析——功效分析3(相关、线性模型)

参考资料:R语言实战【第2版】 1、相关性 pwr.r.test()函数可以对相关性分析进行功效分析。格式如下: pwr.r.test(n, r, sig.level, power, alternative) 其中,n是观测数目,r是效应值(通过线性相关系数衡量&#xff0…...

Django创建模型

1、根据创建好应用模块 python manage.py startapp tests 2、在models文件里创建模型 from django.db import modelsfrom book.models import User# Create your models here. class Tests(models.Model):STATUS_CHOICES ((0, 启用),(1, 停用),# 更多状态...)add_time mode…...

盘点2024年大家都在用的短视频剪辑工具

你现在休息的时间是不是都靠短视频来消遣?看着看着你就会发现短视频制作好像我也可以了吧?这次我就介绍一些简单好操作的短视频剪辑工具。 1.FOXIT视频剪辑 连接直达>>https://www.pdf365.cn/foxitclip/ 短视频剪辑其实也不难,只需…...

“左侧文字横向”的QTabWidget

左侧用 QToolButton 组, 右侧用 QStackedWidget,信号槽绑定切换页面 可定制化高 QButtonGroup* btnGp new QButtonGroup(this);btnGp->addButton(ui->btn1, 0);btnGp->addButton(ui->btn2, 1);btnGp->addButton(ui->btn3, 2);connect…...

python学习之字符串操作

str python # 定义一个字符串变量 print(id(str))print(str) # 打印整个字符串 print(str[0:-1]) # 打印字符串第一个到倒数第二个字符(不包含倒数第一个字符) print(str[0]) # 打印字符串的第一个字符 print(str[2:5]) # 打印字符串第三到第…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽,大家好,我是左手python! Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库,用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括:采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中,设置任务排序规则尤其重要,因为它让看板视觉上直观地体…...

基于Docker Compose部署Java微服务项目

一. 创建根项目 根项目&#xff08;父项目&#xff09;主要用于依赖管理 一些需要注意的点&#xff1a; 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件&#xff0c;否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

Ubuntu Cursor升级成v1.0

0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开&#xff0c;快捷键也不好用&#xff0c;当看到 Cursor 升级后&#xff0c;还是蛮高兴的 1. 下载 Cursor 下载地址&#xff1a;https://www.cursor.com/cn/downloads 点击下载 Linux (x64) &#xff0c;…...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中&#xff0c;明确沟通敏捷转型目的尤为关键&#xff0c;团队成员只有清晰理解转型背后的原因和利益&#xff0c;才能降低对变化的…...

JDK 17 序列化是怎么回事

如何序列化&#xff1f;其实很简单&#xff0c;就是根据每个类型&#xff0c;用工厂类调用。逐个完成。 没什么漂亮的代码&#xff0c;只有有效、稳定的代码。 代码中调用toJson toJson 代码 mapper.writeValueAsString ObjectMapper DefaultSerializerProvider 一堆实…...

车载诊断架构 --- ZEVonUDS(J1979-3)简介第一篇

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 做到欲望极简,了解自己的真实欲望,不受外在潮流的影响,不盲从,不跟风。把自己的精力全部用在自己。一是去掉多余,凡事找规律,基础是诚信;二是…...