Python 梯度下降法(二):RMSProp Optimize
文章目录
- Python 梯度下降法(二):RMSProp Optimize
- 一、数学原理
- 1.1 介绍
- 1.2 公式
- 二、代码实现
- 2.1 函数代码
- 2.2 总代码
- 三、代码优化
- 3.1 存在问题
- 3.2 收敛判断
- 3.3 函数代码
- 3.4 总代码
- 四、优缺点
- 4.1 优点
- 4.2 缺点
Python 梯度下降法(二):RMSProp Optimize
结合第一篇文章一起看:Python 梯度下降法(一):Gradient Descent-CSDN博客
一、数学原理
1.1 介绍
RMSProp(Root Mean Square Propagation)是一种自适应学习率优化算法,广泛用于深度学习中的梯度下降优化。它通过调整每个参数的学习率来解决传统梯度下降法中学习率固定的问题,从而加速收敛并提高性能。
RMSProp 的核心思想是对每个参数的学习率进行自适应调整。它通过维护一个指数加权移动平均(Exponential Moving Average, EMA)的梯度平方值来调整学习率:
- 对于梯度较大的参数,降低其学习率。
- 对于梯度较小的参数,增加其学习率。
这种方法可以有效缓解梯度下降中的震荡问题,尤其是在非凸优化问题中。
1.2 公式
符号说明:
θ : 需要优化的参数向量 J ( θ ) : 损失函数 g t : 在第 t 次迭代时损失函数关于 θ 的梯度, ∇ θ J ( θ t ) ρ : 衰减率,常用值为 0.9 η : 学习率,需要手动设置 ϵ : 一个及小的参数,无限趋近于零,避免不会出现零 ( 1 0 − 8 ) s t : 指数加权移动平均 \begin{array}{l} \theta&:需要优化的参数向量 \\ J(\theta)&: 损失函数 \\ g_{t}&:在第t次迭代时损失函数关于\theta的梯度,\nabla_{\theta}J(\theta_{t}) \\ \rho &: 衰减率,常用值为0.9\\ \eta&:学习率,需要手动设置 \\ \epsilon&: 一个及小的参数,无限趋近于零,避免不会出现零(10^{-8}) \\ s_{t}&:指数加权移动平均 \end{array} θJ(θ)gtρηϵst:需要优化的参数向量:损失函数:在第t次迭代时损失函数关于θ的梯度,∇θJ(θt):衰减率,常用值为0.9:学习率,需要手动设置:一个及小的参数,无限趋近于零,避免不会出现零(10−8):指数加权移动平均
-
初始化参数为: θ 0 , s 0 = 0 \theta_{0},s_{0}=0 θ0,s0=0
-
迭代更新,每次迭代 t t t中,更新指数加权移动平均:
s t = ρ s t − 1 + ( 1 − ρ ) g t ⊙ g t s_{t}=\rho s_{t-1}+(1-\rho)g_{t}\odot g_{t} st=ρst−1+(1−ρ)gt⊙gt s t s_{t} st可以理解为对梯度平方的一个平滑估计,它更关注近期的梯度信息, ρ \rho ρ控制了历史信息的衰减程度。 -
计算自适应学习率,公式为 η s t + ϵ \frac{\eta}{\sqrt{ s_{t}+\epsilon }} st+ϵη,其中,分母 s t + ϵ \sqrt{ s_{t}+\epsilon} st+ϵ起到了归一化梯度的作用,使得学习率可以更具梯度的尺寸进行自适应的调整。
-
更新参数: θ t + 1 = θ t − η s t + ϵ g t \theta_{t+1}=\theta_{t}- \frac{\eta}{\sqrt{ s_{t}+\epsilon }}g_{t} θt+1=θt−st+ϵηgt
二、代码实现
2.1 函数代码
RMSProp优化算法实现:
# 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, num_iter=1000, epsilon=1e-5, rho=0.9):"""X: 数据 x mxny: 数据 y nx1eta: 学习率 num_iter: 迭代次数epsilon: 无穷小rho: 衰减率"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图for _ in range(num_iter):# 计算预测值h = np.dot(X, theta)# 计算误差error = h - yloss_.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(X.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productreturn theta, loss_
2.2 总代码
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, num_iter=1000, epsilon=1e-5, rho=0.9):"""X: 数据 x mxny: 数据 y nx1eta: 学习率 num_iter: 迭代次数epsilon: 无穷小rho: 衰减率"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图for _ in range(num_iter):# 计算预测值h = np.dot(X, theta)# 计算误差error = h - yloss_.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(X.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productreturn theta, loss_# 生成一些示例数据
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 设置超参数
eta = 0.1# RMSProp优化算法
theta, loss_ = rmsprop_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.show()

可以发现,其对于损失值的下降性能也较好,损失率也较为稳定。
三、代码优化
3.1 存在问题
- 未使用小批量数据:该代码在每次迭代时使用了全部的训练数据
X和y来计算梯度,这相当于批量梯度下降的方式。在处理大规模数据集时,这种方式可能会导致计算效率低下,并且可能会陷入局部最优解。可以参考之前小批量梯度下降的代码,引入小批量数据的处理,以提高算法的效率和泛化能力。 - 缺乏数据预处理:在实际应用中,输入数据
X可能需要进行预处理,例如归一化或标准化,以确保不同特征具有相似的尺度,从而加快算法的收敛速度。(这里不进行解决,参考特征缩放:数据归一化-CSDN博客) - 缺乏收敛判断:代码只是简单地进行了固定次数的迭代,没有设置收敛条件。在实际应用中,可以添加收敛判断,例如当损失值的变化小于某个阈值时提前停止迭代,以节省计算资源。
这里引入Mini-batch Gradient Descent,以及收敛判断,减少计算资源
3.2 收敛判断
# 收敛判断,设定阈值,进行收敛判断
# 满足条件即停止,减少系统资源的使用
if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")break # 注意,这里不能使用return
3.3 函数代码
# 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, batch_size=32, num_iter=1000, epsilon=1e-5, rho=0.9, threshold=1e-3):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y nx1eta: 学习率 batch_size: 批量数据的大小num_iter: 迭代次数epsilon: 无穷小rho: 衰减率threshold: 收敛阈值"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图num_batchs = m // batch_sizefor _ in range(num_iter):# 打乱数据集shuffled_indices = np.random.permutation(m)X_shuffled = X[shuffled_indices]y_shuffled = y[shuffled_indices]loss_temp = [] # 存储每次小批量样本生成的值for batch in range(num_batchs):# 选取小批量样本start_index = batch * batch_sizeend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]# 计算预测值h = np.dot(xi, theta)# 计算误差error = h - yiloss_temp.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(xi.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productloss_.append(np.mean(loss_temp))# 收敛判断if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")breakreturn theta, loss_
3.4 总代码
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 定义RMSProp优化算法
def rmsprop_optimizer(X, y, eta, batch_size=32, num_iter=1000, epsilon=1e-5, rho=0.9, threshold=1e-3):"""X: 数据 x mxn,可以在传入数据之前进行数据的归一化y: 数据 y nx1eta: 学习率 batch_size: 批量数据的大小num_iter: 迭代次数epsilon: 无穷小rho: 衰减率threshold: 收敛阈值"""m, n = X.shape theta, s = np.zeros((n, 1)), np.zeros((n, 1)) # 初始参数 nx1,以及指数加权移动平均 nx1loss_ = [] # 存储损失率的变化,便于绘图num_batchs = m // batch_sizefor _ in range(num_iter):# 打乱数据集shuffled_indices = np.random.permutation(m)X_shuffled = X[shuffled_indices]y_shuffled = y[shuffled_indices]loss_temp = [] # 存储每次小批量样本生成的值for batch in range(num_batchs):# 选取小批量样本start_index = batch * batch_sizeend_index = start_index + batch_sizexi = X_shuffled[start_index:end_index]yi = y_shuffled[start_index:end_index]# 计算预测值h = np.dot(xi, theta)# 计算误差error = h - yiloss_temp.append(np.mean(error**2) / 2)# 计算梯度gradient = (1/m) * np.dot(xi.T, error)s = rho * s + (1 - rho) * np.pow(gradient, 2) # 利用广播机制来进行运算每个维度上的平滑估计theta = theta - np.multiply(eta / np.sqrt(s + epsilon), gradient) # Hadamar productloss_.append(np.mean(loss_temp)) # 使用平均值作为参考# 收敛判断if len(loss_) > 1 and abs(loss_[-1] - loss_[-2]) < threshold:print(f"Converged at iteration {_ + 1}")breakreturn theta, loss_# 生成一些示例数据
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]# 设置超参数
eta = 0.1# RMSProp优化算法
theta, loss_ = rmsprop_optimizer(X_b, y, eta)print("最优参数 theta:")
print(theta)
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.show()

四、优缺点
4.1 优点
- 自适应学习率:RMSProp 能够根据参数的梯度变化情况自适应地调整学习率。对于梯度较大的参数,学习率会自动减小;对于梯度较小的参数,学习率会相对增大。这使得算法在处理不同尺度的梯度时更加稳定,有助于加快收敛速度。
- 缓解 Adagrad 学习率衰减过快问题:与 Adagrad 算法不同,RMSProp 使用指数加权移动平均来计算梯度平方的累积值,避免了 Adagrad 中学习率单调递减且后期学习率过小的问题,使得算法在训练后期仍然能够继续更新参数。
4.2 缺点
- 对超参数敏感:RMSProp 的性能依赖于超参数 η \eta η和 ρ \rho ρ的选择。如果超参数设置不当,可能会导致算法收敛速度慢或者无法收敛到最优解。
- 可能陷入局部最优:和其他基于梯度的优化算法一样,RMSProp 仍然有可能陷入局部最优解,尤其是在损失函数具有复杂的地形时。
相关文章:
Python 梯度下降法(二):RMSProp Optimize
文章目录 Python 梯度下降法(二):RMSProp Optimize一、数学原理1.1 介绍1.2 公式 二、代码实现2.1 函数代码2.2 总代码 三、代码优化3.1 存在问题3.2 收敛判断3.3 函数代码3.4 总代码 四、优缺点4.1 优点4.2 缺点 Python 梯度下降法ÿ…...
Android Studio 正式版 10 周年回顾,承载 Androider 的峥嵘十年
Android Studio 1.0 宣发于 2014 年 12 月,而现在时间来到 2025 ,不知不觉间 Android Studio 已经陪伴 Androider 走过十年历程。 Android Studio 10 周年,也代表着了我的职业生涯也超十年,现在回想起来依然觉得「唏嘘」ÿ…...
sem_wait的概念和使用案列
sem_wait 是 POSIX 标准中定义的一个用于同步的函数,它通常用于操作信号量(semaphore)。信号量是一个整数变量,可以用来控制对共享资源的访问。在多线程编程中,sem_wait 常用于实现线程间的同步。 概念 sem_wait 的基…...
集合的奇妙世界:Python集合的经典、避坑与实战
集合的奇妙世界:Python集合的经典、避坑与实战 内容简介 本系列文章是为 Python3 学习者精心设计的一套全面、实用的学习指南,旨在帮助读者从基础入门到项目实战,全面提升编程能力。文章结构由 5 个版块组成,内容层层递进&#x…...
专业视角深度解析:DeepSeek的核心优势何在?
杭州深度求索(DeepSeek)人工智能基础技术研究有限公司,是一家成立于2023年7月的中国人工智能初创企业,总部位于浙江省杭州市。该公司由量化对冲基金幻方量化(High-Flyer)的联合创始人梁文锋创立,…...
MySQL 索引存储结构
索引是优化数据库查询最重要的方式之一,它是在 MySQL 的存储引擎层中实现的,所以 每一种存储引擎对应的索引不一定相同。我们可以通过下面这张表格,看看不同的存储引擎 分别支持哪种索引类型: BTree 索引和 Hash 索引是我们比较…...
【ComfyUI专栏】如何使用Git命令行安装非Manager收录节点
当前的ComfyUI的收录的自定义节点很多,但是有些节点属于新出来,或者他的应用没有那么广泛,Manager管理节点 有可能没有收录到,这时候 如果我们需要安装需要怎么办呢?这就涉及到我们自己安装这些节点了。例如下面的内容…...
python算法和数据结构刷题[1]:数组、矩阵、字符串
一画图二伪代码三写代码 LeetCode必刷100题:一份来自面试官的算法地图(题解持续更新中)-CSDN博客 算法通关手册(LeetCode) | 算法通关手册(LeetCode) (itcharge.cn) 面试经典 150 题 - 学习计…...
数据分析系列--④RapidMiner进行关联分析(案例)
一、核心概念 1.项集(Itemset) 2.规则(Rule) 3.支持度(Support) 3.1 支持度的定义 3.2 支持度的意义 3.3 支持度的应用 3.4 支持度的示例 3.5 支持度的调整 3.6 支持度与其他指标的关系 4.置信度࿰…...
1/30每日一题
从输入 URL 到页面展示到底发生了什么? 1. 输入 URL 与浏览器解析 当你在浏览器地址栏输入 URL 并按下回车,浏览器首先会解析这个 URL(统一资源定位符),比如 https://www.example.com。浏览器会解析这个 URL 中的不同…...
vim的多文件操作
[rootxxx ~]# vim aa.txt bb.txt cc.txt #多文件操作 next #下一个文件 prev #上一个文件 first #第一个文件 last #最后一个文件 快捷键: ctrlshift^ #当前和上个之间切换 说明:快捷键ctrlshift^,…...
设计转换Apache Hive的HQL语句为Snowflake SQL语句的Python程序方法
首先,根据以下各类HQL语句的基本实例和官方文档记录的这些命令语句各种参数设置,得到各种HQL语句的完整实例,然后在Snowflake的官方文档找到它们对应的Snowflake SQL语句,建立起对应的关系表。在这个过程中要注意HQL语句和Snowfla…...
CAPL与外部接口
CAPL与外部接口 目录 CAPL与外部接口1. 引言2. CAPL与C/C++交互2.1 CAPL与C/C++交互简介2.2 CAPL与C/C++交互实现3. CAPL与Python交互3.1 CAPL与Python交互简介3.2 CAPL与Python交互实现4. CAPL与MATLAB交互4.1 CAPL与MATLAB交互简介4.2 CAPL与MATLAB交互实现5. 案例说明5.1 案…...
无公网IP 外网访问 本地部署夫人 hello-algo
hello-algo 是一个为帮助编程爱好者系统地学习数据结构和算法的开源项目。这款项目通过多种创新的方式,为学习者提供了一个直观、互动的学习平台。 本文将详细的介绍如何利用 Docker 在本地安装部署 hello-algo,并结合路由侠内网穿透实现外网访问本地部署…...
实验四 XML
实验四 XML 目的: 1、安装和使用XML的开发环境 2、认识XML的不同类型 3、掌握XML文档的基本语法 4、了解DTD的作用 5、掌握DTD的语法 6、掌握Schema的语法 实验过程: 1、安装XML的编辑器,可以选择以下之一 a)XMLSpy b)VScode,Vs…...
Autosar-Os是怎么运行的?(内存保护)
写在前面: 入行一段时间了,基于个人理解整理一些东西,如有错误,欢迎各位大佬评论区指正!!! 1.功能概述 以TC397芯片为例,英飞凌芯片集成了MPU模块, MPU模块采用了硬件机…...
题单:冒泡排序1
题目描述 给定 n 个元素的数组(下标从 1 开始计),请使用冒泡排序对其进行排序(升序)。 请输出每一次冒泡过程后数组的状态。 要求:每次从第一个元素开始,将最大的元素冒泡至最后。 输入格式…...
多目标优化策略之一:非支配排序
多目标优化策略中的非支配排序是一种关键的技术,它主要用于解决多目标优化问题中解的选择和排序问题,确定解集中的非支配解(也称为Pareto解)。 关于什么是多目标优化问题,可以查看我的文章:改进候鸟优化算法之五:基于多目标优化的候鸟优化算法(MBO-MO)-CSDN博客 多目…...
Go学习:字符、字符串需注意的点
Go语言与C/C语言编程有很多相似之处,但是Go语言中在声明一个字符时,数据类型与其他语言声明一个字符数据时有一点不同之处。通常,字符的数据类型为 char,例如 :声明一个字符 (字符名称为 ch) 的语句格式为 char ch&am…...
Linux文件原生操作
Linux 中一切皆文件,那么 Linux 文件是什么? 在 Linux 中的文件 可以是:传统意义上的有序数据集合,即:文件系统中的物理文件 也可以是:设备,管道,内存。。。(Linux 管理的一切对象…...
linux之kylin系统nginx的安装
一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源(HTML/CSS/图片等),响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址,提高安全性 3.负载均衡服务器 支持多种策略分发流量…...
云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地
借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...
【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
Python爬虫(一):爬虫伪装
一、网站防爬机制概述 在当今互联网环境中,具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类: 身份验证机制:直接将未经授权的爬虫阻挡在外反爬技术体系:通过各种技术手段增加爬虫获取数据的难度…...
令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...
解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错
出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上,所以报错,到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本,cu、torch、cp 的版本一定要对…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...
优选算法第十二讲:队列 + 宽搜 优先级队列
优选算法第十二讲:队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...
USB Over IP专用硬件的5个特点
USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中,从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备(如专用硬件设备),从而消除了直接物理连接的需要。USB over IP的…...
网站指纹识别
网站指纹识别 网站的最基本组成:服务器(操作系统)、中间件(web容器)、脚本语言、数据厍 为什么要了解这些?举个例子:发现了一个文件读取漏洞,我们需要读/etc/passwd,如…...
