当前位置: 首页 > article >正文

2025-05-28 Python深度学习8——优化器

文章目录

  • 1 工作原理
  • 2 常见优化器
    • 2.1 SGD
    • 2.2 Adam
  • 3 优化器参数
  • 4 学习率
  • 5 使用最佳实践

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ 优化器 (Optimizer) 是深度学习中的核心组件,负责根据损失函数的梯度来更新模型的参数,使模型能够逐步逼近最优解。在 PyTorch 中,优化器通过torch.optim模块提供。

​ Pytorch 链接:https://docs.pytorch.org/docs/stable/optim.html。

1 工作原理

​ 优化器的工作流程如下:

  1. 计算损失函数的梯度 (通过backward()方法)。
  2. 根据梯度更新模型参数 (通过step()方法)。
  3. 清除之前的梯度 (通过zero_grad()方法)。
result_loss.backward()  # 计算梯度
optim.step()           # 更新参数
optim.zero_grad()      # 清除梯度

2 常见优化器

​ PyTorch 提供多种优化器,以 SGD 和 Adam 为例。

2.1 SGD

​ 基础优化器,可以添加动量 (momentum) 来加速收敛。

image-20250528111156055
参数类型默认值作用使用建议
paramsiterable-待优化参数必须传入model.parameters()或参数组字典,支持分层配置
lrfloat1e-3学习率控制参数更新步长,SGD常用0.01-0.1,Adam常用0.001
momentumfloat0动量因子加速梯度下降(Adam内置动量,无需单独设置)
dampeningfloat0动量阻尼抑制动量震荡(仅当momentum>0时生效)
weight_decayfloat0L2正则化防止过拟合,AdamW建议0.01-0.1
nesterovboolFalseNesterov动量改进版动量法(需momentum>0)
maximizeboolFalse最大化目标默认最小化损失,True时改为最大化
foreachboolNone向量化实现CUDA下默认开启,内存不足时禁用
differentiableboolFalse可微优化允许优化器步骤参与自动微分(影响性能)
fusedboolNone融合内核CUDA加速,支持float16/32/64/bfloat16

2.2 Adam

image-20250528111450759
  • 特点:自适应矩估计,结合了动量法和 RMSProp 的优点。
  • 优点:通常收敛速度快,对学习率不太敏感。
参数名称类型默认值作用使用建议
paramsiterable-需要优化的参数(如model.parameters()必须传入,支持参数分组配置
lrfloat1e-3学习率(控制参数更新步长)推荐0.001起调,CV任务可尝试0.0001-0.01
betas(float, float)(0.9, 0.999)梯度一阶矩(β₁)和二阶矩(β₂)的衰减系数保持默认,除非有特殊需求
epsfloat1e-8分母稳定项(防止除以零)混合精度训练时可增大至1e-6
weight_decayfloat0L2正则化系数推荐0.01-0.1(使用AdamW时更有效)
decoupled_weight_decayboolFalse启用AdamW模式(解耦权重衰减)需要权重衰减时建议设为True
amsgradboolFalse使用AMSGrad变体(解决收敛问题)训练不稳定时可尝试启用
foreachboolNone使用向量化实现加速(内存消耗更大)CUDA环境下默认开启,内存不足时禁用
maximizeboolFalse最大化目标函数(默认最小化)特殊需求场景使用
capturableboolFalse支持CUDA图捕获仅在图捕获场景启用
differentiableboolFalse允许通过优化器步骤进行自动微分高阶优化需求启用(性能下降)
fusedboolNone使用融合内核实现(需CUDA)支持float16/32/64时启用可加速

3 优化器参数

​ 所有优化器都接收两个主要参数:

  1. params:要优化的参数,通常是model.parameters()
  2. lr:学习率(learning rate),控制参数更新的步长。

​ 其他常见参数:

  • weight_decay:L2 正则化系数,防止过拟合。
  • momentum:动量因子,加速 SGD 在相关方向的收敛。
  • betas(Adam 专用):用于计算梯度及其平方的移动平均的系数。

4 学习率

​ 学习率是优化器中最重要的超参数之一。

  • 太大:可能导致震荡或发散。
  • 太小:收敛速度慢。
  • 常见策略:
    • 固定学习率 (如代码中的 0.01)。
    • 学习率调度器 (Learning Rate Scheduler) 动态调整。

5 使用最佳实践

  1. 梯度清零:每次迭代前调用optimizer.zero_grad(),避免梯度累积。
  2. 参数更新顺序:先backward()step()
  3. 学习率选择:可以从默认值开始 (如 Adam 的 0.001),然后根据效果调整。
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10(root='./dataset',  # 保存路径train=False,  # 是否为训练集transform=torchvision.transforms.ToTensor(),  # 转换为张量download=True  # 是否下载
)dataloader = DataLoader(dataset, batch_size=1)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):return self.model(x)loss = nn.CrossEntropyLoss()
model = MyModel()
torch.optim.Adam(model.parameters(), lr=0.01)
optim = torch.optim.SGD(model.parameters(), lr=0.01)for epoch in range(20):running_loss = 0.0# 遍历dataloader中的数据for data in dataloader:# 获取数据和标签imgs, targets = data# 使用模型对数据进行预测output = model(imgs)# 计算预测结果和真实标签之间的损失result_loss = loss(output, targets)# 将梯度置零optim.zero_grad()# 反向传播计算梯度result_loss.backward()# 更新模型参数optim.step()running_loss += result_lossprint(f'第 {epoch + 1} 轮的损失为 {running_loss}')

相关文章:

2025-05-28 Python深度学习8——优化器

文章目录 1 工作原理2 常见优化器2.1 SGD2.2 Adam 3 优化器参数4 学习率5 使用最佳实践 本文环境: Pycharm 2025.1Python 3.12.9Pytorch 2.6.0cu124 ​ 优化器 (Optimizer) 是深度学习中的核心组件,负责根据损失函数的梯度来更新模型的参数,使…...

篇章二 数据结构——前置知识(二)

目录 1. 包装类 1.1 包装类的概念 1.2 基本数据类型和对应的包装类 1.3 装箱和拆箱 1.4 自动装箱和自动拆箱 1.5 练习 —— 面试题 2. 泛型 2.1 如果没有泛型——会出现什么情况&#xff1f; 2.2 语法 2.3 裸类型 1.没有写<> 但是没有报错为什么&#xff1f; …...

如果是在服务器的tty2终端怎么查看登陆服务器的IP呢

1. 如果是在服务器的tty2终端怎么查看登陆服务器的IP呢 在服务器的 tty2 或其他终端会话中&#xff0c;要查看与该服务器的连接相关的 IP 地址&#xff0c;可以使用几种命令来获取这些信息&#xff1a; 1.1 使用 who 命令&#xff1a; who 命令可以显示当前登录到服务器上的…...

Java求职面试:从核心技术到AI与大数据的全面考核

Java求职面试&#xff1a;从核心技术到AI与大数据的全面考核 第一轮&#xff1a;基础框架与核心技术 面试官&#xff1a;谢飞机&#xff0c;咱们先从简单的开始。请你说说Spring Boot的启动过程。 谢飞机&#xff1a;嗯&#xff0c;Spring Boot启动的时候会自动扫描组件&…...

ubuntu24.04与ubuntu22.04比,有什么新特性?

Ubuntu 24.04 LTS (Noble Numbat) 相较于 Ubuntu 22.04 LTS (Jammy Jellyfish) 带来了许多重要的新特性和改进。以下是一些关键的亮点&#xff1a; Linux Kernel: Ubuntu 24.04 LTS: 搭载了更新的 Linux Kernel 6.8&#xff08;发布时&#xff09;。 Ubuntu 22.04 LTS: 发布时…...

Flutter Container组件、Text组件详解

目录 1. Container容器组件 1.1 Container使用 1.2 Container alignment使用 1.3 Container border边框使用 1.4 Container borderRadius圆角的使用 1.5 Container boxShadow阴影的使用 1.6 Container gradient背景颜色渐变 1.7 Container gradient RadialGradient 背景颜色渐…...

Telegram平台分发其聊天机器人Grok

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…...

STM32 定时器输出比较深度解析:从原理到电机控制应用 (详解)

文章目录 定时器输出比较定时器通道结构输出比较通道(高级) PWM 信号原理输出比较 8 种工作模式互补输出概念极性选择内容 PWM硬件部分舵机直流电机及驱动简介 定时器输出比较 定时器通道结构 通道组成&#xff1a;定时器有四个通道&#xff0c;以通道一为例&#xff0c;中间是…...

用 NGINX 还原真实客户端 IP ngx_mail_realip_module

一、模块作用与使用前提 作用&#xff1a;解析 TCP 会话第一行的 PROXY 协议头&#xff0c;将客户端 IP/端口写回 NGINX 的内部变量&#xff0c;使后续 ngx_mail_proxy_module、认证模块、日志模块都能获取真实来源。 前提&#xff1a;监听指令中必须启用 proxy_protocol&…...

Mysql中索引B+树、最左前缀匹配

这里需要对索引的相关结构有一个基础的认识&#xff0c;比如线性索引&#xff0c;树形索引&#xff08;二叉树&#xff0c;平衡二叉树&#xff0c;红黑树等&#xff09;&#xff0c;这个up主我觉得讲的还是比较清楚的&#xff0c;可以看下。 终于把B树搞明白了(一)_B树的引入…...

Python训练营打卡 Day38

Dataset和Dataloader类 知识点回顾&#xff1a; Dataset类的__getitem__和__len__方法&#xff08;本质是python的特殊方法&#xff09;Dataloader类minist手写数据集的了解 作业&#xff1a;了解下cifar数据集&#xff0c;尝试获取其中一张图片 Dataset和Dataloader类 1. Data…...

【机器学习基础】机器学习入门核心算法:K均值(K-Means)

机器学习入门核心算法&#xff1a;K均值&#xff08;K-Means&#xff09; 1. 算法逻辑2. 算法原理与数学推导2.1 目标函数2.2 数学推导2.3 时间复杂度 3. 模型评估内部评估指标外部评估指标&#xff08;需真实标签&#xff09; 4. 应用案例4.1 客户细分4.2 图像压缩4.3 文档聚类…...

Python Day37

Task&#xff1a; 1.过拟合的判断&#xff1a;测试集和训练集同步打印指标 2.模型的保存和加载 a.仅保存权重 b.保存权重和模型 c.保存全部信息checkpoint&#xff0c;还包含训练状态 3.早停策略 1. 过拟合的判断&#xff1a;测试集和训练集同步打印指标 过拟合是指模型在训…...

RabbitMQ集群与负载均衡实战指南

文章目录 集群架构概述仲裁队列的使用1. 使用Spring框架代码创建2. 使用amqp-client创建3. 使用管理平台创建 负载均衡引入HAProxy 负载均衡&#xff1a;使用方法1. 修改配置文件2. 声明队列 test_cluster3. 发送消息 集群架构 概述 RabbitMQ支持部署多个结点&#xff0c;每个…...

怎么开机自动启动vscode项目

每次开机都得用 vscode 打开多个工程&#xff0c;然后用 vscode 里的终端启动&#xff0c;怎么设置成开机自动启动&#xff0c;省事点。 创建 bat 文件&#xff0c;用 cmd 启动&#xff0c;然后将 bat 文件放到 windows 启动文件夹中 yqp1.bat echo on cls d: cd D:\yqp\add…...

Unity 中 Update、FixedUpdate 和 LateUpdate 的区别及使用场景

在Unity开发中,Update、FixedUpdate 和 LateUpdate 是生命周期函数中最常见也最容易混淆的一组。 一、调用时机 方法名调用频率调用时机说明Update()每帧调用一次跟随帧率(帧率高则调用频率高)FixedUpdate()固定时间间隔调用默认每 0.02 秒执行一次LateUpdate()每帧调用一次…...

linux安装ffmpeg7.0.2全过程

​编辑 白眉大叔 发布于 2025年4月16日 评论关闭 阅读(341) centos 编译安装 ffmpeg 7.0.2 &#xff1a;连接https://www.baimeidashu.com/19668.html 下载 FFmpeg 源代码 在文章最后 一、在CentOS上编译安装FFmpeg 以常见的CentOS为例&#xff0c;FFmpeg的编译说明页面为h…...

Java中的设计模式实战:单例、工厂、策略模式的最佳实践

Java中的设计模式实战&#xff1a;单例、工厂、策略模式的最佳实践 在Java开发中&#xff0c;设计模式是构建高效、可维护、可扩展应用程序的关键。本文将深入探讨三种常见且实用的设计模式&#xff1a;单例模式、工厂模式和策略模式&#xff0c;并通过详细代码实例&#xff0…...

DexGarmentLab 论文翻译

单个 专家 演示 装扮 15 任务 场景 2500+ 服装 手套 棒球帽 裤子 围巾 碗 帽子 上衣 外套 服装-手部交互 捕捉 摇篮 夹紧 平滑 任务 ...... 投掷 悬挂 折叠 ... 多样化位置 ... 多样化 变形 ... 多样化服装形状 类别级 一般化 类别级(有或没有变形) 服装具有相同结构 变形 生…...

Elasticsearch性能优化全解析

Elasticsearch作为一款分布式搜索和分析引擎,其性能优化是实际生产环境中必须深入研究的课题。本文基于Elastic官方文档,系统性地总结了从硬件配置、索引设计到查询优化的全链路优化策略,帮助用户构建高性能、高稳定性的集群。 Elasticsearch的优化需结合业务场景综合决策:…...

2025.05.28【Parallel】Parallel绘图:拟时序分析专用图

Improve general appearance Add title, use a theme, change color palette, control variable orders and more Highlight a group Highlight a group of interest to help people understand your story 文章目录 Improve general appearanceHighlight a group探索Paralle…...

tc3975开发板上有ft2232这块的电路,我想知道这个开发板有哪些升级方式,重点关注是怎样通过ft2232实现的烧录升级的

关于TC3975开发板上FT2232芯片支持的升级方式&#xff0c;特别是如何通过FT2232实现烧录升级的问题。首先&#xff0c;我得回忆一下FT2232的基本功能和常见应用场景。 FT2232是FTDI公司的一款双通道USB转UART/FIFO芯片&#xff0c;常用于嵌入式系统的调试和编程。它支持多种协议…...

自动驾驶与智能交通:构建未来出行的智能引擎

随着人工智能、物联网、5G和大数据等前沿技术的发展&#xff0c;自动驾驶汽车和智能交通系统正以前所未有的速度改变人类的出行方式。这一变革不仅是技术的融合创新&#xff0c;更是推动城市可持续发展的关键支撑。 一、自动驾驶与智能交通的定义 1. 自动驾驶&#xff08;Auto…...

Kotlin Multiplatform与Flutter深度对比:跨平台开发方案的实战选择

简介 在当今多平台应用开发的浪潮中,Kotlin Multiplatform与Flutter代表了两种截然不同的技术路线。KMP以"共享代码、保留原生"为核心理念,允许开发者在业务逻辑层实现高达80%的跨平台代码共享,而Flutter则采用统一渲染引擎,在UI层提供100%的代码共享率。这两种…...

ELectron 中 BrowserView 如何进行实时定位和尺寸调整

背景 BrowserView 是继 Webview 后推出来的高性能多视图管理工具&#xff0c;与 Webview 最大的区别是&#xff0c;Webview 是一个 DOM 节点&#xff0c;依附于主渲染进程的附属进程&#xff0c;Webview 节点的崩溃会导致主渲染进程的连锁反应&#xff0c;会引起软件的崩溃。 …...

深兰科技董事长陈海波率队考察南京,加速AI大模型区域落地应用

近日&#xff0c;深兰科技创始人、董事长陈海波受邀率队赴南京市&#xff0c;先后考察了南京高新技术产业开发区与鼓楼区&#xff0c;就推进深兰AI医诊大模型在南京的落地应用&#xff0c;与当地政府及相关部门进行了深入交流与合作探讨。 此次考察聚焦于深兰科技自主研发的AI医…...

《深度关系-从建立关系到彼此信任》

陈海贤老师推荐的书&#xff0c;花了几个小时&#xff0c;感觉现在的人与人之间特别缺乏这种深度的关系&#xff0c;但是与一个人建立深度的关系并没有那么简单&#xff0c;反正至今为止&#xff0c;自己好像没有与任何一个人建立了这种深度的关系&#xff0c;那种双方高度同频…...

IT选型指南:电信行业需要怎样的服务器?

从第一条电报发出的 那一刻起 电信技术便踏上了飞速发展的征程 百余年间 将世界编织成一个紧密相连的整体 而在今年 我们迎来了第25届世界电信日 同时也是国际电联成立的第160周年 本届世界电信日的主题为:“弥合性别数字鸿沟,为所有人创造机遇”,但在新兴技术浪潮汹涌…...

【ConvLSTM第二期】模拟视频帧的时序建模(Python代码实现)

目录 1 准备工作&#xff1a;python库包安装1.1 安装必要库 案例说明&#xff1a;模拟视频帧的时序建模ConvLSTM概述损失函数说明&#xff08;python全代码&#xff09; 参考 ConvLSTM的原理说明可参见另一博客-【ConvLSTM第一期】ConvLSTM原理。 1 准备工作&#xff1a;pytho…...

[VMM]分享一个用SystemC编写的页表管理程序

分享一个用SystemC编写的页表管理程序 摘要:分享一个用SystemC编写的页表管理的程序,这个程序将模拟页表(PDE和PTE)的创建、虚拟地址(VA)到物理地址(PA)的转换,以及对内存的读写操作。 为了简化实现,我们做出以下假设: 页表是两级结构:PDE (Page Directory…...