当前位置: 首页 > news >正文

深度学习_Learning Rate Scheduling

我们在训练模型时学习率的设置非常重要。

  • 学习率的大小很重要。如果它太大,优化就会发散,如果它太小,训练时间太长,否则我们最终会得到次优的结果。
  • 其次,衰变率同样重要。如果学习率仍然很大,我们可能会简单地在最小值附近反弹,从而无法达到最优

我们可以通过学习率时间表(Learning Rate Scheduling)有效地管理准确性

一、基于FashionMNIST任务的学习率时间表实践准备

构建简单网络

def net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return model

模型结构如下(左-netron
在这里插入图片描述
简单的训练框架
全部脚本可以查看笔者的github: LearningRateScheduling.ipynb

def train(model, train_iter, test_iter, config, scheduler=None):device = config.deviceloss = config.lossopt = config.optnum_epochs = config.num_epochsmodel.to(device)animator = Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])ep_total_steps = len(train_iter)for ep in range(num_epochs):tq_bar = tqdm(enumerate(train_iter))tq_bar.set_description(f'[ Epoch {ep+1}/{num_epochs} ]')# train_loss, train_acc, num_examplesmetric = Accumulator(3)for idx, (X, y) in tq_bar:final_flag = (ep_total_steps == idx + 1) & (num_epochs == ep + 1)model.train()opt.zero_grad()X, y = X.to(device), y.to(device)y_hat = model(X)l = loss(y_hat, y)l.backward()opt.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])train_loss = metric[0] / metric[2]train_acc = metric[1] / metric[2]tq_bar.set_postfix({"loss" : f"{train_loss:.3f}","acc" : f"{train_acc:.3f}",})if (idx + 1) % 50 == 0:animator.add(ep + idx / len(train_iter), (train_loss, train_acc, None), clear_flag=not final_flag)test_acc = evaluate_accuracy_gpu(model, test_iter)animator.add(ep+1, (None, None, test_acc), clear_flag=not final_flag)if scheduler:if scheduler.__module__ == lr_scheduler.__name__:# 使用 PyTorch In-Built schedulerscheduler.step()else:# 使用自定义 schedulerfor param_group in opt.param_groups:param_group['lr'] = scheduler(ep) print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')plt.show()

二、基于FashionMNIST任务的学习率时间表实践

2.1 无learning rate Scheduler 训练

def test(train_iter, test_iter, scheduler=None):net = net_fn()cfg = Namespace(device=try_gpu(),loss=nn.CrossEntropyLoss(),lr=0.3, num_epochs=10,opt=torch.optim.SGD(net.parameters(), lr=0.3))train(net, train_iter, test_iter, cfg, scheduler)batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
test(train_iter, test_iter)

在这里插入图片描述

2.2 Square Root Scheduler训练

更新方式为
η = η ∗ n u m _ u p d a t e + 1 \eta =\eta *\sqrt{num\_update + 1} η=ηnum_update+1
本次试验是每一个epoch更新一次

def get_lr(scheduler):lr = scheduler.get_last_lr()[0]scheduler.optimizer.step()scheduler.step()return lrdef plot_scheduler(scheduler, num_epochs=10):s = scheduler.__class__.__name__if scheduler.__module__ == lr_scheduler.__name__:print('pytorch build lr_scheduler')plot_y = [get_lr(scheduler) for _ in range(num_epochs)]else:plot_y = [scheduler(t) for t in range(num_epochs)]plt.title(f'train with learning rate scheduler: {s}')plt.plot(torch.arange(num_epochs), plot_y)plt.xlabel('num_epochs')plt.ylabel('learning_rate')plt.show()class SquareRootScheduler:"""使用均方根scheduler每一个epoch更新一次"""def __init__(self, lr=0.1):self.lr = lrdef __call__(self, num_update):return self.lr * pow(num_update + 1.0, -0.5)scheduler = SquareRootScheduler(lr=0.1)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

从下图中可以看出:曲线比以前更平滑了。其次,过度拟合较少。
在这里插入图片描述

2.3 FactorScheduler训练

学习率更新方式: η t + 1 ← m a x ( η m i n , η t ⋅ α ) \eta_{t+1} \leftarrow \mathop{\mathrm{max}}(\eta_{\mathrm{min}}, \eta_t \cdot \alpha) ηt+1max(ηmin,ηtα)

class FactorScheduler:def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):self.factor = factorself.stop_factor_lr = stop_factor_lrself.base_lr = base_lrdef __call__(self, num_update):self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)return self.base_lrscheduler = FactorScheduler(factor=0.8, stop_factor_lr=1e-2, base_lr=0.6)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.4 Multi Factor Scheduler训练

保持学习率分段恒定,并每隔一段时间将其降低一个给定的量。也就是说,给定一组何时降低速率的时间比如$ (s = {3, 8} )$
d e c r e a s e ( η t + 1 ← η t ⋅ α ) t ∈ s decrease (\eta_{t+1} \leftarrow \eta_t \cdot \alpha) \ \ t \in s decrease(ηt+1ηtα)  ts

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
scheduler = lr_scheduler.MultiStepLR(trainer, milestones=[3, 8], gamma=0.5)plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.5 Cosine Scheduler训练

Loshchilov和Hutter提出了一个相当令人困惑的启发式方法。它依赖于这样一种观察,即我们可能不想在一开始就大幅降低学习率,此外,我们可能希望在最后使用非常小的学习率来“完善”解决方案。这导致了一个类似余弦的时间表,具有以下函数形式,用于范围内的学习率 t ∈ [ 0 , T ] t \in [0, T] t[0,T]

η t = η T + η 0 − η T 2 ( 1 + cos ⁡ ( π t T ) ) \eta_t = \eta_T + \frac{\eta_0 - \eta_T}{2} \left(1 + \cos(\frac{\pi t}{T})\right) ηt=ηT+2η0ηT(1+cos(Tπt))

注:

  • η T \eta_T ηT: 为最终的学习率
  • η 0 \eta_0 η0: 为最开始的学习率
class CosineScheduler:def __init__(self, max_update, base_lr=0.01, final_lr=0,warmup_steps=0, warmup_begin_lr=0):self.base_lr_orig = base_lrself.max_update = max_updateself.final_lr = final_lrself.warmup_steps = warmup_stepsself.warmup_begin_lr = warmup_begin_lrself.max_steps = self.max_update - self.warmup_stepsdef get_warmup_lr(self, step):increase = (self.base_lr_orig - self.warmup_begin_lr) \* float(step) / float(self.warmup_steps)return self.warmup_begin_lr + increasedef __call__(self, step):if step < self.warmup_steps:return self.get_warmup_lr(step)if step <= self.max_update:self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * (1 + math.cos(math.pi * (step - self.warmup_steps) / self.max_steps)) / 2return self.base_lrscheduler = CosineScheduler(max_update=10, base_lr=0.2, final_lr=0.02)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.6 Warmup

在某些情况下,初始化参数不足以保证良好的解决方案。对于一些先进的网络设计来说,这尤其是一个问题(Transformer的训练常用该方法),可能会导致不稳定的优化问题。
我们可以通过选择一个足够小的学习率来解决这个问题,以防止一开始就出现分歧。不幸的是,这意味着进展缓慢。相反,学习率高最初会导致差异。

对于这种困境,一个相当简单的解决方案是使用一个预热期,在此期间学习速率增加到其初始最大值,并冷却速率直到优化过程结束。为了简单起见,通常使用线性增加来实现这一目的。

scheduler = CosineScheduler(max_update=10, warmup_steps=3, base_lr=0.2, final_lr=0.02)
plot_scheduler(scheduler, 15)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

小结

从上述的5个策略上来看,一般情况我们用 Cosine Scheduler 或者线性衰减就能得到较好的结果。不过对于较大的模型,需要用warmup 并且需要特意去设计,比如NoamOpt等。

相关文章:

深度学习_Learning Rate Scheduling

我们在训练模型时学习率的设置非常重要。 学习率的大小很重要。如果它太大&#xff0c;优化就会发散&#xff0c;如果它太小&#xff0c;训练时间太长&#xff0c;否则我们最终会得到次优的结果。其次&#xff0c;衰变率同样重要。如果学习率仍然很大&#xff0c;我们可能会简…...

snmp服务利用(端口:161、199、391、705、1993)

服务介绍 简单网络管理协议 是一种广泛应用于TCP/IP网络的网络管理标准协议(应用层协议),它提供了一种通过运行网络管理软件的中心计算机(即网络管理工作站)来监控和管理计算机网络的标准化管理框架(方法)。目前已颁布了SNMPv1、SNMPv2c和SNMPv3三个版本,广泛应用于网…...

MyBatis(二)—— 进阶

一、详解配置文件 1.1 核心配置文件 官方建议命名为mybatis-config.xml&#xff0c;核心配置文件里可以进行如下的配置&#xff1a; <environments> 和 <environment> mybatis可以配置多套环境&#xff08;开发一套、测试一套、、、&#xff09;&#xff0c; 在…...

婚恋交友app开发中需要注意的安全问题

前言 随着移动设备的普及&#xff0c;婚恋交友app已经成为了人们生活中重要的一部分。但是&#xff0c;这些应用的开发者需要确保应用的安全性&#xff0c;以保护用户的隐私和数据免受攻击。本文将介绍在婚恋交友app开发中需要注意的安全问题。 在当今数字化时代&#xff0c;…...

相机的内参和外参介绍

注&#xff1a;以下相机内参与外参介绍除来自网络整理外全部来自于《视觉SLAM十四讲从理论到实践 第2版》中的第5讲&#xff1a;相机与图像&#xff0c;为了方便查看&#xff0c;我将每节合并到了一幅图像中 相机与摄像机区别&#xff1a;相机着重于拍摄静态图像&#x…...

Node【包】

文章目录 &#x1f31f;前言&#x1f31f;Nodejs包&#x1f31f;什么是包&#xff1f;&#x1f31f;自定义包&#x1f31f;包配置文件&#x1f31f;示例&#x1f31f;Package.json 属性说明&#x1f31f;语义化版本号&#x1f31f;package.json示例 &#x1f31f;符合CommonJS规…...

CHAPTER 2: 《BACK-OF-THE-ENVELOPE ESTIMATION》 第2章 《初略的估计》

CHAPTER 2: BACK-OF-THE-ENVELOPE ESTIMATION 在系统设计面试中&#xff0c;有时您会被要求估计系统容量或使用粗略估计的性能需求。根据杰夫迪恩的说法&#xff0c;谷歌高级研究员&#xff0c;“粗略的计算是你使用结合思想实验和常见的性能数字&#xff0c;以获得良好的感觉…...

RocketMQ高级概念

一 RocketMQ核心概念 1.消息模型&#xff08;Message Model&#xff09; RocketMQ主要由 Producer、Broker、Consumer 三部分组成&#xff0c;其中Producer 负责⽣产消息&#xff0c;Consumer 负责消费消息&#xff0c;Broker 负责存储消息。Broker 在实际部署过程中对应⼀台…...

eureka注册中心和RestTemplate

eureka注册中心和restTemplate的使用说明 eureka的作用 消费者该如何获取服务提供者的具体信息 1.服务者启动时向eureka注册自己的信息 2.eureka保存这些信息 3.消费者根据服务名称向eureka拉去提供者的信息 如果有多个服务提供者&#xff0c;消费者该如何选择&#xff1f; 服…...

redis复制的设计与实现

一、复制 1.1旧版功能的实现 旧版Redis的复制功能分为 同步&#xff08;sync&#xff09;和 命令传播。 同步用于将从服务器更新至主服务器的当前状态。命令传播用于 主服务器状态变化时&#xff0c;让主从服务器状态回归一致。 1.1.1同步 当客户端向服务端发送slaveof命令…...

Docker更换国内镜像源

什么是Docker Docker 是一个开源的应用容器引擎&#xff0c;基于 Go 语言 并遵从 Apache2.0 协议开源。 Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中&#xff0c;然后发布到任何流行的 Linux 机器上&#xff0c;也可以实现虚拟化。 容器是完全…...

【网络编程】网络套接字,UDP,TCP套接字编程

前言 小亭子正在努力的学习编程&#xff0c;接下来将开启javaEE的学习~~ 分享的文章都是学习的笔记和感悟&#xff0c;如有不妥之处希望大佬们批评指正~~ 同时如果本文对你有帮助的话&#xff0c;烦请点赞关注支持一波, 感激不尽~~ 特别说明&#xff1a;本文分享的代码运行结果…...

海斯坦普Gestamp EDI 需求分析

海斯坦普Gestamp&#xff08;以下简称&#xff1a;Gestamp&#xff09;是一家总部位于西班牙的全球性汽车零部件制造商&#xff0c;目前在全球23个国家拥有超过100家工厂。Gestamp的业务涵盖了车身、底盘和机电系统等多个领域&#xff0c;其产品范围包括钣金、车身结构件、车轮…...

gpt写文章批量写文章-gpt3中文生成教程

怎么用gpt写文章批量写文章 批量写作文章是很多网站、营销人员、编辑等需要的重要任务&#xff0c;GPT可以帮助您快速生成大量自然、通顺的文章。下面是一个简单的步骤介绍&#xff0c;告诉您如何使用GPT批量写作文章。 步骤1&#xff1a;选择好训练模型 首先&#xff0c;选…...

HashMap实现原理

HashMap是基于散列表的Map接口的实现。插入和查询的性能消耗是固定的。可以通过构造器设置容量和负载因子&#xff0c;一调整容易得性能。 散列表&#xff1a;给定表M&#xff0c;存在函数f(key)&#xff0c;对任意给定的关键字值key&#xff0c;代入函数后若能得到包含该关键字…...

【Java 数据结构】PriorityQueue(堆)的使用及源码分析

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了 博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点!人生格言&#xff1a;当你的才华撑不起你的野心的时候,你就应该静下心来学习! 欢迎志同道合的朋友一起加油喔&#x1f9be;&am…...

使用 Kubernetes 运行 non-root .NET 容器

翻译自 Richard Lander 的博客 Rootless 或 non-root Linux 容器一直是 .NET 容器团队最需要的功能。我们最近宣布了所有 .NET 8 容器镜像都可以通过一行代码配置为 non-root 用户。今天的文章将介绍如何使用 Kubernetes 处理 non-root 托管。 您可以尝试使用我们的 non-root…...

为什么大量失业集中爆发在2023年?被裁?别怕!失业是跨越职场瓶颈的关键一步!对于牛逼的人,这是白捡N+1!...

被裁究竟是因为自身能力不行&#xff0c;还是因为大环境不行&#xff1f; 一位网友说&#xff1a; 被裁后找不到工作&#xff0c;本质上还是因为原来的能力就配不上薪资。如果确实有技术在身&#xff0c;根本不怕被裁&#xff0c;相当于白送n1&#xff01; 有人赞同楼主的观点&…...

Word控件Spire.Doc 【脚注】字体(3):将Doc转换为PDF时如何使用卸载的字体

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下&#xff0c;轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具&#xff0c;专注于创建、编辑、转…...

keil5使用c++编写stm32控制程序

keil5使用c编写stm32控制程序 一、前言二、配置图解三、std::cout串口重定向四、串口中断服务函数五、结尾废话 一、前言 想着搞个新奇的玩意玩一玩来着&#xff0c;想用c编写代码来控制stm32&#xff0c;结果在keil5中&#xff0c;把踩给我踩闷了&#xff0c;这里简单记录一下…...

中国社科院与美国杜兰大学金融管理硕士项目——在职读研的日子里藏着我们未来无限可能

人生充满期待&#xff0c;梦想连接着未来。每一天都可以看作新的一页&#xff0c;要努力去成为最好的自己。在职读研的光阴里藏着无限的可能&#xff0c;只有不断的努力&#xff0c;不断的强大自己&#xff0c;未来会因为你的不懈坚持而发生改变&#xff0c;纵使眼前看不到希望…...

hardhat 本地连接matemask钱包

Hardhat 安装 https://hardhat.org/hardhat-runner/docs/getting-started#quick-start Running a Local Hardhat Network Hardhat greatly simplifies the process of setting up a local network by having an in-built local blockchain which can be easily run through a…...

【华为OD机试真题】1001 - 在字符串中找出连续最长的数字串含-号(Java C++ Python JS)| 机试题+算法思路+考点+代码解析

文章目录 一、题目🔸题目描述🔸输入输出二、代码参考🔸Java代码🔸 C++代码🔸 Python代码🔸 JS代码作者:KJ.JK🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🍂个人博客首页: KJ.JK 💖系列专栏:华为OD机试(Java C++ Python JS)...

CrackMapExec 域渗透工具使用

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、CrackMapExec 是什么&#xff1f;二、简单使用1、获取帮助信息2、smb连接执行命令3、使用winrm执行命令&#xff08;躲避杀软&#xff09;4、smb 协议常用枚…...

Modbus协议学习

以下内容从参考文章学习提炼 [参考文章](https://www.cnblogs.com/The-explosion/p/11512677.html) ## 基本概念 Modbus用的是主从通讯技术&#xff0c;主设备操作查询从设备。可以通过物理接口&#xff0c;可选用串口&#xff08;RS232、RS485、RS422&#xff09;&#xff0c…...

camunda如何处理流程待办任务

在 Camunda 中处理流程任务需要使用 Camunda 提供的 API 或者用户界面进行操作。以下是两种常用的处理流程任务的方式&#xff1a; 1、通过 Camunda 任务列表处理任务&#xff1a;在 Camunda 任务列表中&#xff0c;可以看到当前需要处理的任务&#xff0c;点击任务链接&#…...

git部分文件不想提交解决方案

正确的做法应该是&#xff1a;git rm --cached logs/xx.log&#xff0c;然后更新 .gitignore 忽略掉目标文件&#xff0c;最后 git commit -m "We really dont want Git to track this anymore!" 具体的原因如下&#xff1a; 被采纳的答案虽然能达到&#xff08;暂…...

2023年全国最新道路运输从业人员精选真题及答案58

百分百题库提供道路运输安全员考试试题、道路运输从业人员考试预测题、道路安全员考试真题、道路运输从业人员证考试题库等&#xff0c;提供在线做题刷题&#xff0c;在线模拟考试&#xff0c;助你考试轻松过关。 69.根据《公路水路行业安全生产风险管理暂行办法》&#xff0c;…...

Zimbra 远程代码执行漏洞(CVE-2019-9670)漏洞分析

Zimbra 远程代码执行漏洞(CVE-2019-9670)漏洞分析 漏洞简介 Zimbra是著名的开源系统&#xff0c;提供了一套开源协同办公套件包括WebMail&#xff0c;日历&#xff0c;通信录&#xff0c;Web文档管理和创作。一体化地提供了邮件收发、文件共享、协同办公、即时聊天等一系列解决…...

【数据结构初阶】第七节.树和二叉树的性质

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 前言 一、树 1.1 树的概念 1.2 树的结点分类 1.3 结点之间的关系 1.4 树的存储结构 1.5 其他相关概念 二、 二叉树 2.1 二叉树的概念 2.2 特殊的二叉树 2.3 二叉树的性质 2.4…...