当前位置: 首页 > article >正文

pytorch中dataloader自定义数据集

前言

在深度学习中我们需要使用自己的数据集做训练,因此需要将自定义的数据和标签加载到pytorch里面的dataloader里,也就是自实现一个dataloader。

数据集处理

以花卉识别项目为例,我们分别做出图片的训练集和测试集,训练集的标签和测试集的标签

flower_data/
├── train_filelist/
│   ├── image_0001.jpg
│   └── ...
├── val_filelist/
│   ├── image_1001.jpg
│   └── ...
├── train.txt  # 格式:文件名 标签
└── val.txt

 数据目录的组织方式如上所示。

首先看图片的处理。图片只要做好编号放在同一个文件夹里就好了。

再看标签的处理。标签处理我们自己规定了一种形式,就是图像文件的名称+空格+分类标签。

可以看到前面第一列数据是图像名称,第二列数据是图像的分组,同样的数字为一组。比如分组为0的图像就是同一种花朵。

自定义dataset

源码

import os.path
import numpy as np
import torch
from PIL import Image  # 从PIL库导入Image类
from torch.utils.data import Datasetclass FlowerDataSet(Dataset):"""花朵分类任务数据集类,继承自torch的Dataset类"""def __init__(self, root_dir, ann_file, transform=None):"""初始化数据集实例Args:root_dir (str): 数据集根目录路径ann_file (str): 标注文件路径transform (callable, optional): 数据预处理变换函数"""self.ann_file = ann_fileself.root_dir = root_dir# 加载图片路径与标签的映射字典 {文件名: 标签}self.image_label = self.load_annotations()# 构建完整图片路径列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 构建标签列表 [标签1, 标签2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突self.transform = transformdef __len__(self):"""返回数据集样本数量"""return len(self.image)def __getitem__(self, index):"""获取单个样本数据Args:index (int): 样本索引Returns:tuple: (预处理后的图像数据, 对应的标签)"""# 打开图片文件image = Image.open(self.image[index])# 获取对应标签label = self.label[index]# 应用数据预处理if self.transform:image = self.transform(image)# 将标签转换为torch张量label = torch.from_numpy(np.array(label))return image, labeldef load_annotations(self):"""加载标注文件,解析图片文件名和标签的映射关系Returns:dict: {图片文件名: 对应标签} 的字典"""data_infos = {}with open(self.ann_file) as f:# 读取所有行并分割,每行格式应为 "文件名 标签"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 将标签转换为int64类型的numpy数组data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

解析

1、将标签数据进行读取,组成一个哈希表,哈希表的键是图像的文件名称,哈希表的值是分组标签。

    def load_annotations(self):"""加载标注文件,解析图片文件名和标签的映射关系Returns:dict: {图片文件名: 对应标签} 的字典"""data_infos = {}with open(self.ann_file) as f:# 读取所有行并分割,每行格式应为 "文件名 标签"samples = [x.strip().split(' ') for x in f.readlines()]for filename, label in samples:# 将标签转换为int64类型的numpy数组data_infos[filename] = np.array(label, dtype=np.int64)return data_infos

上面的代码里,在录入标签的时候使用数组进行记录,这是为了兼容多标签的场景。如果不考虑兼容问题,仅考虑在单标签场景下的简单实现,可以用下面的代码:

def load_annotations(self):data_infos = {}with open(self.ann_file) as f:for line in f:filename, label = line.strip().split()  # 直接解包data_infos[filename] = int(label)        # 存为 Python 整数return data_infos# 在 __getitem__ 中直接转为张量
label = torch.tensor(self.labels[index], dtype=torch.long)

2、遍历哈希表,将文件名和标签分别存在两个数组里。这里注意,为了方便后面dataloader按照batch去读取图片,这里要将图片的全路径加到文件名里。

        # 构建完整图片路径列表 [root_dir/文件名1, ...]self.image = [os.path.join(self.root_dir, img) for img in list(self.image_label.keys())]# 构建标签列表 [标签1, 标签2, ...]self.label = [lbl for lbl in list(self.image_label.values())]  # 重命名为lbl避免与导入的label冲突

3、在dataloader向显卡/cpu加载数据的时候会调用getitem方法。比如一个batch里有64个数据,dataloader就会调用64次该方法,将64组图片和标签全部获取后交给运算单元去处理。

    def __getitem__(self, index):"""获取单个样本数据Args:index (int): 样本索引Returns:tuple: (预处理后的图像数据, 对应的标签)"""# 打开图片文件image = Image.open(self.image[index])# 获取对应标签label = self.label[index]# 应用数据预处理if self.transform:image = self.transform(image)# 将标签转换为torch张量label = torch.from_numpy(np.array(label))return image, label

测试dataloader

import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloader import FlowerDataSet  # 假设你的数据集类在dataloader.py中def denormalize(image_tensor):"""将归一化的图像张量转换为可显示的格式"""mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])image = image_tensor.numpy().transpose((1, 2, 0))  # 转换维度顺序image = std * image + mean  # 反归一化image = np.clip(image, 0, 1)  # 限制像素值范围return imagedef test_dataloader():# 定义数据预处理data_transforms = {'train': transforms.Compose([transforms.Resize(64),transforms.RandomRotation(45),transforms.CenterCrop(64),transforms.RandomHorizontalFlip(p=0.5),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'valid': transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 检查文件路径是否存在print("[1/5] 检查文件路径...")required_files = {'train_txt': './flower_data/train.txt','val_txt': './flower_data/val.txt','train_dir': './flower_data/train_filelist','val_dir': './flower_data/val_filelist'}for name, path in required_files.items():if not os.path.exists(path):print(f"❌ 文件/目录不存在: {path}")returnprint(f"✅ {name}: {path} 存在")# 初始化数据集print("\n[2/5] 加载数据集...")try:train_dataset = FlowerDataSet(root_dir=required_files['train_dir'],ann_file=required_files['train_txt'],transform=data_transforms['train'])val_dataset = FlowerDataSet(root_dir=required_files['val_dir'],ann_file=required_files['val_txt'],transform=data_transforms['valid'])print("✅ 数据集加载成功")except Exception as e:print(f"❌ 数据集加载失败: {str(e)}")return# 打印数据集信息print("\n[3/5] 数据集统计:")print(f"训练集样本数: {len(train_dataset)}")print(f"验证集样本数: {len(val_dataset)}")# 检查单个样本print("\n[4/5] 检查单个样本:")sample_idx = 0try:img, label = train_dataset[sample_idx]print(f"图像张量形状: {img.shape} (应接近 torch.Size([3, 64, 64]))")print(f"标签类型: {type(label)} (应为 torch.Tensor)")print(f"标签值: {label.item()} (应为整数)")except Exception as e:print(f"❌ 样本检查失败: {str(e)}")# 可视化样本print("\n[5/5] 可视化训练集样本...")try:plt.figure(figsize=(8, 8))img_show = denormalize(img)plt.imshow(img_show)plt.title(f"Label: {label.item()}")plt.axis('off')plt.show()except Exception as e:print(f"❌ 可视化失败: {str(e)}")# 检查DataLoaderprint("\n[附加] 检查DataLoader:")train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:print(f"\n{name} DataLoader测试:")try:batch = next(iter(loader))images, labels = batchprint(f"批次图像形状: {images.shape} (应接近 [batch, 3, 64, 64])")print(f"批次标签示例: {labels[:5].numpy()}")print(f"像素值范围: [{images.min():.3f}, {images.max():.3f}]")except Exception as e:print(f"❌ {name} DataLoader错误: {str(e)}")if __name__ == '__main__':test_dataloader()

在测试代码中,分别测试了文件路径,dataset是否正常创建,dataset样本数量,dataset样本格式,dataset数据可视化,dataloader数据样式。

在打印日志的时候需要注意,dataset和dataloader里面的变量都是张量形式的,所以需要转换成python标量再打印。比如从dataset里取出的标签label是一个一维张量,需要通过label.item()进行转换。

 在遍历的时候为了简化代码,将两个dataloader放在同一个循环语句中处理,并且通过增加name变量来区分两个dataloader。

for loader, name in [(train_loader, '训练集'), (val_loader, '验证集')]:

相关文章:

pytorch中dataloader自定义数据集

前言 在深度学习中我们需要使用自己的数据集做训练,因此需要将自定义的数据和标签加载到pytorch里面的dataloader里,也就是自实现一个dataloader。 数据集处理 以花卉识别项目为例,我们分别做出图片的训练集和测试集,训练集的标…...

SQL Server:触发器

在 SQL Server Management Studio (SSMS) 中查看数据库触发器的方法如下: 方法一:通过对象资源管理器 连接到 SQL Server 打开 SSMS,连接到目标数据库所在的服务器。 定位到数据库 在左侧的 对象资源管理器 中,展开目标数据库&a…...

标题:利用 Rork 打造定制旅游计划应用程序:一步到位的指南

引言: 在数字化时代,旅游计划应用程序已经成为旅行者不可或缺的工具。但开发一个定制的旅游应用可能需要耗费大量时间与精力。好消息是,Rork 提供了一种快捷且智能的解决方案,让你能轻松实现创意。以下是使用 Rork 创建一个定制旅…...

WebSocket原理详解(二)

WebSocket原理详解(一)-CSDN博客 目录 1.WebSocket协议的帧数据详解 1.1.帧结构 1.2.生成数据帧 2.WebSocket协议控制帧结构详解 2.1.关闭帧 2.2.ping帧 2.3.pong帧 3.WebSocket心跳机制 1.WebSocket协议的帧数据详解 1.1.帧结构 WebSocket客户端与服务器通信的最小单…...

计算声音信号波形的谐波

计算声音信号波形的谐波 1、效果 2、定义 在振动分析中,谐波通常指的是信号中频率是基频整数倍的成分。基频是振动的主要频率,而谐波可能由机械系统中的非线性因素引起。 3、流程 1. 信号生成:生成或加载振动信号数据(模拟或实际数据)。 2. 预处理:预处理数据,如去噪…...

RepoReporter 仿照`TortoiseSVN`项目监视器,能够同时支持SVN和Git仓库

RepoReporter 项目地址 RepoReporter 一个仓库监视器,仿照TortoiseSVN项目监视器,能够同时支持SVN和Git仓库。 工作和学习会用到很多的仓库,每天都要花费大量的时间在频繁切换文件夹来查看日志上。 Git 的 GUI 工具琳琅满目,Git…...

C++多线程的性能优化

高效线程池设计与工作窃取算法实现解析 1. 引言 现代多核处理器环境下,线程池技术是提高程序并发性能的重要手段。本文解析一个采用工作窃取(Work Stealing)算法的高效线程池实现,通过详细代码分析和性能测试展示其优势。 2. 线程池核心设计 2.1 类结…...

【TS学习】(19)TS中的类

在 TypeScript 中,类(Class) 是面向对象编程的核心结构,用于封装数据和行为。TypeScript 的类继承了 JavaScript 的类特性,并增加了类型系统和高级功能的支持(如访问修饰符、存取器和装饰器)。 …...

UI设计系统:如何构建一套高效的设计规范?

UI设计系统:如何构建一套高效的设计规范? 1. 色彩系统的建立与应用 色彩系统是设计系统的基础之一,它不仅影响界面的整体美感,还对用户体验有着深远的影响。首先,设计师需要定义主色调、辅助色和强调色,并…...

深度学习--softmax回归

回归可以用于预测多少的问题,预测房屋出售价格,棒球队可能获胜的的常数或者患者住院的天数。 事实上,我们也对分类问题感兴趣,不是问 多少,而是问哪一个 1 某个电子邮件是否属于垃圾邮件 2 某个用户可能注册还是不注册…...

【计算机网络】记录一次校园网无法上网的解决方法

问题现象 环境:实训室教室内时间:近期突然出现 (推测是学校在施工,部分设备可能出现问题)症状: 连接校园网 SWXY-WIFI 后: 连接速度极慢偶发无 IP 分配(DHCP 失败)即使分…...

Java关于抽象类和抽象方法

引入抽象: 在之前把不同类中的共有成员变量和成员方法提取到父类中叫做继承。然后对于成员方法在不同子类中有不同的内容,对这些方法重新书写叫做重写;不过如果有的子类没有用继承的方法,用别的名字对这个方法命名的话&#xff0…...

第二十一章:Python-Plotly库实现数据动态可视化

Plotly是一个强大的Python可视化库,支持创建高质量的静态、动态和交互式图表。它特别擅长于绘制三维图形,能够直观地展示复杂的数据关系。本文将介绍如何使用Plotly库实现函数的二维和三维可视化,并提供一些优美的三维函数示例。资源绑定附上…...

LeetCode 热题 100_打家劫舍(83_198_中等_C++)(动态规划)

LeetCode 热题 100_打家劫舍(83_198) 题目描述:输入输出样例:题解:解题思路:思路一(动态规划(一维dp数组)):思路二(动态规划&#xff…...

C语言复习--assert断言

assert.h 头⽂件定义了宏 assert() ,⽤于在运⾏时确保程序符合指定条件,如果不符合,就报错终止运行。这个宏常常被称为“断⾔”。 assert(p ! NULL); 代码在程序运⾏到这⼀⾏语句时,验证变量 p 是否等于 NULL 。如果确实不等于 NU…...

嵌入式软件设计规范框架(MISRA-C 2012增强版)

以下是一份基于MISRA-C的嵌入式软件设计规范(完整技术文档框架),包含编码规范、安全设计原则和工程实践要求: 嵌入式软件设计规范(MISRA-C 2012增强版) 一、编码基础规范 1.1 文件组织 头文件保护 /* 示…...

系统思考反馈

最近交付的都是一些持续性的项目,越来越感觉到,系统思考和第五项修炼不只是简单的一门课程,它们能真正融入到我们的日常工作和业务中,帮助我们用更清晰的思维方式解决复杂问题,推动团队协作,激发创新。 特…...

【C++】vector常用方法总结

📝前言: 在C中string常用方法总结中我们讲述了string的常见用法,vector中许多接口与string类似,作者水平有限,所以这篇文章我们主要通过读vector官方文档的方式来学习vector中一些较为常见的重要用法。 🎬个…...

Burpsuite 伪造 IP

可以用于绕过 IP 封禁检测,用来暴力、绕过配额限制。 也可以用来做 ff98sha 出的校赛题,要求用 129 个 /8 网段的 IP 地址访问同一个 domain 插件 - IPRotate 原理:利用云服务商的反向代理服务。把反向代理的域名指向到目标 ip 即可。 http…...

12.小节

1.认识 QLabel 类,能够在界面上显示字符串. 通过 setText 来设置的.参数 QString (Qt 中把 C 里的很多容器类, 进行了重新封装.历史原因) c叫法容器类,java叫法集合类 2.内存泄露,文件资源泄露 3.对象树。Qt 中通过对象树,来统一的释放界面的…...

大模型专题10 —LangGraph高级教程:构建支持网页搜索+人工干预的可追溯对话系统

在本教程中,我们将使用 LangGraph 构建一个支持聊天机器人,该机器人能够: ✅ 通过搜索网络回答常见问题 ✅ 在多次调用之间保持对话状态 ✅ 将复杂查询路由给人工进行审核 ✅ 使用自定义状态来控制其行为 ✅ 进行回溯并探索替代的对话路径 我们将从一个基础的聊天机器人开…...

2025年数智化电商产业带发展研究报告260+份汇总解读|附PDF下载

原文链接:https://tecdat.cn/?p41286 在数字技术与实体经济深度融合的当下,数智化产业带正成为经济发展的关键引擎。 从云南鲜花产业带的直播热销到深圳3C数码的智能转型,数智化正重塑产业格局。2023年数字经济规模突破53.9万亿元&#xff…...

Linux中常用服务器监测命令(性能测试监控服务器实用指令)

1.查看进程 ps -ef|grep 进程名以下指令需要先安装:sysstat,安装指令: yum install sysstat2.查看CPU使用情况(间隔1s打印一个,打印6次) sar -u 1 63.#查看内存使用(间隔1s打印一个,打印6次) sar -r 1 6...

SQL:CASE WHEN使用详解

文章目录 1. 数据转换与映射2. 动态条件筛选3. 多条件分组统计4. 数据排名与分级5. 处理空值与默认值6. 动态排序 CASE WHEN 语句在 SQL 中是一个非常强大且灵活的工具,除了常规的条件判断外,还有很多巧妙的用法,以下为你详细总结&#xff1a…...

linux内核`fixmap`和`memblock`有什么不同?

Linux内核中的fixmap和memblock是两个不同层次的内存管理机制,分别用于不同的场景和阶段。以下是它们的核心区别和联系: 功能与作用 memblock 物理内存管理: memblock是内核启动早期的物理内存分配器,在伙伴系统(Budd…...

基于 GEE 的区域降水数据可视化:从数据处理到等值线绘制

目录 1 引言 2 代码功能概述 3 代码详细解析 3.1 几何对象处理与地图显示 3.2 加载 CHIRPS 降水数据 3.3 筛选不同时间段的降水数据 3.4 绘制降水时间序列图 3.5 计算并可视化短期和长期降水总量 3.6 绘制降水等值线图 4 总结 5 完整代码 6 运行结果 1 引言 在气象…...

曲线拟合 | Matlab基于贝叶斯多项式的曲线拟合

效果一览 代码功能 代码功能简述 目标:实现贝叶斯多项式曲线拟合,动态展示随着数据点逐步增加,模型后验分布的更新过程。 核心步骤: 数据生成:在区间[0,1]生成带噪声的正弦曲线作为训练数据。 参数设置&#xff1a…...

Qt6调试项目找不到Bluetooth Component蓝牙组件

错误如图所示 Failed to find required Qt component "Bluetooth" 解决方法:搜索打开Qt maintenance tool 工具 打开后,找到这个Qt Connectivity,勾选上就能解决该错误...

JAVA- 锁机制介绍 进程锁

进程锁 基于文件的锁基于Socket的锁数据库锁分布式锁基于Redis的分布式锁基于ZooKeeper的分布式锁 实际工作中都是集群部署,通过负载均衡多台服务器工作,所以存在多个进程并发执行情况,而在每台服务器中又存在多个线程并发的情况,…...

Java Spring Boot 与前端结合打造图书管理系统:技术剖析与实现

目录 运行展示引言系统整体架构后端技术实现后端代码文件前端代码文件1. 项目启动与配置2. 实体类设计3. 控制器设计4. 异常处理 前端技术实现1. 页面布局与样式2. 交互逻辑 系统功能亮点1. 分页功能2. 搜索与筛选功能3. 图书操作功能 总结 运行展示 引言 本文将详细剖析一个基…...