Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
1 model.train() 和 model.eval()用法和区别
1.1 model.train()
model.train()的作用是启用 Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。
1.2 model.eval()
model.eval()的作用是不启用Batch Normalization 和 Dropout。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。
在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。
1.3 分析原因
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!
# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()
为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。

想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。
2.model.eval()和torch.no_grad()的区别
1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:
主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。
而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。
如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。
相关文章:
Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
1 model.train() 和 model.eval()用法和区别 1.1 model.train() model.train()的作用是启用 Batch Normalization 和 Dropout。 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一…...
Linux CentOS 8(firewalld的配置与管理)
Linux CentOS 8(firewalld的配置与管理) 目录 一、firewalld 简介二、firewalld 工作概念1、预定义区域(管理员可以自定义修改)2、预定义服务 三、firewalld 配置方法1、通过firewall-cmd配置2、通过firewall图形界面配置 四、配置…...
C复习-指针
参考: 里科《C和指针》 指针存储的是一个地址,实际就是一个值。 如果像下面一样对未初始化的指针进行赋值,如果a的初始值是非法地址,那么会报错。UNIX会提示段错误segmentation violation,或内存错误memory fault&…...
Runnable和Thread的区别,以及如何调用start()方法
Runnable和Thread都是Java多线程编程中的核心概念,它们之间存在以下主要差异: Runnable是一个接口,而Thread是一个类。这意味着我们可以通过实现Runnable接口来创建线程,或者直接继承Thread类并重写其方法。Runnable只包含一个ru…...
云音乐Android Cronet接入实践
背景 网易云音乐产品线终端类型广泛,除了移动端(IOS/安卓)之外,还有PC、MAC、Iot多终端等等。移动端由于上线时间早,用户基数大,沉淀了一些端侧相对比较稳定的网络策略和网络基础能力。然而由于各端在基础…...
Linux dup和dup2
Linux dup和dup2函数,他们有什么区别,什么场景下会用到,使用它们有什么注意事项 dup和dup2都是Linux系统中的系统调用,用于复制文件描述符。它们的主要区别在于如何指定新的文件描述符以及处理新文件描述符的方式。 dup函数 #i…...
Spring Boot实战 | 如何整合高性能数据库连接池HikariCP
专栏集锦,大佬们可以收藏以备不时之需 Spring Cloud实战专栏:https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏:https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏:https:/…...
Spring依赖注入
依赖注入底层原理流程图: https://www.processon.com/view/link/5f899fa5f346fb06e1d8f570 Spring中有两种依赖注入的方式 首先分两种: 手动注入自动注入 手动注入 在XML中定义Bean时,就是手动注入,因为是程序员手动给某个属…...
Linux下Jenkins自动化部署SpringBoot应用
Linux下Jenkins自动化部署SpringBoot应用 1、 Jenkins介绍 官方网址:https://www.jenkins.io/ 2、安装Jenkins 2.1 centos下命令行安装 访问官方,点击文档: 点击 Installing Jenkins: 点击 Linux: 选择 Red Hat/…...
【git 学习】--- ubuntu18.04 搭建本地git服务器
在Ubuntu18.04 上简单创建自己的git服务器~ 环境配置 Ubuntu: 18.04git服务器搭建步骤: ##1.安装git sudo apt-get install git##2.添加用户 sudo adduser test_git //test_git -- git用户名##3. 在Git用户的home目录下创建文件夹,作为裸仓库 sudo…...
JAVA电商平台免费搭建 B2B2C商城系统 多用户商城系统 直播带货 新零售商城 o2o商城 电子商务 拼团商城 分销商城
涉及平台 平台管理、商家端(PC端、手机端)、买家平台(H5/公众号、小程序、APP端(IOS/Android)、微服务平台(业务服务) 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis …...
Android 13 Framework 裁剪
裁剪应用 1. 修改 build/core/product.mk 添加PRODUCT_DEL_PACKAGES变量的声明 新增一行_product_single_value_vars PRODUCT_DEL_PACKAGES # The first API level this product shipped with _product_single_value_vars PRODUCT_SHIPPING_API_LEVEL _product_single_val…...
【Axios封装示例Vue2】
文章目录 为什么要封装axios?如何封装axios在Vue组件中使用封装的axios 为什么要封装axios? 在Vue 2项目中,直接在组件中使用axios可能会导致以下问题: 代码重复:每个组件都需要导入axios并编写相似的请求代码&#…...
k8s-----20、持久化存储--PV/PVC
PV/PVC 1、概念1.1 基本定义1.2 生命周期1.3 PV 卷阶段状态 2、 示例2.1 创建pod和PVC 与PV2.2 绑定PV2.3 强制删除pv,pvc2.4 测试 1、概念 1.1 基本定义 PersistentVolume(PV)是集群中由管理员配置的一段网络存储。 它是集群中的资源,就像…...
python matplotlib 生成矢量图
import matplotlib.pyplot as plt plt.savefig(r"xxx.svg", format"svg")注意: plt.savefig(r"xxx.svg", format"svg") 需要放在 plt.show()前面 原因:如果在 plt.show()调用后, 实际上已经创建了一…...
机器学习中常见的特征工程处理
一、特征工程 特征工程(Feature Engineering)对特征进行进一步分析,并对数据进行处理。 常见的特征工程包括:异常值处理、缺失值处理、数据分桶、特征处理、特征构造、特征筛选及降维等。 1、异常值处理 具体实现 from scipy.s…...
Spring IOC 和 AOP
核心概念 咱们这节就讲完了,在这节中我们讲了两个大概念,一个叫做IOC,一个叫做DI IOC是什么?是用对象的时候不要自己用new而是由外部提供,而spring在进行实现的时候是谁提供,就是IOC容器给你提供。 DI是什…...
echarts插件-liquidFill(水球图)
echarts插件-liquidFill(水球图) 1.下载2.引入:3.使用 1.下载 echarts.js下载:https://cdnjs.com/libraries/echarts echarts-liquidfill.js下载:https://github.com/ecomfe/echarts-liquidfill 2.引入: …...
c++ vscode cmake debug for mac
1. 下载vscode 2. 安装c插件 参考:C programming with Visual Studio Code 3. 安装llvm,可以使用brew安装 4. 配置llvm到系统环境变量中 5. 编写c代码 6. 编写CMakeLists.txt文件(前提安装cmake) cmake_minimum_required(V…...
17 结构型模式-享元模式
1 享元模式介绍 2 享元模式原理 3 享元模式实现 抽象享元类可以是一个接口也可以是一个抽象类,作为所有享元类的公共父类, 主要作用是提高系统的可扩展性. //* 抽象享元类 public abstract class Flyweight {public abstract void operation(String extrinsicState); }具体享…...
逆向实战:WASM加密在荔枝网x-itouchtv-ca参数中的定位与Hook技巧
1. WASM加密技术解析 WebAssembly(简称WASM)是一种新兴的二进制指令格式,它的出现让前端加密技术迈上了新台阶。与传统JavaScript加密相比,WASM具有明显的性能优势。在我的实际测试中,相同加密算法在WASM环境下的执行速…...
双倍效率:在快马平台中融合chatgpt实现智能代码生成与即时调试
最近在开发过程中,我发现了一个能显著提升效率的工作方式:将ChatGPT的智能生成能力与InsCode(快马)平台的即时调试环境结合起来。这种组合让我在代码编写、问题排查和逻辑优化上都节省了大量时间,今天就来分享一下具体的使用体验。 自然语言…...
如何让任何老旧手柄在PC游戏中完美工作:3步终极解决方案
如何让任何老旧手柄在PC游戏中完美工作:3步终极解决方案 【免费下载链接】ViGEmBus Windows kernel-mode driver emulating well-known USB game controllers. 项目地址: https://gitcode.com/gh_mirrors/vi/ViGEmBus 还在为心爱的游戏手柄无法在PC上使用而烦…...
开源工具本地化实践:FigmaCN插件让设计协作更高效
开源工具本地化实践:FigmaCN插件让设计协作更高效 【免费下载链接】figmaCN 中文 Figma 插件,设计师人工翻译校验 项目地址: https://gitcode.com/gh_mirrors/fi/figmaCN 在全球化协作与本地化需求日益增长的今天,开源工具本地化已成为…...
我用 QClaw 打造了一只“养生龙虾“——打工人保命健康守护助手
从一个简单的健康需求,到完整的健康提醒系统,我用 QClaw 这个智能助手完成了从"想法"到"落地"的全过程。缘起:打工人的健康焦虑 作为一个长期久坐、对着电脑敲代码的打工人,我越来越意识到健康的重要性。心血…...
Winhance中文版:图形化系统优化工具让Windows用户实现高效系统管理与个性化定制
Winhance中文版:图形化系统优化工具让Windows用户实现高效系统管理与个性化定制 【免费下载链接】Winhance-zh_CN A Chinese version of Winhance. C# application designed to optimize and customize your Windows experience. 项目地址: https://gitcode.com/g…...
实战应用:基于快马平台开发一个具备节点测速功能的网络工具面板
最近在折腾服务器节点管理时,发现手动测试各个节点的延迟特别麻烦。正好看到InsCode(快马)平台这个在线开发环境,就尝试用它快速搭建了一个带测速功能的网络工具面板。整个过程比想象中简单很多,分享下具体实现思路。 项目构思 这个工具的核…...
针对C++开源项目的AI工具讲解。我将它们分为两大类,便于理解
以下是针对C开源项目的AI工具讲解。我将它们分为两大类,便于理解: C开发者使用AI工具来提升开源项目开发效率(代码补全、调试、重构、文档生成等)。用C开发的开源AI工具/框架(这些工具本身是C开源项目,常用…...
Qt VS Tools配置全攻略:从安装到解决‘No Qt version assigned‘错误
Qt开发环境配置实战:从工具链搭建到疑难解析 Visual Studio作为主流的集成开发环境,与Qt框架的结合为C开发者提供了强大的生产力工具组合。但在实际项目配置过程中,"No Qt version assigned"这类基础错误却频繁困扰着开发者。本文…...
快速原型:用快马一键生成win11右键菜单传统样式恢复工具
快速原型:用快马一键生成win11右键菜单传统样式恢复工具 最近升级到Windows 11后,最让我不习惯的就是那个右键菜单了。新版的设计把所有选项都折叠起来,每次想找个功能还得点"显示更多选项",效率大打折扣。作为一个习惯…...
