当前位置: 首页 > news >正文

深度学习_Learning Rate Scheduling

我们在训练模型时学习率的设置非常重要。

  • 学习率的大小很重要。如果它太大,优化就会发散,如果它太小,训练时间太长,否则我们最终会得到次优的结果。
  • 其次,衰变率同样重要。如果学习率仍然很大,我们可能会简单地在最小值附近反弹,从而无法达到最优

我们可以通过学习率时间表(Learning Rate Scheduling)有效地管理准确性

一、基于FashionMNIST任务的学习率时间表实践准备

构建简单网络

def net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return model

模型结构如下(左-netron
在这里插入图片描述
简单的训练框架
全部脚本可以查看笔者的github: LearningRateScheduling.ipynb

def train(model, train_iter, test_iter, config, scheduler=None):device = config.deviceloss = config.lossopt = config.optnum_epochs = config.num_epochsmodel.to(device)animator = Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])ep_total_steps = len(train_iter)for ep in range(num_epochs):tq_bar = tqdm(enumerate(train_iter))tq_bar.set_description(f'[ Epoch {ep+1}/{num_epochs} ]')# train_loss, train_acc, num_examplesmetric = Accumulator(3)for idx, (X, y) in tq_bar:final_flag = (ep_total_steps == idx + 1) & (num_epochs == ep + 1)model.train()opt.zero_grad()X, y = X.to(device), y.to(device)y_hat = model(X)l = loss(y_hat, y)l.backward()opt.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])train_loss = metric[0] / metric[2]train_acc = metric[1] / metric[2]tq_bar.set_postfix({"loss" : f"{train_loss:.3f}","acc" : f"{train_acc:.3f}",})if (idx + 1) % 50 == 0:animator.add(ep + idx / len(train_iter), (train_loss, train_acc, None), clear_flag=not final_flag)test_acc = evaluate_accuracy_gpu(model, test_iter)animator.add(ep+1, (None, None, test_acc), clear_flag=not final_flag)if scheduler:if scheduler.__module__ == lr_scheduler.__name__:# 使用 PyTorch In-Built schedulerscheduler.step()else:# 使用自定义 schedulerfor param_group in opt.param_groups:param_group['lr'] = scheduler(ep) print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')plt.show()

二、基于FashionMNIST任务的学习率时间表实践

2.1 无learning rate Scheduler 训练

def test(train_iter, test_iter, scheduler=None):net = net_fn()cfg = Namespace(device=try_gpu(),loss=nn.CrossEntropyLoss(),lr=0.3, num_epochs=10,opt=torch.optim.SGD(net.parameters(), lr=0.3))train(net, train_iter, test_iter, cfg, scheduler)batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)
test(train_iter, test_iter)

在这里插入图片描述

2.2 Square Root Scheduler训练

更新方式为
η = η ∗ n u m _ u p d a t e + 1 \eta =\eta *\sqrt{num\_update + 1} η=ηnum_update+1
本次试验是每一个epoch更新一次

def get_lr(scheduler):lr = scheduler.get_last_lr()[0]scheduler.optimizer.step()scheduler.step()return lrdef plot_scheduler(scheduler, num_epochs=10):s = scheduler.__class__.__name__if scheduler.__module__ == lr_scheduler.__name__:print('pytorch build lr_scheduler')plot_y = [get_lr(scheduler) for _ in range(num_epochs)]else:plot_y = [scheduler(t) for t in range(num_epochs)]plt.title(f'train with learning rate scheduler: {s}')plt.plot(torch.arange(num_epochs), plot_y)plt.xlabel('num_epochs')plt.ylabel('learning_rate')plt.show()class SquareRootScheduler:"""使用均方根scheduler每一个epoch更新一次"""def __init__(self, lr=0.1):self.lr = lrdef __call__(self, num_update):return self.lr * pow(num_update + 1.0, -0.5)scheduler = SquareRootScheduler(lr=0.1)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

从下图中可以看出:曲线比以前更平滑了。其次,过度拟合较少。
在这里插入图片描述

2.3 FactorScheduler训练

学习率更新方式: η t + 1 ← m a x ( η m i n , η t ⋅ α ) \eta_{t+1} \leftarrow \mathop{\mathrm{max}}(\eta_{\mathrm{min}}, \eta_t \cdot \alpha) ηt+1max(ηmin,ηtα)

class FactorScheduler:def __init__(self, factor=1, stop_factor_lr=1e-7, base_lr=0.1):self.factor = factorself.stop_factor_lr = stop_factor_lrself.base_lr = base_lrdef __call__(self, num_update):self.base_lr = max(self.stop_factor_lr, self.base_lr * self.factor)return self.base_lrscheduler = FactorScheduler(factor=0.8, stop_factor_lr=1e-2, base_lr=0.6)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.4 Multi Factor Scheduler训练

保持学习率分段恒定,并每隔一段时间将其降低一个给定的量。也就是说,给定一组何时降低速率的时间比如$ (s = {3, 8} )$
d e c r e a s e ( η t + 1 ← η t ⋅ α ) t ∈ s decrease (\eta_{t+1} \leftarrow \eta_t \cdot \alpha) \ \ t \in s decrease(ηt+1ηtα)  ts

net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
scheduler = lr_scheduler.MultiStepLR(trainer, milestones=[3, 8], gamma=0.5)plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.5 Cosine Scheduler训练

Loshchilov和Hutter提出了一个相当令人困惑的启发式方法。它依赖于这样一种观察,即我们可能不想在一开始就大幅降低学习率,此外,我们可能希望在最后使用非常小的学习率来“完善”解决方案。这导致了一个类似余弦的时间表,具有以下函数形式,用于范围内的学习率 t ∈ [ 0 , T ] t \in [0, T] t[0,T]

η t = η T + η 0 − η T 2 ( 1 + cos ⁡ ( π t T ) ) \eta_t = \eta_T + \frac{\eta_0 - \eta_T}{2} \left(1 + \cos(\frac{\pi t}{T})\right) ηt=ηT+2η0ηT(1+cos(Tπt))

注:

  • η T \eta_T ηT: 为最终的学习率
  • η 0 \eta_0 η0: 为最开始的学习率
class CosineScheduler:def __init__(self, max_update, base_lr=0.01, final_lr=0,warmup_steps=0, warmup_begin_lr=0):self.base_lr_orig = base_lrself.max_update = max_updateself.final_lr = final_lrself.warmup_steps = warmup_stepsself.warmup_begin_lr = warmup_begin_lrself.max_steps = self.max_update - self.warmup_stepsdef get_warmup_lr(self, step):increase = (self.base_lr_orig - self.warmup_begin_lr) \* float(step) / float(self.warmup_steps)return self.warmup_begin_lr + increasedef __call__(self, step):if step < self.warmup_steps:return self.get_warmup_lr(step)if step <= self.max_update:self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * (1 + math.cos(math.pi * (step - self.warmup_steps) / self.max_steps)) / 2return self.base_lrscheduler = CosineScheduler(max_update=10, base_lr=0.2, final_lr=0.02)
plot_scheduler(scheduler)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

2.6 Warmup

在某些情况下,初始化参数不足以保证良好的解决方案。对于一些先进的网络设计来说,这尤其是一个问题(Transformer的训练常用该方法),可能会导致不稳定的优化问题。
我们可以通过选择一个足够小的学习率来解决这个问题,以防止一开始就出现分歧。不幸的是,这意味着进展缓慢。相反,学习率高最初会导致差异。

对于这种困境,一个相当简单的解决方案是使用一个预热期,在此期间学习速率增加到其初始最大值,并冷却速率直到优化过程结束。为了简单起见,通常使用线性增加来实现这一目的。

scheduler = CosineScheduler(max_update=10, warmup_steps=3, base_lr=0.2, final_lr=0.02)
plot_scheduler(scheduler, 15)

在这里插入图片描述
训练

test(train_iter, test_iter, scheduler)

在这里插入图片描述

小结

从上述的5个策略上来看,一般情况我们用 Cosine Scheduler 或者线性衰减就能得到较好的结果。不过对于较大的模型,需要用warmup 并且需要特意去设计,比如NoamOpt等。

相关文章:

深度学习_Learning Rate Scheduling

我们在训练模型时学习率的设置非常重要。 学习率的大小很重要。如果它太大&#xff0c;优化就会发散&#xff0c;如果它太小&#xff0c;训练时间太长&#xff0c;否则我们最终会得到次优的结果。其次&#xff0c;衰变率同样重要。如果学习率仍然很大&#xff0c;我们可能会简…...

snmp服务利用(端口:161、199、391、705、1993)

服务介绍 简单网络管理协议 是一种广泛应用于TCP/IP网络的网络管理标准协议(应用层协议),它提供了一种通过运行网络管理软件的中心计算机(即网络管理工作站)来监控和管理计算机网络的标准化管理框架(方法)。目前已颁布了SNMPv1、SNMPv2c和SNMPv3三个版本,广泛应用于网…...

MyBatis(二)—— 进阶

一、详解配置文件 1.1 核心配置文件 官方建议命名为mybatis-config.xml&#xff0c;核心配置文件里可以进行如下的配置&#xff1a; <environments> 和 <environment> mybatis可以配置多套环境&#xff08;开发一套、测试一套、、、&#xff09;&#xff0c; 在…...

婚恋交友app开发中需要注意的安全问题

前言 随着移动设备的普及&#xff0c;婚恋交友app已经成为了人们生活中重要的一部分。但是&#xff0c;这些应用的开发者需要确保应用的安全性&#xff0c;以保护用户的隐私和数据免受攻击。本文将介绍在婚恋交友app开发中需要注意的安全问题。 在当今数字化时代&#xff0c;…...

相机的内参和外参介绍

注&#xff1a;以下相机内参与外参介绍除来自网络整理外全部来自于《视觉SLAM十四讲从理论到实践 第2版》中的第5讲&#xff1a;相机与图像&#xff0c;为了方便查看&#xff0c;我将每节合并到了一幅图像中 相机与摄像机区别&#xff1a;相机着重于拍摄静态图像&#x…...

Node【包】

文章目录 &#x1f31f;前言&#x1f31f;Nodejs包&#x1f31f;什么是包&#xff1f;&#x1f31f;自定义包&#x1f31f;包配置文件&#x1f31f;示例&#x1f31f;Package.json 属性说明&#x1f31f;语义化版本号&#x1f31f;package.json示例 &#x1f31f;符合CommonJS规…...

CHAPTER 2: 《BACK-OF-THE-ENVELOPE ESTIMATION》 第2章 《初略的估计》

CHAPTER 2: BACK-OF-THE-ENVELOPE ESTIMATION 在系统设计面试中&#xff0c;有时您会被要求估计系统容量或使用粗略估计的性能需求。根据杰夫迪恩的说法&#xff0c;谷歌高级研究员&#xff0c;“粗略的计算是你使用结合思想实验和常见的性能数字&#xff0c;以获得良好的感觉…...

RocketMQ高级概念

一 RocketMQ核心概念 1.消息模型&#xff08;Message Model&#xff09; RocketMQ主要由 Producer、Broker、Consumer 三部分组成&#xff0c;其中Producer 负责⽣产消息&#xff0c;Consumer 负责消费消息&#xff0c;Broker 负责存储消息。Broker 在实际部署过程中对应⼀台…...

eureka注册中心和RestTemplate

eureka注册中心和restTemplate的使用说明 eureka的作用 消费者该如何获取服务提供者的具体信息 1.服务者启动时向eureka注册自己的信息 2.eureka保存这些信息 3.消费者根据服务名称向eureka拉去提供者的信息 如果有多个服务提供者&#xff0c;消费者该如何选择&#xff1f; 服…...

redis复制的设计与实现

一、复制 1.1旧版功能的实现 旧版Redis的复制功能分为 同步&#xff08;sync&#xff09;和 命令传播。 同步用于将从服务器更新至主服务器的当前状态。命令传播用于 主服务器状态变化时&#xff0c;让主从服务器状态回归一致。 1.1.1同步 当客户端向服务端发送slaveof命令…...

Docker更换国内镜像源

什么是Docker Docker 是一个开源的应用容器引擎&#xff0c;基于 Go 语言 并遵从 Apache2.0 协议开源。 Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中&#xff0c;然后发布到任何流行的 Linux 机器上&#xff0c;也可以实现虚拟化。 容器是完全…...

【网络编程】网络套接字,UDP,TCP套接字编程

前言 小亭子正在努力的学习编程&#xff0c;接下来将开启javaEE的学习~~ 分享的文章都是学习的笔记和感悟&#xff0c;如有不妥之处希望大佬们批评指正~~ 同时如果本文对你有帮助的话&#xff0c;烦请点赞关注支持一波, 感激不尽~~ 特别说明&#xff1a;本文分享的代码运行结果…...

海斯坦普Gestamp EDI 需求分析

海斯坦普Gestamp&#xff08;以下简称&#xff1a;Gestamp&#xff09;是一家总部位于西班牙的全球性汽车零部件制造商&#xff0c;目前在全球23个国家拥有超过100家工厂。Gestamp的业务涵盖了车身、底盘和机电系统等多个领域&#xff0c;其产品范围包括钣金、车身结构件、车轮…...

gpt写文章批量写文章-gpt3中文生成教程

怎么用gpt写文章批量写文章 批量写作文章是很多网站、营销人员、编辑等需要的重要任务&#xff0c;GPT可以帮助您快速生成大量自然、通顺的文章。下面是一个简单的步骤介绍&#xff0c;告诉您如何使用GPT批量写作文章。 步骤1&#xff1a;选择好训练模型 首先&#xff0c;选…...

HashMap实现原理

HashMap是基于散列表的Map接口的实现。插入和查询的性能消耗是固定的。可以通过构造器设置容量和负载因子&#xff0c;一调整容易得性能。 散列表&#xff1a;给定表M&#xff0c;存在函数f(key)&#xff0c;对任意给定的关键字值key&#xff0c;代入函数后若能得到包含该关键字…...

【Java 数据结构】PriorityQueue(堆)的使用及源码分析

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了 博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点!人生格言&#xff1a;当你的才华撑不起你的野心的时候,你就应该静下心来学习! 欢迎志同道合的朋友一起加油喔&#x1f9be;&am…...

使用 Kubernetes 运行 non-root .NET 容器

翻译自 Richard Lander 的博客 Rootless 或 non-root Linux 容器一直是 .NET 容器团队最需要的功能。我们最近宣布了所有 .NET 8 容器镜像都可以通过一行代码配置为 non-root 用户。今天的文章将介绍如何使用 Kubernetes 处理 non-root 托管。 您可以尝试使用我们的 non-root…...

为什么大量失业集中爆发在2023年?被裁?别怕!失业是跨越职场瓶颈的关键一步!对于牛逼的人,这是白捡N+1!...

被裁究竟是因为自身能力不行&#xff0c;还是因为大环境不行&#xff1f; 一位网友说&#xff1a; 被裁后找不到工作&#xff0c;本质上还是因为原来的能力就配不上薪资。如果确实有技术在身&#xff0c;根本不怕被裁&#xff0c;相当于白送n1&#xff01; 有人赞同楼主的观点&…...

Word控件Spire.Doc 【脚注】字体(3):将Doc转换为PDF时如何使用卸载的字体

Spire.Doc for .NET是一款专门对 Word 文档进行操作的 .NET 类库。在于帮助开发人员无需安装 Microsoft Word情况下&#xff0c;轻松快捷高效地创建、编辑、转换和打印 Microsoft Word 文档。拥有近10年专业开发经验Spire系列办公文档开发工具&#xff0c;专注于创建、编辑、转…...

keil5使用c++编写stm32控制程序

keil5使用c编写stm32控制程序 一、前言二、配置图解三、std::cout串口重定向四、串口中断服务函数五、结尾废话 一、前言 想着搞个新奇的玩意玩一玩来着&#xff0c;想用c编写代码来控制stm32&#xff0c;结果在keil5中&#xff0c;把踩给我踩闷了&#xff0c;这里简单记录一下…...

23-Oracle 23 ai 区块链表(Blockchain Table)

小伙伴有没有在金融强合规的领域中遇见&#xff0c;必须要保持数据不可变&#xff0c;管理员都无法修改和留痕的要求。比如医疗的电子病历中&#xff0c;影像检查检验结果不可篡改行的&#xff0c;药品追溯过程中数据只可插入无法删除的特性需求&#xff1b;登录日志、修改日志…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

系统设计 --- MongoDB亿级数据查询优化策略

系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log&#xff0c;共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题&#xff0c;不能使用ELK只能使用…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

《基于Apache Flink的流处理》笔记

思维导图 1-3 章 4-7章 8-11 章 参考资料 源码&#xff1a; https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子&#xff0c;再用 CNN-BiLSTM-Attention 来动态预测每个子序列&#xff0c;最后重构出总位移&#xff0c;预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵&#xff08;S…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件&#xff0c;用于在原生应用中加载 HTML 页面&#xff1a; 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

【生成模型】视频生成论文调研

工作清单 上游应用方向&#xff1a;控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...

MySQL 索引底层结构揭秘:B-Tree 与 B+Tree 的区别与应用

文章目录 一、背景知识&#xff1a;什么是 B-Tree 和 BTree&#xff1f; B-Tree&#xff08;平衡多路查找树&#xff09; BTree&#xff08;B-Tree 的变种&#xff09; 二、结构对比&#xff1a;一张图看懂 三、为什么 MySQL InnoDB 选择 BTree&#xff1f; 1. 范围查询更快 2…...