当前位置: 首页 > news >正文

深度学习之pytorch常见的学习率绘制

文章目录

    • 0. Scope
    • 1. StepLR
    • 2. MultiStepLR
    • 3. ExponentialLR
    • 4. CosineAnnealingLR
    • 5. ReduceLROnPlateau
    • 6. CyclicLR
    • 7. OneCycleLR
    • 小结
    • 参考文献

https://blog.csdn.net/coldasice342/article/details/143435848

0. Scope

在深度学习中,学习率(Learning Rate, LR)是一个非常重要的超参数,它决定了模型权重更新的步长。选择合适的学习率对于训练过程至关重要,因为它不仅影响模型收敛的速度,还会影响最终模型的性能。然而,固定的学习率可能无法在整个训练过程中都保持最优,因此,学习率衰减(Learning Rate Decay, 或称 Learning Rate Schedule)策略应运而生,通过调整学习率来优化训练过程。

在PyTorch中,可以通过torch.optim.lr_scheduler模块提供的多个学习率调度器(Learning Rate Scheduler)来实现学习率的动态调整。这些调度器可以帮助优化训练过程,提高模型的性能。以下是PyTorch中一些常用的学习率调度器及其简要说明。

1. StepLR

每隔一定数量的epoch后,将学习率乘以一个固定的衰减因子。

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

参数:
step_size:经过多少个epoch后进行一次学习率衰减。
gamma:学习率的衰减因子,默认为0.1。
示例:

import torch
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = models.resnet18(pretrained=False)max_epoch = 50  # 一共50 epoch
iters = 20      # 每个epoch 有 20 个 bach
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)lr = []
for epoch in range(max_epoch):for batch in range(iters):optimizer.zero_grad()optimizer.step()scheduler.step()  # 更新learning rate# current_lr = scheduler.get_last_lr()[0]  #  注意:获取当前学习率不能使用get_lr()current_lr = optimizer.param_groups[0]['lr']lr.append(current_lr)print(f"End of Epoch {epoch + 1}: Current Learning Rate: {current_lr:.6f}")plt.figure(figsize=(10, 8))
plt.plot(range(1, max_epoch + 1), lr, marker='o')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

在这里插入图片描述

2. MultiStepLR

类似于 StepLR,但允许在不同 epoch 设置不同的学习率衰减点,提供更精细的控制。在指定的epoch列表处,将学习率乘以一个固定的衰减因子。

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 30], gamma=0.1)

参数:
milestones:一个列表,表示在哪些epoch处进行学习率衰减。
gamma:学习率的衰减因子,默认为0.1。
示例:
在这里插入图片描述

3. ExponentialLR

每个 epoch 将学习率按固定的指数衰减因子 gamma 进行调整。相比于 StepLR,它的衰减更平滑,适合需要持续减小学习率的任务。

scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

参数:
gamma:每个epoch结束时学习率的乘法因子。
示例:
在这里插入图片描述

4. CosineAnnealingLR

CosineAnnealingLR 利用余弦函数的特点,使学习率在训练过程中按照一个周期性变化的余弦曲线来衰减,即学习率从大到小再到大反复变化。通常用于长时间训练任务,能在训练后期有效避免学习率过快下降。
在这里插入图片描述

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=- 1, verbose=False)

参数:
T_max:一个周期的最大epoch数。
eta_min:学习率的最小值,默认为0。

示例:

import torch
from torchvision import models
import matplotlib.pyplot as plt
import numpy as npnet = models.resnet18(pretrained=False)
max_epoch = 50  # 一共50 epoch
iters = 200     # 每个epoch 有 200 个 bach
update_mode = 'epoch'
if update_mode == 'epoch':optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=max_epoch)  # * iterslr = []for epoch in range(max_epoch):for batch in range(iters):optimizer.step()lr.append(scheduler.get_lr()[0])scheduler.step()  # 注意 每个epoch 结束, 更新learning rate
else:optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)# 调整了四分之一周期的长度 max_epoch * itersscheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=max_epoch * iters)  lr = []for epoch in range(max_epoch):for batch in range(iters):optimizer.step()lr.append(scheduler.get_lr()[0])scheduler.step()  # 注意 每个batch 结束, 更新learning rateplt.figure(figsize=(10, 8))
plt.plot(np.arange(len(lr)), lr)
plt.xlabel('Iterations')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

每个epoch更新一次
在这里插入图片描述
每个iteration更新一次
在这里插入图片描述

5. ReduceLROnPlateau

ReduceLROnPlateau 是基于验证集表现来调整学习率的一种方法。当模型的验证集指标(如损失)在一段时间内没有改善时,学习率会自动减小。

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

和其他学习率更新不一样,ReduceLROnPlateau学习率更新时需要传入对应的参,例如:scheduler.step(ac) ,ac可以是loss或验证集的准确率之类的
参数:
mode:‘min’表示当监测指标不再下降时减少学习率,‘max’表示当监测指标不再上升时减少学习率。
factor:学习率的衰减因子,默认为0.1。
patience:在没有观察到性能提升的epoch数之后减少学习率。
示例:

import torch
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = models.resnet18(pretrained=False)max_epoch = 50  # 一共50 epoch
iters = 20      # 每个epoch 有 20 个 bach
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)lr = []
for epoch in range(max_epoch):for batch in range(iters):optimizer.zero_grad()optimizer.step()ac = 1if epoch > 20:ac = 10else:ac = ac - 0.1*epochscheduler.step(ac)  # 更新learning rate# current_lr = scheduler.get_last_lr()[0]  #  注意:获取当前学习率不能使用get_lr()current_lr = optimizer.param_groups[0]['lr']lr.append(current_lr)print(f"End of Epoch {epoch + 1}: Current Learning Rate: {current_lr:.6f}")plt.figure(figsize=(10, 8))
plt.plot(range(1, max_epoch + 1), lr, marker='o')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

在这里插入图片描述

6. CyclicLR

学习率在一个范围内循环变化。

scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=20, step_size_down=None,mode="triangular")

参数:
base_lr:学习率的下限。
max_lr:学习率的上限。
step_size_up:从base_lr到max_lr的步数。
step_size_down:从max_lr到base_lr的步数,如果为None,则默认与step_size_up相同。
示例:

CyclicLR - triangular

import torch
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = models.resnet18(pretrained=False)max_epoch = 50  # 一共50 epoch
iters = 20      # 每个epoch 有 20 个 bach
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=2, step_size_down=None)lr = []
for epoch in range(max_epoch):for batch in range(iters):optimizer.zero_grad()optimizer.step()scheduler.step()  # 更新learning rate# current_lr = scheduler.get_last_lr()[0]  #  注意:获取当前学习率不能使用get_lr()current_lr = optimizer.param_groups[0]['lr']lr.append(current_lr)print(f"End of Epoch {epoch + 1}: Current Learning Rate: {current_lr:.6f}")plt.figure(figsize=(10, 8))
plt.plot(range(1, max_epoch + 1), lr, marker='o')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

在这里插入图片描述
CyclicLR - triangular2

scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=2, step_size_down=None,mode="triangular2")

在这里插入图片描述
CyclicLR - exp_range

scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01,max_lr=0.1, step_size_up=5,mode="exp_range", gamma=0.85)

在这里插入图片描述
当step_size_up设置较大时:

scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01,max_lr=0.1, step_size_up=20,mode="exp_range", gamma=0.85)

在这里插入图片描述

7. OneCycleLR

根据 “1cycle” 策略,先逐步增加学习率,然后在训练的后期快速减小学习率,这种方式能在训练初期提供更快的收敛速度,同时在后期细化模型。

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, total_steps=None, epochs=100, steps_per_epoch=1)

参数:
max_lr:周期内的最高学习率。
total_steps:整个训练过程中的总步数。注意,如果这里是None,那么必须通过提供epochs和step_per_epoch的值来推断它。
epochs:训练的总轮数。
steps_per_epoch:每个epoch中的步数。
示例:
若每个epoch更新学习率:

import torch
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = models.resnet18(pretrained=False)max_epoch = 50  # 一共50 epoch
iters = 20      # 每个epoch 有 20 个 bach
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, total_steps=None, epochs=max_epoch, steps_per_epoch=1)lr = []
for epoch in range(max_epoch):for batch in range(iters):optimizer.zero_grad()optimizer.step()scheduler.step()  # 更新learning rate# current_lr = scheduler.get_last_lr()[0]  #  注意:获取当前学习率不能使用get_lr()current_lr = optimizer.param_groups[0]['lr']lr.append(current_lr)print(f"End of Epoch {epoch + 1}: Current Learning Rate: {current_lr:.6f}")plt.figure(figsize=(10, 8))
plt.plot(range(1, max_epoch + 1), lr, marker='o')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

在这里插入图片描述
若每个batch更新学习率:

import torch
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = models.resnet18(pretrained=False)max_epoch = 50  # 一共50 epoch
iters = 20      # 每个epoch 有 20 个 bach
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, total_steps=None, epochs=max_epoch, steps_per_epoch=iters)lr = []
for epoch in range(max_epoch):for batch in range(iters):optimizer.zero_grad()optimizer.step()scheduler.step()current_lr = optimizer.param_groups[0]['lr']lr.append(current_lr)# scheduler.step()  # 更新learning rate# current_lr = scheduler.get_last_lr()[0]  #  注意:获取当前学习率不能使用get_lr()# current_lr = optimizer.param_groups[0]['lr']# lr.append(current_lr)print(f"End of Epoch {epoch + 1}: Current Learning Rate: {current_lr:.6f}")plt.figure(figsize=(10, 8))
plt.plot(range(1, max_epoch*iters + 1), lr, marker='o')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

在这里插入图片描述

小结

本文绘制了pytorch中7种常见的学习率,其中没有最好的,只有适合的。无论使用何种学习率策略,主要还是得适合自己的模型训练,切勿邯郸学步。谨以此记,以备后续训练模型时选择合适的学习率。

参考文献

[1] 图解Pytorch学习率衰减策略(一)
[2] 深度学习】图解 9 种PyTorch中常用的学习率调整策略
[3] pytorch余弦退火学习率CosineAnnealingLR的使用

相关文章:

深度学习之pytorch常见的学习率绘制

文章目录 0. Scope1. StepLR2. MultiStepLR3. ExponentialLR4. CosineAnnealingLR5. ReduceLROnPlateau6. CyclicLR7. OneCycleLR小结参考文献 https://blog.csdn.net/coldasice342/article/details/143435848 0. Scope 在深度学习中,学习率(Learning R…...

Spring Boot集成SQL Server快速入门Demo

1.什么是SQL Server? SQL Server是由Microsoft开发和推广的以客户/服务器(c/s)模式访问、使用Transact-SQL语言的关系数据库管理系统(DBMS),它最初是由Microsoft、Sybase和Ashton-Tate三家公司共同开发的&…...

低代码牵手 AI 接口:开启智能化开发新征程

一、低代码与 AI 接口的结合趋势 低代码开发平台近年来在软件开发领域迅速崛起。随着企业数字化转型的需求不断增长,低代码开发平台以其快速构建应用程序的优势,满足了企业对高效开发的需求。例如,启效云低代码平台通过范式化和高颗粒度的可配…...

【已解决】git push一直提示输入用户名及密码、fatal: Could not read from remote repository的问题

问题描述: 在实操中,git push代码到github上一直提示输入用户名及密码,并且跳出的输入框输入用户名和密码后,报错找不到远程仓库 实际解决中,发现我环境有两个问题解决: git push一直提示输入用户名及密码…...

python语言基础-4 常用模块-4.13 其他模块

声明:本内容非盈利性质,也不支持任何组织或个人将其用作盈利用途。本内容来源于参考书或网站,会尽量附上原文链接,并鼓励大家看原文。侵删。 4.13 其他模块 除此之外python中还有大量的功能模块,如: pill…...

微信小程序=》基础=》常见问题=》性能总结

文章目录 微信小程序开发应用 实例小程序生命周期 以及 各生命周期应用实例小程序图片 展示方案 小程序打包应用方案技术细节(分包应用实例)技术细节(压缩处理)一、准备工作二、JavaScript 代码压缩三、WXML 文件优化&#xff08…...

JWT深度解析:Java Web中的安全传输与身份验证

标题:JWT深度解析:Java Web中的安全传输与身份验证 引言 JSON Web Token(JWT)是一种轻量级的身份验证和授权标准,它允许在各方之间安全地传输信息。在Java Web开发中,JWT因其无状态、可扩展性和跨域支持而…...

使用Java爬虫获取商品订单详情:从API到数据存储

在电子商务日益发展的今天,获取商品订单详情成为了许多开发者和数据分析师的需求。无论是为了分析用户行为,还是为了优化库存管理,订单数据的获取都是至关重要的。本文将详细介绍如何使用Java编写爬虫,通过API获取商品订单详情&am…...

Mybatis中批量插入foreach优化

数据库批量入库方常见方式:Java中foreach和xml中使用foreach 两者的区别: 通过Java的foreach循环批量插入: 当我们在Java通过foreach循环插入的时候,是一条一条sql执行然后将事物统一交给spring的事物来管理(Transa…...

Word VBA如何间隔选中多个(非连续)段落

实例需求:Word文档中的有多个段落,段落总数量不确定,现在需要先选中所有基数段落,即:段落1,段落3 … ,然后一次性设置粗体格式。 也许有的读者会认为这个无厘头的需求,循环遍历遍历文…...

Linux系统常用操作与命令指南

一、快捷分类 1、移动光标 h, j, k, l 左, 下, 上, 右 Ctrl-F:下翻一页 Ctrl-B:上翻一页 Ctrl-U:上翻半页 Ctrl-d:下翻半页 0:跳至行首,不管有无缩进,就是跳到第0个字…...

StructuredStreaming (一)

一、sparkStreaming的不足 1.基于微批,延迟高不能做到真正的实时 2.DStream基于RDD,不直接支持SQL 3.流批处理的API应用层不统一,(流用的DStream-底层是RDD,批用的DF/DS/RDD) 4.不支持EventTime事件时间(一般流处理都会有两个时间:事件发生的事件&am…...

由播客转向个人定制的音频频道(1)平台搭建

项目的背景 最近开始听喜马拉雅播客的内容,但是发现许多不方便的地方。 休息的时候收听喜马拉雅,但是还需要不断地选择喜马拉雅的内容,比较麻烦,而且黑灯操作反而伤眼睛。 喜马拉雅为代表的播客平台都是VOD 形式的&#xff0…...

[自然语言处理] [AI]深入理解语言与情感分类:从基础到深度学习的进展

语言是人类智能的核心组成部分,具有极高的复杂性和多样性。理解语言,尤其是语言中的隐含部分,向来是人工智能研究的一个巨大挑战。图灵测试本身便是一场关于语言生成与理解的比赛,旨在检验机器是否能够模拟人类的语言能力。随着深度学习的飞速发展,语音识别、情感分析等自…...

【GPTs】Gif-PT:DALL·E制作创意动图与精灵动画

博客主页: [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 💯GPTs指令💯前言💯Gif-PT主要功能适用场景优点缺点 💯小结 💯GPTs指令 中文翻译: 使用Dalle生成用户请求的精灵图动画&#…...

云原生周刊:Istio 1.24.0 正式发布

云原生周刊:Istio 1.24.0 正式发布 开源项目推荐 Kopf Kopf 是一个简洁高效的 Python 框架,只需几行代码即可编写 Kubernetes Operator。Kubernetes(K8s)作为强大的容器编排系统,虽自带命令行工具(kubec…...

Linux设置jar包开机启动

操作系统环境:CentOS 7 【需要 root 权限,使用 root 用户进行操作 或 普通用户使用 sudo 进行操作】 一、系统服务的方式 原理:利用系统服务管理应用程序的生命周期, systemctl 为系统服务管理工具 systemctl start applicati…...

计算机视觉和机器人技术中的下一个标记预测与视频扩散相结合

一种新方法可以训练神经网络对损坏的数据进行分类,同时预测下一步操作。 它可以为机器人制定灵活的计划,生成高质量的视频,并帮助人工智能代理导航数字环境。 Diffusion Forcing 方法可以对嘈杂的数据进行分类,并可靠地预测任务的…...

C语言之简单的获取命令行参数和环境变量

C语言之简单的获取命令行参数和环境变量 本人的开发环境为WIN10操作系统用VMWARE虚拟的UBUNTU LINUX 18.04LTS!!! 所有代码的编辑、编译、运行都在虚拟机上操作,初学的朋友要注意这一点!!! 详细…...

STL之vecor的使用(超详解)

目录 1. C/C中的数组 1.1. C语言中的数组 1.2. C中的数组 2. vector的接口 2.1. vector的迭代器 2.2. vector的初始化与销毁 2.3. vector的容量操作 2.4. vector的访问操作 2.5. vector的修改操作 💓 博客主页:C-SDN花园GGbond ⏩ 文章专栏…...

椭圆曲线密码学(ECC)

一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建(全平台详解) 在开始使用 React Native 开发移动应用之前,正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南,涵盖 macOS 和 Windows 平台的配置步骤,如何在 Android 和 iOS…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库,专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性,并提供了一个通用的框架&…...

HDFS分布式存储 zookeeper

hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...

Java数值运算常见陷阱与规避方法

整数除法中的舍入问题 问题现象 当开发者预期进行浮点除法却误用整数除法时,会出现小数部分被截断的情况。典型错误模式如下: void process(int value) {double half = value / 2; // 整数除法导致截断// 使用half变量 }此时...

mac 安装homebrew (nvm 及git)

mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用: 方法一:使用 Homebrew 安装 Git(推荐) 步骤如下:打开终端(Terminal.app) 1.安装 Homebrew…...

day36-多路IO复用

一、基本概念 (服务器多客户端模型) 定义:单线程或单进程同时监测若干个文件描述符是否可以执行IO操作的能力 作用:应用程序通常需要处理来自多条事件流中的事件,比如我现在用的电脑,需要同时处理键盘鼠标…...