Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
1 model.train() 和 model.eval()用法和区别
1.1 model.train()
model.train()的作用是启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。
1.2 model.eval()
model.eval()的作用是不启用Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。
在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。
1.3 分析原因
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!
# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()
为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。

想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。
2.model.eval()和torch.no_grad()的区别
1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:
主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。
而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。
如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。
相关文章:
Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
1 model.train() 和 model.eval()用法和区别 1.1 model.train() model.train()的作用是启用 Batch Normalization 和 Dropout。 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一…...
Linux CentOS 8(firewalld的配置与管理)
Linux CentOS 8(firewalld的配置与管理) 目录 一、firewalld 简介二、firewalld 工作概念1、预定义区域(管理员可以自定义修改)2、预定义服务 三、firewalld 配置方法1、通过firewall-cmd配置2、通过firewall图形界面配置 四、配置…...
C复习-指针
参考: 里科《C和指针》 指针存储的是一个地址,实际就是一个值。 如果像下面一样对未初始化的指针进行赋值,如果a的初始值是非法地址,那么会报错。UNIX会提示段错误segmentation violation,或内存错误memory fault&…...
Runnable和Thread的区别,以及如何调用start()方法
Runnable和Thread都是Java多线程编程中的核心概念,它们之间存在以下主要差异: Runnable是一个接口,而Thread是一个类。这意味着我们可以通过实现Runnable接口来创建线程,或者直接继承Thread类并重写其方法。Runnable只包含一个ru…...
云音乐Android Cronet接入实践
背景 网易云音乐产品线终端类型广泛,除了移动端(IOS/安卓)之外,还有PC、MAC、Iot多终端等等。移动端由于上线时间早,用户基数大,沉淀了一些端侧相对比较稳定的网络策略和网络基础能力。然而由于各端在基础…...
Linux dup和dup2
Linux dup和dup2函数,他们有什么区别,什么场景下会用到,使用它们有什么注意事项 dup和dup2都是Linux系统中的系统调用,用于复制文件描述符。它们的主要区别在于如何指定新的文件描述符以及处理新文件描述符的方式。 dup函数 #i…...
Spring Boot实战 | 如何整合高性能数据库连接池HikariCP
专栏集锦,大佬们可以收藏以备不时之需 Spring Cloud实战专栏:https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏:https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏:https:/…...
Spring依赖注入
依赖注入底层原理流程图: https://www.processon.com/view/link/5f899fa5f346fb06e1d8f570 Spring中有两种依赖注入的方式 首先分两种: 手动注入自动注入 手动注入 在XML中定义Bean时,就是手动注入,因为是程序员手动给某个属…...
Linux下Jenkins自动化部署SpringBoot应用
Linux下Jenkins自动化部署SpringBoot应用 1、 Jenkins介绍 官方网址:https://www.jenkins.io/ 2、安装Jenkins 2.1 centos下命令行安装 访问官方,点击文档: 点击 Installing Jenkins: 点击 Linux: 选择 Red Hat/…...
【git 学习】--- ubuntu18.04 搭建本地git服务器
在Ubuntu18.04 上简单创建自己的git服务器~ 环境配置 Ubuntu: 18.04git服务器搭建步骤: ##1.安装git sudo apt-get install git##2.添加用户 sudo adduser test_git //test_git -- git用户名##3. 在Git用户的home目录下创建文件夹,作为裸仓库 sudo…...
JAVA电商平台免费搭建 B2B2C商城系统 多用户商城系统 直播带货 新零售商城 o2o商城 电子商务 拼团商城 分销商城
涉及平台 平台管理、商家端(PC端、手机端)、买家平台(H5/公众号、小程序、APP端(IOS/Android)、微服务平台(业务服务) 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis …...
Android 13 Framework 裁剪
裁剪应用 1. 修改 build/core/product.mk 添加PRODUCT_DEL_PACKAGES变量的声明 新增一行_product_single_value_vars PRODUCT_DEL_PACKAGES # The first API level this product shipped with _product_single_value_vars PRODUCT_SHIPPING_API_LEVEL _product_single_val…...
【Axios封装示例Vue2】
文章目录 为什么要封装axios?如何封装axios在Vue组件中使用封装的axios 为什么要封装axios? 在Vue 2项目中,直接在组件中使用axios可能会导致以下问题: 代码重复:每个组件都需要导入axios并编写相似的请求代码&#…...
k8s-----20、持久化存储--PV/PVC
PV/PVC 1、概念1.1 基本定义1.2 生命周期1.3 PV 卷阶段状态 2、 示例2.1 创建pod和PVC 与PV2.2 绑定PV2.3 强制删除pv,pvc2.4 测试 1、概念 1.1 基本定义 PersistentVolume(PV)是集群中由管理员配置的一段网络存储。 它是集群中的资源,就像…...
python matplotlib 生成矢量图
import matplotlib.pyplot as plt plt.savefig(r"xxx.svg", format"svg")注意: plt.savefig(r"xxx.svg", format"svg") 需要放在 plt.show()前面 原因:如果在 plt.show()调用后, 实际上已经创建了一…...
机器学习中常见的特征工程处理
一、特征工程 特征工程(Feature Engineering)对特征进行进一步分析,并对数据进行处理。 常见的特征工程包括:异常值处理、缺失值处理、数据分桶、特征处理、特征构造、特征筛选及降维等。 1、异常值处理 具体实现 from scipy.s…...
Spring IOC 和 AOP
核心概念 咱们这节就讲完了,在这节中我们讲了两个大概念,一个叫做IOC,一个叫做DI IOC是什么?是用对象的时候不要自己用new而是由外部提供,而spring在进行实现的时候是谁提供,就是IOC容器给你提供。 DI是什…...
echarts插件-liquidFill(水球图)
echarts插件-liquidFill(水球图) 1.下载2.引入:3.使用 1.下载 echarts.js下载:https://cdnjs.com/libraries/echarts echarts-liquidfill.js下载:https://github.com/ecomfe/echarts-liquidfill 2.引入: …...
c++ vscode cmake debug for mac
1. 下载vscode 2. 安装c插件 参考:C programming with Visual Studio Code 3. 安装llvm,可以使用brew安装 4. 配置llvm到系统环境变量中 5. 编写c代码 6. 编写CMakeLists.txt文件(前提安装cmake) cmake_minimum_required(V…...
17 结构型模式-享元模式
1 享元模式介绍 2 享元模式原理 3 享元模式实现 抽象享元类可以是一个接口也可以是一个抽象类,作为所有享元类的公共父类, 主要作用是提高系统的可扩展性. //* 抽象享元类 public abstract class Flyweight {public abstract void operation(String extrinsicState); }具体享…...
AI智能体开发实战:agent-skills工具库核心技能解析与应用
1. 项目概述与核心价值最近在折腾AI智能体开发,发现一个挺有意思的现象:很多开发者,包括我自己在内,一开始都热衷于去研究那些大型的、功能全面的智能体框架,试图打造一个“全能”的AI助手。但实际落地时,往…...
基于SpringBoot的汽车美容养护管理系统的设计与开发
一、选题依据和意义 (一)选题依据 随着国内汽车保有量持续攀升,汽车后市场规模不断扩大,汽车美容养护行业迎来快速发展期,但行业整体仍存在管理效率低下、服务流程不规范等问题[1]。传统管理模式依赖人工记录客户信息…...
基于RT-Thread与N32G457的工业UART网关设计与实现
1. 项目概述与核心价值最近在做一个工业数据采集的项目,现场有十几台不同品牌、不同协议的串口设备,PLC、仪表、传感器什么都有,它们的数据都需要汇总到一台中心服务器上。最头疼的是,这些设备分布在车间各处,拉线成本…...
NotebookLM生物技术研究落地难?92%实验室尚未启用的3个隐藏功能(内部白皮书首次公开)
更多请点击: https://intelliparadigm.com 第一章:NotebookLM生物技术研究落地难?92%实验室尚未启用的3个隐藏功能(内部白皮书首次公开) NotebookLM 作为 Google 推出的实验性 AI 助手,其在生物技术领域的…...
NoFences:免费开源桌面分区工具,Windows用户必备的效率神器
NoFences:免费开源桌面分区工具,Windows用户必备的效率神器 【免费下载链接】NoFences 🚧 Open Source Stardock Fences alternative 项目地址: https://gitcode.com/gh_mirrors/no/NoFences NoFences是一款基于C#开发的开源桌面分区工…...
在Windows上轻松安装APK文件:APK Installer完全指南
在Windows上轻松安装APK文件:APK Installer完全指南 【免费下载链接】APK-Installer An Android Application Installer for Windows 项目地址: https://gitcode.com/GitHub_Trending/ap/APK-Installer 你是否曾想过在Windows电脑上直接运行Android应用&…...
MIMO AONN架构:量子干涉实现超低功耗光学神经网络
1. MIMO AONN架构的核心价值光学神经网络(AONN)正在突破传统电子计算的物理极限。在传统电子神经网络中,非线性激活函数需要消耗大量能量进行电子-光子转换,而基于量子干涉的光学非线性机制可以直接在光域实现这一关键操作。我们实…...
汇顶科技入围GSA奖项:中国芯片设计公司的全球化突破与启示
1. 项目概述:一次里程碑式的行业认可最近在半导体圈子里,一个消息引起了不小的关注:汇顶科技成功入围了全球半导体联盟(GSA)2019年度的两大奖项提名。对于不熟悉这个领域的朋友来说,这或许只是一个普通的公…...
开源对话机器人平台Dialoqbase:基于RAG与微服务架构的快速部署指南
1. 项目概述:一个开源的对话机器人构建平台最近在折腾AI应用,想自己搭个智能客服或者知识库问答机器人,发现市面上的SaaS服务要么太贵,要么定制性太差。后来在GitHub上翻到了一个叫dialoqbase的开源项目,眼前一亮。这玩…...
Gemini3.1Pro评估ViT平移不变性:4周MVP路线图
利用 Gemini 3.1 Pro 评估视觉 Transformer 的平移不变性:从机制刻画、对照验证到门控降级与4周MVP路线图“平移不变性(Translation Invariance)”是视觉 Transformer(ViT 等)稳健性的核心指标之一:当图像在…...
