PyTorch中定义自己的数据集
文章目录
- 1. 简介
- 2. 查看PyTorch自带的数据集(可视化)
- 3. 准备材料
- 3.1 图片数据
- 3.2 标签数据
- 4. 方法
1. 简介
尽管PyTorch提供了许多自带的数据集,如MNIST、CIFAR-10、ImageNet等,但它们对于没有经验的用户来说,理解数据加载器的工作原理以及如何正确地配置数据加载器可能会有一定难度。 用户需要了解所使用的数据集,包括数据集的内容、结构、标签等信息。对于一些复杂的数据集,用户可能需要理解数据集的结构和标签的含义。通过定义自己的数据集类,您可以更好地控制数据的加载和处理过程,提高代码的灵活性、可读性和可维护性,同时更好地满足模型训练的需求。
2. 查看PyTorch自带的数据集(可视化)
为了更好的定义自己的数据集,我们首先查看PyTorch自带的数据集的内容,代码如下
# 导入所需的库
import matplotlib.pyplot as plt # 导入Matplotlib库,用于可视化
import torch # 导入PyTorch库
from torchvision.datasets import MNIST # 从torchvision中导入MNIST数据集
from torchvision import transforms # 导入transforms模块,用于数据预处理
import numpy as np # 导入NumPy库# 加载MNIST数据集
train_mnist_data = MNIST(root='./data', # 数据集存储路径train=True, # 加载训练集transform=transforms.Compose([transforms.Resize(size=(28, 28)), transforms.ToTensor()]), # 数据预处理操作download=True) # 如果数据集不存在,则自动下载# 设置要显示的样本数量
num_samples = 10# 创建包含多个子图的大图窗口
fig, axes = plt.subplots(1, num_samples, figsize=(10, 6))# 遍历选择要显示的样本
for i in range(num_samples):# 从数据集中获取图像数据和标签image, label = train_mnist_data[i]# 在子图中显示图像axes[i].imshow(image.squeeze().numpy(), cmap='gray') # 使用imshow函数显示图像,将张量转换为NumPy数组axes[i].set_title(f"Label: {label}") # 设置子图标题,显示图像对应的标签axes[i].axis('off') # 关闭坐标轴显示# 将图像保存为PNG格式的图片文件,文件名以图像的标签命名plt.imsave(f"./data/mnist_images/{label}.png", image.squeeze().numpy(), cmap='gray')# 显示图形窗口
plt.show()
这里,我们使用MNIST
类加载MNIST数据集。在加载数据集时,通过transform
参数指定了数据预处理操作,包括将图像大小调整为28x28像素,并将图像转换为张量。train=True
表示加载训练集,download=True
表示如果数据集不存在则自动下载到指定的路径。
接下来,我们选择一些样本进行可视化。我们在一个子图中显示了10个样本,每个样本对应一个数字图像和其对应的标签。通过循环遍历这些样本,从数据集中获取图像数据和标签,并使用Matplotlib的imshow()
函数将图像显示在子图中。
同时,使用imsave()
函数将每个图像保存为PNG格式的图片文件,文件名以标签命名。最后,使用plt.show()
显示图形窗口,显示图像的同时也会将图像保存到指定的路径中。这段代码的执行结果是显示10张MNIST数据集中的数字图像,并将这些图像保存到指定路径下。保存的图片如下所示
通过上面程序可以看到,数据集主要是由图片数据和对应的标签构成,那么我们就可以用这两个主要构成成分来构建自己的数据集。
3. 准备材料
3.1 图片数据
这里我们就用刚才保存的十张图片,即
当然,你也可以准备其它的图片,并给图片分别命名为“0.png, 1.png, …”。
这里,十张图片的相对路径为
imgs_path = "./data/mnist_images"
注:你们要根据自己存储的路径来给定。
3.2 标签数据
创建一个txt文件,为每一幅图片指定标签数据,如下所示
这里,txt文件的相对路径为
labels_path = "labels.txt"
4. 方法
在PyTorch中,您可以通过创建一个自定义的数据集类来定义自己的数据集。这个自定义类需要继承自torch.utils.data.Dataset
类,并且实现两个主要的方法:__len__
和 __getitem__
。__len__
方法应该返回数据集的长度,而__getitem__
方法则根据给定的索引返回数据集中的样本。
下面我们展示如何创建一个自定义的数据集类:
import os # 导入os模块,用于操作文件路径
from PIL import Image # 导入PIL库中的Image模块,用于图像处理
import torch # 导入PyTorch库
from torch.utils.data import Dataset # 从torch.utils.data模块导入Dataset类,用于定义自定义数据集
from torchvision import transforms # 导入transforms模块,用于数据预处理
import numpy as np # 导入NumPy库,用于数值处理
import matplotlib.pyplot as plt # 导入Matplotlib库,用于可视化class CustomDataset(Dataset):def __init__(self, image_dir, label_file, transform=None):super().__init__() # 调用父类的构造函数self.image_dir = image_dir # 图像数据的路径self.label_file = label_file # 标签文本的路径self.transform = transform # 数据预处理操作self.samples = self._load_samples() # 加载数据集样本信息def _load_samples(self):samples = [] # 存储样本信息的列表with open(self.label_file, 'r') as f: # 打开标签文本文件for line in f: # 逐行读取标签文本文件中的内容image_name, label = line.strip().split(',') # 根据逗号分隔每行内容,获取图像文件名和标签image_path = os.path.join(self.image_dir, image_name) # 拼接图像文件的完整路径samples.append((image_path, int(label))) # 将图像路径和标签组成元组,加入样本列表return samples # 返回样本列表def __len__(self):return len(self.samples) # 返回数据集样本的数量def __getitem__(self, index):image_path, label = self.samples[index] # 获取指定索引处的图像路径和标签image = Image.open(image_path).convert('L') # 打开图像文件并将其转换为灰度图像if self.transform: # 如果定义了数据预处理操作image = self.transform(image) # 对图像进行预处理操作return image, label # 返回预处理后的图像和标签# 设置图片数据路径和标签文本路径
image_dir = './data/mnist_images' # 图像数据的路径
label_file = 'labels.txt' # 标签文本的路径# 定义数据预处理操作,根据需要添加其他预处理操作
transform = transforms.Compose([transforms.Resize((28, 28)), # 调整图像大小transforms.ToTensor(), # 将图像转换为张量
])# 创建自定义数据集实例
custom_dataset = CustomDataset(image_dir, label_file, transform=transform)# 创建数据加载器
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=1, shuffle=False)# 遍历数据加载器中的每个批次数据
for batch_images, batch_labels in data_loader:# 使用squeeze()函数去除图像张量中的单维度,将图像数据转换为NumPy数组,并存储在变量image中image = batch_images.squeeze().numpy()# 使用imshow()函数显示图像,cmap='gray'指定使用灰度色彩映射plt.imshow(image, cmap='gray')# 设置图像标题,显示图像对应的标签,使用f-string格式化字符串,将batch_labels转换为Python标量并获取其值plt.title(f"Label: {batch_labels.item()}")# 关闭坐标轴显示,即不显示坐标轴plt.axis('off')# 显示图形窗口plt.show()
这段代码实现了加载自定义数据集,并使用 PyTorch 的 DataLoader 将数据加载成批次,然后逐批次地展示图像。
相关文章:

PyTorch中定义自己的数据集
文章目录 1. 简介2. 查看PyTorch自带的数据集(可视化)3. 准备材料3.1 图片数据3.2 标签数据 4. 方法 1. 简介 尽管PyTorch提供了许多自带的数据集,如MNIST、CIFAR-10、ImageNet等,但它们对于没有经验的用户来说,理解数据加载器的工作原理以及…...

助力数字农林业发展服务香榧智慧种植,基于YOLOv5全系列【n/s/m/l/x】参数模型开发构建香榧种植场景下香榧果实检测识别系统
作为一个生在北方但在南方居住多年的人,居然头一次听过香榧(fei)这种作物,而且这个字还不会念,查了以后才知道读音(fei),三声,这着实引起了我的好奇心,我相信…...

2024 年 4 月区块链游戏研报:市场低迷中活跃用户数创新高
2024 年 4 月区块链游戏研报 作者:stellafootprint.network 数据来源:GameFi 研究页面 2024 年 4 月,Web3 游戏领域在经历 3 月创纪录的表现后,迎来了显著波动。比特币自历史高位回调,月跌幅达到 10.4%。与此同时&a…...

排序(一)----冒泡排序,插入排序
前言 今天讲一些简单的排序,冒泡排序和插入排序,但是这两个排序时间复杂度较大,只是起到一定的学习作用,只需要了解并会使用就行,本文章是以升序为例子来介绍的 一冒泡排序 思路 冒泡排序是一种简单的排序算法,它重复地遍历要排序的序列,每次比较相邻…...

springcloud简单了解及上手
springcloud微服务框架简单上手 文章目录 springcloud微服务框架简单上手一、SpringCloud简单介绍1.1 单体架构1.2 分布式架构1.3 微服务 二、SpringCloud与SpringBoot的版本对应关系2022.x 分支2021.x 分支2.2.x 分支 三、Nacos注册中心3.1 认识和安装Nacos3.2 配置Nacos3.3 n…...
Halcon与深度学习框架结合进行图像分析
Halcon 是一款强大的机器视觉软件,而深度学习框架如 TensorFlow 或 PyTorch 在图像识别和分类任务中表现出色。结合两者的优势,可以实现复杂的图像分析任务。Halcon 负责图像预处理和特征提取,而深度学习框架则利用这些特征进行高级分析和识别…...

STL----push,insert,empalce
push_back和emplace_back的区别 #include <iostream> #include <vector>using namespace std; class testDemo { public:testDemo(int n) :num(n) {cout << "构造函数" << endl;}testDemo(const testDemo& other) :num(other.num) {cou…...

解决OpenHarmony设备开发Device Tools工具的QUICK ACCESS一直为空
今天重新安装了OpenHarmony设备开发的环境,在安装过程中,到了工程之后,QUICK ACCESS一直为空。如下图红色大方框的内容一开始没有。 解决方案: 在此记录我的原因,我的原因主要是deveco device tools的远程连接的是z…...
k8s拉起一个pod底层是如何运行的
在Kubernetes中,当你尝试启动一个Pod时,底层的运行方式是由Kubelet服务来管理的。以下是Pod启动过程的简化概述: Kubernetes API Server接收到创建Pod的请求。 API Server将Pod的元数据存储到etcd中,以便于Pod的调度和跟踪。 Sc…...

Java代理模式的实现详解
一、前言 1.1、说明 本文章是在学习mybatis框架源码的过程中,发现对于动态代理Mapper接口这一块的代理实现还是有些遗忘和陌生,因此在本文章中就Java实现代理模式的过程进行一个学习和总结。 1.2、参考文章 《设计模式》(第2版࿰…...

数据结构与算法===优先队列
文章目录 前言一、优先队列二、应用场景三、代码实现总结 前言 之前写过很多数据结构与算法相关的了,今天看一个新的数据结构,优先队列。优先队列类似队列,却又优先于队列,是堆实现的。接下来详细看看。 一、优先队列 优先队列一…...

HTML常用标签-超链接标签
超链接标签 点击后带有链接跳转的标签 ,也叫作a标签 href属性用于定义连接 href中可以使用绝对路径,以/开头,始终以一个固定路径作为基准路径作为出发点href中也可以使用相对路径,不以/开头,以当前文件所在路径为出发点href中也可以定义完整的URL target用于定义打开的方式 _b…...

财务管理|基于SprinBoot+vue的财务管理系统(源码+数据库+文档)
财务管理系统 目录 基于SprinBootvue的财务管理系统 一、前言 二、系统设计 三、系统功能设计 系统功能实现 1管理员功能模块 2员工功能模块 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介绍࿱…...

快速学习SpringAi
Spring AI是AI工程师的一个应用框架,它提供了一个友好的API和开发AI应用的抽象,旨在简化AI应用的开发工序,例如开发一款基于ChatGPT的对话应用程序。通过使用Spring Ai使我们更简单直接使用chatgpt 1.创建项目 jdk17 引入依赖 2.依赖配置 …...

谈谈 Spring 的过滤器和拦截器
前言 我们在进行 Web 应用开发时,时常需要对请求进行拦截或处理,故 Spring 为我们提供了过滤器和拦截器来应对这种情况。那么两者之间有什么不同呢?本文将详细讲解两者的区别和对应的使用场景。 (本文的代码实现首先是基于 Sprin…...
请介绍下H264的多参考帧技术及其应用场景,并请说明下为什么要有多参考帧?
H.264(也称为H.264/AVC)的多参考帧机制是其编码效率和质量提升的关键部分。这个机制允许编码器在编码当前帧时,参考多个之前已编码的帧。这种多参考帧的方法为编码器提供了更多的选择,使其能够更准确地预测当前帧的内容࿰…...
第6章 Elasticsearch,分布式搜索引擎【仿牛客网社区论坛项目】
第6章 Elasticsearch,分布式搜索引擎【仿牛客网社区论坛项目】 前言推荐项目总结第6章 Elasticsearch,分布式搜索引擎1.Elasticsearch入门2.Spring整合ElasticsearchDiscussPostRepositoryDiscussPostControllerEventConsumer 3.开发社区搜索功能 最后 前…...
odoo 全局调整list_controller中默认方法(form_controller和kanban_controller等亦可以同样操作)
需求说明 工作中遇到需要调整odoo原生的tree hearder button显示逻辑,又不可以直接跳转odoo源码,故新加个js全局替换对应的方法,以实现对应功能的同时不影响后期odoo版本升级。 odoo 全局调整list_controller方法示例 创建一个js放到stati…...
大模型日报2024-05-13
大模型日报 2024-05-13 大模型资讯 谷歌推出Gemini生成式AI平台 摘要: 生成式人工智能正在改变我们与技术的互动方式。谷歌最近推出了名为Gemini的新平台,该平台代表了其在生成式AI领域的最新进展。Gemini平台集成了一系列先进的工具和功能,旨在为用户提…...
【使用Condition来模拟生产消费】
使用Condition来模拟生产消费 1. 关于ReentrantLock 和condition的认知?2.使用condition实现生产者-消费者1. 关于ReentrantLock 和condition的认知? /*Q: ReentrantLock是如何实现管理锁和线程的?A: ReentrantLock是并发包中 一个类,它实现了Lock接口,提供了比内置synch…...

对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...
Spring Boot面试题精选汇总
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...

Docker 本地安装 mysql 数据库
Docker: Accelerated Container Application Development 下载对应操作系统版本的 docker ;并安装。 基础操作不再赘述。 打开 macOS 终端,开始 docker 安装mysql之旅 第一步 docker search mysql 》〉docker search mysql NAME DE…...

图解JavaScript原型:原型链及其分析 | JavaScript图解
忽略该图的细节(如内存地址值没有用二进制) 以下是对该图进一步的理解和总结 1. JS 对象概念的辨析 对象是什么:保存在堆中一块区域,同时在栈中有一块区域保存其在堆中的地址(也就是我们通常说的该变量指向谁&…...
数据库正常,但后端收不到数据原因及解决
从代码和日志来看,后端SQL查询确实返回了数据,但最终user对象却为null。这表明查询结果没有正确映射到User对象上。 在前后端分离,并且ai辅助开发的时候,很容易出现前后端变量名不一致情况,还不报错,只是单…...
【题解-洛谷】P10480 可达性统计
题目:P10480 可达性统计 题目描述 给定一张 N N N 个点 M M M 条边的有向无环图,分别统计从每个点出发能够到达的点的数量。 输入格式 第一行两个整数 N , M N,M N,M,接下来 M M M 行每行两个整数 x , y x,y x,y,表示从 …...
基于Java项目的Karate API测试
Karate 实现了可以只编写Feature 文件进行测试,但是对于熟悉Java语言的开发或是测试人员,可以通过编程方式集成 Karate 丰富的自动化和数据断言功能。 本篇快速介绍在Java Maven项目中编写和运行测试的示例。 创建Maven项目 最简单的创建项目的方式就是创建一个目录,里面…...

[C++错误经验]case语句跳过变量初始化
标题:[C错误经验]case语句跳过变量初始化 水墨不写bug 文章目录 一、错误信息复现二、错误分析三、解决方法 一、错误信息复现 write.cc:80:14: error: jump to case label80 | case 2:| ^ write.cc:76:20: note: crosses initialization…...
触发DMA传输错误中断问题排查
在STM32项目中,集成BLE模块后触发DMA传输错误中断(DMA2_Stream1_IRQHandler进入错误流程),但单独运行BLE模块时正常,表明问题可能源于原有线程与BLE模块的交互冲突。以下是逐步排查与解决方案: 一、问题根源…...

python3GUI--基于PyQt5+DeepSort+YOLOv8智能人员入侵检测系统(详细图文介绍)
文章目录 一.前言二.技术介绍1.PyQt52.DeepSort3.卡尔曼滤波4.YOLOv85.SQLite36.多线程7.入侵人员检测8.ROI区域 三.核心功能1.登录注册1.登录2.注册 2.主界面1.主界面简介2.数据输入3.参数配置4.告警配置5.操作控制台6.核心内容显示区域7.检…...