当前位置: 首页 > news >正文

学习率调整策略 | PyTorch 深度学习实战

前一篇文章,深度学习里面的而优化函数 Adam,SGD,动量法,AdaGrad 等 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】

PyTorch 学习率调整策略

  • 常见的学习率调节器
    • 学习率衰减
    • 指数衰减
    • 余弦学习率调节
    • 预热
  • 示例程序
    • 执行结果
      • 没有使用学习率自动调节时
      • 使用了学习率自动调节
      • 结论
  • 常见学习率调节器
  • Links

常见的学习率调节器

在这里插入图片描述

学习率衰减

在这里插入图片描述

指数衰减

在这里插入图片描述

余弦学习率调节

实现学习率循环降低或升高的效果

在这里插入图片描述

预热

在这里插入图片描述

示例程序

下面以指数衰减调节器(ExponentialLR)为例子,展示同样的数据条件下:不衰减学习率和衰减学习率两种情况下,损失函数loss的收敛情况。

import torch
torch.manual_seed(777)'''
Learning rate scheduler
'''
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset # 构造数据集加载器
from torch.utils.data import random_split # 划分数据集torch.manual_seed(777)# for reproducibility为了重复使用############################
# 生成数据
############################# 定义函数
def f(x,y):return x**2 + 2*y**2# 定义初始值
num_samples = 1000 # 1000 个样本点
X = torch.rand(num_samples) # 均匀分布
Y = torch.rand(num_samples)
Z = f(X,Y) + 3 * torch.randn(num_samples)# Concatenates a sequence of tensors along a new dimension.
# All tensors need to be of the same size.
# https://pytorch.org/docs/stable/generated/torch.stack.html
dataset = torch.stack([X,Y,Z], dim=1)
# print(dataset.shape) # torch.Size([1000, 3])# split data, 按照 7:3 划分数据集
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_sizetrain_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_size, test_size])# 将数据封装为数据加载器
# narrow 函数对数据进行切片操作,
# 
train_dataloader = DataLoader(TensorDataset(train_dataset.dataset.narrow(1,0,2), train_dataset.dataset.narrow(1,2,1)), batch_size=32, shuffle=False)
test_dataloader = DataLoader(TensorDataset(test_dataset.dataset.narrow(1,0,2), test_dataset.dataset.narrow(1,2,1)), batch_size=32, shuffle=False)############################
# 模型定义
############################# 定义一个简单的模型
class Model(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(2, 8)self.output = nn.Linear(8, 1)def forward(self, x):x = torch.relu(self.hidden(x))return self.output(x)############################
# 模型训练
############################# 超参数
num_epochs = 100
learning_rate = 0.1 # 学习率,调大一些更直观# 定义损失函数
loss_fn = nn.MSELoss()# 通过两次训练,对比有无调节器的效果
for with_scheduler in [False, True]:# 定义训练和测试误差数组train_losses = []test_losses = []# 初始化模型model = Model()# 定义优化器optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)# 定义学习率调节器scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)# 迭代训练for epoch in range(num_epochs):# 设定模型工作在训练模式model.train()train_loss = 0# 遍历训练集for inputs, targets in train_dataloader:# 预测、损失函数、反向传播optimizer.zero_grad()outputs = model(inputs)loss = loss_fn(outputs, targets)loss.backward()optimizer.step()# 记录 losstrain_loss += loss.item()# 计算 loss 并记录到训练误差train_loss /= len(train_dataloader)train_losses.append(train_loss)# 在测试数据集上评估model.eval()test_loss = 0with torch.no_grad():# 遍历测试集for inputs, targets in test_dataloader:# 预测、损失函数outputs = model(inputs)loss = loss_fn(outputs, targets)# 记录 losstest_loss += loss.item()# 计算 loss 并记录到测试误差test_loss /= len(test_dataloader)test_losses.append(test_loss)# 是否更新学习率if with_scheduler:scheduler.step()# 绘制训练和测试误差曲线plt.figure(figsize= (8, 4))plt.plot(range(num_epochs), train_losses, label="Train")plt.plot(range(num_epochs), test_losses, label="Test")plt.title("{0} lr_scheduler".format("With " if with_scheduler else "Without"))plt.legend()# plt.ylim((1,2))plt.show()

执行结果

没有使用学习率自动调节时

在这里插入图片描述

使用了学习率自动调节

在这里插入图片描述

结论

使用了学习率自动调节,学习的速度更快,收敛速度更快。

常见学习率调节器

## 学习率衰减,例如每训练 100 次就将学习率降低为原来的一半
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=100, gamma=0.5)
## 指数衰减法,每次迭代将学习率乘上一个衰减率
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,gamma=0.99)
## 余弦学习率调节,optimizer 初始学习率为最大学习率,eta_min 是最小学习率,T_max 是最大的迭代次数
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=100, eta_min=0.00001)
## 自定义学习率,通过一个 lambda 函数自定义实现学习率调节器
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.99 ** epoch)
## 预热
warmup_steps = 20
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda t: min(t/warmup_steps, 0.001))

Links

  • PyTorch学习率调整策略.ipynb
  • 6.2 动态调整学习率
  • 【学习率】torch.optim.lr_scheduler学习率10种调整方法整理
  • 11.11. 学习率调度器
  • Pytorch – 手动调整学习率以及使用torch.optim.lr_scheduler调整学习率

相关文章:

学习率调整策略 | PyTorch 深度学习实战

前一篇文章,深度学习里面的而优化函数 Adam,SGD,动量法,AdaGrad 等 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 本篇文章内容来自于 强化学习必修课:引…...

DeepSeekMoE 论文解读:混合专家架构的效能革新者

论文链接:DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models 目录 一、引言二、背景知识(一)MoE架构概述(二)现有MoE架构的问题 三、DeepSeekMoE架构详解(一&a…...

以下是基于巨控GRM241Q-4I4D4QHE模块的液位远程控制系统技术方案:

以下是基于巨控GRM241Q-4I4D4QHE模块的液位远程控制系统技术方案: 一、系统概述 本系统采用双巨控GRM241Q模块构建4G无线物联网络,实现山上液位数据实时传输至山下水泵站,通过预设逻辑自动控制水泵启停,同时支持APP远程监控及人工…...

【JVM详解五】JVM性能调优

示例: 配置JVM参数运行 #前台运行 java -XX:MetaspaceSize-128m -XX:MaxMetaspaceSize-128m -Xms1024m -Xmx1024m -Xmn256m -Xss256k -XX:SurvivorRatio8 - XX:UseConcMarkSweepGC -jar /jar包路径 #后台运行 nohup java -XX:MetaspaceSize-128m -XX:MaxMetaspaceS…...

2.10日学习总结

题目一&#xff1a; AC代码 #include <stdio.h>#define N 1000000typedef long long l;int main() {int n, m;l s 0;l a[N 1], b[N 1];int i 1, j 1;scanf("%d %d", &n, &m);for (int k 1; k < n; k) {scanf("%lld", &a[k]);…...

疯狂前端面试题(四)

一、Ajax、JSONP、JSON、Fetch 和 Axios 技术详解 1. Ajax&#xff08;异步 JavaScript 和 XML&#xff09; 什么是 Ajax&#xff1f; Ajax 是一种用于在不刷新页面的情况下与服务器进行数据交互的技术。它通过 XMLHttpRequest 对象实现。 优点 - 支持同步和异步请求。 - 能…...

YOLOv11-ultralytics-8.3.67部分代码阅读笔记-metrics.py

metrics.py ultralytics\utils\metrics.py 目录 metrics.py 1.所需的库和模块 2.def bbox_ioa(box1, box2, iouFalse, eps1e-7): 3.def box_iou(box1, box2, eps1e-7): 4.def bbox_iou(box1, box2, xywhTrue, GIoUFalse, DIoUFalse, CIoUFalse, eps1e-7): 5.def mas…...

SuperCopy解除网页禁用复制功能插件安装和使用

点击下载《SuperCopy解除网页禁用复制功能插件》 1. 前言 在当今数字化时代&#xff0c;网络已成为我们获取信息和知识的主要渠道。互联网如同一片浩瀚无垠的知识海洋&#xff0c;蕴藏着无数的资源&#xff0c;从学术论文到生活小窍门&#xff0c;从专业教程到娱乐资讯&#…...

UP-VLA:具身智体的统一理解与预测模型

25年1月来自清华大学和上海姚期智研究院的论文“UP-VLA: A Unified Understanding and Prediction Model for Embodied Agent”。 视觉-语言-动作 (VLA) 模型的最新进展&#xff0c;利用预训练的视觉语言模型 (VLM) 来提高泛化能力。VLM 通常经过视觉语言理解任务的预训练&…...

Unity 基于状态机的逻辑控制详解

状态机是游戏开发中常用的逻辑控制方法&#xff0c;它可以将复杂的逻辑分解成多个独立的状态&#xff0c;并通过状态转移来控制逻辑的执行流程。本文将详细介绍如何在 Unity 中基于状态机实现逻辑控制&#xff0c;并提供技术详解和代码实现。 一、状态机简介 1.1 基本概念 状…...

傅里叶单像素成像技术研究进展

摘要&#xff1a;计算光学成像&#xff0c;通过光学系统和信号处理的有机结合与联合优化实现特定成像特性的成像系统&#xff0c;摆脱了传统成像系统的限制&#xff0c;为光学成像技术添加了浓墨重彩的一笔&#xff0c;并逐步向简单化与智能化的方向发展。单像素成像(Single-Pi…...

IDEA接入DeepSeek

IDEA 目前有多个途径可以接入deepseek&#xff0c;比如CodeGPT或者Continue&#xff0c;这里借助CodeGPT插件接入&#xff0c;CodeGPT目前用的人最多&#xff0c;相对更稳定 一、安装 1.安装CodeGPT idea插件市场找到CodeGPT并安装 2.创建API Key 进入deepseek官网&#xf…...

前端如何判断浏览器 AdBlock/AdBlock Plus(最新版)广告屏蔽插件已开启拦截

2个月前AdBlock/AdBlock Plus疑似升级了一次 因为自己主要负责面对海外的用户项目&#xff0c;发现以前的检测AdBlock/AdBlock Plus开启状态方法已失效了&#xff0c;于是专门研究了一下。并尝试了很多方法。 已失效的老方法 // 定义一个检测 AdBlock 的函数 function chec…...

macOS 上部署 RAGFlow

在 macOS 上从源码部署 RAGFlow-0.14.1&#xff1a;详细指南 一、引言 RAGFlow 作为一款强大的工具&#xff0c;在人工智能领域应用广泛。本文将详细介绍如何在 macOS 系统上从源码部署 RAGFlow 0.14.1 版本&#xff0c;无论是开发人员进行项目实践&#xff0c;还是技术爱好者…...

如何在Kickstart自动化安装完成后ISO内拷贝文件到新系统或者执行命令

如何在Kickstart自动化安装完成后ISO内拷贝文件到新系统或者执行命令 需求 在自动化安装操作系统完成后&#xff0c;需要对操作系统进行配置需要拷贝一些文件到新的操作系统中需要运行一些脚本 问题分析 Linux安装操作系统时&#xff0c;实际上是将ISO镜像文件中的操作系统…...

在服务器部署JVM后,如何评估JVM的工作能力,比如吞吐量

在服务器部署JVM后&#xff0c;评估其工作能力&#xff08;如吞吐量&#xff09;可以通过以下步骤进行&#xff1a; 1. 选择合适的基准测试工具 JMH (Java Microbenchmark Harness)&#xff1a;适合微基准测试&#xff0c;测量特定代码片段的性能。Apache JMeter&#xff1a;…...

攻防世界32 very_easy_sql【SSRF/SQL时间盲注】

不太会&#xff0c;以后慢慢看 被骗了&#xff0c;看见very_easy就点进来了&#xff0c;结果所有sql能试的全试了一点用都没有 打开源代码发现有个use.php 好家伙&#xff0c;这是真的在考sql吗...... 制作gopher协议的脚本&#xff1a; import urllib.parsehost "12…...

STM32G474--Whetstone程序移植(双精度)笔记

1 获取Whetstone程序 Whetstone程序&#xff0c;我用github被墙了&#xff0c;所以用了KK的方式。 获取的程序目录如上所示。 2 新建STM32工程 配置如上&#xff0c;生成工程即可。 3 在生成的工程中添加并修改Whetstone程序 3.1 实现串口打印功能 在生成的usart.c文件中…...

【DeepSeek × Postman】请求回复

新建一个集合 在 Postman 中创建一个测试集合 DeepSeek API Test&#xff0c;并创建一个关联的测试环境 DeepSeek API Env&#xff0c;同时定义两个变量 base_url 和 api_key 的步骤如下&#xff1a; 1. 创建测试集合 DeepSeek API Test 打开 Postman。点击左侧导航栏中的 Co…...

开源身份和访问管理方案之keycloak(一)快速入门

文章目录 什么是IAM什么是keycloakKeycloak 的功能 核心概念client管理 OpenID Connect 客户端 Client Scoperealm roleAssigning role mappings分配角色映射Using default roles使用默认角色Role scope mappings角色范围映射 UsersGroupssessionsEventsKeycloak Policy创建策略…...

COMSOL实战:从微波炉到压电泵的多物理场魔法

comsol软件教程&#xff0c;电热力耦合&#xff0c;动网格&#xff0c;传热&#xff0c;优化&#xff0c;微波加热&#xff0c;压电&#xff08;非comsol官网搬运&#xff09; comsol仿真教程&#xff0c;多物理场&#xff0c;建模仿真&#xff0c;低频电磁今天咱们来点硬核的—…...

次元画室一键部署教程:Python环境快速配置与模型启动

次元画室一键部署教程&#xff1a;Python环境快速配置与模型启动 你是不是也对AI绘画感兴趣&#xff0c;想自己动手试试&#xff0c;结果被复杂的Python环境、CUDA版本、模型权重这些术语给吓退了&#xff1f;别担心&#xff0c;这种感觉我太懂了。几年前我第一次接触这些时&a…...

Web AR开发全指南:从技术原理到实战应用

Web AR开发全指南&#xff1a;从技术原理到实战应用 【免费下载链接】AR.js Image tracking, Location Based AR, Marker tracking. All on the Web. 项目地址: https://gitcode.com/gh_mirrors/arj/AR.js 随着增强现实技术的发展&#xff0c;Web AR开发已成为前端领域的…...

Qwen3.5-27B-GPTQ-Int4:超高效多模态AI新体验

Qwen3.5-27B-GPTQ-Int4&#xff1a;超高效多模态AI新体验 【免费下载链接】Qwen3.5-27B-GPTQ-Int4 项目地址: https://ai.gitcode.com/hf_mirrors/Qwen/Qwen3.5-27B-GPTQ-Int4 导语 阿里云推出Qwen3.5-27B-GPTQ-Int4模型&#xff0c;通过4位量化技术实现性能与效率的双…...

ESP32烧录全攻略:从命令行到GUI工具,新手也能轻松搞定

ESP32烧录全攻略&#xff1a;从命令行到GUI工具&#xff0c;新手也能轻松搞定 第一次接触ESP32开发板时&#xff0c;那块小小的芯片里蕴藏着无限可能&#xff0c;但如何将自己的代码"装进"这个硬件大脑却成了拦路虎。记得我最初尝试烧录时&#xff0c;面对各种专业术…...

Html2Pdf高性能转换引擎:PHP 7.2-8.4全版本兼容的企业级HTML转PDF解决方案

Html2Pdf高性能转换引擎&#xff1a;PHP 7.2-8.4全版本兼容的企业级HTML转PDF解决方案 【免费下载链接】html2pdf OFFICIAL PROJECT | HTML to PDF converter written in PHP 项目地址: https://gitcode.com/gh_mirrors/ht/html2pdf 在当今企业数字化转型浪潮中&#xf…...

从零到一:基于GitHub Pages与Jekyll搭建你的专属学术主页

1. 为什么选择GitHub Pages Jekyll搭建学术主页&#xff1f; 作为一个长期在学术界摸爬滚打的老兵&#xff0c;我见过太多同行花大价钱购买服务器和维护网站&#xff0c;结果最后因为各种技术问题半途而废。直到我发现GitHub Pages和Jekyll这对黄金组合&#xff0c;才真正找到…...

避开这些坑!群晖+acme.sh申请Let’s Encrypt证书的完整指南

群晖NAS上零踩坑申请Lets Encrypt证书的终极实践手册 每次看到浏览器地址栏那个刺眼的"不安全"提示就浑身难受&#xff1f;作为群晖深度用户&#xff0c;我花了三个周末时间踩遍了所有证书申请的坑。从idn指令缺失到nss验证失败&#xff0c;从API调用超时到证书自动更…...

第三章 Qt 编译及安装

1. Qt 编译安装 2 Qt 在线安装 在线安装包的下载地址&#xff1a; https://download.qt.io/official_releases/online_installers/ Qt对不同的平台提供了不同版本的安装包&#xff0c;可根据实际情况自行下载安装&#xff0c;本文档使用qt-online-installer-windows-x64-on…...

OpenClaw任务编排:GLM-4.7-Flash驱动复杂工作流

OpenClaw任务编排&#xff1a;GLM-4.7-Flash驱动复杂工作流 1. 为什么需要任务编排&#xff1f; 去年我接手了一个重复性极高的数据整理工作——每周需要从十几个不同来源收集数据&#xff0c;清洗后生成可视化报告。最初尝试用Python脚本自动化&#xff0c;但随着需求变化&a…...