当前位置: 首页 > news >正文

浅谈torch.utils.data.TensorDataset和torch.utils.data.DataLoader

1.torch.utils.data.TensorDataset

功能定位

torch.utils.data.TensorDataset 是一个将多个张量(Tensor)数据进行简单包装整合的数据集类,它主要的作用是将相关联的数据(比如特征数据和对应的标签数据等)组合在一起,形成一个方便后续用于训练等操作的数据集对象。

例如,如果你有输入特征数据 x(形状为 [n_samples, feature_dim])和对应的标签数据 y(形状为 [n_samples]),且它们都是 torch.Tensor 类型,可以这样创建 TensorDataset

import torch
from torch.utils.data import TensorDatasetx = torch.randn(100, 10)  # 模拟100个样本,每个样本特征维度为10
y = torch.randint(0, 2, (100,))  # 模拟二分类标签dataset = TensorDataset(x, y)
特点
  • 简单包装:只是把传入的张量按照样本维度进行了对应组合,并没有对数据做复杂的预处理、采样等额外操作。

  • 索引支持:支持像普通列表那样通过索引访问其中的数据元素,例如 dataset[0] 会返回由对应索引的特征和标签组成的元组(按照传入构造函数的张量顺序)。

  • 适用于小型数据集直接使用:当数据量不大且数据格式已经整理为张量形式时,可以直接基于它来进行简单的模型训练循环等操作,不过对于批量处理等更复杂的情况支持有限,需要进一步配合其他工具。

2.torch.utils.data.DataLoader

功能定位

torch.utils.data.DataLoader 是一个用于加载数据的工具类,它围绕着给定的数据集(比如 TensorDataset 或者自定义的继承自 Dataset 的类实例等),实现了诸如批量加载数据、打乱数据顺序、并行加载数据等功能,旨在让数据能够以合适的方式、合适的批量大小等被送入到模型中进行训练、验证或测试等操作。

示例:

from torch.utils.data import DataLoaderbatch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)for batch_x, batch_y in dataloader:# 这里的batch_x和batch_y就是每次迭代取出的一个批量的特征和标签数据pass
特点
  • 批量处理:可以按照设定的 batch_size 参数,将数据集中的数据划分为一个个的小批量(mini-batch),方便模型以批量的方式进行梯度计算更新,有助于优化训练过程和提升效率,尤其在大数据集场景下优势明显。

  • 数据打乱:通过设置 shuffle=True 可以在每个训练轮次(epoch)开始时对数据集里面的数据顺序进行随机打乱,使得数据的输入顺序具有随机性,这有助于提升模型训练的泛化能力,避免模型因数据顺序固定而产生过拟合等问题。

  • 并行加载:支持多进程加载数据(通过设置 num_workers 参数大于 0),能够利用多核 CPU 的优势加快数据读取和预处理的速度,特别是在处理大规模数据集或者数据加载比较耗时的情况下,能显著提升整体训练效率。

  • 灵活性和通用性:它可以适配各种不同类型的数据集,只要这些数据集继承自 torch.utils.data.Dataset 抽象类并实现了必要的 __len____getitem__ 等方法,因此无论是简单的 TensorDataset 还是复杂的自定义数据集都可以用它来加载数据。

总的来说,TensorDataset 侧重于对已有张量数据进行简单的整合包装形成数据集;而 DataLoader 侧重于围绕数据集实现数据的批量加载、打乱顺序、并行化等复杂的数据加载相关功能,它们通常配合使用,先使用 TensorDataset 组织好数据,再通过 DataLoader 按照训练需求来加载和处理这些数据并送入模型中。

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs)

相关文章:

浅谈torch.utils.data.TensorDataset和torch.utils.data.DataLoader

1.torch.utils.data.TensorDataset 功能定位 torch.utils.data.TensorDataset 是一个将多个张量(Tensor)数据进行简单包装整合的数据集类,它主要的作用是将相关联的数据(比如特征数据和对应的标签数据等)组合在一起&…...

gesp(C++二级)(16)洛谷:B4037:[GESP202409 二级] 小杨的 N 字矩阵

gesp(C++二级)(16)洛谷:B4037:[GESP202409 二级] 小杨的 N 字矩阵 题目描述 小杨想要构造一个 m m m \times m m...

FFmpeg:详细安装教程与环境配置指南

FFmpeg 部署完整教程 在本篇博客中,我们将详细介绍如何下载并安装 FFmpeg,并将其添加到系统的环境变量中,以便在终端或命令行工具中直接调用。无论你是新手还是有一定基础的用户,这篇教程都能帮助你轻松完成 FFmpeg 的部署。 一、…...

《特征工程:自动化浪潮下的坚守与变革》

在机器学习的广阔天地中,特征工程一直占据着举足轻重的地位。它宛如一位幕后的工匠,精心雕琢着原始数据,将其转化为能够被机器学习模型高效利用的特征,从而推动模型性能迈向新的高度。然而,随着技术的飞速发展&#xf…...

webrtc 源码阅读 make_ref_counted模板函数用法

目录 1. 模板参数解析 1.1 typename T 1.2 typename... Args 1.3 typename std::enable_if::value, T>::type* nullptr 2. scoped_refptr 3. new RefCountedObject(std::forward(args)...); 4. 综合说明 5.在webrtc中的用法 5.1 peerConnectionFactory对象的构建过…...

【深度学习基础之多尺度特征提取】特征金字塔(Feature Pyramid)是如何在深度学习网络中提取多尺度特征的?附代码

【深度学习基础之多尺度特征提取】特征金字塔(Feature Pyramid)是如何在深度学习网络中提取多尺度特征的?附代码 【深度学习基础之多尺度特征提取】特征金字塔(Feature Pyramid)是如何在深度学习网络中提取多尺度特征…...

【Docker】离线安装 Docker

离线安装 Docker 在CentOS系统上安装Docker 1、下载 Docker 仓库文件 https://download.docker.com/linux/centos/docker-ce.repo 2、添加 Docker 仓库文件 将上一步下载的文件,移动到 /etc/yum.repos.d/ 目录 3、清除 YUM 缓存 sudo yum clean all sudo yum…...

三大行业案例:AI大模型+Agent实践全景

本文将从AI Agent和大模型的发展背景切入,结合51Talk、哈啰出行以及B站三个各具特色的行业案例,带你一窥事件驱动架构、RAG技术、人机协作流程,以及一整套行之有效的实操方法。具体包含内容有:51Talk如何让智能客服“主动进攻”&a…...

Dockerfile基础指令

1.FROM 基于基准镜像(建议使用官方提供的镜像作为基准镜像,相对安全一些) 举例: 制作基准镜像(基于centos:lastest) FROM cenots 不依赖于任何基准镜像 FROM scratch 依赖于9.0.22版本的tomcat镜像 FROM…...

12.30 linux 文件操作,磁盘分区挂载

ubuntu 在linux 对文件的相关操作【压缩,打包,软链接,文件权限】【head,tail,管道符,通配符,find,grep,cut等】脑图-CSDN博客 1.文件操作 在家目录下创建目录文件&#…...

[图形渲染]【Unity Shader】【游戏开发】 Shader数学基础17-法线变换基础与应用

在计算机图形学中,法线(normal) 是表示表面方向的向量。它在光照、阴影、碰撞检测等领域有着重要作用。本文将介绍如何在模型变换过程中正确变换法线,确保其在光照计算中的正确性,特别是法线与顶点的变换问题。 1. 法线与切线的基本概念 法线(Normal Vector) 法线(或…...

YOLOv9-0.1部分代码阅读笔记-train.py

train.py train.py 目录 train.py 1.所需的库和模块 2.def train(hyp, opt, device, callbacks): 3.def parse_opt(knownFalse): 4.def main(opt, callbacksCallbacks()): 5.def run(**kwargs): 6.if __name__ "__main__": 1.所需的库和模块 import …...

等保测评和密评的相关性和区别

等保测评和密评在网络安全领域均扮演着至关重要的角色,它们之间既存在相关性,又各具特色。 以下是对两者相关性和区别的详细阐述:相关性 1.法律基础:等保测评和密评都是依据国家相关法律法规开展的活动。 等保测评主要依据《网…...

活动预告 |【Part2】 Azure 在线技术公开课:迁移和保护 Windows Server 和 SQL Server 工作负载

课程介绍 通过 Microsoft Learn 免费参加 Microsoft Azure 在线技术公开课,掌握创造新机遇所需的技能,加快对 Microsoft 云技术的了解。参加我们举办的“迁移和保护 Windows Server 和 SQL Server 工作负载”活动,了解 Azure 如何为将工作负载…...

大语言模型(LLM)一般训练过程

大语言模型(LLM)一般训练过程 数据收集与预处理 收集:从多种来源收集海量文本数据,如互联网的新闻文章、博客、论坛,以及书籍、学术论文、社交媒体等,以涵盖丰富的语言表达和知识领域。例如,训练一个通用型的LLM时,可能会收集数十亿甚至上百亿字的文本数据.清洗:去除…...

单片机的基本组成

单片机,即单芯片微型计算机(Single-Chip Microcomputer),是一种将中央处理器(CPU)、内存、输入输出接口等功能集成在一块集成电路芯片上的微型计算机。它具有体积小、成本低、可靠性高、功耗低等优点,在现代电子产品中…...

GO性能优化的一些记录:trace工具的使用

使用场景: 1 想要查看接口延时性偏高 2 深入了解协程具体如何运营的详细信息(运行时长,或者什么原因导致了协程运行受阻) 可以使用 trace 功能,程序便会对下面的一系列事件进行详细记录,并且会依据所搜集到…...

dede-cms关于shell漏洞

一.文件式管理器 1.新建文件 新建一个php文件,内容写个php脚本语言 访问,可以运行 2.文件上传 上传一个php文件,内容同样写一个php代码 访问,运行成功 二.模块-广告管理 来到模块-广告管理——>增加一个新广告 在这里试一下…...

NAT 技术如何解决 IP 地址短缺问题?

NAT 技术如何解决 IP 地址短缺问题? 前言 这是我在这个网站整理的笔记,有错误的地方请指出,关注我,接下来还会持续更新。 作者:神的孩子都在歌唱 随着互联网的普及和发展,IP 地址的需求量迅速增加。尤其是 IPv4 地址&…...

使用 IDE生成 Java Doc

使用步骤 Android Studio界面->Tools->Generate JavaDoc zh-CN -encoding UTF-8 -charset UTF-8 -classpath “C:\Users\fangjian\AppData\Local\Android\Sdk\platforms\android-34\android.jar” 报错问题 错误: 目标 17 不允许选项 --boot-class-path 如果你正在使用…...

后进先出(LIFO)详解

LIFO 是 Last In, First Out 的缩写,中文译为后进先出。这是一种数据结构的工作原则,类似于一摞盘子或一叠书本: 最后放进去的元素最先出来 -想象往筒状容器里放盘子: (1)你放进的最后一个盘子&#xff08…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展&#xff0c;光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域&#xff0c;IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选&#xff0c;但在长期运行中&#xff0c;例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

jmeter聚合报告中参数详解

sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample&#xff08;样本数&#xff09; 表示测试中发送的请求数量&#xff0c;即测试执行了多少次请求。 单位&#xff0c;以个或者次数表示。 示例&#xff1a;…...