pytorch加载的cifar10数据集,到底有没有经过归一化
pytorch加载cifar10的归一化
- pytorch怎么加载cifar10数据集
- torchvision.datasets.CIFAR10
- transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
- torchvision.datasets加载的数据集搭配Dataloader使用
- model.train()和model.eval()
pytorch怎么加载cifar10数据集
torchvision.datasets.CIFAR10
pytorch里面的torchvision.datasets中提供了大多数计算机视觉领域相关任务的数据集,可以根据实际需要加载相关数据集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就调用torchvision.datasets.SVHN()。
针对cifar10数据集而言,调用torchvision.datasets.CIFAR10(),其中root是下载数据集后保存的位置;train是一个bool变量,为true就是训练数据集,false就是测试数据集;download也是一个bool变量,表示是否下载;transform是对数据集中的"image"进行一些操作,比如归一化、随机裁剪、各种数据增强操作等;target_transform是针对数据集中的"label"进行一些操作。
示例代码如下:
# 加载训练数据集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化 ]) )
# 加载测试数据集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化 ]) )
transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
上面的代码中,我们用transforms.Compose([……])组合了一系列的对image的操作,其中trandforms.ToTensor()
和transforms.Normalize()
都涉及到归一化操作:
-
原始的cifar10数据集是numpy array的形式,其中数据范围是[0,255],pytorch加载时,并没有改变数据范围,依旧是[0,255],加载后的数据维度是(H, W, C),源码部分:
-
__getitem__()
函数中进行transforms操作,进行了归一化:实际上传入的transform在__getitem__()
函数中被调用,其中transforms.Totensor()
会将data(也就是image)的维度变成(C,H, W)的形式,并且归一化到[0.0,1.0];
transforms.Normalize()
会根据z = (x-mean) / std 对数据进行归一化,上述代码中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
是可以将3个通道单独进行归一化,3个通道可以设置不同的mean和std,最终数据范围变成[-0.5,+0.5] 。
所以如果通过pytorch的cifar10加载数据集后,针对traindataset.data,依旧是没有进行归一化的;但是比如traindataset[index].data,其中[index]这样的按下标取元素的操作会直接调用的__getitem__()函数,此时的data就是经过了归一化的。
除traindataset[index]会隐式自动调用__getitem__()函数外,还有什么时候会调用这个函数呢?毕竟……只有调用了这个函数才会调用transforms中的归一化处理。——答案是与dataloader搭配使用!
torchvision.datasets加载的数据集搭配Dataloader使用
torchvision.datasets实际上是torch.utils.data.Dataset的子类,那么就能传入Dataloader中,迭代的按batch-size获取批量数据,用于训练或者测试。其中dataloader加载dataset中的数据时,就是用到了其__getitem__()函数,所以用dataloader加载数据集,得到的是经过归一化后的数据。
model.train()和model.eval()
我发现上面的问题,是我用dataloader加载了训练数据集用于训练resnet18模型,训练过程中,我训练好并保存后,顺便测试了一下在测试数据集上的准确度。但是在测试的过程中,我没有用dataloader加载测试数据集,而是直接用的dataset.data来进行的测试。并且!由于是并没有将model设置成model.eval()【其实我设置了,但是我对自己很无语,我写的model.eval,忘记加括号了,无语呜呜】……也就是即便我的测试数据集没有经过归一化,由于模型还是在model.train()模式下,因此模型的BN层会自己调整,使得模型性能不受影响,因此在测试数据集上的accuracy达到了0.86,我就没有多想。
后来我用模型的时候,设置了model.eval()后,依旧是直接用的dataset.data(也就是没有归一化),不管是在测试数据集上还是在训练数据集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才发现是归一化的问题。
- 在
model.train()
模式下进行预测时,PyTorch会默认启用一些训练相关的操作,例如Batch Normalization和Dropout,并且模型的参数是可变的,能够根据输入进行调整。这些操作在训练模式下可以帮助模型更好地适应训练数据,并产生较高的准确度。 - 在
model.eval()
模式下进行预测时,PyTorch会将模型切换到评估模式,这会导致一些训练相关的操作行为发生变化。具体而言,Batch Normalization层会使用训练集上的统计信息进行归一化,而不是使用当前批次的统计信息。因此,如果输入数据没有进行归一化,模型在评估模式下的准确度可能会显著下降。
以下是我没有用dataloader加载数据集,进行预测的代码:
def correctness(model,data,target, device):batchsize = 1000batch_num = int(len(data) / batchsize) # 对原始的数据进行操作 从H.W.C变成C.H.W data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)# 手动归一化data = data/255data = (data - 0.5) / 0.5 # 求一个batch的correctnessdef _batch_correctness(i):images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]predict = model(images).detach().cpu() correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)return correctnessresult = np.array([_batch_correctness(i) for i in range(batch_num)])return result.flatten().sum()/data.shape[0]
我后面用上面的代码测试了四种情况:
- model.eval() + 没有归一化:train_accuracy = 0.10,test_accuracy = 0.10;
- model.eval() + 手动归一化:train_accuracy = 0.95,test_accuracy = 0.84;
- model.train() + 没有归一化:train_accuracy = 0.95,test_accuracy = 0.83;
- model.train() + 手动归一化:train_accuracy = 0.94,test_accuracy = 0.84;
由此可见,在model.eval()模式下,数据归一化对最终的测试结果有很大影响。
相关文章:

pytorch加载的cifar10数据集,到底有没有经过归一化
pytorch加载cifar10的归一化 pytorch怎么加载cifar10数据集torchvision.datasets.CIFAR10transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】 torchvision.datasets加载的数据集搭配Dataloader使用model.train()和model.eval() pytorch怎么加载…...

Day1 ARM基础
【ARM课程认知】 1.ARM课程的作用 承上启下 基础授课阶段:c语言、数据结构、linux嵌入式应用层课程:IO、进程线程、网络编程嵌入式底层课程:ARM体系结构、系统移植、linux设备驱动c/QT 2.ARM课程需要掌握的内容 自己能够实现简单的汇编编…...

ns3入门基础教程
ns3入门基础教程 文章目录 ns3入门基础教程ns环境配置测试ns3环境ns3简单案例 ns环境配置 官方网站:https://www.nsnam.org/releases/ 代码仓库:https://gitlab.com/nsnam/ns-3-dev 如果安装遇到问题,可以参考以下博文: https://…...

计算机视觉
目录 一、图像处理 main denoise 二、Harris角点检测 三、Hough变换直线检测 四、直方图显著性检测 五、人脸识别 六、kmeans import 函数 kmeanstext 七、神经网络 常用函数: imread----------读取图像 imshow---------显示图像 rgb2hsv---------RGB转…...

NSSCTF第10页(3)
[LitCTF 2023]彩蛋 第一题: LitCTF{First_t0_The_k3y! (1/?) 第三题: <?php // 第三个彩蛋!(看过头号玩家么?) // R3ady_Pl4yer_000ne (3/?) ?> 第六题: wow 你找到了第二个彩蛋哦~ _S0_ne3t? (2/?) 第七题…...

MySQL性能分析工具的使用
1. 统计SQL的查询成本:last_query_cost SHOW STATUS LIKE last_query_cost; 使用场景:它对于比较开销是非常有用的,特别是我们有好几种查询方式可选的时候。 SQL 查询是一个动态的过程,从页加载的角度来看,我们可以得到…...

Uniapp使用AES128加解密16进制
在对接低功耗蓝牙时,我们需要对蓝牙传输数据进行加解密,由于我们对接的命令是16进制,如5500020101aa00,每个16进制表示特定的含义,所以直接对16进制加解密 import CryptoJS from crypto-js// AES128 加密函数 functio…...

C++基础——类与对象
1 概述 C是面向对象的语言,面向对象语言三大特性:封装、继承、多态。 C将万事万物抽象为对象,对象上有其属性和行为。 2 封装 2.1 封装的意义 封装是面向对象的三大特性之一,封装将属性和行为作为一个整体,对属性和…...

人工智能-卷积神经网络
从全连接层到卷积 我们之前讨论的多层感知机十分适合处理表格数据,其中行对应样本,列对应特征。 对于表格数据,我们寻找的模式可能涉及特征之间的交互,但是我们不能预先假设任何与特征交互相关的先验结构。 此时,多层感…...

MySQL的event的使用方法
MySQL的event的使用方法 一、事件定时策略 1、查看event事件开启状态 SHOW VARIABLES LIKE event_scheduler;如图,Value值 ON:打开,OFF:关闭。 2、设置event事件打开 SET GLOBAL event_scheduler ON;如果MySQL重启了&#x…...

Leetcode Daily Challenge 1845. Seat Reservation Manager
1845. Seat Reservation Manager 题目要求:初始化一个SeatManager类包括默认构造函数和类函数,所有的seat初始化为true。reverse函数返回最小的true,然后把这个编号的椅子赋值为false。unreverse(seatNumber)函数把编号为seatNumber的椅子恢…...

Blender vs 3ds Max:谁才是3D软件的未来
在不断发展的3D建模和动画领域,两大软件巨头Blender和3ds Max一直在争夺顶级地位。 随着技术的进步和用户需求的演变,一个重要问题逐渐浮出水面:Blender是否最终会取代3ds Max?本文将深入探讨二者各自的优势和劣势、当前状况&…...

MapReduce:大数据处理的范式
一、介绍 在当今的数字时代,生成和收集的数据量正以前所未有的速度增长。这种数据的爆炸式增长催生了大数据领域,传统的数据处理方法往往不足。MapReduce是一个编程模型和相关框架,已成为应对大数据处理挑战的强大解决方案。本文探讨了MapRed…...

【已解决】ModuleNotFoundError: No module named ‘dgl‘
禁止使用下面方法安装DGL,这种方法会更新你的pytorch版本,环境越变越乱 pip install dgl 二是进入DGL官网:Deep Graph Library (dgl.ai),了解自己的配置情况,比如我cuda11.8,ubuntu,当然和linux是一样的 …...

R 复习 菜鸟教程
R语言老师说R好就业,学就完了 基础语法 cat()可以拼接函数: > cat(1, "加", 1, "等于", 2, \n) 1 加 1 等于 2sink():重定向 sink("r_test.txt", splitTRUE) # 控制台同样输出 for (i in 1:5) print(i…...

第十二章《搞懂算法:朴素贝叶斯是怎么回事》笔记
朴素贝叶斯是经典的机器学习算法,也是统计模型中的一个基本方法。它的基本思想是利用统计学中的条件概率来进行分类。它是一种有监督学习算法,其中“朴素”是指该算法基于样本特征之间相互独立这个“朴素”假设。朴素贝叶斯原理简单、容易实现࿰…...

【从0到1开发一个网关】网关Mock功能的实现
文章目录 什么是Mock?如何实现Mock什么是Mock? Mock(模拟)是一种测试技术,用于创建虚拟对象来模拟真实对象的行为。Mock对象模拟了真实对象的行为,但是不依赖于真实对象的实现细节。它们可以在测试中替代真实对象,以便进行独立的单元测试。 需要使用Mock的原因包括以下几…...

前端框架Vue学习 ——(三)Vue生命周期
生命周期:指一个对象从创建到销毁的整个过程。 生命周期的八个阶段:每触发一个生命周期事件,会自动执行一个生命周期方法(钩子) mounted:挂载完成,Vue 初始化成功,HTML 页面渲染成功…...

相机滤镜软件Nevercenter CameraBag Photo mac中文版特点介绍
Nevercenter CameraBag Photo mac是一款相机和滤镜应用程序,它提供了一系列先进的滤镜、调整工具和预设,可以帮助用户快速地优化和编辑照片。 Nevercenter CameraBag Photo mac软件特点介绍 1. 滤镜:Nevercenter CameraBag Photo提供了超过2…...

游戏专用....
游戏专用:星际战甲 APP窗口以及键鼠监控 import tkinter as tk import time,threading from pynput.keyboard import Key,Listener import pynput.keyboard as kbclass myClass:def __init__(self):self.root tk.Tk()self.new_text self.flag threading.Event()…...

第三方登录和第三方支付
第三方登录 在现代Web应用中,提供第三方登录选项已经变得非常普遍。用户可以使用其社交媒体或其他在线帐户(如Google、GitHub或Facebook)来访问您的应用程序,而无需创建新的用户名和密码。这提供了更好的用户体验,减少…...

SpringMvc执行流程(含过滤器Filter+拦截器interceptor)
目录 1.Mvc的概念 2.SpringMvc的概念 3.SpringMvc的核心组件 4.SpringMvc的执行流程 5.SpringMvcFilterInterceptor执行流程 一、Mvc的概念 Mvc(Model View Controller):Mvc是一种设计规范,它将数据、视图、业务逻辑代码进行分离,降低代码…...

【UDS基础】简单介绍“统一诊断服务“
1. 前言 我们将在这个实用教程中介绍UDS的基础知识,重点关注在CAN总线上的UDS(UDSonCAN)和CAN诊断(DoCAN)。此外,我们还会介绍ISO-TP协议,并解释UDS、OBD2、WWH-OBD和OBDonUDS之间的差异。 最后,我们将解释如何请求、记录和解码UDS消息,并提供一些实际示例,例如记录…...

深度学习框架TensorFlow.NET之数据类型及张量2(C#)
环境搭建参考: 深度学习框架TensorFlow.NET环境搭建1(C#)-CSDN博客 由于本文作者水平有限,如有写得不对的地方,往指出 声明变量:tf.Variable 声明常量:tf.constant 下面通过代码的方式进行学…...

Pandas指定多列组合形成新列
目录 1、数据准备2、多列组合 1、数据准备 df pd.DataFrame({first_name: [A, B], last_name: [a, b]}) print(df.to_string()) first_name last_name 0 A a 1 B b 2、多列组合 2.1、方式一:使用cat() df[full_name] df[firs…...

硕鼠——视频下载利器
相信很多做自媒体、剪辑的同志们,经常会遇到一个棘手的问题:剪辑的素材从何而来。诸如很多高燃混剪的视频,往往需要多个影视作品中的原画来进行二次创作,可是这些视频素材从何而来呢? 有小伙伴们提出,通过录…...

Android 13.0 Launcher3 app图标长按去掉应用信息按钮
1.前言 在13.0的rom定制化开发中,在Launcher3定制化开发中,对Launcher3的定制化功能中,在Launcher3的app列表页会在长按时,弹出微件和应用信息两个按钮,点击对应的按钮跳转到相关的功能页面, 现在由于产品需求要求禁用应用信息,不让进入到应用信息页面所以要去掉应用信息…...

10 DETR 论文精读【论文精读】End-to-End Object Detection with Transformers
目录 DETR 这篇论文,大家为什么喜欢它?为什么大家说它是一个目标检测里的里程碑式的工作?而且为什么说它是一个全新的架构? 1 题目 2摘要 2.1新的任务定义:把这个目标检测这个任务直接看成是一个集合预测的问题 2.…...

高数笔记05:不定积分与定积分
图源:文心一言 时间比较紧张,仅导图~~🥝🥝 第1版:查资料、画导图~🧩🧩 参考资料:《高等数学 基础篇》武忠祥 🐳目录 🐳目录 🐳不定积分 &#…...

【代码随想录】算法训练计划13
1、347. 前 K 个高频元素 题目: 给你一个整数数组 nums 和一个整数 k ,请你返回其中出现频率前 k 高的元素。你可以按 任意顺序 返回答案。 输入: nums [1,1,1,2,2,3], k 2 输出: [1,2] 思路: sort.Slice学习一下,其实还有so…...