【深度学习】Pytorch:加载自定义数据集
本教程将使用
flower_photos数据集演示如何在 PyTorch 中加载和导入自定义数据集。该数据集包含不同花种的图像,每种花的图像存储在以花名命名的子文件夹中。我们将深入讲解每个函数和对象的使用方法,使读者能够推广应用到其他数据集任务中。
flower_photos/
├── daisy/
│ ├── image1.jpg
│ ├── image2.jpg
└── rose/├── image1.jpg├── image2.jpg
...
环境配置
所需工具和库
pip install torch torchvision matplotlib
导入必要的库
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import pathlib
数据集导入方法
定义数据转换
图像转换在计算机视觉任务中至关重要。通过 transforms 对象,我们可以实现图像大小调整、归一化、随机变换等预处理操作。
# 定义图像转换
transform = transforms.Compose([ transforms.Resize((150, 150)), # 调整图像大小为 150x150 transforms.ToTensor(), # 将图像转换为 PyTorch 张量 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化图像数据
]) # 数据路径
data_dir = r"E:\CodeSpace\Deep\data\flower_photos" # 使用 ImageFolder 加载数据
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform) # 计算训练集和测试集的样本数量(80%和20%的划分)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size # 随机划分数据集
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size]) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 获取类别名
class_names = full_dataset.classes
print("类别名:", class_names)
显示部分样本图像
可视化样本数据有助于理解数据集结构和数据质量。
# 定义函数以绘制样本图像
def plot_images(images, labels, class_names):plt.figure(figsize=(10, 10))for i in range(9): # 绘制前 9 张图像plt.subplot(3, 3, i + 1)img = images[i].permute(1, 2, 0) # 将张量维度从 (C, H, W) 转为 (H, W, C)plt.imshow(img * 0.5 + 0.5) # 反归一化处理,恢复到原始像素范围 [0, 1]plt.title(class_names[labels[i]]) # 显示类别标签plt.axis('off') # 去掉坐标轴# 获取部分样本数据用于展示
sample_images, sample_labels = next(iter(train_loader))
plot_images(sample_images, sample_labels, class_names)
自定义数据加载方法
当数据结构复杂或需要额外处理时,可以通过继承 torch.utils.data.Dataset 创建自定义数据加载类。
Dataset 类详解
Dataset 是 PyTorch 中的一个抽象类,用户需要实现以下核心方法:
__init__():初始化方法- 传入数据路径和转换方法。
- 加载所有图像路径并生成类别标签。
__len__():返回数据集大小- 指定数据集中样本数量。
__getitem__():根据索引获取样本数据- 加载指定位置的图像和标签,并进行必要的转换。
代码实现
class CustomFlowerDataset(torch.utils.data.Dataset):def __init__(self, data_dir, transform=None):# 初始化数据集路径和图像转换方法self.data_dir = pathlib.Path(data_dir)self.transform = transformself.image_paths = list(self.data_dir.glob('*/*.jpg')) # 获取所有图像路径self.label_names = sorted(item.name for item in self.data_dir.glob('*/') if item.is_dir())self.label_to_index = {name: idx for idx, name in enumerate(self.label_names)} # 将类别名映射为索引def __len__(self):# 返回数据集大小return len(self.image_paths)def __getitem__(self, idx):# 根据索引获取图像及其标签img_path = self.image_paths[idx]label = self.label_to_index[img_path.parent.name] # 通过父文件夹名获取标签image = Image.open(img_path).convert("RGB") # 确保图像是 RGB 模式if self.transform:image = self.transform(image) # 进行图像预处理return image, label# 使用自定义数据集
custom_dataset = CustomFlowerDataset(data_dir, transform=transform)
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)
随机划分数据集
如果你还希望在这个自定义数据集上随机划分训练集和测试集,可以使用 torch.utils.data.random_split。以下是示例代码:
from torch.utils.data import random_split # 获取数据集长度
full_dataset = CustomFlowerDataset(data_dir, transform=transform) # 计算训练集和测试集的样本数量(80%和20%的划分)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size # 随机划分数据集
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size]) # 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) print(f"训练集大小: {len(train_dataset)}, 测试集大小: {len(test_dataset)}")
数据加载性能优化
num_workers参数:设置并行数据加载线程数。对于多核 CPU,可以显著提高数据加载效率。prefetch_factor参数:控制每个工作线程预取的批次数量。
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2)
Dataset 类扩展建议
- 支持多格式数据读取:通过扩展
__getitem__()来支持其他格式如 PNG、BMP。 - 数据过滤:在
__init__()中根据文件名或元数据筛选特定样本。 - 标签增强:为每个样本生成附加信息,例如图像的元数据或分布特征。
数据集的使用方法
遍历数据集
模型训练前需要遍历数据集以加载图像和标签:
for images, labels in custom_loader:# images 是图像张量,labels 是对应的类别标签print(f"图像张量大小: {images.shape}, 标签: {labels}")
模型输入
数据集加载完成后可直接用于模型训练:
import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
model = nn.Sequential(nn.Flatten(), # 将输入张量展平成一维nn.Linear(150*150*3, 128), # 输入层到隐藏层的全连接层nn.ReLU(), # 激活函数nn.Linear(128, len(class_names)) # 输出层,类别数量等于花的种类数
)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam 优化器# 示例训练过程
for epoch in range(2): # 简单训练两轮for images, labels in custom_loader:outputs = model(images) # 前向传播计算输出loss = criterion(outputs, labels) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward() # 反向传播计算梯度optimizer.step() # 更新模型参数print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
模型评估
加载后的数据集也可用于验证模型性能:
correct = 0
total = 0
model.eval() # 设置模型为评估模式
with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total
print(f"模型准确率: {accuracy:.2f}%")
方法对比与扩展
ImageFolder vs 自定义 Dataset
ImageFolder:适合简单目录结构,快速加载标准图像数据。- 自定义
Dataset:更适合复杂数据结构及自定义逻辑,例如多模态数据处理。
提高模型泛化能力
- 数据增强:通过
transforms.RandomHorizontalFlip()、transforms.ColorJitter()等方法增加数据多样性。 - 归一化技巧:根据数据集的特性调整
mean和std参数。
总结
本教程详细讲解了如何在 PyTorch 中加载和导入 flower_photos 数据集,结合不同方法的讲解使你能根据项目需求灵活选择适合的数据加载方案。同时,我们探讨了优化和扩展方法,希望这些内容能为你的深度学习项目提供有力支持。
相关文章:
【深度学习】Pytorch:加载自定义数据集
本教程将使用 flower_photos 数据集演示如何在 PyTorch 中加载和导入自定义数据集。该数据集包含不同花种的图像,每种花的图像存储在以花名命名的子文件夹中。我们将深入讲解每个函数和对象的使用方法,使读者能够推广应用到其他数据集任务中。 flower_ph…...
最近在盘gitlab.0.先review了一下docker
# 正文 本猿所在产品的代码是保存到了一个本地gitlab实例上,实例是别的同事搭建的。最近又又又想了解一下,而且已经盘了一些了,所以写写记录一下。因为这个事儿没太多的进度压力,索性写到哪儿算哪儿,只要是新了解到的…...
OA项目登录
导入依赖,下面的依赖是在这次OA登录中用到的 <!--web依赖--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.sprin…...
verilogHDL仿真详解
前言 Verilog HDL中提供了丰富的系统任务和系统函数,用于对仿真环境、文件操作、时间控制等进行操作。(后续会进行补充) 正文 一、verilogHDL仿真详解 timescale 1ns/1ps //时间单位为1ns,精度为1ps, //编译…...
基于http协议的天气爬虫
该系统将基于目前比较流行的网络爬虫技术, 对网站上的天气数据进行查询分析, 最终使客户能够通过简单的操作, 快速, 准确的获取目标天气数据。主要包括两部分的功能, 第一部分是天气数据查询, 包括时间段数…...
_STM32关于CPU超频的参考_HAL
MCU: STM32F407VET6 官方最高稳定频率:168MHz 工具:STM32CubeMX 本篇仅仅只是提供超频(默认指的是主频)的简单方法,并未涉及STM32超频极限等问题。原理很简单,通过设置锁相环的倍频系数达到不同的频率&am…...
C#,图论与图算法,任意一对节点之间最短距离的弗洛伊德·沃肖尔(Floyd Warshall)算法与源程序
一、弗洛伊德沃肖尔算法 Floyd-Warshall算法是图的最短路径算法。与Bellman-Ford算法或Dijkstra算法一样,它计算图中的最短路径。然而,Bellman Ford和Dijkstra都是单源最短路径算法。这意味着他们只计算来自单个源的最短路径。另一方面,Floy…...
AWS云计算概览(自用留存,整理中)
目录 一、云概念概览 (1)云计算简介 (2)云计算6大优势 (3)web服务 (4)AWS云采用框架(AWS CAF) 二、云经济学 & 账单 (1)定…...
1. npm 常用命令详解
npm 常用命令详解 npm(Node Package Manager)是 Node.js 的包管理工具,用于安装和管理 Node.js 应用中的依赖库。下面是 npm 的一些常用命令及其详细解释和示例代码。 镜像源 # 查询当前使用的镜像源 npm get registry# 设置为淘宝镜像源 …...
js:根据后端返回数据的最大值进行计算然后设置这个最大值为百分之百,其他的值除这个最大值
问: 现在tabData.value 接收到了后端返回的数据, [{text:人力,percentage:‘90’},{text:物品,percentage:‘20’},{text:物理,percentage:‘50’},{text:服务,percentageÿ…...
【Spring】@Size 无法拦截null的原因
问题复现 在构建 Web 服务时,我们一般都会对一个 HTTP 请求的 Body 内容进行校验,例如我们来看这样一个案例及对应代码。当开发一个学籍管理系统时,我们会提供了一个 API 接口去添加学生的相关信息,其对象定义参考下面的代码&…...
【Block总结】掩码窗口自注意力 (M-WSA)
摘要 论文链接:https://arxiv.org/pdf/2404.07846 论文标题:Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制,旨在解决传统自注意力方法在…...
用 HTML5 Canvas 和 JavaScript 实现雪花飘落特效
这篇文章将带您深入解析使用 HTML5 Canvas 和 JavaScript 实现动态雪花特效的代码原理。 1,效果展示 该效果模拟了雪花从天而降的动态场景,具有以下特点: 雪花数量、大小、透明度和下落速度随机。雪花会在屏幕底部重置到顶部,形成循环效果。随窗口大小动态调整,始终覆盖…...
【cocos creator】【ts】事件派发系统
触发使用: EventTool.emit(“onClick”) 需要监听的地方,onload调用: EventTool.on(“onClick”, this.onClickEvent, this) /**事件派发*/class EventTool {protected static _instance: EventTool null;public static get Instance(): Eve…...
《探索鸿蒙Next上开发人工智能游戏应用的技术难点》
在科技飞速发展的当下,鸿蒙Next系统为应用开发带来了新的机遇与挑战,开发一款运行在鸿蒙Next上的人工智能游戏应用更是备受关注。以下是在开发过程中可能会遇到的一些技术难点: 鸿蒙Next系统适配性 多设备协同:鸿蒙Next的一大特色…...
CSS | CSS实现两栏布局(左边定宽 右边自适应,左右成比自适应)
目录 一、左边定宽 右边自适应 1.浮动 2.利用浮动margin 3.定位margin 4.flex布局 5.table 布局 二、左右成比自适应 1:1 1flex布局 table布局 1:2 flex布局 <div class"father"><div class"left">左边自适应</div><div class"r…...
acwing_3195_有趣的数
acwing_3195_有趣的数 // // Created by HUAWEI on 2024/11/17. // #include<iostream> #include<cstring> #include<algorithm>#define int long longusing namespace std;const int N 1000 50; const int MOD 1e9 7; int C[N][N]; //组合数signed mai…...
Liunx-搭建安装VSOMEIP环境教程 执行 运行VSOMEIP示例demo
本文安装环境为Liunx,搭建安装VSOMEIP环境并运行基础例子。 1. 安装基础环境 使用apt-get来安装基础环境,受网络影响可以分开多次安装。环境好的也可以一次性执行。 sudo apt-get install gcc g sudo apt-get install cmake sudo apt-get install lib…...
Git | git revert命令详解
关注:CodingTechWork 引言 Git 是一个强大的版本控制工具,广泛应用于现代软件开发中。它为开发人员提供了多种功能来管理代码、协作开发和版本控制。在 Git 中,有时我们需要撤销或回退某些提交,而git revert 是一个非常有用的命令…...
ASP.NET Core 中,Cookie 认证在集群环境下的应用
在 ASP.NET Core 中,Cookie 认证在集群环境下的应用通常会遇到一些挑战。主要的问题是 Cookie 存储在客户端的浏览器中,而认证信息(比如 Session 或身份令牌)通常是保存在 Cookie 中,多个应用实例需要共享这些 Cookie …...
艾体宝洞察|语义搜索与关键词搜索?业务的抉择
包括我在内,不少人第一次做搜索功能时,都会觉得这是一件没什么技术含量的事:用户输入几个词,系统返回结果,不就行了吗? 但只要你真正做过搜索系统,尤其是参与过 RAG(Retrieval-Augme…...
芯片研发为什么总是延期?问题不在技术,在管理没闭环
一个芯片项目失败,事后复盘,技术问题往往只占一小部分。更多的时候,是计划没做好,执行过程没人盯,出了问题没人协调,最后交付的时候才发现跑偏了很久。这是行业里非常普遍的现象。法约尔在一百年前提出管理…...
LeetCode 最长回文子串:python 题解
一、核心问题及解决方案(按踩坑频率排序) 问题 1:误删他人持有锁——最基础也最易犯的漏洞 成因:释放锁时未做身份校验,直接执行 DEL 命令删除键。典型场景:服务 A 持有锁后,业务逻辑耗时超过锁…...
【RT-DETR涨点改进】TGRS 2026 | 全网独家创新、特征融合改进篇| 引入STSAM协同时空注意力融合模块,发论文热点创新,注意力能够互相引导强化边界和结构细节,增强目标检测高效涨点
一、本文介绍 🔥本文给大家介绍使用 STSAM协同时空注意力融合模块 改进RT-DETR网络模型,STSAM 是 空间域特征增强模块,通过全局跨时相注意力和局部坐标注意力的并行处理,能有效聚焦真实变化目标,强化边界和结构细节,同时兼顾训练稳定性,为后续浅层特征融合提供高质量特…...
DRM显示框架中的“导演”:深入理解CRTC如何协同Plane与Connector工作
DRM显示框架中的“导演”:深入理解CRTC如何协同Plane与Connector工作 想象一下,当你在电影院观看一部大片时,银幕上的每一帧画面都经过精心编排——主角的位置、特效的时机、放映机的同步,所有这些元素都需要一个核心指挥者来协调…...
2026 年电子邮件认证部署缺陷与安全风险治理研究
摘要 电子邮件作为网络攻击最主要入口,域名伪造与商业邮件欺诈(BEC)持续威胁机构安全。SPF、DKIM、DMARC 作为抵御邮件伪造的核心协议已提出十余年,但大量组织仍存在认知不足、配置错误、长期停留在监控模式等问题,导致…...
OpenClaw技能组合:Qwen2.5-VL-7B串联多个自动化任务流
OpenClaw技能组合:Qwen2.5-VL-7B串联多个自动化任务流 1. 为什么需要任务流串联 上周我需要完成一个市场竞品分析的周报,整个过程让我意识到手动操作的效率瓶颈。首先要在电商平台截图商品页面,然后用OCR工具提取价格信息,接着把…...
ChatGPT_JCM深色模式实现:保护眼睛的界面显示方案
ChatGPT_JCM深色模式实现:保护眼睛的界面显示方案 【免费下载链接】ChatGPT_JCM 项目地址: https://gitcode.com/gh_mirrors/ch/ChatGPT_JCM ChatGPT_JCM是一款功能强大的AI交互工具,其深色模式实现为用户提供了舒适的夜间使用体验,有…...
SEO_ 揭秘影响搜索引擎排名的核心SEO因素
SEO的核心因素解析:提升搜索引擎排名的关键路径 在当今数字化时代,搜索引擎优化(SEO)已经成为每个网站和企业获取有效流量的重要途径。究竟有哪些核心因素影响搜索引擎的排名呢?本文将深入探讨这些核心SEO因素&#x…...
保姆级教程:用ArduPilot给无人车/船配置避障(附MR72雷达、TFmini Plus参数)
保姆级教程:用ArduPilot为无人车/船配置毫米波与激光雷达避障系统 当你的无人车在野外自动巡航时突然检测到前方障碍物,是紧急刹车还是智能绕行?水面无人船在夜间航行如何避开漂浮物?本文将手把手带你完成从硬件选型到参数调优的全…...
