当前位置: 首页 > news >正文

【Transformer】手撕Attention

import torch  
from torch import nn   
import torch.functional as F  
import mathX = torch.randn(16,64,512) # B,T,Dd_model = 512 # 模型的维度  
n_head = 8 # 注意力头的数量

多头注意力机制

在这里插入图片描述

class multi_head_attention(nn.Module):  def __init__(self, d_model, n_head):  # 调用父类构造函数  super(multi_head_attention, self).__init__()  # 保存注意力头的数量和模型的维度  self.n_head = n_head  self.d_model = d_model  # 定义查询(Q)、键(K)、值(V)的线性变换层  self.w_q = nn.Linear(d_model, d_model)  # 输入d_model维度,输出d_model维度  self.w_k = nn.Linear(d_model, d_model)  # 输入d_model维度,输出d_model维度  self.w_v = nn.Linear(d_model, d_model)  # 输入d_model维度,输出d_model维度  self.w_o = nn.Linear(d_model, d_model)  # 输出线性变换层,用来做一个线形缩放  # 定义softmax函数,用于计算注意力得分的归一化  self.softmax = nn.Softmax(dim=-1)  # softmax会在最后一维(dim=-1)上操作  def forward(self, q, k, v):  # 获取输入查询(q),键(k),值(v)的形状  B, T, D = q.shape  # B: batch size, T: sequence length, D: feature dimension (d_model)  # 每个注意力头的维度  n_d = self.d_model // self.n_head  # 每个头的维度(d_model / n_head)  # 将输入的q、k、v通过各自的线性变换层映射到新的空间  q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)  # 将q, k, v 按头数进行拆分(reshape),并转置使得各头的计算可以并行  # q, k, v的形状变为 (B, T, n_head, n_d),然后转置变为 (B, n_head, T, n_d)        q = q.view(B, T, self.n_head, n_d).transpose(1, 2)  # (B, n_head, T, n_d)  k = k.view(B, T, self.n_head, n_d).transpose(1, 2)  # (B, n_head, T, n_d)  v = v.view(B, T, self.n_head, n_d).transpose(1, 2)  # (B, n_head, T, n_d)  # 计算缩放点积注意力(scaled dot-product attention)  score = q @ k.transpose(2, 3) / math.sqrt(n_d)  # (B, n_head, T, T)  # score是查询q与键k之间的相似度矩阵,进行缩放以防止数值过大  # 生成一个下三角矩阵,用于实现自注意力中的"masking",屏蔽未来的信息  mask = torch.tril(torch.ones(T, T, dtype=bool))  # 生成一个下三角的布尔矩阵  # 使用mask进行屏蔽,mask为0的位置会被填充为一个非常大的负值(-10000)  score = score.masked_fill(mask == 0, -10000)  # 把mask == 0的位置置为-10000  # 对score进行softmax归一化处理,得到注意力权重  score = self.softmax(score)  # (B, n_head, T, T)  # 将注意力权重与值(v)相乘,得到加权后的值  score = score @ v  # (B, n_head, T, n_d)  # 将多个头的结果合并(concatenate),并通过线性层进行映射  # 首先将score的维度变为 (B, T, n_head * n_d),然后通过w_o进行线性变换  x_concate = score.transpose(1, 2).contiguous().view(B, T, self.d_model)  # (B, T, d_model)  x_output = self.w_o(x_concate)  # (B, T, d_model)  # 返回最终的输出  return x_output  attn = multi_head_attention(d_model, n_head)  
Y = attn(X,X,X)  
print(Y.shape)

层归一化

# layer norm  
class layer_norm(nn.Module):  def __init__(self, d_model, eps = 1e-12):  super(layer_norm, self).__init__()  self.gamma = nn.Parameter(torch.ones(d_model))  self.beta = nn.Parameter(torch.zeros(d_model))  self.eps = eps  def forward(self, x):  mean = x.mean(-1, keepdim = True)  var = x.var(-1, unbiased=False, keepdim = True)  out = (x - mean) / torch.sqrt(var + self.eps)  out = self.gamma * out + self.beta  return out  d_model = 512  
X = torch.randn(2,5,512) # 2句话, 5个token,词向量512  
ln = layer_norm(d_model)  
print("d_model: ", d_model)  
print(f"ln gamma: {ln.gamma.shape}")  
print(f"ln beta: {ln.beta.shape}")  
Y_ln = ln(X)  
print(Y_ln.shape)

这段代码实现了一个多头注意力机制(Multi-Head Attention),这是Transformer模型中的核心组件之一。多头注意力机制允许模型在处理序列数据时,同时关注序列中不同位置的信息,并且可以从不同的子空间中学习到不同的特征表示。

层归一化

代码解读

1. 初始化部分 (__init__ 方法)
def __init__(self, d_model, n_head):super(multi_head_attention, self).__init__()self.n_head = n_headself.d_model = d_modelself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.softmax = nn.Softmax(dim=-1)
  • d_model:模型的维度,即输入向量的维度。
  • n_head:注意力头的数量。
  • w_q, w_k, w_v:分别是对查询(Query)、键(Key)、值(Value)进行线性变换的层,将输入映射到新的空间。
  • w_o:输出线性变换层,用于将多个头的输出合并并映射回原始维度。
  • softmax:用于对注意力得分进行归一化。
2. 前向传播部分 (forward 方法)
def forward(self, q, k, v):B, T, D = q.shapen_d = self.d_model // self.n_headq, k, v = self.w_q(q), self.w_k(k), self.w_v(v)q = q.view(B, T, self.n_head, n_d).transpose(1, 2)k = k.view(B, T, self.n_head, n_d).transpose(1, 2)v = v.view(B, T, self.n_head, n_d).transpose(1, 2)
  • B:批量大小(batch size)。
  • T:序列长度(sequence length)。
  • D:特征维度(feature dimension),即 d_model
  • n_d:每个注意力头的维度,等于 d_model / n_head
  • q, k, v:通过线性变换层映射到新的空间后,再按头数进行拆分和转置,以便并行计算。
  • [[q = q.view(B, T, self.n_head, n_d).transpose(1, 2)]]
    score = q @ k.transpose(2, 3) / math.sqrt(n_d)mask = torch.tril(torch.ones(T, T, dtype=bool))score = score.masked_fill(mask == 0, -10000)score = self.softmax(score)
  • score:计算查询 q 和键 k 之间的相似度矩阵,并进行缩放(防止数值过大)。
  • mask:生成一个下三角矩阵,用于屏蔽未来的信息(在自注意力机制中,当前时间步只能看到之前的时间步)。
  • score:通过 mask 屏蔽未来的信息,并对得分进行 softmax 归一化,得到注意力权重。
  • [[mask = torch.tril(torch.ones(T, T, dtype=bool))]]
    score = score @ vx_concate = score.transpose(1, 2).contiguous().view(B, T, self.d_model)x_output = self.w_o(x_concate)return x_output
  • score @ v:将注意力权重与值 v 相乘,得到加权后的值。
  • x_concate:将多个头的输出合并(concatenate),并通过 w_o 进行线性变换,得到最终的输出。
  • [[x_concate = score.transpose(1, 2).contiguous().view(B, T, self.d_model)]]
3. 使用示例
attn = multi_head_attention(d_model, n_head)
Y = attn(X, X, X)
print(Y.shape)
  • attn:创建一个多头注意力机制的实例。
  • Y = attn(X, X, X):将输入 X 分别作为查询、键、值传入多头注意力机制,得到输出 Y
  • print(Y.shape):输出 Y 的形状,通常与输入 X 的形状相同,即 (B, T, d_model)

总结

这段代码实现了一个完整的多头注意力机制,包括线性变换、缩放点积注意力、掩码处理、softmax归一化、多头结果的合并和最终的线性变换。多头注意力机制是Transformer模型的核心组件,广泛应用于自然语言处理、计算机视觉等领域。

这段代码实现了一个层归一化(Layer Normalization)模块

层归一化是深度学习中常用的一种归一化技术,用于稳定训练过程并加速收敛


1. 初始化部分 (__init__ 方法)

def __init__(self, d_model, eps=1e-12):super(layer_norm, self).__init__()self.gamma = nn.Parameter(torch.ones(d_model))self.beta = nn.Parameter(torch.zeros(d_model))self.eps = eps
  • d_model:输入特征的维度(即词向量的维度)。
  • gammabeta
  • gamma可学习缩放参数,初始值为全 1,形状为 (d_model,)
    • beta可学习偏移参数,初始值为全 0,形状为 (d_model,)
    • 这两个参数用于对归一化后的数据进行缩放和偏移,以增强模型的表达能力。
  • eps:一个小常数,用于防止分母为零的情况,通常设置为 1e-12

2. 前向传播部分 (forward 方法)

def forward(self, x):mean = x.mean(-1, keepdim=True)var = x.var(-1, unbiased=False, keepdim=True)out = (x - mean) / torch.sqrt(var + self.eps)out = self.gamma * out + self.betareturn out
  • 输入 x:假设 x 的形状为 (B, T, d_model),其中:

    • B 是批量大小(batch size)。
    • T 是序列长度(sequence length)。
    • d_model 是特征维度(即词向量的维度)。
  • 步骤 1:计算均值和方差

    • mean = x.mean(-1, keepdim=True):沿着最后一个维度(d_model)计算均值,形状为 (B, T, 1)
    • var = x.var(-1, unbiased=False, keepdim=True):沿着最后一个维度计算方差,形状为 (B, T, 1)
    • unbiased=False 表示计算方差时不使用无偏估计(即除以 n 而不是 n-1)。
  • 步骤 2:归一化

    • out = (x - mean) / torch.sqrt(var + self.eps):对输入 x 进行归一化,减去均值并除以标准差(加上 eps 防止除零)。
  • 步骤 3:缩放和偏移

    • out = self.gamma * out + self.beta:对归一化后的数据进行缩放和偏移,gammabeta 是可学习的参数。
  • 输出 out:形状与输入 x 相同,为 (B, T, d_model)


3. 代码运行示例

d_model = 512
X = torch.randn(2, 5, 512)  # 2句话, 5个token,词向量512
ln = layer_norm(d_model)
print("d_model: ", d_model)
print(f"ln gamma: {ln.gamma.shape}")
print(f"ln beta: {ln.beta.shape}")
Y_ln = ln(X)
print(Y_ln.shape)
  • 输入 X:形状为 (2, 5, 512),表示 2 个句子,每个句子有 5 个 token,每个 token 的词向量维度为 512。
  • ln.gammaln.beta
    • ln.gamma 的形状为 (512,)
    • ln.beta 的形状为 (512,)
  • 输出 Y_ln:形状与输入 X 相同,为 (2, 5, 512)

4. 层归一化的作用

  • 稳定训练:通过对每个样本的特征进行归一化,减少内部协变量偏移(Internal Covariate Shift),从而稳定训练过程。
  • 加速收敛:归一化后的数据分布更加稳定,有助于加速模型的收敛。
  • 增强表达能力:通过可学习的参数 gammabeta,模型可以学习到适合当前任务的归一化方式。

5. 与批量归一化(Batch Normalization)的区别

  • 批量归一化:沿着批量维度(B)计算均值和方差,适用于批量较大的情况。
  • 层归一化:沿着特征维度(d_model)计算均值和方差,适用于序列数据(如 NLP 中的句子)或批量较小的情况。

6. 总结

  • 这段代码实现了一个层归一化模块,对输入的特征进行归一化,并通过可学习的参数 gammabeta 进行缩放和偏移。
  • 层归一化在 Transformer 等模型中广泛应用,用于稳定训练和加速收敛。
  • 输入形状为 (B, T, d_model),输出形状与输入相同。

相关文章:

【Transformer】手撕Attention

import torch from torch import nn import torch.functional as F import mathX torch.randn(16,64,512) # B,T,Dd_model 512 # 模型的维度 n_head 8 # 注意力头的数量多头注意力机制 class multi_head_attention(nn.Module): def __init__(self, d_model, n_hea…...

844.比较含退格的字符串

目录 题目思路解法收获 题目 给定 s 和 t 两个字符串,当它们分别被输入到空白的文本编辑器后,如果两者相等,返回 true 。# 代表退格字符。 注意:如果对空文本输入退格字符,文本继续为空。 思路 如何解退格之后left…...

图书管理系统 Axios 源码__编辑图书

目录 功能概述: 代码实现(index.js): 代码解析: 图书管理系统中,删除图书功能是核心操作之一。下是基于 HTML、Bootstrap、JavaScript 和 Axios 实现的删除图书功能的详细介绍。 功能概述: …...

LabVIEW纤维集合体微电流测试仪

LabVIEW开发纤维集合体微电流测试仪。该设备精确测量纤维材料在特定电压下的电流变化,以分析纤维的结构、老化及回潮率等属性,对于纤维材料的科学研究及质量控制具有重要意义。 ​ 项目背景 在纤维材料的研究与应用中,电学性能是评估其性能…...

Commander 一款命令行自定义命令依赖

一、安装 commander 插件 npm install commander 二、基本用法 1. 创建一个简单的命令行程序 创建一个 JavaScript 文件,例如 mycli.js,并添加以下代码: // 引入 commander 模块并获取 program 对象。const { program } require("…...

Day24 洛谷普及2004(内涵前缀和与差分算法)

零基础洛谷刷题记录 Day01 2024.11.18 Day02 2024.11.25 Day03 2024.11.26 Day04 2024.11.28 Day05 2024.11.29 Day06 2024 12.02 Day07 2024.12.03 Day08 2024 12 05 Day09 2024.12.07 Day10 2024.12.09 Day11 2024.12.10 Day12 2024.12.12 Day13 2024.12.16 Day14 2024.12.1…...

遗传算法与深度学习实战(33)——WGAN详解与实现

遗传算法与深度学习实战(33)——WGAN详解与实现 0. 前言1. 训练生成对抗网络的挑战2. GAN 优化问题2.1 梯度消失2.2 模式崩溃 2.3 无法收敛3 Wasserstein GAN3.1 Wasserstein 损失3.2 使用 Wasserstein 损失改进 DCGAN 小结系列链接 0. 前言 原始的生成…...

gitlab云服务器配置

目录 1、关闭防火墙 2、安装gitlab 3、修改配置 4、查看版本 GitLab终端常用命令 5、访问 1、关闭防火墙 firewall-cmd --state 检查防火墙状态 systemctl stop firewalld.service 停止防火墙 2、安装gitlab xftp中导入安装包 [rootgitlab ~]#mkdir -p /service/tool…...

SAP SD学习笔记27 - 请求计划(开票计划)之1 - 定期请求(定期开票)

上两章讲了贩卖契约(框架协议)的概要,以及贩卖契约中最为常用的 基本契约 - 数量契约和金额契约。 SAP SD学习笔记26 - 贩卖契约(框架协议)的概要,基本契约 - 数量契约_sap 框架协议-CSDN博客 SAP SD学习笔记27 - 贩卖契约(框架…...

HTML DOM 修改 HTML 内容

HTML DOM 修改 HTML 内容 引言 HTML DOM(文档对象模型)是浏览器内部用来解析和操作HTML文档的一种机制。通过DOM,我们可以轻松地修改HTML文档的结构、样式和行为。本文将详细介绍如何使用HTML DOM来修改HTML内容,包括元素的增删改查、属性修改以及事件处理等。 1. HTML …...

基于VMware的ubuntu与vscode建立ssh连接

1.首先安装openssh服务 sudo apt update sudo apt install openssh-server -y 2.启动并检查ssh服务状态 到这里可以按q退出 之后输入命令 : ip a 红色挡住的部分就是我们要的地址,这里就不展示了哈 3.配置vscode 打开vscode 搜索并安装:…...

Flutter Candies 一桶天下

| | | | | | | | 入魔的冬瓜 最近刚入桶的兄弟,有责任心的开发者,对自己的项目会不断进行优化,达到最完美的状态 自定义日历组件 主要功能 支持公历,农历,节气,传统节日,常用节假日 …...

maven如何不把依赖的jar打包到同一个jar?

spring boot项目打jar包部署: 经过以下步骤, 最终会形成maven依赖的多个jar(包括lib下添加的)、 我们编写的程序代码打成一个jar,将程序jar与 依赖jar分开,便于管理: success: 最终…...

HTML5 技术深度解读:本地存储与地理定位的最佳实践

系列文章目录 01-从零开始学 HTML:构建网页的基本框架与技巧 02-HTML常见文本标签解析:从基础到进阶的全面指南 03-HTML从入门到精通:链接与图像标签全解析 04-HTML 列表标签全解析:无序与有序列表的深度应用 05-HTML表格标签全面…...

AIGC技术中常提到的 “嵌入转换到同一个向量空间中”该如何理解

在AIGC(人工智能生成内容)技术中,“嵌入转换到同一个向量空间中”是一个核心概念,其主要目的是将不同类型的输入数据(如文本、图像、音频等)映射到一个统一的连续向量空间中,从而实现数据之间的…...

【机器学习理论】朴素贝叶斯网络

基础知识: 先验概率:对某个事件发生的概率的估计。可以是基于历史数据的估计,可以由专家知识得出等等。一般是单独事件概率。 后验概率:指某件事已经发生,计算事情发生是由某个因素引起的概率。一般是一个条件概率。 …...

Docker 部署 GLPI(IT 资产管理软件系统)

GLPI 简介 GLPI open source tool to manage Helpdesk and IT assets GLPI stands for Gestionnaire Libre de Parc Informatique(法语 资讯设备自由软件 的缩写) is a Free Asset and IT Management Software package, that provides ITIL Service De…...

【Vaadin flow 实战】第5讲-使用常用UI组件绘制页面元素

vaadin flow官方提供的UI组件文档地址是 https://vaadin.com/docs/latest/components这里,我简单实战了官方提供的一些免费的UI组件,使用案例如下: Accordion 手风琴 Accordion 手风琴效果组件 Accordion 手风琴-测试案例代码 Slf4j PageT…...

强化学习 DAY1:什么是 RL、马尔科夫决策、贝尔曼方程

第一部分 RL基础:什么是RL与MRP、MDP 1.1 入门强化学习所需掌握的基本概念 1.1.1 什么是强化学习:依据策略执行动作-感知状态-得到奖励 强化学习里面的概念、公式,相比ML/DL特别多,初学者刚学RL时,很容易被接连不断…...

理解神经网络:Brain.js 背后的核心思想

温馨提示 这篇文章篇幅较长,主要是为后续内容做铺垫和说明。如果你觉得文字太多,可以: 先收藏,等后面文章遇到不懂的地方再回来查阅。直接跳读,重点关注加粗或高亮的部分。放心,这种“文字轰炸”不会常有的,哈哈~ 感谢你的耐心阅读!😊 欢迎来到 brain.js 的学习之旅!…...

【Docker】dockerfile识别当前构建的镜像平台

在编写dockerfile的时候,可能会遇到需要针对不同平台进行不同操作的时候,这需要我们对dockerfile进行针对性修改。 比如opencv的依赖项libjasper-dev在ubuntu18.04上就需要根据不同的平台做不同的处理,关于这个库的安装在另外一篇博客里面有…...

【VM】VirtualBox安装CentOS8虚拟机

阅读本文前,请先根据 VirtualBox软件安装教程 安装VirtualBox虚拟机软件。 1. 下载centos8系统iso镜像 可以去两个地方下载,推荐跟随本文的操作用阿里云的镜像 centos官网:https://www.centos.org/download/阿里云镜像:http://…...

【C++篇】哈希表

目录 一,哈希概念 1.1,直接定址法 1.2,哈希冲突 1.3,负载因子 二,哈希函数 2.1,除法散列法 /除留余数法 2.2,乘法散列法 2.3,全域散列法 三,处理哈希冲突 3.1&…...

Java篇之继承

目录 一. 继承 1. 为什么需要继承 2. 继承的概念 3. 继承的语法 4. 访问父类成员 4.1 子类中访问父类的成员变量 4.2 子类中访问父类的成员方法 5. super关键字 6. super和this关键字 7. 子类构造方法 8. 代码块的执行顺序 9. protected访问修饰限定符 10. 继承方式…...

边缘检测算法(candy)

人工智能例子汇总:AI常见的算法和例子-CSDN博客 Canny 边缘检测的步骤 1. 灰度转换 如果输入的是彩色图像,则需要先转换为 灰度图像,因为边缘检测通常在单通道图像上进行。 2. 高斯滤波(Gaussian Blur) 由于边缘…...

设计模式Python版 组合模式

文章目录 前言一、组合模式二、组合模式实现方式三、组合模式示例四、组合模式在Django中的应用 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式…...

dfs枚举问题

碎碎念:要开始刷算法题备战蓝桥杯了,一切的开头一定是dfs 定义 枚举问题就是咱数学上学到的,从n个数里面选m个数,有三种题型(来自Acwing) 从 1∼n 这 n个整数中随机选取任意多个,输出所有可能的选择方案。 把 1∼n这…...

【开源免费】基于SpringBoot+Vue.JS社区智慧养老监护管理平台(JAVA毕业设计)

本文项目编号 T 163 ,文末自助获取源码 \color{red}{T163,文末自助获取源码} T163,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...

安全防护前置

就业概述 网络安全工程师/安全运维工程师/安全工程师 安全架构师/安全专员/研究院(数学要好) 厂商工程师(售前/售后) 系统集成工程师(所有计算机知识都要会一点) 学习目标 前言 网络安全事件 蠕虫病毒--&…...

高性能消息队列Disruptor

定义一个事件模型 之后创建一个java类来使用这个数据模型。 /* <h1>事件模型工程类&#xff0c;用于生产事件消息</h1> */ no usages public class EventMessageFactory implements EventFactory<EventMessage> { Overridepublic EventMessage newInstance(…...