当前位置: 首页 > news >正文

PyTorch DataLoader 学习

1. DataLoader的核心概念

DataLoader是PyTorch中一个重要的类,用于将数据集(dataset)和数据加载器(sampler)结合起来,以实现批量数据加载和处理。它可以高效地处理数据加载、多线程加载、批处理和数据增强等任务。

核心参数

  • dataset: 数据集对象,必须是继承自torch.utils.data.Dataset的类。
  • batch_size: 每个批次的大小。
  • shuffle: 是否在每个epoch开始时打乱数据。
  • sampler: 定义数据加载顺序的对象,通常与shuffle互斥。
  • num_workers: 使用多少个子进程加载数据。
  • collate_fn: 如何将单个样本合并为一个批次的函数。
  • pin_memory: 是否将数据加载到CUDA固定内存中。

2. 基本使用方法

定义数据集类

首先定义一个数据集类,该类需要继承自torch.utils.data.Dataset并实现__len____getitem__方法。

import torch
from torch.utils.data import Dataset, DataLoaderclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'data': self.data[idx], 'label': self.labels[idx]}return sample# 创建一些示例数据
data = torch.randn(100, 3, 64, 64)  # 100个样本,每个样本为3x64x64的图像
labels = torch.randint(0, 2, (100,))  # 100个标签,0或1dataset = CustomDataset(data, labels)

创建DataLoader

使用自定义数据集类创建DataLoader对象。

batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

迭代DataLoader

遍历DataLoader获取批量数据。

for batch in dataloader:data, labels = batch['data'], batch['label']print(data.shape, labels.shape)

3. 进阶技巧

自定义collate_fn

如果需要自定义如何将样本合并为批次,可以定义自己的collate_fn函数。

def custom_collate_fn(batch):data = [item['data'] for item in batch]labels = [item['label'] for item in batch]return {'data': torch.stack(data), 'label': torch.tensor(labels)}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

使用Sampler

Sampler定义了数据加载的顺序。可以自定义一个Sampler来实现更复杂的数据加载策略。

from torch.utils.data import Samplerclass CustomSampler(Sampler):def __init__(self, data_source):self.data_source = data_sourcedef __iter__(self):return iter(range(len(self.data_source)))def __len__(self):return len(self.data_source)custom_sampler = CustomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=custom_sampler, num_workers=2)

数据增强

在图像处理中,数据增强(Data Augmentation)是提高模型泛化能力的一种有效方法。可以使用torchvision.transforms进行数据增强。

import torchvision.transforms as transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])dataset = CustomDataset(data, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

4. 实战示例:CIFAR-10数据集

以下是使用CIFAR-10数据集的完整示例代码,包括数据加载、数据增强和模型训练。

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10# 定义数据增强和标准化
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 加载训练和测试数据集
trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义简单的卷积神经网络
import torch.nn as nn
import torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.fc1 = nn.Linear(64 * 8 * 8, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 8 * 8)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 创建模型、定义损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(10):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100}')running_loss = 0.0print('Finished Training')# 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

5. 数据加载加速技巧

使用多进程数据加载

通过设置num_workers参数,可以启用多进程数据加载,加速数据读取过程。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

使用pin_memory

如果使用GPU进行训练,将pin_memory设置为True可以加速数据传输。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

预取数据

使用prefetch_factor参数来预取数据,以减少数据加载等待时间。

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)

6. 处理不规则数据

在某些情况下,数据样本可能不规则,例如变长序列。可以使用自定义的collate_fn来处理这种数据。

def custom_collate_fn(batch):batch = sorted(batch, key=lambda x: len(x['data']), reverse=True)data = [item['data'] for item in batch]labels = [item['label'] for item in batch]data_padded = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)labels = torch.tensor(labels)return {'data': data_padded, 'label': labels}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

7. 使用中应注意的问题

数据加载效率

设置num_workers

  • 多线程数据加载: num_workers参数决定了用于数据加载的子进程数量。合理设置num_workers可以显著提升数据加载速度。一般来说,设置为CPU核心数的一半或等于核心数是一个不错的选择,但需要根据具体情况进行调整。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

使用pin_memory

  • 固定内存: 当使用GPU进行训练时,将pin_memory设置为True可以加速数据从CPU传输到GPU的速度。固定内存使得数据可以直接从页面锁定内存复制到GPU内存。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

预取数据

  • 预取因子: 使用prefetch_factor参数来预取数据,以减少数据加载等待时间。默认情况下,预取因子为2。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, prefetch_factor=2)

数据集与DataLoader的兼容性

正确实现 __getitem____len__

  • 数据集类的实现: 确保自定义数据集类正确实现了__getitem____len__方法,确保DataLoader能够正确地索引和迭代数据。
class CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'data': self.data[idx], 'label': self.labels[idx]}return sample

数据增强与预处理

数据增强

  • 变换操作: 在图像处理中,数据增强可以提高模型的泛化能力。可以使用torchvision.transforms进行数据增强和标准化。
import torchvision.transforms as transformstransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])dataset = CustomDataset(data, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

数据加载过程中的内存问题

避免内存泄漏

  • 防止内存泄漏: 在使用DataLoader时,尤其是多进程加载时,注意内存泄漏问题。确保在训练过程中及时释放不再使用的数据。

合理设置batch_size

  • 批次大小: 根据GPU显存和内存大小合理设置batch_size。过大可能导致内存不足,过小可能导致计算效率低。
batch_size = 64  # 根据实际情况调整
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

数据顺序与随机性

shufflesampler

  • 数据随机性: 在训练集上使用shuffle=True,可以在每个epoch开始时打乱数据,防止模型过拟合。
  • 使用Sampler: 对于特殊的数据加载顺序需求,可以自定义Sampler。
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

数据不一致性

自定义collate_fn

  • 处理变长序列:在处理变长序列或不规则数据时,自定义collate_fn函数,确保每个批次的数据能够正确合并。
def custom_collate_fn(batch):data = [item['data'] for item in batch]labels = [item['label'] for item in batch]return {'data': torch.stack(data), 'label': torch.tensor(labels)}dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

数据加载调试

调试与错误处理

  • 调试: 在数据加载过程中,可以打印或检查部分数据样本,确保数据预处理和加载过程正确无误。
  • 错误处理: 使用try-except块捕捉并处理数据加载中的异常,防止程序崩溃。
for i, data in enumerate(dataloader, 0):try:inputs, labels = data['data'], data['label']# 数据处理和训练代码except Exception as e:print(f"Error loading data at batch {i}: {e}")

性能优化

数据加载性能

  • Profile数据加载: 使用profiling工具(如PyTorch的torch.utils.bottleneck)分析数据加载和训练过程中的性能瓶颈,进行相应优化。
import torch.utils.bottleneck# 在命令行运行以下命令进行性能分析
# python -m torch.utils.bottleneck <script.py>

相关文章:

PyTorch DataLoader 学习

1. DataLoader的核心概念 DataLoader是PyTorch中一个重要的类&#xff0c;用于将数据集&#xff08;dataset&#xff09;和数据加载器&#xff08;sampler&#xff09;结合起来&#xff0c;以实现批量数据加载和处理。它可以高效地处理数据加载、多线程加载、批处理和数据增强…...

TCP传输控制协议二

TCP 是 TCP/IP 模型中的传输层一个最核心的协议&#xff0c;不仅如此&#xff0c;在整个 4 层模型中&#xff0c;它都是核心的协议&#xff0c;要不然模型怎么会叫做 TCP/IP 模型呢。 它向下使用网络层的 IP 协议&#xff0c;向上为 FTP、SMTP、POP3、SSH、Telnet、HTTP 等应用…...

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(五)-同时支持无人机和eMBB用户数据传输的用例

引言 本文是3GPP TR 22.829 V17.1.0技术报告&#xff0c;专注于无人机&#xff08;UAV&#xff09;在3GPP系统中的增强支持。文章提出了多个无人机应用场景&#xff0c;分析了相应的能力要求&#xff0c;并建议了新的服务级别要求和关键性能指标&#xff08;KPIs&#xff09;。…...

使用F1C200S从零制作掌机之debian文件系统完善NES

一、模拟器源码 源码&#xff1a;https://files.cnblogs.com/files/twzy/arm-NES-linux-master.zip 二、文件系统 文件系统&#xff1a;debian bullseye 使用builtroot2018构建的文件系统&#xff0c;使用InfoNES模拟器存在bug&#xff0c;搞不定&#xff0c;所以放弃&…...

Vue 3 与 TypeScript:最佳实践详解

大家好,我是CodeQi! 很多人问我为什么要用TypeScript? 因为 Vue3 喜欢它! 开个玩笑... 在我们开始探索 Vue 3 和 TypeScript 最佳实践之前,让我们先打个比方。 如果你曾经尝试过在没有 GPS 的情况下开车到一个陌生的地方,你可能会知道那种迷失方向的感觉。 而 Typ…...

PyMysql error : Packet Sequence Number Wrong - got 1 expected 0

文章目录 错误一错误原因解决方案 错误二原因解决方案 我自己知道的&#xff0c;这类问题有两类原因&#xff0c;两种解决方案。 错误一 错误原因 pymysql的主进程启动的connect无法给子进程中使用&#xff0c;所以读取大批量数据时最后容易出现了此类问题。 解决方案 换成…...

MVC 生成验证码

在mvc 出现之前 生成验证码思路 在一个html页面上&#xff0c;生成一个验证码&#xff0c;在把这个页面嵌入到需要验证码的页面中。 JS生成验证码 <script type"text/javascript">jQuery(function ($) {/**生成一个随机数**/function randomNum(min, max) {…...

OSPF.综合实验

1、首先将各个网段基于172.16.0.0 16 进行划分 1.1、划分为4个大区域 172.16.0.0 18 172.16.64.0 18 172.16.128.0 18 172.16.192.0 18 四个网段 划分R4 划分area2 划分area3 划分area1 2、进行IP配置 如图使用配置指令进行配置 ip address x.x.x.x /x 并且将缺省路由…...

云计算【第一阶段(29)】远程访问及控制

一、ssh远程管理 1.1、ssh (secureshell)协议 是一种安全通道协议对通信数据进行了加密处理&#xff0c;用于远程管理功能SSH 协议对通信双方的数据传输进行了加密处理&#xff0c;其中包括用户登录时输入的用户口令&#xff0c;建立在应用层和传输层基础上的安全协议。SSH客…...

2024前端面试真题【CSS篇】

盒子模型 盒子模型&#xff1a;box-sizing&#xff0c;描述了文档中的元素如何生成矩形盒子&#xff0c;并通过这些盒子的布局来组织和设计网页。包含content、padding、margin、border四个部分。 分类 W3C盒子模型&#xff08;content-box&#xff09;&#xff1a;标准盒子模…...

python中设置代码格式,函数编写指南,类的编程风格

4.6 设置代码格式 随着你编写的程序越来越长&#xff0c;确保代码格式一致变得尤为重要。花时间让代码尽可能易于阅读&#xff0c;这不仅有助于你理解程序的功能&#xff0c;也能帮助他人理解你的代码。 为了保证所有人的代码结构大致一致&#xff0c;Python程序员遵循一系列…...

CentOS 8升级gcc版本

1、查看gcc版本 gcc -v发现gcc版本为8.x.x&#xff0c;而跑某个项目的finetune需要gcc-9&#xff0c;之前搜索过很多更新gcc版本的方式&#xff0c;例如https://blog.csdn.net/xunye_dream/article/details/108918316?spm1001.2014.3001.5506&#xff0c;但执行指令 sudo yu…...

Kafka基础入门篇(深度好文)

Kafka简介 Kafka 是一个高吞吐量的分布式的基于发布/订阅模式的消息队列&#xff08;Message Queue&#xff09;&#xff0c;主要应用与大数据实时处理领域。 1. 以时间复杂度为O(1)的方式提供消息持久化能力。 2. 高吞吐率。&#xff08;Kafka 的吞吐量是MySQL 吞吐量的30…...

C++之复合资料型态KU网址第二部V蒐NAY3989

结构 结构可存放不同资料型态的数值&#xff0c;例如 #include <iostream>struct Demo {int member1;char *member2;float member3; };int main() {Demo d;d.member1 19823;d.member2 "203";d.member3 3.011;std::cout << "member1: " &l…...

乡镇集装箱生活污水处理设备处理效率高

乡镇集装箱生活污水处理设备处理效率高 乡镇集装箱生活污水处理设备优势 结构紧凑&#xff1a;集装箱式设计减少了占地面积&#xff0c;便于在土地资源紧张的乡镇地区部署。 安装方便&#xff1a;设备出厂前已完成组装和调试&#xff0c;现场只需进行简单的连接和调试即可投入使…...

计算机网络高频面试题

从输入URL到展现页面的全过程&#xff1a; 用户在浏览器中输入URL。浏览器解析URL&#xff0c;确定协议、主机名和路径。浏览器查找本地DNS缓存&#xff0c;如果没有找到&#xff0c;向DNS服务器发起查询请求。DNS服务器解析主机名&#xff0c;返回IP地址。浏览器使用IP地址建立…...

进程通信(1):无名管道(pipe)

无名管道(pipe)用来具有亲缘关系的进程之间进行单向通信。半双工的通信方式&#xff0c;数据只能单向流动。 管道以字节流的方式通信&#xff0c;数据格式由用户自行定义。 无名管道多用于父子进程间通信&#xff0c;也可用于其他亲缘关系进程间通信。 因为父进程调用fork函…...

YOLOv10改进 | 损失函数篇 | SlideLoss、FocalLoss、VFLoss分类损失函数助力细节涨点(全网最全)

一、本文介绍 本文给大家带来的是分类损失 SlideLoss、VFLoss、FocalLoss损失函数&#xff0c;我们之前看那的那些IoU都是边界框回归损失&#xff0c;和本文的修改内容并不冲突&#xff0c;所以大家可以知道损失函数分为两种一种是分类损失另一种是边界框回归损失&#xff0c;…...

【数组、特殊矩阵的压缩存储】

目录 一、数组1.1、一维数组1.1.1 、一维数组的定义方式1.1.2、一维数组的数组名 1.2、二维数组1.2.1、二维数组的定义方式1.2.2、二维数组的数组名 二、对称矩阵的压缩存储三、三角矩阵的压缩存储四、三对角矩阵的压缩存储五、稀疏矩阵的压缩存储 一、数组 概述&#xff1a;数…...

Flat Ads:金融APP海外广告投放素材的优化指南

在当今全球化的数字营销环境中,金融APP的海外营销推广已成为众多金融机构与开发者最为关注的环节之一。面对不同地域、文化及用户习惯的挑战,如何优化广告素材,以吸引目标受众的注意并促成有效转化,成为了广告主们亟待解决的问题。 作为领先的全球化营销推广平台,Flat Ads凭借…...

龙虎榜——20250610

上证指数放量收阴线&#xff0c;个股多数下跌&#xff0c;盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型&#xff0c;指数短线有调整的需求&#xff0c;大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的&#xff1a;御银股份、雄帝科技 驱动…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽&#xff0c;大家好&#xff0c;我是左手python&#xff01; Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库&#xff0c;用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

JDK 17 新特性

#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持&#xff0c;不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的&#xff…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法&#xff0c;当前调用一个医疗行业的AI识别算法后返回…...

【Go语言基础【13】】函数、闭包、方法

文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数&#xff08;函数作为参数、返回值&#xff09; 三、匿名函数与闭包1. 匿名函数&#xff08;Lambda函…...

CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝

目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为&#xff1a;一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

django blank 与 null的区别

1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是&#xff0c;要注意以下几点&#xff1a; Django的表单验证与null无关&#xff1a;null参数控制的是数据库层面字段是否可以为NULL&#xff0c;而blank参数控制的是Django表单验证时字…...