Python 梯度下降法(二):RMSProp Optimize
文章目录
- Python 梯度下降法(二):RMSProp Optimize
- 一、数学原理
- 1.1 介绍
- 1.2 公式
- 二、代码实现
- 2.1 函数代码
- 2.2 总代码
- 三、代码优化
- 3.1 存在问题
- 3.2 收敛判断
- 3.3 函数代码
- 3.4 总代码
- 四、优缺点
- 4.1 优点
- 4.2 缺点
Python 梯度下降法(二):RMSProp Optimize
结合第一篇文章一起看:Python 梯度下降法(一):Gradient Descent-CSDN博客
一、数学原理
1.1 介绍
RMSProp(Root Mean Square Propagation)是一种自适应学习率优化算法,广泛用于深度学习中的梯度下降优化。它通过调整每个参数的学习率来解决传统梯度下降法中学习率固定的问题,从而加速收敛并提高性能。
RMSProp 的核心思想是对每个参数的学习率进行自适应调整。它通过维护一个指数加权移动平均(Exponential Moving Average, EMA)的梯度平方值来调整学习率:
- 对于梯度较大的参数,降低其学习率。
- 对于梯度较小的参数,增加其学习率。
这种方法可以有效缓解梯度下降中的震荡问题,尤其是在非凸优化问题中。
1.2 公式
符号说明:
θ : 需要优化的参数向量 J ( θ ) : 损失函数 g t : 在第 t 次迭代时损失函数关于 θ 的梯度, ∇ θ J ( θ t ) ρ : 衰减率,常用值为 0.9 η : 学习率,需要手动设置 ϵ : 一个及小的参数,无限趋近于零,避免不会出现零 ( 1 0 − 8 ) s t : 指数加权移动平均 \begin{array}{l} \theta&:需要优化的参数向量 \\ J(\theta)&: 损失函数 \\ g_{t}&:在第t次迭代时损失函数关于\theta的梯度,\nabla_{\theta}J(\theta_{t}) \\ \rho &: 衰减率,常用值为0.9\\ \eta&:学习率,需要手动设置 \\ \epsilon&: 一个及小的参数,无限趋近于零,避免不会出现零(10^{-8}) \\ s_{t}&:指数加权移动平均 \end{array} θJ(θ)gtρηϵst:需要优化的参数向量:损失函数:在第t次迭代时损失函数关于θ的梯度,∇θJ(θt):衰减率,常用值为0.9:学习率,需要手动设置:一个及小的参数,无限趋近于零,避免不会出现零(10−8):指数加权移动平均
-
初始化参数为: θ 0 , s 0 = 0 \theta_{0},s_{0}=0 θ0,s0=0
-
迭代更新,每次迭代 t t t中,更新指数加权移动平均:
s t = ρ s t − 1 + ( 1 − ρ ) g t ⊙ g t s_{t}=\rho s_{t-1}+(1-\rho)g_{t}\odot g_{t} st=ρst−1+(1−ρ)gt⊙gt s t s_{t} st可以理解为对梯度平方的一个平滑估计,它更关注近期的梯度信息, ρ \rho ρ控制了历史信息的衰减程度。 -
计算自适应学习率,公式为 η s t + ϵ \frac{\eta}{\sqrt{ s_{t}+\epsilon }} st+ϵη,其中,分母 s t + ϵ \sqrt{ s_{t}+\epsilon} st+ϵ起到了归一化梯度的作用,使得学习率可以更具梯度的尺寸进行自适应的调整。
-
更新参数: θ t + 1 = θ t − η s t + ϵ g t \theta_{t+1}=\theta_{t}- \frac{\eta}{\sqrt{ s_{t}+\epsilon }}g_{t} θt+1=θt−st+ϵηgt
二、代码实现
2.1 函数代码
RMSProp优化算法实现:
# 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, num_iter=1000, epsilon=1e-5, rho=0.9):"""X: 数据 x mxny: 数据 y nx1eta: 学习率 num_iter: 迭代次数epsilon: 无穷小rho: 衰减率"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图for _ in range(num_iter):# 计算预测值h = np.dot(X, theta)# 计算误差error = h - yloss_.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(X.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productreturn theta, loss_
2.2 总代码
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, num_iter=1000, epsilon=1e-5, rho=0.9):"""X: 数据 x mxny: 数据 y nx1eta: 学习率 num_iter: 迭代次数epsilon: 无穷小rho: 衰减率"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图for _ in range(num_iter):# 计算预测值h = np.dot(X, theta)# 计算误差error = h - yloss_.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(X.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productreturn theta, loss_# 生成一些示例数据
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 设置超参数
eta = 0.1# RMSProp优化算法
theta, loss_ = rmsprop_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.show()
可以发现,其对于损失值的下降性能也较好,损失率也较为稳定。
三、代码优化
3.1 存在问题
- 未使用小批量数据:该代码在每次迭代时使用了全部的训练数据
X
和y
来计算梯度,这相当于批量梯度下降的方式。在处理大规模数据集时,这种方式可能会导致计算效率低下,并且可能会陷入局部最优解。可以参考之前小批量梯度下降的代码,引入小批量数据的处理,以提高算法的效率和泛化能力。 - 缺乏数据预处理:在实际应用中,输入数据
X
可能需要进行预处理,例如归一化或标准化,以确保不同特征具有相似的尺度,从而加快算法的收敛速度。(这里不进行解决,参考特征缩放:数据归一化-CSDN博客) - 缺乏收敛判断:代码只是简单地进行了固定次数的迭代,没有设置收敛条件。在实际应用中,可以添加收敛判断,例如当损失值的变化小于某个阈值时提前停止迭代,以节省计算资源。
这里引入Mini-batch Gradient Descent,以及收敛判断,减少计算资源
3.2 收敛判断
# 收敛判断,设定阈值,进行收敛判断
# 满足条件即停止,减少系统资源的使用
if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")break # 注意,这里不能使用return
3.3 函数代码
# 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, batch_size=32, num_iter=1000, epsilon=1e-5, rho=0.9, threshold=1e-3):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y nx1eta: 学习率 batch_size: 批量数据的大小num_iter: 迭代次数epsilon: 无穷小rho: 衰减率threshold: 收敛阈值"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图num_batchs = m // batch_sizefor _ in range(num_iter):# 打乱数据集shuffled_indices = np.random.permutation(m)X_shuffled = X[shuffled_indices]y_shuffled = y[shuffled_indices]loss_temp = [] # 存储每次小批量样本生成的值for batch in range(num_batchs):# 选取小批量样本start_index = batch * batch_sizeend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]# 计算预测值h = np.dot(xi, theta)# 计算误差error = h - yiloss_temp.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(xi.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productloss_.append(np.mean(loss_temp))# 收敛判断if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")breakreturn theta, loss_
3.4 总代码
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, batch_size=32, num_iter=1000, epsilon=1e-5, rho=0.9, threshold=1e-3):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y nx1eta: 学习率 batch_size: 批量数据的大小num_iter: 迭代次数epsilon: 无穷小rho: 衰减率threshold: 收敛阈值"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图num_batchs = m // batch_sizefor _ in range(num_iter):# 打乱数据集shuffled_indices = np.random.permutation(m)X_shuffled = X[shuffled_indices]y_shuffled = y[shuffled_indices]loss_temp = [] # 存储每次小批量样本生成的值for batch in range(num_batchs):# 选取小批量样本start_index = batch * batch_sizeend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]# 计算预测值h = np.dot(xi, theta)# 计算误差error = h - yiloss_temp.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(xi.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productloss_.append(np.mean(loss_temp)) # 使用平均值作为参考# 收敛判断if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")breakreturn theta, loss_# 生成一些示例数据
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 设置超参数
eta = 0.1# RMSProp优化算法
theta, loss_ = rmsprop_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.show()
四、优缺点
4.1 优点
- 自适应学习率:RMSProp 能够根据参数的梯度变化情况自适应地调整学习率。对于梯度较大的参数,学习率会自动减小;对于梯度较小的参数,学习率会相对增大。这使得算法在处理不同尺度的梯度时更加稳定,有助于加快收敛速度。
- 缓解 Adagrad 学习率衰减过快问题:与 Adagrad 算法不同,RMSProp 使用指数加权移动平均来计算梯度平方的累积值,避免了 Adagrad 中学习率单调递减且后期学习率过小的问题,使得算法在训练后期仍然能够继续更新参数。
4.2 缺点
- 对超参数敏感:RMSProp 的性能依赖于超参数 η \eta η和 ρ \rho ρ的选择。如果超参数设置不当,可能会导致算法收敛速度慢或者无法收敛到最优解。
- 可能陷入局部最优:和其他基于梯度的优化算法一样,RMSProp 仍然有可能陷入局部最优解,尤其是在损失函数具有复杂的地形时。
相关文章:

Python 梯度下降法(二):RMSProp Optimize
文章目录 Python 梯度下降法(二):RMSProp Optimize一、数学原理1.1 介绍1.2 公式 二、代码实现2.1 函数代码2.2 总代码 三、代码优化3.1 存在问题3.2 收敛判断3.3 函数代码3.4 总代码 四、优缺点4.1 优点4.2 缺点 Python 梯度下降法ÿ…...

Android Studio 正式版 10 周年回顾,承载 Androider 的峥嵘十年
Android Studio 1.0 宣发于 2014 年 12 月,而现在时间来到 2025 ,不知不觉间 Android Studio 已经陪伴 Androider 走过十年历程。 Android Studio 10 周年,也代表着了我的职业生涯也超十年,现在回想起来依然觉得「唏嘘」ÿ…...

sem_wait的概念和使用案列
sem_wait 是 POSIX 标准中定义的一个用于同步的函数,它通常用于操作信号量(semaphore)。信号量是一个整数变量,可以用来控制对共享资源的访问。在多线程编程中,sem_wait 常用于实现线程间的同步。 概念 sem_wait 的基…...

集合的奇妙世界:Python集合的经典、避坑与实战
集合的奇妙世界:Python集合的经典、避坑与实战 内容简介 本系列文章是为 Python3 学习者精心设计的一套全面、实用的学习指南,旨在帮助读者从基础入门到项目实战,全面提升编程能力。文章结构由 5 个版块组成,内容层层递进&#x…...

专业视角深度解析:DeepSeek的核心优势何在?
杭州深度求索(DeepSeek)人工智能基础技术研究有限公司,是一家成立于2023年7月的中国人工智能初创企业,总部位于浙江省杭州市。该公司由量化对冲基金幻方量化(High-Flyer)的联合创始人梁文锋创立,…...

MySQL 索引存储结构
索引是优化数据库查询最重要的方式之一,它是在 MySQL 的存储引擎层中实现的,所以 每一种存储引擎对应的索引不一定相同。我们可以通过下面这张表格,看看不同的存储引擎 分别支持哪种索引类型: BTree 索引和 Hash 索引是我们比较…...

【ComfyUI专栏】如何使用Git命令行安装非Manager收录节点
当前的ComfyUI的收录的自定义节点很多,但是有些节点属于新出来,或者他的应用没有那么广泛,Manager管理节点 有可能没有收录到,这时候 如果我们需要安装需要怎么办呢?这就涉及到我们自己安装这些节点了。例如下面的内容…...

python算法和数据结构刷题[1]:数组、矩阵、字符串
一画图二伪代码三写代码 LeetCode必刷100题:一份来自面试官的算法地图(题解持续更新中)-CSDN博客 算法通关手册(LeetCode) | 算法通关手册(LeetCode) (itcharge.cn) 面试经典 150 题 - 学习计…...

数据分析系列--④RapidMiner进行关联分析(案例)
一、核心概念 1.项集(Itemset) 2.规则(Rule) 3.支持度(Support) 3.1 支持度的定义 3.2 支持度的意义 3.3 支持度的应用 3.4 支持度的示例 3.5 支持度的调整 3.6 支持度与其他指标的关系 4.置信度࿰…...

1/30每日一题
从输入 URL 到页面展示到底发生了什么? 1. 输入 URL 与浏览器解析 当你在浏览器地址栏输入 URL 并按下回车,浏览器首先会解析这个 URL(统一资源定位符),比如 https://www.example.com。浏览器会解析这个 URL 中的不同…...

vim的多文件操作
[rootxxx ~]# vim aa.txt bb.txt cc.txt #多文件操作 next #下一个文件 prev #上一个文件 first #第一个文件 last #最后一个文件 快捷键: ctrlshift^ #当前和上个之间切换 说明:快捷键ctrlshift^,…...

设计转换Apache Hive的HQL语句为Snowflake SQL语句的Python程序方法
首先,根据以下各类HQL语句的基本实例和官方文档记录的这些命令语句各种参数设置,得到各种HQL语句的完整实例,然后在Snowflake的官方文档找到它们对应的Snowflake SQL语句,建立起对应的关系表。在这个过程中要注意HQL语句和Snowfla…...

CAPL与外部接口
CAPL与外部接口 目录 CAPL与外部接口1. 引言2. CAPL与C/C++交互2.1 CAPL与C/C++交互简介2.2 CAPL与C/C++交互实现3. CAPL与Python交互3.1 CAPL与Python交互简介3.2 CAPL与Python交互实现4. CAPL与MATLAB交互4.1 CAPL与MATLAB交互简介4.2 CAPL与MATLAB交互实现5. 案例说明5.1 案…...

无公网IP 外网访问 本地部署夫人 hello-algo
hello-algo 是一个为帮助编程爱好者系统地学习数据结构和算法的开源项目。这款项目通过多种创新的方式,为学习者提供了一个直观、互动的学习平台。 本文将详细的介绍如何利用 Docker 在本地安装部署 hello-algo,并结合路由侠内网穿透实现外网访问本地部署…...

实验四 XML
实验四 XML 目的: 1、安装和使用XML的开发环境 2、认识XML的不同类型 3、掌握XML文档的基本语法 4、了解DTD的作用 5、掌握DTD的语法 6、掌握Schema的语法 实验过程: 1、安装XML的编辑器,可以选择以下之一 a)XMLSpy b)VScode,Vs…...

Autosar-Os是怎么运行的?(内存保护)
写在前面: 入行一段时间了,基于个人理解整理一些东西,如有错误,欢迎各位大佬评论区指正!!! 1.功能概述 以TC397芯片为例,英飞凌芯片集成了MPU模块, MPU模块采用了硬件机…...

题单:冒泡排序1
题目描述 给定 n 个元素的数组(下标从 1 开始计),请使用冒泡排序对其进行排序(升序)。 请输出每一次冒泡过程后数组的状态。 要求:每次从第一个元素开始,将最大的元素冒泡至最后。 输入格式…...

多目标优化策略之一:非支配排序
多目标优化策略中的非支配排序是一种关键的技术,它主要用于解决多目标优化问题中解的选择和排序问题,确定解集中的非支配解(也称为Pareto解)。 关于什么是多目标优化问题,可以查看我的文章:改进候鸟优化算法之五:基于多目标优化的候鸟优化算法(MBO-MO)-CSDN博客 多目…...

Go学习:字符、字符串需注意的点
Go语言与C/C语言编程有很多相似之处,但是Go语言中在声明一个字符时,数据类型与其他语言声明一个字符数据时有一点不同之处。通常,字符的数据类型为 char,例如 :声明一个字符 (字符名称为 ch) 的语句格式为 char ch&am…...

Linux文件原生操作
Linux 中一切皆文件,那么 Linux 文件是什么? 在 Linux 中的文件 可以是:传统意义上的有序数据集合,即:文件系统中的物理文件 也可以是:设备,管道,内存。。。(Linux 管理的一切对象…...

解决Oracle SQL语句性能问题(10.5)——常用Hint及语法(7)(其他Hint)
10.5.3. 常用hint 10.5.3.7. 其他Hint 1)cardinality:显式的指示优化器为SQL语句的某个行源指定势。该Hint具体语法如下所示。 SQL> select /*+ cardinality([@qb] [table] card ) */ ...; --注: 1)这里,第一个参数(@qb)为可选参数,指定查询语句块名;第二个参数…...

JavaScript系列(50)--编译器实现详解
JavaScript编译器实现详解 🔨 今天,让我们深入探讨JavaScript编译器的实现。编译器是一个将源代码转换为目标代码的复杂系统,通过理解其工作原理,我们可以更好地理解JavaScript的执行过程。 编译器基础概念 🌟 &…...

大数据相关职位 职业进阶路径
大数据相关职位 & 职业进阶路径 📌 大数据相关职位 & 职业进阶路径 大数据领域涵盖多个方向,包括数据工程、数据分析、数据治理、数据科学等,每个方向的进阶路径有所不同。以下是大数据相关职位的详细解析及其职业进阶关系。 &#…...

基础项目实战——学生管理系统(c++)
目录 前言一、功能菜单界面二、类与结构体的实现三、录入学生信息四、删除学生信息五、更改学生信息六、查找学生信息七、统计学生人数八、保存学生信息九、读取学生信息十、打印所有学生信息十一、退出系统十二、文件拆分结语 前言 这一期我们来一起学习我们在大学做过的课程…...

C++,STL,【目录篇】
文章目录 一、简介二、内容提纲第一部分:STL 概述第二部分:STL 容器第三部分:STL 迭代器第四部分:STL 算法第五部分:STL 函数对象第六部分:STL 高级主题第七部分:STL 实战应用 三、写作风格四、…...

【Rust自学】15.3. Deref trait Pt.2:隐式解引用转化与可变性
喜欢的话别忘了点赞、收藏加关注哦,对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 15.3.1. 函数和方法的隐式解引用转化(Deref Coercion) 隐式解引用转化(Deref Coercion)是为函数和方法提供的一种便捷特性。 它的原理是…...

密码强度验证代码解析:C语言实现与细节剖析
在日常的应用开发中,密码强度验证是保障用户账户安全的重要环节。今天,我们就来深入分析一段用C语言编写的密码强度验证代码,看看它是如何实现对密码强度的多维度检测的。 代码整体结构 这段C语言代码主要实现了对输入密码的一系列规则验证&a…...

arkts bridge使用示例
接上一篇:arkui-x跨平台与android java联合开发-CSDN博客 本篇讲前端arkui如何与后端其他平台进行数据交互,接上一篇,后端os平台为Android java。 arkui-x框架提供了一个独特的机制:bridge。 1、前端接口定义实现 定义一个bri…...

LINUX部署微服务项目步骤
项目简介技术栈 主体技术:SpringCloud,SpringBoot,VUE2, 中间件:RabbitMQ、Redis 创建用户 在linux服务器home下创建用户qshh,用于后续本项目需要的环境进行安装配置 #创建用户 useradd 用户名 #设置登录密…...

zsh安装插件
0 zsh不仅在外观上比较美观,而且其具有强大的插件,如果不使用那就亏大了。 官方插件库 https://github.com/ohmyzsh/ohmyzsh/wiki/Plugins 官方插件库并不一定有所有的插件,比如zsh-autosuggestions插件就不再列表里,下面演示zs…...