当前位置: 首页 > news >正文

深度学习技术栈 —— Pytorch之TensorDataset、DataLoader

深度学习技术栈 —— Pytorch之TensorDataset、DataLoader

  • 前言
  • 一、TensorDataset、DataLoader的用法?
  • 二、从.csv文件-->tensor张量
  • 总结


前言

简单来说,TensorDatasetDataLoader这两个类的作用, 就是将数据读入并做整合,以便交给模型处理。就像石油加工厂一样,你不关心石油是如何采集与加工的,你关心的是自己去哪加油,油价是多少,对于一个模型而言,DataLoader就是这样的一个予取予求的数据服务商。

参考文章或视频链接
[1] How to use TensorDataset, Dataloader (pytorch)

一、TensorDataset、DataLoader的用法?

# coding:utf-8
# @Time: 2024/1/23 上午9:57
# @Author: 键盘国治理专家
# @File: __init__.py.py
# @Description: import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoaderdef test_TensorDataset():input = np.random.rand(4, 2)  # Input datacorrect = np.random.rand(4, 1)  # Correct answer datainput = torch.FloatTensor(input)  # Change to an array that can be handled by pytorchcorrect = torch.FloatTensor(correct)  # Same as aboveprint(input)print(correct)dataset = TensorDataset(input, correct)  # set the data,注意,是TensorDataset而不是Dataset,Dataset是个abstract class不能实例化print(dataset)  # 打印地址print(vars(dataset))  # vars prints the contents of the objectreturn datasetdef test_DataLoader(dataset):train_load = DataLoader(dataset, batch_size=3, shuffle=False)  # Data shuffle with shuffle=Truefor x, t in train_load:print('x-->', x)print('t-->', t)if __name__ == '__main__':dataset = test_TensorDataset()print("========================================================================================")test_DataLoader(dataset)

二、从.csv文件–>tensor张量

一般说来,大部分Kaggle比赛的数据都是以.csv为格式的,而Pytorch处理的是tensor张量,所以我们要了解如何将.csv文件的数据变成tensor张量数据。

"""
步骤如下
(1) xx.csv --> 经由pandas 变成 numpy 数组
(2) numpy 变成 tensor 张量
(3) tensor张量经过TensorDataset的组合
(4) dataset再经过DataLoader的处理,进而保证数据可用,以上为清洗过程
.csv --> numpy --> tensor --> dataset --> dataloader 四个过程,五个数据中转形式。
"""
# coding:utf-8
# @Time: 2024/1/23 下午1:01
# @Author: 键盘国治理专家
# @File: csv2tensor.py
# @Description:import numpy
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoaderdef csv2numpy(csv_path):data = pd.read_csv(csv_path, dtype=np.float64)# numpy_data = data.iloc[:, data.columns != "xx"]  # 另一种用法,data.columns != "xx" 可以过滤掉你不想读入的字段numpy_data = data.iloc[:].valuesreturn numpy_datadef numpy2tensor(numpy_data):tensor_data = torch.from_numpy(numpy_data)return tensor_datadef tensor2DataLoader(tensor_data):  # 一步到位,直接变成DataLoader。最简单的实现方式,这个func还有改进空间,DataSet可以接收多个tensor数据dataset = torch.utils.data.TensorDataset(tensor_data)data_loader = torch.utils.data.DataLoader(dataset, shuffle=False)return data_loader# 你甚至可以直接将.csv处理成DataLoader了,把这几个过程简单组合下形成一个新函数
def csv2DataLoader(csv_path):numpy_data = csv2numpy(csv_path)tensor_data = numpy2tensor(numpy_data)data_loader = tensor2DataLoader(tensor_data)return data_loaderif __name__ == '__main__':numpy_data = csv2numpy("./test.csv")# print(type(numpy_data))# print(numpy_data.shape)# print(numpy_data)tensor_data = numpy2tensor(numpy_data)# print(type(tensor_data))# print(tensor_data.shape)# print(tensor_data)data_loader = tensor2DataLoader(tensor_data)# print(type(data_loader))# print(data_loader)# print(data_loader.dataset)# # 用遍历的方式才能输出data_loader里的数据# for data_item in data_loader:#     print('data_item-->', data_item)# # 把数据的索引也一起输出# for i, data_item in enumerate(data_loader):#     print('i', i)#     print('data_item-->', data_item)

总结

本篇工作虽然简单,但确是进阶的一个不大不小的绊脚石,功夫虽小,也不能不练。

相关文章:

深度学习技术栈 —— Pytorch之TensorDataset、DataLoader

深度学习技术栈 —— Pytorch之TensorDataset、DataLoader 前言一、TensorDataset、DataLoader的用法?二、从.csv文件-->tensor张量总结 前言 简单来说,TensorDataset与DataLoader这两个类的作用, 就是将数据读入并做整合,以便…...

远程git开发

两种本地与远程仓库同步 """ 1)你作为项目仓库初始化人员:线上要创建空仓库 > 本地初始化好仓库 > 建立remote链接(remote add) > 提交本地仓库到远程(push)2)你作为项目后期开发人员:远程项目仓库已经创…...

Codeforces Round 812 (Div. 2) ---- C. Build Permutation --- 题解

目录 C. Build Permutation 题目描述: ​编辑 思路解析: 代码实现: C. Build Permutation 题目描述: 思路解析: 先证明在任何情况下答案均存在。 假设我们所求的为 m m1 m2.....n 的排列,我们称不小于n…...

Matlab 将工作区变量保存到文件中(save)

语法 1、save(filename) 2、save(filename,variables) 3、save(filename,variables,fmt) 4、save(filename,variables,version) 5、save(filename,variables,version,-nocompression) 6、save(filename,variables,-append) 7、save(filename,variables,-append,-nocompression…...

源码实现简介

本系列所有代码在文章底部,每一章节代码可独立编译运行 随着科技的飞速发展,自动驾驶技术正逐渐成为现实。而在自动驾驶技术中,感知是至关重要的一个环节。通过感知,自动驾驶车辆能够识别和理解周围环境,进而做出相应…...

我每天如何使用 ChatGPT

我们都清楚互联网的运作方式——充斥着各种“爆款观点”,极端分裂的意见,恶搞和无知现象屡见不鲜。 最近,大家对于人工智能(AI)特别是大语言模型(LLMs)和生成式 AI(GenAI&#xff0…...

MySQL修炼手册14:用户权限管理:安全保障与数据隔离

目录 写在开头1 用户与权限的关系1.1 用户的创建与删除1.1.1 创建新用户1.1.2 批量创建用户1.1.3 安全删除用户 1.2 授予与撤销权限1.2.1 授予权限1.2.2 批量授予权限1.2.3 撤销权限 2 角色的应用2.1 创建与管理角色2.1.1 创建角色2.1.2 管理角色 2.2 将权限赋予角色2.2.1 将权…...

动态规划解决马尔可夫决策过程

马尔可夫决策过程是强化学习中的基本问题模型之一,而解决马尔可夫决策过程的方法我们统称为强化学习算法。 动态规划( dynamic programming, DP )具体指的是在某些复杂问题中,将问题转化为若干个子问题,并在求解每个子…...

ubuntu1604安装及问题解决

虚拟机安装vmbox7 虚拟机操作: 安装增强功能 sudo mkdir /mnt/share sudo mount -t vboxsf sharefolder /mnt/share第一次使用sudo提示is not in the sudoers file. This incident will be reported 你的root需要设置好密码 sudo passwd root 输入如下指令&#x…...

Leetcode—24. 两两交换链表中的节点【中等】

2023每日刷题(八十七) Leetcode—24. 两两交换链表中的节点 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x),…...

USRP相关报错解决办法

文章目录 前言一、本地环境二、相关报错信息二、解决办法1、更换电脑操作系统2、升级最新版固件 前言 在进行 USRP 开发时遇到了一些报错,这里做个记录解决问题的方法。 一、本地环境 电脑操作系统:Windows11MATLAB 版本:MATLAB 2021aUSRP …...

【剑指offer】重建二叉树

👑专栏内容:力扣刷题⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、题目描述1、题目2、示例 二、题目分析1、递归2、栈 一、题目描述 1、题目 剑指offer:重建二叉树 给定节…...

中仕教育:事业编招考全流程介绍

一、报名阶段 1. 了解查看招聘信息:查看各类事业编岗位的招聘信息,包括岗位职责、招聘条件、报名时间等。 2. 填写报名表:按照要求填写报名表,包括个人信息、教育背景、工作经历等内容。 3. 提交报名材料:将报名表及…...

149. 直线上最多的点数

149. 直线上最多的点数 class MaxPoints:"""149. 直线上最多的点数https://leetcode.cn/problems/max-points-on-a-line/description/?envTypestudy-plan-v2&envIdtop-interview-150"""def solution(self, points: List[List[int]]) ->…...

不合格机器人工程讲师再读《悉达多》-2024-

一次又一次失败的经历,让我对经典书籍的认同感越来越多,越来越觉得原来的自己是多么多么的无知和愚昧。 ----zhangrelay 唯物也好,唯心也罢,我们都要先热爱这个世界,然后才能在其中找到自己所热爱的事业。 ----zh…...

【STM32CubeMX串口通信详解】USART2 -- DMA发送 + DMA空闲中断 接收不定长数据

( 本篇正在编写、更新状态中.....) 文章目录: 前言 前言 本篇,详细地用截图解释 CubeMX 对 USART2 的配置,HAL函数使用,和收发程序的编写。 收、发机制:DMA发送 DAM空闲中断接收。 DMA空…...

Webpack5入门到原理19:React 脚手架搭建

开发模式配置 // webpack.dev.js const path require("path"); const ESLintWebpackPlugin require("eslint-webpack-plugin"); const HtmlWebpackPlugin require("html-webpack-plugin"); const ReactRefreshWebpackPlugin require("…...

苹果眼镜(Vision Pro)的开发者指南(6)-实战应用场景开发 - 游戏、协作、空间音频、WebXR

第一部分:【构建游戏和媒体体验】 了解如何使用visionOS在游戏和媒体体验中创建真正身临其境的时刻。游戏和媒体可以利用全方位的沉浸感来讲述令人难以置信的故事,并以一种新的方式与人们联系。将向你展示可供你入门的visionOS游戏和叙事开发途径。了解如何使用RealityKit有…...

flutter底层架构初探

本文出处:​​​​​​​​​​​​​Flutter 中文开发者网站 架构 embedder嵌入层 提供程序入口(其他原生应用也采用此方式),程序由此和底层操作系统协调(surface渲染、辅助功能和输入服务,管理事件循环…...

初识SQL注入

目录 注入攻击 SQL注入 手工注入 Information_schema数据库 自动注入 介绍一下这款工具:sqlmap 半自动注入 前面给大家通过学习练习的方式将XSS攻击的几种形式和一些简单的靶场和例题的演示,从本篇开始我将和小伙伴们通过边复习、边练习的方式来进…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

synchronized 学习

学习源&#xff1a; https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.应用场景 不超卖&#xff0c;也要考虑性能问题&#xff08;场景&#xff09; 2.常见面试问题&#xff1a; sync出…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

Python爬虫(二):爬虫完整流程

爬虫完整流程详解&#xff08;7大核心步骤实战技巧&#xff09; 一、爬虫完整工作流程 以下是爬虫开发的完整流程&#xff0c;我将结合具体技术点和实战经验展开说明&#xff1a; 1. 目标分析与前期准备 网站技术分析&#xff1a; 使用浏览器开发者工具&#xff08;F12&…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南&#xff1a;计算机基础与源码原理深度解析 第一轮提问&#xff1a;基础概念问题 1. 请解释什么是进程和线程的区别&#xff1f; 面试官&#xff1a;进程是程序的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff1b;而线程是进程中的…...

【iOS】 Block再学习

iOS Block再学习 文章目录 iOS Block再学习前言Block的三种类型__ NSGlobalBlock____ NSMallocBlock____ NSStackBlock__小结 Block底层分析Block的结构捕获自由变量捕获全局(静态)变量捕获静态变量__block修饰符forwarding指针 Block的copy时机block作为函数返回值将block赋给…...

Linux-进程间的通信

1、IPC&#xff1a; Inter Process Communication&#xff08;进程间通信&#xff09;&#xff1a; 由于每个进程在操作系统中有独立的地址空间&#xff0c;它们不能像线程那样直接访问彼此的内存&#xff0c;所以必须通过某种方式进行通信。 常见的 IPC 方式包括&#…...

如何通过git命令查看项目连接的仓库地址?

要通过 Git 命令查看项目连接的仓库地址&#xff0c;您可以使用以下几种方法&#xff1a; 1. 查看所有远程仓库地址 使用 git remote -v 命令&#xff0c;它会显示项目中配置的所有远程仓库及其对应的 URL&#xff1a; git remote -v输出示例&#xff1a; origin https://…...