【深度学习】学习率及多种选择策略
学习率是最影响性能的超参数之一,如果我们只能调整一个超参数,那么最好的选择就是它。相比于其它超参数学习率以一种更加复杂的方式控制着模型的有效容量,当学习率最优时,模型的有效容量最大。本文从手动选择学习率到使用预热机制介绍了很多学习率的选择策略。
这篇文章记录了我对以下问题的理解:
- 学习速率是什么?学习速率有什么意义?
- 如何系统地获得良好的学习速率?
- 我们为什么要在训练过程中改变学习速率?
- 当使用预训练模型时,我们该如何解决学习速率的问题?
本文的大部分内容都是以 fast.ai 研究员写的内容 [1], [2], [5] 和 [3] 为基础的。本文是一个更为简洁的版本,通过本文可以快速获取这些文章的主要内容。如果您想了解更多详情,请参阅参考资料。
首先,什么是学习速率?
学习速率是指导我们该如何通过损失函数的梯度调整网络权重的超参数。学习率越低,损失函数的变化速度就越慢。虽然使用低学习率可以确保我们不会错过任何局部极小值,但也意味着我们将花费更长的时间来进行收敛,特别是在被困在高原区域的情况下。
下述公式表示了上面所说的这种关系。
new_weight = existing_weight — learning_rate * gradient
采用小学习速率(顶部)和大学习速率(底部)的梯度下降。来源:Coursera 上吴恩达(Andrew Ng)的机器学习课程。
一般而言,用户可以利用过去的经验(或其他类型的学习资料)直观地设定学习率的最佳值。
因此,想得到最佳学习速率是很难做到的。下图演示了配置学习速率时可能遇到的不同情况。
不同学习速率对收敛的影响(图片来源:cs231n)
此外,学习速率对模型收敛到局部极小值(也就是达到最好的精度)的速度也是有影响的。因此,从正确的方向做出正确的选择意味着我们可以用更短的时间来训练模型。
- Less training time, lesser money spent on GPU cloud compute. 😃
有更好的方法选择学习速率吗?
在「训练神经网络的周期性学习速率」[4] 的 3.3 节中,Leslie N. Smith 认为,用户可以以非常低的学习率开始训练模型,在每一次迭代过程中逐渐提高学习率(线性提高或是指数提高都可以),用户可以用这种方法估计出最佳学习率。
在每一个 mini-batch 后提升学习率
如果我们对每次迭代的学习进行记录,并绘制学习率(对数尺度)与损失,我们会看到,随着学习率的提高,从某个点开始损失会停止下降并开始提高。在实践中,学习速率的理想情况应该是从图的左边到最低点(如下图所示)。在本例中,是从 0.001 到 0.01。
上述方法看似有用,但该如何应用呢?
目前,上述方法在 fast.ai 包中作为一个函数进行使用。fast.ai 包是由 Jeremy Howard 开发的一种高级 pytorch 包(就像 Keras 之于 Tensorflow)。
在训练神经网络之前,只需输入以下命令即可开始找到最佳学习速率。
- # learn is an instance of Learner class or one of derived classes like ConvLearner
- learn.lr_find()
- learn.sched.plot_lr()
使之更好
现在我们已经知道了什么是学习速率,那么当我们开始训练模型时,怎样才能系统地得到最理想的值呢。接下来,我们将介绍如何利用学习率来改善模型的性能。
传统的方法
一般而言,当已经设定好学习速率并训练模型时,只有等学习速率随着时间的推移而下降,模型才能最终收敛。
然而,随着梯度达到高原,训练损失会更难得到改善。在 [3] 中,Dauphin 等人认为,减少损失的难度来自鞍点,而不是局部最低点。
误差曲面中的鞍点。鞍点是函数上的导数为零但不是轴上局部极值的点。(图片来源:safaribooksonline)
所以我们该如何解决这个问题?
我们可以采取几种办法。[1] 中是这么说的:
…无需使用固定的学习速率,并随着时间的推移而令它下降。如果训练不会改善损失,我们可根据一些周期函数 f 来改变每次迭代的学习速率。每个 Epoch 的迭代次数都是固定的。这种方法让学习速率在合理的边界值之间周期变化。这是有益的,因为如果我们卡在鞍点上,提高学习速率可以更快地穿越鞍点。
在 [2] 中,Leslie 提出了一种「三角」方法,这种方法可以在每次迭代之后重新开始调整学习速率。
Leslie N. Smith 提出的「Triangular」和「Triangular2」学习率周期变化的方法。左图中,LR 的最小值和最大值保持不变。右图中,每个周期之后 LR 最小值和最大值之间的差减半。
另一种常用的方法是由 Loshchilov&Hutter [6] 提出的预热重启(Warm Restarts)随机梯度下降。这种方法使用余弦函数作为周期函数,并在每个周期最大值时重新开始学习速率。「预热」是因为学习率重新开始时并不是从头开始的,而是由模型在最后一步收敛的参数决定的 [7]。
下图展示了伴随这种变化的过程,该过程将每个周期设置为相同的时间段。
SGDR 图,学习率 vs 迭代次数。
因此,我们现在可以通过周期性跳过「山脉」的办法缩短训练时间(下图)。
比较固定 LR 和周期 LR(图片来自 ruder.io)
研究表明,使用这些方法除了可以节省时间外,还可以在不调整的情况下提高分类准确性,而且可以减少迭代次数。
迁移学习中的学习速率
在 fast.ai 课程中,非常重视利用预训练模型解决 AI 问题。例如,在解决图像分类问题时,会教授学生如何使用 VGG 或 Resnet50 等预训练模型,并将其连接到想要预测的图像数据集。
我们采取下面的几个步骤,总结了 fast.ai 是如何完成模型构建(该程序不要与 fast.ai 包混淆)的:
1. 启用数据增强,precompute = True
2. 使用 lr_find() 找到损失仍在降低的最高学习速率
3. 从预计算激活值到最后一层训练 1~2 个 Epoch
4. 在 cycle_len = 1 的情况下使用数据增强(precompute=False)训练最后一层 2~3 次
5. 修改所有层为可训练状态
6. 将前面层的学习率设置得比下一个较高层低 3~10 倍
7. 再次使用 lr_find()
8. 在 cycle_mult=2 的情况下训练整个网络,直到过度拟合
从上面的步骤中,我们注意到步骤 2、5 和 7 提到了学习速率。这篇文章的前半部分已经基本涵盖了上述步骤中的第 2 项——如何在训练模型之前得出最佳学习率。
在下文中,我们会通过 SGDR 来了解如何通过重启学习速率来减少训练时间和提高准确性,以避免梯度接近零。
在最后一节中,我们将重点介绍差异学习(differential learning),以及如何在训练带有预训练模型中应用差异学习确定学习速率。
什么是差异学习
差异学习(different learning)在训练期间为网络中的不同层设置不同的学习速率。这种方法与人们常用的学习速率配置方法相反,常用的方法是训练时在整个网络中使用相同的学习速率。
在写这篇文章的时候,Jeremy 和 Sebastian Ruder 发表的一篇论文深入探讨了这个问题。所以我估计差异学习速率现在有一个新的名字——差别性的精调。😃
为了更清楚地说明这个概念,我们可以参考下面的图。在下图中将一个预训练模型分成 3 组,每个组的学习速率都是逐渐增加的。
具有差异学习速率的简单 CNN 模型。图片来自 [3]
这种方法的意义在于,前几个层通常会包含非常细微的数据细节,比如线和边,我们一般不希望改变这些细节并想保留它的信息。因此,无需大量改变权重。
相比之下,在后面的层,以绿色以上的层为例,我们可以从中获得眼球、嘴巴或鼻子等数据的细节特征,但我们可能不需要保留它们。
这种方法与其他微调方法相比如何?
在 [9] 中提出,微调整个模型太过昂贵,因为有些模型可能超过了 100 层。因此人们通常一次一层地对模型进行微调。
然而,这样的调整对顺序有要求,不具并行性,且因为需要通过数据集进行微调,导致模型会在小数据集上过拟合。
下表证明 [9] 中引入的方法能够在各种 NLP 分类任务中提高准确度且降低错误率。
参考文献:
[1] Improving the way we work with learning rate.
[2] The Cyclical Learning Rate technique.
[3] Transfer Learning using differential learning rates.
[4] Leslie N. Smith. Cyclical Learning Rates for Training Neural Networks.
[5] Estimating an Optimal Learning Rate for a Deep Neural Network
[6] Stochastic Gradient Descent with Warm Restarts
[7] Optimization for Deep Learning Highlights in 2017
[8] Lesson 1 Notebook, fast.ai Part 1 V2
[9] Fine-tuned Language Models for Text Classification
原文链接:https://towardsdatascience.com/understanding-learning-rates-and-how-it-improves-performance-in-deep-learning-d0d4059c1c10
相关文章:

【深度学习】学习率及多种选择策略
学习率是最影响性能的超参数之一,如果我们只能调整一个超参数,那么最好的选择就是它。相比于其它超参数学习率以一种更加复杂的方式控制着模型的有效容量,当学习率最优时,模型的有效容量最大。本文从手动选择学习率到使用预热机制…...

具有“真实触感”的动捕数据手套mhand pro,提供更精确的动作捕捉
随着人工智能的普及和万物互联,vr虚拟技术备受关注,为了更加真实的虚拟现实交互体验,动捕数据手套的使用逐渐普及,vr手套可以实时采集各手指关节运动数据,使用动捕数据手套可以在虚拟现实的场景中实现对真实手部运动的…...
Mongodb使用killCursors停止运行的cursor
cursor指向查询结果的游标,通过游标向下移动,获得下一条查询结果。MongoDB分批向用户返回数据结果。通过游标的移动, mongodb确定当前返回结果的位置,是否要加载更多数据到内存当中。cursor有默认的超时时间, 超时后cu…...
电脑风扇转一下停一下,无法正常开机问题解决
今天同事电话说电脑开不了机了,只听见风扇不停地呜呜地作响。笔者第一反应是不是硬件哪里出问题了,于是二话没说拿起心爱的螺丝刀就闪了过去。 按下电源,确实如电话所述。但感觉风扇并非一直在转,而是时断时续。由于听不大真切&a…...

无需部署服务器,如何结合内网穿透实现公网访问导航页工具Dashy
文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务,具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起,形成自己的导航…...

Go GORM简介
GORM(Go Object-Relational Mapping)是一个用于Go语言的ORM库,它提供了一种简单、优雅的方式来操作数据库。GORM支持多种数据库,包括MySQL、PostgreSQL、SQLite和SQL Server。以下是GORM的一些主要特性 全功能ORM:GORM…...

前端量子纠缠 效果炸裂 multipleWindow3dScene
我 | 在这里 🕵️ 读书 | 长沙 ⭐软件工程 ⭐ 本科 🏠 工作 | 广州 ⭐ Java 全栈开发(软件工程师) 🎃 爱好 | 研究技术、旅游、阅读、运动、喜欢流行歌曲 ✈️已经旅游的地点 | 新疆-乌鲁木齐、新疆-吐鲁番、广东-广州…...
第十七章 处理空字符串和 Null 值 - XMLIGNORENULL、XMLNIL 和 XMLUSEMPTYELEMENT 的详细信息
文章目录 第十七章 处理空字符串和 Null 值 - XMLIGNORENULL、XMLNIL 和 XMLUSEMPTYELEMENT 的详细信息XMLIGNORENULL、XMLNIL 和 XMLUSEMPTYELEMENT 的详细信息XMLIGNORENULLXMLNILXMLUSEEMPTYELEMENT 导入值 第十七章 处理空字符串和 Null 值 - XMLIGNORENULL、XMLNIL 和 XML…...

Asp.net core WebApi 配置自定义swaggerUI和中文注释
1.创建asp.net core webApi项目 默认会引入swagger的Nuget包 <PackageReference Include"Swashbuckle.AspNetCore" Version"6.2.3" />2.配置基本信息和中文注释(默认是没有中文注释的) 2.1创建一个新的controller using Micr…...
Xilinx SDK获取代码运行时间
Xilinx SDK获取代码运行时间 一、API 头文件 “xtime_l.h”函数XTime_GetTime(XTime * xtime),获取周期数时钟频率宏 COUNTS_PER_SECOND 二、使用 #include "xtime_l.h"int main(){XTime tBegin, tEnd;unsigned int t_us;unsigned long long cycles;XTime_GetTim…...

【力扣】189. 轮转数组
【力扣】189. 轮转数组 文章目录 【力扣】189. 轮转数组1. 题目介绍2. 解法2.1 方法一:不太正规,但是简单2.2 方法二:使用额外的数组2.3 方法三:环状替换2.4 方法四:数组翻转 3. Danger参考 1. 题目介绍 给定一个整数…...

Spring 拾枝杂谈—Spring原生容器结构剖析(通俗易懂)
目录 一、前言 二、Spring快速入门 1.简介 : 2. 入门实例 : 三、Spring容器结构分析 1.bean配置信息的存储 : 2.bean对象的存储 : 3.bean-id的快捷访问 : 四、总结 一、前言 开门见山,11.25日开始我们正式进入Java框架—Spring的学习,此前&…...

Java核心知识点整理大全22-笔记
目录 19.1.14. CAP 一致性(C): 可用性(A): 分区容忍性(P): 20. 一致性算法 20.1.1. Paxos Paxos 三种角色:Proposer,Acceptor,L…...
qt 5.15.2读取csv文件功能
qt 5.15.2读取csv文件功能 工程文件.pro 内容: QT core#添加网络模块 QT networkCONFIG c17 cmdline# You can make your code fail to compile if it uses deprecated APIs. # In order to do so, uncomment the following line. #DEFINES QT_DISABLE_DEPREC…...

【Vue】绝了!还有不懂生命周期的?
生命周期 Vue.js 组件生命周期: 生命周期函数(钩子)就是给我们提供了一些特定的时刻,让我们可以在这个周期段内加入自己的代码,做一些需要的事情; 生命周期钩子中的this指向是VM 或 组件实例对象 在JS 中,…...
关于IP与端口以及localhost
IP和域名 IP地址是一个规定,现在使用的是IPv4,既由4个0-255之间的数字组成,在计算机中,IP地址是分配给网卡的,每个网卡有一个唯一的IP地址。 域名(Domain Name)就是给IP取一个字符的名字,例如http://163.c…...

如何进行MySQL的主从复制(MySQL5.7)
背景:在一些Web服务器开发中,系统用户在进行数据访问时,基本都是直接操作数据库MySQL进行访问,而这种情况下,若只有一台MySQL服务器,可能会存在如下问题 数据的读和写的所有压力都会由一台数据库独…...
5:kotlin 类(Classes )
kotlin支持面向对象编程,也有雷和对象的概念 要声明一个类需要使用class关键字 class Customer属性(Properties) 可以在类名后边添加(),在()里边声明属性 class Contact(val id: Int, var email: String)声明了不…...
达梦:【1】达梦常用操作
达梦:【1】达梦常用操作 一、登录达梦二、创建表空间及用户模式三、查看表空间、用户、模式四、系统查询五、角色管理六、数据库导入导出七、达梦数据库汉字存储八、根据表生成ctl控制文件九、本地连多台数据库(RAC) 一、登录达梦 ./disql username/passwordip:por…...

数字人透明屏幕的技术原理是什么?
数字人透明屏幕的技术原理主要包括人脸识别和全息影像技术。其中,人脸识别技术是通过摄像头捕捉游客的面部表情和动作,并将其转化为数据指令,以便与数字人物进行互动。而全息影像技术则是利用透明屏幕,通过全息投影的方式将数字人…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地
借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...
【论文笔记】若干矿井粉尘检测算法概述
总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机
这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机,因为在使用过程中发现 Airsim 对外部监控相机的描述模糊,而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置,最后在源码示例中找到了,所以感…...

MFC 抛体运动模拟:常见问题解决与界面美化
在 MFC 中开发抛体运动模拟程序时,我们常遇到 轨迹残留、无效刷新、视觉单调、物理逻辑瑕疵 等问题。本文将针对这些痛点,详细解析原因并提供解决方案,同时兼顾界面美化,让模拟效果更专业、更高效。 问题一:历史轨迹与小球残影残留 现象 小球运动后,历史位置的 “残影”…...

Qemu arm操作系统开发环境
使用qemu虚拟arm硬件比较合适。 步骤如下: 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载,下载地址:https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...
HTML前端开发:JavaScript 获取元素方法详解
作为前端开发者,高效获取 DOM 元素是必备技能。以下是 JS 中核心的获取元素方法,分为两大系列: 一、getElementBy... 系列 传统方法,直接通过 DOM 接口访问,返回动态集合(元素变化会实时更新)。…...