当前位置: 首页 > news >正文

【深度学习】Pytorch:加载自定义数据集

本教程将使用 flower_photos 数据集演示如何在 PyTorch 中加载和导入自定义数据集。该数据集包含不同花种的图像,每种花的图像存储在以花名命名的子文件夹中。我们将深入讲解每个函数和对象的使用方法,使读者能够推广应用到其他数据集任务中。

flower_photos/
├── daisy/
│   ├── image1.jpg
│   ├── image2.jpg
└── rose/├── image1.jpg├── image2.jpg
...

环境配置

所需工具和库

pip install torch torchvision matplotlib

导入必要的库

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
import pathlib

数据集导入方法

定义数据转换

图像转换在计算机视觉任务中至关重要。通过 transforms 对象,我们可以实现图像大小调整、归一化、随机变换等预处理操作。

# 定义图像转换  
transform = transforms.Compose([  transforms.Resize((150, 150)),  # 调整图像大小为 150x150  transforms.ToTensor(),  # 将图像转换为 PyTorch 张量  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化图像数据  
])  # 数据路径  
data_dir = r"E:\CodeSpace\Deep\data\flower_photos"  # 使用 ImageFolder 加载数据  
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)  # 计算训练集和测试集的样本数量(80%和20%的划分)  
train_size = int(0.8 * len(full_dataset))  
test_size = len(full_dataset) - train_size  # 随机划分数据集  
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  # 创建数据加载器  
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)  # 获取类别名  
class_names = full_dataset.classes  
print("类别名:", class_names)

显示部分样本图像

可视化样本数据有助于理解数据集结构和数据质量。

# 定义函数以绘制样本图像
def plot_images(images, labels, class_names):plt.figure(figsize=(10, 10))for i in range(9):  # 绘制前 9 张图像plt.subplot(3, 3, i + 1)img = images[i].permute(1, 2, 0)  # 将张量维度从 (C, H, W) 转为 (H, W, C)plt.imshow(img * 0.5 + 0.5)  # 反归一化处理,恢复到原始像素范围 [0, 1]plt.title(class_names[labels[i]])  # 显示类别标签plt.axis('off')  # 去掉坐标轴# 获取部分样本数据用于展示
sample_images, sample_labels = next(iter(train_loader))
plot_images(sample_images, sample_labels, class_names)

自定义数据加载方法

当数据结构复杂或需要额外处理时,可以通过继承 torch.utils.data.Dataset 创建自定义数据加载类。

Dataset 类详解

Dataset 是 PyTorch 中的一个抽象类,用户需要实现以下核心方法:

  1. __init__():初始化方法
    • 传入数据路径和转换方法。
    • 加载所有图像路径并生成类别标签。
  2. __len__():返回数据集大小
    • 指定数据集中样本数量。
  3. __getitem__():根据索引获取样本数据
    • 加载指定位置的图像和标签,并进行必要的转换。

代码实现

class CustomFlowerDataset(torch.utils.data.Dataset):def __init__(self, data_dir, transform=None):# 初始化数据集路径和图像转换方法self.data_dir = pathlib.Path(data_dir)self.transform = transformself.image_paths = list(self.data_dir.glob('*/*.jpg'))  # 获取所有图像路径self.label_names = sorted(item.name for item in self.data_dir.glob('*/') if item.is_dir())self.label_to_index = {name: idx for idx, name in enumerate(self.label_names)}  # 将类别名映射为索引def __len__(self):# 返回数据集大小return len(self.image_paths)def __getitem__(self, idx):# 根据索引获取图像及其标签img_path = self.image_paths[idx]label = self.label_to_index[img_path.parent.name]  # 通过父文件夹名获取标签image = Image.open(img_path).convert("RGB")  # 确保图像是 RGB 模式if self.transform:image = self.transform(image)  # 进行图像预处理return image, label# 使用自定义数据集
custom_dataset = CustomFlowerDataset(data_dir, transform=transform)
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)

随机划分数据集

如果你还希望在这个自定义数据集上随机划分训练集和测试集,可以使用 torch.utils.data.random_split。以下是示例代码:

from torch.utils.data import random_split  # 获取数据集长度  
full_dataset = CustomFlowerDataset(data_dir, transform=transform)  # 计算训练集和测试集的样本数量(80%和20%的划分)  
train_size = int(0.8 * len(full_dataset))  
test_size = len(full_dataset) - train_size  # 随机划分数据集  
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  # 创建数据加载器  
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)  print(f"训练集大小: {len(train_dataset)}, 测试集大小: {len(test_dataset)}")  

数据加载性能优化

  • num_workers 参数:设置并行数据加载线程数。对于多核 CPU,可以显著提高数据加载效率。
  • prefetch_factor 参数:控制每个工作线程预取的批次数量。
custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2)

Dataset 类扩展建议

  1. 支持多格式数据读取:通过扩展 __getitem__() 来支持其他格式如 PNG、BMP。
  2. 数据过滤:在 __init__() 中根据文件名或元数据筛选特定样本。
  3. 标签增强:为每个样本生成附加信息,例如图像的元数据或分布特征。

数据集的使用方法

遍历数据集

模型训练前需要遍历数据集以加载图像和标签:

for images, labels in custom_loader:# images 是图像张量,labels 是对应的类别标签print(f"图像张量大小: {images.shape}, 标签: {labels}")

模型输入

数据集加载完成后可直接用于模型训练:

import torch.nn as nn
import torch.optim as optim# 定义一个简单的神经网络模型
model = nn.Sequential(nn.Flatten(),  # 将输入张量展平成一维nn.Linear(150*150*3, 128),  # 输入层到隐藏层的全连接层nn.ReLU(),  # 激活函数nn.Linear(128, len(class_names))  # 输出层,类别数量等于花的种类数
)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失适用于多分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器# 示例训练过程
for epoch in range(2):  # 简单训练两轮for images, labels in custom_loader:outputs = model(images)  # 前向传播计算输出loss = criterion(outputs, labels)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新模型参数print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

模型评估

加载后的数据集也可用于验证模型性能:

correct = 0
total = 0
model.eval()  # 设置模型为评估模式
with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total
print(f"模型准确率: {accuracy:.2f}%")

方法对比与扩展

ImageFolder vs 自定义 Dataset

  • ImageFolder:适合简单目录结构,快速加载标准图像数据。
  • 自定义 Dataset:更适合复杂数据结构及自定义逻辑,例如多模态数据处理。

提高模型泛化能力

  • 数据增强:通过 transforms.RandomHorizontalFlip()transforms.ColorJitter() 等方法增加数据多样性。
  • 归一化技巧:根据数据集的特性调整 meanstd 参数。

总结

本教程详细讲解了如何在 PyTorch 中加载和导入 flower_photos 数据集,结合不同方法的讲解使你能根据项目需求灵活选择适合的数据加载方案。同时,我们探讨了优化和扩展方法,希望这些内容能为你的深度学习项目提供有力支持。

相关文章:

【深度学习】Pytorch:加载自定义数据集

本教程将使用 flower_photos 数据集演示如何在 PyTorch 中加载和导入自定义数据集。该数据集包含不同花种的图像,每种花的图像存储在以花名命名的子文件夹中。我们将深入讲解每个函数和对象的使用方法,使读者能够推广应用到其他数据集任务中。 flower_ph…...

最近在盘gitlab.0.先review了一下docker

# 正文 本猿所在产品的代码是保存到了一个本地gitlab实例上,实例是别的同事搭建的。最近又又又想了解一下,而且已经盘了一些了,所以写写记录一下。因为这个事儿没太多的进度压力,索性写到哪儿算哪儿,只要是新了解到的…...

OA项目登录

导入依赖,下面的依赖是在这次OA登录中用到的 <!--web依赖--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.sprin…...

verilogHDL仿真详解

前言 Verilog HDL中提供了丰富的系统任务和系统函数&#xff0c;用于对仿真环境、文件操作、时间控制等进行操作。&#xff08;后续会进行补充&#xff09; 正文 一、verilogHDL仿真详解 timescale 1ns/1ps //时间单位为1ns&#xff0c;精度为1ps&#xff0c; //编译…...

基于http协议的天气爬虫

该系统将基于目前比较流行的网络爬虫技术&#xff0c; 对网站上的天气数据进行查询分析&#xff0c; 最终使客户能够通过简单的操作&#xff0c; 快速&#xff0c; 准确的获取目标天气数据。主要包括两部分的功能&#xff0c; 第一部分是天气数据查询&#xff0c; 包括时间段数…...

_STM32关于CPU超频的参考_HAL

MCU: STM32F407VET6 官方最高稳定频率&#xff1a;168MHz 工具&#xff1a;STM32CubeMX 本篇仅仅只是提供超频&#xff08;默认指的是主频&#xff09;的简单方法&#xff0c;并未涉及STM32超频极限等问题。原理很简单&#xff0c;通过设置锁相环的倍频系数达到不同的频率&am…...

C#,图论与图算法,任意一对节点之间最短距离的弗洛伊德·沃肖尔(Floyd Warshall)算法与源程序

一、弗洛伊德沃肖尔算法 Floyd-Warshall算法是图的最短路径算法。与Bellman-Ford算法或Dijkstra算法一样&#xff0c;它计算图中的最短路径。然而&#xff0c;Bellman Ford和Dijkstra都是单源最短路径算法。这意味着他们只计算来自单个源的最短路径。另一方面&#xff0c;Floy…...

AWS云计算概览(自用留存,整理中)

目录 一、云概念概览 &#xff08;1&#xff09;云计算简介 &#xff08;2&#xff09;云计算6大优势 &#xff08;3&#xff09;web服务 &#xff08;4&#xff09;AWS云采用框架&#xff08;AWS CAF&#xff09; 二、云经济学 & 账单 &#xff08;1&#xff09;定…...

1. npm 常用命令详解

npm 常用命令详解 npm&#xff08;Node Package Manager&#xff09;是 Node.js 的包管理工具&#xff0c;用于安装和管理 Node.js 应用中的依赖库。下面是 npm 的一些常用命令及其详细解释和示例代码。 镜像源 # 查询当前使用的镜像源 npm get registry# 设置为淘宝镜像源 …...

js:根据后端返回数据的最大值进行计算然后设置这个最大值为百分之百,其他的值除这个最大值

问&#xff1a; 现在tabData.value 接收到了后端返回的数据&#xff0c; [{text:人力,percentage&#xff1a;‘90’}&#xff0c;{text:物品,percentage&#xff1a;‘20’}&#xff0c;{text:物理,percentage&#xff1a;‘50’}&#xff0c;{text:服务,percentage&#xff…...

【Spring】@Size 无法拦截null的原因

问题复现 在构建 Web 服务时&#xff0c;我们一般都会对一个 HTTP 请求的 Body 内容进行校验&#xff0c;例如我们来看这样一个案例及对应代码。当开发一个学籍管理系统时&#xff0c;我们会提供了一个 API 接口去添加学生的相关信息&#xff0c;其对象定义参考下面的代码&…...

【Block总结】掩码窗口自注意力 (M-WSA)

摘要 论文链接&#xff1a;https://arxiv.org/pdf/2404.07846 论文标题&#xff1a;Transformer-Based Blind-Spot Network for Self-Supervised Image Denoising Masked Window-Based Self-Attention (M-WSA) 是一种新颖的自注意力机制&#xff0c;旨在解决传统自注意力方法在…...

用 HTML5 Canvas 和 JavaScript 实现雪花飘落特效

这篇文章将带您深入解析使用 HTML5 Canvas 和 JavaScript 实现动态雪花特效的代码原理。 1,效果展示 该效果模拟了雪花从天而降的动态场景,具有以下特点: 雪花数量、大小、透明度和下落速度随机。雪花会在屏幕底部重置到顶部,形成循环效果。随窗口大小动态调整,始终覆盖…...

【cocos creator】【ts】事件派发系统

触发使用&#xff1a; EventTool.emit(“onClick”) 需要监听的地方&#xff0c;onload调用&#xff1a; EventTool.on(“onClick”, this.onClickEvent, this) /**事件派发*/class EventTool {protected static _instance: EventTool null;public static get Instance(): Eve…...

《探索鸿蒙Next上开发人工智能游戏应用的技术难点》

在科技飞速发展的当下&#xff0c;鸿蒙Next系统为应用开发带来了新的机遇与挑战&#xff0c;开发一款运行在鸿蒙Next上的人工智能游戏应用更是备受关注。以下是在开发过程中可能会遇到的一些技术难点&#xff1a; 鸿蒙Next系统适配性 多设备协同&#xff1a;鸿蒙Next的一大特色…...

CSS | CSS实现两栏布局(左边定宽 右边自适应,左右成比自适应)

目录 一、左边定宽 右边自适应 1.浮动 2.利用浮动margin 3.定位margin 4.flex布局 5.table 布局 二、左右成比自适应 1:1 1flex布局 table布局 1:2 flex布局 <div class"father"><div class"left">左边自适应</div><div class"r…...

acwing_3195_有趣的数

acwing_3195_有趣的数 // // Created by HUAWEI on 2024/11/17. // #include<iostream> #include<cstring> #include<algorithm>#define int long longusing namespace std;const int N 1000 50; const int MOD 1e9 7; int C[N][N]; //组合数signed mai…...

Liunx-搭建安装VSOMEIP环境教程 执行 运行VSOMEIP示例demo

本文安装环境为Liunx&#xff0c;搭建安装VSOMEIP环境并运行基础例子。 1. 安装基础环境 使用apt-get来安装基础环境&#xff0c;受网络影响可以分开多次安装。环境好的也可以一次性执行。 sudo apt-get install gcc g sudo apt-get install cmake sudo apt-get install lib…...

Git | git revert命令详解

关注&#xff1a;CodingTechWork 引言 Git 是一个强大的版本控制工具&#xff0c;广泛应用于现代软件开发中。它为开发人员提供了多种功能来管理代码、协作开发和版本控制。在 Git 中&#xff0c;有时我们需要撤销或回退某些提交&#xff0c;而git revert 是一个非常有用的命令…...

ASP.NET Core 中,Cookie 认证在集群环境下的应用

在 ASP.NET Core 中&#xff0c;Cookie 认证在集群环境下的应用通常会遇到一些挑战。主要的问题是 Cookie 存储在客户端的浏览器中&#xff0c;而认证信息&#xff08;比如 Session 或身份令牌&#xff09;通常是保存在 Cookie 中&#xff0c;多个应用实例需要共享这些 Cookie …...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;百货中心供应链管理系统被用户普遍使用&#xff0c;为方…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

站群服务器的应用场景都有哪些?

站群服务器主要是为了多个网站的托管和管理所设计的&#xff0c;可以通过集中管理和高效资源的分配&#xff0c;来支持多个独立的网站同时运行&#xff0c;让每一个网站都可以分配到独立的IP地址&#xff0c;避免出现IP关联的风险&#xff0c;用户还可以通过控制面板进行管理功…...

[ACTF2020 新生赛]Include 1(php://filter伪协议)

题目 做法 启动靶机&#xff0c;点进去 点进去 查看URL&#xff0c;有 ?fileflag.php说明存在文件包含&#xff0c;原理是php://filter 协议 当它与包含函数结合时&#xff0c;php://filter流会被当作php文件执行。 用php://filter加编码&#xff0c;能让PHP把文件内容…...

Rust 开发环境搭建

环境搭建 1、开发工具RustRover 或者vs code 2、Cygwin64 安装 https://cygwin.com/install.html 在工具终端执行&#xff1a; rustup toolchain install stable-x86_64-pc-windows-gnu rustup default stable-x86_64-pc-windows-gnu ​ 2、Hello World fn main() { println…...

【SpringBoot自动化部署】

SpringBoot自动化部署方法 使用Jenkins进行持续集成与部署 Jenkins是最常用的自动化部署工具之一&#xff0c;能够实现代码拉取、构建、测试和部署的全流程自动化。 配置Jenkins任务时&#xff0c;需要添加Git仓库地址和凭证&#xff0c;设置构建触发器&#xff08;如GitHub…...