当前位置: 首页 > article >正文

仅仅使用pytorch来手撕transformer架构(3):编码器模块和编码器类的实现和向前传播

仅仅使用pytorch来手撕transformer架构(2):编码器模块和编码器类的实现和向前传播

往期文章:
仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播

最适合小白入门的Transformer介绍

仅仅使用pytorch来手撕transformer架构(2):多头注意力MultiHeadAttention类的实现和向前传播

# Transformer 编码器模块
class TransformerBlock(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = MultiHeadAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size),)self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out# 编码器
class Encoder(nn.Module):def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = PositionalEncoding(embed_size, dropout, max_length)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_length = x.shapex = self.dropout(self.position_embedding(self.word_embedding(x)))for layer in self.layers:x = layer(x, x, x, mask)return x

1.编码器模块的实现

这段代码实现了一个Transformer编码器模块(Transformer Block),它是Transformer架构的核心组件之一。Transformer架构是一种基于自注意力机制(Self-Attention)的深度学习模型,广泛应用于自然语言处理(NLP)任务,如机器翻译、文本生成等。以下是对代码的详细解释:


1.1 类定义

class TransformerBlock(nn.Module):

TransformerBlock 是一个继承自 PyTorch 的 nn.Module 的类,表示一个Transformer编码器模块。nn.Module 是 PyTorch 中所有神经网络模块的基类,用于定义和管理神经网络的结构。


2.2 初始化方法

def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerBlock, self).__init__()self.attention = MultiHeadAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size),)self.dropout = nn.Dropout(dropout)
参数解释
  • embed_size: 嵌入向量的维度,表示每个词或标记(token)的特征维度。
  • heads: 多头注意力机制中的头数(Multi-Head Attention)。
  • dropout: Dropout比率,用于防止过拟合。
  • forward_expansion: 前馈网络(Feed-Forward Network, FFN)中隐藏层的扩展因子。
组件解释
  1. 多头注意力机制(MultiHeadAttention

    self.attention = MultiHeadAttention(embed_size, heads)
    

    这是Transformer的核心部分,实现了多头注意力机制。它允许模型在不同的表示子空间中学习信息。MultiHeadAttention 的具体实现没有在这段代码中给出,但通常它会将输入分为多个“头”,分别计算注意力权重,然后将结果拼接起来。

  2. 层归一化(LayerNorm

    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    

    层归一化(Layer Normalization)是一种归一化方法,用于稳定训练过程并加速收敛。它对每个样本的特征进行归一化,而不是像批量归一化(Batch Normalization)那样对整个批次进行归一化。

  3. 前馈网络(Feed-Forward Network

    self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size),
    )
    

    前馈网络是一个简单的两层全连接网络。它的作用是进一步处理多头注意力机制的输出。forward_expansion 参数控制隐藏层的大小,通常设置为一个较大的值(如4),表示隐藏层的维度是输入维度的4倍。

  4. Dropout

    self.dropout = nn.Dropout(dropout)
    

    Dropout 是一种正则化技术,通过随机丢弃一部分神经元的输出来防止过拟合。dropout 参数表示丢弃的概率。


3. 前向传播方法

def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)x = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return out
参数解释
  • value: 值向量,用于计算注意力权重后的加权求和。
  • key: 键向量,用于计算注意力权重。
  • query: 查询向量,用于计算注意力权重。
  • mask: 掩码,用于防止某些位置的信息泄露(如在自注意力中屏蔽未来信息)。
流程解释
  1. 多头注意力

    attention = self.attention(value, key, query, mask)
    

    首先,使用多头注意力机制计算注意力输出。valuekeyquery 是输入的三个部分,mask 用于控制哪些位置的信息可以被关注。

  2. 残差连接与层归一化

    x = self.dropout(self.norm1(attention + query))
    

    将注意力输出与输入的 query 进行残差连接(attention + query),然后通过层归一化(LayerNorm)。最后,应用 Dropout 防止过拟合。

  3. 前馈网络

    forward = self.feed_forward(x)
    

    将经过归一化的输出传递到前馈网络中进行进一步处理。

  4. 第二次残差连接与层归一化

    out = self.dropout(self.norm2(forward + x))
    

    将前馈网络的输出与之前的输出 x 进行残差连接,再次通过层归一化和 Dropout。

  5. 返回结果

    return out
    

    最终返回处理后的输出。


4. 总结

Transformer编码器模块,其核心包括:

  • 多头注意力机制(MultiHeadAttention)。
  • 残差连接(Residual Connection)。
  • 层归一化(LayerNorm)。
  • 前馈网络(Feed-Forward Network)。
  • Dropout 正则化。

这些组件共同作用,使得Transformer能够高效地处理序列数据,并在许多NLP任务中取得了优异的性能。

2.编码器的实现

这段代码定义了一个完整的 Transformer 编码器(Encoder),它是 Transformer 架构中的一个重要组成部分。编码器的作用是将输入序列(如源语言文本)转换为上下文表示,这些表示可以被解码器(Decoder)用于生成目标序列(如目标语言文本)。以下是对代码的详细解释:


1. 类定义

class Encoder(nn.Module):

Encoder 是一个继承自 PyTorch 的 nn.Module 的类,用于定义 Transformer 编码器的结构。nn.Module 是 PyTorch 中所有神经网络模块的基类,用于定义和管理神经网络的结构。


2. 初始化方法

def __init__(self, src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):super(Encoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = PositionalEncoding(embed_size, dropout, max_length)self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)
参数解释
  • src_vocab_size: 源语言词汇表的大小,即输入序列中可能的标记(token)数量。
  • embed_size: 嵌入向量的维度,表示每个词或标记的特征维度。
  • num_layers: 编码器中 Transformer 块(TransformerBlock)的数量。
  • heads: 多头注意力机制中的头数。
  • device: 运行设备(如 CPU 或 GPU)。
  • forward_expansion: 前馈网络(FFN)中隐藏层的扩展因子。
  • dropout: Dropout 比率,用于防止过拟合。
  • max_length: 输入序列的最大长度,用于位置编码。
组件解释
  1. 词嵌入(word_embedding

    self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
    

    词嵌入层将输入的标记(token)索引映射为固定维度的嵌入向量。src_vocab_size 是词汇表的大小,embed_size 是嵌入向量的维度。

  2. 位置编码(position_embedding

    self.position_embedding = PositionalEncoding(embed_size, dropout, max_length)
    

    位置编码层用于为每个标记添加位置信息,使得模型能够捕捉序列中的顺序关系。PositionalEncoding 的具体实现没有在这段代码中给出,但通常它会根据标记的位置生成一个固定维度的向量,并将其与词嵌入相加。

  3. Transformer 块(TransformerBlock

    self.layers = nn.ModuleList([TransformerBlock(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)]
    )
    

    编码器由多个 Transformer 块组成。每个 Transformer 块包含多头注意力机制和前馈网络。num_layers 表示 Transformer 块的数量。

  4. Dropout

    self.dropout = nn.Dropout(dropout)
    

    Dropout 是一种正则化技术,通过随机丢弃一部分神经元的输出来防止过拟合。dropout 参数表示丢弃的概率。


3. 前向传播方法

def forward(self, x, mask):N, seq_length = x.shapex = self.dropout(self.position_embedding(self.word_embedding(x)))for layer in self.layers:x = layer(x, x, x, mask)return x
参数解释
  • x: 输入序列,形状为 (N, seq_length),其中 N 是批次大小,seq_length 是序列长度。
  • mask: 掩码,用于防止某些位置的信息泄露(如在自注意力中屏蔽未来信息)。
流程解释
  1. 词嵌入与位置编码

    x = self.dropout(self.position_embedding(self.word_embedding(x)))
    
    • 首先,将输入序列 x 通过词嵌入层(word_embedding)得到嵌入向量。
    • 然后,将嵌入向量与位置编码(position_embedding)相加,为每个标记添加位置信息。
    • 最后,应用 Dropout 防止过拟合。
  2. 逐层传递

    for layer in self.layers:x = layer(x, x, x, mask)
    
    • 输入序列 x 逐层传递到每个 Transformer 块中。在每个块中:
      • valuekeyquery 都是 x,因为这是自注意力机制(Self-Attention)。
      • mask 用于控制哪些位置的信息可以被关注。
    • 每个 Transformer 块的输出会作为下一层的输入。
  3. 返回结果

    return x
    
    • 最终返回编码器的输出,形状为 (N, seq_length, embed_size),表示每个位置的上下文表示。

4. 总结

Transformer 编码器,其主要功能包括:

  1. 词嵌入与位置编码:将输入标记转换为嵌入向量,并添加位置信息。
  2. 多层 Transformer 块:通过多头注意力机制和前馈网络逐层处理输入序列。
  3. 掩码机制:通过掩码控制注意力的范围,避免信息泄露。
  4. Dropout 正则化:防止过拟合。

编码器的输出是一个上下文表示,可以被解码器用于生成目标序列。这种架构在机器翻译、文本生成等任务中表现出色。

作者码字不易,觉得有用的话不妨点个赞吧,关注我,持续为您更新AI的优质内容。

相关文章:

仅仅使用pytorch来手撕transformer架构(3):编码器模块和编码器类的实现和向前传播

仅仅使用pytorch来手撕transformer架构(2):编码器模块和编码器类的实现和向前传播 往期文章: 仅仅使用pytorch来手撕transformer架构(1):位置编码的类的实现和向前传播 最适合小白入门的Transformer介绍 仅仅使用pytorch来手撕transformer…...

rust语言match模式匹配涉及转移所有权Error Case

struct S{data:String, }//注意:因为String默认是移动语义,从而决定结构体S也是移动语义,可采用(1)或(2)两种方法解决编译错误;关键思路:放弃获取结构体S的字段data的所有权,改为借用。fn process(s_ref:&a…...

小肥柴慢慢手写数据结构(C篇)(4-3 关于栈和队列的讨论)

小肥柴慢慢学习数据结构笔记(C篇)(4-3 关于栈和队列的讨论) 目录1 双端栈/队列2 栈与队列的相互转化2-1 栈转化成队列2-2 队列转化成栈 3 经典工程案例3-1 生产者和消费者模型(再次重温环形缓冲区)3-2 MapR…...

大模型在甲状腺癌诊疗全流程预测及方案制定中的应用研究

目录 一、引言 1.1 研究背景与意义 1.2 研究目的与创新点 1.3 国内外研究现状 二、大模型预测甲状腺癌的理论基础 2.1 甲状腺癌相关医学知识 2.2 大模型技术原理与特点 2.3 大模型在医疗领域的应用潜力 三、术前预测方案 3.1 预测模型构建 3.1.1 数据收集与预处理 …...

java-单列模式-final-继承-多态

内存存储区域 引用变量和普通变量引用变量放在栈中,基本数据类型的内容是在堆内存中。 对象放在堆内存中,其引用变量放在栈中,指向堆内存存放对象的地址。 静态变量放在静态区中,静态变量在程序的执行始中中分配一次,…...

Python:正则表达式

正则表达式的基础和应用 一、正则表达式核心语法(四大基石) 1. ​元字符(特殊符号)​ ​定位符 ^:匹配字符串开始位置 $:匹配字符串结束位置 \b:匹配单词边界​(如 \bword\b 匹配…...

网络通信中的带宽(Bandwidth)概念

在计算机网络中,带宽是指单位时间内可以传输的数据量,通常以比特每秒(bps)或字节每秒(Bps)为单位。 1. 理论计算 链路带宽:链路带宽是指网络链路的物理传输能力,通常由网络设备的规…...

基于杀伤链的勒索软件控制框架

40s说清楚勒索软件如何工作 基于杀伤链的勒索软件控制框架开发了4种缓解策略(预防、阻止、检测&响应、重建),覆盖18个控制域90项控制措施,以正确管理与勒索软件攻击杀伤链各阶段相关的风险。 注:本文节选出自《基于杀伤链的勒索软件防御指…...

Windows编程----结束进程

进程有启动就有终止,通过CreateProcess函数可以启动一个新的子进程,但是如何终结子进程呢?主要有四种方法: 通过主线程的入口函数(main函数、WinMain函数)的return关键字终止进程 一个应用程序只有一个入…...

三、Docker 集群管理与应用

(一)项目案例 1、准备主机 (1)关闭防火墙,或者开放TCP端口2377(用于集群管理通信)、TCP/UPD端口7946(用于节点之间的通信)、UDP端口4789(用于overlay网络流…...

无标签数据增强+高效注意力GAN:基于CARLA的夜间车辆检测精度跃升

目录 一、摘要 二、引言 三、框架 四、方法 生成合成夜间数据 昼夜图像风格转换 针对夜间图像的无标签数据增强技术 五、Coovally AI模型训练与应用平台 六、实验 数据 图像风格转换 夜间车辆检测和分类 结论 论文题目:ENHANCING NIGHTTIME VEHICLE D…...

SqlSugar 进阶之原生Sql操作与存储过程写法 【ORM框架】

系列文章目录 🎀🎀🎀 .NET开源 ORM 框架 SqlSugar 系列 🎀🎀🎀 文章目录 系列文章目录一、前言 🍃二、用法介绍三、方法列表四、使用案例五、调用存储过程六、in参数用法七、SqlServer带Go的脚…...

NO.33十六届蓝桥杯备战|函数|返回值|声明|调用|引用|函数重载(C++)

返回值 我们在设计的函数的时候,函数在经过计算后,有时候需要带回⼀些计算好的数据,这时候往往使⽤return 来返回,这⾥我们就讨论⼀下使⽤ return 返回。 return 后边可以是⼀个数值,也可以是⼀个表达式,…...

5G工业路由器赋能无人码头,港口物流智能化管理

全球贸易发展促使港口需提升运营效率,传统港口面临诸多难题,无人码头成为转型关键方向。5G 工业路由器为其提供有力通信支持,引领港口物流变革。 随着无人码头建设在全球兴起,如荷兰鹿特丹港、中国上海洋山港等。码头作业设备需实…...

机试准备第14天

首先进行树的学习。树的存储分为链式存储与顺序存储。完全二叉树是可以顺序存储的&#xff0c;将各个节点从上往下&#xff0c;从左往右存储。 第一题是找位置&#xff0c;好兄弟给的一道题&#xff0c;一遍过了。 #include <stdio.h> #include <map> #include &…...

【Academy】OAuth 2.0 身份验证漏洞 ------ OAuth 2.0 authentication vulnerabilities

OAuth 2.0 身份验证漏洞 ------ OAuth 2.0 authentication vulnerabilities 1. 什么是 OAuth&#xff1f;2. OAuth 2.0 是如何工作的&#xff1f;3. OAuth 授权类型3.1 OAuth 范围3.2 授权代码授权类型3.3 隐式授权类型 4. OAuth 身份验证4.1 识别 OAuth 身份验证4.2 侦察OAuth…...

有关Java中的多线程

学习目标 ● 掌握线程相关概念 ● 掌握线程的基本使用 ● 掌握线程池的使用 ● 了解解决线程安全方式 1.为什么要学习线程? ● 从1946年2月14日世界上第一台计算机在美国宾夕法尼亚大学诞生到今天&#xff0c;计算和处理的模式早已从单用户单任务的串行模式发展到了多用户多…...

【eNSP实战】配置交换机端口安全

拓扑图 目的&#xff1a;让交换机端口与主机mac绑定&#xff0c;防止私接主机。 主机PC配置不展示&#xff0c;按照图中配置即可。 开始配置之前&#xff0c;使用PC1 ping 一遍PC2、PC3、PC4、PC5&#xff0c;让交换机mac地址表刷新一下记录。 LSW1查看mac地址表 LSW1配置端…...

MAC-禁止百度网盘自动升级更新

通过终端禁用更新服务(推荐)​ 此方法直接移除百度网盘的自动更新组件,无需修改系统文件。 ​步骤: ​1.关闭百度网盘后台进程 按下 Command + Space → 输入「活动监视器」→ 搜索 BaiduNetdisk 或 UpdateAgent → 结束相关进程。 ​2.删除自动更新配置文件 打开终端…...

LLMs基础学习(一)概念、模型分类、主流开源框架介绍以及模型的预训练任务

文章目录 LLM基础学习&#xff08;一&#xff09;一、大语言模型&#xff08;LLMs&#xff09;的简单介绍定义与基本信息核心特点局限性参考的模型 二、大语言模型&#xff08;LLMs&#xff09;名称后 “175B”“60B”“540B” 等数字的含义数字代表模型参数数量具体示例参数数…...

【leetcode hot 100 24】两两交换链表中的节点

解法一&#xff1a;先判断链表是否为空&#xff0c;若为空则直接返回&#xff1b;否则用left和right指向第一个和第二个节点&#xff0c;当这两个节点非空时一直执行交换。其中先判断right.nextnull&#xff0c;说明链表为偶数且已经交换完break&#xff1b;再判断right.next.n…...

软件IIC和硬件IIC的主要区别,用标准库举例!

学习交流792125321&#xff0c;欢迎一起加入讨论&#xff01; 在学习iic的时候&#xff0c;我们经常会遇到软件 IC和硬件 IC,它两到底有什么区别呢&#xff1f; 软件 IC&#xff08;模拟 IC&#xff09;和硬件 IC&#xff08;外设 IC&#xff09;是两种实现 IC 总线通信的方式…...

Codeforces Round 1006 Div3 A-E

A 题目描述 夏目章人&#xff08;Natsume Akito&#xff09;刚刚在一个新世界苏醒&#xff0c;便立即收到了他的第一个任务&#xff01;系统为他提供了一个包含 n 个零的数组 a&#xff0c;以及两个整数 k 和 p。在每次操作中&#xff0c;章人需要选择两个整数 i 和 x&#x…...

4个 Vue 路由实现的过程

大家好&#xff0c;我是大澈&#xff01;一个喜欢结交朋友、喜欢编程技术和科技前沿的老程序员&#x1f468;&#x1f3fb;‍&#x1f4bb;&#xff0c;关注我&#xff0c;科技未来或许我能帮到你&#xff01; Vue 路由相信朋友们用的都很熟了&#xff0c;但是你知道 Vue 路由…...

git文件过大导致gitea仓库镜像推送失败问题解决(push failed: context deadline exceeded)

问题描述&#xff1a; 今天发现gitea仓库推送到某个镜像仓库的操作几个月前已经报错终止推送了&#xff0c;报错如下&#xff1a; 首先翻译报错提示可知是因为git仓库大小超过1G限制。检查本地.git文件&#xff0c;发现.git文件大小已达到1.13G。确定是.git文件过大导致&…...

简要分析NETLINK_ROUTE参数

NETLINK_ROUTE时Linux内核中Netlink协议族的一个子类型&#xff0c;专用于用户空间与内核网络子系统之间的通信&#xff0c;它是实现动态网络配置&#xff08;如路由表、网络接口、地址管理&#xff09;的核心机制&#xff0c;为现代网络管理工具&#xff08;如iproute2&#x…...

Java中default关键字

1. 在 switch 语句中作为默认分支 在 switch 语句里&#xff0c;default 用于定义当所有 case 标签的值都无法匹配 switch 表达式的值时要执行的代码块。它并非强制要求&#xff0c;但使用它可以增强代码的健壮性&#xff0c;处理未预见的情况。 public class SwitchDefaultE…...

怎么利用DeepSeek进行PCB设计?

最近在琢磨利用Deepseek改善PCB的细节设计&#xff0c;毕竟立创EDA里面没有集成DS&#xff0c;因此&#xff0c;如何让DS能识别图片成了重中之重。所幸最近腾讯元宝里面集成了R1的满血版&#xff0c;这个版本可以上传图片&#xff0c;于是让DS识别图片就可能了。 在原理图设计…...

详细介绍 Jupyter nbconvert 工具及其用法:如何将 Notebook 转换为 Python 脚本

nbconvert 是 Jupyter 提供的一个非常强大的工具&#xff0c;允许用户将 Jupyter Notebook 文件&#xff08;.ipynb&#xff09;转换成多种格式&#xff0c;包括 Python 脚本&#xff08;.py&#xff09;、HTML、PDF、LaTeX 等。你可以通过命令行来运行 nbconvert&#xff0c;也…...

windows上传uniapp打包的ipa文件到app store构建版本

uniapp是一个跨平台的框架&#xff0c;使用windows电脑也可以开发ios软件&#xff0c;因为uniapp的打包是在云端实现的&#xff0c;本地电脑无需使用mac电脑即可完成打包。 但是打包后的ipa文件需要上架到app store的构建版本上&#xff0c;没有mac电脑&#xff0c;又如何上架…...