当前位置: 首页 > news >正文

Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

1 model.train() 和 model.eval()用法和区别

1.1 model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

1.2 model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

1.3 分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
在这里插入图片描述
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

2.model.eval()和torch.no_grad()的区别

1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

相关文章:

Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

1 model.train() 和 model.eval()用法和区别 1.1 model.train() model.train()的作用是启用 Batch Normalization 和 Dropout。 如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一…...

Linux CentOS 8(firewalld的配置与管理)

Linux CentOS 8(firewalld的配置与管理) 目录 一、firewalld 简介二、firewalld 工作概念1、预定义区域(管理员可以自定义修改)2、预定义服务 三、firewalld 配置方法1、通过firewall-cmd配置2、通过firewall图形界面配置 四、配置…...

C复习-指针

参考: 里科《C和指针》 指针存储的是一个地址,实际就是一个值。 如果像下面一样对未初始化的指针进行赋值,如果a的初始值是非法地址,那么会报错。UNIX会提示段错误segmentation violation,或内存错误memory fault&…...

Runnable和Thread的区别,以及如何调用start()方法

Runnable和Thread都是Java多线程编程中的核心概念,它们之间存在以下主要差异: Runnable是一个接口,而Thread是一个类。这意味着我们可以通过实现Runnable接口来创建线程,或者直接继承Thread类并重写其方法。Runnable只包含一个ru…...

云音乐Android Cronet接入实践

背景 网易云音乐产品线终端类型广泛,除了移动端(IOS/安卓)之外,还有PC、MAC、Iot多终端等等。移动端由于上线时间早,用户基数大,沉淀了一些端侧相对比较稳定的网络策略和网络基础能力。然而由于各端在基础…...

Linux dup和dup2

Linux dup和dup2函数,他们有什么区别,什么场景下会用到,使用它们有什么注意事项 dup和dup2都是Linux系统中的系统调用,用于复制文件描述符。它们的主要区别在于如何指定新的文件描述符以及处理新文件描述符的方式。 dup函数 #i…...

Spring Boot实战 | 如何整合高性能数据库连接池HikariCP

专栏集锦,大佬们可以收藏以备不时之需 Spring Cloud实战专栏:https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏:https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏:https:/…...

Spring依赖注入

依赖注入底层原理流程图: https://www.processon.com/view/link/5f899fa5f346fb06e1d8f570 Spring中有两种依赖注入的方式 首先分两种: 手动注入自动注入 手动注入 在XML中定义Bean时,就是手动注入,因为是程序员手动给某个属…...

Linux下Jenkins自动化部署SpringBoot应用

Linux下Jenkins自动化部署SpringBoot应用 1、 Jenkins介绍 官方网址:https://www.jenkins.io/ 2、安装Jenkins 2.1 centos下命令行安装 访问官方,点击文档: 点击 Installing Jenkins: 点击 Linux: 选择 Red Hat/…...

【git 学习】--- ubuntu18.04 搭建本地git服务器

在Ubuntu18.04 上简单创建自己的git服务器~ 环境配置 Ubuntu: 18.04git服务器搭建步骤: ##1.安装git sudo apt-get install git##2.添加用户 sudo adduser test_git //test_git -- git用户名##3. 在Git用户的home目录下创建文件夹,作为裸仓库 sudo…...

JAVA电商平台免费搭建 B2B2C商城系统 多用户商城系统 直播带货 新零售商城 o2o商城 电子商务 拼团商城 分销商城

涉及平台 平台管理、商家端(PC端、手机端)、买家平台(H5/公众号、小程序、APP端(IOS/Android)、微服务平台(业务服务) 2. 核心架构 Spring Cloud、Spring Boot、Mybatis、Redis …...

Android 13 Framework 裁剪

裁剪应用 1. 修改 build/core/product.mk 添加PRODUCT_DEL_PACKAGES变量的声明 新增一行_product_single_value_vars PRODUCT_DEL_PACKAGES # The first API level this product shipped with _product_single_value_vars PRODUCT_SHIPPING_API_LEVEL _product_single_val…...

【Axios封装示例Vue2】

文章目录 为什么要封装axios?如何封装axios在Vue组件中使用封装的axios 为什么要封装axios? 在Vue 2项目中,直接在组件中使用axios可能会导致以下问题: 代码重复:每个组件都需要导入axios并编写相似的请求代码&#…...

k8s-----20、持久化存储--PV/PVC

PV/PVC 1、概念1.1 基本定义1.2 生命周期1.3 PV 卷阶段状态 2、 示例2.1 创建pod和PVC 与PV2.2 绑定PV2.3 强制删除pv,pvc2.4 测试 1、概念 1.1 基本定义 PersistentVolume(PV)是集群中由管理员配置的一段网络存储。 它是集群中的资源,就像…...

python matplotlib 生成矢量图

import matplotlib.pyplot as plt plt.savefig(r"xxx.svg", format"svg")注意: plt.savefig(r"xxx.svg", format"svg") 需要放在 plt.show()前面 原因:如果在 plt.show()调用后, 实际上已经创建了一…...

机器学习中常见的特征工程处理

一、特征工程 特征工程(Feature Engineering)对特征进行进一步分析,并对数据进行处理。 常见的特征工程包括:异常值处理、缺失值处理、数据分桶、特征处理、特征构造、特征筛选及降维等。 1、异常值处理 具体实现 from scipy.s…...

Spring IOC 和 AOP

核心概念 咱们这节就讲完了,在这节中我们讲了两个大概念,一个叫做IOC,一个叫做DI IOC是什么?是用对象的时候不要自己用new而是由外部提供,而spring在进行实现的时候是谁提供,就是IOC容器给你提供。 DI是什…...

echarts插件-liquidFill(水球图)

echarts插件-liquidFill(水球图) 1.下载2.引入:3.使用 1.下载 echarts.js下载:https://cdnjs.com/libraries/echarts echarts-liquidfill.js下载:https://github.com/ecomfe/echarts-liquidfill 2.引入: …...

c++ vscode cmake debug for mac

1. 下载vscode 2. 安装c插件 参考:C programming with Visual Studio Code 3. 安装llvm,可以使用brew安装 4. 配置llvm到系统环境变量中 5. 编写c代码 6. 编写CMakeLists.txt文件(前提安装cmake) cmake_minimum_required(V…...

17 结构型模式-享元模式

1 享元模式介绍 2 享元模式原理 3 享元模式实现 抽象享元类可以是一个接口也可以是一个抽象类,作为所有享元类的公共父类, 主要作用是提高系统的可扩展性. //* 抽象享元类 public abstract class Flyweight {public abstract void operation(String extrinsicState); }具体享…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

vscode里如何用git

打开vs终端执行如下: 1 初始化 Git 仓库(如果尚未初始化) git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

微信小程序之bind和catch

这两个呢,都是绑定事件用的,具体使用有些小区别。 官方文档: 事件冒泡处理不同 bind:绑定的事件会向上冒泡,即触发当前组件的事件后,还会继续触发父组件的相同事件。例如,有一个子视图绑定了b…...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件

在选煤厂、化工厂、钢铁厂等过程生产型企业,其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进,需提前预防假检、错检、漏检,推动智慧生产运维系统数据的流动和现场赋能应用。同时,…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

有限自动机到正规文法转换器v1.0

1 项目简介 这是一个功能强大的有限自动机(Finite Automaton, FA)到正规文法(Regular Grammar)转换器,它配备了一个直观且完整的图形用户界面,使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...

docker 部署发现spring.profiles.active 问题

报错: org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

C++使用 new 来创建动态数组

问题: 不能使用变量定义数组大小 原因: 这是因为数组在内存中是连续存储的,编译器需要在编译阶段就确定数组的大小,以便正确地分配内存空间。如果允许使用变量来定义数组的大小,那么编译器就无法在编译时确定数组的大…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)

安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...