当前位置: 首页 > news >正文

Pytorch 自学笔记(三):利用自定义文本数据集构建Dataset和DataLoader

Pytorch 自学笔记(三)

  • 1. Dataset与DataLoader
    • 1.1 torch.utils.data.Dataset
    • 1.2 torch.utils.data.DataLoader

Pytorch 自学笔记系列的第三篇。针对Pytorch的Dataset和DataLoader进行简单的介绍,同时,介绍如何使用自定义文本数据集构建Dataset和DataLoader,以实现数据集的随机采样与batch加载。(注:文中代码使用python3.7和pytorch1.7.1编写)

1. Dataset与DataLoader

1.1 torch.utils.data.Dataset

torch.utils.data.Dataset是pytorch中定义的数据集抽象类,pytorch中任何的数据集类都必须继承并重写这个类,其源码如下:

class Dataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`.All datasets that represent a map from keys to data samples should subclassit. All subclasses should overwrite :meth:`__getitem__`, supporting fetching adata sample for a given key. Subclasses could also optionally overwrite:meth:`__len__`, which is expected to return the size of the dataset by many:class:`~torch.utils.data.Sampler` implementations and the default optionsof :class:`~torch.utils.data.DataLoader`... note:::class:`~torch.utils.data.DataLoader` by default constructs a indexsampler that yields integral indices.  To make it work with a map-styledataset with non-integral indices/keys, a custom sampler must be provided."""def __getitem__(self, index) -> T_co:raise NotImplementedErrordef __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':return ConcatDataset([self, other])# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.py

而任何继承torch.utils.data.Dataset的数据集类,必须重写__getitem__方法,可以选择性重写__len__方法(若要以该数据集类构建torch.utils.data.Sampler或者torch.utils.data.DataLoader,则必须重写__len__方法)。__getitem__方法的作用为,利用index获得数据集中该index对应的样本(这就要求该数据类中必须维持一个可以;),而__len__方法的作用为返回数据集的样本数量。一个torch.utils.data.Dataset子类的样例如下:

from torch.utils.data import Dataset
import pandas as pdclass MyDataset(Dataset):def __init__(self, csv_file, txt_file, root_dir, other_file):self.csv_data = pd.read_csv(csv_file)with open(txt_file, 'r') as f:data_list = f.readlines()# 可利用索引下标进行取值的成员变量,list类型self.txt_data = data_listself.root_dir = root_dir# 返回数据集的样本数量def __len__(self):return len(self.csv_data)# 返回数据集中索引为idx的样本def __getitem__(self, index):data = (self.csv_data[index], self.txt_data[index])return data

利用自定义的Dataset子类,可以将我们的数据集定义我们需要的数据类,然后通过迭代的方式利用index下标索引来获取数据集中的每一条样本数据。而数据集的batch取样和取样时的shuffle,则需要利用torch.utils.data.DataLoader来实现。

1.2 torch.utils.data.DataLoader

首先需要明确一点,Dataset和DataLoader本质上都是iterable(可迭代对象),都可以实现数据集的迭代访问。而 torch.utils.data.DataLoader相当于是Dataset(数据集)和Sampler(采样器)的组合,即可以在Dataset上进行迭代的自定义采样。同时,DataLoader还支持单进程或多进程加载,自定义加载顺序以及可选的自动批处理(整理)和memory pinning,它还支持 map风格的数据集对象,其参数具体解释如下(参数说明参考了这篇文章,并按照pytorch1.7.1的文档进行了修改):

  1. dataset(Dataset): 传入的数据集类
  2. batch_size(int, optional): 每个batch有多少个样本
  3. shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序(即随机采样)
  4. sampler(Sampler or Iterable, optional): 自定义从数据集中取样本的策略;如果指定这个参数,那么shuffle必须为False;该值可以为任何实现了__len__函数的Iterable对象
  5. batch_sampler(Sampler or Iterable, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再指定了(互斥——Mutually exclusive)
  6. num_workers (int, optional):这个参数决定了有几个进程来处理data loading;0意味着所有的数据都会被load进主进程(默认为0)
  7. collate_fn (callable, optional): 一个函数,该函数的作用是将一个由样本构成的batch_size大小的list转换成mini-batch,该函数的输出即为迭代时获得的batch
  8. pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
  9. drop_last (bool, optional):如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了;如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点
  10. timeout(numeric, optional):如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了;这个numeric应总是大于等于0;默认为0
  11. worker_init_fn (callable, optional): 每个进程的初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
  12. prefetch_factor (int, optional, keyword-only arg):每个进程预先加载的样本数量。该值2意味着所有的进程预先加载了2 * num_workers个样本(默认为2)
  13. persistent_workers (bool, optional) :如果为True,则迭代完一次数据集后,DataLoader将不会关闭工作进程;这样可以使Worker Dataset实例保持活动状态(默认为False)

利用上一节定义的MyDataset数据集类可以构建一个DataLoader对象:

from torch.utils.data import DataLoadermy_data_loader = DataLoader(myDataset, batch_size=32, shuffle=True)

相关文章:

Pytorch 自学笔记(三):利用自定义文本数据集构建Dataset和DataLoader

Pytorch 自学笔记(三) 1. Dataset与DataLoader1.1 torch.utils.data.Dataset1.2 torch.utils.data.DataLoader Pytorch 自学笔记系列的第三篇。针对Pytorch的Dataset和DataLoader进行简单的介绍,同时,介绍如何使用自定义文本数据集…...

QT 使用QSqlTableModel对数据库进行创建,插入,显示

文章目录 效果图概述功能点代码分析初始数据插入数据数据显示 总结 效果图 概述 本案例用于对数据库中的数据进行显示等其他操作,其他表格筛选,过滤等功能可看此博客 框架:数据模型使用QSqlTableModel,视图使用QTableView&#x…...

如何学习Transformer架构

Transformer架构自提出以来,在自然语言处理领域引发了革命性的变化。作为一种基于注意力机制的模型,Transformer解决了传统序列模型在并行化和长距离依赖方面的局限性。本文将探讨Transformer论文《Attention is All You Need》与Hugging Face Transform…...

浅谈云计算22 | Kubernetes容器编排引擎

Kubernetes容器编排引擎 一、Kubernetes管理对象1.1 Kubernetes组件和架构1.2 主要管理对象类型 二、Kubernetes 服务2.1 服务的作用与原理2.2 服务类型 三、Kubernetes网络管理3.1 网络模型与目标3.2 网络组件3.2.1 kube-proxy3.2.2 网络插件 3.3 网络通信流程 四、Kubernetes…...

计算 SAMOut V3 在将词汇表从1万 增加到6千万的情况下能够减少多少参数

当我们将词汇表从 60,000,000(六千万)减少到 10,000 时,实际上是在缩小模型的词嵌入层及其共享的语言模型头(LM Head)的规模。这将导致参数量显著减少。我们可以通过以下步骤来计算具体的参数减少量。 参数量减少计算…...

03.选择排序

一、题目思路 选择排序是一种简单直观的排序算法。它的工作原理是:首先在未排序序列中找到最小(或最大)元素,存放到排序序列的起始位置,然后,再从剩余未排序元素中继续寻找最小(或最大&#xff…...

02_登录窗口

新建场景 重命名为GameRoot 双击GameRoot进入新场景 同样摄像机清除格式 删除平行光并关闭渲染灯光的天空盒 新建空节点重命名为GameRoot GameRoot为游戏的根节点 在整个游戏中都不会被删除 在游戏的根节点下创建UI的根节点Canvas 创建一个空节点 作为UI根节点下的 登录场景UI…...

NodeJS | 搭建本地/公网服务器 live-server 的使用与安装

目录 介绍 安装 live-server 安装方法 安装后的验证 环境变量问题 Node.js 环境变量未配置正确 全局安装的 live-server 路径未添加到环境变量 运行测试 默认访问主界面 访问文件 报错信息与解决 问题一:未知命令 问题二:拒绝脚本 公网配置…...

SystemUI 实现音量条同步功能

需求:SystemUI 实现音量条同步功能 具体问题 以前在SystemUI 下拉框添加了音量条控制,目前发现在SystemUI下拉框显示状态的情况下, 按键或者底部虚拟导航点击音量加减时候,SystemUI音量条不更新。 如下图:两个Syste…...

嵌入式知识点总结 C/C++ 专题提升(一)-关键字

针对于嵌入式软件杂乱的知识点总结起来,提供给读者学习复习对下述内容的强化。 目录 1.C语言宏中"#“和"##"的用法 1.1.(#)字符串化操作符 1.2.(##)符号连接操作符 2.关键字volatile有什么含意?并举出三个不同的例子? 2.1.并行设备的硬件寄存…...

基础入门-传输加密数据格式编码算法密文存储代码混淆逆向保护安全影响

知识点: 1、传输格式&传输数据-类型&编码&算法 2、密码存储&代码混淆-不可逆&非对称性 一、演示案例-传输格式&传输数据-类型&编码&算法 传输格式 JSON XML WebSockets HTML 二进制 自定义 WebSockets:聊天交互较常…...

几个Linux系统安装体验(续): 统信桌面系统

本文介绍统信桌面系统(uos)的安装。 下载 下载地址: https://www.chinauos.com/resource/download-professional 下载文件:本文下载文件名称为uos-desktop-20-professional-1070-amd64.iso。 下载注意事项:可直接下…...

算法日记6.StarryCoding P52:我们都需要0(异或)

一、题目 二、题解: 1、对于这道题,题意为让我们寻找一个数x使得 b[i]a[i]^x, 并且b[1]^b[2]^b[3]^ b[4]^b[5]....0 2、我们把b[i]给拆开,可以得到 3、又因为^满足结合律,因此,可以把括号给拆开 4、接着…...

【网络协议】RFC3164-The BSD syslog Protocol

引言 Syslog常被称为系统日志或系统记录,是一种标准化的协议,用于网络设备、服务器和应用程序向中央Syslog服务器发送日志消息。互联网工程任务组(IETF)发布的RFC 3164,专门定义了BSD Syslog协议的规范和实现方式。通…...

SpringCloud -根据服务名获取服务运行实例并进行负载均衡

Nacos注册中心 每个服务启动之后都要向注册中心发送服务注册请求&#xff0c;注册中心可以和各个注册客户端自定义协议实现服务注册和发现。 pom.xml <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-na…...

CentOS 安装Redis

1. 安装 Redis 安装 EPEL 仓库&#xff08;对于 CentOS/RHEL 系统&#xff09;&#xff1a; 首先安装 EPEL 仓库&#xff0c;因为 Redis 存在于 EPEL 仓库中&#xff1a; yum install epel-release安装 Redis 数据库&#xff1a; yum install redis2. 修改 Redis 配置文件 …...

Linux网络 TCP socket

TCP简介 TCP&#xff08;Transmission Control Protocol&#xff09;是一种面向连接的、可靠的、基于字节流的传输层通信协议。它位于OSI模型的第四层&#xff0c;主要为应用层提供数据传输服务。TCP通过三次握手建立连接&#xff0c;确保数据在发送和接收过程中的准确性和顺序…...

(一)相机标定——四大坐标系的介绍、对应转换、畸变原理以及OpenCV完整代码实战(C++版)

一、四大坐标系介绍 1&#xff0c;世界坐标系 从这个世界&#xff08;world&#xff09;的视角来看物体 世界坐标系是3D空间坐标&#xff0c;每个点的位置用 ( X w , Y w , Z w ) (X_w,Y_w,Z_w) (Xw​,Yw​,Zw​)表示 2&#xff0c;相机坐标系 相机本身具有一个坐标系&…...

【Linux网络编程】高效I/O--I/O的五种类型

目录 I/O的概念 网络通信的本质 I/O的本质 高效I/O 五种I/O模型 阻塞I/O 非阻塞I/O 信号驱动I/O 多路转接/多路复用I/O 异步I/O 非阻塞I/O的实现 I/O的概念 网络通信的本质 网络通信的本质其实就是I/O I&#xff1a;表示input(输入)O&#xff1a;表示ou…...

企业级NoSQL数据库Redis

1.浏览器缓存过期机制 1.1 最后修改时间 last-modified 浏览器缓存机制是优化网页加载速度和减少服务器负载的重要手段。以下是关于浏览器缓存过期机制、Last-Modified 和 ETag 的详细讲解&#xff1a; 一、Last-Modified 头部 定义&#xff1a;Last-Modified 表示服务器上资源…...

简易版抽奖活动的设计技术方案

1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

在rocky linux 9.5上在线安装 docker

前面是指南&#xff0c;后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...

【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密

在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接&#xff1a;A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串&#xff0c;只有在同时为 o 时输出 Yes 并结束程序&#xff0c;否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时&#xff0c;你可能需要保留重要的数据&#xff0c;例如通讯录。好在&#xff0c;将通讯录从 iPhone 转移到 Android 手机非常简单&#xff0c;你可以从本文中学习 6 种可靠的方法&#xff0c;确保随时保持连接&#xff0c;不错过任何信息。 第 1…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

三体问题详解

从物理学角度&#xff0c;三体问题之所以不稳定&#xff0c;是因为三个天体在万有引力作用下相互作用&#xff0c;形成一个非线性耦合系统。我们可以从牛顿经典力学出发&#xff0c;列出具体的运动方程&#xff0c;并说明为何这个系统本质上是混沌的&#xff0c;无法得到一般解…...