Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
1 model.train() 和 model.eval()用法和区别
1.1 model.train()
model.train()的作用是启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。
1.2 model.eval()
model.eval()的作用是不启用Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。
在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。
1.3 分析原因
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!
# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()
为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。
2.model.eval()和torch.no_grad()的区别
1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:
主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。
而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。
如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。
相关文章:

Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
1 model.train() 和 model.eval()用法和区别 1.1 model.train() model.train()的作用是启用 Batch Normalization 和 Dropout。 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一…...

Linux CentOS 8(firewalld的配置与管理)
Linux CentOS 8(firewalld的配置与管理) 目录 一、firewalld 简介二、firewalld 工作概念1、预定义区域(管理员可以自定义修改)2、预定义服务 三、firewalld 配置方法1、通过firewall-cmd配置2、通过firewall图形界面配置 四、配置…...
C复习-指针
参考: 里科《C和指针》 指针存储的是一个地址,实际就是一个值。 如果像下面一样对未初始化的指针进行赋值,如果a的初始值是非法地址,那么会报错。UNIX会提示段错误segmentation violation,或内存错误memory fault&…...
Runnable和Thread的区别,以及如何调用start()方法
Runnable和Thread都是Java多线程编程中的核心概念,它们之间存在以下主要差异: Runnable是一个接口,而Thread是一个类。这意味着我们可以通过实现Runnable接口来创建线程,或者直接继承Thread类并重写其方法。Runnable只包含一个ru…...

云音乐Android Cronet接入实践
背景 网易云音乐产品线终端类型广泛,除了移动端(IOS/安卓)之外,还有PC、MAC、Iot多终端等等。移动端由于上线时间早,用户基数大,沉淀了一些端侧相对比较稳定的网络策略和网络基础能力。然而由于各端在基础…...
Linux dup和dup2
Linux dup和dup2函数,他们有什么区别,什么场景下会用到,使用它们有什么注意事项 dup和dup2都是Linux系统中的系统调用,用于复制文件描述符。它们的主要区别在于如何指定新的文件描述符以及处理新文件描述符的方式。 dup函数 #i…...

Spring Boot实战 | 如何整合高性能数据库连接池HikariCP
专栏集锦,大佬们可以收藏以备不时之需 Spring Cloud实战专栏:https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏:https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏:https:/…...
Spring依赖注入
依赖注入底层原理流程图: https://www.processon.com/view/link/5f899fa5f346fb06e1d8f570 Spring中有两种依赖注入的方式 首先分两种: 手动注入自动注入 手动注入 在XML中定义Bean时,就是手动注入,因为是程序员手动给某个属…...

Linux下Jenkins自动化部署SpringBoot应用
Linux下Jenkins自动化部署SpringBoot应用 1、 Jenkins介绍 官方网址:https://www.jenkins.io/ 2、安装Jenkins 2.1 centos下命令行安装 访问官方,点击文档: 点击 Installing Jenkins: 点击 Linux: 选择 Red Hat/…...
【git 学习】--- ubuntu18.04 搭建本地git服务器
在Ubuntu18.04 上简单创建自己的git服务器~ 环境配置 Ubuntu: 18.04git服务器搭建步骤: ##1.安装git sudo apt-get install git##2.添加用户 sudo adduser test_git //test_git -- git用户名##3. 在Git用户的home目录下创建文件夹,作为裸仓库 sudo…...

JAVA电商平台免费搭建 B2B2C商城系统 多用户商城系统 直播带货 新零售商城 o2o商城 电子商务 拼团商城 分销商城
涉及平台 平台管理、商家端(PC端、手机端)、买家平台(H5/公众号、小程序、APP端(IOS/Android)、微服务平台(业务服务) 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis …...
Android 13 Framework 裁剪
裁剪应用 1. 修改 build/core/product.mk 添加PRODUCT_DEL_PACKAGES变量的声明 新增一行_product_single_value_vars PRODUCT_DEL_PACKAGES # The first API level this product shipped with _product_single_value_vars PRODUCT_SHIPPING_API_LEVEL _product_single_val…...
【Axios封装示例Vue2】
文章目录 为什么要封装axios?如何封装axios在Vue组件中使用封装的axios 为什么要封装axios? 在Vue 2项目中,直接在组件中使用axios可能会导致以下问题: 代码重复:每个组件都需要导入axios并编写相似的请求代码&#…...
k8s-----20、持久化存储--PV/PVC
PV/PVC 1、概念1.1 基本定义1.2 生命周期1.3 PV 卷阶段状态 2、 示例2.1 创建pod和PVC 与PV2.2 绑定PV2.3 强制删除pv,pvc2.4 测试 1、概念 1.1 基本定义 PersistentVolume(PV)是集群中由管理员配置的一段网络存储。 它是集群中的资源,就像…...
python matplotlib 生成矢量图
import matplotlib.pyplot as plt plt.savefig(r"xxx.svg", format"svg")注意: plt.savefig(r"xxx.svg", format"svg") 需要放在 plt.show()前面 原因:如果在 plt.show()调用后, 实际上已经创建了一…...

机器学习中常见的特征工程处理
一、特征工程 特征工程(Feature Engineering)对特征进行进一步分析,并对数据进行处理。 常见的特征工程包括:异常值处理、缺失值处理、数据分桶、特征处理、特征构造、特征筛选及降维等。 1、异常值处理 具体实现 from scipy.s…...

Spring IOC 和 AOP
核心概念 咱们这节就讲完了,在这节中我们讲了两个大概念,一个叫做IOC,一个叫做DI IOC是什么?是用对象的时候不要自己用new而是由外部提供,而spring在进行实现的时候是谁提供,就是IOC容器给你提供。 DI是什…...

echarts插件-liquidFill(水球图)
echarts插件-liquidFill(水球图) 1.下载2.引入:3.使用 1.下载 echarts.js下载:https://cdnjs.com/libraries/echarts echarts-liquidfill.js下载:https://github.com/ecomfe/echarts-liquidfill 2.引入: …...
c++ vscode cmake debug for mac
1. 下载vscode 2. 安装c插件 参考:C programming with Visual Studio Code 3. 安装llvm,可以使用brew安装 4. 配置llvm到系统环境变量中 5. 编写c代码 6. 编写CMakeLists.txt文件(前提安装cmake) cmake_minimum_required(V…...

17 结构型模式-享元模式
1 享元模式介绍 2 享元模式原理 3 享元模式实现 抽象享元类可以是一个接口也可以是一个抽象类,作为所有享元类的公共父类, 主要作用是提高系统的可扩展性. //* 抽象享元类 public abstract class Flyweight {public abstract void operation(String extrinsicState); }具体享…...

论文阅读笔记——Large Language Models Are Zero-Shot Fuzzers
TitanFuzz 论文 深度学习库(TensorFlow 和 Pytorch)中的 bug 对下游任务系统是重要的,保障安全性和有效性。在深度学习(DL)库的模糊测试领域,直接生成满足输入语言(例如 Python )语法/语义和张量计算的DL A…...

【版本控制】Git 和 GitHub 入门教程
目录 0 引言1 Git与GitHub的诞生1.1 Git:Linus的“两周奇迹”,拯救Linux内核1.2 GitHub:为Git插上协作的翅膀1.3 协同进化:从工具到生态的质变1.4 关键历程时间轴(2005–2008) 2 Git与GitHub入门指南2.1 Gi…...
MySQL间隙锁入手,拿下间隙锁面试与实操
一、MySQL 间隙锁,究竟是什么? 在 MySQL 的世界里,间隙锁(Gap Lock)就像是一个默默守护数据一致性的卫士,看似低调,却在并发控制中扮演着至关重要的角色。 想象一下,你去图书馆借…...

Jenkins的学习与使用(CI/CD)
文章目录 前言背景CI/CDJenkins简介Jenkins特性 安装Jenkins工作流程(仅供参考)安装maven和其他插件新建任务任务源码管理配置maven配置git(非必需) 尝试手动构建jar包可能遇到的错误 发布到远程服务器前置清理工作构建触发器git钩…...

前端技能包
ES6 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> </head> <body><script>// 变量定义var a1;let b5; // 现在使用let 定义变量// 对象解构let person{&quo…...
stress-ng 服务器压力测试的工具学习
一、stress-ng (下一代压力测试) 介绍 项目地址:https://github.com/ColinIanKing/stress-ng stress-ng 将以多种可选方式对计算机系统进行压力测试。它旨在锻炼计算机的各种物理子系统以及各种操作系统内核接口。stress-ng 的特点: 360 压力测试100 …...

跟进一下目前最新的大数据技术
搭建最新平台 40C64G服务器,搭建3节点kvm,8C12G。 apache-hive-4.0.1-bin apache-tez-0.10.4-bin flink-1.20.1 hadoop-3.4.1 hbase-2.6.2 jdk-11.0.276 jdk8u452-b09 jdk8终于可以不用了 spark-3.5.5-bin-hadoop3 zookeeper-3.9.3 trino…...
笔记:算法题目中需要处理 int 某个位的三种方法:for、while、to_string
int n; cin >> n; 1. 使用for观察高位、低位、本位 for(int i 1; i < n; i * 10){ //i 1 当前位为个位, i 10 为十位,以此类推 high n / (i * 10); //这是相对于 i 的高位,例如 i 为个位…...
SQL进阶之旅 Day 14:数据透视与行列转换技巧
【SQL进阶之旅 Day 14】数据透视与行列转换技巧 开篇 欢迎来到“SQL进阶之旅”系列的第14天!今天我们将探讨数据透视与行列转换技巧,这是数据分析和报表生成中的核心技能。无论你是数据库开发工程师、数据分析师还是后端开发人员,行转列或列…...

破解HTTP无状态:基于Java的Session与Cookie协同工作指南
HTTP协议自身是属于“无状态”协议 无状态是指:默认情况下,HTTP协议的客户端和服务器之间的这次通信,和下次通信之间没有直接的关系 但在实际开发中,我们很多时候是需要知道请求之间的关联关系的 上述图中的令牌,通常就…...