当前位置: 首页 > news >正文

使用 PyTorch 自定义数据集并划分训练、验证与测试集

使用 PyTorch 自定义数据集并划分训练、验证与测试集

在图像分类等任务中,通常需要将原始训练数据进一步划分为训练集和验证集,以便在训练过程中评估模型的性能。下面将详细介绍如何组织数据与注释文件、如何分割训练集和验证集,以及如何基于自定义 Dataset 类构建 DataLoader 以加速模型训练与评估。

一、数据准备

1.1 文件结构

假设你的数据目录结构如下所示:

data/
├── train_data/
│   ├── img1.png
│   ├── img2.png
│   ├── img3.png
│   └── ...
├── test_data/
│   ├── img101.png
│   ├── img102.png
│   ├── img103.png
│   └── ...
├── train_annotations.csv
└── test_annotations.csv

注意:这里将 train_annotations.csvtest_annotations.csv 文件单独放在 data/ 目录下,而不放在各自图片的子文件夹中。这样当图片数量非常多时,我们也能快速找到并管理这两个 CSV 文件。

1.2 注释文件(CSV)格式示例

train_annotations.csvtest_annotations.csv 中,一般会包含两列或更多列信息,但最关键的通常是 图片文件名(filename)和 标签(label)。格式示例如下:

train_annotations.csv

filename,label
img1.png,0
img2.png,1
img3.png,0
...

test_annotations.csv

filename,label
img101.png,0
img102.png,1
img103.png,0
...
  • filename 列表示图像的文件名,需要与 train_data/test_data/ 文件夹下的文件一一对应。
  • label 列表示图像所对应的类别或标签,可以是整数,也可以是字符串,比如 catdog 等。训练时通常会将字符串映射到整数标签或独热编码。

二、将训练数据划分为训练集和验证集

在进行模型训练前,往往需要将原始训练数据(以下简称 “总训练集”)拆分成 训练集(train) 和 验证集(val)。这里我们使用 scikit-learn 提供的 train_test_split 函数来完成这一步骤。

import pandas as pd
from sklearn.model_selection import train_test_split# 读取原始训练集的注释文件(此时还未拆分)
train_annotations = pd.read_csv('data/train_annotations.csv')# 按 80%:20% 的比例拆分为 新的训练集(train_df) 和 验证集(val_df)
train_df, val_df = train_test_split(train_annotations, test_size=0.2, random_state=42, stratify=train_annotations['label']
)# 将拆分后的注释文件保存为新的 CSV 文件
train_df.to_csv('data/train_split.csv', index=False)
val_df.to_csv('data/val_split.csv', index=False)

关键参数说明:

  • test_size=0.2:表示将 20% 的样本作为验证集,其余 80% 作为新的训练集。
  • random_state=42:让划分结果可复现,方便后续对比不同实验结果。
  • stratify=train_annotations['label']:在划分时保持各类别在训练和验证集中相同比例,这在分类任务中尤为重要。

执行完以上步骤后,你的 data 目录下会多出两个新的注释文件:

data/
├── train_data/
│   ├── ...
├── test_data/
│   ├── ...
├── train_annotations.csv   # 原始,总训练集注释
├── train_split.csv         # 新的,训练集注释
└── val_split.csv           # 新的,验证集注释

三、自定义 Dataset

PyTorch 提供了 torch.utils.data.Dataset 作为数据集的抽象基类。我们可以通过继承并重写其中的方法,来实现灵活的数据加载逻辑。

下面的 CustomImageDataset 类支持通过 CSV 文件(包括你在上一步生成的 train_split.csv, val_split.csv 等)来读取图像与标签,并在取样本时进行必要的预处理操作。

import os
import pandas as pd
from torch.utils.data import Dataset
from PIL import Imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):"""初始化数据集。参数:annotations_file (str): CSV 文件路径,包含 (filename, label) 等信息img_dir (str): 存放图像的文件夹路径transform (callable, optional): 对图像进行转换和增强的函数或 transforms 组合target_transform (callable, optional): 对标签进行转换的函数"""self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):"""返回整个数据集的样本数量。"""return len(self.img_labels)def __getitem__(self, idx):"""根据索引 idx 获取单个样本。返回:(image, label) 其中 image 可以是一个 PIL 图像或 Tensor,label 可以是整数或字符串"""# 1. 获取图像文件名与对应的标签img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])label = self.img_labels.iloc[idx, 1]# 2. 读取图像并转换为 RGB 模式(如果是灰度则可用 'L')image = Image.open(img_path).convert('RGB')# 3. 对图像和标签进行必要的变换if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label

四、创建训练集、验证集、测试集对应的 DataLoader

有了自定义 Dataset 后,就可以利用 PyTorch 自带的 DataLoader 来进行批量数据加载、随机打乱以及多线程读取数据等工作。以下示例展示了如何分别实例化 训练集验证集测试集Dataset 对象,并为每个对象创建 DataLoader

from torchvision import transforms
from torch.utils.data import DataLoader# 定义训练、验证/测试时所需的数据变换
train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),  # 数据增强transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])val_test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])# 实例化训练集 (train_dataset)
train_dataset = CustomImageDataset(annotations_file='data/train_split.csv',  # 注意这里不再是 data/train_annotations.csvimg_dir='data/train_data',transform=train_transform
)# 实例化验证集 (val_dataset)
val_dataset = CustomImageDataset(annotations_file='data/val_split.csv',img_dir='data/train_data',transform=val_test_transform
)# 实例化测试集 (test_dataset)
test_dataset = CustomImageDataset(annotations_file='data/test_annotations.csv',img_dir='data/test_data',transform=val_test_transform
)# 构建 DataLoader
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,      # 训练时常使用 shuffle=True 来打乱顺序num_workers=4,     # 根据 CPU 核心数进行调整drop_last=True     # 避免最后一个 batch 样本数不足时带来的问题
)val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)

通过使用 DataLoader,你就可以在训练和验证过程中以 (batch)为单位获取数据,从而显著提升训练速度,并方便进行数据增强、随机打乱等操作。

五、完整示例脚本

下面给出一个相对完整的示例脚本,整合了数据拆分、自定义数据集加载以及构建 DataLoader 的主要流程。如果你愿意,可以将这些步骤拆分到不同的 Python 文件中,以保持项目结构清晰。

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms# ========== 1. 数据集拆分函数 ========== #
def split_train_val(annotations_file, output_train_file, output_val_file, test_size=0.2, random_state=42):df = pd.read_csv(annotations_file)train_df, val_df = train_test_split(df, test_size=test_size, random_state=random_state, stratify=df['label'])train_df.to_csv(output_train_file, index=False)val_df.to_csv(output_val_file, index=False)# ========== 2. 定义自定义 Dataset 类 ========== #
class CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])label = self.img_labels.iloc[idx, 1]image = Image.open(img_path).convert('RGB')if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label# ========== 3. 执行划分并创建训练/验证/测试集 ========== #
# 假设原始的训练集标注文件位于 data/train_annotations.csv
split_train_val(annotations_file='data/train_annotations.csv',output_train_file='data/train_split.csv',output_val_file='data/val_split.csv',test_size=0.2,random_state=42
)train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])val_test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])train_dataset = CustomImageDataset(annotations_file='data/train_split.csv',img_dir='data/train_data',transform=train_transform
)val_dataset = CustomImageDataset(annotations_file='data/val_split.csv',img_dir='data/train_data',transform=val_test_transform
)test_dataset = CustomImageDataset(annotations_file='data/test_annotations.csv',img_dir='data/test_data',transform=val_test_transform
)train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,num_workers=4,drop_last=True
)val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)test_loader = DataLoader(test_dataset,batch_size=64,shuffle=False,num_workers=4,drop_last=False
)# ========== 4. 简单测试:读取一个 batch ========== #
for images, labels in train_loader:print(images.shape, labels.shape)break

六、在训练循环中使用验证集

构建好训练、验证和测试集的 DataLoader 之后,你就可以在模型训练过程中使用验证集来评估模型性能;并在完全训练结束后,对测试集进行最终评估。以下是一个最简化的示例,演示如何在每个 epoch 后进行验证:

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的神经网络
class SimpleNN(nn.Module):def __init__(self, num_classes=10):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(224*224*3, 128)  # 根据输入图像大小进行调整self.relu = nn.ReLU()self.fc2 = nn.Linear(128, num_classes)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 初始化模型、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNN(num_classes=2).to(device)  # 假设有 2 个类别
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练过程
num_epochs = 5
for epoch in range(num_epochs):# 1. 训练阶段model.train()running_loss = 0.0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()avg_train_loss = running_loss / len(train_loader)# 2. 验证阶段model.eval()correct = 0total = 0val_loss = 0.0with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()avg_val_loss = val_loss / len(val_loader)val_accuracy = 100.0 * correct / totalprint(f'Epoch [{epoch+1}/{num_epochs}], 'f'Train Loss: {avg_train_loss:.4f}, 'f'Val Loss: {avg_val_loss:.4f}, 'f'Val Accuracy: {val_accuracy:.2f}%')

输出示例

Epoch [1/5], Train Loss: 1.2034, Val Loss: 0.4567, Val Accuracy: 85.32%
Epoch [2/5], Train Loss: 0.9876, Val Loss: 0.3987, Val Accuracy: 88.45%
...

总结

  1. 数据组织:将大量图片与注释文件分开存储(如 train_annotations.csvtest_annotations.csv 单独放在 data/ 目录下),可以在图片数量庞大时更方便地管理和检索。
  2. 数据集拆分:使用 train_test_split 将原始训练集拆分为训练集与验证集,以便在训练过程中监控模型的过拟合情况。
  3. 自定义 Dataset:通过继承 Dataset 并重写 __getitem____len__,可以灵活处理任意格式的数据,并在读入时执行预处理/增强操作。
  4. 构建 DataLoader:使用 PyTorch 的 DataLoader 可以轻松实现批量读取、并行加速、随机打乱等功能,大幅提升训练效率。
  5. 验证与测试:在每个 epoch 后对验证集进行评估可以及时发现过拟合和调参问题;最终对测试集进行评估可以获得模型的实际泛化性能。

相关文章:

使用 PyTorch 自定义数据集并划分训练、验证与测试集

使用 PyTorch 自定义数据集并划分训练、验证与测试集 在图像分类等任务中,通常需要将原始训练数据进一步划分为训练集和验证集,以便在训练过程中评估模型的性能。下面将详细介绍如何组织数据与注释文件、如何分割训练集和验证集,以及如何基于…...

VSCode 插件

VSCode 插件 1. GitHub Copilot - AI 代码助手 功能:根据上下文提供实时代码补全,支持自然语言转代码,提供符合现代编程规范的建议。进阶技巧: 使用快捷键 Alt ] 切换多个建议。写注释时,描述业务逻辑而不是具体实现…...

Windows使用AutoHotKey解决鼠标键连击现象(解决鼠标连击、单击变双击的故障)

注:罗技鼠标,使用久了之后会出现连击现象,如果刚好过保了,可以考虑使用软件方案解决连击现象: 以下是示例AutoHotKey脚本,实现了调用XButton1用于关闭窗口(以及WinW,XButton2也导向…...

Linux 环境(Ubuntu)部署 Hadoop 环境

前置准备 准备三台机器 cat /etc/hosts 192.168.1.7 hadoop-master 192.168.1.11 hadoop-slave01 192.168.1.12 hadoop-slave02Linux 环境 cat /etc/os-release PRETTY_NAME"Ubuntu 24.10" NAME"Ubuntu" VERSION_ID"24.10" VERSION"24.…...

如何在Windows 11 WSL2 Ubuntu 环境下安装和配置perf性能分析工具?

在Windows 11 WSL2 Ubuntu 环境下完整安装和配置perf性能分析工具 一、背景二、准备工作三、获取并编译Linux内核源码四、安装和配置perf五、测试perf六、总结 一、背景 由于WSL2使用的是微软定制的内核,并非标准的Ubuntu内核,因此直接使用apt安装linux…...

Docker运维高级容器技术知识点总结

1、虚拟机部署和容器化部署的区别是什么&#xff1f; 1、技术基础&#xff1a; <1>.虚拟化技术在物理硬件上创建虚拟机&#xff0c;每台虚拟机运行自己完整的操作系统、从而实现资源隔离。 <2>.容器化技术&#xff1a;将应用程序打包在容器内&#xff0c;在进程空间…...

react-quill 富文本组件编写和应用

index.tsx文件 import React, { useRef, useState } from react; import { Modal, Button } from antd; import RichEditor from ./RichEditor;const AnchorTouchHistory: React.FC () > {const editorRef useRef<any>(null);const [isModalVisible, setIsModalVis…...

LabVIEW轴承性能测试系统

本文介绍了基于LabVIEW的高效轴承性能测试系统的设计与开发。系统通过双端驱动技术实现高精度同步控制&#xff0c;针对轴承性能进行全面的测试与分析&#xff0c;以提高轴承的可靠性和寿命。 项目背景 随着工业自动化程度的提高&#xff0c;对轴承的性能要求越来越高。传统的…...

【《游戏编程模式》实战04】状态模式实现敌人AI

目录 1、状态模式 2、使用工具 3、状态模式适用范围 4、实现内容 5、代码及思路 Enemy.cs EnemyState.cs 6、unity里的设置 7、运行效果展示 1、状态模式 “允许一个对象在其内部状态改变时改变自身的行为。对象看起来好像是在修改自身类。” 就是一个对象能随着自己…...

借助免费GIS工具箱轻松实现las点云格式到3dtiles格式的转换

在当今数字化浪潮下&#xff0c;地理信息系统&#xff08;GIS&#xff09;技术日新月异&#xff0c;广泛渗透到城市规划、地质勘探、文化遗产保护等诸多领域。而 GISBox 作为一款功能强大且易用的 GIS 工具箱&#xff0c;以轻量级、免费使用、操作便捷等诸多优势&#xff0c;为…...

科研绘图系列:R语言科研绘图之标记热图(heatmap)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载导入数据数据预处理画图系统信息参考介绍 科研绘图系列:R语言科研绘图之标记热图(heatmap) 加载R包 library(tidyverse) library(ggplot2) library(reshape)…...

【轻松学C:编程小白的大冒险】--- C语言简介 02

在编程的艺术世界里&#xff0c;代码和灵感需要寻找到最佳的交融点&#xff0c;才能打造出令人为之惊叹的作品。而在这座秋知叶i博客的殿堂里&#xff0c;我们将共同追寻这种完美结合&#xff0c;为未来的世界留下属于我们的独特印记。 【轻松学C&#xff1a;编程小白的大冒险】…...

《HeadFirst设计模式》笔记(上)

设计模式的目录&#xff1a; 1 设计模式介绍 要不断去学习如何利用其它开发人员的智慧与经验。学习前人的正统思想。 我们认为《Head First》的读者是一位学习者。 一些Head First的学习原则&#xff1a; 使其可视化将文字放在相关图形内部或附近&#xff0c;而不是放在底部…...

数据结构:ArrayList与顺序表

目录 &#x1f4d6;一、什么是List &#x1f4d6;二、线性表 &#x1f4d6;三、顺序表 &#x1f42c;1、display()方法 &#x1f42c;2、add(int data)方法 &#x1f42c;3、add(int pos, int data)方法 &#x1f42c;4、contains(int toFind)方法 &#x1f42c;5、inde…...

SpringBoot之核心配置

学习目标&#xff1a; 1.熟悉Spring Boot全局配置文件的使用 2.掌握Spring Boot配置文件属性值注入 3.熟悉Spring Boot自定义配置 4.掌握Profile多环境配置 5.了解随机值设置以及参数间引用 1.全局配置文件 Spring Boot使用 application.properties 或者application.yaml 的文…...

EasyExcel上传校验文件错误信息放到文件里以Base64 返回给前端

产品需求&#xff1a; 前端上传个csv 或 excel 文件&#xff0c;文件共4列&#xff0c;验证文件大小&#xff0c;类型&#xff0c;文件名长度&#xff0c;文件内容&#xff0c;如果某行某个单元格数据验证不通过&#xff0c;就把错误信息放到这行第五列&#xff0c;然后把带有…...

单片机软件定时器V4.0

单片机软件定时器V4.0 用于单片机定时执行任务等&#xff0c;比如LED GPIO等定时控制&#xff0c;内置前后台工作模式 头文件有使用例子 #ifndef __SORFTIME_APP_H #define __SORFTIME_APP_H#ifdef __cplusplus extern "C" { #endif#include <stdint.h>// #…...

超完整Docker学习记录,Docker常用命令详解

前言 关于国内拉取不到docker镜像的问题&#xff0c;可以利用Github Action将需要的镜像转存到阿里云私有仓库&#xff0c;然后再通过阿里云私有仓库去拉取就可以了。 参考项目地址&#xff1a;使用Github Action将国外的Docker镜像转存到阿里云私有仓库 一、Docker简介 Do…...

C++ 入门第26天:文件与流操作基础

往期回顾&#xff1a; C 入门第23天&#xff1a;Lambda 表达式与标准库算法入门-CSDN博客 C 入门第24天&#xff1a;C11 多线程基础-CSDN博客 C 入门第25天&#xff1a;线程池&#xff08;Thread Pool&#xff09;基础-CSDN博客 C 入门第26天&#xff1a;文件与流操作基础 前言…...

使用python将多个Excel表合并成一个表

import pandas as pd# 定义要合并的Excel文件路径和名称 file_paths [file1.xlsx, file2.xlsx, file3.xlsx, file4.xlsx, file5.xlsx]# 创建一个空的DataFrame来存储合并后的数据 merged_data pd.DataFrame()# 循环遍历每个Excel文件&#xff0c;并读取其中的数据 for file_p…...

数学建模期末速成 聚类分析与判别分析

聚类分析是在不知道有多少类别的前提下&#xff0c;建立某种规则对样本或变量进行分类。判别分析是已知类别&#xff0c;在已知训练样本的前提下&#xff0c;利用训练样本得到判别函数&#xff0c;然后对未知类别的测试样本判别其类别。 聚类分析 根据样本自身的属性&#xf…...

Houdini POP入门学习05 - 物理属性

接下来随着教程学习碰撞部分&#xff0c;当粒子较为复杂或者下载了一些粒子模板进行修改时&#xff0c;会遇到一些较奇怪问题&#xff0c;如粒子穿透等&#xff0c;这些问题实际上可以通过调节参数解决。 hip资源文件&#xff1a;https://download.csdn.net/download/grayrail…...

如何从浏览器中导出网站证书

以导出 GitHub 证书为例&#xff0c;点击 小锁 点击 导出 注意&#xff1a;这里需要根据你想要证书格式手动加上后缀名&#xff0c;我的是加 .crt 双击文件打开...

分享一道力扣

刚刚笔试遇到的。好像很简单&#xff0c;但又不容易写的 611 有效三角形 def triangleNumber(self, nums):count 0nums.sort()for i in range(len(nums) - 2):k i 2for j in range(i 1, len(nums) - 1):if nums[i] 0:breakwhile k < len(nums) and nums[i] nums[j] &g…...

11 - ArcGIS For JavaScript -- 高程分析

这里写自定义目录标题 描述代码实现结果 描述 高程分析是地理信息系统(GIS)中的核心功能之一&#xff0c;主要涉及对地表高度数据(数字高程模型, DEM)的处理和分析。 ArcGIS For JavaScript4.32版本的发布&#xff0c;提供了Web端的针对高程分析的功能。 代码实现 <!doct…...

My图床项目

引言: 在海量文件存储中尤其是小文件我们通常会用上fastdfs对数据进行高效存储,在现实生产中fastdfs通常用于图片,文档,音频等中小文件。 一.项目中用到的基础组件(Base) 1.网络库(muduo) 我们就以muduo网络库为例子讲解IO多路复用和reactor网络模型 1.1 IO多路复用 我们可以…...

《Java 并发神器:深入理解CompletableFuture.supplyAsync与线程池实战优化》

一、背景介绍 在 Java 后端开发中&#xff0c;我们经常会遇到以下问题&#xff1a; 需要并行执行多个数据库查询或远程调用&#xff1b;单线程执行多个 .list() 方法时耗时过长&#xff1b;希望提升系统响应速度&#xff0c;但又不想引入过多框架。 这时&#xff0c;Java 8 …...

科技创新驱动人工智能,计算中心建设加速产业腾飞​

在科技飞速发展的当下&#xff0c;人工智能正以前所未有的速度融入我们的生活。一辆辆无人驾驶的车辆在道路上自如地躲避车辆和行人&#xff0c;行驶平稳且操作熟练&#xff1b;刷脸支付让购物变得安全快捷&#xff0c;一秒即可通行。这些曾经只存在于想象中的场景&#xff0c;…...

AI Agent开发第78课-大模型结合Flink构建政务类长公文、长文件、OA应用Agent

开篇 AI Agent2025确定是进入了爆发期,到处都在冒出各种各样的实用AI Agent。很多人、组织都投身于开发AI Agent。 但是从3月份开始业界开始出现了一种这样的声音: AI开发入门并不难,一旦开发完后没法用! 经历过至少一个AI Agent从开发到上线的小伙伴们其实都听到过这种…...

从理论崩塌到新路径:捷克科学院APL Photonics论文重构涡旋光技术边界

理论预言 vs 实验挑战 光子轨道角动量&#xff08;Orbital Angular Momentum, OAM&#xff09;作为光场调控的新维度&#xff0c;曾被理论预言可突破传统拉曼散射的对称性限制——尤其是通过涡旋光&#xff08;如拉盖尔高斯光束&#xff09;激发晶体中常规手段无法探测的"…...