当前位置: 首页 > news >正文

【PyTorch实战演练】自调整学习率实例应用(附代码)

目录

0. 前言

1. 自调整学习率的常用方法

1.1  ExponentialLR 指数衰减方法

1.2 CosineAnnealingLR 余弦退火方法

1.3 ChainedScheduler 链式方法

2. 实例说明

3. 结果说明

3.1 余弦退火法训练过程

3.2 指数衰减法训练过程

3.3 恒定学习率训练过程

3.4 结果解读

4. 完整代码


0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文介绍深度学习训练中经常能使用到的实用技巧——自调整学习率,并基于PyTorch框架通过实例进行使用,最后对比同样条件下以自调整学习率和固定学习率在模型训练上的表现。

在深度学习模型训练过程中,经常会出现损失值不收敛的头疼情况,令人怀疑自己设计的网络模型存在不合理或者错误之处,但是导致这种情况的真正原因往往非常简单——就是学习率的设定不合理。

这就导致在深度学习模型训练的初期,可能需要“摸索”一个比较合适的学习率,在后期逐渐细化的过程可能还需要尝试其他的学习率,进行“分段训练”。学习率的这种“摸索”费时费力,因此自调整学习率的方法应运而生。

1. 自调整学习率的常用方法

自调整学习率是一种通用的方法,通过在训练期间动态地更新学习率以使模型更好地收敛,提高模型的精确度和稳定性。在PyTorch框架 torch.optim.lr_scheduler 中集成了大量的自调整学习率的方法:

lr_scheduler.LambdaLR

lr_scheduler.MultiplicativeLR

lr_scheduler.StepLR

lr_scheduler.MultiStepLR

lr_scheduler.ConstantLR

lr_scheduler.LinearLR

lr_scheduler.ExponentialLR

lr_scheduler.PolynomialLR

lr_scheduler.CosineAnnealingLR

lr_scheduler.ChainedScheduler

lr_scheduler.SequentialLR

lr_scheduler.ReduceLROnPlateau

lr_scheduler.CyclicLR

lr_scheduler.OneCycleLR

lr_scheduler.CosineAnnealingWarmRestarts

这里具体方法虽然很多,但是每种自动调整学习率的策略都比较简单,本文仅介绍以下两种常见的方法,其余可以查看PyTorch官网说明。

1.1  ExponentialLR 指数衰减方法

这种方法的学习率为:

lr = lr_{initial} \cdot \gamma^{epoch}

这里有两个参数:

  • lr_{initial}:初始学习率
  • \gamma:衰减指数,取值范围(0,1),一般来说取值要非常接近1(>0.95,甚至要>0.99),否则在动辄几百几千的迭代次数epoch条件下学习率很快会衰减到0

PyTorch调用代码为:

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch)
  • optimizer:训练该模型的优化算法
  • gamma:上面的衰减指数\gamma
  • last_epoch:不用理会,默认-1即可
1.2 CosineAnnealingLR 余弦退火方法

首先我要吐槽下不知道是哪个大聪明给起的这个名字,我学过《机械加工原理与工艺》,但我不明白这个算法和退火有什么关系?名字非常唬人,方法并不复杂。

这个方法就是让学习率按余弦曲线进行震荡,其震荡范围为[0, lr_initial],震荡周期为T_max:

PyTorch调用代码为:

torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, last_epoch)

其参数同上说明不再赘述。

1.3 ChainedScheduler 链式方法

这个方法说白了就是多种自调节学习率方法进行组合使用,例如以下:

scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(self.opt, gamma=0.9)
scheduler = ChainedScheduler([scheduler1, scheduler2])

最后再强调一下,无论使用哪种方法,在coding时别忘了每个epoch学习后加上 scheduler.step() ,这样才能更新学习率。

2. 实例说明

本文目的是验证自调节学习率的作用,选用一个简单的“平方网络”实例,即期望输出为输入值的平方。

训练输入数据x_train,输出数据y_train分别为:

x_train = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.01,0.04,0.09,0.16,0.25,0.36,0.49,0.64,0.81,1],dtype=torch.float32).unsqueeze(-1)

验证输入数据x_val为:

x_val = torch.tensor([0.15,0.25,0.35,0.45,0.55,0.65,0.75,0.85,0.95],dtype=torch.float32).unsqueeze(-1)

网络模型使用5层全连接网络,对应这种简单问题足够用了:

class Linear(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(in_features=1, out_features=3),torch.nn.Sigmoid(),torch.nn.Linear(in_features=3,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=10),torch.nn.Sigmoid(),torch.nn.Linear(in_features=10,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=1),torch.nn.ReLU(),)

优化器选用Adam,训练组及验证组的损失函数都选用MSE均方差。

3. 结果说明

对于本文对比的三种方法,其训练参数设定如下表:

学习率调整方法训练参数设定
指数衰减迭代次数epoch=2000,初始学习率initial_lr=0.0005,衰减指数gamma=0.99995
余弦退火迭代次数epoch=2000,初始学习率initial_lr=0.001(因为学习率均值为初始值的一半,公平起见给余弦退火方法初始学习率*2),震荡周期T_max=200
恒定学习率学习率lr=0.0005

各种方法的训练过程如下图:

3.1 余弦退火法训练过程

3.2 指数衰减法训练过程

3.3 恒定学习率训练过程

3.4 结果解读
  1. 指数衰减方法的gamma真的要选择非常接近1才行!理由我上面也解释了,下面是gamma=0.999时的loss下降情况,可见其下降速度之慢;

  2. 本实例中,使用余弦退火方法训练的模型在验证组表现最好,验证组损失值val=0.0027,其次是指数衰减法(val=0.0037),最差是恒定学习率(val=0.0053),但这个结果仅限本实例中,并不是说哪种方法就比哪种方法好,在不同模型上可能要因地制宜选择合适的方法
  3. 目前这些所谓的“自调节学习率”我认为仅能算是“半自动调节”,因为仍有超参数需要事前设定(例如衰减指数gamma)。而选择哪种方法最好,以及这种方法的参数如何设定呢?可能还是需要做一定的“摸索”,但是相比恒定学习率肯定会节省很多时间!

4. 完整代码

import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparsetorch.manual_seed(25)x_train = torch.tensor([0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor([0.01,0.04,0.09,0.16,0.25,0.36,0.49,0.64,0.81,1],dtype=torch.float32).unsqueeze(-1)
x_val = torch.tensor([0.15,0.25,0.35,0.45,0.55,0.65,0.75,0.85,0.95],dtype=torch.float32).unsqueeze(-1)class Linear(torch.nn.Module):def __init__(self):super().__init__()self.layers = torch.nn.Sequential(torch.nn.Linear(in_features=1, out_features=3),torch.nn.Sigmoid(),torch.nn.Linear(in_features=3,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=10),torch.nn.Sigmoid(),torch.nn.Linear(in_features=10,out_features=5),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=1),torch.nn.ReLU(),)def forward(self,x):return self.layers(x)criterion = torch.nn.MSELoss()def train_with_CosinAnneal(epoch, initial_lr, T_max):linear1 = Linear()opt = torch.optim.Adam(linear1.parameters(), lr=initial_lr)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=T_max, last_epoch=-1)last_epoch_loss = 0for i in tqdm(range(epoch)):opt.zero_grad()total_loss = 0for j in range(len(x_train)):output = linear1(x_train[j])loss = criterion(output,y_train[j])total_loss = total_loss + loss.detach()loss.backward()opt.step()if i == epoch-1:last_epoch_loss = total_lossscheduler.step()cur_lr = scheduler.get_lr()  #查看学习率变化情况# plt.scatter(i, cur_lr, s=2, c='g')  plt.scatter(i, total_loss,s=2,c='r')val = 0for x in x_val:val = val + ((linear1(x) - x * x) / x * x)**2plt.title('CosinAnneal---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))plt.xlabel('epoch')plt.ylabel('loss')plt.show()def train_with_ExponentialLR(epoch, initial_lr, gamma):linear2 = Linear()opt = torch.optim.Adam(linear2.parameters(), lr=initial_lr)scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=gamma, last_epoch=-1)last_epoch_loss = 0for i in tqdm(range(epoch)):opt.zero_grad()total_loss = 0for j in range(len(x_train)):output = linear2(x_train[j])loss = criterion(output,y_train[j])total_loss = total_loss + loss.detach()loss.backward()opt.step()if i == epoch-1:last_epoch_loss = total_lossscheduler.step()cur_lr = scheduler.get_lr()# plt.scatter(i, cur_lr, s=2, c='g')plt.scatter(i, total_loss,s=2,c='g')val = 0for x in x_val:val = val + ((linear2(x) - x * x) / x * x)**2plt.title('ExponentialLR---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))plt.xlabel('epoch')plt.ylabel('loss')plt.show()def train_without_lr_scheduler(epoch, initial_lr):linear3 = Linear()opt = torch.optim.Adam(linear3.parameters(), lr=initial_lr)last_epoch_loss = 0for i in tqdm(range(epoch)):opt.zero_grad()total_loss = 0for j in range(len(x_train)):output = linear3(x_train[j])loss = criterion(output, y_train[j])total_loss = total_loss + loss.detach()loss.backward()opt.step()if i == epoch-1:last_epoch_loss = total_lossplt.scatter(i, total_loss, s=2, c='b')val = 0for x in x_val:val = val + ((linear3(x) - x * x) / x * x)**2plt.title('without_lr_scheduler---val=%f---last_epoch_loss=%f'%(val,last_epoch_loss))plt.xlabel('epoch')plt.ylabel('loss')plt.show()if __name__ == '__main__' :epoch = 2000initial_lr = 0.0005gamma = 0.99995T_max = 200Cosine_Anneal_args = [epoch,initial_lr*2,T_max]train_with_CosinAnneal(*Cosine_Anneal_args)# ExponentialLR_args = [epoch, initial_lr, gamma]# train_with_ExponentialLR(*ExponentialLR_args)## without_scheduler_args = [epoch, initial_lr]# train_without_lr_scheduler(*without_scheduler_args)

相关文章:

【PyTorch实战演练】自调整学习率实例应用(附代码)

目录 0. 前言 1. 自调整学习率的常用方法 1.1 ExponentialLR 指数衰减方法 1.2 CosineAnnealingLR 余弦退火方法 1.3 ChainedScheduler 链式方法 2. 实例说明 3. 结果说明 3.1 余弦退火法训练过程 3.2 指数衰减法训练过程 3.3 恒定学习率训练过程 3.4 结果解读 4. …...

app拉新渠道整合 一手地推、网推拉新平台整理

1.聚量推客 聚量推客自己本身是服务商,自己直营的平台,相对来说数据更好,我们也拿到了平台首码:000000 填这个就行,属于官方渠道 2.蓝猫推客 蓝猫推客我认为是比较又潜力的平台,经过几天测试数据和结算都…...

十六进制IP转换点分十进制代码

以下是一个可以实现将输入的十六进制格式的IP地址转换为点分十进制格式并输出的简单程序。它使用了 sscanf 函数将输入的字符串解析成无符号整数&#xff0c;然后使用 inet_ntoa 函数将其转换成点分十进制格式&#xff0c;并打印输出&#xff1a; #include <stdio.h> #i…...

面试官的一句话,让五年功能测试老手彻夜难眠!

小王是一名软件测试工程师&#xff0c;已经在目前的公司做了四五年的功能测试。虽然一直表现得非常努力&#xff0c;但他还是没能躲过裁员。只能被动跳槽&#xff0c;寻找更好的职业机会。 然而事情并没有像他想象中那样顺利。在多次面试中小王屡屡碰壁&#xff0c;被面试官吐槽…...

向量检索库Milvus架构及数据处理流程

文章目录 背景milvus想做的事milvus之前——向量检索的一些基础近似算法欧式距离余弦距离 常见向量索引1&#xff09; FLAT2&#xff09; Hash based3&#xff09; Tree based4&#xff09; 基于聚类的倒排5&#xff09; NSW&#xff08;Navigable Small World&#xff09;图 向…...

【华为路由器】配置企业通过5G链路接入Internet示例

场景介绍 5G Cellular接口是路由器用来实现5G技术的物理接口&#xff0c;它为用户提供了企业级的无线广域网接入服务&#xff0c;主要用于eMBB场景。与LTE相比&#xff0c;5G系统可以为企业用户提供更大带宽的无线广域接入服务。 路由器的5G功能&#xff0c;可以实现企业分支…...

python安装.whl文件

python --version https://www.lfd.uci.edu/~gohlke/pythonlibs/ 用CtrlF找需要安装的包 下载对应版本的whl python3.8 把下载好的whl放到安装路径下&#xff1a;C:\Users\Administrator\AppData\Local\Programs\Python\Python38\Lib\site-packages 并在该路径下打开cmd执行…...

Java方法调用动态绑定(多态性)详解

CONTENTS 1. 方法调用绑定2. 尝试重写Private方法3. 字段访问与静态方法的多态4. 构造器内部的多态方法行为 1. 方法调用绑定 我们首先来看下面这个例子&#xff1a; package com.yyj;enum Tone {LOW, MIDDLE, HIGH; }class Instrument {public void play(Tone t) {System.ou…...

【SwiftUI模块】0060、SwiftUI基于Firebase搭建一个类似InstagramApp 2/7部分-搭建TabBar

SwiftUI模块系列 - 已更新60篇 SwiftUI项目 - 已更新5个项目 往期Demo源码下载 技术:SwiftUI、SwiftUI4.0、Instagram、Firebase 运行环境: SwiftUI4.0 Xcode14 MacOS12.6 iPhone Simulator iPhone 14 Pro Max SwiftUI基于Firebase搭建一个类似InstagramApp 2/7部分-搭建Tab…...

代码随想录第50天 | 84.柱状图中最大的矩形

84.柱状图中最大的矩形 //双指针 js中运行速度最快 var largestRectangleArea function(heights) {const len heights.length;const minLeftIndex new Array(len);const maxRigthIndex new Array(len);// 记录每个柱子 左边第一个小于该柱子的下标minLeftIndex[0] -1; //…...

深度学习---卷积神经网络

卷积神经网络概述 卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领域。往往输入的图像都很大&#xff0c;使用全连接网络的话&#xff0c;计算的代价较高。另外图像也很难保留原有的特征&#xff0c;导致图像处理的准确率不高。 卷积神经网络&#xff0…...

Windows系统下安装CouchDB3.3.2教程

安装 前往CouchDB官网 官网点击download下载msi文件 双击该msi文件&#xff0c;一直下一步 创建个人account 设置cookie value 用于进行身份验证和授权。 愉快下载 点击OK 重启 启动 重启电脑后 打开浏览器并访问以下链接&#xff1a;http://127.0.0.1:5984/ 如果没有问…...

JavaScript基础知识(二)

JavaScript基础知识&#xff08;二&#xff09; 一、ES2015 基础语法1.变量2.常量3.模板字符串4.结构赋值 二、函数进阶1. 设置默认参数值2. 立即执行函数3. 闭包4. 箭头函数 三、面向对象1. 面向对象概述2. 基本概念3. 新语法 与 旧语法3.1 ES5 面向对象的知识ES5构造函数原型…...

SQL NULL Values(空值)

什么是SQL NULL值&#xff1f; SQL 中&#xff0c;NULL 用于表示缺失的值。数据表中的 NULL 值表示该值所处的字段为空。 具有NULL值的字段是没有值的字段。 如果表中的字段是可选的&#xff0c;则可以插入新记录或更新记录而不向该字段添加值。然后&#xff0c;该字段将被保存…...

云原生Docker网络管理

目录 Docker网络 Docker 网络实现原理 为容器创建端口映射 查看容器的输出和日志信息 Docker 的网络模式 查看docker网络列表 指定容器网络模式 网络模式详解 host模式 container模式 none模式 bridge模式 自定义网络 Docker网络 Docker 网络实现原理 Docker使用Lin…...

聊聊线程池的预热

序 本文主要研究一下线程池的预热 prestartCoreThread java/util/concurrent/ThreadPoolExecutor.java /*** Starts a core thread, causing it to idly wait for work. This* overrides the default policy of starting core threads only when* new tasks are executed. T…...

VueComponent的原型对象

一、prototype 每一个构造函数身上又有一个prototype指向其原型对象。 如果我们在控制台输入如下代码&#xff0c;就能看到Vue构造函数的信息&#xff0c;在他身上可以找到prototype属性&#xff0c;指向的是Vue原型对象&#xff1a; 二、__proto__ 通过构造函数创建的实例对…...

Redis不止能存储字符串,还有List、Set、Hash、Zset,用对了能给你带来哪些优势?

文章目录 &#x1f31f; Redis五大数据类型的应用场景&#x1f34a; 一、String&#x1f34a; 二、Hash&#x1f34a; 三、List&#x1f34a; 四、Set&#x1f34a; 五、Zset &#x1f4d5;我是廖志伟&#xff0c;一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO…...

Python OpenCV通过灰度平均值进行二值化处理以减少像素误差

Python OpenCV通过灰度平均值进行二值化处理以减少像素误差 前言前提条件相关介绍实验环境通过灰度平均值进行二值化处理以减少像素误差固定阈值二值化代码实现 灰度平均值二值化代码实现 前言 由于本人水平有限&#xff0c;难免出现错漏&#xff0c;敬请批评改正。更多精彩内容…...

[Golang]多返回值函数、defer关键字、内置函数、变参函数、类成员函数、匿名函数

函数 文章目录 函数多返回值函数按值传递、按引用传递类成员函数改变外部变量变参函数defer和追踪说明一些常见操作实现 使用defer实现代码追踪记录函数的参数和返回值 常见的内置函数将函数作为参数闭包实例闭包将函数作为返回值 计算函数执行时间使用内存缓存来提升性能 参考…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)

说明&#xff1a; 想象一下&#xff0c;你正在用eNSP搭建一个虚拟的网络世界&#xff0c;里面有虚拟的路由器、交换机、电脑&#xff08;PC&#xff09;等等。这些设备都在你的电脑里面“运行”&#xff0c;它们之间可以互相通信&#xff0c;就像一个封闭的小王国。 但是&#…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql

智慧工地管理云平台系统&#xff0c;智慧工地全套源码&#xff0c;java版智慧工地源码&#xff0c;支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求&#xff0c;提供“平台网络终端”的整体解决方案&#xff0c;提供劳务管理、视频管理、智能监测、绿色施工、安全管…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始&#xff0c;我们会探讨数据链路层的差错控制功能&#xff0c;差错控制功能的主要目标是要发现并且解决一个帧内部的位错误&#xff0c;我们需要使用特殊的编码技术去发现帧内部的位错误&#xff0c;当我们发现位错误之后&#xff0c;通常来说有两种解决方案。第一…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

dify打造数据可视化图表

一、概述 在日常工作和学习中&#xff0c;我们经常需要和数据打交道。无论是分析报告、项目展示&#xff0c;还是简单的数据洞察&#xff0c;一个清晰直观的图表&#xff0c;往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server&#xff0c;由蚂蚁集团 AntV 团队…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具&#xff0c;可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件&#xff0c;也不需要在线上传文件&#xff0c;保护您的隐私。 工具截图 主要特点 &#x1f680; 快速转换&#xff1a;本地转换&#xff0c;无需等待上…...

多模态图像修复系统:基于深度学习的图片修复实现

多模态图像修复系统:基于深度学习的图片修复实现 1. 系统概述 本系统使用多模态大模型(Stable Diffusion Inpainting)实现图像修复功能,结合文本描述和图片输入,对指定区域进行内容修复。系统包含完整的数据处理、模型训练、推理部署流程。 import torch import numpy …...