pytorch加载的cifar10数据集,到底有没有经过归一化
pytorch加载cifar10的归一化
- pytorch怎么加载cifar10数据集
- torchvision.datasets.CIFAR10
- transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
- torchvision.datasets加载的数据集搭配Dataloader使用
- model.train()和model.eval()
pytorch怎么加载cifar10数据集
torchvision.datasets.CIFAR10
pytorch里面的torchvision.datasets中提供了大多数计算机视觉领域相关任务的数据集,可以根据实际需要加载相关数据集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就调用torchvision.datasets.SVHN()。
针对cifar10数据集而言,调用torchvision.datasets.CIFAR10(),其中root是下载数据集后保存的位置;train是一个bool变量,为true就是训练数据集,false就是测试数据集;download也是一个bool变量,表示是否下载;transform是对数据集中的"image"进行一些操作,比如归一化、随机裁剪、各种数据增强操作等;target_transform是针对数据集中的"label"进行一些操作。

示例代码如下:
# 加载训练数据集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化 ]) )
# 加载测试数据集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化 ]) )
transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
上面的代码中,我们用transforms.Compose([……])组合了一系列的对image的操作,其中trandforms.ToTensor()和transforms.Normalize()都涉及到归一化操作:
-
原始的cifar10数据集是numpy array的形式,其中数据范围是[0,255],pytorch加载时,并没有改变数据范围,依旧是[0,255],加载后的数据维度是(H, W, C),源码部分:

-
__getitem__()函数中进行transforms操作,进行了归一化:实际上传入的transform在__getitem__()函数中被调用,其中transforms.Totensor()会将data(也就是image)的维度变成(C,H, W)的形式,并且归一化到[0.0,1.0];

transforms.Normalize()会根据z = (x-mean) / std 对数据进行归一化,上述代码中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]是可以将3个通道单独进行归一化,3个通道可以设置不同的mean和std,最终数据范围变成[-0.5,+0.5] 。

所以如果通过pytorch的cifar10加载数据集后,针对traindataset.data,依旧是没有进行归一化的;但是比如traindataset[index].data,其中[index]这样的按下标取元素的操作会直接调用的__getitem__()函数,此时的data就是经过了归一化的。
除traindataset[index]会隐式自动调用__getitem__()函数外,还有什么时候会调用这个函数呢?毕竟……只有调用了这个函数才会调用transforms中的归一化处理。——答案是与dataloader搭配使用!
torchvision.datasets加载的数据集搭配Dataloader使用
torchvision.datasets实际上是torch.utils.data.Dataset的子类,那么就能传入Dataloader中,迭代的按batch-size获取批量数据,用于训练或者测试。其中dataloader加载dataset中的数据时,就是用到了其__getitem__()函数,所以用dataloader加载数据集,得到的是经过归一化后的数据。

model.train()和model.eval()
我发现上面的问题,是我用dataloader加载了训练数据集用于训练resnet18模型,训练过程中,我训练好并保存后,顺便测试了一下在测试数据集上的准确度。但是在测试的过程中,我没有用dataloader加载测试数据集,而是直接用的dataset.data来进行的测试。并且!由于是并没有将model设置成model.eval()【其实我设置了,但是我对自己很无语,我写的model.eval,忘记加括号了,无语呜呜】……也就是即便我的测试数据集没有经过归一化,由于模型还是在model.train()模式下,因此模型的BN层会自己调整,使得模型性能不受影响,因此在测试数据集上的accuracy达到了0.86,我就没有多想。
后来我用模型的时候,设置了model.eval()后,依旧是直接用的dataset.data(也就是没有归一化),不管是在测试数据集上还是在训练数据集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才发现是归一化的问题。
- 在
model.train()模式下进行预测时,PyTorch会默认启用一些训练相关的操作,例如Batch Normalization和Dropout,并且模型的参数是可变的,能够根据输入进行调整。这些操作在训练模式下可以帮助模型更好地适应训练数据,并产生较高的准确度。 - 在
model.eval()模式下进行预测时,PyTorch会将模型切换到评估模式,这会导致一些训练相关的操作行为发生变化。具体而言,Batch Normalization层会使用训练集上的统计信息进行归一化,而不是使用当前批次的统计信息。因此,如果输入数据没有进行归一化,模型在评估模式下的准确度可能会显著下降。
以下是我没有用dataloader加载数据集,进行预测的代码:
def correctness(model,data,target, device):batchsize = 1000batch_num = int(len(data) / batchsize) # 对原始的数据进行操作 从H.W.C变成C.H.W data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)# 手动归一化data = data/255data = (data - 0.5) / 0.5 # 求一个batch的correctnessdef _batch_correctness(i):images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]predict = model(images).detach().cpu() correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)return correctnessresult = np.array([_batch_correctness(i) for i in range(batch_num)])return result.flatten().sum()/data.shape[0]
我后面用上面的代码测试了四种情况:
- model.eval() + 没有归一化:train_accuracy = 0.10,test_accuracy = 0.10;
- model.eval() + 手动归一化:train_accuracy = 0.95,test_accuracy = 0.84;
- model.train() + 没有归一化:train_accuracy = 0.95,test_accuracy = 0.83;
- model.train() + 手动归一化:train_accuracy = 0.94,test_accuracy = 0.84;
由此可见,在model.eval()模式下,数据归一化对最终的测试结果有很大影响。
相关文章:
pytorch加载的cifar10数据集,到底有没有经过归一化
pytorch加载cifar10的归一化 pytorch怎么加载cifar10数据集torchvision.datasets.CIFAR10transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】 torchvision.datasets加载的数据集搭配Dataloader使用model.train()和model.eval() pytorch怎么加载…...
Day1 ARM基础
【ARM课程认知】 1.ARM课程的作用 承上启下 基础授课阶段:c语言、数据结构、linux嵌入式应用层课程:IO、进程线程、网络编程嵌入式底层课程:ARM体系结构、系统移植、linux设备驱动c/QT 2.ARM课程需要掌握的内容 自己能够实现简单的汇编编…...
ns3入门基础教程
ns3入门基础教程 文章目录 ns3入门基础教程ns环境配置测试ns3环境ns3简单案例 ns环境配置 官方网站:https://www.nsnam.org/releases/ 代码仓库:https://gitlab.com/nsnam/ns-3-dev 如果安装遇到问题,可以参考以下博文: https://…...
计算机视觉
目录 一、图像处理 main denoise 二、Harris角点检测 三、Hough变换直线检测 四、直方图显著性检测 五、人脸识别 六、kmeans import 函数 kmeanstext 七、神经网络 常用函数: imread----------读取图像 imshow---------显示图像 rgb2hsv---------RGB转…...
NSSCTF第10页(3)
[LitCTF 2023]彩蛋 第一题: LitCTF{First_t0_The_k3y! (1/?) 第三题: <?php // 第三个彩蛋!(看过头号玩家么?) // R3ady_Pl4yer_000ne (3/?) ?> 第六题: wow 你找到了第二个彩蛋哦~ _S0_ne3t? (2/?) 第七题…...
MySQL性能分析工具的使用
1. 统计SQL的查询成本:last_query_cost SHOW STATUS LIKE last_query_cost; 使用场景:它对于比较开销是非常有用的,特别是我们有好几种查询方式可选的时候。 SQL 查询是一个动态的过程,从页加载的角度来看,我们可以得到…...
Uniapp使用AES128加解密16进制
在对接低功耗蓝牙时,我们需要对蓝牙传输数据进行加解密,由于我们对接的命令是16进制,如5500020101aa00,每个16进制表示特定的含义,所以直接对16进制加解密 import CryptoJS from crypto-js// AES128 加密函数 functio…...
C++基础——类与对象
1 概述 C是面向对象的语言,面向对象语言三大特性:封装、继承、多态。 C将万事万物抽象为对象,对象上有其属性和行为。 2 封装 2.1 封装的意义 封装是面向对象的三大特性之一,封装将属性和行为作为一个整体,对属性和…...
人工智能-卷积神经网络
从全连接层到卷积 我们之前讨论的多层感知机十分适合处理表格数据,其中行对应样本,列对应特征。 对于表格数据,我们寻找的模式可能涉及特征之间的交互,但是我们不能预先假设任何与特征交互相关的先验结构。 此时,多层感…...
MySQL的event的使用方法
MySQL的event的使用方法 一、事件定时策略 1、查看event事件开启状态 SHOW VARIABLES LIKE event_scheduler;如图,Value值 ON:打开,OFF:关闭。 2、设置event事件打开 SET GLOBAL event_scheduler ON;如果MySQL重启了&#x…...
Leetcode Daily Challenge 1845. Seat Reservation Manager
1845. Seat Reservation Manager 题目要求:初始化一个SeatManager类包括默认构造函数和类函数,所有的seat初始化为true。reverse函数返回最小的true,然后把这个编号的椅子赋值为false。unreverse(seatNumber)函数把编号为seatNumber的椅子恢…...
Blender vs 3ds Max:谁才是3D软件的未来
在不断发展的3D建模和动画领域,两大软件巨头Blender和3ds Max一直在争夺顶级地位。 随着技术的进步和用户需求的演变,一个重要问题逐渐浮出水面:Blender是否最终会取代3ds Max?本文将深入探讨二者各自的优势和劣势、当前状况&…...
MapReduce:大数据处理的范式
一、介绍 在当今的数字时代,生成和收集的数据量正以前所未有的速度增长。这种数据的爆炸式增长催生了大数据领域,传统的数据处理方法往往不足。MapReduce是一个编程模型和相关框架,已成为应对大数据处理挑战的强大解决方案。本文探讨了MapRed…...
【已解决】ModuleNotFoundError: No module named ‘dgl‘
禁止使用下面方法安装DGL,这种方法会更新你的pytorch版本,环境越变越乱 pip install dgl 二是进入DGL官网:Deep Graph Library (dgl.ai),了解自己的配置情况,比如我cuda11.8,ubuntu,当然和linux是一样的 …...
R 复习 菜鸟教程
R语言老师说R好就业,学就完了 基础语法 cat()可以拼接函数: > cat(1, "加", 1, "等于", 2, \n) 1 加 1 等于 2sink():重定向 sink("r_test.txt", splitTRUE) # 控制台同样输出 for (i in 1:5) print(i…...
第十二章《搞懂算法:朴素贝叶斯是怎么回事》笔记
朴素贝叶斯是经典的机器学习算法,也是统计模型中的一个基本方法。它的基本思想是利用统计学中的条件概率来进行分类。它是一种有监督学习算法,其中“朴素”是指该算法基于样本特征之间相互独立这个“朴素”假设。朴素贝叶斯原理简单、容易实现࿰…...
【从0到1开发一个网关】网关Mock功能的实现
文章目录 什么是Mock?如何实现Mock什么是Mock? Mock(模拟)是一种测试技术,用于创建虚拟对象来模拟真实对象的行为。Mock对象模拟了真实对象的行为,但是不依赖于真实对象的实现细节。它们可以在测试中替代真实对象,以便进行独立的单元测试。 需要使用Mock的原因包括以下几…...
前端框架Vue学习 ——(三)Vue生命周期
生命周期:指一个对象从创建到销毁的整个过程。 生命周期的八个阶段:每触发一个生命周期事件,会自动执行一个生命周期方法(钩子) mounted:挂载完成,Vue 初始化成功,HTML 页面渲染成功…...
相机滤镜软件Nevercenter CameraBag Photo mac中文版特点介绍
Nevercenter CameraBag Photo mac是一款相机和滤镜应用程序,它提供了一系列先进的滤镜、调整工具和预设,可以帮助用户快速地优化和编辑照片。 Nevercenter CameraBag Photo mac软件特点介绍 1. 滤镜:Nevercenter CameraBag Photo提供了超过2…...
游戏专用....
游戏专用:星际战甲 APP窗口以及键鼠监控 import tkinter as tk import time,threading from pynput.keyboard import Key,Listener import pynput.keyboard as kbclass myClass:def __init__(self):self.root tk.Tk()self.new_text self.flag threading.Event()…...
eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)
说明: 想象一下,你正在用eNSP搭建一个虚拟的网络世界,里面有虚拟的路由器、交换机、电脑(PC)等等。这些设备都在你的电脑里面“运行”,它们之间可以互相通信,就像一个封闭的小王国。 但是&#…...
基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...
C++:std::is_convertible
C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...
shell脚本--常见案例
1、自动备份文件或目录 2、批量重命名文件 3、查找并删除指定名称的文件: 4、批量删除文件 5、查找并替换文件内容 6、批量创建文件 7、创建文件夹并移动文件 8、在文件夹中查找文件...
大数据零基础学习day1之环境准备和大数据初步理解
学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 (1)设置网关 打开VMware虚拟机,点击编辑…...
Linux云原生安全:零信任架构与机密计算
Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...
ABAP设计模式之---“简单设计原则(Simple Design)”
“Simple Design”(简单设计)是软件开发中的一个重要理念,倡导以最简单的方式实现软件功能,以确保代码清晰易懂、易维护,并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计,遵循“让事情保…...
Kubernetes 节点自动伸缩(Cluster Autoscaler)原理与实践
在 Kubernetes 集群中,如何在保障应用高可用的同时有效地管理资源,一直是运维人员和开发者关注的重点。随着微服务架构的普及,集群内各个服务的负载波动日趋明显,传统的手动扩缩容方式已无法满足实时性和弹性需求。 Cluster Auto…...
从零开始了解数据采集(二十八)——制造业数字孪生
近年来,我国的工业领域正经历一场前所未有的数字化变革,从“双碳目标”到工业互联网平台的推广,国家政策和市场需求共同推动了制造业的升级。在这场变革中,数字孪生技术成为备受关注的关键工具,它不仅让企业“看见”设…...
[特殊字符] 手撸 Redis 互斥锁那些坑
📖 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作,想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁,也顺便跟 Redisson 的 RLock 机制对比了下,记录一波,别踩我踩过…...
