当前位置: 首页 > news >正文

pytorch加载的cifar10数据集,到底有没有经过归一化

pytorch加载cifar10的归一化

  • pytorch怎么加载cifar10数据集
    • torchvision.datasets.CIFAR10
    • transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
  • torchvision.datasets加载的数据集搭配Dataloader使用
  • model.train()和model.eval()

pytorch怎么加载cifar10数据集

torchvision.datasets.CIFAR10

pytorch里面的torchvision.datasets中提供了大多数计算机视觉领域相关任务的数据集,可以根据实际需要加载相关数据集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就调用torchvision.datasets.SVHN()。

针对cifar10数据集而言,调用torchvision.datasets.CIFAR10(),其中root是下载数据集后保存的位置;train是一个bool变量,为true就是训练数据集,false就是测试数据集;download也是一个bool变量,表示是否下载;transform是对数据集中的"image"进行一些操作,比如归一化、随机裁剪、各种数据增强操作等;target_transform是针对数据集中的"label"进行一些操作。
在这里插入图片描述

示例代码如下:

# 加载训练数据集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,transform= transforms.Compose([  transforms.ToTensor(),  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化  ])  )
# 加载测试数据集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, transform= transforms.Compose([  transforms.ToTensor(),  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化  ])  )

transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】

上面的代码中,我们用transforms.Compose([……])组合了一系列的对image的操作,其中trandforms.ToTensor()transforms.Normalize()都涉及到归一化操作:

  • 原始的cifar10数据集是numpy array的形式,其中数据范围是[0,255],pytorch加载时,并没有改变数据范围,依旧是[0,255],加载后的数据维度是(H, W, C),源码部分:
    在这里插入图片描述

  • __getitem__()函数中进行transforms操作,进行了归一化:实际上传入的transform在__getitem__()函数中被调用,其中transforms.Totensor()会将data(也就是image)的维度变成(C,H, W)的形式,并且归一化到[0.0,1.0]

在这里插入图片描述

  • transforms.Normalize()会根据z = (x-mean) / std 对数据进行归一化,上述代码中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]是可以将3个通道单独进行归一化,3个通道可以设置不同的mean和std,最终数据范围变成[-0.5,+0.5] 。
    在这里插入图片描述

所以如果通过pytorch的cifar10加载数据集后,针对traindataset.data,依旧是没有进行归一化的;但是比如traindataset[index].data,其中[index]这样的按下标取元素的操作会直接调用的__getitem__()函数,此时的data就是经过了归一化的。
除traindataset[index]会隐式自动调用__getitem__()函数外,还有什么时候会调用这个函数呢?毕竟……只有调用了这个函数才会调用transforms中的归一化处理。——答案是与dataloader搭配使用!

torchvision.datasets加载的数据集搭配Dataloader使用

torchvision.datasets实际上是torch.utils.data.Dataset的子类,那么就能传入Dataloader中,迭代的按batch-size获取批量数据,用于训练或者测试。其中dataloader加载dataset中的数据时,就是用到了其__getitem__()函数,所以用dataloader加载数据集,得到的是经过归一化后的数据
在这里插入图片描述

model.train()和model.eval()

我发现上面的问题,是我用dataloader加载了训练数据集用于训练resnet18模型,训练过程中,我训练好并保存后,顺便测试了一下在测试数据集上的准确度。但是在测试的过程中,我没有用dataloader加载测试数据集,而是直接用的dataset.data来进行的测试。并且!由于是并没有将model设置成model.eval()【其实我设置了,但是我对自己很无语,我写的model.eval,忘记加括号了,无语呜呜】……也就是即便我的测试数据集没有经过归一化,由于模型还是在model.train()模式下,因此模型的BN层会自己调整,使得模型性能不受影响,因此在测试数据集上的accuracy达到了0.86,我就没有多想。
后来我用模型的时候,设置了model.eval()后,依旧是直接用的dataset.data(也就是没有归一化),不管是在测试数据集上还是在训练数据集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才发现是归一化的问题。

  • model.train()模式下进行预测时,PyTorch会默认启用一些训练相关的操作,例如Batch Normalization和Dropout,并且模型的参数是可变的,能够根据输入进行调整。这些操作在训练模式下可以帮助模型更好地适应训练数据,并产生较高的准确度。
  • model.eval()模式下进行预测时,PyTorch会将模型切换到评估模式,这会导致一些训练相关的操作行为发生变化。具体而言,Batch Normalization层会使用训练集上的统计信息进行归一化,而不是使用当前批次的统计信息。因此,如果输入数据没有进行归一化,模型在评估模式下的准确度可能会显著下降。

以下是我没有用dataloader加载数据集,进行预测的代码:

def correctness(model,data,target, device):batchsize = 1000batch_num = int(len(data) / batchsize)   # 对原始的数据进行操作 从H.W.C变成C.H.W data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)# 手动归一化data = data/255data = (data - 0.5) / 0.5 # 求一个batch的correctnessdef _batch_correctness(i):images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]predict = model(images).detach().cpu()    correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)return correctnessresult = np.array([_batch_correctness(i) for i in range(batch_num)])return result.flatten().sum()/data.shape[0]

我后面用上面的代码测试了四种情况:

  • model.eval() + 没有归一化:train_accuracy = 0.10,test_accuracy = 0.10;
  • model.eval() + 手动归一化:train_accuracy = 0.95,test_accuracy = 0.84;
  • model.train() + 没有归一化:train_accuracy = 0.95,test_accuracy = 0.83;
  • model.train() + 手动归一化:train_accuracy = 0.94,test_accuracy = 0.84;

由此可见,在model.eval()模式下,数据归一化对最终的测试结果有很大影响。

相关文章:

pytorch加载的cifar10数据集,到底有没有经过归一化

pytorch加载cifar10的归一化 pytorch怎么加载cifar10数据集torchvision.datasets.CIFAR10transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】 torchvision.datasets加载的数据集搭配Dataloader使用model.train()和model.eval() pytorch怎么加载…...

Day1 ARM基础

【ARM课程认知】 1.ARM课程的作用 承上启下 基础授课阶段:c语言、数据结构、linux嵌入式应用层课程:IO、进程线程、网络编程嵌入式底层课程:ARM体系结构、系统移植、linux设备驱动c/QT 2.ARM课程需要掌握的内容 自己能够实现简单的汇编编…...

ns3入门基础教程

ns3入门基础教程 文章目录 ns3入门基础教程ns环境配置测试ns3环境ns3简单案例 ns环境配置 官方网站:https://www.nsnam.org/releases/ 代码仓库:https://gitlab.com/nsnam/ns-3-dev 如果安装遇到问题,可以参考以下博文: https://…...

计算机视觉

目录 一、图像处理 main denoise 二、Harris角点检测 三、Hough变换直线检测 四、直方图显著性检测 五、人脸识别 六、kmeans import 函数 kmeanstext 七、神经网络 常用函数: imread----------读取图像 imshow---------显示图像 rgb2hsv---------RGB转…...

NSSCTF第10页(3)

[LitCTF 2023]彩蛋 第一题&#xff1a; LitCTF{First_t0_The_k3y! (1/?) 第三题&#xff1a; <?php // 第三个彩蛋&#xff01;(看过头号玩家么&#xff1f;) // R3ady_Pl4yer_000ne (3/?) ?> 第六题&#xff1a; wow 你找到了第二个彩蛋哦~ _S0_ne3t? (2/?) 第七题…...

MySQL性能分析工具的使用

1. 统计SQL的查询成本&#xff1a;last_query_cost SHOW STATUS LIKE last_query_cost; 使用场景&#xff1a;它对于比较开销是非常有用的&#xff0c;特别是我们有好几种查询方式可选的时候。 SQL 查询是一个动态的过程&#xff0c;从页加载的角度来看&#xff0c;我们可以得到…...

Uniapp使用AES128加解密16进制

在对接低功耗蓝牙时&#xff0c;我们需要对蓝牙传输数据进行加解密&#xff0c;由于我们对接的命令是16进制&#xff0c;如5500020101aa00&#xff0c;每个16进制表示特定的含义&#xff0c;所以直接对16进制加解密 import CryptoJS from crypto-js// AES128 加密函数 functio…...

C++基础——类与对象

1 概述 C是面向对象的语言&#xff0c;面向对象语言三大特性&#xff1a;封装、继承、多态。 C将万事万物抽象为对象&#xff0c;对象上有其属性和行为。 2 封装 2.1 封装的意义 封装是面向对象的三大特性之一&#xff0c;封装将属性和行为作为一个整体&#xff0c;对属性和…...

人工智能-卷积神经网络

从全连接层到卷积 我们之前讨论的多层感知机十分适合处理表格数据&#xff0c;其中行对应样本&#xff0c;列对应特征。 对于表格数据&#xff0c;我们寻找的模式可能涉及特征之间的交互&#xff0c;但是我们不能预先假设任何与特征交互相关的先验结构。 此时&#xff0c;多层感…...

MySQL的event的使用方法

MySQL的event的使用方法 一、事件定时策略 1、查看event事件开启状态 SHOW VARIABLES LIKE event_scheduler;如图&#xff0c;Value值 ON&#xff1a;打开&#xff0c;OFF&#xff1a;关闭。 2、设置event事件打开 SET GLOBAL event_scheduler ON;如果MySQL重启了&#x…...

Leetcode Daily Challenge 1845. Seat Reservation Manager

1845. Seat Reservation Manager 题目要求&#xff1a;初始化一个SeatManager类包括默认构造函数和类函数&#xff0c;所有的seat初始化为true。reverse函数返回最小的true&#xff0c;然后把这个编号的椅子赋值为false。unreverse(seatNumber)函数把编号为seatNumber的椅子恢…...

Blender vs 3ds Max:谁才是3D软件的未来

在不断发展的3D建模和动画领域&#xff0c;两大软件巨头Blender和3ds Max一直在争夺顶级地位。 随着技术的进步和用户需求的演变&#xff0c;一个重要问题逐渐浮出水面&#xff1a;Blender是否最终会取代3ds Max&#xff1f;本文将深入探讨二者各自的优势和劣势、当前状况&…...

MapReduce:大数据处理的范式

一、介绍 在当今的数字时代&#xff0c;生成和收集的数据量正以前所未有的速度增长。这种数据的爆炸式增长催生了大数据领域&#xff0c;传统的数据处理方法往往不足。MapReduce是一个编程模型和相关框架&#xff0c;已成为应对大数据处理挑战的强大解决方案。本文探讨了MapRed…...

【已解决】ModuleNotFoundError: No module named ‘dgl‘

禁止使用下面方法安装DGL,这种方法会更新你的pytorch版本&#xff0c;环境越变越乱 pip install dgl 二是进入DGL官网&#xff1a;Deep Graph Library (dgl.ai)&#xff0c;了解自己的配置情况&#xff0c;比如我cuda11.8&#xff0c;ubuntu&#xff0c;当然和linux是一样的 …...

R 复习 菜鸟教程

R语言老师说R好就业&#xff0c;学就完了 基础语法 cat()可以拼接函数&#xff1a; > cat(1, "加", 1, "等于", 2, \n) 1 加 1 等于 2sink()&#xff1a;重定向 sink("r_test.txt", splitTRUE) # 控制台同样输出 for (i in 1:5) print(i…...

第十二章《搞懂算法:朴素贝叶斯是怎么回事》笔记

朴素贝叶斯是经典的机器学习算法&#xff0c;也是统计模型中的一个基本方法。它的基本思想是利用统计学中的条件概率来进行分类。它是一种有监督学习算法&#xff0c;其中“朴素”是指该算法基于样本特征之间相互独立这个“朴素”假设。朴素贝叶斯原理简单、容易实现&#xff0…...

【从0到1开发一个网关】网关Mock功能的实现

文章目录 什么是Mock?如何实现Mock什么是Mock? Mock(模拟)是一种测试技术,用于创建虚拟对象来模拟真实对象的行为。Mock对象模拟了真实对象的行为,但是不依赖于真实对象的实现细节。它们可以在测试中替代真实对象,以便进行独立的单元测试。 需要使用Mock的原因包括以下几…...

前端框架Vue学习 ——(三)Vue生命周期

生命周期&#xff1a;指一个对象从创建到销毁的整个过程。 生命周期的八个阶段&#xff1a;每触发一个生命周期事件&#xff0c;会自动执行一个生命周期方法&#xff08;钩子&#xff09; mounted&#xff1a;挂载完成&#xff0c;Vue 初始化成功&#xff0c;HTML 页面渲染成功…...

相机滤镜软件Nevercenter CameraBag Photo mac中文版特点介绍

Nevercenter CameraBag Photo mac是一款相机和滤镜应用程序&#xff0c;它提供了一系列先进的滤镜、调整工具和预设&#xff0c;可以帮助用户快速地优化和编辑照片。 Nevercenter CameraBag Photo mac软件特点介绍 1. 滤镜&#xff1a;Nevercenter CameraBag Photo提供了超过2…...

游戏专用....

游戏专用&#xff1a;星际战甲 APP窗口以及键鼠监控 import tkinter as tk import time,threading from pynput.keyboard import Key,Listener import pynput.keyboard as kbclass myClass:def __init__(self):self.root tk.Tk()self.new_text self.flag threading.Event()…...

第三方登录和第三方支付

第三方登录 在现代Web应用中&#xff0c;提供第三方登录选项已经变得非常普遍。用户可以使用其社交媒体或其他在线帐户&#xff08;如Google、GitHub或Facebook&#xff09;来访问您的应用程序&#xff0c;而无需创建新的用户名和密码。这提供了更好的用户体验&#xff0c;减少…...

SpringMvc执行流程(含过滤器Filter+拦截器interceptor)

目录 1.Mvc的概念 2.SpringMvc的概念 3.SpringMvc的核心组件 4.SpringMvc的执行流程 5.SpringMvcFilterInterceptor执行流程 一、Mvc的概念 Mvc(Model View Controller)&#xff1a;Mvc是一种设计规范&#xff0c;它将数据、视图、业务逻辑代码进行分离&#xff0c;降低代码…...

【UDS基础】简单介绍“统一诊断服务“

1. 前言 我们将在这个实用教程中介绍UDS的基础知识,重点关注在CAN总线上的UDS(UDSonCAN)和CAN诊断(DoCAN)。此外,我们还会介绍ISO-TP协议,并解释UDS、OBD2、WWH-OBD和OBDonUDS之间的差异。 最后,我们将解释如何请求、记录和解码UDS消息,并提供一些实际示例,例如记录…...

深度学习框架TensorFlow.NET之数据类型及张量2(C#)

环境搭建参考&#xff1a; 深度学习框架TensorFlow.NET环境搭建1&#xff08;C#&#xff09;-CSDN博客 由于本文作者水平有限&#xff0c;如有写得不对的地方&#xff0c;往指出 声明变量&#xff1a;tf.Variable 声明常量&#xff1a;tf.constant 下面通过代码的方式进行学…...

Pandas指定多列组合形成新列

目录 1、数据准备2、多列组合 1、数据准备 df pd.DataFrame({first_name: [A, B], last_name: [a, b]}) print(df.to_string()) first_name last_name 0 A a 1 B b 2、多列组合 2.1、方式一&#xff1a;使用cat() df[full_name] df[firs…...

硕鼠——视频下载利器

相信很多做自媒体、剪辑的同志们&#xff0c;经常会遇到一个棘手的问题&#xff1a;剪辑的素材从何而来。诸如很多高燃混剪的视频&#xff0c;往往需要多个影视作品中的原画来进行二次创作&#xff0c;可是这些视频素材从何而来呢&#xff1f; 有小伙伴们提出&#xff0c;通过录…...

Android 13.0 Launcher3 app图标长按去掉应用信息按钮

1.前言 在13.0的rom定制化开发中,在Launcher3定制化开发中,对Launcher3的定制化功能中,在Launcher3的app列表页会在长按时,弹出微件和应用信息两个按钮,点击对应的按钮跳转到相关的功能页面, 现在由于产品需求要求禁用应用信息,不让进入到应用信息页面所以要去掉应用信息…...

10 DETR 论文精读【论文精读】End-to-End Object Detection with Transformers

目录 DETR 这篇论文&#xff0c;大家为什么喜欢它&#xff1f;为什么大家说它是一个目标检测里的里程碑式的工作&#xff1f;而且为什么说它是一个全新的架构&#xff1f; 1 题目 2摘要 2.1新的任务定义&#xff1a;把这个目标检测这个任务直接看成是一个集合预测的问题 2.…...

高数笔记05:不定积分与定积分

图源&#xff1a;文心一言 时间比较紧张&#xff0c;仅导图~~&#x1f95d;&#x1f95d; 第1版&#xff1a;查资料、画导图~&#x1f9e9;&#x1f9e9; 参考资料&#xff1a;《高等数学 基础篇》武忠祥 &#x1f433;目录 &#x1f433;目录 &#x1f433;不定积分 &#…...

【代码随想录】算法训练计划13

1、347. 前 K 个高频元素 题目&#xff1a; 给你一个整数数组 nums 和一个整数 k &#xff0c;请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。 输入: nums [1,1,1,2,2,3], k 2 输出: [1,2] 思路&#xff1a; sort.Slice学习一下&#xff0c;其实还有so…...