Transformer的PyTorch实现之若干问题探讨(一)
《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。
1.自定义数据中enc_input、dec_input及dec_output的区别
博文中给出了两对德语翻译成英语的例子:
# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input dec_input dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]
初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。
我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。
实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。
2.Transformer的训练流程
此处给出博文中附带的非常简洁的Transformer训练代码:
from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1)) # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')
这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:
ich mochte ein bier
在进行训练时,当我们给出起始符S,接下来应该预测出:
I
那训练时,有了SI后,则应该预测出
want
那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。
那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。
相关文章:
Transformer的PyTorch实现之若干问题探讨(一)
《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。 1.自定义数据中enc_input、dec_input及dec_output的区别 博文中给出了两对德语翻译成英语的例子: # S: de…...
系统参数SystemParameters.MinimumHorizontalDragDistance
SystemParameters.MinimumHorizontalDragDistance 是一个系统参数,它表示在拖放操作中鼠标水平移动的最小距离。 当用户按下鼠标左键并开始移动鼠标时,系统会检查鼠标的水平移动距离是否超过了 SystemParameters.MinimumHorizontalDragDistance。只有当…...
平屋顶安装光伏需要注意哪些事项?
我国对于房屋建设的屋顶形式,主要有平屋顶、斜屋顶、曲面屋顶和多波式折板屋顶等。今天来讲讲在平屋顶安装光伏,需要注意的事项。 1.屋顶结构:在安装光伏系统之前,需要对屋顶结构进行评估,确保屋顶能够承受光伏系统的…...
《Git 简易速速上手小册》第7章:处理大型项目(2024 最新版)
文章目录 7.1 Git Large File Storage (LFS)7.1.1 基础知识讲解7.1.2 重点案例:在 Python 项目中使用 Git LFS 管理数据集7.1.3 拓展案例 1:使用 Git LFS 管理大型静态资源7.1.4 拓展案例 2:优化现有项目中的大文件管理 7.2 性能优化技巧7.2.…...
从0开始学Docker ---Docker安装教程
Docker安装教程 本安装教程参考Docker官方文档,地址如下: https://docs.docker.com/engine/install/centos/ 1.卸载旧版 首先如果系统中已经存在旧的Docker,则先卸载: yum remove docker \docker-client \docker-client-latest…...
嵌入式学习之Linux入门篇笔记——15,Linux编写第一个自己的命令
配套视频学习链接:http://【【北京迅为】嵌入式学习之Linux入门篇】 https://www.bilibili.com/video/BV1M7411m7wT/?p4&share_sourcecopy_web&vd_sourcea0ef2c4953d33a9260910aaea45eaec8 1.什么是命令? 命令就是可执行程序。 比如 ls -a…...
【C语言】SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol)
一、SYSCALL_DEFINE3与系统调用 在Linux操作系统中,为了从用户空间跳转到内核空间执行特定的内核级操作,使用了一种机制叫做"系统调用"(System Call)。系统调用是操作系统提供给程序员访问和使用内核功能的接口。例如&…...
C++实现鼠标点击和获取鼠标位置(编译环境visual studio 2022)
1环境说明 2获取鼠标位置的接口 void GetMouseCurPoint() {POINT mypoint;for (int i 0; i < 100; i){GetCursorPos(&mypoint);//获取鼠标当前所在位置printf("% ld, % ld \n", mypoint.x, mypoint.y);Sleep(1000);} } 3操作鼠标左键和右键的接口 void Mo…...
Matplotlib绘制炫酷散点图:从二维到三维,再到散点图矩阵的完整指南与实战【第58篇—python:Matplotlib绘制炫酷散点图】
文章目录 Matplotlib绘制炫酷散点图:二维、三维和散点图矩阵的参数说明与实战引言二维散点图三维散点图散点图矩阵二维散点图进阶:辅助线、注释和子图三维散点图进阶:动画效果和交互性散点图矩阵进阶:调整样式和添加密度图总结与展…...
Docker-Learn(一)使用Dockerfile创建Docker镜像
1.创建并运行容器 编写Dockerfile,文件名字就是为Dockerfile 在自己的工作工作空间当中新建文件,名字为Docerfile vim Dockerfile写入以下内容: # 使用一个基础镜像 FROM ubuntu:latest # 设置工作目录 WORKDIR /app # 复制当前目…...
问题:银行账号建立以后,一般需要维护哪些设置,不包括() #学习方法#经验分享
问题:银行账号建立以后,一般需要维护哪些设置,不包括() A.维护结算科目对照 B.期初余额初始化刷 C.自定义转账定义 D.对账单初始化 参考答案如图所示...
教授LLM思考和行动:ReAct提示词工程
ReAct:论文主页 原文链接:Teaching LLMs to Think and Act: ReAct Prompt Engineering 在人类从事一项需要多个步骤的任务时,而步骤和步骤之间,或者说动作和动作之间,往往会有一个推理过程。让LLM把内心独白说出来&am…...
FPGA_工程_按键控制的基于Rom数码管显示
一 信号 框图: 其中 key_filter seg_595_dynamic均为已有模块,直接例化即可使用,rom_8*256模块,调用rom ip实现。Rom_ctrl模块需要重新编写。 波形图: 二 代码 module key_fliter #(parameter CNT_MAX 24d9_999_99…...
WordPress Plugin HTML5 Video Player SQL注入漏洞复现(CVE-2024-1061)
0x01 产品简介 WordPress和WordPress plugin都是WordPress基金会的产品。WordPress是一套使用PHP语言开发的博客平台。该平台支持在PHP和MySQL的服务器上架设个人博客网站。WordPress plugin是一个应用插件。 0x02 漏洞概述 WordPress Plugin HTML5 Video Player 插件 get_v…...
【Kotlin】Kotlin基本数据类型
1 变量声明 var a : Int // 声明整数类型变量 var b : Int 1 // 声明整数类型变量, 同时赋初值为1 var c 1 // 声明整数类型变量, 同时赋初值为1 val d 1 // 声明整数类型常量, 值为1(后面不能改变d的值) 变量命名规范如下。 变量名可以由字母、数字、下划线(_…...
UDP端口探活的那些细节
一 背景 商业客户反馈用categraf的net_response插件配置了udp探测, 遇到报错了,如图 udp是无连接的,无法用建立连接的形式判断端口。 插件最初的设计是需要配置udp的发送字符,并且配置期望返回的字符串, [[instances]] targets…...
拦截器配置,FeignClient根据业务规则实现微服务动态路由
文章目录 业务场景拦截器用法Open Feign介绍 业务场景 我们服务使用Spring Cloud微服务架构,使用Spring Cloud Gateway 作为网关,使用 Spring Cloud OpenFeign 作为服务间通信方式我们现在做的信控平台,主要功能之一就是对路口信号机进行管控…...
预测模型:MATLAB线性回归
1. 线性回归模型的基本原理 线性回归是统计学中用来预测连续变量之间关系的一种方法。它假设变量之间存在线性关系,可以通过一个或多个自变量(预测变量)来预测因变量(响应变量)的值。基本的线性回归模型可以表示为&…...
【人工智能】神奇的Embedding:文本变向量,大语言模型智慧密码解析(10)
什么是嵌入? OpenAI 的文本嵌入衡量文本字符串的相关性。嵌入通常用于: Search 搜索(结果按与查询字符串的相关性排序)Clustering 聚类(文本字符串按相似性分组)Recommendations 推荐(推荐具有…...
Redis + Lua 实现分布式限流器
文章目录 Redis Lua 限流实现1. 导入依赖2. 配置application.properties3. 配置RedisTemplate实例4. 定义限流类型枚举类5. 自定义注解6. 切面代码实现7. 控制层实现8. 测试 相比 Redis事务, Lua脚本的优点: 减少网络开销:使用Lua脚本&…...
【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...
Python爬虫(一):爬虫伪装
一、网站防爬机制概述 在当今互联网环境中,具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类: 身份验证机制:直接将未经授权的爬虫阻挡在外反爬技术体系:通过各种技术手段增加爬虫获取数据的难度…...
NFT模式:数字资产确权与链游经济系统构建
NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...
Spring AI与Spring Modulith核心技术解析
Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...
以光量子为例,详解量子获取方式
光量子技术获取量子比特可在室温下进行。该方式有望通过与名为硅光子学(silicon photonics)的光波导(optical waveguide)芯片制造技术和光纤等光通信技术相结合来实现量子计算机。量子力学中,光既是波又是粒子。光子本…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
排序算法总结(C++)
目录 一、稳定性二、排序算法选择、冒泡、插入排序归并排序随机快速排序堆排序基数排序计数排序 三、总结 一、稳定性 排序算法的稳定性是指:同样大小的样本 **(同样大小的数据)**在排序之后不会改变原始的相对次序。 稳定性对基础类型对象…...
GitHub 趋势日报 (2025年06月06日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...
「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案
在移动互联网营销竞争白热化的当下,推客小程序系统凭借其裂变传播、精准营销等特性,成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径,助力开发者打造具有市场竞争力的营销工具。 一、系统核心功能架构&…...
