PyTorch处理数据--Dataset和DataLoader
在 PyTorch 中,Dataset 和 DataLoader 是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。
一、Dataset:数据集的抽象
Dataset 是一个抽象类,用于表示数据集的接口。你需要继承 torch.utils.data.Dataset 并实现以下两个方法:
__len__(): 返回数据集的总样本数。__getitem__(idx): 根据索引idx返回一个样本(数据和标签)。
示例:自定义 Dataset
import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transform # 数据预处理/增强函数def __len__(self):return len(self.data)def __getitem__(self, idx):sample = {"data": self.data[idx], "label": self.labels[idx]}if self.transform:sample = self.transform(sample)return sample
使用场景
- 加载图像、文本、表格数据等。
- 支持数据预处理(如归一化、裁剪)和数据增强(如随机翻转)。
二、 DataLoader:高效加载数据
DataLoader 负责将 Dataset 包装成一个可迭代对象,支持批量加载、多线程加速和数据打乱。
基本用法
from torch.utils.data import DataLoader# 假设 dataset 是你的 CustomDataset 实例
data_loader = DataLoader(dataset,batch_size=32, # 批量大小shuffle=True, # 是否打乱数据(训练时建议开启)num_workers=4, # 多线程加载数据的进程数drop_last=False # 是否丢弃最后不足一个 batch 的数据
)
遍历 DataLoader
for batch in data_loader:data = batch["data"] # 形状:[batch_size, ...]labels = batch["label"] # 形状:[batch_size]# 将数据送入模型训练...
三、pytorch内置数据集
PyTorch 提供了一系列内置数据集,这些数据集可以直接用于训练模型。这些数据集涵盖了多种领域,如图像、文本、音频等。以下是一些常用的PyTorch内置数据集:
图像数据集
-
MNIST: 手写数字数据集,包含0到9的手写数字图片。
from torchvision import datasets mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform) -
CIFAR10/CIFAR100: 包含彩色图片的数据集,CIFAR10有60000张32x32的彩色图片,分为10个类别;CIFAR100类似但有100个类别。
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) -
ImageNet: 包含超过1400万张图片的非常庞大的数据集,常用于图像识别和分类任务。
import torchvision.datasets as datasets imagenet_train = datasets.ImageNet(root='./data', split='train', download=True) -
STL10: 一个用于计算机视觉研究的小型图像数据集,包含96x96的彩色图片。
stl10_train = datasets.STL10(root='./data', split='train', download=True) -
SVHN: 包含数字图片的数据集,与MNIST类似但包含更多实际场景的图片。
svhn_train = datasets.SVHN(root='./data', split='train', download=True, transform=transform)
文本数据集
1.Text8: 一个用于自然语言处理的小型文本数据集。
from torchtext.datasets import Text8
text8_train = Text8(split=('train',))
2. AG_NEWS: 包含新闻文章的文本数据集,分为4个类别。
from torchtext.datasets import AG_NEWS
ag_news_train = AG_NEWS(split=('train',))
音频数据集
1. Speech Commands: 一个用于语音识别的数据集,包含约65,000个单词发音的音频文件。
from torchaudio.datasets import SPEECHCOMMANDS
speech_commands = SPEECHCOMMANDS(root="./data", download=True)
使用方法
要使用这些数据集,首先需要导入torchvision(对于图像数据集)、torchtext(对于文本数据集)或torchaudio(对于音频数据集),然后使用其提供的类来加载数据。通常还包括一些数据预处理步骤,例如转换(transforms)。
import torchvision.transforms as transforms
from torchvision import datasetstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
四、完整代码示例
步骤 1:创建数据集
import numpy as np
from torch.utils.data import Dataset, DataLoader# 生成示例数据(假设是 10 个样本,每个样本是长度为 5 的向量)
data = np.random.randn(10, 5)
labels = np.random.randint(0, 2, size=(10,)) # 二分类标签class MyDataset(Dataset):def __init__(self, data, labels):self.data = torch.tensor(data, dtype=torch.float32)self.labels = torch.tensor(labels, dtype=torch.long)def __len__(self):return len(self.data)def __getitem__(self, idx):return {"data": self.data[idx],"label": self.labels[idx]}dataset = MyDataset(data, labels)
步骤 2:创建 DataLoader
data_loader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=2
)
步骤 3:使用 DataLoader 训练模型
model = ... # 你的模型
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch in data_loader:x = batch["data"]y = batch["label"]# 前向传播outputs = model(x)loss = loss_fn(outputs, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()
五、常见问题解决
(1)数据格式不匹配
- 问题:
DataLoader返回的数据形状与模型输入不匹配。 - 解决:检查
Dataset的__getitem__返回的数据类型和形状,确保与模型输入一致。
(2)多线程加载卡顿
- 问题:设置
num_workers>0时程序卡死或报错。 - 解决:在 Windows 系统中,多线程可能需要将代码放在
if __name__ == "__main__":块中运行。
(3)数据增强
- 使用
torchvision.transforms中的工具(如RandomCrop、RandomHorizontalFlip)对图像数据进行增强:from torchvision import transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5]), ])
(4)内存不足
- 对于大型数据集,使用
torch.utils.data.DataLoader的persistent_workers=True(PyTorch 1.7+)或优化数据加载逻辑。
六、高级功能
- 分布式训练:使用
torch.utils.data.distributed.DistributedSampler配合多 GPU。 - 预加载数据:使用
torch.utils.data.TensorDataset直接加载 Tensor 数据。 - 自定义采样器:通过
sampler参数控制数据采样顺序(如平衡类别采样)。
相关文章:
PyTorch处理数据--Dataset和DataLoader
在 PyTorch 中,Dataset 和 DataLoader 是处理数据的核心工具。它们的作用是将数据高效地加载到模型中,支持批量处理、多线程加速和数据增强等功能。 一、Dataset:数据集的抽象 Dataset 是一个抽象类,用于表示数据集的接口。你…...
【Linux】POSIX信号量与基于环形队列的生产消费者模型
目录 一、POSIX信号量: 接口: 二、基于环形队列的生产消费者模型 环形队列: 单生产单消费实现代码: RingQueue.hpp: main.cc: 多生产多消费实现代码: RingQueue.hpp: main.…...
Spring Boot 连接 MySQL 配置参数详解
Spring Boot 连接 MySQL 配置参数详解 前言参数及含义常用参数及讲解和示例useUnicode 参数说明: 完整配置示例注意事项 前言 在 Spring Boot 中使用 Druid 连接池配置 MySQL 数据库连接时,URL 中 ? 后面的参数用于指定连接的各种属性。以下是常见参数…...
[linux] linux基本指令 + shell + 文件权限
目录 1. Linux的认识 1.1. Linux的应用场景 1.2. Linux的版本问题 1.3. 操作系统的认识 1.4. 常用快捷键 2. 常用指令介绍 2.1. ADD 2.1.1. touch [file] 2.1.1.1. 文件的属性信息 2.1.2. mkdir [directory] 2.1.3. cp [file/directory] 2.1.4. echo [file] 2.1.4.…...
查看进程文件描述符的限制
查看进程文件描述符限制 rootgb:/home/gb/Monitor-Device-Mgr/Monitor-Device-Mgr/bin# ps -ef |grep Monitor-Device-Mgr root 3976 2380 59 11:10 pts/2 00:00:06 ./Monitor-Device-Mgr root 4010 2395 0 11:10 pts/3 00:00:00 grep --colorauto Monito…...
Python实现小红书app版爬虫
简介:由于数据需求的日益增大,小红书网页版已经不能满足我们日常工作的需求,为此,小编特地开发了小红书手机版算法,方便大家获取更多的数据,提升工作效率。 手机版接口主要包括:搜素࿰…...
【docker】docker-compose安装RabbitMQ
docker-compose安装RabbitMQ 1、配置docker-compose.yml文件(docker容器里面的目录请勿修改)2、启动mq3、访问mq4、查看服务器映射目录5、踩坑5.1、权限不足 1、配置docker-compose.yml文件(docker容器里面的目录请勿修改) versi…...
playwright-go实战:自动化登录测试
1.新建项目 打开Goland新建项目playwright-go-demo 项目初始化完成后打开终端输入命令: #安装项目依赖 go get -u github.com/playwright-community/playwright-go #安装浏览器 go run github.com/playwright-community/playwright-go/cmd/playwrightlatest insta…...
LeetCode hot 100 每日一题(13)——73. 矩阵置零
这是一道难度为中等的题目,让我们来看看题目描述: 给定一个 _m_ x _n_ 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 提示: m matrix.lengthn matrix[0].length1 < m, n …...
CEF 给交互函数, 添加控制台是否显示交互参数log开关
CEF 控制台添加一函数,枚举 注册的供前端使用的CPP交互函数有哪些 CEF 多进程模式时,注入函数,获得交互信息-CSDN博客 这两篇文章,介绍了注入函数,在控制台中显示 各自提供的交互函数信息。 有些场景下,我们还需要更详细的信息,比如想知道 彼此传递的参数, 如果每次调…...
云端存储新纪元:SAN架构驱动的智能网盘解决方案
一、企业存储的"不可能三角"破局 1.1 传统存储架构的困局 性能瓶颈:NAS架构在1000并发访问时延迟飙升300%容量限制:传统RAID扩容需停机维护,PB级存储扩展耗时超48小时成本矛盾:全闪存阵列每TB成本高达$3000࿰…...
PVE 安装黑苹果 MacOS
背景 我需要一台黑苹果,登录我不常用苹果账号。 方法 The Definitive Guide to Running MacOS in ProxmoxRunning a MacOS 15 Sequoia VM in ProxMox VE及视频 按照第二个的视频一步一步配置,第一个链接提供了不同版本OS...
Unity URP自定义Shader支持RenderLayer
前言: 当我们想用一个灯光只对特定的物体造成影响,而不对其余物体造成影响时,我们就需要设置相对应的LightLayer,但是这在URP12.0是存在的,在之后就不存在LightLayer这一功能,URP将其隐藏而改成了RenderLa…...
Axure项目实战:智慧城市APP(完整交互汇总版)
亲爱的小伙伴,在您浏览之前,烦请关注一下,在此深表感谢! 课程主题:智慧城市APP 主要内容:主功能(社保查询、医疗信息、公交查询等)、活动、消息、我的页面汇总 应用场景ÿ…...
LVS-DR模式配置脚本
LVS-DR模式配置脚本 实验环境,需要4台虚拟机 IP说明172.25.254.101客户端172.25.254.102负载均衡器DS172.25.254.103真实服务器RS172.25.254.104真实服务器RSVIP:172.25.254.255/32 系统必须有ipvsadm和ifconfig命令 dnf install ipvsadm dnf install n…...
树状数组 3 :区间修改,区间查询
【题目描述】 这是一道模板题。 给定数列 a[1],a[2],…,a[n],你需要依次进行q个操作,操作有两类: 1lrx:给定 l,r,x对于所有 i∈[l,r],将a[i]加上x(换言之,将 a[l],a[l1],…a[r] 分别加上 x&a…...
架构思维:预约抢茅子架构设计
文章目录 案例:预约抢茅子复杂度分析商品预约阶段等待抢购阶段商品抢购阶段订单支付阶段 技术方案商品预约阶段一、基于 Redis 单节点的分布式锁方案1. 核心流程2. 关键设计点 二、Redis 单节点方案的局限性1. 单点故障风险2. 主从切换问题 三、多节点 Redis 实现高…...
使用 gone.WrapFunctionProvider 快速接入第三方服务
项目地址:https://github.com/gone-io/gone 本文中源代码: esexamples/es 文章目录 1. gone.WrapFunctionProvider 简介2. 配置注入实现3. 实战示例:Elasticsearch 集成4. 使用方式5. 最佳实践6. 总结 在如何给Gone框架编写Goner组件…...
基于SpringBoot+Vue的在教务管理(课程管理)系统+LW示例
1.项目介绍 系统角色:管理员、学生、教师功能模块:管理员(学院管理、专业管理、班级管理、学生管理、教师管理、课程管理、选课修改)、教师(授课查询、教师课表、成绩录入)、学生(选修课程、学…...
gitee 常用指令
1.拉取代码 // http git clone http.........// https git clone https......... 2. 设置自己账户和密码 ----- 绑定git git config --global user.name "你的用户名"git config --global user.email "你的邮箱" 3. 上传本地代码至git git initgit r…...
etcd性能测试
etcd性能测试 本文参考官方文档完成etcd性能测试,提供etcd官方推荐的性能测试方案。 1. 理解性能:延迟与吞吐量 etcd 提供稳定、持续的高性能。有两个因素决定性能:延迟和吞吐量。延迟是完成一项操作所花费的时间。吞吐量是在某个时间段内…...
JIRA/Xray测试管理工具的最佳实践:从基础到高阶的全场景指南
引言:测试管理的数字化转型与工具价值 在数字化时代,软件质量已成为企业竞争力的核心指标。然而,传统的测试管理方式——如Excel记录用例、邮件沟通缺陷、手动执行回归测试——已无法满足快速迭代的敏捷开发需求。据统计,全球因测…...
ubuntu桌面图标异常——主目录下的所有文件(如文档、下载等)全部显示在桌面
ubuntu桌面图标异常 问题现象问题根源系统级解决方案方法一:全局修改(推荐多用户环境)方法二:单用户修改(推荐个人环境)操作验证与调试避坑指南扩展知识参考文档问题现象 主目录文件异常显示 用户主目录(如/home/user/)下的所有文件(如文档、下载等)全部显示在桌面,…...
AIP-191 文件和目录结构
编号191原文链接https://google.aip.dev/191状态批准创建日期2019-07-25更新日期2019-07-25 统一的文件和目录结构,虽然在技术上差别不大,但可以让用户和审查者更容易阅读API界面定义。 指南 注意 以下指南适合于使用protobuf定义的API,例如…...
sql结尾加刷题
找了一下mysql对extractvalue()、updatexml()函数的官方介绍https://dev.mysql.com/doc/refman/5.7/en/xml-functions.html#function_extractvalue ExtractValue(xml_frag, xpath_expr) 知识点 解释一下这两个参数xml_frag,是xml标记片段,第二个参数…...
Linux学习笔记(应用篇三)
基于I.MX6ULL-MINI开发板 LED学习GPIO应用编程输入设备 开发板中所有的设备(对象)都会在/sys/devices 体现出来,是 sysfs 文件系统中最重要的目录结构 /sys下的子目录说明/sys/devices这是系统中所有设备存放的目录,也就是系统中…...
LLM动态Shape实现原理与核心技术
LLM动态Shape实现原理与核心技术 目录 LLM动态Shape实现原理与核心技术1. **动态Shape核心原理**2. **实现方法与关键技术**3. **示例:vLLM处理动态长度输入**4. **动态Shape vs 静态Shape对比**5. **性能优化案例**总结`SamplingParams` 是什么常见参数及作用使用示例1. 动态…...
MyBatis 语法不支持 having 节点
MyBatis 不支持 having 节点 比如在 GROUP BY 之后添加了 HAVING 子句,其内容为SUM(vsbsad.business_income) > 0,该子句会对分组后的 SUM(vsbsad.business_income) 结果进行过滤,仅保留求和结果不为负数的分组记录。但是试过不支持。可把…...
【redis】事务详解,相关命令multi、exec、discard 与 watch 的原理
文章目录 什么是事务原子性一致性持久性隔离性 优势与 MySQL 对比用处 事务相关命令开启事务——MULTI执行事务——EXEC放弃当前事务——DISCARD监控某个 key——WATCH作用场景使用方法实现原理 事务总结 什么是事务 MySQL 事务: 原子性:把多个操作&am…...
数据库基础知识点(系列七)
视图和索引相关的语句 1.引入视图的主要目的是什么? 答:数据库的基本表是按照数据库设计人员的观点设计的,并不一定符合用户的需求。SQL Server 2008可以根据用户需求重新定义表的数据结构,这种数据结构就是视图。视图是关系数据…...
