当前位置: 首页 > article >正文

Python Day38

Task:
1.Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
2.Dataloader类
3.minist手写数据集的了解


1. Dataset 类的 __getitem____len__ 方法

在 PyTorch (或类似深度学习框架) 中,Dataset 是一个抽象基类,用于表示你的数据。它通常用于将原始数据(例如图像文件、文本文件、CSV 数据等)处理成模型可以直接消费的格式。

Dataset 类有两个核心的特殊方法,它们是 Python 的“魔法方法”:

  • __len__(self):

    • 作用: 这个方法必须返回数据集中样本的总数量。
    • 实现: 当你创建一个 Dataset 的子类时,你需要实现它来告诉 PyTorch 这个数据集有多大。
    • 用处: Dataloader 需要知道总长度才能正确地进行批处理、洗牌和分发数据。
    • 示例:
      class MyDataset(Dataset):def __init__(self, data_list):self.data = data_list # 假设data_list是你的数据源def __len__(self):return len(self.data) # 返回数据源的长度def __getitem__(self, idx):# ... 具体实现将在下面说明pass
      
  • __getitem__(self, idx):

    • 作用: 这个方法用于根据给定的索引 idx 返回数据集中的一个样本。
    • 实现: 这是最关键的部分。你需要在其中定义如何加载、预处理(如图像变换、文本编码)并返回一个样本及其对应的标签。
    • 返回类型: 通常,它返回一个元组或字典,其中包含一个数据样本和其对应的标签。例如 (image_tensor, label_tensor)
    • 用处: 当 Dataloader 需要获取一个批次的数据时,它会内部多次调用 __getitem__ 来收集单个样本。
    • 示例:
      import torch
      from torch.utils.data import Datasetclass CustomImageDataset(Dataset):def __init__(self, image_paths, labels, transform=None):self.image_paths = image_pathsself.labels = labelsself.transform = transform # 用于图像预处理的转换def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]label = self.labels[idx]# 假设这里是加载图像的逻辑 (实际会用Pillow等库)# 为了演示,我们创建一个虚拟图像tensorimage = torch.randn(3, 224, 224) # 3 channels, 224x224 pixelsif self.transform:image = self.transform(image) # 应用预处理return image, label # 返回图像张量和标签
      

总结 Dataset 的作用和特殊方法:

Dataset 类负责:

  1. 数据抽象: 将原始数据封装成一个可迭代、可索引的对象。
  2. 数据加载: 在 __getitem__ 中处理从文件系统或内存中加载单个数据项的逻辑。
  3. 数据预处理: 在 __getitem__ 中应用必要的预处理步骤(如归一化、裁剪、数据增强)。
  4. 提供索引: __len____getitem__ 使得数据集可以通过索引访问,并知道其总大小。

2. DataLoader

DataLoader 是 PyTorch 中一个非常强大的工具,它建立在 Dataset 之上,负责高效地加载和批处理数据。它的核心功能是:

  • 批处理 (Batching): 将单个样本组合成批次,这是深度学习训练的常用方式,因为它可以提高计算效率,并有助于梯度下降的稳定。
  • 洗牌 (Shuffling): 在每个 epoch 开始时随机打乱数据,以防止模型学习到数据中的顺序模式,并提高模型的泛化能力。
  • 多进程数据加载 (Multiprocessing Data Loading): 可以使用多个工作进程并行加载数据,从而减少数据加载成为训练瓶颈的可能性。
  • 内存固定 (Pin Memory): 可以将张量加载到 CUDA 固定内存中,这可以加快数据传输到 GPU 的速度。

DataLoader 的主要参数:

  • dataset: 必须是 torch.utils.data.Dataset 的实例。这是 DataLoader 从中获取数据的来源。
  • batch_size: 每个批次包含的样本数量。
  • shuffle: 布尔值,如果设置为 True,则在每个 epoch 开始时打乱数据。
  • num_workers: 用于数据加载的子进程数量。设置为 0 意味着数据将在主进程中加载。大于 0 会开启多进程,通常能加快加载速度,但也需要更多内存。
  • drop_last: 布尔值,如果设置为 True,则如果数据集大小不能被 batch_size 整除,则最后一个不完整的批次将被丢弃。
  • collate_fn: 可选参数,一个函数,用于如何将单个样本列表合并成一个批次。默认情况下,它会尝试堆叠张量。如果你有复杂的数据结构(如变长序列),你可能需要自定义这个函数。

DataLoader 的使用:

DataLoader 是一个可迭代对象。你可以直接在 for 循环中使用它来获取批次数据。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np# 假设我们有一个简单的Dataset
class SimpleDataset(Dataset):def __init__(self, num_samples=100):self.data = torch.randn(num_samples, 10) # 100个样本,每个样本10个特征self.labels = torch.randint(0, 2, (num_samples,)) # 100个标签,0或1def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]# 创建数据集实例
my_dataset = SimpleDataset(num_samples=100)# 创建DataLoader实例
train_loader = DataLoader(dataset=my_dataset,batch_size=16,shuffle=True,num_workers=0) # 简单示例,不使用多进程# 迭代DataLoader获取批次数据
for epoch in range(5): # 假设训练5个epochprint(f"\nEpoch {epoch+1}")for batch_idx, (data, labels) in enumerate(train_loader):print(f"  Batch {batch_idx+1}: data shape = {data.shape}, labels shape = {labels.shape}")# 在这里执行模型的前向传播、计算损失、反向传播等训练步骤if batch_idx >= 2: # 只打印前3个批次,避免输出过多break

DataLoaderDataset 的协作:

  • DataLoader 接收一个 Dataset 对象。
  • DataLoader 需要一个批次数据时,它会:
    1. 如果 shuffle=True,它会首先打乱 Dataset 的索引。
    2. 它会选择 batch_size 个索引。
    3. 对于每个选定的索引,它会调用 Dataset__getitem__(idx) 方法来获取单个样本。
    4. 它将这些单个样本集合起来(默认通过 torch.stacktorch.cat),形成一个批次张量。
    5. 最终将批次张量返回给你的训练循环。

3. MNIST 手写数字数据集的了解

MNIST (Modified National Institute of Standards and Technology) 是一个经典的、广泛使用的计算机视觉数据集,被誉为“深度学习的 Hello World”。

主要特点:

  1. 内容: 包含大量手写数字的灰度图像。
  2. 类别: 10 个类别,对应数字 0 到 9。
  3. 图像大小: 每张图像都是 28x28 像素。
  4. 数据量:
    • 训练集: 60,000 张图像,用于训练模型。
    • 测试集: 10,000 张图像,用于评估模型的性能。
  5. 图像格式: 灰度图像,每个像素的值通常在 0 到 255 之间,表示像素亮度。

MNIST 的重要性:

  • 入门级: 简单且足够小,适合初学者学习深度学习的基本概念和 PyTorch 的使用。
  • 基准: 由于其标准化和广泛使用,它经常作为新算法和模型架构的初步测试基准。
  • 低计算需求: 训练一个在 MNIST 上表现良好的模型通常不需要强大的 GPU,普通 CPU 也能完成。

PyTorch 中使用 MNIST:

PyTorch 的 torchvision 库提供了方便的工具来下载和加载 MNIST 数据集。

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 1. 定义数据转换 (Transformations)
# MNIST图像是PIL.Image类型,需要转换为Tensor,并进行归一化。
# 归一化是常用的预处理步骤,将像素值缩放到一个特定范围(例如0到1,或-1到1)。
# 对于MNIST,通常是 (mean=0.1307, std=0.3081),这是根据整个MNIST数据集计算得出的。
transform = transforms.Compose([transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为FloatTensor,并除以255将像素值缩放到0-1transforms.Normalize((0.1307,), (0.3081,)) # 归一化,(mean,) (std,),对于灰度图像是单通道
])# 2. 下载并加载训练数据集
# root: 数据存放的根目录
# train=True: 获取训练集
# download=True: 如果数据不存在,则下载
# transform: 应用上述定义的转换
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 3. 下载并加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 4. 创建 DataLoader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) # num_workers可以根据你的CPU核心数调整
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) # 测试集通常不打乱# 5. 遍历训练数据 (示例)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")for batch_idx, (data, target) in enumerate(train_loader):print(f"训练批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")# data.shape 会是 [batch_size, 1, 28, 28] (1是通道数,28x28是图像尺寸)# target.shape 会是 [batch_size]break # 只打印第一个批次# 6. 遍历测试数据 (示例)
for batch_idx, (data, target) in enumerate(test_loader):print(f"测试批次 {batch_idx+1}: data shape = {data.shape}, target shape = {target.shape}")break

相关文章:

Python Day38

Task: 1.Dataset类的__getitem__和__len__方法(本质是python的特殊方法) 2.Dataloader类 3.minist手写数据集的了解 1. Dataset 类的 __getitem__ 和 __len__ 方法 在 PyTorch (或类似深度学习框架) 中,Dataset 是一个抽象基类&a…...

DeepSeek R1 模型小版本升级,DeepSeek-R1-0528都更新了哪些新特性?

DeepSeek-R1‑0528 技术剖析:思维链再进化,推理性能飙升 目录 版本概览深度思考能力再升级基准测试成绩功能与体验更新API 变动与示例模型开源与下载结语 版本概览 DeepSeek 团队今日发布 DeepSeek‑R1‑0528 —— 基于 DeepSeek V3 Base(2…...

线路板厂家遇到的PCB元件放置的常见问题有哪些?

印刷电路板现在无处不在。尽管大多数人认为这是理所当然的,但工程师和设计师们充分意识到这些电路开发和生产背后的巨大努力。传统的PCB生产涉及复杂的机械和高昂的前期成本,因此必须将制造外包给专业工厂。 说到交货时间,你可能需要几周的时…...

【C/C++】无限长有序数组中查找特定元素

在无限长有序数组中查找特定元素&#xff0c;由于数组长度未知&#xff0c;需先定位搜索范围&#xff0c;再进行二分查找。以下是C实现&#xff1a; #include <iostream> #include <vector> #include <climits> using namespace std;// 假设数组访问函数&am…...

SQL正则表达式总结

这里写目录标题 一、元字符二、正则表达函数1、 regexp_like(x,pattern[,match_option])2、 regexp_instr(x,pattern[,start[,occurrence[,return_option[, match_option]]]]) 3、 REGEXP_SUBSTR(x,pattern[,start[,occurrence[, match_option]]]) 4、 REGEXP_REPLACE(x,patter…...

力扣经典算法篇-13-接雨水(较难,动态规划,加法转减法优化,双指针法)

1、题干 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例 1&#xff1a; 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2,1] 输出&#xff1a;6 解释&#xff1a;上面是由数组 [0,1,0,2,1,0,1,3…...

STM32 -- USB虚拟串口通信

本篇操作: 通过CubeMX Keil&#xff0c;配置STM32作为USB设备端&#xff0c;与电脑上位机进行通信&#xff08;CDC&#xff09;&#xff1b;通用带USB功能的 STM32 芯片 &#xff08;如F1、F4等&#xff0c;系统时钟配置不同&#xff0c;代码通用&#xff09;。 目录 一、 S…...

uni-app开发特殊社交APP

uni-app开发特殊社交APP 目录 1.展示APP功能 2.展示项目结构 3.关于我的GitHub 引言 博主最近自己在GitHub上面上传了一个关于社交软件的项目&#xff08;该项目早已开发完毕&#xff09;, 这个社交软件比较特殊, 被称之为blind-date&#xff0c; blind-date 是基于 uni-…...

Linux中Shell脚本的常用命令

一、设置主机名称 1、通过修改系统文件来修改主机名称 [rootsakura1 桌面]# vim /etc/hostname sakura /etc/hostname&#xff1a;Linux 系统中存储主机名的配置文件。修改完文件后&#xff0c;在当前的shell中是不生效的&#xff0c;需要关闭当前shell后重新开启才能看到效…...

RabbitMQ项目实战

先参考文章&#xff1a;&#xff08;必看&#xff09; 06-MQ基础_mq服务-CSDN博客 07-MQ高级&#xff08;幂等性&#xff09;-CSDN博客 https://cloud.iocoder.cn/message-queue/rabbitmq/#_2-0-%E5%BC%95%E5%85%A5%E4%BE%9D%E8%B5%96%E4%B8%8E%E9%85%8D%E7%BD%AE 1、Rabbi…...

安卓开发用到的设计模式(3)行为型模式

安卓开发用到的设计模式&#xff08;3&#xff09;行为型模式 文章目录 安卓开发用到的设计模式&#xff08;3&#xff09;行为型模式1. 命令模式&#xff08;Command Pattern&#xff09;2. 策略模式&#xff08;Strategy Pattern&#xff09;3. 观察者模式&#xff08;Observ…...

生成模型:从数据学习到创造的 AI 新范式

一、生成模型&#xff1a;定义与核心逻辑 生成模型是一类通过学习数据潜在分布来创造新样本的机器学习模型。其核心目标是构建数据的概率分布模型 P(X)&#xff0c;使生成的样本 X^ 与真实数据 X 具有相似的统计特征。 1.1 与判别模型的本质区别 维度生成模型判别模型核心目…...

尚硅谷redis7 90-92 redis集群分片之集群扩容

90 redis集群分片之集群扩容 三主三从不够用了&#xff0c;进行扩容变为4主4从 问题&#xff1a;1.新建两个redis实例&#xff0c;怎么加入原有集群&#xff1f;2.原有的槽位分3段&#xff0c;又加进来一个槽位怎么算&#xff1f; 新建6387、6388两个服务实例配置文件新建后启…...

RabbitMQ性能调优:关键技术、技巧与最佳实践

RabbitMQ作为一款高可靠、高扩展性的消息中间件&#xff0c;其性能表现直接影响到分布式系统的吞吐量和响应延迟。本文基于RabbitMQ官方文档和最佳实践&#xff0c;结合核心性能优化方向&#xff0c;详细探讨RabbitMQ性能调优的关键技术、技巧和策略。 通过以下优化策略&#…...

系统架构中的组织驱动:康威定律在系统设计中的应用

康威定律&#xff08;Conway’s Law&#xff09; 是由计算机科学家 Melvin Conway 在1967年提出的理论&#xff0c;其核心观点是&#xff1a;“系统的架构设计会不可避免地反映其开发组织的沟通结构。换句话说&#xff0c;软件系统的结构会与构建它的团队的组织结构高度相似。 …...

TypeScript 中高级类型 keyof 与 typeof的场景剖析。

文章目录 前言一、typeof&#xff1a;从值到类型的映射1. 核心概念2. 类型推导示例3. 常见用途 二、keyof&#xff1a;从类型到键的映射1. 核心概念2. 常见用途 三、typeof keyof&#xff1a;强强联合的实战场景1. 场景一&#xff1a;对象属性的安全访问2. 场景二&#xff1a;…...

Android LiveData 详解

一、LiveData 核心概念与特性 1.1 定义与基本功能 LiveData 是 Android Jetpack 架构组件中的一个可观察数据持有者类&#xff0c;其核心功能是实现数据与 UI 的响应式绑定。与传统观察者模式不同&#xff0c;LiveData 具有生命周期感知能力&#xff0c;能够自动根据观察者…...

为什么共现矩阵是高维稀疏的

为什么共现矩阵是高维稀疏的&#xff1f; 共现矩阵&#xff08;Co-occurrence Matrix&#xff09;的高维稀疏性是其固有特性&#xff0c;主要由以下原因导致&#xff1a; 1. 高维性的根本原因 词汇表大小决定维度&#xff1a; 共现矩阵的维度为 ( V \times V )&#xff0c;其…...

离散化算法的二分法应用

我们思考一个问题&#xff1a;其实这里的二分法回归本源也是基于下标映射的原理&#xff0c;只是实现是借助二分的形式。 在排序好的数组中对目标数值进行二分搜索&#xff0c;在 O(logn) 的时间复杂度内找到该数值是整体数据中的第几个。 具体的我们可以如下操作&#xff1a; …...

IntelliJ IDEA 中进行背景设置

&#x1f3a8; ​​一、全局主题切换​​ ​​操作路径​​ File → Settings → Appearance & Behavior → Appearance → Theme​​可选主题​​&#xff1a; ​​Darcula​​&#xff1a;深色模式&#xff08;默认暗黑主题&#xff09;​​IntelliJ Light​​&#xff…...

Dart语言学习指南「专栏简介」

Dart 是 Google 开发的一款开源通用编程语言&#xff0c;它不仅支持客户端和服务器端的应用开发&#xff0c;还因其与 Flutter 框架的深度集成&#xff0c;在移动端和 Web 开发中广受欢迎。Dart 适用于 Android 应用、iOS 应用、物联网&#xff08;IoT&#xff09;项目以及 Web…...

AWS之AI服务

目录 一、AWS AI布局 ​​1. 底层基础设施与芯片​​ ​​2. AI训练框架与平台​​ ​​3. 大模型与应用层​​ ​​4. 超级计算与网络​​ ​​与竞品对比​​ AI服务 ​​1. 机器学习平台​​ ​​2. 预训练AI服务​​ ​​3. 边缘与物联网AI​​ ​​4. 数据与AI…...

Docker 部署项目

使用 Docker 部署项目是一个很好的选择&#xff0c;可以避免服务器环境不兼容的问题&#xff0c;并且能够实现一致性和可移植性。我会给你一个详细的步骤&#xff0c;帮你从零开始理解 Docker&#xff0c;最终在服务器上部署 Roop 项目。 1. 安装 Docker 首先&#xff0c;你需…...

半导体厂房设计建造流程、方案和技术要点-江苏泊苏系统集成有限公司

半导体厂房设计建造流程、方案和技术要点-江苏泊苏系统集成有限公司 半导体厂房的设计建造是一项高度复杂、专业性极强的系统工程&#xff0c;涉及洁净室、微振动控制、电磁屏蔽、特殊气体/化学品管理等关键技术。 一、设计建造流程&#xff1a; 1.需求定义与可行性分析 &a…...

(c++)string的模拟实现

目录 1.构造函数 2.析构函数 3.扩容 1.reserve(扩容不初始化) 2.resize(扩容加初始化) 4.push_back 5.append 6. 运算符重载 1.一个字符 2.一个字符串 7 []运算符重载 8.find 1.找一个字符 2.找一个字符串 9.insert 1.插入一个字符 2.插入一个字符串 9.erase 10…...

一种通用图片红色印章去除的工具设计

朋友今天下午需要处理个事情&#xff0c;问我有没有什么好的办法能够去除&#xff0c;核心问题是要去除图片上的印章。记得以前处理过类似的需求&#xff0c;photoshop操作比较简单&#xff0c;本质是做运算。这种处理方式有很多&#xff0c;比如现在流行的大模型&#xff0c;一…...

企业应用AI对向量数据库选型思考

一、向量数据库概述 向量数据库是一种专门用于存储和检索高维向量数据的数据库系统&#xff0c;它能够高效地处理基于向量相似性的查询&#xff0c;如最近邻搜索等&#xff0c;在人工智能、机器学习等领域的应用中发挥着重要作用&#xff0c;为处理复杂的向量数据提供了有力的…...

时序数据库IoTDB安装学习经验分享

1. JDK安装问题 在安装IoTDB时&#xff0c;我遇到了“无法加载主类”的错误&#xff0c;这通常表明Java环境存在问题。尽管我能正确输出classpath和查询JDK版本&#xff0c;但问题依旧存在。经过查阅相关资料&#xff0c;我发现问题出在多余的classpath设置上。Java编译器和虚…...

RapidOCR集成PP-OCRv5_det mobile模型记录

该文章主要摘取记录RapidOCR集成PP-OCRv5_mobile_det记录&#xff0c;涉及模型转换&#xff0c;模型精度测试等步骤。原文请前往官方博客&#xff1a; https://rapidai.github.io/RapidOCRDocs/main/blog/2025/05/26/rapidocr%E9%9B%86%E6%88%90pp-ocrv5_det%E6%A8%A1%E5%9E%8B…...

当 Redis 作为缓存使用时,如何保证缓存数据与数据库(或其他服务的数据源)之间的一致性?

当 Redis 作为缓存使用时&#xff0c;保证缓存数据与数据库&#xff08;或其他数据源&#xff09;之间的一致性是一个核心挑战。通常&#xff0c;我们追求的是“最终一致性”&#xff0c;而不是“强一致性”&#xff0c;因为强一致性往往会牺牲性能和可用性&#xff0c;这与使用…...