Python 梯度下降法(五):Adam Optimize
文章目录
- Python 梯度下降法(五):Adam Optimize
- 一、数学原理
- 1.1 介绍
- 1.2 符号说明
- 1.3 实现流程
- 二、代码实现
- 2.1 函数代码
- 2.2 总代码
- 2.3 遇到的问题
- 2.4 算法优化
- 三、优缺点
- 3.1 优点
- 3.2 缺点
- 四、相关链接
Python 梯度下降法(五):Adam Optimize
一、数学原理
1.1 介绍
Adam 算法结合了 Adagrad 和 RMSProp 算法的优点。Adagrad 算法会根据每个参数的历史梯度信息来调整学习率,对于出现频率较低的参数会给予较大的学习率,而对于出现频率较高的参数则给予较小的学习率。RMSProp 算法则是对 Adagrad 算法的改进,它通过使用移动平均的方式来计算梯度的平方,从而避免了 Adagrad 算法中学习率单调下降的问题。
1.2 符号说明
| 参数 | 意义 |
|---|---|
| g t = ∇ θ J ( θ t ) g_{t}=\nabla_{\theta}J(\theta_{t}) gt=∇θJ(θt) | 第 t t t时刻的梯度 |
| m t m_{t} mt | 梯度的一阶矩(均值) |
| β 1 \beta_{1} β1 | 一阶矩衰减率,一般取0.9 |
| v t v_{t} vt | 梯度的二阶矩(未中心化的方差) |
| β 2 \beta_{2} β2 | 二阶矩衰减率,一般取0.99 |
| θ \theta θ | 线性拟合参数 |
| η \eta η | 学习率 |
| ϵ \epsilon ϵ | 无穷小量,一般取 1 0 − 8 10^{-8} 10−8 |
1.3 实现流程
- 初始化: θ \theta θ、 η \eta η、 m 0 ⃗ = 0 \vec{m_{0}}=0 m0=0、 v 0 ⃗ = 0 \vec{v_{0}}=0 v0=0
- 计算梯度: g t = ∇ θ J ( θ t ) = 1 m X T L g_{t}=\nabla_{\theta}J(\theta_{t})=\frac{1}{m}X^{T}L gt=∇θJ(θt)=m1XTL
- 梯度的一阶矩估计(均值): m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t} mt=β1mt−1+(1−β1)gt
- 梯度的二阶矩估计(未中心化的方差): v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} vt=β2vt−1+(1−β2)gt2
- 偏差修正: m t ^ = m t 1 − β 1 t 、 v t ^ = v t 1 − β 2 t \hat{m_{t}}=\frac{m_{t}}{1-\beta_{1}^{t}}、\hat{v_{t}}=\frac{v_{t}}{1-\beta_{2}^{t}} mt^=1−β1tmt、vt^=1−β2tvt
- 更新参数: θ t = θ t − 1 − η m t ^ v t ^ + ϵ \theta_{t}=\theta_{t-1}-\frac{\eta \hat{m_{t}}}{\sqrt{ \hat{v_{t}} }+\epsilon} θt=θt−1−vt^+ϵηmt^
二、代码实现
2.1 函数代码
# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y mx1eta: 学习率num_iter: 迭代次数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), [] # 初始化数据for iter in range(num_iter):h = X.dot(theta)err = h - yloss_.append(np.mean((err ** 2) / 2))g = (1 / m ) * X.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 偏差修正mt_ = mt / (1 - pow(beta1, (iter + 1))) # 得 + 1 不然在 iter = 0 时,分母为零vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))# 更新参数theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)# 检查是否收敛if iter > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {iter + 1}")breakreturn theta.flatten(), loss_
2.2 总代码
import numpy as np
import matplotlib.pyplot as plt# 设置 matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y mx1eta: 学习率num_iter: 迭代次数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), [] # 初始化数据for iter in range(num_iter):h = X.dot(theta)err = h - yloss_.append(np.mean((err ** 2) / 2))g = (1 / m ) * X.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 偏差修正mt_ = mt / (1 - pow(beta1, (iter + 1))) # 得 + 1 不然在 iter = 0 时,分母为零vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))# 更新参数theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)# 检查是否收敛if iter > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {iter + 1}")breakreturn theta.flatten(), loss_# 生成一些示例数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 超参数
eta = 0.01# 运行 Adam 优化器
theta, loss_ = adam_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)# 绘制损失函数图像
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.legend() # 显示图例
plt.grid(True) # 显示网格线
plt.show()

2.3 遇到的问题
当偏差修正为以下算法时,出现报错:
# 偏差修正mt_ = mt / (1 - pow(beta1, (iter)))vt_ = np.abs(vt / (1 - pow(beta2, (iter))))

进行检验时,我们发现:

mt_,vt_ \text{mt\_,vt\_} mt_,vt_为无穷量,因此考虑分母为零的情况,而当 iter = 0 \text{iter}=0 iter=0时, 1 − β iter = 0 1- \beta^{\text{iter}}=0 1−βiter=0,故说明索引不能从0开始,而应该从1开始,因此引入 iter + 1 \text{iter}+1 iter+1,防止分母的无穷大引入。
2.4 算法优化
由于算法过程中,如果数据量太多会引起资源的严重浪费,因此我们引入小批量梯度下降法的类似方法,批量截取数据来进行拟合。
# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, batch_size=32, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y mx1eta: 学习率num_iter: 迭代次数batch_size: 小批量分支法的批量数beta: 衰减率epsilon: 无穷小threshold: 阈值"""m, n = X.shapetheta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), [] # 初始化数据num_batchs = m // batch_sizefor _ in range(num_iter):range_shuffle = np.random.permutation(m)X_shuffled = X[range_shuffle]y_shuffled = y[range_shuffle]loss_temp = []for iter in range(num_batchs):start_index = batch_size * iterend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]h = xi.dot(theta)err = h - yiloss_temp.append(np.mean((err ** 2) / 2))g = (1 / m ) * xi.T.dot(err)# 一阶矩估计mt = beta1 * mt + (1 - beta1) * g# 二阶矩估计vt = beta2 * vt + (1 - beta2) * g ** 2# 偏差修正mt_ = mt / (1 - pow(beta1, (iter + 1)))vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))# 更新参数theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)loss_.append(np.mean(loss_temp))# 检查是否收敛if _ > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {iter + 1}")breakreturn theta.flatten(), loss_

使用小批量进行Adam优化,可以大大节省系统的资源。
三、优缺点
3.1 优点
对不同参数调整学习率:Adam 能够为模型的每个参数自适应地调整学习率。它会根据参数的梯度历史信息,对出现频率较低的参数给予较大的学习率,对出现频率较高的参数给予较小的学习率。这使得模型在训练过程中能够更好地处理不同尺度和变化频率的参数,加速收敛过程。
无需手动精细调整:在很多情况下,Adam 算法提供的默认超参数就能取得不错的效果,不需要像传统优化算法那样进行大量的手动调参,节省了时间和精力。
低内存需求:Adam 只需要存储梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差),不需要像一些二阶优化方法那样存储复杂的海森矩阵(Hessian matrix),因此内存占用相对较小,适合处理大规模数据集和深度神经网络。
快速收敛:通过结合梯度的一阶矩和二阶矩信息,Adam 能够更准确地估计梯度的方向和大小,从而在大多数情况下比传统的随机梯度下降(SGD)算法更快地收敛到最优解。
利用稀疏信息:在处理稀疏数据(如自然语言处理中的词向量)时,Adam 能够根据数据的稀疏性调整学习率。对于那些很少出现的特征,算法会给予较大的学习率,使得模型能够更有效地学习这些特征,避免因数据稀疏而导致的学习困难
偏差修正机制:Adam 算法引入了偏差修正机制,用于修正一阶矩和二阶矩估计在训练初期的偏差。这使得算法在训练的早期阶段更加稳定,能够避免因初始估计不准确而导致的训练波动或不收敛问题。
3.2 缺点
自适应特性的局限性:虽然 Adam 能够自适应地调整学习率,但在某些情况下,这种自适应特性可能会导致算法陷入局部最优解。由于学习率会随着训练过程自动调整,可能会在接近局部最优解时过早地降低学习率,使得算法难以跳出局部最优区域,从而无法找到全局最优解。
需要一定的调参经验:尽管 Adam 提供了默认的超参数,但在某些复杂的任务或数据集上,这些默认参数可能不是最优的。例如, β \beta β、 ϵ \epsilon ϵ的取值会影响算法的性能,如果选择不当,可能会导致收敛速度变慢、模型性能下降等问题。因此,在实际应用中,可能仍然需要进行一定的超参数调优。
过度适应训练数据:由于 Adam 算法在训练过程中过于关注梯度的历史信息和自适应调整学习率,可能会导致模型过度适应训练数据,从而降低模型的泛化能力。在某些情况下,使用 Adam 训练的模型在测试集上的表现可能不如使用其他优化算法训练的模型。
四、相关链接
Python 梯度下降法合集:
- Python 梯度下降法(一):Gradient Descent-CSDN博客
- Python 梯度下降法(二):RMSProp Optimize-CSDN博客
- Python 梯度下降法(三):Adagrad Optimize-CSDN博客
- Python 梯度下降法(四):Adadelta Optimize-CSDN博客
- Python 梯度下降法(五):Adam Optimize-CSDN博客
- Python 梯度下降法(六):Nadam Optimize-CSDN博客
- Python 梯度下降法(七):Summary-CSDN博客
相关文章:
Python 梯度下降法(五):Adam Optimize
文章目录 Python 梯度下降法(五):Adam Optimize一、数学原理1.1 介绍1.2 符号说明1.3 实现流程 二、代码实现2.1 函数代码2.2 总代码2.3 遇到的问题2.4 算法优化 三、优缺点3.1 优点3.2 缺点 四、相关链接 Python 梯度下降法(五&a…...
笔试-二进制
应用题 将符合区间[l,r]内的十进制整数转换为二进制表示,请问不包含“101”的整数个数是多少? 实现 l int(input("请输入下限l,其值大于等于1:")) r int(input("请输入上限r,其值大于等于l&#x…...
springboot 2.7.6 security mysql redis jwt配置例子
数据库结构用的是若依的数据库基本结构,ruoyi.vip。 总体参考了文章:https://blog.csdn.net/qq_45847507/article/details/126681110 本文章只包含不同的地方,相同的不再赘述。 1、创建spring工程,jdk1.8,maven。 pom.xml中依赖部…...
FreeRTOS从入门到精通 第十六章(任务通知)
参考教程:【正点原子】手把手教你学FreeRTOS实时系统_哔哩哔哩_bilibili 一、任务通知简介 1、概述 (1)任务通知顾名思义是用来通知任务的,任务控制块中的结构体成员变量ulNotifiedValue就是这个通知值。 (2&#…...
TensorFlow 简单的二分类神经网络的训练和应用流程
展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括: 1. 数据准备与预处理 2. 构建模型 3. 编译模型 4. 训练模型 5. 评估模型 6. 模型应用与部署 加载和应用已训练的模型 1. 数据准备与预处理 在本例中,数据准备是通过两个 Numpy 数…...
无人机图传模块 wfb-ng openipc-fpv,4G
openipc 的定位是为各种模块提供底层的驱动和linux最小系统,openipc 是采用buildroot系统编译而成,因此二次开发能力有点麻烦。为啥openipc 会用于无人机图传呢?因为openipc可以将现有的网络摄像头ip-camera模块直接利用起来,从而…...
.cc扩展名是什么语言?C语言必须用.c为扩展名吗?主流编程语言扩展名?Java为什么不能用全数字的文件名?
.cc扩展名是什么语言? .cc是C语言使用的扩展名,一种说法是它是c with class的简写,当然C语言使用的扩展名不止.cc和.cpp, 还包含.cxx, .c, .C等,这些在不同编译器系统采用的默认设定不同,需要区分使用。当然,编译器提…...
【MyDB】4-VersionManager 之 3-死锁及超时检测
【MyDB】4-VersionManager 之 3-死锁及超时检测 死锁及超时检测案例背景LockTable锁请求与等待管理 addvm调用addputIntoList,isInList,removeFromList 死锁检测 hasDeadLock方法资源释放与重分配 参考资料 死锁及超时检测 本章涉及代码:top/…...
【Linux】使用管道实现一个简易版本的进程池
文章目录 使用管道实现一个简易版本的进程池流程图代码makefileTask.hppProcessPool.cc 程序流程: 使用管道实现一个简易版本的进程池 流程图 代码 makefile ProcessPool:ProcessPool.ccg -o $ $^ -g -stdc11 .PHONY:clean clean:rm -f ProcessPoolTask.hpp #pr…...
【OpenGL】OpenGL游戏案例(二)
文章目录 特殊效果数据结构生成逻辑更新逻辑 文本渲染类结构构造函数加载函数渲染函数 特殊效果 为提高游戏的趣味性,在游戏中提供了六种特殊效果。 数据结构 PowerUp 类只存储存活数据,实际逻辑在游戏代码中通过Type字段来区分执行 class PowerUp …...
28. 【.NET 8 实战--孢子记账--从单体到微服务】--简易报表--报表定时器与报表数据修正
这篇文章是《.NET 8 实战–孢子记账–从单体到微服务》系列专栏的《单体应用》专栏的最后一片和开发有关的文章。在这片文章中我们一起来实现一个数据统计的功能:报表数据汇总。这个功能为用户查看月度、年度、季度报表提供数据支持。 一、需求 数据统计方面&…...
Java 泛型<? extends Object>
在 Java 泛型中,<? extends Object> 和 <?> 都表示未知类型,但它们在某些情况下有细微的差异。泛型的引入是为了消除运行时错误并增强类型安全性,使代码更具可读性和可维护性。 在 JDK 5 中引入了泛型,以消除编译时…...
FPGA|使用quartus II通过AS下载POF固件
1、将开发板设置到AS下载挡位,或者把下载线插入到AS端口 2、打开quartus II,选择Tools→Programmer→ Mode选择Active Serial Programming 3、点击左侧Add file…,选择 .pof 文件 →start 4、勾选program和verify(可选࿰…...
“新月之智”智能战术头盔系统(CITHS)
新月人物传记:人物传记之新月篇-CSDN博客 相关文章链接(更新): 星际战争模拟系统:新月的编程之道-CSDN博客 新月智能护甲系统CMIA--未来战场的守护者-CSDN博客 目录 一、引言 二、智能头盔控制系统概述 三、系统架…...
php:代码中怎么搭建一个类似linux系统的crontab服务
一、前言 最近使用自己搭建的php框架写一些东西,需要用到异步脚本任务的执行,但是是因为自己搭建的框架没有现成的机制,所以想自己搭建一个类似linux系统的crontab服务的功能。 因为如果直接使用linux crontab的服务配置起来很麻烦࿰…...
【LeetCode: 958. 二叉树的完全性检验 + bfs + 二叉树】
🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…...
MinDoc 安装与部署
下载可执行文件 mindoc mindoc_linux_amd64.zip 上传并解压压缩包 cd /opt mkdir mindoc cd mindocunzip mindoc_linux_amd64.zip 创建数据库 CREATE DATABASE mindoc_db DEFAULT CHARSET utf8mb4 COLLATE utf8mb4_general_ci; 配置数据库 将解压目录下 conf/app.conf.exam…...
从0开始使用面对对象C语言搭建一个基于OLED的图形显示框架(基础组件实现)
目录 基础组件实现 如何将图像和文字显示到OLED上 如何绘制图像 如何绘制文字 如何获取字体? 如何正确的访问字体 如何抽象字体 如何绘制字符串 绘制方案 文本绘制 更加方便的绘制 字体附录 ascii 6x8字体 ascii 8 x 16字体 基础组件实现 我们现在离手…...
windows系统如何检查是否开启了mongodb服务
windows系统如何检查是否开启了mongodb服务!我们有很多软件开发,网站开发时候需要使用到这个mongodb数据库,下面我们看看,如何在windows系统内排查,是否已经启动了本地服务。 在 Windows 系统上,您可以通过…...
VS安卓仿真器下载失败怎么办?
如果网络不稳定,则VS的安卓仿真器很容易下载失败,如下 Downloaded file <USER_HOME>\AppData\Local\Temp\xamarin-android-sdk\x86_64-35_r08.zip not found for Android SDK archive https://dl.google.com/android/repository/sys-img/google_a…...
逻辑回归:给不确定性划界的分类大师
想象你是一名医生。面对患者的检查报告(肿瘤大小、血液指标),你需要做出一个**决定性判断**:恶性还是良性?这种“非黑即白”的抉择,正是**逻辑回归(Logistic Regression)** 的战场&a…...
23-Oracle 23 ai 区块链表(Blockchain Table)
小伙伴有没有在金融强合规的领域中遇见,必须要保持数据不可变,管理员都无法修改和留痕的要求。比如医疗的电子病历中,影像检查检验结果不可篡改行的,药品追溯过程中数据只可插入无法删除的特性需求;登录日志、修改日志…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...
第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明
AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...
Java 加密常用的各种算法及其选择
在数字化时代,数据安全至关重要,Java 作为广泛应用的编程语言,提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景,有助于开发者在不同的业务需求中做出正确的选择。 一、对称加密算法…...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...
蓝桥杯3498 01串的熵
问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798, 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
【生成模型】视频生成论文调研
工作清单 上游应用方向:控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...
