当前位置: 首页 > article >正文

【漫话机器学习系列】087.常见的神经网络最优化算法(Common Optimizers Of Neural Nets)

常见的神经网络优化算法

1. 引言

在深度学习中,优化算法(Optimizers)用于更新神经网络的权重,以最小化损失函数(Loss Function)。一个高效的优化算法可以加速训练过程,并提高模型的性能和稳定性。本文介绍几种常见的神经网络优化算法,包括随机梯度下降(SGD)、带动量的随机梯度下降(Momentum SGD)、均方根传播算法(RMSProp)以及自适应矩估计(Adam),并提供相应的代码示例。

2. 常见的优化算法

2.1 随机梯度下降(Stochastic Gradient Descent, SGD)

随机梯度下降(SGD)是最基本的优化算法,其更新规则如下:

其中:

  • w 代表网络参数(权重);
  • α 是学习率(Learning Rate),控制更新步长;
  • ∇L(w) 是损失函数相对于权重的梯度。

代码示例(使用 PyTorch 实现 SGD)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降# 训练步骤
for epoch in range(100):optimizer.zero_grad()  # 清空梯度inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 4.9142
Epoch [10/100], Loss: 2.1721
Epoch [20/100], Loss: 0.9601
Epoch [30/100], Loss: 0.4244
Epoch [40/100], Loss: 0.1876
Epoch [50/100], Loss: 0.0829
Epoch [60/100], Loss: 0.0366
Epoch [70/100], Loss: 0.0162
Epoch [80/100], Loss: 0.0072
Epoch [90/100], Loss: 0.0032


2.2 带动量的随机梯度下降(Momentum SGD)

带动量的 SGD 在 SGD 的基础上加入动量(Momentum),用于加速收敛并减少震荡:


其中:

  • 是累积的梯度,类似于物理中的动量;
  • β 是动量系数(通常取 0.9)。

代码示例(Momentum SGD)

import torch
import torch.nn as nn
import torch.optim as optimmodel = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果 

Epoch [0/100], Loss: 3.0073
Epoch [10/100], Loss: 1.3292
Epoch [20/100], Loss: 0.5875
Epoch [30/100], Loss: 0.2597
Epoch [40/100], Loss: 0.1148
Epoch [50/100], Loss: 0.0507
Epoch [60/100], Loss: 0.0224
Epoch [70/100], Loss: 0.0099
Epoch [80/100], Loss: 0.0044
Epoch [90/100], Loss: 0.0019

优点:

  • 缓解了 SGD 震荡问题,提高收敛速度;
  • 在非凸优化问题中表现更好。

2.3 均方根传播算法(RMSProp)

RMSProp 通过自适应调整学习率来加速训练,并缓解震荡问题:


其中:

  • 是梯度平方的滑动平均;
  • β 是衰减系数(一般取 0.9);
  • ϵ 是一个很小的数,防止除零错误。

代码示例(RMSProp)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.9)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

运行结果

Epoch [0/100], Loss: 1.1952
Epoch [10/100], Loss: 0.5887
Epoch [20/100], Loss: 0.3333
Epoch [30/100], Loss: 0.1731
Epoch [40/100], Loss: 0.0752
Epoch [50/100], Loss: 0.0239
Epoch [60/100], Loss: 0.0043
Epoch [70/100], Loss: 0.0003
Epoch [80/100], Loss: 0.0000
Epoch [90/100], Loss: 0.0000

优点:

  • 适用于非平稳目标函数;
  • 能有效处理不同特征尺度的问题;
  • 在 RNN(循环神经网络)等任务上表现较好。

2.4 自适应矩估计(Adam, Adaptive Moment Estimation)

Adam 结合了动量法(Momentum)和 RMSProp,同时考虑梯度的一阶矩(平均值)和二阶矩(方差):



其中:

  • ​ 是梯度的一阶矩估计;
  • ​ 是梯度的二阶矩估计;
  • ​ 分别控制一阶矩和二阶矩的指数衰减率(通常取 0.9 和 0.999)。

代码示例(Adam)

import torch
import torch.nn as nn
import torch.optim as optim# 定义简单的线性模型
model = nn.Linear(1, 1)  # 1 个输入特征,1 个输出特征
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.Adam(model.parameters(), lr=0.01)for epoch in range(100):optimizer.zero_grad()inputs = torch.tensor([[1.0]], requires_grad=True)targets = torch.tensor([[2.0]])outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()if epoch % 10 == 0:print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')

输出结果 

Epoch [0/100], Loss: 3.6065
Epoch [10/100], Loss: 2.8894
Epoch [20/100], Loss: 2.2642
Epoch [30/100], Loss: 1.7359
Epoch [40/100], Loss: 1.3021
Epoch [50/100], Loss: 0.9555
Epoch [60/100], Loss: 0.6855
Epoch [70/100], Loss: 0.4805
Epoch [80/100], Loss: 0.3287
Epoch [90/100], Loss: 0.2192

优点:

  • 结合 Momentum 和 RMSProp 的优势;
  • 适用于大规模数据集和高维参数优化;
  • 具有自适应学习率,适用于不同类型的问题。

3. 选择合适的优化算法

优化算法特点适用场景
SGD计算简单,但容易震荡适用于大规模数据,适合凸优化问题
Momentum SGD增加动量,减少震荡,加速收敛适用于复杂深度神经网络
RMSProp自适应调整学习率,适用于非平稳问题适用于 RNN、强化学习等
Adam结合 Momentum 和 RMSProp,自适应学习率适用于大多数深度学习任务

4. 结论

在神经网络训练过程中,优化算法的选择对最终的模型性能有重要影响。SGD 是最基础的优化方法,而带动量的 SGD 在收敛速度和稳定性上有所提升。RMSProp 适用于非平稳目标函数,而 Adam 结合了 Momentum 和 RMSProp 的优势,成为当前最流行的优化算法之一。

不同任务可能需要不同的优化算法,通常的建议是:

  • 对于简单的凸优化问题,可以使用 SGD。
  • 对于深度神经网络,可以使用 Momentum SGD 或 Adam。
  • 对于 RNN 和强化学习问题,RMSProp 是一个不错的选择。

合理选择优化算法可以显著提升模型训练的效率和效果!

相关文章:

【漫话机器学习系列】087.常见的神经网络最优化算法(Common Optimizers Of Neural Nets)

常见的神经网络优化算法 1. 引言 在深度学习中,优化算法(Optimizers)用于更新神经网络的权重,以最小化损失函数(Loss Function)。一个高效的优化算法可以加速训练过程,并提高模型的性能和稳定…...

react-native fetch在具有http远程服务器后端的Android设备上抛出“Network request failed“错误

问题描述: 在具有http远程服务器后端的Android设备上,使用react-native fetch时抛出"Network request failed"错误。 回答: "Network request failed"错误通常表示在进行网络请求时出现了问题。可能的原因包括网络连接…...

【JVM详解四】执行引擎

一、概述 Java程序运行时,JVM会加载.class字节码文件,但是字节码并不能直接运行在操作系统之上,而JVM中的执行引擎就是负责将字节码转化为对应平台的机器码让CPU运行的组件。 执行引擎是JVM核心的组成部分之一。可以把JVM架构分成三部分&am…...

route 与 router 之间的差别

简述&#xff1a; router&#xff1a;主要用于处理一些动作&#xff0c; route&#xff1a;主要获得或处理一些数据&#xff0c;比如地址、参数等 例&#xff1a; videoInfo1.vue&#xff1a; <template><div class"video-info"><h3>二级组件…...

[vue3] Ref Reactive

【b站-【前端面试】Vue3 ref 与 reactive 区别】 Ref&#xff1a;Ref用于创建一个响应式的基本数据类型&#xff0c;比如数字、字符串等。它将普通的数据变成响应式数据&#xff0c;可以监听数据的变化。使用Ref时&#xff0c;我们可以通过.value来访问和修改数据的值。 Reac…...

SamWaf开源轻量级的网站应用防火墙(安装包),私有化部署,加密本地存储的数据,易于启动,并支持 Linux 和 Windows 64 位和 Arm64

一、SamWaf轻量级开源防火墙介绍 &#xff08;文末提供下载&#xff09; SamWaf网站防火墙是一款适用于小公司、工作室和个人网站的开源轻量级网站防火墙&#xff0c;完全私有化部署&#xff0c;数据加密且仅保存本地&#xff0c;一键启动&#xff0c;支持Linux&#xff0c;Wi…...

极客说|利用 Azure AI Agent Service 创建自定义 VS Code Chat participant

作者&#xff1a;卢建晖 - 微软高级云技术布道师 「极客说」 是一档专注 AI 时代开发者分享的专栏&#xff0c;我们邀请来自微软以及技术社区专家&#xff0c;带来最前沿的技术干货与实践经验。在这里&#xff0c;您将看到深度教程、最佳实践和创新解决方案。关注「极客说」&a…...

22.2、Apache安全分析与增强

目录 Apache Web安全分析与增强 - Apache Web概述Apache Web安全分析与增强 - Apache Web安全威胁Apache Web安全机制Apache Web安全增强 Apache Web安全分析与增强 - Apache Web概述 阿帕奇是一个用于搭建WEB服务器的应用程序&#xff0c;它是开源的&#xff0c;它的配置文件…...

理邦仪器嵌入式(C/C++开发)开发面试题及参考答案

C++ 虚函数的概念和作用 C++ 中的虚函数是一种非常重要的机制,它在实现多态性方面起着关键作用。 概念上来说,虚函数是在基类中使用关键字 virtual 声明的成员函数。当基类的指针或引用指向派生类的对象时,通过这个基类的指针或引用调用虚函数,实际执行的是派生类中重写的该…...

windows + visual studio 2019 使用cmake 编译构建静、动态库并调用详解

环境 windows visual studio 2019 visual studio 2019创建cmake工程 1. 静态库.lib 1.1 静态库编译生成 以下是我创建的cmake工程文件结构&#xff0c;只关注高亮文件夹部分 libout 存放编译生成的.lib文件libsrc 存放编译用的源代码和头文件CMakeLists.txt 此次编译CMak…...

Chrome 浏览器 支持多账号登录和管理的浏览器容器解决方案

根据搜索结果&#xff0c;目前没有直接提到名为“chrometable”的浏览器容器或插件。不过&#xff0c;从功能描述来看&#xff0c;您可能需要的是一个能够支持多账号登录和管理的浏览器容器解决方案。以下是一些可能的实现方式&#xff1a; 1. 使用 Docker 容器化部署 Chrome …...

GrassWebProxy

GrassWebProxy第一版&#xff1a; using System; using System.Collections.Generic; using System.Linq; using System.Net.Sockets; using System.Net; using System.Text; using System.Threading; using System.Threading.Tasks; using System.IO; using Newtonsoft.Json;…...

DeepSeek API 调用 - Spring Boot 实现

DeepSeek API 调用 - Spring Boot 实现 1. 项目依赖 在 pom.xml 中添加以下依赖&#xff1a; <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-webflux</artifactId></depe…...

【DeepSeek】Deepseek辅组编程-通过卫星轨道计算终端距离、相对速度和多普勒频移

引言 笔者在前面的文章中&#xff0c;介绍了基于卫星轨道参数如何计算终端和卫星的距离&#xff0c;相对速度和多普勒频移。 【一文读懂】卫星轨道的轨道参数&#xff08;六根数&#xff09;和位置速度矢量转换及其在终端距离、相对速度和多普勒频移计算中的应用 Matlab程序 …...

【kafka实战】05 Kafka消费者消费消息过程源码剖析

1. 概述 Kafka消费者&#xff08;Consumer&#xff09;是Kafka系统中负责从Kafka集群中拉取消息的客户端组件。消费者消费消息的过程涉及多个步骤&#xff0c;包括消费者组的协调、分区分配、消息拉取、消息处理等。本文将深入剖析Kafka消费者消费消息的源码&#xff0c;并结合…...

[EAI-033] SFT 记忆,RL 泛化,LLM和VLM的消融研究

Paper Card 论文标题&#xff1a;SFT Memorizes, RL Generalizes: A Comparative Study of Foundation Model Post-training 论文作者&#xff1a;Tianzhe Chu, Yuexiang Zhai, Jihan Yang, Shengbang Tong, Saining Xie, Dale Schuurmans, Quoc V. Le, Sergey Levine, Yi Ma 论…...

算法与数据结构(字符串相乘)

题目 思路 这道题我们可以使用竖式乘法&#xff0c;从右往左遍历每个乘数&#xff0c;将其相乘&#xff0c;并且把乘完的数记录在nums数组中&#xff0c;然后再进行进位运算&#xff0c;将同一列的数进行相加&#xff0c;进位。 解题过程 首先求出两个数组的长度&#xff0c;…...

DeepSeek从入门到精通:全面掌握AI大模型的核心能力

文章目录 一、DeepSeek是什么&#xff1f;性能对齐OpenAI-o1正式版 二、Deepseek可以做什么&#xff1f;能力图谱文本生成自然语言理解与分析编程与代码相关常规绘图 三、如何使用DeepSeek&#xff1f;四、DeepSeek从入门到精通推理模型推理大模型非推理大模型 快思慢想&#x…...

【Pytorch函数】PyTorch随机数生成全解析 | torch.rand()家族函数使用指南

&#x1f31f; PyTorch随机数生成全解析 | torch.rand()家族函数使用指南 &#x1f31f; &#x1f4cc; 一、核心函数参数详解 PyTorch提供多种随机数生成函数&#xff08;注意&#xff1a;无直接torch.random()函数&#xff09;&#xff0c;以下是常用函数及参数&#xff1a;…...

vue print 打印

vue 点击打印页面部分内容&#xff0c;或者打印弹窗内的内容 打印页面部分内容 <template><div><div id"print"><div class"info"><div class"bx_title">费用报销单<span class"code">NO.<s…...

【异常解决】在idea中提示 hutool 提示 HttpResponse used withoud try-with-resources statement

博主介绍&#xff1a;✌全网粉丝22W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...

【Uniapp-Vue3】UniCloud云数据库获取指定字段的数据

使用where方法可以获取指定的字段&#xff1a; let db uniCloud.database(); db.collection("数据表").where({字段名1:数据, 字段名2:数据}).get({getOne:true}) 如果我们不在get中添加{getOne:true}&#xff0c;在只获取到一个数据res.result.data将会是一个数组&…...

信息科技伦理与道德3-2:智能决策

2.2 智能推荐 推荐算法介绍 推荐系统&#xff1a;猜你喜欢 https://blog.csdn.net/search_129_hr/article/details/120468187 推荐系统–矩阵分解 https://blog.csdn.net/search_129_hr/article/details/121598087 案例一&#xff1a;YouTube推荐算法向儿童推荐不适宜视频 …...

openssl使用

openssl使用 提取密钥对 数字证书pfx包含公钥和私钥&#xff0c;而cer证书只包含公钥。提取需输入证书保护密码 openssl pkcs12 -in xxx.pfx -nocerts -nodes -out pare.key提取私钥 openssl rsa -in pare.key -out pri.key提取公钥 openssl rsa -in pare.key -pubout -ou…...

Visual Studio 2022 中使用 Google Test

要在 Visual Studio 2022 中使用 Google Test (gtest)&#xff0c;可以按照以下步骤进行&#xff1a; 安装 Google Test&#xff1a;确保你已经安装了 Google Test。如果没有安装&#xff0c;可以通过 Visual Studio Installer 安装。在安装程序中&#xff0c;找到并选择 Googl…...

SpringBoot3 + Jedis5 + Redis集群 如何通过scan方法分页获取所有keys

背景: 由于需要升级老项目代码&#xff0c;从SpringBoot1.5.x 升级到 SpringBoot3.3.x&#xff0c;框架中引用的Jedis自动升级到了 5.x&#xff1b;正好代码中有需要获取Redis集群的所有keys的需求存在&#xff1b;代码就不适用了&#xff0c;修改如下&#xff1a; POM 由于…...

WGCLOUD监控系统部署教程

官网地址&#xff1a;下载WGCLOUD安装包 - WGCLOUD官网 第一步、环境配置 #安装jdk 1、安装 EPEL 仓库&#xff1a; sudo yum install -y epel-release 2、安装 OpenJDK 11&#xff1a; sudo yum install java-11-openjdk-devel 3、如果成功&#xff0c;你可以通过运行 java …...

协议-WebRTC-HLS

是什么&#xff1f; WebRTC&#xff08;Web Real-Time Communication&#xff09; 实现 Web 浏览器和移动应用程序之间通过互联网直接进行实时通信。允许点对点音频、视频和数据共享&#xff0c;而无需任何插件或其他软件。WebRTC 广泛用于构建视频会议、语音通话、直播、在线游…...

jQuery UI 下载指南

jQuery UI 下载指南 引言 jQuery UI 是一个基于 jQuery 的用户界面和交互库&#xff0c;它提供了一套丰富的交互组件和视觉效果&#xff0c;可以帮助开发者快速构建美观、交互性强的网页应用。本文将为您详细介绍如何下载 jQuery UI&#xff0c;并指导您进行安装和使用。 jQ…...

MySQL系列之数据类型(String)

导览 前言一、字符串类型知多少 1. 类型说明2. 字符和字节的转换 二、字符串类型的异同 1. CHAR & VARCHAR2. BINARY & VARBINARY3. BLOB & TEXT4. ENUM & SET 结语精彩回放 前言 MySQL数据类型第三弹闪亮登场&#xff0c;欢迎关注O。 本篇博主开始谈谈MySQ…...