当前位置: 首页 > article >正文

Seq2Seq - Dataset 类

本节代码定义了一个 CMN 类,它继承自 PyTorch 的 Dataset 类,用于处理英文和中文的平行语料库。这个类的主要作用是将文本数据转换为模型可以处理的格式,并进行必要的填充操作,以确保所有序列的长度一致。

⭐重写Dataset类是模型训练的重中之重请务必掌握!

重写时格式固定为三件套 __init__  __len__ __getitem__重点记忆! 

1. 类定义

class CMN(Dataset):def __init__(self, en_corpus, cn_corpus, en_tokenizer: Tokenizer, cn_tokenizer: Tokenizer, seq_len):self.en_corpus = en_corpusself.cn_corpus = cn_corpusself.en_tokenizer = en_tokenizerself.cn_tokenizer = cn_tokenizerself.seq_len = seq_lenself.pad_id = self.en_tokenizer.vocab["[PAD]"]self.bos_id = self.en_tokenizer.vocab["[BOS]"]self.eos_id = self.en_tokenizer.vocab["[EOS]"]
参数
  • en_corpus:英文语料库,是一个字符串列表。

  • cn_corpus:中文语料库,是一个字符串列表。

  • en_tokenizer:英文分词器,用于将英文文本转换为索引。

  • cn_tokenizer:中文分词器,用于将中文文本转换为索引。

  • seq_len:序列的最大长度,用于填充或截断序列。

属性
  • self.pad_id:填充标记 [PAD] 的索引。

  • self.bos_id:序列开始标记 [BOS] 的索引。

  • self.eos_id:序列结束标记 [EOS] 的索引。

2. 数据集长度(__len__

def __len__(self):return len(self.en_corpus)
  • 功能:返回数据集的长度,即语料库中句子的数量。

  • 返回值:数据集的长度。

3. 获取数据项(__getitem__

def __getitem__(self, idx):en_text = self.en_corpus[idx]cn_text = self.cn_corpus[idx]en_ids = self.en_tokenizer.encode(en_text)cn_ids = self.cn_tokenizer.encode(cn_text)encoder_input = self.pad_to_seq_len([self.bos_id] + en_ids)  # batch * seq_lendecoder_input = self.pad_to_seq_len([self.bos_id] + cn_ids)labels = self.pad_to_seq_len(cn_ids + [self.eos_id])return {"encoder_input": encoder_input,"decoder_input": decoder_input,"labels": labels,"en_text": en_text,"cn_text": cn_text}

CMN 类的 __getitem__ 方法中,代码的主要目的是将英文和中文文本转换为模型可以处理的格式,并进行必要的填充操作,以确保所有序列的长度一致。以下是对 __getitem__ 方法中各个部分的详细解释:

1. 获取文本
en_text = self.en_corpus[idx]
cn_text = self.cn_corpus[idx]
  • 功能:从语料库中获取索引为 idx 的英文句子 en_text 和中文句子 cn_text

  • 目的:为每个索引提供一对对应的英文和中文句子,用于后续的编码和解码。

2. 文本编码
en_ids = self.en_tokenizer.encode(en_text)
cn_ids = self.cn_tokenizer.encode(cn_text)
  • 功能:将英文和中文句子分别通过对应的分词器编码为索引列表。

  • 目的:将文本转换为模型可以处理的数值形式。分词器将每个字符(或单词)映射为词汇表中的索引。

3. 构建输入序列
encoder_input = self.pad_to_seq_len([self.bos_id] + en_ids)  # batch * seq_len
decoder_input = self.pad_to_seq_len([self.bos_id] + cn_ids)
  • 功能:构建编码器和解码器的输入序列。

  • 目的

    • 编码器输入:在英文索引列表的开头添加 [BOS] 标记,表示序列的开始。然后对序列进行填充或截断,使其长度为 seq_len

    • 解码器输入:在中文索引列表的开头添加 [BOS] 标记,表示序列的开始。同样进行填充或截断,使其长度为 seq_len

  • 为什么这样写

    • [BOS] 标记:在序列的开头添加 [BOS] 标记,是为了让模型知道序列的开始位置。这对于模型理解序列的起始点非常重要,尤其是在解码阶段。

    • 填充或截断:为了确保所有序列的长度一致,需要对序列进行填充或截断。填充是通过添加 [PAD] 标记来实现的,截断则是直接截取序列的前 seq_len 个元素。

4. 构建目标序列(标签)
labels = self.pad_to_seq_len(cn_ids + [self.eos_id])
  • 功能:构建解码器的目标序列(标签)。

  • 目的:为目标序列添加 [EOS] 标记,表示序列的结束。然后进行填充或截断,使其长度为 seq_len

  • 为什么这样写

    • [EOS] 标记:在目标序列的末尾添加 [EOS] 标记,是为了让模型知道序列的结束位置。这对于模型在解码阶段生成完整的序列非常重要。

    • 填充或截断:同样是为了确保所有序列的长度一致,需要对目标序列进行填充或截断。

5. 返回结果
return {"encoder_input": encoder_input,"decoder_input": decoder_input,"labels": labels,"en_text": en_text,"cn_text": cn_text
}
  • 功能:返回一个字典,包含以下内容:

    • "encoder_input":编码器的输入序列。

    • "decoder_input":解码器的输入序列。

    • "labels":解码器的目标序列。

    • "en_text":原始英文句子。

    • "cn_text":原始中文句子。

  • 目的:提供模型训练所需的所有输入和目标数据,同时保留原始文本以便后续验证和调试。

6. 填充序列(pad_to_seq_len
def pad_to_seq_len(self, x):pad_num = self.seq_len - len(x)return torch.tensor(x + [self.pad_id] * pad_num)
  • 功能:将一个索引列表填充或截断到指定的序列长度 seq_len

  • 目的:确保所有序列的长度一致,以便模型可以批量处理。

  • 为什么这样写

    • 填充:如果序列长度小于 seq_len,则在末尾添加 [PAD] 标记,直到长度达到 seq_len

    • 截断:如果序列长度大于 seq_len,则直接截取前 seq_len 个元素。

    • 转换为张量:将填充或截断后的列表转换为 PyTorch 张量,以便模型可以直接使用。

4. 填充序列(pad_to_seq_len

def pad_to_seq_len(self, x):pad_num = self.seq_len - len(x)return torch.tensor(x + [self.pad_id] * pad_num)
功能
  • 将一个索引列表填充或截断到指定的序列长度 seq_len

过程
  1. 计算填充数量

    • pad_num 是目标长度 seq_len 与当前列表长度的差值。

    • 如果 pad_num 为正数,则需要填充;如果为负数,则需要截断。

  2. 填充或截断

    • 如果 pad_num 为正数,将 [self.pad_id] 重复 pad_num 次,添加到列表的末尾。

    • 如果 pad_num 为负数,直接截断列表的末尾部分。

  3. 返回结果

    • 将填充或截断后的列表转换为 PyTorch 张量并返回。

示例

假设 seq_len=10x=[2, 3, 4],调用 pad_to_seq_len(x) 的结果:

pad_num = 10 - 3 = 7
result = [2, 3, 4] + [0, 0, 0, 0, 0, 0, 0]  # 假设 pad_id=0
torch.tensor([2, 3, 4, 0, 0, 0, 0, 0, 0, 0])

5.  CMN 类实现了以下功能:

  1. 数据读取:从语料库中读取英文和中文句子。

  2. 文本编码:将文本转换为索引列表。

  3. 序列填充:将索引列表填充或截断到指定长度。

  4. 构建输入和标签:为编码器和解码器构建输入序列和目标序列。

这些步骤是构建 Seq2Seq 模型中数据预处理的重要环节,确保了数据可以被模型有效处理。

需复现完整代码如下:

class CMN(Dataset):def __init__(self, en_corpus, cn_corpus, en_tokenizer: Tokenizer, cn_tokenizer: Tokenizer, seq_len):self.en_corpus = en_corpusself.cn_corpus = cn_corpusself.en_tokenizer = en_tokenizerself.cn_tokenizer = cn_tokenizerself.seq_len = seq_lenself.pad_id = self.en_tokenizer.vocab["[PAD]"]self.bos_id = self.en_tokenizer.vocab["[BOS]"]self.eos_id = self.en_tokenizer.vocab["[EOS]"]def __len__(self):return len(self.en_corpus)def __getitem__(self, idx):en_text = self.en_corpus[idx]cn_text = self.cn_corpus[idx]en_ids = self.en_tokenizer.encode(en_text)cn_ids = self.cn_tokenizer.encode(cn_text)encoder_input = self.pad_to_seq_len([self.bos_id] + en_ids) #batch * seq_lendecoder_input = self.pad_to_seq_len([self.bos_id] + cn_ids)labels = self.pad_to_seq_len(cn_ids + [self.eos_id])return {"encoder_input": encoder_input,"decoder_input": decoder_input,"labels": labels,"en_text": en_text,"cn_text": cn_text}def pad_to_seq_len(self, x):pad_num = self.seq_len - len(x)return torch.tensor(x + [self.pad_id] * pad_num)

相关文章:

Seq2Seq - Dataset 类

本节代码定义了一个 CMN 类,它继承自 PyTorch 的 Dataset 类,用于处理英文和中文的平行语料库。这个类的主要作用是将文本数据转换为模型可以处理的格式,并进行必要的填充操作,以确保所有序列的长度一致。 ⭐重写Dataset类是模型训…...

学习OpenCV C++版

OpenCV C 1 数据载入、显示与保存1.1 概念1.2 Mat 类构造与赋值1.3 Mat 类的赋值1.4 Mat 类支持的运算1.5 图像的读取与显示1.6 视频加载与摄像头调用1.7 数据保存 参考:《OpenCV4快速入门》作者冯 振 郭延宁 吕跃勇 1 数据载入、显示与保存 1.1 概念 Mat 类 : Ma…...

echarts图表相关

echarts图表相关 echarts官网折线图实际开发场景一: echarts官网 echarts官网 折线图 实际开发场景一: 只有一条折线,一半实线,一半虚线。 option {tooltip: {trigger: "axis",formatter: (params: any) > {const …...

idea自动部署jar包到服务器Alibaba Cloud Toolkit

安装插件:Alibaba Cloud Toolkit 配置服务器: 服务器配置: 项目启动Shell脚本命令: projectpd-otb.jar echo 根据项目名称查询对应的pid pid$(pgrep -f $project); echo $pid echo 杀掉对应的进程,如果pid不存在,则不执行 if [ …...

奥利司他

https://m.baidu.com/bh/m/detail/ar_9900965142893895938 奥利司他(四氢脂抑素)是一种众所周知的胰腺和胃脂肪酶不可逆抑制剂 生物活性:奥利司他(四氢脂抑素)是一种众所周知的胰腺和胃脂肪酶不可逆抑制剂。奥利司…...

Element Plus 图标使用方式整理

Element Plus 图标使用方式整理 以下是 Element Plus 图标的所有使用方式&#xff0c;包含完整代码示例和总结表格&#xff1a; 1. 按需引入图标组件 适用场景&#xff1a;仅需少量图标时&#xff0c;按需导入减少打包体积 示例代码&#xff1a; <template><div>…...

链路聚合+vrrp

1.链路聚合 作用注意事项将多个物理接口&#xff08;线路&#xff09;逻辑上绑定在一起形成一条逻辑链路&#xff0c;起到叠加带宽的作用1.聚合接口必须转发速率一致。2.聚合设备两端必须一致 配置命令 方法一 [Huawei]interface Eth-Trunk 0----先创建聚合接口&#xff0c;…...

Dynamics 365 Business Central Register Customer Payment 客户付款登记

#Dynamics 365 BC ERP# #D365 ERP# #Navision 前言 在实施过程&#xff0c;经常给客户介绍的 给客户付款一般用Payment Journal. 在客户熟悉系统运行后&#xff0c;往往会推荐客户使用Register Customer Payment.用这个function 工作会快很多&#xff0c;但出错的机会也比较大…...

Odoo免费开源ERP:企业销售过程中出现的问题

在企业未上线Odoo免费开源ERP时&#xff0c;企业销售过程中会存在失误。比如&#xff0c;许多销售订单都有如下问题&#xff1a;不当的定价、向客户过多地询问、处理订单延误、错过发货日期等。这些问题源于企业三个未集成的信息系统&#xff1a;销售管理系统、库存系统和财务系…...

手撕unique_ptr 和 shareed_ptr

文章目录 unique_ptrshared_ptr unique_ptr template<class T> class Unique_ptr { private:T* ptrNULL; public://1、删除默认的拷贝构造函数Unique_ptr(Unique_ptr& u) delete;//2、删除默认的复制构造Unique_ptr& operator(Unique_ptr& u) delete; …...

工会考试的重点内容是什么

工会考试的内容通常涵盖以下几个方面&#xff1a; 1、政治理论&#xff1a; 主要考查考生对马克思主义基本原理、中国特色社会主义理论体系、党的基本路线、方针、政策等方面的掌握程度。题型通常包括选择题、判断题和论述题。 2、法律法规&#xff1a; 这部分主要涉及国家…...

网络稳定性--LCA+最大生成树+bfs1/dfs1找最小边

1.最大生成树去除重边&#xff0c;只要最大的边成树 2.LCA查最近公共祖先&#xff0c;然后询问的lca(x,y)ff,分别从x,y向上找最小边 3.bfs1/dfs1就是2.中向上找的具体实现 #include<bits/stdc.h> using namespace std; #define N 100011 typedef long long ll; typede…...

混合并行技术在医疗AI领域的应用分析(代码版)

混合并行技术(专家并行/张量并行/数据并行)通过多维度的计算资源分配策略,显著提升了医疗AI大模型的训练效率与推理性能。以下结合技术原理与医疗场景实践,从策略分解、技术对比、编排优化及典型案例等维度展开分析: 一、混合并行技术:突破单卡算力限制 1. 并行策略三维分…...

【C++面向对象】封装(上):探寻构造函数的幽微之境

每文一诗 &#x1f4aa;&#x1f3fc; 我本将心向明月&#xff0c;奈何明月照沟渠 —— 元/高明《琵琶记》 译文&#xff1a;我本是以真诚的心来对待你&#xff0c;就像明月一样纯洁无瑕&#xff1b;然而&#xff0c;你却像沟渠里的污水一样&#xff0c;对这份心意无动于衷&a…...

每日算法-250409

这是我今天的算法学习记录。 2187. 完成旅途的最少时间 题目描述 思路 二分查找 解题过程 为什么可以使用二分查找&#xff1f; 问题的关键在于寻找一个最小的时间 t&#xff0c;使得在时间 t 内所有公交车完成的总旅途次数 sum 大于等于 totalTrips。 我们可以观察到时间的单…...

如何实现文本回复Ai ChatGPT DeepSeek 式文字渐显效果?前端技术详解(附完整代码)

个人开发的塔罗牌占卜小程序&#xff1a;【问问塔罗牌】 快来瞧瞧吧&#xff01; 一、核心实现原理 我们通过三步实现这个效果&#xff1a; 逐字渲染&#xff1a;通过 JavaScript 定时添加字符 透明度动画&#xff1a;CSS 实现淡入效果 光标动画&#xff1a;伪元素 CSS 动画…...

CompletableFuture高级模式详解

目录 CompletableFuture高级模式详解 1. CompletableFuture基础概念 1.1 什么是CompletableFuture? 1.2 异步编程基础 1.3 CompletableFuture与Future的对比 2. 创建CompletableFuture 2.1 基本创建方法 2.2 使用异步方法创建 2.3 指定执行器 3. 转换和链式操作 3.…...

【AI开源大模型工具链ModelEngine】【01】应用框架-源码编译运行

ModelEngine提供从数据处理、知识生成&#xff0c;到模型微调和部署&#xff0c;以及RAG&#xff08;Retrieval Augmented Generation&#xff09;应用开发的AI训推全流程工具链。 GitCode开源地址&#xff1a;https://gitcode.com/ModelEngineGitee开源地址&#xff1a;https…...

linux下截图工具的选择

方案一 gnome插件Screenshot Tool&#xff08;截屏&#xff09; ksnip&#xff08;图片标注&#xff09; gnome setting设置图片的默认打开方式为ksnip就可以快捷的将Screenshot Tool截屏的图片打开进行标记了。 但是最近我发现Screenshot Tool的延迟截图功能是有问题的&…...

每天记录一道Java面试题---day36

事务的基本特性和隔离级别 回答重点 事务基本特性ACID分别是&#xff1a; - 原子性指的是一个事务中的操作要么全部成功&#xff0c;要么全部失败。 - 一致性指的是数据库总是一个一致性的状态转换到另一个一致性的状态。比如A转账给B100块钱&#xff0c;假设A只有 90块&…...

Qt音频采集:QAudioInput详解与示例

1. 简介 QAudioInput是Qt Multimedia模块中用于音频采集的核心类&#xff0c;能够从麦克风等输入设备实时获取原始音频数据&#xff08;PCM格式&#xff09;。本文将通过原理讲解和代码示例&#xff0c;帮助开发者快速掌握音频采集的核心技术。 2. 核心功能 支持多种音频格式&…...

rkmpp 解码 精简mpi_dec_test.c例程

rkmpp 解码流程&#xff08;除 MPP_VIDEO_CodingMJPEG 之外&#xff09; 源码 输入h264码流 输出nv12文件 /** Copyright 2015 Rockchip Electronics Co. LTD** Licensed under the Apache License, Version 2.0 (the "License");* you may not use this file exce…...

怎么构造思维链数据?思维链提示工程的五大原则

我来为您翻译这篇关于思维链提示工程的文章&#xff0c;采用通俗易懂的中文表达&#xff1a; 思维链(CoT)提示工程是生成式AI(GenAI)中一种强大的方法&#xff0c;它能让模型通过逐步推理来解决复杂任务。通过构建引导模型思考过程的提示&#xff0c;思维链能提高输出的准确性…...

网络安全之-信息收集

域名收集 域名注册信息 站长之家 https://whois.chinaz.com/ whois 查询的相关网站有:中国万网域名WHOIS信息查询地址: https://whois.aliyun.com/西部数码域名WHOIS信息查询地址: https://whois.west.cn/新网域名WHOIS信息查询地址: http://whois.xinnet.com/domain/whois/in…...

JdbcTemplate基本使用

JdbcTemplate概述 它是spring框架中提供的一个对象&#xff0c;是对原始繁琐的JdbcAPI对象的简单封装。spring框架为我们提供了很多的操作模板类。例如:操作关系型数据的JdbcTemplate和MbernateTemplate&#xff0c;操作nosql数据库的RedisTemplate&#xff0c;操作消息队列的…...

pnpm 中 Next.js 模块无法找到问题解决

问题概述 项目在使用 pnpm 管理依赖时,出现了 “Cannot find module ‘next/link’ or its corresponding type declarations” 的错误。这是因为 pnpm 的软链接机制在某些情况下可能导致模块路径解析问题。 问题诊断 通过命令 pnpm list next 确认项目已安装 Next.js 15.2.…...

openEuler24.03 LTS下安装Spark

目录 安装模式介绍 下载Spark 安装Local模式 前提条件 解压安装包 简单使用 安装Standalone模式 前提条件 集群规划 解压安装包 配置Spark 配置Spark-env.sh 配置workers 分发到其他机器 启动集群 简单使用 关闭集群 安装YARN模式 前提条件 解压安装包 配…...

蓝桥杯真题——接龙序列

蓝桥杯2023年第十四届省赛真题-接龙数列 题目描述 对于一个长度为 K 的整数数列&#xff1a;A1, A2, . . . , AK&#xff0c;我们称之为接龙数列当且仅当 Ai 的首位数字恰好等于 Ai−1 的末位数字 (2 ≤ i ≤ K)。 例如 12, 23, 35, 56, 61, 11 是接龙数列&#xff1b;12, 2…...

使用 DeepSeek API 实现新闻文章地理位置检测与地图可视化

使用 DeepSeek API 实现新闻文章地理位置检测与地图可视化 | Implementing News Article Location Detection and Map Visualization with DeepSeek API 作者&#xff1a;zhutoutoutousan | Author: zhutoutoutousan 发布时间&#xff1a;2025-04-08 | Published: 2025-04-08 标…...

如何精准控制大模型的推理深度

论文标题 ThinkEdit: Interpretable Weight Editing to Mitigate Overly Short Thinking in Reasoning Models 论文地址 https://arxiv.org/pdf/2503.22048 代码地址 https://github.com/Trustworthy-ML-Lab/ThinkEdit 作者背景 加州大学圣迭戈分校 动机 链式推理能显…...