当前位置: 首页 > news >正文

第N2周:NLP中的数据集构建

对于初学者,NLP中最烦人的问题之一就数据集的构建问题,处理不好就会引起shape问题(各种由于shape错乱导致的问题)。这里给出一个模版,大家可根据这个模版来构建。

torch.utils.data是PyTorch中用于数据加载和预处理的模块。其中包括Dataset和DataLoader两个类,它们通常结合使用来加载和处理数据。

一、Dataset
torch.utils.data.Dataset是一个抽象类,用于表示数据集。它需要用户自己实现两个方法:__ len__ 和__getitem__。其中,__len__方法返回数据集的大小,__getitem__方法用于根据给定的索引返回一个数据样本。

以下是一个简单的示例,展示了如何定义一个数据集:

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Datasetclass MyDataset(Dataset):def __init__(self, texts, labels):self.texts  = textsself.labels = labelsdef __len__(self):return len(self.labels)def __getitem__(self, idx):texts  = self.texts[idx]labels = self.labels[idx]return texts, labels

在这个示例中,MyDataset继承了torch.utils.data.Dataset类,并实现了__len__和__getitem__方法。__len__方法返回数据集的大小,这里使用了Python内置函数len。__getitem__方法根据给定的索引返回一个数据样本,这里返回的是数据列表中对应的元素。

二、DataLoader

torch.utils.data.DataLoader是PyTorch中一个重要的类,用于高效加载数据集。它可以处理数据的批次化、打乱顺序、多线程数据加载等功能。
以下是一个简单的示例:

# 假设我们有以下三个样本,分别由不同数量的单词索引组成
text_data = [torch.tensor([1, 2, 3, 4], dtype=torch.long),  # 样本1torch.tensor([4, 3, 2], dtype=torch.long),     # 样本2torch.tensor([1, 2], dtype=torch.long)         # 样本3
]# 对应的标签
labels = torch.tensor([1, 0, 1], dtype=torch.float)# 创建数据集和数据加载器
my_dataset  = MyDataset(text_data, labels)
data_loader = DataLoader(my_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: x)for batch in data_loader:print(batch)

代码输出

[(tensor([4, 3, 2]), tensor(0.)), (tensor([1, 2]), tensor(1.))]
[(tensor([1, 2, 3, 4]), tensor(1.))]

在这个示例中,我们首先创建了一个MyDataset实例my_dataset,它包含了一个整数列表。然后,我们使用DataLoader类创建了一个数据加载器data_loader,它将data_loader作为输入,并将数据分成大小为4的批次,并对数据进行随机化。最后,遍历data_loader,并打印出每个批次的数据。

三、DataLoader参数讲解

函数原型:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)

常用的参数:

1.dataset:一个数据集对象,必须实现__len__和__getitem__方法。
2.batch_size:每个batch的大小。
3.shuffle:是否对数据进行洗牌(随机打乱)。
4.sampler:一个数据采样器,用于对数据进行自定义采样。
5.batch_sampler:一个batch采样器,用于对batch进行自定义采样。
6.num_workers:用于数据加载的子进程数量。默认值为0,表示在主进程中加载数据。
7.collate_fn:用于将一个batch的数据合并成一个张量或者元组。
8.pin_memory:是否将数据存储在pin memory中(锁定物理内存,用于GPU加速数据传输),默认值为False。
9.drop_last:如果数据不能完全分成batch,是否删除最后一批数据。默认为False。
10.timeout:当数据加载器陷入死锁时,等待数据准备的最大秒数。默认值为0,表示无限等待。
11.worker_init_fn:用于每个数据加载器进程的初始化函数。可以用来设置特定的随机种子。
12.multiprocessing_context:用于创建数据加载器子进程的上下文。

以上是torch.utils.data.DataLoader中一些常用的参数,使用时根据实际情况选择相应的参数组合。

sampler参数详解:

sampler是一个用于指定数据集采样方式的类,它控制DataLoader如何从数据集中选取样本。PyTorch提供了多种Sampler类,例如RandomSampler和SequentialSampler,分别用于随机采样和顺序采样。

以下是一个示例:

from torch.utils.data.sampler import RandomSamplermy_sampler = RandomSampler(my_dataset)my_dataloader = data.DataLoader(my_dataset, batch_size=4, shuffle=False, sampler=my_sampler)

在这个示例中,我们使用RandomSampler类来指定随机采样方式,然后将其传递给DataLoader的sampler参数。这将覆盖默认的shuffle参数,使数据集按照sampler指定的采样方式进行

四、自定义Dataset类

除了使用torchvision.datasets中提供的数据集,我们还可以使用torch.utils.data.Dataset类来自定义自己的数据集。自定义数据集需要实现__len__和__getitem__方法。

●__init__: 用来初始化数据集
●__len__:方法返回数据集中样本的数量
●__getitem__:给定索引值,返回该索引值对应的数据;它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问

class MyDataset(data.Dataset):def __init__(self, data_path):self.data_list = torch.load(data_path)def __len__(self):return len(self.data_list)def __getitem__(self, index):x = self.data_list[index][0]y = self.data_list[index][1]return x, y

在这个示例中,MyDataset类继承自torch.utils.data.Dataset类,实现了__len__和__getitem__方法。MyDataset类的构造函数接受一个数据路径作为参数,数据集被保存为一个由数据-标签对组成的列表。

五、自定义Sampler类

除了使用torch.utils.data.sampler中提供的采样器,我们还可以使用Sampler类来自定义自己的采样器。自定义采样器需要实现__iter__和__len__方法。

●__iter__方法返回一个迭代器,用于遍历数据集中的样本索引。
●__len__方法返回数据集中样本的数量。

以下是一个示例:

class MySampler(Sampler):def __init__(self, data_source):self.data_source = data_sourcedef __iter__(self):return iter(range(len(self.data_source)))def __len__(self):return len(self.data_source)

在这个示例中,MySampler类继承自torch.utils.data.sampler.Sampler类,实现了__iter__和__len__方法。

六、自定义Transform类

除了使用torchvision.transforms中提供的变换,我们还可以使用transforms模块中的Compose类来自定义自己的变换。Compose类将多个变换组合在一起,并按照顺序应用它们。

以下是一个示例:

class MyTransform(object):def __call__(self, x):x = self.crop(x)x = self.to_tensor(x)return xdef crop(self, x):# 这里实现裁剪变换# .......return xdef to_tensor(self, x):# 这里实现张量化变换# .......return xmy_transform = transforms.Compose([MyTransform()
])# 创建数据集和数据加载器
my_dataset    = MyDataset(data_path)
my_dataloader = DataLoader(my_dataset, batch_size=32, shuffle=True, num_workers=4)# 遍历数据集
for batch in my_dataloader:# 在这里处理数据批次pass

在这个示例中,MyTransform类实现了一个自定义的变换,它将裁剪和张量化两个变换组合在一起。transforms.Compose将这个自定义变换组合成一个变换序列,并在数据集中的每个样本上应用这个序列。

相关文章:

第N2周:NLP中的数据集构建

对于初学者,NLP中最烦人的问题之一就数据集的构建问题,处理不好就会引起shape问题(各种由于shape错乱导致的问题)。这里给出一个模版,大家可根据这个模版来构建。 torch.utils.data是PyTorch中用于数据加载和预处理的…...

AI助力浮雕创作!万物皆可浮雕?Stable Diffusion AI绘画【浮雕艺术】之文生浮雕!

前言 对于浮雕艺术,其实并不了解。但有幸能和“细辛”前辈结识,对浮雕有了简单的了解,浮雕图案的传统方式是先由画师画出图,然后由雕刻师雕刻。画师画图归为浮雕的设计阶段,画师会绘制出浮雕的设计图,‌这为…...

你觉得大模型时代该出现什么?

大模型的概念都火了两年了,之前各种媒体吹嘘大模型的出现是类似“蒸汽机时代”、“iPhone时刻”等等。那为什么我们期待的结果都没出现呢?咱们先一起回顾下历史。 1、蒸汽机时代 1.1、蒸汽机历史 许多人都在讨论大模型时代好像只是概念在火&#xff0…...

JS【详解】事件委托

事件委托的简介 事件委托(Event Delegation)是 JS 处理事件的一种技术:不直接在目标元素上设置事件监听器,而是在其父元素或祖先元素上设置监听器,然后利用事件冒泡机制来捕获和处理事件。 事件委托的好处 减少内存占用…...

谈对象系列:C++类和对象

文章目录 一、类的定义1.1类定义的格式类的两种定义方法结构体: 1.2访问限定符1.3类域 二、实例化2.1变量的声明和定义2.2类的大小计算空类的大小(面试): 三、this指针小考题 一、类的定义 1.1类定义的格式 使用class关键字&…...

设计模式20-备忘录模式

设计模式20-备忘录 动机定义与结构定义结构 C代码推导优缺点应用场景总结备忘录模式和序列化备忘录模式1. **动机**2. **实现方式**3. **应用场景**4. **优点**5. **缺点** 序列化1. **动机**2. **实现方式**3. **应用场景**4. **优点**5. **缺点** 对比总结 动机 在软件构建过…...

绘制echarts-liquidfill水球图

文章目录 一、效果图二、步骤1.安装插件2.引入2.主要代码2.素材图片 总结 一、效果图 二、步骤 1.安装插件 npm install echarts npm install echarts-liquidfillecharts5的版本与echarts-liquidfill3兼容,echarts4的版本与echarts-liquidfill2兼容,安装的时候需要…...

应急响应:D盾的简单使用.

什么是应急响应. 一个组织为了 应对 各种网络安全 意外事件 的发生 所做的准备 以及在 事件发生后 所采取的措施 。说白了就是别人攻击你了,你怎么把这个攻击还原,看看别人是怎么攻击的,然后你如何去处理,这就是应急响应。 D盾功…...

c语言第14天笔记

通过指针引用数组 数组元素的指针 数组指针:数组中的第一个元素的地址,也就是数组的首地址。 指针数组:用来存放数组元素地址的数组,称之为指针数组。 注意:虽然我们定义了一个指针变量接收了数组地址,但…...

服装行业QMS中的来料检验:常见问题解析与解决策略

在服装行业的来料检验过程中,常会遇到一系列问题,这些问题可能影响到原材料的质量,进而影响最终产品的品质。以下将详细介绍来料检验的常见问题及相应的解决方法: 一、常见问题 外观瑕疵 问题描述:原材料表面存在污渍…...

健身动作AI识别,仰卧起坐计数(含UI界面)

用Python和Mediapipe打造,让你的运动效果一目了然! 【技术揭秘】 利用Mediapipe的人体姿态估计,实时捕捉关键点,精确识别动作。 每一帧的关键点坐标和角度都被详细记录,为动作分析提供数据支持。 支持自定义动作训练&a…...

GitHub开源金融系统:Actual

Actual:电子金融,本地优先,自由开源- 精选真开源,释放新价值。 概览 Actual的创新之处在于其对个人财务管理的全面考虑,它不仅仅是一个简单的记账工具,而是一个综合性的理财解决方案。它的本地优先设计意味…...

【学习笔记】Day 7

一、进度概述 1、DL-FWI基础入门培训笔记 2、inversionnet_train 试运行——未成功 二、详情 1、InversionNet: 深度学习实现的反演 InversionNet构建了一个具有编码器-解码器结构的卷积神经网络,以模拟地震数据与地下速度结构的对应关系。 (一…...

网络中特殊的 IP 地址

特殊网络 IP 127.0.0.1 127.0.0.1 是本机回送地址,发送到 127.0.0.1 的数据或者从 127.0.0.1 返回的数据只会在本机进行传输, 而不进行外部网络传输。 主要有以下两个作用: 测试本机网络 当我们可以 ping 通 127.0.0.1 的时候, 则说明本机的网卡以及 tc…...

ASP 表单处理入门指南

ASP 表单处理入门指南 简介 ASP(Active Server Pages)是一种由微软开发的服务器端脚本环境,用于动态生成交互性网页。它允许开发者结合HTML、VBScript或JScript脚本语言来创建和运行动态网页或Web应用程序。本文将重点介绍如何使用ASP来处理表单数据,包括表单的创建、数据…...

极米RS10Plus性价比高吗?7款4-6K价位投影仪测评哪款最好

通常家庭想买个投影仪都会选择4-6K这个价位段的投影仪,3K以下的投影配置太低,6K以上的价格略高,4-6K价位段的中高端投影仪正好满足大部分家庭的使用需求。正好极米投影在8月份上新了一款Plus版本的长焦投影:极米RS10Plus&#xff…...

RocketMQ怎么对文件进行读写的?

RocketMQ 对文件的读写主要依赖于其底层的存储机制,核心组件是 CommitLog 和 ConsumeQueue,并且通过 MappedFile 类来进行高效的文件操作。以下是 RocketMQ 文件读写的详细介绍: 1. CommitLog CommitLog 是 RocketMQ 的核心存储文件&#x…...

智慧宠物护理:智能听诊器引领健康监测新潮流

在宠物健康科技的浪潮中,智能听诊器的诞生标志着宠物健康管理迈向了智能化的新纪元。广州坎普利智能信息科技有限公司的创新产品,正为宠物主人和他们的毛茸茸伙伴带来前所未有的关怀体验。 创新特点 这款智能听诊器,以其前沿科技和人性化设…...

SRE工程师第2天:我只要截图功能 而不是打开微信

大家好,我是watchpoints 别想太多,只管去提问,所有问题,都会有答案 watchpoints是我github用户名 , 也是我的wechat 用户名,如果我有讲不明白 欢迎提问 什么是SRE(Site Reliability Engineer) 和…...

【RunnerGo】离线安装成功版本

目录 一、下载 二、解压安装包 三、修改安装配置 3.1 编辑修改安装参数(我没有改,默认安装即可) 3.2 安装目录结构说明 四、执行安装 五、检查服务并使用 六、访问 前言:最近在调研一个新工具,发现RunnerGo&…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

Spark 之 入门讲解详细版(1)

1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器

一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂

蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色,华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型,能助力我们轻松驾驭 DeepSeek-V3/R1,本文中将分享如何…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块,用于对本地知识库系统中的知识库进行增删改查(CRUD)操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 📘 一、整体功能概述 该模块…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...