pytorch加载的cifar10数据集,到底有没有经过归一化
pytorch加载cifar10的归一化
- pytorch怎么加载cifar10数据集
- torchvision.datasets.CIFAR10
- transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
- torchvision.datasets加载的数据集搭配Dataloader使用
- model.train()和model.eval()
pytorch怎么加载cifar10数据集
torchvision.datasets.CIFAR10
pytorch里面的torchvision.datasets中提供了大多数计算机视觉领域相关任务的数据集,可以根据实际需要加载相关数据集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就调用torchvision.datasets.SVHN()。
针对cifar10数据集而言,调用torchvision.datasets.CIFAR10(),其中root是下载数据集后保存的位置;train是一个bool变量,为true就是训练数据集,false就是测试数据集;download也是一个bool变量,表示是否下载;transform是对数据集中的"image"进行一些操作,比如归一化、随机裁剪、各种数据增强操作等;target_transform是针对数据集中的"label"进行一些操作。

示例代码如下:
# 加载训练数据集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化 ]) )
# 加载测试数据集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, transform= transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化 ]) )
transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
上面的代码中,我们用transforms.Compose([……])组合了一系列的对image的操作,其中trandforms.ToTensor()和transforms.Normalize()都涉及到归一化操作:
-
原始的cifar10数据集是numpy array的形式,其中数据范围是[0,255],pytorch加载时,并没有改变数据范围,依旧是[0,255],加载后的数据维度是(H, W, C),源码部分:

-
__getitem__()函数中进行transforms操作,进行了归一化:实际上传入的transform在__getitem__()函数中被调用,其中transforms.Totensor()会将data(也就是image)的维度变成(C,H, W)的形式,并且归一化到[0.0,1.0];

transforms.Normalize()会根据z = (x-mean) / std 对数据进行归一化,上述代码中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]是可以将3个通道单独进行归一化,3个通道可以设置不同的mean和std,最终数据范围变成[-0.5,+0.5] 。

所以如果通过pytorch的cifar10加载数据集后,针对traindataset.data,依旧是没有进行归一化的;但是比如traindataset[index].data,其中[index]这样的按下标取元素的操作会直接调用的__getitem__()函数,此时的data就是经过了归一化的。
除traindataset[index]会隐式自动调用__getitem__()函数外,还有什么时候会调用这个函数呢?毕竟……只有调用了这个函数才会调用transforms中的归一化处理。——答案是与dataloader搭配使用!
torchvision.datasets加载的数据集搭配Dataloader使用
torchvision.datasets实际上是torch.utils.data.Dataset的子类,那么就能传入Dataloader中,迭代的按batch-size获取批量数据,用于训练或者测试。其中dataloader加载dataset中的数据时,就是用到了其__getitem__()函数,所以用dataloader加载数据集,得到的是经过归一化后的数据。

model.train()和model.eval()
我发现上面的问题,是我用dataloader加载了训练数据集用于训练resnet18模型,训练过程中,我训练好并保存后,顺便测试了一下在测试数据集上的准确度。但是在测试的过程中,我没有用dataloader加载测试数据集,而是直接用的dataset.data来进行的测试。并且!由于是并没有将model设置成model.eval()【其实我设置了,但是我对自己很无语,我写的model.eval,忘记加括号了,无语呜呜】……也就是即便我的测试数据集没有经过归一化,由于模型还是在model.train()模式下,因此模型的BN层会自己调整,使得模型性能不受影响,因此在测试数据集上的accuracy达到了0.86,我就没有多想。
后来我用模型的时候,设置了model.eval()后,依旧是直接用的dataset.data(也就是没有归一化),不管是在测试数据集上还是在训练数据集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才发现是归一化的问题。
- 在
model.train()模式下进行预测时,PyTorch会默认启用一些训练相关的操作,例如Batch Normalization和Dropout,并且模型的参数是可变的,能够根据输入进行调整。这些操作在训练模式下可以帮助模型更好地适应训练数据,并产生较高的准确度。 - 在
model.eval()模式下进行预测时,PyTorch会将模型切换到评估模式,这会导致一些训练相关的操作行为发生变化。具体而言,Batch Normalization层会使用训练集上的统计信息进行归一化,而不是使用当前批次的统计信息。因此,如果输入数据没有进行归一化,模型在评估模式下的准确度可能会显著下降。
以下是我没有用dataloader加载数据集,进行预测的代码:
def correctness(model,data,target, device):batchsize = 1000batch_num = int(len(data) / batchsize) # 对原始的数据进行操作 从H.W.C变成C.H.W data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)# 手动归一化data = data/255data = (data - 0.5) / 0.5 # 求一个batch的correctnessdef _batch_correctness(i):images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]predict = model(images).detach().cpu() correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)return correctnessresult = np.array([_batch_correctness(i) for i in range(batch_num)])return result.flatten().sum()/data.shape[0]
我后面用上面的代码测试了四种情况:
- model.eval() + 没有归一化:train_accuracy = 0.10,test_accuracy = 0.10;
- model.eval() + 手动归一化:train_accuracy = 0.95,test_accuracy = 0.84;
- model.train() + 没有归一化:train_accuracy = 0.95,test_accuracy = 0.83;
- model.train() + 手动归一化:train_accuracy = 0.94,test_accuracy = 0.84;
由此可见,在model.eval()模式下,数据归一化对最终的测试结果有很大影响。
相关文章:
pytorch加载的cifar10数据集,到底有没有经过归一化
pytorch加载cifar10的归一化 pytorch怎么加载cifar10数据集torchvision.datasets.CIFAR10transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】 torchvision.datasets加载的数据集搭配Dataloader使用model.train()和model.eval() pytorch怎么加载…...
Day1 ARM基础
【ARM课程认知】 1.ARM课程的作用 承上启下 基础授课阶段:c语言、数据结构、linux嵌入式应用层课程:IO、进程线程、网络编程嵌入式底层课程:ARM体系结构、系统移植、linux设备驱动c/QT 2.ARM课程需要掌握的内容 自己能够实现简单的汇编编…...
ns3入门基础教程
ns3入门基础教程 文章目录 ns3入门基础教程ns环境配置测试ns3环境ns3简单案例 ns环境配置 官方网站:https://www.nsnam.org/releases/ 代码仓库:https://gitlab.com/nsnam/ns-3-dev 如果安装遇到问题,可以参考以下博文: https://…...
计算机视觉
目录 一、图像处理 main denoise 二、Harris角点检测 三、Hough变换直线检测 四、直方图显著性检测 五、人脸识别 六、kmeans import 函数 kmeanstext 七、神经网络 常用函数: imread----------读取图像 imshow---------显示图像 rgb2hsv---------RGB转…...
NSSCTF第10页(3)
[LitCTF 2023]彩蛋 第一题: LitCTF{First_t0_The_k3y! (1/?) 第三题: <?php // 第三个彩蛋!(看过头号玩家么?) // R3ady_Pl4yer_000ne (3/?) ?> 第六题: wow 你找到了第二个彩蛋哦~ _S0_ne3t? (2/?) 第七题…...
MySQL性能分析工具的使用
1. 统计SQL的查询成本:last_query_cost SHOW STATUS LIKE last_query_cost; 使用场景:它对于比较开销是非常有用的,特别是我们有好几种查询方式可选的时候。 SQL 查询是一个动态的过程,从页加载的角度来看,我们可以得到…...
Uniapp使用AES128加解密16进制
在对接低功耗蓝牙时,我们需要对蓝牙传输数据进行加解密,由于我们对接的命令是16进制,如5500020101aa00,每个16进制表示特定的含义,所以直接对16进制加解密 import CryptoJS from crypto-js// AES128 加密函数 functio…...
C++基础——类与对象
1 概述 C是面向对象的语言,面向对象语言三大特性:封装、继承、多态。 C将万事万物抽象为对象,对象上有其属性和行为。 2 封装 2.1 封装的意义 封装是面向对象的三大特性之一,封装将属性和行为作为一个整体,对属性和…...
人工智能-卷积神经网络
从全连接层到卷积 我们之前讨论的多层感知机十分适合处理表格数据,其中行对应样本,列对应特征。 对于表格数据,我们寻找的模式可能涉及特征之间的交互,但是我们不能预先假设任何与特征交互相关的先验结构。 此时,多层感…...
MySQL的event的使用方法
MySQL的event的使用方法 一、事件定时策略 1、查看event事件开启状态 SHOW VARIABLES LIKE event_scheduler;如图,Value值 ON:打开,OFF:关闭。 2、设置event事件打开 SET GLOBAL event_scheduler ON;如果MySQL重启了&#x…...
Leetcode Daily Challenge 1845. Seat Reservation Manager
1845. Seat Reservation Manager 题目要求:初始化一个SeatManager类包括默认构造函数和类函数,所有的seat初始化为true。reverse函数返回最小的true,然后把这个编号的椅子赋值为false。unreverse(seatNumber)函数把编号为seatNumber的椅子恢…...
Blender vs 3ds Max:谁才是3D软件的未来
在不断发展的3D建模和动画领域,两大软件巨头Blender和3ds Max一直在争夺顶级地位。 随着技术的进步和用户需求的演变,一个重要问题逐渐浮出水面:Blender是否最终会取代3ds Max?本文将深入探讨二者各自的优势和劣势、当前状况&…...
MapReduce:大数据处理的范式
一、介绍 在当今的数字时代,生成和收集的数据量正以前所未有的速度增长。这种数据的爆炸式增长催生了大数据领域,传统的数据处理方法往往不足。MapReduce是一个编程模型和相关框架,已成为应对大数据处理挑战的强大解决方案。本文探讨了MapRed…...
【已解决】ModuleNotFoundError: No module named ‘dgl‘
禁止使用下面方法安装DGL,这种方法会更新你的pytorch版本,环境越变越乱 pip install dgl 二是进入DGL官网:Deep Graph Library (dgl.ai),了解自己的配置情况,比如我cuda11.8,ubuntu,当然和linux是一样的 …...
R 复习 菜鸟教程
R语言老师说R好就业,学就完了 基础语法 cat()可以拼接函数: > cat(1, "加", 1, "等于", 2, \n) 1 加 1 等于 2sink():重定向 sink("r_test.txt", splitTRUE) # 控制台同样输出 for (i in 1:5) print(i…...
第十二章《搞懂算法:朴素贝叶斯是怎么回事》笔记
朴素贝叶斯是经典的机器学习算法,也是统计模型中的一个基本方法。它的基本思想是利用统计学中的条件概率来进行分类。它是一种有监督学习算法,其中“朴素”是指该算法基于样本特征之间相互独立这个“朴素”假设。朴素贝叶斯原理简单、容易实现࿰…...
【从0到1开发一个网关】网关Mock功能的实现
文章目录 什么是Mock?如何实现Mock什么是Mock? Mock(模拟)是一种测试技术,用于创建虚拟对象来模拟真实对象的行为。Mock对象模拟了真实对象的行为,但是不依赖于真实对象的实现细节。它们可以在测试中替代真实对象,以便进行独立的单元测试。 需要使用Mock的原因包括以下几…...
前端框架Vue学习 ——(三)Vue生命周期
生命周期:指一个对象从创建到销毁的整个过程。 生命周期的八个阶段:每触发一个生命周期事件,会自动执行一个生命周期方法(钩子) mounted:挂载完成,Vue 初始化成功,HTML 页面渲染成功…...
相机滤镜软件Nevercenter CameraBag Photo mac中文版特点介绍
Nevercenter CameraBag Photo mac是一款相机和滤镜应用程序,它提供了一系列先进的滤镜、调整工具和预设,可以帮助用户快速地优化和编辑照片。 Nevercenter CameraBag Photo mac软件特点介绍 1. 滤镜:Nevercenter CameraBag Photo提供了超过2…...
游戏专用....
游戏专用:星际战甲 APP窗口以及键鼠监控 import tkinter as tk import time,threading from pynput.keyboard import Key,Listener import pynput.keyboard as kbclass myClass:def __init__(self):self.root tk.Tk()self.new_text self.flag threading.Event()…...
label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...
剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...
基础测试工具使用经验
背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...
Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...
08. C#入门系列【类的基本概念】:开启编程世界的奇妙冒险
C#入门系列【类的基本概念】:开启编程世界的奇妙冒险 嘿,各位编程小白探险家!欢迎来到 C# 的奇幻大陆!今天咱们要深入探索这片大陆上至关重要的 “建筑”—— 类!别害怕,跟着我,保准让你轻松搞…...
机器学习的数学基础:线性模型
线性模型 线性模型的基本形式为: f ( x ) ω T x b f\left(\boldsymbol{x}\right)\boldsymbol{\omega}^\text{T}\boldsymbol{x}b f(x)ωTxb 回归问题 利用最小二乘法,得到 ω \boldsymbol{\omega} ω和 b b b的参数估计$ \boldsymbol{\hat{\omega}}…...
写一个shell脚本,把局域网内,把能ping通的IP和不能ping通的IP分类,并保存到两个文本文件里
写一个shell脚本,把局域网内,把能ping通的IP和不能ping通的IP分类,并保存到两个文本文件里 脚本1 #!/bin/bash #定义变量 ip10.1.1 #循环去ping主机的IP for ((i1;i<10;i)) doping -c1 $ip.$i &>/dev/null[ $? -eq 0 ] &&am…...
CentOS 7.9安装Nginx1.24.0时报 checking for LuaJIT 2.x ... not found
Nginx1.24编译时,报LuaJIT2.x错误, configuring additional modules adding module in /www/server/nginx/src/ngx_devel_kit ngx_devel_kit was configured adding module in /www/server/nginx/src/lua_nginx_module checking for LuaJIT 2.x ... not…...
RFID推动新能源汽车零部件生产系统管理应用案例
RFID推动新能源汽车零部件生产系统管理应用案例 一、项目背景 新能源汽车零部件场景 在新能源汽车零部件生产领域,电子冷却水泵等关键部件的装配溯源需求日益增长。传统 RFID 溯源方案采用 “网关 RFID 读写头” 模式,存在单点位单独头溯源、网关布线…...
