当前位置: 首页 > news >正文

【PyTorch】数据集

文章目录

  • 1. 创建数据集
    • 1.1. 直接继承Dataset类
    • 1.2. 使用TensorDataset类
  • 2. 数据集的划分
  • 3. 加载数据集
  • 4. 将数据转移到GPU

1. 创建数据集

主要是将数据集读入内存,并用Dataset类封装。

1.1. 直接继承Dataset类

必须要重写__getitem__方法,用于根据索引获得相应样本数据。必要时还可以重写__len__方法,用于返回数据集的大小。

from torch.utils.data import Datasetclass BostonHousingDataset(Dataset):"""定义波士顿房价数据集"""def __init__(self):self.data = np.load('../dataset/boston_housing/boston_housing.npz')def __getitem__(self, index):return self.data['x'][index], self.data['y'][index]def __len__(self):return self.data['x'].shape[0]

1.2. 使用TensorDataset类

将多个张量组合成一个数据集,要保证所有张量的第一个维度相等,保证每批样本数据格式相同。

import torch
from torch.utils.data import TensorDatasetdata = np.load('../dataset/boston_housing/boston_housing.npz')
X = torch.tensor(data['x'])
y = torch.tensor(data['y'])
dataset = TensorDataset(X, y)

2. 数据集的划分

数据集可以划分为训练集、验证集和测试集。

  • 训练集:用于模型拟合的数据样本集合。
  • 验证集:通常被用来调整模型的参数,以找出效果最佳的模型。
  • 测试集:用于训练好的模型性能评估的数据样本集合。
from torch.utils.data import random_splittrain_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

3. 加载数据集

使用DataLoader类将Dataset封装的数据集分成批次并进行迭代,以便于模型训练。DataLoader常用参数如下:

  • dataset
    要加载的数据集。
  • batch_size
    每个数据批次中包含的样本数。默认为1。
  • shuffle
    是否打乱数据集。默认为False。
  • num_workers
    使用几个进程来加载数据。默认为0,即在主进程中加载数据。
  • drop_last
    当数据集样本数不能被batch_size整除时,是否舍弃最后一个不完整的batch。默认为False。
from torch.utils.data import DataLoaderdataloader = DataLoader(dataset, batch_size=16, shuffle=True)

4. 将数据转移到GPU

一般在要运算时才将数据转移到GPU,有以下两种方法:

  1. var.to(device)
  2. var.cuda()
import torchdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for X,y in dataloader:# 将数据转移到GPUX = X.to(device)y = y.to(device)# 也可以X = X.cuda()y = y.cuda()

相关文章:

【PyTorch】数据集

文章目录 1. 创建数据集1.1. 直接继承Dataset类1.2. 使用TensorDataset类 2. 数据集的划分3. 加载数据集4. 将数据转移到GPU 1. 创建数据集 主要是将数据集读入内存,并用Dataset类封装。 1.1. 直接继承Dataset类 必须要重写__getitem__方法,用于根据索…...

oops-framework框架 之 本地存储(五)

引擎: CocosCreator 3.8.0 环境: Mac Gitee: oops-game-kit 注: 作者dgflash的oops-framework框架QQ群: 628575875 简介 在CocosCreator中,本地存储主要使用sys.localStorage 接口,通过 key-value的格式进…...

编程常见的问题

在现代社会中,编程已经成为一项非常重要的技能。随着科技的不断发展和普及,计算机已经渗透到我们生活的方方面面,从个人电脑、手机到智能家居、自动驾驶等。编程作为计算机科学的基础,为我们提供了解决问题和创造新事物的工具和方…...

针对Arrays.asList的坑,可以有哪些处理措施

上文讲述:Error querying database. Cause: java.lang.reflect.InaccessibleObjectException: 那么如果真的只习惯用Arrays.asList,那也是有对应的解决办法的。 一、解决办法大方向 不管做什么事情,都是先判定一个大方向,不管是…...

SE考研真题总结(一)

本帖开始分享考研真题中设计【软件工程】的部分,预计会出5期左右,敬请期待~ 一.单选题 1.程序编写不是软件质量保障过程~ 静态代码扫描是今年来多数被人提及的软件应用安全解决方案之一,指程序员在编写好代码后无需进行编译,直接…...

Xshell远程登录AWS EC2 Linux实例

文章目录 小结问题解决参考 小结 Xshell远程登录AWS EC2 Linux实例碰到些问题,进行解决并记录。 问题 在AWS中创建 EC2 Linux实例,生成的非对称密钥对,使用Xshell远程登录碰到一些问题。 解决 首先在Putty中可以使用的ppk密钥文件在Xshe…...

Elasticsearch:对时间序列数据流进行降采样(downsampling)

降采样提供了一种通过以降低的粒度存储时间序列数据来减少时间序列数据占用的方法。 指标(metrics)解决方案收集大量随时间增长的时间序列数据。 随着数据老化,它与系统当前状态的相关性越来越小。 降采样过程将固定时间间隔内的文档汇总为单…...

python自动化测试框架:unittest测试用例编写及执行

本文将介绍 unittest 自动化测试用例编写及执行的相关内容,包括测试用例编写、测试用例执行、测试报告等内容。 官方文档: https://docs.python.org/zh-cn/3/library/unittest.mock.html 1. 测试用例编写 在 unittest 中,一个测试用例通常…...

ctfhub技能树_web_web前置技能_HTTP

目录 一、HTTP协议 1.1、请求方式 1.2、302跳转 1.3、Cookie 1.4、基础认证 1.5、响应包源代码 一、HTTP协议 1.1、请求方式 注:HTTP协议中定义了八种请求方法。这八种都有:1、OPTIONS :返回服务器针对特定资源所支持的HTTP请求方法…...

mysql8报sql_mode=only_full_group_by(存储过程一直报)

1:修改数据库配置(重启失效) select global.sql_mode;会打印如下信息 ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ENGINE_SUBSTITUTION里面包含 ONLY_FULL_GROUP_BY,那么就重新设置,在数据库中输入以下代码,去掉ONLY_FULL_GROU…...

Vue2中v-html引发的安全问题

前言:v-html指令 1.作用:向指定节点中渲染包含html结构的内容。 2.与插值语法的区别: (1).v-html会替换掉节点中所有的内容,{{xx}}则不会。 (2).v-html可以识别html结构。 3.严重注意:v-html有安全性问题&#xff0…...

java内部类详解

文章目录 一、介绍二、为什么要使用内部类三、非静态内部类四、静态内部类五、局部内部类六、匿名内部类七、lambda表达式内部类八、成员重名九、序列化十、如何选择内部类 一、介绍 在java中,我们被允许在编写一个类(外部类OuterClass)时,在其内部再嵌…...

Python 潮流周刊#29:Rust 会比 Python 慢?!

△请给“Python猫”加星标 ,以免错过文章推送 你好,我是猫哥。这里每周分享优质的 Python、AI 及通用技术内容,大部分为英文。本周刊开源,欢迎投稿[1]。另有电报频道[2]作为副刊,补充发布更加丰富的资讯。 &#x1f43…...

吴恩达《机器学习》11-1-11-2:首先要做什么、误差分析

一、首先要做什么 选择特征向量的关键决策 以垃圾邮件分类器算法为例,首先需要决定如何选择和表达特征向量 𝑥。视频提到的一个示例是构建一个由 100 个最常出现在垃圾邮件中的词构成的列表,根据这些词是否在邮件中出现来创建特征向量&…...

Pandas在Excel同一个sheet里插入多个Dataframe和行

Pandas默认的to_excel是直接把完成的Datafrme写入一个sheet里,这并不能满足我们在一个sheet里插入多个Dataframe或多行的需求。为了实现插入多行或多Dataframe的目的,我们需要新建一个ExcelWriter对象,然后依次插入数据。 这里我们以插入2个Dataframe和三行单元格为例。 新…...

查看mysql 或SQL server 的连接数,mysql超时、最大连接数配置

1、mysql 的连接数 1.1、最大可连接数 show variables like max_connections; 1.2、运行中连接数 show status like Threads_connected; 1.3、配置最大连接数, mysql版本不同可配置的最大连接数不同,mysql8.0的版本默认151个连接数,…...

C++学习之路(七)C++ 实现简单的Qt界面(消息弹框、按钮点击事件监听)- 示例代码拆分讲解

这个示例创建了一个主窗口,其中包含两个按钮。第一个按钮点击时会显示一个简单的消息框,第二个按钮点击时会执行一个特定的操作(在这个例子中,仅打印一条调试信息)。 功能描述: 创建窗口和布局:…...

python实现一个计算器

实现一个计算器首先熟悉一下这个阅读器的属性import subprocess subprocess.run(["espeak", "-v", "enf3", "This is a Calculator"])class Calculator:def speaker(self,word):subprocess.run(["espeak", "-v", …...

C++ 共享内存ShellCode跨进程传输

在计算机安全领域,ShellCode是一段用于利用系统漏洞或执行特定任务的机器码。为了增加攻击的难度,研究人员经常探索新的传递ShellCode的方式。本文介绍了一种使用共享内存的方法,通过该方法,两个本地进程可以相互传递ShellCode&am…...

如何快速移植(从STM32F103到STM32F407)

最近用到F4的地方比较多,网上代码还是F1多一些,便需要移植代码,如何快速移植代码呢? 看下面这篇文章 外设 首先就是STM32的外设了。 STM32F407ZGT6的基本外设 STM32F407ZGT6 作为 MCU,该芯片是 STM32F407 里面配置…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

UDP(Echoserver)

网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...

蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练

前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1):从基础到实战的深度解析-CSDN博客,但实际面试中,企业更关注候选人对复杂场景的应对能力(如多设备并发扫描、低功耗与高发现率的平衡)和前沿技术的…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析:CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展,AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者,分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...

动态 Web 开发技术入门篇

一、HTTP 协议核心 1.1 HTTP 基础 协议全称 :HyperText Transfer Protocol(超文本传输协议) 默认端口 :HTTP 使用 80 端口,HTTPS 使用 443 端口。 请求方法 : GET :用于获取资源,…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...