PyTorch从零开始实现Transformer
文章目录
- 自注意力
- Transformer块
- 编码器
- 解码器块
- 解码器
- 整个Transformer
- 参考来源
- 全部代码(可直接运行)
自注意力
计算公式

代码实现
class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # 矩阵乘法,使用爱因斯坦标记法# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim) # 矩阵乘法,使用爱因斯坦标记法einsum# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return out
Transformer块
我们把Transfomer块定义为如下图所示的结构,这个Transformer块在编码器和解码器中都有出现过。

代码实现
class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out
编码器
编码器结构如下所示,Inputs经过Input Embedding 和Positional Encoding之后,通过多个Transformer块

代码实现
class Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out
解码器块
解码器块结构如下图所示

代码实现
class DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return out
解码器
解码器块加上word embedding 和 positional embedding之后构成解码器

代码实现
class Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return out
整个Transformer

代码实现
class Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask, trg_mask)return out
参考来源
[1] https://www.youtube.com/watch?v=U0s0f995w14
[2] https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py
[3] https://arxiv.org/abs/1706.03762
[4] https://www.youtube.com/watch?v=pkVwUVEHmfI
全部代码(可直接运行)
import torch
import torch.nn as nnclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embed size needs to be div by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads*self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0] # the number of training examplesvalue_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split embedding into self.heads piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim)# keys shape: (N, key_len, heads, heads_dim)# energy shape: (N, heads, query_len, key_len)if mask is not None:energy = energy.masked_fill(mask==0, float("-1e20")) #Fills elements of self tensor with value where mask is Trueattention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)# attention shape: (N, heads, query_len, key_len)# values shape: (N, value_len, heads, head_dim)# after einsum (N, query_len, heads, head_dim) then flatten last two dimensionsout = self.fc_out(out)return outclass TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion*embed_size),nn.ReLU(),nn.Linear(forward_expansion*embed_size, embed_size))self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass Encoder(nn.Module):def __init__(self, src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_lengh = x.shapepositions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return outclass DecoderBlock(nn.Module):def __init__(self, embed_size, heads, forward_expansion, dropout, device):super(DecoderBlock, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm = nn.LayerNorm(embed_size)self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)self.dropout = nn.Dropout(dropout)def forward(self, x, value, key, src_mask, trg_mask):attention = self.attention(x, x, x, trg_mask)query = self.dropout(self.norm(attention + x))out = self.transformer_block(value, key, query, src_mask)return outclass Decoder(nn.Module):def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length):super(Decoder, self).__init__()self.device = deviceself.word_embedding = nn.Embedding(trg_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([DecoderBlock(embed_size, heads, forward_expansion, dropout, device)for _ in range(num_layers)])self.fc_out = nn.Linear(embed_size, trg_vocab_size)self.dropout = nn.Dropout(dropout)def forward(self, x, enc_out, src_mask, trg_mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))for layer in self.layers:x = layer(x, enc_out, enc_out, src_mask, trg_mask)out = self.fc_out(x)return outclass Transformer(nn.Module):def __init__(self,src_vocab_size, trg_vocab_size,src_pad_idx,trg_pad_idx,embed_size=256,num_layers=6,forward_expansion=4,heads=8,dropout=0,device="cuda",max_length=100):super(Transformer, self).__init__()self.encoder = Encoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length)self.decoder = Decoder(trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length)self.src_pad_idx = src_pad_idxself.trg_pad_idx = trg_pad_idxself.device = devicedef make_src_mask(self, src):src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)#(N, 1, 1, src_len)return src_mask.to(self.device)def make_trg_mask(self, trg):N, trg_len = trg.shapetrg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)return trg_mask.to(self.device)def forward(self, src, trg):src_mask = self.make_src_mask(src)trg_mask = self.make_trg_mask(trg)enc_src = self.encoder(src, src_mask)out = self.decoder(trg, enc_src, src_mask, trg_mask)return outif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)src_pad_idx = 0trg_pad_idx = 0src_vocab_size = 10trg_vocab_size = 10model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)out = model(x, trg[:, :-1])print(out.shape)相关文章:
PyTorch从零开始实现Transformer
文章目录 自注意力Transformer块编码器解码器块解码器整个Transformer参考来源全部代码(可直接运行) 自注意力 计算公式 代码实现 class SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.e…...
运动蓝牙耳机什么牌子的好用、最好用的运动蓝牙耳机推荐
音乐是运动的灵魂,而一款优秀的运动耳机则是让音乐与我们的身体完美融合的关键。今天,我推荐五款备受运动爱好者喜爱的耳机,它们以卓越的音质、舒适的佩戴和出色的稳定性能脱颖而出,助你在运动中创造最佳状态。 1、NANK南卡Runne…...
HTTP、HTTPS协议详解
文章目录 HTTP是什么报文结构请求头部响应头部 工作原理用户点击一个URL链接后,浏览器和web服务器会执行什么http的版本持久连接和非持久连接无状态与有状态Cookie和Sessionhttp方法:get和post的区别 状态码 HTTPS是什么ssl如何搞到证书nginx中的部署 加…...
【算法与数据结构】222、LeetCode完全二叉树的节点个数
文章目录 一、题目二、一般遍历解法三、利用完全二叉树性质四、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、一般遍历解法 思路分析:利用层序遍历,然后用num记录节点数量。其他的例如…...
登录和注册表单的11个HTML最佳实践
原文:11 HTML best practices for login & sign-up forms 原作者:Andrey Sitnik 翻译已获原文作者许可,禁止转载和商用 大多数网站都有登录或注册表单;它们是业务转换的关键部分。然而,即使是流行的站点也没有实现本文中提到的…...
Mysql删除历史数据
Mysql定时删除历史数据 实现 1.创建存储过程(函数) SQL DROP PROCEDURE IF EXISTS KeepDatasWith30Days CREATE PROCEDURE KeepDatasWith30Days() BEGINSELECT maxId:max(Id) FROM tableName WHERE CreateTime<DATE(DATE_SUB(NOW(),INTERVAL 31 D…...
Python—数据结构(一)
先放一张自己学习和整理归纳的思维导图,以便让大家都知道我自己的整体学习路线。 数据结构的学习路上内容枯燥,但坚持下来一定有很大的收获!加油💪🏻! 数据结构 数据的概念数据元素: 若干基本…...
离线环境安装flask依赖包
找到当前版本需要的所有依赖包,生产flask项目生成项目依赖包文件requirements.txt 1)在当前项目目录下 生成requirements文件:pip freeze >requirements.txt 执行requirements文件,安装依赖包:pip install -r requirements.t…...
ChatGPT与Claude对比分析
一 简介 1、ChatGPT: 访问地址:https://chat.openai.com/ 由OpenAI研发,2022年11月发布。基于 transformer 结构的大规模语言模型,包含1750亿参数。训练数据集主要是网页文本,聚焦于流畅的对话交互。对话风格友好,回复通顺灵活,富有创造性。存在一定的安全性问题,可…...
登录和注册页面 - 验证码功能的实现
目录 1. 生成验证码 2. 将本地验证码发布成 URL 3. 后端返回验证码的 URL 给前端 4. 前端将用户输入的验证码传给后端 5. 后端验证验证码 1. 生成验证码 使用hutool 工具生成验证码. 1.1 添加 hutool 验证码依赖 <!-- 验证码 --> <dependency><groupId…...
HDFS的文件块大小(重点)
HDFS 中的文件在物理上是分块存储 (Block ) , 块的大小可以通过配置参数( dfs.blocksize)来规定,默认大小在Hadoop2.x/3.x版本中是128M,1.x版本中是64M。 如果一个文件文件小于128M,该文件会占…...
深度学习(二)
目录 一、神经网络 整体架构: 架构细节: 神经元个数的影响: 神经网络过拟合解决: 卷积网络 整体架构: 卷积层 边缘填充 特征尺寸计算 池化层 特征图变化 递归神经网络 一、神经网络 整体架构: 图中分别为输入层、隐层1、隐层2、输出层 通过输入层输入某数值…...
无涯教程-jQuery - wrapInner( html )方法函数
wrapInner(html)方法使用HTML结构包装每个匹配元素(包括文本节点)的内部子内容。 wrapInner( html ) - 语法 selector.wrapInner( html ) 这是此方法使用的所有参数的描述- html - 将动态创建并环绕目标的HTML字符串。 wrapInner( html ) - 示例 以下是一个简单的示例…...
【unity之IMGUI实践】单例模式管理数据存储【二】
👨💻个人主页:元宇宙-秩沅 👨💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨💻 本文由 秩沅 原创 👨💻 收录于专栏:uni…...
【C++】开源:Linux端ALSA音频处理库
😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍Linux端ALSA音频处理库。 无专精则不能成,无涉猎则不能通。。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下,…...
【Linux | Shell】结构化命令2 - test命令、方括号测试条件、case命令
目录 一、概述二、test 命令2.1 test 命令2.2 方括号测试条件2.3 test 命令和测试条件可以判断的 3 类条件2.3.1 数值比较2.3.2 字符串比较 三、复合条件测试四、if-then 的高级特性五、case 命令 一、概述 上篇文章介绍了 if 语句相关知识。但 if 语句只能执行命令,…...
基于单片机的语音识别智能垃圾桶垃圾分类的设计与实现
功能介绍 以51单片机作为主控系统;液晶显示当前信息和状态;通过语音识别模块对当前垃圾种类进行语音识别; 通过蜂鸣器进行声光报警提醒垃圾桶已满;采用舵机控制垃圾桶打开关闭;超声波检测当前垃圾桶满溢程度࿱…...
最新版本docker 设置国内镜像源 加速办法
解决问题:加速 docker 设置国内镜像源 目录: 国内加速地址 修改方法 国内加速地址 1.Docker中国区官方镜像 https://registry.docker-cn.com 2.网易 http://hub-mirror.c.163.com 3.ustc https://docker.mirrors.ustc.edu.cn 4.中国科技大学 https://docker.mirrors…...
深度学习——LSTM解决分类问题
RNN基本介绍 概述 循环神经网络(Recurrent Neural Network,RNN)是一种深度学习模型,主要用于处理序列数据,如文本、语音、时间序列等具有时序关系的数据。 核心思想 RNN的关键思想是引入了循环结构,允许…...
three.js入门二:相机的zoom参数
环境: threejs:129 (在浏览器的控制台下输入: window.__THREE__即可查看版本)vscodewindowedge 透视相机或正交相机都有一个zoom参数,它可以用来将相机排到的内容在canvas上缩放显示。 注意:…...
告别手动编译:用Conda在Ubuntu 20.04上一键安装与管理SUMO交通仿真环境
告别手动编译:用Conda在Ubuntu 20.04上一键安装与管理SUMO交通仿真环境 在交通工程和智能驾驶研究领域,SUMO(Simulation of Urban MObility)作为开源的微观交通仿真工具,正被越来越多的研究者和开发者采用。然而&#…...
一文搞懂训练大模型的数据怎么准备!
谈到大模型,很多人第一反应都是模型参数大、算力强,但其实数据才是大模型真正的底座。没有足够大、足够干净的数据,再先进的模型也发挥不出威力。今天就从数据层面,把大模型训练的几个关键环节梳理清楚。 数据采集与清洗 大模型训…...
Git提交时Personal Access Token权限不足:如何正确配置workflow scope
1. 为什么Git提交会提示Personal Access Token权限不足? 最近在团队协作中遇到一个典型问题:当开发者尝试推送包含.github/workflows目录的代码到GitHub仓库时,系统突然报错refusing to allow a Personal Access Token to create or update w…...
为什么操作 UI 必须加 `lcd_mutex` 互斥锁?不用会怎样?
1. 先给结论(你必须记住) LVGL 所有界面操作(创建文字、按钮、刷新屏幕)都不是线程安全的。 意思是: 绝对不能有两个线程同时操作 LVGL 界面! 线程A:LVGL 主线程(一直在刷新屏幕&…...
告别Makefile!用Zig 0.10.0自带的构建系统搞定ARM裸机开发(附完整项目配置)
用Zig构建系统重塑ARM裸机开发:告别Makefile的终极指南 当你在凌晨三点盯着第47个Makefile规则调试链接器错误时,是否想过——嵌入式开发必须这么痛苦吗?Zig 0.10.0带来的不仅是一门新语言,更是一套彻底革新裸机开发工作流的构建系…...
Android音频输出流实战:从AudioFlinger到HAL层的完整调用链解析
Android音频输出流深度解析:从框架设计到硬件交互 1. Android音频系统架构概览 Android音频子系统采用分层设计,每一层都有明确的职责划分。理解这个架构是分析音频输出流的基础。 核心层级结构: 应用层:通过AudioTrack、MediaPla…...
Ubuntu 22.04 换源+Docker安装+镜像加速
Ubuntu 22.04 换源Docker安装镜像加速 前言 本文针对 Ubuntu 22.04 LTS 系统,先更换国内镜像源提升下载速度,再完成 Docker 引擎与 Compose 插件安装,最后配置 Docker 国内镜像加速,全程无报错、可直接复制执行,适配 V…...
BilibiliDown终极实战指南:解锁B站视频批量下载的完整方案
BilibiliDown终极实战指南:解锁B站视频批量下载的完整方案 【免费下载链接】BilibiliDown (GUI-多平台支持) B站 哔哩哔哩 视频下载器。支持稍后再看、收藏夹、UP主视频批量下载|Bilibili Video Downloader 😳 项目地址: https://gitcode.com/gh_mirro…...
HackTricks数字取证方法论:内存转储分析与恶意软件检测完全指南
HackTricks数字取证方法论:内存转储分析与恶意软件检测完全指南 【免费下载链接】hacktricks Welcome to the page where you will find each trick/technique/whatever I have learnt in CTFs, real life apps, and reading researches and news. 项目地址: http…...
OpenClaw隐私保护方案:ollama-QwQ-32B本地化数据处理流程
OpenClaw隐私保护方案:ollama-QwQ-32B本地化数据处理流程 1. 为什么需要本地化隐私保护方案 去年我在处理一份涉及客户隐私的市场分析报告时,遇到了一个棘手问题:当使用云端AI服务进行数据清洗和分析时,不得不将包含敏感字段的原…...
