当前位置: 首页 > news >正文

Transformer 代码剖析1 - 数据处理 (pytorch实现)

引言

Transformer 架构自《Attention Is All You Need》论文发表以来,在自然语言处理领域引起了巨大的变革。它摒弃了传统的循环结构,完全基于注意力机制,显著提高了处理序列数据的效率和性能。本文将通过对一个具体的项目代码结构进行详细分析,带领大家深入了解 Transformer 模型的数据处理部分。

项目结构概述

首先,让我们来看看项目的整体结构:(参考项目代码)

transformer-master
├── paper\
│   └── attention is all you need.pdf
├── image\
├── models\
│   ├── __init__.py
│   ├── blocks\
│   │   ├── __init__.py
│   │   ├── decoder_layer.py
│   │   └── encoder_layer.py
│   ├── embedding\
│   │   ├── __init__.py
│   │   ├── positional_encoding.py
│   │   ├── token_embeddings.py
│   │   └── transformer_embedding.py
│   ├── layers\
│   │   ├── __init__.py
│   │   ├── layer_norm.py
│   │   ├── multi_head_attention.py
│   │   ├── position_wise_feedforward.py
│   │   └── scaled_dot_product_attention.py
│   └── model\
│       ├── __init__.py
│       ├── encoder.py
│       ├── decoder.py
│       └── transformer.py
├── saved\ 
├── util\
│   ├── __init__.py
│   ├── bleu.py
│   ├── data_loader.py
│   ├── epoch_timer.py
│   └── tokenizer.py
├── conf.py
├── data.py
├── graph.py
├── train.py
└── README.md

这个项目结构清晰,包含了论文、模型模块、数据处理、训练脚本等部分。其中,data.py 文件负责数据的处理和准备,是模型训练的基础。

数据处理流程

数据处理流程图

Vocabulary and Indexes
Data Preparation
Initialization
Get source pad index
Get target pad index
Get target sos index
Get source vocab size
Get target vocab size
Make dataset
Build vocabulary
Create data iterators
Import configurations
Import DataLoader
Import Tokenizer
Create Tokenizer instance
Create DataLoader instance

从流程图中可以看出,数据处理主要分为三个阶段:初始化、数据准备和获取词汇表及索引。

数据处理代码及解析

"""
@author : Hyunwoong
@when : 2019-10-29
@homepage : https://github.com/gusdnd852
"""
from conf import *
from util.data_loader import DataLoader
from util.tokenizer import Tokenizertokenizer = Tokenizer()
loader = DataLoader(ext=('.en', '.de'),tokenize_en=tokenizer.tokenize_en,tokenize_de=tokenizer.tokenize_de,init_token='<sos>',eos_token='<eos>')train, valid, test = loader.make_dataset()
loader.build_vocab(train_data=train, min_freq=2)
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,batch_size=batch_size,device=device)src_pad_idx = loader.source.vocab.stoi['<pad>']
trg_pad_idx = loader.target.vocab.stoi['<pad>']
trg_sos_idx = loader.target.vocab.stoi['<sos>']enc_voc_size = len(loader.source.vocab)
dec_voc_size = len(loader.target.vocab)

1. 导入模块和配置

from conf import *
from util.data_loader import DataLoader
from util.tokenizer import Tokenizer
  • from conf import *:从 conf 模块导入所有内容,通常包含全局配置,如路径、参数等,方便在整个项目中使用统一的配置。
  • from util.data_loader import DataLoader:导入 DataLoader 类,它负责数据的加载和处理,包括数据集的划分、迭代器的创建等。
  • from util.tokenizer import Tokenizer:导入 Tokenizer 类,用于文本的分词和编码,将文本转换为模型可以处理的形式。

2. 初始化分词器和数据加载器

tokenizer = Tokenizer()
loader = DataLoader(ext=('.en', '.de'),tokenize_en=tokenizer.tokenize_en,tokenize_de=tokenizer.tokenize_de,init_token='<sos>',eos_token='<eos>')
  • 创建 Tokenizer 实例,用于后续的分词操作。
  • 创建 DataLoader 实例,指定:
    • ext:数据文件的扩展名,这里指定为英语(.en)和德语(.de),表示处理的是英德双语数据。
    • tokenize_entokenize_de:分别指定英语和德语的分词函数,这些函数来自 Tokenizer 实例,确保不同语言的文本能正确分词。
    • init_tokeneos_token:分别指定序列的开始和结束标记,这里使用 <sos><eos>,方便模型识别序列的边界。

3. 加载和划分数据集

train, valid, test = loader.make_dataset()

调用 DataLoader 实例的 make_dataset 方法,加载数据并划分为训练集、验证集和测试集。训练集用于模型的训练,验证集用于调整模型的超参数,测试集用于评估模型的最终性能。

4. 构建词汇表

loader.build_vocab(train_data=train, min_freq=2)

调用 DataLoader 实例的 build_vocab 方法,基于训练数据构建词汇表。min_freq=2 表示词汇表中只包含出现频率至少为 2 的单词,这样可以过滤掉一些罕见的单词,减少词汇表的大小,提高模型的训练效率。

5. 创建数据迭代器

train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,batch_size=batch_size,device=device)

调用 DataLoader 实例的 make_iter 方法,为训练集、验证集和测试集创建数据迭代器。这些迭代器将数据分批加载到模型中,batch_size 指定了每批数据的大小,device 指定了数据加载到的设备(如 CPU 或 GPU)。分批处理数据可以减少内存的使用,提高训练的效率。

6. 获取特殊标记的索引

src_pad_idx = loader.source.vocab.stoi['<pad>']
trg_pad_idx = loader.target.vocab.stoi['<pad>']
trg_sos_idx = loader.target.vocab.stoi['<sos>']

从源语言和目标语言的词汇表中获取 <pad>(填充标记)和 <sos>(序列开始标记)的索引。这些索引在后续的模型训练和推理中会被用到,具体意义如下:

  • <pad> 索引:在处理不同长度的序列时,为了能够将它们批量输入到神经网络中,通常需要对序列进行填充或截断,使其具有相同的长度。<pad> 标记用于填充较短的序列,使其与最长序列的长度一致。在模型训练过程中,通过识别 <pad> 标记的索引,模型可以忽略这些填充部分的影响,只关注实际有效的序列内容。
  • <sos> 索引:<sos> 标记用于标识序列的开始位置,这对于生成式模型(如文本生成、机器翻译中的解码器)尤为重要。它告诉模型从哪里开始生成或解码序列。在解码过程中,模型通常会从 <sos> 标记开始,逐步生成后续的单词或字符,直到遇到序列结束标记(如 <eos>)。

以下是一个简短的例程,展示了如何在模型训练中使用这些特殊标记的索引:

假设我们已经有一个训练好的seq2seq模型,以及相应的数据加载器
#...(模型和数据加载器的初始化代码省略)...获取特殊标记的索引
src_pad_idx = loader.source.vocab.stoi['<pad>']
trg_pad_idx = loader.target.vocab.stoi['<pad>']
trg_sos_idx = loader.target.vocab.stoi['<sos>']在训练循环中
for batch in train_iter:src_seq, trg_seq = batch# 忽略填充部分的影响(具体实现取决于模型)#...(处理src_seq和trg_seq,可能包括掩码操作等)...# 解码器从<sos>标记开始生成序列decoder_output = model.decode(encoder_output, start_token=trg_sos_idx)#...(计算损失、更新模型参数等)...

7. 获取词汇表大小

enc_voc_size = len(loader.source.vocab)
dec_voc_size = len(loader.target.vocab)

获取源语言和目标语言词汇表的大小,这些大小通常用于定义模型中的嵌入层大小。嵌入层将单词转换为向量表示,词汇表的大小决定了嵌入层的输入维度。

Transformer模型数据处理的基本流程,包括数据的加载、分词、划分、构建词汇表以及创建数据迭代器等操作。在实际应用中,我们可以根据具体的任务和数据特点对这些步骤进行调整和优化。

相关文章:

Transformer 代码剖析1 - 数据处理 (pytorch实现)

引言 Transformer 架构自《Attention Is All You Need》论文发表以来&#xff0c;在自然语言处理领域引起了巨大的变革。它摒弃了传统的循环结构&#xff0c;完全基于注意力机制&#xff0c;显著提高了处理序列数据的效率和性能。本文将通过对一个具体的项目代码结构进行详细分…...

Python异常处理面试题及参考答案

目录 什么是 Python 中的异常?程序为什么需要异常处理机制? 解释 BaseException 和 Exception 的区别 Python 的异常处理与传统的错误代码返回机制相比有哪些优势? 列出至少 5 个 Python 内置异常类型并说明触发场景 语法错误 (SyntaxError) 与运行时异常 (Runtime Erro…...

Python多线程知多少

目录 目标 Python版本 官方文档 概述 线程 守护线程 线程同步 事件对象&#xff08;Event Object&#xff09; 实战 创建线程的基本语法 阻塞线程 守护线程 线程同步的方法 互斥锁&#xff08;排他锁&#xff09; 信号量&#xff08;Semaphore&#xff09; 事件…...

C++ Qt常见面试题(8):C++ Qt中的线程同步与互斥

在C++ Qt中,线程同步和互斥通常通过 QMutex 和 QMutexLocker 来实现。线程同步确保多个线程不会同时访问共享资源,而互斥机制通过锁定一个资源,确保在任何给定时刻只有一个线程能够访问它。 以下是一个使用 QMutex 来同步和互斥访问共享资源的详细示例代码: 1. 使用 QMut…...

数字内容个性化推荐的关键是什么?

智能算法交互体系构建 构建数字内容体验的智能推荐系统&#xff0c;本质上是实现数据驱动与算法响应的动态协同。其核心在于建立多维度用户数据与机器学习模型的深度交互链路——通过实时采集用户点击、停留时长、交互路径等行为特征&#xff0c;结合设备属性、场景状态等上下…...

DeepSeek-OpenSourceWeek-第三天-Release of DeepGEMM

DeepGEMM:这是一款专为高效的 FP8(8 位浮点)通用矩阵乘法(GEMMs)而开发的尖端库。GEMMs 是许多 AI 工作负载(尤其是深度学习)中的基本操作。 特点: 支持稠密和 MoE GEMMs:它可以处理标准的稠密矩阵乘法以及混合专家(MoE)模型中使用的矩阵乘法。MoE 是一种神经网络架…...

LeetCode 1472.设计浏览器历史记录:一个数组完成模拟,单次操作均O(1)

【LetMeFly】1472.设计浏览器历史记录&#xff1a;一个数组完成模拟&#xff0c;单次操作均O(1) 力扣题目链接&#xff1a;https://leetcode.cn/problems/design-browser-history/ 你有一个只支持单个标签页的 浏览器 &#xff0c;最开始你浏览的网页是 homepage &#xff0c…...

AI+游戏,正在进行时!

2月&#xff0c;DeepSeek引领的AI浪潮对游戏行业造成了巨大冲击。 2月17日马斯克在社交平台宣布&#xff0c;xAI将成立一家AI游戏工作室&#xff0c;高调宣布两大核心理念&#xff0c;打破大公司的垄断&#xff0c;利用AI重构游戏体验。随后的新闻中还表示&#xff0c;团队计划…...

贪心算法精品题

1.找钱问题 本题的贪心策略在于我们希望就可能的保留作用大的5元 class Solution { public:bool lemonadeChange(vector<int>& bills) {std::map<int ,int> _map;for(auto ch:bills){if(ch 5) _map[ch];else if(ch 10){if(_map[5] 0) return false;else{_m…...

sql server 复制从备份初始化数据

参考 &#xff1a; 从备份初始化订阅&#xff08;事务&#xff09; - SQL Server | Microsoft Learn sql server 复制默认是用快照初始化数据的&#xff0c;也支持从备份初始化数据&#xff0c;参考如上...

【蓝桥杯】1.k倍区间

前缀和 #include <iostream> using namespace std; const int N100010; long long a[N]; int cnt[N]; int main(){int n, m;cnt[0] 1;cin >> n >> m;long long res 0;for(int i 1; i < n; i){scanf("%d", &a[i]);a[i] a[i-1];res cnt…...

Qt互斥锁(QMutex)的使用、QMutexLocker的使用

Qt互斥锁【QMutex】的使用、QMutexLocker的使用 Chapter1 Qt互斥锁(QMutex)的使用、QMutexLocker的使用一、QMutexLocker和QMutex实现示例图二、QMutex和QMutexLocker的关系&#xff08;个人理解&#xff09;三、QMutex使用和QMutexLocker使用1.QMutex的使用2.QMutexLocker的使…...

具身智能(Embodied AI)的物理交互基准测试:构建真实世界的智能体评估体系

文章目录 引言:从虚拟到物理的智能跃迁一、具身智能测试体系设计1.1 评估维度矩阵1.2 测试平台技术栈二、核心测试场景构建2.1 基础运动能力测试集2.2 复杂操作任务设计三、物理仿真引擎关键技术3.1 高精度接触力学模型3.2 传感器噪声模拟四、评估指标体系4.1 量化指标公式4.2…...

Javaweb后端数据库多表关系一对多,外键,一对一

多表关系 一对多 多的表里&#xff0c;要有一表里的主键 外键 多的表上&#xff0c;添加外键 一对一 多对多 案例...

鸿蒙 ArkUI 实现敲木鱼小游戏

敲木鱼是一款具有禅意的趣味小游戏&#xff0c;本文将通过鸿蒙 ArkUI 框架的实现代码&#xff0c;逐步解析其核心技术点&#xff0c;包括动画驱动、状态管理、音效震动反馈等。 一、架构设计与工程搭建 1.1 项目结构解析 完整项目包含以下核心模块&#xff1a; ├── entry…...

cv2.solvePnP 报错 求相机位姿

目录 报错信息及解决&#xff1a; cv2.solvePnP 使用例子&#xff1a; 设置初始值效果也不好 cv2.projectPoints 函数效果不好 报错信息及解决&#xff1a; File "/shared_disk/users/lbg/project/human_4d/nlf_pose/render_demo_pkl2_cal.py", line 236, in <…...

Linux实操——在服务器上直接从百度网盘下载(/上传)文件

Linux Linux实操——在服务器上直接从百度网盘下载&#xff08;/上传&#xff09;文件 文章目录 Linux前言一、下载并安装bypy工具二、认证并授权网盘账号三、将所需文件转移至目的文件夹下四、下载文件五、上传文件六、更换绑定的百度云盘账户 前言 最近收到一批很大的数据&…...

2004-2024年光刻机系统及性能研究领域国内外发展历史、差距、研究难点热点、进展突破及下一个十年研究热点方向2025.2.27

一.光刻机概述 1.1 定义与原理 光刻机是 集成电路芯片制造的核心设备 ,其工作原理基于 光学成像和化学反应 。它通过 曝光系统 将掩模版上的图形精确地转移到涂覆于硅片表面的光刻胶上。这个过程涉及复杂的物理和化学反应,主要包括以下几个步骤: 涂胶 :在硅片表面均匀涂抹…...

请求Geoserver的WTMS服务返回200不返回图片问题-跨域导致

今天碰到个奇怪问题&#xff0c;改了个页面标题再打包布署GeoServer发现调用WTMS服务失败&#xff0c;请求返回状态码200&#xff0c;返回包大小0&#xff0c;使用postman模拟请求是可以正常返回图片的。 跟之前版本对比如下&#xff1a; 正常Response请求: HTTP/1.1 200X-Fr…...

ubuntu配置jmeter

1.前提准备 系统 ubuntu server 22.04 前提条件&#xff1a;服务器更新apt与安装lrzsz&#xff1a;更新apt&#xff1a; sudo apt update安装lrzsz: 命令行下的上传下载文件工具 sudo apt install lrzszsudo apt install zip2.安装jemeter 2.1.下载jdk17 输入命令&#xf…...

云计算——弹性云计算器(ECS)

弹性云服务器&#xff1a;ECS 概述 云计算重构了ICT系统&#xff0c;云计算平台厂商推出使得厂家能够主要关注应用管理而非平台管理的云平台&#xff0c;包含如下主要概念。 ECS&#xff08;Elastic Cloud Server&#xff09;&#xff1a;即弹性云服务器&#xff0c;是云计算…...

逻辑回归:给不确定性划界的分类大师

想象你是一名医生。面对患者的检查报告&#xff08;肿瘤大小、血液指标&#xff09;&#xff0c;你需要做出一个**决定性判断**&#xff1a;恶性还是良性&#xff1f;这种“非黑即白”的抉择&#xff0c;正是**逻辑回归&#xff08;Logistic Regression&#xff09;** 的战场&a…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例&#xff0c;其中使用的是 Module Federation 和 npx-build-plus 实现了主应用&#xff08;Shell&#xff09;与子应用&#xff08;Remote&#xff09;的集成。 &#x1f6e0;️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?

Redis 的发布订阅&#xff08;Pub/Sub&#xff09;模式与专业的 MQ&#xff08;Message Queue&#xff09;如 Kafka、RabbitMQ 进行比较&#xff0c;核心的权衡点在于&#xff1a;简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

云原生安全实战:API网关Kong的鉴权与限流详解

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关&#xff08;API Gateway&#xff09; API网关是微服务架构中的核心组件&#xff0c;负责统一管理所有API的流量入口。它像一座…...

Android写一个捕获全局异常的工具类

项目开发和实际运行过程中难免会遇到异常发生&#xff0c;系统提供了一个可以捕获全局异常的工具Uncaughtexceptionhandler&#xff0c;它是Thread的子类&#xff08;就是package java.lang;里线程的Thread&#xff09;。本文将利用它将设备信息、报错信息以及错误的发生时间都…...

CppCon 2015 学习:REFLECTION TECHNIQUES IN C++

关于 Reflection&#xff08;反射&#xff09; 这个概念&#xff0c;总结一下&#xff1a; Reflection&#xff08;反射&#xff09;是什么&#xff1f; 反射是对类型的自我检查能力&#xff08;Introspection&#xff09; 可以查看类的成员变量、成员函数等信息。反射允许枚…...

6.9本日总结

一、英语 复习默写list11list18&#xff0c;订正07年第3篇阅读 二、数学 学习线代第一讲&#xff0c;写15讲课后题 三、408 学习计组第二章&#xff0c;写计组习题 四、总结 明天结束线代第一章和计组第二章 五、明日计划 英语&#xff1a;复习l默写sit12list17&#…...