当前位置: 首页 > news >正文

手撕Transformer编码器:从Self-Attention到Positional Encoding的PyTorch逐行实现


Transformer 编码器深度解读 + 代码实战


1. 编码器核心作用

Transformer 编码器的核心任务是将输入序列(如文本、语音)转换为富含上下文语义的高维特征表示。它通过多层自注意力(Self-Attention)和前馈网络(FFN),逐步建模全局依赖关系,解决传统RNN/CNN的长距离依赖缺陷。


2. 编码器单层结构详解

每层编码器包含以下模块(附 PyTorch 代码):

2.1 多头自注意力(Multi-Head Self-Attention)
class MultiHeadAttention(nn.Module):def __init__(self, embed_size, heads):super().__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // heads# 线性变换层生成 Q, K, Vself.to_qkv = nn.Linear(embed_size, embed_size * 3)  # 同时生成 Q/K/Vself.scale = self.head_dim ** -0.5  # 缩放因子# 输出线性层self.to_out = nn.Linear(embed_size, embed_size)def forward(self, x, mask=None):batch_size, seq_len, _ = x.shape# 生成 Q, K, V 并分割多头qkv = self.to_qkv(x).chunk(3, dim=-1)  # 拆分为 [Q, K, V]q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)# 计算注意力分数 (QK^T / sqrt(d_k))attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale# 掩码(编码器通常不需要,但保留接口)if mask is not None:attn = attn.masked_fill(mask == 0, -1e10)# Softmax 归一化attn = torch.softmax(attn, dim=-1)# 加权求和out = torch.einsum('bhij,bhjd->bhid', attn, v)out = out.reshape(batch_size, seq_len, self.embed_size)# 输出线性变换return self.to_out(out)

代码解析

  • nn.Linear 生成 Q/K/V 矩阵,通过 chunk 分割。
  • einsum 实现高效矩阵运算,计算注意力分数。
  • 支持掩码(虽编码器通常不用,但为兼容性保留)。

2.2 前馈网络(Feed-Forward Network)
class FeedForward(nn.Module):def __init__(self, embed_size, expansion=4):super().__init__()self.net = nn.Sequential(nn.Linear(embed_size, embed_size * expansion),  # 扩展维度nn.GELU(),  # 更平滑的激活函数(比ReLU效果更好)nn.Linear(embed_size * expansion, embed_size)   # 压缩回原维度)def forward(self, x):return self.net(x)

代码解析

  • 典型结构:扩展维度(如512→2048)→激活→压缩回原维度。
  • 使用 GELU 替代 ReLU(现代Transformer的常见选择)。

2.3 残差连接 + 层归一化(Add & Norm)
class TransformerEncoderLayer(nn.Module):def __init__(self, embed_size, heads, dropout=0.1):super().__init__()self.attn = MultiHeadAttention(embed_size, heads)self.ffn = FeedForward(embed_size)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.dropout = nn.Dropout(dropout)def forward(self, x):# 自注意力子层attn_out = self.attn(x)x = x + self.dropout(attn_out)  # 残差连接x = self.norm1(x)# 前馈子层ffn_out = self.ffn(x)x = x + self.dropout(ffn_out)   # 残差连接x = self.norm2(x)return x

代码解析

  • 每个子层后执行 x = x + dropout(sublayer(x)),再层归一化。
  • 残差连接确保梯度稳定,层归一化加速收敛。

3. 位置编码(Positional Encoding)
class PositionalEncoding(nn.Module):def __init__(self, embed_size, max_len=5000):super().__init__()pe = torch.zeros(max_len, embed_size)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0)/embed_size)pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置self.register_buffer('pe', pe.unsqueeze(0))   # (1, max_len, embed_size)def forward(self, x):return x + self.pe[:, :x.size(1)]  # 自动广播到 (batch_size, seq_len, embed_size)

代码解析

  • 通过正弦/余弦函数编码绝对位置。
  • register_buffer 将位置编码注册为模型常量(不参与训练)。

4. 完整编码器实现
class TransformerEncoder(nn.Module):def __init__(self, vocab_size, embed_size, layers, heads, dropout=0.1):super().__init__()self.embedding = nn.Embedding(vocab_size, embed_size)self.pos_encoding = PositionalEncoding(embed_size)self.layers = nn.ModuleList([TransformerEncoderLayer(embed_size, heads, dropout)for _ in range(layers)])def forward(self, x):# 输入x形状: (batch_size, seq_len)x = self.embedding(x)  # (batch_size, seq_len, embed_size)x = self.pos_encoding(x)for layer in self.layers:x = layer(x)return x  # (batch_size, seq_len, embed_size)

5. 实战测试

# 参数设置
vocab_size = 10000  # 假设词表大小
embed_size = 512    # 嵌入维度
layers = 6          # 编码器层数
heads = 8           # 注意力头数# 初始化模型
encoder = TransformerEncoder(vocab_size, embed_size, layers, heads)# 模拟输入(batch_size=32, seq_len=50)
x = torch.randint(0, vocab_size, (32, 50))  # 随机生成句子# 前向传播
output = encoder(x)
print(output.shape)  # 预期输出: torch.Size([32, 50, 512])

相关文章:

手撕Transformer编码器:从Self-Attention到Positional Encoding的PyTorch逐行实现

Transformer 编码器深度解读 代码实战 1. 编码器核心作用 Transformer 编码器的核心任务是将输入序列(如文本、语音)转换为富含上下文语义的高维特征表示。它通过多层自注意力(Self-Attention)和前馈网络(FFN&#x…...

Webpack和Vite插件的开发与使用

在现代开发中一般各公司都有自己的监控平台,对前端而言如果浏览器报错的话就可以通过埋点收集错误日志,再结合sourcemap文件可以帮助我们定位到错误代码,帮助我们排查问题。这里就记录一下之前在webpack和vite两个环境中的插件开发&#xff0…...

HTTP的状态码

HTTP 状态码 当浏览者访问一个网页时,浏览者的浏览器会向网页所在服务器发出请求。当浏览器接收并显示网页前,此网页所在的服务器会返回一个包含 HTTP 状态码的信息头(server header)用以响应浏览器的请求。 常见的HTTP状态码 …...

Python函数-装饰器

装饰器 写好的函数,不做任何修改,就可以改变执行内容,在其头或尾部加入新的流程代码本质上就是使用函数嵌套,在内部嵌套定义的函数中调用原函数,从而可读在前或后加入新的代码使用的关键: 将原函数作为参数…...

【数据可视化-17】基于pyecharts的印度犯罪数据可视化分析

🧑 博主简介:曾任某智慧城市类企业算法总监,目前在美国市场的物流公司从事高级算法工程师一职,深耕人工智能领域,精通python数据挖掘、可视化、机器学习等,发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…...

HTTP请求报文头和相应报文头

一、HTTP请求报文头 HTTP请求报文由请求行、请求头和请求体组成。请求头包含客户端向服务器发送的附加信息。 1.1 请求行 格式: 方法 请求URI HTTP/版本示例: GET /index.html HTTP/1.1   方法: 请求类型,如GET、POST、PUT、DELETE等。   请求URI: 请求的资源…...

19.4.9 数据库方式操作Excel

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 本节所说的操作Excel操作是讲如何把Excel作为数据库来操作。 通过COM来操作Excel操作,请参看第21.2节 在第19.3.4节【…...

BFS 走迷宫

#include<bits/stdc.h> using namespace std; int a[100][100],v[100][100];//访问数组 n,m<100 struct point {int x;int y;int step; }; queue<point> r;//申请队列 int dx[4]{0,1,0,-1};//四个方向 右下左上 int dy[4]{1,0,-1,0}; int main() { /* 5 4 1 …...

【Linux系统】—— 简易进度条的实现

【Linux系统】—— 简易进度条的实现 1 回车和换行2 缓冲区3 进度条的准备代码4 第一版进度条5 第二版进度条 1 回车和换行 先问大家一个问题&#xff1a;回车换行是什么&#xff0c;或者说回车和换行是同一个概念吗&#xff1f;   可能大家对回车换行有一定的误解&#xff0…...

Qt 中使用 SQLite 数据库的完整指南

SQLite 是一款轻量级、嵌入式的关系型数据库&#xff0c;无需独立的服务器进程&#xff0c;数据以文件形式存储&#xff0c;非常适合桌面和移动端应用的本地数据管理。Qt 通过 Qt SQL 模块提供了对 SQLite 的原生支持&#xff0c;开发者可以轻松实现数据库的增删改查、事务处理…...

数智化时代的工单管理:从流程驱动到数据驱动-亿发

在数智化时代&#xff0c;工单管理系统已从简单的任务分发工具演变为企业运营的智能中枢。传统工单系统关注流程的线性推进&#xff0c;而现代工单管理系统则强调数据的全生命周期管理&#xff0c;通过智能算法实现工单的自动分配、优先级判定和效能优化。这种转变不仅提升了运…...

Large Language Model Distilling Medication Recommendation Model

摘要&#xff1a;药物推荐是智能医疗系统的一个重要方面&#xff0c;因为它涉及根据患者的特定健康需求开具最合适的药物。不幸的是&#xff0c;目前使用的许多复杂模型往往忽视医疗数据的细微语义&#xff0c;而仅仅严重依赖于标识信息。此外&#xff0c;这些模型在处理首次就…...

floodfill算法系列一>被围绕的区域

目录 整体思想&#xff1a;代码设计&#xff1a;代码呈现&#xff1a; 整体思想&#xff1a; 代码设计&#xff1a; 代码呈现&#xff1a; class Solution {int m,n;int[] dx {0,0,-1,1};int[] dy {-1,1,0,0};public void solve(char[][] board) {m board.length;n board[…...

Redis 01 02章——入门概述与安装配置

一、入门概述 &#xff08;1&#xff09;是什么 Redis&#xff1a;REmote Dictionary Server&#xff08;远程字典服务器&#xff09;官网解释&#xff1a;Remote Dictionary Server(远程字典服务)是完全开源的&#xff0c;使用ANSIC语言编写遵守BSD协议&#xff0c;是一个高…...

windows基于cpu安装pytorch运行faster-whisper-large-v3实现语音转文字

1.创建虚拟环境 conda create -n faster-whisper python3.10 conda activate faster-whisper 2.安装cpu版本的pytorch pip3 install torch torchvision torchaudio -i https://pypi.tuna.tsinghua.edu.cn/simple 3.验证pytorch安装结果 (faster-whisper) H:\big-model\faste…...

AI大模型(如GPT、BERT等)可以通过自然语言处理(NLP)和机器学习技术,显著提升测试效率

在软件测试中,AI大模型(如GPT、BERT等)可以通过自然语言处理(NLP)和机器学习技术,显著提升测试效率。以下是几个具体的应用场景及对应的代码实现示例: 1. 自动生成测试用例 AI大模型可以根据需求文档或用户故事自动生成测试用例。 代码示例(使用 OpenAI GPT API): …...

【Prometheus】prometheus黑盒监控balckbox全面解析与应用实战

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,阿里云开发者社区专家博主,CSDN全栈领域优质创作者,掘金优秀博主,51CTO博客专家等。 🏆《博客》:Python全…...

CSS实现单行、多行文本溢出显示省略号(…)

在网页设计中&#xff0c;我们常常遇到这样的情况&#xff1a;文本内容太长&#xff0c;无法完全显示在一个固定的区域内。为了让界面看起来更整洁&#xff0c;我们可以使用省略号&#xff08;…&#xff09;来表示内容溢出。这不仅能提升用户体验&#xff0c;还能避免内容溢出…...

服务器中部署大模型DeepSeek-R1 | 本地部署DeepSeek-R1大模型 | deepseek-r1部署详细教程

0. 部署前的准备 首先我们需要足够算力的机器&#xff0c;这里我在vultr中租了有一张A16显卡一共16GB显存的服务器作为演示。部署的模型参数为14b的。如果需要部署满血版本671b的&#xff0c;需要更大的算力支持&#xff0c;这里由于是个人资金有限&#xff0c;就演示14b的部署…...

元学习之孪生网络Siamese Network

简介&#xff1a;元学习是一种思想&#xff0c;一般以神经网络作为特征嵌入的工具&#xff0c;实现对数据特征的提取&#xff0c;然后通过构造某种指标以引导优化器对模型参数进行优化。而最小化距离是最常见的学习目标&#xff0c;这就是熟知的度量学习&#xff0c;度量学习里…...

地震勘探——干扰波识别、井中地震时距曲线特点

目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波&#xff1a;可以用来解决所提出的地质任务的波&#xff1b;干扰波&#xff1a;所有妨碍辨认、追踪有效波的其他波。 地震勘探中&#xff0c;有效波和干扰波是相对的。例如&#xff0c;在反射波…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止

<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet&#xff1a; https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...

在四层代理中还原真实客户端ngx_stream_realip_module

一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡&#xff08;如 HAProxy、AWS NLB、阿里 SLB&#xff09;发起上游连接时&#xff0c;将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后&#xff0c;ngx_stream_realip_module 从中提取原始信息…...

Python爬虫(一):爬虫伪装

一、网站防爬机制概述 在当今互联网环境中&#xff0c;具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类&#xff1a; 身份验证机制&#xff1a;直接将未经授权的爬虫阻挡在外反爬技术体系&#xff1a;通过各种技术手段增加爬虫获取数据的难度…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂度…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

GitHub 趋势日报 (2025年06月06日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...