当前位置: 首页 > news >正文

torch 的数据加载 Datasets DataLoaders

点赞收藏关注!
如需要转载,请注明出处!

torch的模型加载有两种方式:
Datasets & DataLoaders

torch本身可以提供两数据加载函数:
torch.utils.data.DataLoader()和torch.utils.data.Dataset()

其中torch.utils.data 是PyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

加载函数后可以实现数据集代码与模型训练代码分离,以获得更好的可读性和模块化
Dataset定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。制作自己的数据集必须要实现三个函数:

  • init()函数在实例化Dataset对象时运行一次
  • len()函数返回数据集中样本的数量
  • getitem()函数的作用是:从给定索引 index,从数据集中加载并返回一个样本并将其转换为张量。
import torch
from torch.utils.data import Datasetclass CreateDataset(Dataset):def __init__(self, data):self.data = datadef __getitem__(self, index):# 根据索引获取样本return self.data[index]def __len__(self):# 返回数据集大小return len(self.data)# 创建数据集对象
data = [[255,255,255],[255,245,235],[225,226,227]]
dataset = CreateDataset(data)# 根据索引获取样本
sample = dataset[1]
print(sample)
# [255,245,235]

数据处理模块其他的功能:

  • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoadera = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99], [11, 22, 33], [44, 55, 66], [77, 88, 99]])
b = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2])
train_ids = TensorDataset(a, b)for x_train, y_label in train_ids:print(x_train, y_label)##############################################################################################
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
#tensor([11, 22, 33]) tensor(0)
#tensor([44, 55, 66]) tensor(1)
#tensor([77, 88, 99]) tensor(2)
  • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。

  • Subset: 数据集的子集类,用于从数据集中选择指定的样本。定义了一个子集的索引列表indices,它可以根据需要进行调整。然后,我们使用Subset类创建了一个名为subset的子集对象,它接受两个参数:原始数据集dataset和子集的索引列表indices。

indices = [0, 2, 4]  # 子集的索引列表
subset = Subset(dataset, indices)
  • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
import torch
import torchvision
# from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader# 定义训练的设备
device = torch.device("cuda")
#读取数据
data_transform = transforms.Compose([transforms.Resize(size=(224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
])
full_dataset = ImageFolder(r'D:\PythonSpace\data\trainTest',transform = data_transform)
# length 数据集总长度
full_data_size = len(full_dataset)
print("总数据集的长度为:{}".format(full_data_size))
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size#在这里
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
#在这里train_data_size = len(train_dataset)
test_data_size = len(test_dataset)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
>>>
总数据集的长度为:100
训练数据集的长度为:80
测试数据集的长度为:20
  • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
#链接两个数据集
dataset = torch.utils.data.ConcatDataset([celeba_dataset, digiface_dataset]) 
#导入数据集
loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, shuffle=True, drop_last=True, num_workers=cfg.n_workers)
  • get_worker_info: 获取当前数据加载器所在的进程信息。torch.utils.data.get_worker_info() 在worker进程中返回各种有用的信息(包括worker id、dataset replica、initial seed等),在main进程中返回None。用户可以在数据集代码和/或 worker_init_fn 中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。分片数据集特别有用。

如有帮助,点赞收藏关注!

相关文章:

torch 的数据加载 Datasets DataLoaders

点赞收藏关注! 如需要转载,请注明出处! torch的模型加载有两种方式: Datasets & DataLoaders torch本身可以提供两数据加载函数: torch.utils.data.DataLoader()和torch.utils.data.Datase…...

【Promise】某个异步方法执行结束后 在执行下面方法

使用Promise ,当 layer.msg(查询成功) 这个方法执行结束后 ,下面代码才会执行 let thas this async function showMessage() {await new Promise(resolve > layer.msg(查询成功, resolve));// 这里的代码将在 layer.msg 执行结束后执行thas.isGuaran…...

任意文件下载漏洞(CVE-2021-44983)

简介 CVE-2021-44983是Taocms内容管理系统中的一个安全漏洞,可以追溯到版本3.0.1。该漏洞主要源于在登录后台后,文件管理栏存在任意文件下载漏洞。简言之,这个漏洞可能让攻击者通过特定的请求下载系统中的任意文件,包括但不限于敏…...

C++(20):通过source_location实现日志函数

C++20中引入了std::source_location,用来描述函数调用的上下文信息。 其主要的成员函数如下: line():获取行号。column():获取列号。file_name():获取文件名。function_name():获取函数域名。#include <iostream> #include <string_view> #include <sour…...

【数据结构】树与二叉树(廿二):树和森林的遍历——后根遍历(递归算法PostOrder、非递归算法NPO)

文章目录 5.1 树的基本概念5.1.1 树的定义5.1.2 森林的定义5.1.3 树的术语 5.2 二叉树5.3 树5.3.1 树的存储结构1. 理论基础2. 典型实例3. Father链接结构4. 儿子链表链接结构5. 左儿子右兄弟链接结构 5.3.2 获取结点的算法5.3.3 树和森林的遍历1. 先根遍历&#xff08;递归、非…...

精通Nginx(17)-安全管控之防暴露、限制访问、防DDos攻击、防爬虫、防非法引用

安全是每个系统都需要考虑的关键因素,Nginx在这方面提供了丰富的功能,使我们可以就实际情形做很精细调整。这些功能包括防信息暴露、客户端访问限制、通讯加密、防DDos攻击、防爬虫、防非法引用及防非法域名请求等。 目录 防信息暴露 关闭版本号 关闭目录列表 客户端访问…...

STM32 Flash

FLASH简介 Flash是常用的用于存储数据的半导体器件&#xff0c;它具有容量大&#xff0c;可重复擦写&#xff0c;按“扇区/块”擦除、掉电后数据可继续保存的特性。 常见的FLASH主要有NOR FLASH和NAND FLASH两种类型。NOR和NAND是两种数字门电路&#xff0c;可以简单地认为FL…...

文件批量重命名技巧:图片文件名太长怎么办?告别手动改名方法

在日常生活中&#xff0c;常常会遇到文件名过长导致的问题。尤其是在处理大量图片文件时&#xff0c;过长的文件名可能会使得文件管理变得混乱不堪。现在来看下云炫文件管理器如何批量重命名&#xff0c;让图片文件名变得更简洁&#xff0c;提高工作效率。 操作1、在云炫文件…...

微信小程序手写滑动tab

微信小程序手写滑动tab index.wxml <view class"tab-bar"> <scroll-view scroll-x class"tab-scroll"> <block wx:for"{{tabs}}" wx:key"index"> <view class"tab-item {{currentIndex index ? acti…...

一文读懂如何安全地存储密码

目录 引言 明文存储 基本哈希存储 加盐哈希存储 适应性哈希算法 密码加密存储 小结 引言 密码是最常用的身份验证手段&#xff0c;既简单又高效。密码安全是网络安全的基石&#xff0c;对保护个人和组织信息的安全具有根本性的作用。然而有关密码泄漏的安全问题一再发生…...

【运维面试100问】(六)buffer和cache的区别

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…...

创建域名邮箱邮件地址的方法与步骤

如何创建域名邮箱邮件地址?使用Zoho Mail创建域名邮箱邮件地址的步骤简单易懂&#xff0c;操作便捷。从其他邮箱迁移到Zoho Mail的过程也相当顺畅&#xff0c;您可以轻松为所有员工创建具有企业邮箱域名的电子邮件地址。 步骤1&#xff1a;添加并验证您的域名 首先&#xff0c…...

Qt框架学习(1)

1.安装Qt官网 安装需注意的是&#xff0c;要安装开源版(有钱当我没说),而安装包都是一样的&#xff0c;主要是在注册账户时选择个人开发&#xff0c;而不要选公司&#xff0c;否则在安装时登录账号后会安装商业版Qt. 2.Qt中的快捷键 快捷键解释F4头文件和实现文件切换ShiftF…...

3D电路板在线渲染案例

从概念上讲,这是有道理的,因为PCB印制电路板上的走线从一个连接到下一个连接的路线基本上是平面的。 然而,我们生活在一个 3 维世界中,能够以这种方式可视化电路以及相应的组件,对于设计过程很有帮助。本文将介绍KiCad中基本的3D查看功能,以及如何使用NSDT 3DConvert在线…...

ResizeObserver loop limit exceeded报错解决方案

前言&#xff1a; 控制台没有报错&#xff0c;但是开发Vue项目过程中一直报ResizeObserver loop limit exceeded 错&#xff0c;找到以下解决方式。在main.js文件中重写 ResizeObserver 方法。 main.js文件 &#xff08;完整版&#xff09; import { createApp } from "v…...

【OpenCV实现图像:使用OpenCV进行图像处理之透视变换】

文章目录 概要计算公式举个栗子实际应用小结 概要 透视变换&#xff08;Perspective Transformation&#xff09;是一种图像处理中常用的变换手段&#xff0c;它用于将图像从一个视角映射到另一个视角&#xff0c;常被称为投影映射。透视变换可以用于矫正图像中的透视畸变&…...

Vue中学习笔记-数据代理

文章目录 前文提要数据代理的概念MVVM模型和Vue中的数据代理M&#xff0c;模型V&#xff0c;视图VM&#xff0c;视图模型 前文提要 本人仅做个人学习记录&#xff0c;如有错误&#xff0c;请多包涵 数据代理的概念 使用一个对象代理对另一个对象中属性的操作。 MVVM模型和Vu…...

IDEA 配置maven结合案例使用篇

1. 项目需求和结构分析 需求案例&#xff1a;搭建一个电商平台项目&#xff0c;该平台包括用户服务、订单服务、通用工具模块等。 项目架构&#xff1a; 用户服务&#xff1a;负责处理用户相关的逻辑&#xff0c;例如用户信息的管理、用户注册、登录等。 spring-context 6.0.…...

基于白鲸算法优化概率神经网络PNN的分类预测 - 附代码

基于白鲸算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于白鲸算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于白鲸优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神经网络的光滑…...

Android使用Kotlin利用Gson解析多层嵌套Json数据

文章目录 1、依赖2、解析 1、依赖 build.gradle(app)中加入 dependencies { implementation com.google.code.gson:gson:2.8.9 }2、解析 假设这是要解析Json数据 var responseStr "{"code": 200,"message": "操作成功","data&quo…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0&#xff1a;开发环境同步测试 cookie 至 localhost&#xff0c;便于本地请求服务携带 cookie 参考地址&#xff1a;https://juejin.cn/post/7139354571712757767 里面有源码下载下来&#xff0c;加在到扩展即可使用FeHelp…...

VB.net复制Ntag213卡写入UID

本示例使用的发卡器&#xff1a;https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

rnn判断string中第一次出现a的下标

# coding:utf8 import torch import torch.nn as nn import numpy as np import random import json""" 基于pytorch的网络编写 实现一个RNN网络完成多分类任务 判断字符 a 第一次出现在字符串中的位置 """class TorchModel(nn.Module):def __in…...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...

C语言中提供的第三方库之哈希表实现

一. 简介 前面一篇文章简单学习了C语言中第三方库&#xff08;uthash库&#xff09;提供对哈希表的操作&#xff0c;文章如下&#xff1a; C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

根目录0xa0属性对应的Ntfs!_SCB中的FileObject是什么时候被建立的----NTFS源代码分析--重要

根目录0xa0属性对应的Ntfs!_SCB中的FileObject是什么时候被建立的 第一部分&#xff1a; 0: kd> g Breakpoint 9 hit Ntfs!ReadIndexBuffer: f7173886 55 push ebp 0: kd> kc # 00 Ntfs!ReadIndexBuffer 01 Ntfs!FindFirstIndexEntry 02 Ntfs!NtfsUpda…...

CVPR2025重磅突破:AnomalyAny框架实现单样本生成逼真异常数据,破解视觉检测瓶颈!

本文介绍了一种名为AnomalyAny的创新框架&#xff0c;该方法利用Stable Diffusion的强大生成能力&#xff0c;仅需单个正常样本和文本描述&#xff0c;即可生成逼真且多样化的异常样本&#xff0c;有效解决了视觉异常检测中异常样本稀缺的难题&#xff0c;为工业质检、医疗影像…...

stm32wle5 lpuart DMA数据不接收

配置波特率9600时&#xff0c;需要使用外部低速晶振...

从物理机到云原生:全面解析计算虚拟化技术的演进与应用

前言&#xff1a;我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM&#xff08;Java Virtual Machine&#xff09;让"一次编写&#xff0c;到处运行"成为可能。这个软件层面的虚拟化让我着迷&#xff0c;但直到后来接触VMware和Doc…...