当前位置: 首页 > news >正文

机器学习之常用优化器

机器学习之常用优化器

  • 1、SGD 优化器
    • 1.2、 SGD 的优缺点
  • 2、 Adam 优化器
    • 2.1、设置 Adam 优化器
    • 2.2、使用 Adam 优化器的训练流程
    • 2.3、Adam 优化器的优缺点
  • 3. AdamW 优化器
    • 3.1、示例
    • 3.2、训练过程
    • 3.3、AdamW 优化器的优点

1、SGD 优化器

   在 PyTorch 中,设置 SGD 优化器很直接。可以设置的关键参数,如学习率(lr)、动量(momentum)、权重衰减(weight decay)等,以帮助改善训练过程中的收敛速度和最终模型的泛化能力。

1.1、示例代码

from torch.optim import SGD# 假设 model 是您的模型实例
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6)

相关参数解释:

  • lr (learning rate): 学习率控制着权重调整的幅度,即步长。如果学习率太高,训练可能会在最优解附近震荡或者完全偏离;如果学习率太低,训练过程可能会非常慢。
  • momentum: 动量帮助优化器在相关方向上持续前进,从而加快收敛,并减少震荡。
  • weight_decay: 权重衰减是一种正则化技术,可以防止模型过拟合。它通过在损失函数中添加一项与权重大小成比例的成本来工作。

在训练循环中使用 SGD 优化器进行参数更新:

# 训练模型的一个epoch
model.train()  # 将模型设置为训练模式
for inputs, targets in train_loader:  # 假设 train_loader 是您的数据加载器optimizer.zero_grad()  # 清除之前的梯度outputs = model(inputs)  # 获得模型的预测结果loss = loss_function(outputs, targets)  # 计算损失loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新模型参数

1.2、 SGD 的优缺点

  • 优点

    • 简单且易于实现。
    • 在许多情况下效果很好,尤其是数据集很大时。
  • 缺点

    • 对学习率和其他超参数比较敏感。
    • 可能需要更多的epoch来收敛。
    • 在高维空间中,SGD的收敛速度可能较慢,尤其是在参数空间存在很多不敏感的方向时。

在实际应用中,选择哪种优化器往往取决于具体的任务和模型。对于一些复杂的或非凸的优化问题,可能会考虑使用带有自适应学习率的优化器,如 Adam 或 RMSprop,它们能够在不同的训练阶段自动调整学习率,通常在实践中更容易获得好的性能。然而,SGD 由于其简单性和有效性,在许多问题上仍是一个非常有价值的选择。

2、 Adam 优化器

使用 Adam 优化器是深度学习训练中的一种常见选择,特别是当你需要一个鲁棒且自适应的优化策略时。Adam 优化器结合了 RMSprop 和 Momentum 两种优化算法的优点,通过计算梯度的一阶矩(均值)和二阶矩(未中心化的方差)来调整每个参数的学习率。这种方法使得它特别适用于处理非稳定目标函数和非常大的数据集或参数数量。

2.1、设置 Adam 优化器

在 PyTorch 中,设置 Adam 优化器可以通过以下方式:

from torch.optim import Adam# 假设 model 是您的模型实例
# T_lr 是您定义的学习率变量
optimizer = Adam(model.parameters(), lr=T_lr)

相关参数解释如下:

  • lr (learning rate): 学习率决定了参数更新的幅度。对于 Adam,通常会设置一个较小的值,如 1e-31e-4,因为 Adam 自身的调整机制已经非常有效。
  • betas: 一个元组 (beta1, beta2),用于计算梯度及其平方的运行平均值;默认值通常是 (0.9, 0.999)
  • eps: 用于数值稳定性的小常数,防止在计算中出现除以零的错误;默认值为 1e-8
  • weight_decay: 权重衰减系数,这在正则化和防止过拟合中非常有用;类似于 L2 正则化。
  • amsgrad: 一个布尔值,表明是否使用 AMSGrad 变种的 Adam 算法,据说可以提高该算法的收敛性,防止过早停滞。

2.2、使用 Adam 优化器的训练流程

以下是一个标准的训练循环,展示了如何使用 Adam 进行参数更新:

# 训练模型
model.train()  # 确保模型处于训练模式
for epoch in range(num_epochs):  # num_epochs 是总的迭代周期数for inputs, targets in train_loader:  # 假设 train_loader 是您的数据加载器optimizer.zero_grad()  # 清空之前的梯度outputs = model(inputs)  # 计算模型输出loss = loss_function(outputs, targets)  # 计算损失函数loss.backward()  # 反向传播,计算当前梯度optimizer.step()  # 根据梯度更新网络参数print(f"Epoch {epoch+1}, Loss: {loss.item()}")  # 打印损失

2.3、Adam 优化器的优缺点

  • 优点

    • 自适应学习率使得参数调整更加灵活和高效。
    • 非常适合处理大数据集和高维空间。
    • 较少的调整学习率的需求。
  • 缺点

    • 相对于简单的 SGD,计算资源消耗更大。
    • 在某些情况下,如非常深的网络或复杂的架构中,可能不如 SGD 配合学习率衰减策略稳定。

Adam 由于其自适应性强、实现简单且通常性能优良,成为许多深度学习应用的默认选择。确保根据您的具体任务需求调整优化器的参数,以便最大化模型的性能和效率。对于不同的任务和模型结构,可能需要进行一些试验和错误调整,以找到最佳的超参数设置。

3. AdamW 优化器

AdamW 是一种改进的 Adam 优化算法,它对 Adam 的权重衰减组件进行了修改,使得权重衰减不再是添加到梯度上而是直接对参数进行更新。这种方法与 L2 正则化的原理更为接近,通常可以带来更好的训练稳定性和泛化性能。

3.1、示例

在 PyTorch 中,AdamW 的使用与 Adam 非常相似,但它通常被认为是一种更适合带权重衰减的场景的优化器,特别是在深度学习中。这是因为它可以更有效地控制模型的过拟合。optim.AdamW 是一个优化器,它结合了 Adam 优化器和权重衰减(L2 正则化)。它有助于防止模型过拟合,并在训练过程中提高模型的泛化能力。
model.parameters(): 这是模型参数的迭代器,AdamW将使用这些参数来计算梯度并更新权重。

from torch.optim import AdamW# 假设 model 是您的模型实例
# T_lr 是定义的学习率
# weight_decay 是权重衰减系数,常用于正则化
optimizer = AdamW(model.parameters(), lr=T_lr, weight_decay=0.01)

相关参数解释如下:

  • lr (learning rate):控制优化器步长的大小,即每次参数更新的幅度。
  • weight_decay:权重衰减(L2正则化)系数,用于正则化和避免过拟合。在 AdamW 中,这个参数是直接在权重更新时应用的,而不是作为梯度的一部分。这个值越高,正则化的效果越强,可以帮助减少模型的复杂度,从而减少过拟合。

3.2、训练过程

使用 AdamW 优化器的训练过程与其他优化器类似,但由于其处理权重衰减的方式,它可能在需要正则化以提高模型泛化能力的任务中表现得更好。

# 训练模型
model.train()  # 确保模型处于训练模式
for epoch in range(num_epochs):  # num_epochs 是总的迭代周期数for inputs, targets in train_loader:  # 假设 train_loader 是您的数据加载器optimizer.zero_grad()  # 清空之前的梯度outputs = model(inputs)  # 计算模型输出loss = loss_function(outputs, targets)  # 计算损失函数loss.backward()  # 反向传播,计算当前梯度optimizer.step()  # 根据梯度更新网络参数print(f"Epoch {epoch+1}, Loss: {loss.item()}")  # 打印损失

3.3、AdamW 优化器的优点

  • 更有效的正则化:由于其更新权重衰减的方式,AdamW 可以更有效地进行正则化,避免深度学习模型中常见的过拟合问题。
  • 改善泛化:与传统的 Adam 相比,AdamW 在多个基准测试中显示出更好的泛化能力。
  • 易于集成AdamW 可以轻松替换 Adam 优化器,无需改动其他代码,提供了一种简单的提升模型性能的方式。

AdamW 适用于几乎所有使用 Adam 的场景,特别推荐在易于过拟合的大模型和复杂任务中使用。由于其改进的权重衰减机制,它特别适合需要正则化的应用,如在小数据集上训练的深度网络。在选择优化器时,如果你已经计划使用 Adam 并关注模型的泛化,不妨试试 AdamW

相关文章:

机器学习之常用优化器

机器学习之常用优化器 1、SGD 优化器1.2、 SGD 的优缺点 2、 Adam 优化器2.1、设置 Adam 优化器2.2、使用 Adam 优化器的训练流程2.3、Adam 优化器的优缺点 3. AdamW 优化器3.1、示例3.2、训练过程3.3、AdamW 优化器的优点 1、SGD 优化器 在 PyTorch 中,设置 SGD 优…...

机器学习基本概念,Numpy,matplotlib和张量Tensor知识进一步学习

机器学习一些基本概念: 监督学习 监督学习是机器学习中最常见的形式之一,它涉及到使用带标签的数据集来训练模型。这意味着每条训练数据都包含输入特征和对应的输出标签。目标是让模型学会从输入到输出的映射,这样当给出新的未见过的输入时…...

博客前端项目学习day01

这里写自定义目录标题 登录创建项目配置环境变量,方便使用登录页面验证码登陆表单 在VScode上写前端,采用vue3。 登录 创建项目 检查node版本 node -v 创建一个新的项目 npm init vitelatest blog-front-admin 中间会弹出询问是否要安装包&#xff0c…...

java Collections.synchronizedCollection方法介绍

Collections.synchronizedCollection 是 Java 中的一个实用方法,用于创建一个线程安全的集合。它通过包装现有的集合对象来实现线程安全,以确保在多线程环境中对集合的访问是安全的。 主要功能 线程安全:通过同步包装现有的集合,使得在多线程环境中对集合的所有访问(包括…...

力扣每日一题:3011. 判断一个数组是否可以变为有序

力扣官网:前往作答!!!! 今日份每日一题: 题目要求: 给你一个下标从 0 开始且全是 正 整数的数组 nums 。 一次 操作 中,如果两个 相邻 元素在二进制下数位为 1 的数目 相同 &…...

ubuntu 上vscode +cmake的debug调试配置方法

在ubuntu配置pcl点云库以及opencv库的时候,需要在CMakeLists.txt中加入相应的代码。配置完成后,无法调试,与在windows上体验vs studio差别有点大。 找了好多调试debug配置方法,最终能用的有几种,但是有一种特别好用&a…...

使用Redis实现签到功能:Java示例解析

使用Redis实现签到功能:Java示例解析 在本博客中,我们将讨论一个使用Redis实现的签到功能的Java示例。该示例包括两个主要方法:sign()和signCount(),分别用于用户签到和计算用户当月的签到次数。 1. 签到方法:sign()…...

tableau标靶图,甘特图与瀑布图绘制 - 9

标靶图,甘特图与瀑布图 1. 标靶图绘制1.1 筛选器筛选日期1.2 条形图绘制1.3 编辑参考线1.4 设置参考线1.5 设置参考区间1.6 四分位设置1.7 其他标靶图结果显示 2.甘特图绘制2.1 选择列属性2.2 选择列属性2.3 创建新字段2.4 设置天数大小及颜色 3. 瀑布图绘制3.1 she…...

双向链表专题

在之前的单链表专题中,了解的单链表的结构是如何实现的,以及学习了如何实现单链表得各个功能。单链表虽然也能实现数据的增、删、查、改等功能,但是要找到尾节点或者是要找到指定位置之前的节点时,还是需要遍历链表,这…...

SpringCoud组件

一、使用SpringCloudAlibaba <dependencyManagement><dependencies><dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-alibaba-dependencies</artifactId><version>2023.0.1.0</version><…...

向量的定义和解释

这是一个向量&#xff1a; 向量具有大小&#xff08;大小&#xff09;和方向&#xff1a; 线的长度显示其大小&#xff0c;箭头指向方向。 在这里玩一个&#xff1a; 我们可以通过将它们从头到尾连接来添加两个向量&#xff1a; 无论我们添加它们的顺序如何&#xff0c;我们都…...

IoTDB 集群高效管理:一键启停功能介绍

如何快速启动、停止 IoTDB 集群节点的功能详解&#xff01; 在部署 IoTDB 集群时&#xff0c;对于基础的单机模式&#xff0c;启动过程相对简单&#xff0c;仅需执行 start-standalone 脚本来启动 1 个 ConfigNode 节点和 1 个 DataNode 节点。然而&#xff0c;对于更高级的分布…...

一个spring boot项目的启动过程分析

1、web.xml 定义入口类 <context-param><param-name>contextConfigLocation</param-name><param-value>com.baosight.ApplicationBoot</param-value> </context-param> 2、主入口类: ApplicationBoot,SpringBoot项目的mian函数 SpringBo…...

智驭未来:人工智能与目标检测的深度交融

在科技日新月异的今天&#xff0c;人工智能&#xff08;AI&#xff09;如同一股不可阻挡的浪潮&#xff0c;正以前所未有的速度重塑着我们的世界。在众多AI应用领域中&#xff0c;目标检测以其独特的魅力和广泛的应用前景&#xff0c;成为了连接现实与智能世界的桥梁。本文旨在…...

01MFC建立单个文件类型——画线

文章目录 选择模式初始化文件作用解析各初始化文件解析 类导向创建鼠标按键按下抬起操作函数添加一个变量记录起始位置注意事项代码实现效果图 虚实/颜色线 选择模式 初始化文件作用解析 运行&#xff1a; 各初始化文件解析 MFC&#xff08;Microsoft Foundation Classes&am…...

免杀中用到的工具

&#x1f7e2; 绝大部分无法直接生成免杀木马&#xff0c;开发、测试免杀时会用到。 工具简称 概述 工具来源 下载路径 x64dbg 中文版安装程序(Jan 6 2024).exe 52pojie hellshell 官方的加密或混淆shellcode github Releases ORCA / HellShell GitLab hellshe…...

[vite] Pre-transform error: Cannot find package pnpm路径过长导致运行报错

下了套vue3的代码&#xff0c;执行pnpm install初始化&#xff0c;使用vite启动&#xff0c;启动后访问就会报错 报错信息 ERROR 16:40:53 [vite] Pre-transform error: Cannot find package E:\work\VSCodeProjectWork\jeecg\xxxxxxxxx-next\xxxxxxxxx-next-jeecgBoot-vue3\…...

Promise总结

Promise.then() 的返回值仍然是 Promise 对象 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>D…...

ROI 接口便捷修改

传入的图片截取ROI后再进入识别接口 &#xff08;识别接口比ROI接口的函数参数少一个传入的ROI&#xff09; 无点只有点集 返回双点集 //平直冷侧翅片 bool ImageProcessingTest::straightColdSideFin_ROI(cv::Mat img, cv::Rect ROI, std::vector<cv::Point>& topL…...

jenkins打包java项目报错Error: Unable to access jarfile tlm-admin.jar

jenkins打包boot项目 自动重启脚本失败 查看了一下项目日志报错&#xff1a; Error: Unable to access jarfile tlm-admin.jar我检查了一下这个配置&#xff0c;感觉没有问题&#xff0c;包可以正常打&#xff0c; cd 到项目目录下面&#xff0c;手动执行这个sh脚本也是能正常…...

SQL Server设置端口:跨平台指南

在使用SQL Server时&#xff0c;设置或修改其监听的端口是确保数据库服务安全访问和高效管理的重要步骤。由于SQL Server可以部署在多种操作系统上&#xff0c;包括Windows、Linux和Docker容器等&#xff0c;因此设置端口的步骤和方法也会因平台而异。本文将为您提供一个跨平台…...

ActiveMQ-CVE-2023-46604

Apache ActiveMQ OpenWire 协议反序列化命令执行漏洞 OpenWire协议在ActiveMQ中被用于多语言客户端与服务端通信。在Apache ActvieMQ5.18.2版本以及以前&#xff0c;OpenWire协议通信过程中存在一处反序列化漏洞&#xff0c;该漏洞可以允许具有网络访问权限的远程攻击者通过操作…...

TensorBoard ,PIL 和 OpenCV 在深度学习中的应用

重要工具介绍 TensorBoard&#xff1a; 是一个TensorFlow提供的强大工具&#xff0c;用于可视化和理解深度学习模型的训练过程和结果。下面我将介绍TensorBoard的相关知识和使用方法。 TensorBoard 简介 TensorBoard是TensorFlow提供的一个可视化工具&#xff0c;用于&#x…...

【超音速 专利 CN117576413A】基于全连接网络分类模型的AI涂布抓边处理方法及系统

申请号CN202311568976.4公开号&#xff08;公开&#xff09;CN117576413A申请日2023.11.22申请人&#xff08;公开&#xff09;超音速人工智能科技股份有限公司发明人&#xff08;公开&#xff09;张俊峰&#xff08;总&#xff09;; 杨培文&#xff08;总&#xff09;; 沈俊羽…...

iPhone数据恢复篇:iPhone 数据恢复软件有哪些

问题&#xff1a;iPhone 15 最好的免费恢复软件是什么&#xff1f;我一直在寻找一个恢复程序来恢复从iPhone中意外删除的照片&#xff0c;联系人和消息&#xff0c;但是我有很多选择。 谷歌一下&#xff0c;你会发现许多付费或免费的iPhone数据恢复工具&#xff0c;声称它们可…...

Html5+Css3学习笔记

Html5 CSS3 一、概念 1.什么是html5 html: Hyper Text Markup Language ( 超文本标记语言) 文本&#xff1a;记事本 超文本&#xff1a; 文字、图片、音频、视频、动画等等&#xff08;网页&#xff09; html语言经过浏览器的编译显示成超文本 开发者使用5种浏览器&#xf…...

WPF学习(2) -- 样式基础

一、代码 <Window x:Class"学习.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expression/blend/2008&…...

独家揭秘!五大内网穿透神器,访问你的私有服务

本文精心筛选了五款炙手可热的内网穿透工具&#xff0c;它们各怀绝技&#xff0c;无论您是企业用户、独立开发者&#xff0c;还是技术探索者&#xff0c;这篇文章都物有所值&#xff0c;废话不多说&#xff0c;主角们即将上场。 目录 1. 巴比达 - 安全至上的企业护航者 2. 花…...

Ubuntu 编译和运行ZLMediaKit

摘要 本文描述了如何在Ubuntu上构建ZLMediaKIt项目源码&#xff0c;以及如何体验其WebRTC推流和播放功能。 实验环境 操作系统版本&#xff1a;Ubuntu 22.04.3 LTS gcc版本&#xff1a;11.4.0 g版本&#xff1a;11.4.0 依赖库安装 #让ZLMediaKit媒体服务器具备WebRTC流转发…...

基于JavaSpringBoot+Vue+uniapp微信小程序校园宿舍管理系统设计与实现

基于JavaSpringBootVueuniapp微信小程序实现校园宿舍管理系统设计与实现 目录 第一章 绪论 1.1 研究背景 1.2 研究现状 1.3 研究内容 第二章 相关技术介绍 2.1 Java语言 2.2 HTML网页技术 2.3 MySQL数据库 2.4 Springboot 框架介绍 2.5 VueJS介绍 2.6 ElementUI介绍…...