当前位置: 首页 > news >正文

【深度学习】学习率及多种选择策略

学习率是最影响性能的超参数之一,如果我们只能调整一个超参数,那么最好的选择就是它。相比于其它超参数学习率以一种更加复杂的方式控制着模型的有效容量,当学习率最优时,模型的有效容量最大。本文从手动选择学习率到使用预热机制介绍了很多学习率的选择策略。

这篇文章记录了我对以下问题的理解:

  • 学习速率是什么?学习速率有什么意义?
  • 如何系统地获得良好的学习速率?
  • 我们为什么要在训练过程中改变学习速率?
  • 当使用预训练模型时,我们该如何解决学习速率的问题?

本文的大部分内容都是以 fast.ai 研究员写的内容 [1], [2], [5] 和 [3] 为基础的。本文是一个更为简洁的版本,通过本文可以快速获取这些文章的主要内容。如果您想了解更多详情,请参阅参考资料。

首先,什么是学习速率?

学习速率是指导我们该如何通过损失函数的梯度调整网络权重的超参数。学习率越低,损失函数的变化速度就越慢。虽然使用低学习率可以确保我们不会错过任何局部极小值,但也意味着我们将花费更长的时间来进行收敛,特别是在被困在高原区域的情况下。

下述公式表示了上面所说的这种关系。

  1. new_weight = existing_weight — learning_rate * gradient

理解深度学习中的学习率及多种选择策略

采用小学习速率(顶部)和大学习速率(底部)的梯度下降。来源:Coursera 上吴恩达(Andrew Ng)的机器学习课程。

一般而言,用户可以利用过去的经验(或其他类型的学习资料)直观地设定学习率的最佳值。

因此,想得到最佳学习速率是很难做到的。下图演示了配置学习速率时可能遇到的不同情况。

理解深度学习中的学习率及多种选择策略

不同学习速率对收敛的影响(图片来源:cs231n)

此外,学习速率对模型收敛到局部极小值(也就是达到最好的精度)的速度也是有影响的。因此,从正确的方向做出正确的选择意味着我们可以用更短的时间来训练模型。

  
  1. Less training time, lesser money spent on GPU cloud compute. 😃

有更好的方法选择学习速率吗?

在「训练神经网络的周期性学习速率」[4] 的 3.3 节中,Leslie N. Smith 认为,用户可以以非常低的学习率开始训练模型,在每一次迭代过程中逐渐提高学习率(线性提高或是指数提高都可以),用户可以用这种方法估计出最佳学习率。

理解深度学习中的学习率及多种选择策略

在每一个 mini-batch 后提升学习率

如果我们对每次迭代的学习进行记录,并绘制学习率(对数尺度)与损失,我们会看到,随着学习率的提高,从某个点开始损失会停止下降并开始提高。在实践中,学习速率的理想情况应该是从图的左边到最低点(如下图所示)。在本例中,是从 0.001 到 0.01。

理解深度学习中的学习率及多种选择策略

上述方法看似有用,但该如何应用呢?

目前,上述方法在 fast.ai 包中作为一个函数进行使用。fast.ai 包是由 Jeremy Howard 开发的一种高级 pytorch 包(就像 Keras 之于 Tensorflow)。

在训练神经网络之前,只需输入以下命令即可开始找到最佳学习速率。

  
  1. # learn is an instance of Learner class or one of derived classes like ConvLearner
  2. learn.lr_find()
  3. learn.sched.plot_lr()

使之更好

现在我们已经知道了什么是学习速率,那么当我们开始训练模型时,怎样才能系统地得到最理想的值呢。接下来,我们将介绍如何利用学习率来改善模型的性能。

传统的方法

一般而言,当已经设定好学习速率并训练模型时,只有等学习速率随着时间的推移而下降,模型才能最终收敛。

然而,随着梯度达到高原,训练损失会更难得到改善。在 [3] 中,Dauphin 等人认为,减少损失的难度来自鞍点,而不是局部最低点。

理解深度学习中的学习率及多种选择策略

误差曲面中的鞍点。鞍点是函数上的导数为零但不是轴上局部极值的点。(图片来源:safaribooksonline)

所以我们该如何解决这个问题?

我们可以采取几种办法。[1] 中是这么说的:

…无需使用固定的学习速率,并随着时间的推移而令它下降。如果训练不会改善损失,我们可根据一些周期函数 f 来改变每次迭代的学习速率。每个 Epoch 的迭代次数都是固定的。这种方法让学习速率在合理的边界值之间周期变化。这是有益的,因为如果我们卡在鞍点上,提高学习速率可以更快地穿越鞍点。

在 [2] 中,Leslie 提出了一种「三角」方法,这种方法可以在每次迭代之后重新开始调整学习速率。

理解深度学习中的学习率及多种选择策略

Leslie N. Smith 提出的「Triangular」和「Triangular2」学习率周期变化的方法。左图中,LR 的最小值和最大值保持不变。右图中,每个周期之后 LR 最小值和最大值之间的差减半。

另一种常用的方法是由 Loshchilov&Hutter [6] 提出的预热重启(Warm Restarts)随机梯度下降。这种方法使用余弦函数作为周期函数,并在每个周期最大值时重新开始学习速率。「预热」是因为学习率重新开始时并不是从头开始的,而是由模型在最后一步收敛的参数决定的 [7]。

下图展示了伴随这种变化的过程,该过程将每个周期设置为相同的时间段。

理解深度学习中的学习率及多种选择策略

SGDR 图,学习率 vs 迭代次数。

因此,我们现在可以通过周期性跳过「山脉」的办法缩短训练时间(下图)。

理解深度学习中的学习率及多种选择策略

比较固定 LR 和周期 LR(图片来自 ruder.io)

研究表明,使用这些方法除了可以节省时间外,还可以在不调整的情况下提高分类准确性,而且可以减少迭代次数。

迁移学习中的学习速率

在 fast.ai 课程中,非常重视利用预训练模型解决 AI 问题。例如,在解决图像分类问题时,会教授学生如何使用 VGG 或 Resnet50 等预训练模型,并将其连接到想要预测的图像数据集。

我们采取下面的几个步骤,总结了 fast.ai 是如何完成模型构建(该程序不要与 fast.ai 包混淆)的:

1. 启用数据增强,precompute = True

2. 使用 lr_find() 找到损失仍在降低的最高学习速率

3. 从预计算激活值到最后一层训练 1~2 个 Epoch

4. 在 cycle_len = 1 的情况下使用数据增强(precompute=False)训练最后一层 2~3 次

5. 修改所有层为可训练状态

6. 将前面层的学习率设置得比下一个较高层低 3~10 倍

7. 再次使用 lr_find()

8. 在 cycle_mult=2 的情况下训练整个网络,直到过度拟合

从上面的步骤中,我们注意到步骤 2、5 和 7 提到了学习速率。这篇文章的前半部分已经基本涵盖了上述步骤中的第 2 项——如何在训练模型之前得出最佳学习率。

在下文中,我们会通过 SGDR 来了解如何通过重启学习速率来减少训练时间和提高准确性,以避免梯度接近零。

在最后一节中,我们将重点介绍差异学习(differential learning),以及如何在训练带有预训练模型中应用差异学习确定学习速率。

什么是差异学习

差异学习(different learning)在训练期间为网络中的不同层设置不同的学习速率。这种方法与人们常用的学习速率配置方法相反,常用的方法是训练时在整个网络中使用相同的学习速率。

理解深度学习中的学习率及多种选择策略

在写这篇文章的时候,Jeremy 和 Sebastian Ruder 发表的一篇论文深入探讨了这个问题。所以我估计差异学习速率现在有一个新的名字——差别性的精调。😃

为了更清楚地说明这个概念,我们可以参考下面的图。在下图中将一个预训练模型分成 3 组,每个组的学习速率都是逐渐增加的。

理解深度学习中的学习率及多种选择策略

具有差异学习速率的简单 CNN 模型。图片来自 [3]

这种方法的意义在于,前几个层通常会包含非常细微的数据细节,比如线和边,我们一般不希望改变这些细节并想保留它的信息。因此,无需大量改变权重。

相比之下,在后面的层,以绿色以上的层为例,我们可以从中获得眼球、嘴巴或鼻子等数据的细节特征,但我们可能不需要保留它们。

这种方法与其他微调方法相比如何?

在 [9] 中提出,微调整个模型太过昂贵,因为有些模型可能超过了 100 层。因此人们通常一次一层地对模型进行微调。

然而,这样的调整对顺序有要求,不具并行性,且因为需要通过数据集进行微调,导致模型会在小数据集上过拟合。

下表证明 [9] 中引入的方法能够在各种 NLP 分类任务中提高准确度且降低错误率。

理解深度学习中的学习率及多种选择策略

参考文献:

[1] Improving the way we work with learning rate.

[2] The Cyclical Learning Rate technique.

[3] Transfer Learning using differential learning rates.

[4] Leslie N. Smith. Cyclical Learning Rates for Training Neural Networks.

[5] Estimating an Optimal Learning Rate for a Deep Neural Network

[6] Stochastic Gradient Descent with Warm Restarts

[7] Optimization for Deep Learning Highlights in 2017

[8] Lesson 1 Notebook, fast.ai Part 1 V2

[9] Fine-tuned Language Models for Text Classification

原文链接:https://towardsdatascience.com/understanding-learning-rates-and-how-it-improves-performance-in-deep-learning-d0d4059c1c10

相关文章:

【深度学习】学习率及多种选择策略

学习率是最影响性能的超参数之一,如果我们只能调整一个超参数,那么最好的选择就是它。相比于其它超参数学习率以一种更加复杂的方式控制着模型的有效容量,当学习率最优时,模型的有效容量最大。本文从手动选择学习率到使用预热机制…...

具有“真实触感”的动捕数据手套mhand pro,提供更精确的动作捕捉

随着人工智能的普及和万物互联,vr虚拟技术备受关注,为了更加真实的虚拟现实交互体验,动捕数据手套的使用逐渐普及,vr手套可以实时采集各手指关节运动数据,使用动捕数据手套可以在虚拟现实的场景中实现对真实手部运动的…...

Mongodb使用killCursors停止运行的cursor

cursor指向查询结果的游标,通过游标向下移动,获得下一条查询结果。MongoDB分批向用户返回数据结果。通过游标的移动, mongodb确定当前返回结果的位置,是否要加载更多数据到内存当中。cursor有默认的超时时间, 超时后cu…...

电脑风扇转一下停一下,无法正常开机问题解决

今天同事电话说电脑开不了机了,只听见风扇不停地呜呜地作响。笔者第一反应是不是硬件哪里出问题了,于是二话没说拿起心爱的螺丝刀就闪了过去。 按下电源,确实如电话所述。但感觉风扇并非一直在转,而是时断时续。由于听不大真切&a…...

无需部署服务器,如何结合内网穿透实现公网访问导航页工具Dashy

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务,具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起,形成自己的导航…...

Go GORM简介

GORM(Go Object-Relational Mapping)是一个用于Go语言的ORM库,它提供了一种简单、优雅的方式来操作数据库。GORM支持多种数据库,包括MySQL、PostgreSQL、SQLite和SQL Server。以下是GORM的一些主要特性 全功能ORM:GORM…...

前端量子纠缠 效果炸裂 multipleWindow3dScene

我 | 在这里 🕵️ 读书 | 长沙 ⭐软件工程 ⭐ 本科 🏠 工作 | 广州 ⭐ Java 全栈开发(软件工程师) 🎃 爱好 | 研究技术、旅游、阅读、运动、喜欢流行歌曲 ✈️已经旅游的地点 | 新疆-乌鲁木齐、新疆-吐鲁番、广东-广州…...

第十七章 处理空字符串和 Null 值 - XMLIGNORENULL、XMLNIL 和 XMLUSEMPTYELEMENT 的详细信息

文章目录 第十七章 处理空字符串和 Null 值 - XMLIGNORENULL、XMLNIL 和 XMLUSEMPTYELEMENT 的详细信息XMLIGNORENULL、XMLNIL 和 XMLUSEMPTYELEMENT 的详细信息XMLIGNORENULLXMLNILXMLUSEEMPTYELEMENT 导入值 第十七章 处理空字符串和 Null 值 - XMLIGNORENULL、XMLNIL 和 XML…...

Asp.net core WebApi 配置自定义swaggerUI和中文注释

1.创建asp.net core webApi项目 默认会引入swagger的Nuget包 <PackageReference Include"Swashbuckle.AspNetCore" Version"6.2.3" />2.配置基本信息和中文注释&#xff08;默认是没有中文注释的&#xff09; 2.1创建一个新的controller using Micr…...

Xilinx SDK获取代码运行时间

Xilinx SDK获取代码运行时间 一、API 头文件 “xtime_l.h”函数XTime_GetTime(XTime * xtime),获取周期数时钟频率宏 COUNTS_PER_SECOND 二、使用 #include "xtime_l.h"int main(){XTime tBegin, tEnd;unsigned int t_us;unsigned long long cycles;XTime_GetTim…...

【力扣】189. 轮转数组

【力扣】189. 轮转数组 文章目录 【力扣】189. 轮转数组1. 题目介绍2. 解法2.1 方法一&#xff1a;不太正规&#xff0c;但是简单2.2 方法二&#xff1a;使用额外的数组2.3 方法三&#xff1a;环状替换2.4 方法四&#xff1a;数组翻转 3. Danger参考 1. 题目介绍 给定一个整数…...

Spring 拾枝杂谈—Spring原生容器结构剖析(通俗易懂)

目录 一、前言 二、Spring快速入门 1.简介 : 2. 入门实例 : 三、Spring容器结构分析 1.bean配置信息的存储 : 2.bean对象的存储 : 3.bean-id的快捷访问 : 四、总结 一、前言 开门见山&#xff0c;11.25日开始我们正式进入Java框架—Spring的学习&#xff0c;此前&…...

Java核心知识点整理大全22-笔记

目录 19.1.14. CAP 一致性&#xff08;C&#xff09;&#xff1a; 可用性&#xff08;A&#xff09;&#xff1a; 分区容忍性&#xff08;P&#xff09;&#xff1a; 20. 一致性算法 20.1.1. Paxos Paxos 三种角色&#xff1a;Proposer&#xff0c;Acceptor&#xff0c;L…...

qt 5.15.2读取csv文件功能

qt 5.15.2读取csv文件功能 工程文件.pro 内容&#xff1a; QT core#添加网络模块 QT networkCONFIG c17 cmdline# You can make your code fail to compile if it uses deprecated APIs. # In order to do so, uncomment the following line. #DEFINES QT_DISABLE_DEPREC…...

【Vue】绝了!还有不懂生命周期的?

生命周期 Vue.js 组件生命周期&#xff1a; 生命周期函数&#xff08;钩子&#xff09;就是给我们提供了一些特定的时刻&#xff0c;让我们可以在这个周期段内加入自己的代码&#xff0c;做一些需要的事情; 生命周期钩子中的this指向是VM 或 组件实例对象 在JS 中&#xff0c;…...

关于IP与端口以及localhost

IP和域名 IP地址是一个规定&#xff0c;现在使用的是IPv4&#xff0c;既由4个0-255之间的数字组成&#xff0c;在计算机中&#xff0c;IP地址是分配给网卡的&#xff0c;每个网卡有一个唯一的IP地址。 域名(Domain Name)就是给IP取一个字符的名字&#xff0c;例如http://163.c…...

如何进行MySQL的主从复制(MySQL5.7)

背景&#xff1a;在一些Web服务器开发中&#xff0c;系统用户在进行数据访问时&#xff0c;基本都是直接操作数据库MySQL进行访问&#xff0c;而这种情况下&#xff0c;若只有一台MySQL服务器&#xff0c;可能会存在如下问题 数据的读和写的所有压力都会由一台数据库独…...

5:kotlin 类(Classes )

kotlin支持面向对象编程&#xff0c;也有雷和对象的概念 要声明一个类需要使用class关键字 class Customer属性&#xff08;Properties&#xfeff;&#xff09; 可以在类名后边添加()&#xff0c;在()里边声明属性 class Contact(val id: Int, var email: String)声明了不…...

达梦:【1】达梦常用操作

达梦&#xff1a;【1】达梦常用操作 一、登录达梦二、创建表空间及用户模式三、查看表空间、用户、模式四、系统查询五、角色管理六、数据库导入导出七、达梦数据库汉字存储八、根据表生成ctl控制文件九、本地连多台数据库(RAC) 一、登录达梦 ./disql username/passwordip:por…...

数字人透明屏幕的技术原理是什么?

数字人透明屏幕的技术原理主要包括人脸识别和全息影像技术。其中&#xff0c;人脸识别技术是通过摄像头捕捉游客的面部表情和动作&#xff0c;并将其转化为数据指令&#xff0c;以便与数字人物进行互动。而全息影像技术则是利用透明屏幕&#xff0c;通过全息投影的方式将数字人…...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启&#xff0c;数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后&#xff0c;存在与用户组权限相关的问题。具体表现为&#xff0c;Oracle 实例的运行用户&#xff08;oracle&#xff09;和集…...

【力扣数据库知识手册笔记】索引

索引 索引的优缺点 优点1. 通过创建唯一性索引&#xff0c;可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度&#xff08;创建索引的主要原因&#xff09;。3. 可以加速表和表之间的连接&#xff0c;实现数据的参考完整性。4. 可以在查询过程中&#xff0c;…...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我&#xff0c;后续持续新增专题博文&#xff0c;谢谢&#xff01;&#xff01;&#xff01;】 上一篇我们讲了&#xff1a; 这一篇我们开始讲&#xff1a; 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下&#xff1a; 一、场景操作步骤 操作步…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

push [特殊字符] present

push &#x1f19a; present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中&#xff0c;push 和 present 是两种不同的视图控制器切换方式&#xff0c;它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

基于Java+VUE+MariaDB实现(Web)仿小米商城

仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意&#xff1a;运行前…...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中&#xff0c;明确沟通敏捷转型目的尤为关键&#xff0c;团队成员只有清晰理解转型背后的原因和利益&#xff0c;才能降低对变化的…...

pycharm 设置环境出错

pycharm 设置环境出错 pycharm 新建项目&#xff0c;设置虚拟环境&#xff0c;出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...

从物理机到云原生:全面解析计算虚拟化技术的演进与应用

前言&#xff1a;我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM&#xff08;Java Virtual Machine&#xff09;让"一次编写&#xff0c;到处运行"成为可能。这个软件层面的虚拟化让我着迷&#xff0c;但直到后来接触VMware和Doc…...