Transformer的PyTorch实现之若干问题探讨(一)
《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。
1.自定义数据中enc_input、dec_input及dec_output的区别
博文中给出了两对德语翻译成英语的例子:
# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input dec_input dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]
初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:
在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。
我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。
实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。
2.Transformer的训练流程
此处给出博文中附带的非常简洁的Transformer训练代码:
from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1)) # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')
这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:
ich mochte ein bier
在进行训练时,当我们给出起始符S,接下来应该预测出:
I
那训练时,有了SI后,则应该预测出
want
那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。
那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。
相关文章:

Transformer的PyTorch实现之若干问题探讨(一)
《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。 1.自定义数据中enc_input、dec_input及dec_output的区别 博文中给出了两对德语翻译成英语的例子: # S: de…...
系统参数SystemParameters.MinimumHorizontalDragDistance
SystemParameters.MinimumHorizontalDragDistance 是一个系统参数,它表示在拖放操作中鼠标水平移动的最小距离。 当用户按下鼠标左键并开始移动鼠标时,系统会检查鼠标的水平移动距离是否超过了 SystemParameters.MinimumHorizontalDragDistance。只有当…...
平屋顶安装光伏需要注意哪些事项?
我国对于房屋建设的屋顶形式,主要有平屋顶、斜屋顶、曲面屋顶和多波式折板屋顶等。今天来讲讲在平屋顶安装光伏,需要注意的事项。 1.屋顶结构:在安装光伏系统之前,需要对屋顶结构进行评估,确保屋顶能够承受光伏系统的…...

《Git 简易速速上手小册》第7章:处理大型项目(2024 最新版)
文章目录 7.1 Git Large File Storage (LFS)7.1.1 基础知识讲解7.1.2 重点案例:在 Python 项目中使用 Git LFS 管理数据集7.1.3 拓展案例 1:使用 Git LFS 管理大型静态资源7.1.4 拓展案例 2:优化现有项目中的大文件管理 7.2 性能优化技巧7.2.…...

从0开始学Docker ---Docker安装教程
Docker安装教程 本安装教程参考Docker官方文档,地址如下: https://docs.docker.com/engine/install/centos/ 1.卸载旧版 首先如果系统中已经存在旧的Docker,则先卸载: yum remove docker \docker-client \docker-client-latest…...

嵌入式学习之Linux入门篇笔记——15,Linux编写第一个自己的命令
配套视频学习链接:http://【【北京迅为】嵌入式学习之Linux入门篇】 https://www.bilibili.com/video/BV1M7411m7wT/?p4&share_sourcecopy_web&vd_sourcea0ef2c4953d33a9260910aaea45eaec8 1.什么是命令? 命令就是可执行程序。 比如 ls -a…...

【C语言】SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol)
一、SYSCALL_DEFINE3与系统调用 在Linux操作系统中,为了从用户空间跳转到内核空间执行特定的内核级操作,使用了一种机制叫做"系统调用"(System Call)。系统调用是操作系统提供给程序员访问和使用内核功能的接口。例如&…...

C++实现鼠标点击和获取鼠标位置(编译环境visual studio 2022)
1环境说明 2获取鼠标位置的接口 void GetMouseCurPoint() {POINT mypoint;for (int i 0; i < 100; i){GetCursorPos(&mypoint);//获取鼠标当前所在位置printf("% ld, % ld \n", mypoint.x, mypoint.y);Sleep(1000);} } 3操作鼠标左键和右键的接口 void Mo…...

Matplotlib绘制炫酷散点图:从二维到三维,再到散点图矩阵的完整指南与实战【第58篇—python:Matplotlib绘制炫酷散点图】
文章目录 Matplotlib绘制炫酷散点图:二维、三维和散点图矩阵的参数说明与实战引言二维散点图三维散点图散点图矩阵二维散点图进阶:辅助线、注释和子图三维散点图进阶:动画效果和交互性散点图矩阵进阶:调整样式和添加密度图总结与展…...

Docker-Learn(一)使用Dockerfile创建Docker镜像
1.创建并运行容器 编写Dockerfile,文件名字就是为Dockerfile 在自己的工作工作空间当中新建文件,名字为Docerfile vim Dockerfile写入以下内容: # 使用一个基础镜像 FROM ubuntu:latest # 设置工作目录 WORKDIR /app # 复制当前目…...

问题:银行账号建立以后,一般需要维护哪些设置,不包括() #学习方法#经验分享
问题:银行账号建立以后,一般需要维护哪些设置,不包括() A.维护结算科目对照 B.期初余额初始化刷 C.自定义转账定义 D.对账单初始化 参考答案如图所示...

教授LLM思考和行动:ReAct提示词工程
ReAct:论文主页 原文链接:Teaching LLMs to Think and Act: ReAct Prompt Engineering 在人类从事一项需要多个步骤的任务时,而步骤和步骤之间,或者说动作和动作之间,往往会有一个推理过程。让LLM把内心独白说出来&am…...

FPGA_工程_按键控制的基于Rom数码管显示
一 信号 框图: 其中 key_filter seg_595_dynamic均为已有模块,直接例化即可使用,rom_8*256模块,调用rom ip实现。Rom_ctrl模块需要重新编写。 波形图: 二 代码 module key_fliter #(parameter CNT_MAX 24d9_999_99…...

WordPress Plugin HTML5 Video Player SQL注入漏洞复现(CVE-2024-1061)
0x01 产品简介 WordPress和WordPress plugin都是WordPress基金会的产品。WordPress是一套使用PHP语言开发的博客平台。该平台支持在PHP和MySQL的服务器上架设个人博客网站。WordPress plugin是一个应用插件。 0x02 漏洞概述 WordPress Plugin HTML5 Video Player 插件 get_v…...
【Kotlin】Kotlin基本数据类型
1 变量声明 var a : Int // 声明整数类型变量 var b : Int 1 // 声明整数类型变量, 同时赋初值为1 var c 1 // 声明整数类型变量, 同时赋初值为1 val d 1 // 声明整数类型常量, 值为1(后面不能改变d的值) 变量命名规范如下。 变量名可以由字母、数字、下划线(_…...

UDP端口探活的那些细节
一 背景 商业客户反馈用categraf的net_response插件配置了udp探测, 遇到报错了,如图 udp是无连接的,无法用建立连接的形式判断端口。 插件最初的设计是需要配置udp的发送字符,并且配置期望返回的字符串, [[instances]] targets…...
拦截器配置,FeignClient根据业务规则实现微服务动态路由
文章目录 业务场景拦截器用法Open Feign介绍 业务场景 我们服务使用Spring Cloud微服务架构,使用Spring Cloud Gateway 作为网关,使用 Spring Cloud OpenFeign 作为服务间通信方式我们现在做的信控平台,主要功能之一就是对路口信号机进行管控…...

预测模型:MATLAB线性回归
1. 线性回归模型的基本原理 线性回归是统计学中用来预测连续变量之间关系的一种方法。它假设变量之间存在线性关系,可以通过一个或多个自变量(预测变量)来预测因变量(响应变量)的值。基本的线性回归模型可以表示为&…...

【人工智能】神奇的Embedding:文本变向量,大语言模型智慧密码解析(10)
什么是嵌入? OpenAI 的文本嵌入衡量文本字符串的相关性。嵌入通常用于: Search 搜索(结果按与查询字符串的相关性排序)Clustering 聚类(文本字符串按相似性分组)Recommendations 推荐(推荐具有…...

Redis + Lua 实现分布式限流器
文章目录 Redis Lua 限流实现1. 导入依赖2. 配置application.properties3. 配置RedisTemplate实例4. 定义限流类型枚举类5. 自定义注解6. 切面代码实现7. 控制层实现8. 测试 相比 Redis事务, Lua脚本的优点: 减少网络开销:使用Lua脚本&…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...
连锁超市冷库节能解决方案:如何实现超市降本增效
在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...

HTML 列表、表格、表单
1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...

k8s业务程序联调工具-KtConnect
概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)
本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...

推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...

LabVIEW双光子成像系统技术
双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制,展现出显著的技术优势: 深层组织穿透能力:适用于活体组织深度成像 高分辨率观测性能:满足微观结构的精细研究需求 低光毒性特点:减少对样本的损伤…...

Ubuntu Cursor升级成v1.0
0. 当前版本低 使用当前 Cursor v0.50时 GitHub Copilot Chat 打不开,快捷键也不好用,当看到 Cursor 升级后,还是蛮高兴的 1. 下载 Cursor 下载地址:https://www.cursor.com/cn/downloads 点击下载 Linux (x64) ,…...