梯度下降方法
2.5 梯度下降方法介绍
学习目标
- 掌握梯度下降法的推导过程
- 知道全梯度下降算法的原理
- 知道随机梯度下降算法的原理
- 知道随机平均梯度下降算法的原理
- 知道小批量梯度下降算法的原理
上一节中给大家介绍了最基本的梯度下降法实现流程,本节我们将进一步介绍梯度下降法的详细过算法推导过程和常见的梯度下降算法。
1 详解梯度下降算法
1.1梯度下降的相关概念复习
在详细了解梯度下降的算法之前,我们先复习相关的一些概念。
步长(Learning rate):
- 步长决定了在梯度下降迭代的过程中,每一步沿梯度负方向前进的长度。用前面下山的例子,步长就是在当前这一步所在位置沿着最陡峭最易下山的位置走的那一步的长度。
特征(feature):
- 指的是样本中输入部分,比如2个单特征的样本(x(0),y(0)),(x(1),y(1))(x^{(0)},y^{(0)}),(x^{(1)},y^{(1)})(x(0),y(0)),(x(1),y(1)),则第一个样本特征为x(0)x^{(0)}x(0),第一个样本输出为y(0)y^{(0)}y(0)。
假设函数(hypothesis function):
- 在监督学习中,为了拟合输入样本,而使用的假设函数,记为hθ(x)h_\theta (x)hθ(x)。比如对于单个特征的m个样本(x(i),y(i))(i=1,2,...m)(x^{(i)},y^{(i)})(i=1,2,...m)(x(i),y(i))(i=1,2,...m),可以采用拟合函数如下: hθ(x)=θ0+θ1xh_\theta (x)=\theta _0+\theta _1xhθ(x)=θ0+θ1x。
损失函数(loss function):
- 为了评估模型拟合的好坏,通常用损失函数来度量拟合的程度。损失函数极小化,意味着拟合程度最好,对应的模型参数即为最优参数。
- 在线性回归中,损失函数通常为样本输出和假设函数的差取平方。比如对于m个样本(xi,yi)(i=1,2,...m)(x_i,y_i)(i=1,2,...m)(xi,yi)(i=1,2,...m),采用线性回归,损失函数为:
其中xix_ixi表示第i个样本特征,yiy_iyi表示第i个样本对应的输出,hθ(xi)h_\theta (x_i)hθ(xi)为假设函数。
1.2 梯度下降法的推导流程
1) 先决条件: 确认优化模型的假设函数和损失函数。
比如对于线性回归,假设函数表示为 hθ(x1,x2,...,xn)=θ0+θ1x1+...+θnxnh_\theta (x_1,x_2,...,x_n)=\theta _0+\theta _1x_1+...+\theta _nx_nhθ(x1,x2,...,xn)=θ0+θ1x1+...+θnxn, 其中θi(i=0,1,2...n)\theta _i (i = 0,1,2... n)θi(i=0,1,2...n)为模型参数,xi(i=0,1,2...n)x_i (i = 0,1,2... n)xi(i=0,1,2...n)为每个样本的n个特征值。这个表示可以简化,我们增加一个特征x0=1x_0=1x0=1 ,这样
同样是线性回归,对应于上面的假设函数,损失函数为:
2) 算法相关参数初始化,
主要是初始化θ0,θ1...,θn\theta _0,\theta _1...,\theta _nθ0,θ1...,θn,算法终止距离ε以及步长α\alphaα 。在没有任何先验知识的时候,我喜欢将所有的θ\thetaθ 初始化为0, 将步长初始化为1。在调优的时候再 优化。
3) 算法过程:
3.1) 确定当前位置的损失函数的梯度,对于θi\theta _iθi,其梯度表达式如下:
3.2) 用步长乘以损失函数的梯度,得到当前位置下降的距离,即
对应于前面登山例子中的某一步。
3.3) 确定是否所有的θi\theta _iθi,梯度下降的距离都小于ε,如果小于ε则算法终止,当前所有的θi(i=0,1,...n)\theta _i(i=0,1,...n)θi(i=0,1,...n)即为最终结果。否则进入步骤4.
4)更新所有的θ\thetaθ ,对于θi\theta _iθi,其更新表达式如下。更新完毕后继续转入步骤1.
下面用线性回归的例子来具体描述梯度下降。假设我们的样本是:
损失函数如前面先决条件所述:
则在算法过程步骤1中对于θi\theta _iθi 的偏导数计算如下:
由于样本中没有x0x_0x0上式中令所有的x0jx_0^jx0j为1.
步骤4中θi\theta _iθi的更新表达式如下:
从这个例子可以看出当前点的梯度方向是由所有的样本决定的,加1m\frac{1}{m}m1 是为了好理解。由于步长也为常数,他们的乘积也为常数,所以这里α\alphaα¨NBSP;1m\frac{1}{m}m1 可以用一个常数表示。
在下面一节中,咱们会详细讲到梯度下降法的变种,他们主要的区别就是对样本的采用方法不同。这里我们采用的是用所有样本。
2 梯度下降法大家族
首先,我们来看一下,常见的梯度下降算法有:
- 全梯度下降算法(Full gradient descent),
- 随机梯度下降算法(Stochastic gradient descent),
- 小批量梯度下降算法(Mini-batch gradient descent),
- 随机平均梯度下降算法(Stochastic average gradient descent)
它们都是为了正确地调节权重向量,通过为每个权重计算一个梯度,从而更新权值,使目标函数尽可能最小化。其差别在于样本的使用方式不同。
2.1 全梯度下降算法(FG)
批量梯度下降法,是梯度下降法最常用的形式,具体做法也就是在更新参数时使用所有的样本来进行更新。
计算训练集所有样本误差,对其求和再取平均值作为目标函数。
权重向量沿其梯度相反的方向移动,从而使当前目标函数减少得最多。
其是在整个训练数据集上计算损失函数关于参数θ\thetaθ 的梯度:
由于我们有m个样本,这里求梯度的时候就用了所有m个样本的梯度数据。
注意:
因为在执行每次更新时,我们需要在整个数据集上计算所有的梯度,所以批梯度下降法的速度会很慢,同时,批梯度下降法无法处理超出内存容量限制的数据集。
批梯度下降法同样也不能在线更新模型,即在运行的过程中,不能增加新的样本。
2.2 随机梯度下降算法(SG)
由于FG每迭代更新一次权重都需要计算所有样本误差,而实际问题中经常有上亿的训练样本,故效率偏低,且容易陷入局部最优解,因此提出了随机梯度下降算法。
其每轮计算的目标函数不再是全体样本误差,而仅是单个样本误差,即每次只代入计算一个样本目标函数的梯度来更新权重,再取下一个样本重复此过程,直到损失函数值停止下降或损失函数值小于某个可以容忍的阈值。
此过程简单,高效,通常可以较好地避免更新迭代收敛到局部最优解。其迭代形式为
但是由于,SG每次只使用一个样本迭代,若遇上噪声则容易陷入局部最优解。
2.3 小批量梯度下降算法(mini-batch)
小批量梯度下降算法是FG和SG的折中方案,在一定程度上兼顾了以上两种方法的优点。
每次从训练样本集上随机抽取一个小样本集,在抽出来的小样本集上采用FG迭代更新权重。
被抽出的小样本集所含样本点的个数称为batch_size,通常设置为2的幂次方,更有利于GPU加速处理。
特别的,若batch_size=1,则变成了SG;若batch_size=n,则变成了FG.其迭代形式为
上式中,也就是我们从m个样本中,选择x个样本进行迭代(1<x<m),
2.4 随机平均梯度下降算法(SAG)
在SG方法中,虽然避开了运算成本大的问题,但对于大数据训练而言,SG效果常不尽如人意,因为每一轮梯度更新都完全与上一轮的数据和梯度无关。
随机平均梯度算法克服了这个问题,在内存中为每一个样本都维护一个旧的梯度,随机选择第i个样本来更新此样本的梯度,其他样本的梯度保持不变,然后求得所有梯度的平均值,进而更新了参数。
如此,每一轮更新仅需计算一个样本的梯度,计算成本等同于SG,但收敛速度快得多。
其迭代形式为:
θi=θi−αn(hθ(x0(j),x1(j),...xn(j))−yj)xi(j)\theta _i=\theta _i-\frac{\alpha }{n}(h_\theta (x^{(j)}_0,x^{(j)}_1,...x^{(j)}_n)-y_j)x_i^{(j)}θi=θi−nα(hθ(x0(j),x1(j),...xn(j))−yj)xi(j)
- 我们知道sgd是当前权重减去步长乘以梯度,得到新的权重。sag中的a,就是平均的意思,具体说,就是在第k步迭代的时候,我考虑的这一步和前面n-1个梯度的平均值,当前权重减去步长乘以最近n个梯度的平均值。
- n是自己设置的,当n=1的时候,就是普通的sgd。
- 这个想法非常的简单,在随机中又增加了确定性,类似于mini-batch sgd的作用,但不同的是,sag又没有去计算更多的样本,只是利用了之前计算出来的梯度,所以每次迭代的计算成本远小于mini-batch sgd,和sgd相当。效果而言,sag相对于sgd,收敛速度快了很多。这一点下面的论文中有具体的描述和证明。
- SAG论文链接:https://arxiv.org/pdf/1309.2388.pdf
拓展阅读:
- 梯度下降法算法比较和进一步优化
3 小结
- 全梯度下降算法(FG)【知道】
- 在进行计算的时候,计算所有样本的误差平均值,作为我的目标函数
- 随机梯度下降算法(SG)【知道】
- 每次只选择一个样本进行考核
- 小批量梯度下降算法(mini-batch)【知道】
- 选择一部分样本进行考核
- 随机平均梯度下降算法(SAG)【知道】
- 会给每个样本都维持一个平均值,后期计算的时候,参考这个平均值
相关文章:

梯度下降方法
2.5 梯度下降方法介绍 学习目标 掌握梯度下降法的推导过程知道全梯度下降算法的原理知道随机梯度下降算法的原理知道随机平均梯度下降算法的原理知道小批量梯度下降算法的原理 上一节中给大家介绍了最基本的梯度下降法实现流程,本节我们将进一步介绍梯度下降法的详细…...

web3与AI结合-Sahara AI 项目介绍
背景介绍 Sahara AI 于 2023 年创立,是一个 "区块链AI" 领域的项目。其项目愿景是,利用区块链和隐私技术将现有的 AI 商业模式去中心化,打造公平、透明、低门槛的 “协作 AI 经济” 体系,旨在重构新的利益分配机制以及…...

Nginx——反向代理(三/五)
目录 1.Nginx 反向代理1.1.Nginx 反向代理概述1.2.Nginx 反向代理的配置语法1.2.1.proxy_pass1.2.2.proxy_set_header1.2.3.proxy_redirect 1.3.Nginx 反向代理实战1.4.Nginx 的安全控制1.4.1.如何使用 SSL 对流量进行加密1.4.2.Nginx 添加 SSL 的支持1.4.3.Nginx 的 SSL 相关指…...

环动科技平均售价波动下滑:大客户依赖明显,应收账款周转率骤降
《港湾商业观察》施子夫 2024年12月18日,浙江环动机器人关节科技股份有限公司(以下简称,环动科技)的上市审核状态变更为“已问询”,公司在11月25日科创板IPO获上交所受理,独家保荐机构为广发证券。 此次环…...

源网荷储:构建智慧能源生态的关键方案设计
一、技术融合基石 多元能源采集技术:在 “源” 端,除了常见的光伏、风电、火电,生物质能发电、地热能利用技术也应纳入考量。例如在有丰富生物质原料的农村地区,小型生物质发电厂可实现废物利用与供电双赢;地热资源丰…...

进程间通讯
简介: 进程间通讯方式有: 1.内存映射(mmap): 使用mmap函数将磁盘空间映射到内存 2.管道 3.信号 4.套接字(socket) 5.信号机制 通过进程中kill函数,去给另一个函数发送信号&a…...

STM32-笔记33-OLED实验
实验目的 驱动 OLED 屏幕,显示点、线、字符、字符串、汉字、图片等内容。 项目实现-OLED通讯协议 复制项目文件19-串口打印功能 重命名为47-OLED实验 打开项目文件 加载文件 代码书写顺序: oled.c #include "oled.h"//初始化oled的gpio …...

低空管控技术-无人机云监视技术详解!
一、无人机监听技术的原理 无人机监听技术主要依赖于射频(RF)探测、光学和红外传感器等技术手段。这些技术通过被动监听和监测无人机与飞行员(或控制器)之间的通信链路传输,以确定无人机的位置,甚至在某些…...

RedisTemplate执行lua脚本及Lua 脚本语言详解
使用RedisTemplate执行lua脚本 在开发中,我们经常需要与Redis数据库进行交互,而Redis是一个基于内存的高性能键值存储数据库,它支持多种数据结构,并提供了丰富的命令接口。在某些情况下,我们可能需要执行一些复杂的逻…...

基于springboot的网上商城购物系统
作者:学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等 文末获取“源码数据库万字文档PPT”,支持远程部署调试、运行安装。 目录 项目包含: 开发说明: 系统功能: 项目截图…...

服务器攻击方式有哪几种?
随着互联网的快速发展,网络攻击事件频发,已泛滥成互联网行业的重病,受到了各个行业的关注与重视,因为它对网络安全乃至国家安全都形成了严重的威胁。面对复杂多样的网络攻击,想要有效防御就必须了解网络攻击的相关内容…...

【Unity3D】AB包加密(AssetBundle加密)
加密前: 加密后,直接无法加载ab,所以无法正常看到ab内容。 using UnityEngine; using UnityEditor; using System.IO; public static class AssetBundleDemoTest {[MenuItem("Tools/打包!")]public static void Build(){//注意:St…...

【FTP 协议】FTP主动模式
一、测试工具 服务器:FileZilla_Server-cn-0_9_60_2.exe 中文版本 客户端:FileZilla_3.66.5_win64 客户端IP: 192.168.9.186 服务端 IP: 192.168.9.161 在客户端请求PORT之前,抓包测试的结果跟被动模式流程相同。 二、客户端主动模式命令…...

十五、Vue 响应接口
文章目录 一、响应式系统基础什么是响应式系统响应式数据的声明与使用二、响应式原理深入Object.defineProperty () 方法的应用(Vue2)Proxy 对象的应用(Vue3)三、响应式接口之 ref 和 reactive(Vue3)ref 函数的使用reactive 函数的使用四、计算属性(computed)作为响应式…...

至强6搭配美光CZ122,证明CXL可以提高生成式AI的性能表现
最近发现了英特尔官网公布的一项最新测试报告,报告显示,将美光的CZ122 CXL内存模块放到英特尔至强6平台上,显著提升了HPC和AI工作负载的内存带宽,特别是在采用基于软件的交错配置(interleave configuration)…...

一文理解ssh,ssl协议以及应用
在使用基于密钥的认证方式的时候,私钥的位置一定要符合远程服务器规定的位置,否则找不到私钥的位置会导致建立ssh连接失败 SSH 全称是 “Secure Shell”,即安全外壳协议。 它是一种网络协议,用于在不安全的网络中安全地进行远程登…...

电子应用设计方案87:智能AI收纳箱系统设计
智能 AI 收纳箱系统设计 一、引言 智能 AI 收纳箱系统旨在为用户提供更高效、便捷和智能的物品收纳与管理解决方案,通过融合人工智能技术和创新设计,提升用户的生活品质和物品整理效率。 二、系统概述 1. 系统目标 - 实现物品的自动分类和整理…...

BloombergGPT: A Large Language Model for Finance——面向金融领域的大语言模型
这篇文章介绍了BloombergGPT,一个专门为金融领域设计的大语言模型(LLM)。以下是文章的主要内容总结: 背景与动机: 大语言模型(如GPT-3)在多个任务上表现出色,但尚未有针对金融领域的…...

LeetCode - #180 Swift 实现连续数字查询
文章目录 摘要描述SQL 解法Swift 题解代码Swift 题解代码分析核心逻辑关键函数 示例测试及结果测试 1测试 2 时间复杂度空间复杂度总结 摘要 本文将解决如何从日志数据中找出连续出现至少三次的数字。通过 SQL 查询语句结合 Swift 数据库操作,我们将完成这一任务。…...

为什么ip属地一会河南一会江苏
在使用互联网的过程中,许多用户可能会遇到这样一个问题:自己的IP属地一会儿显示为河南,一会儿又变成了江苏。这种现象可能会让人感到困惑,甚至产生疑虑,担心自己的网络活动是否受到了某种影响。为了解答这一疑问&#…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别
一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互
物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合
强化学习(Reinforcement Learning, RL)是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程,然后使用强化学习的Actor-Critic机制(中文译作“知行互动”机制),逐步迭代求解…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动
一、前言说明 在2011版本的gb28181协议中,拉取视频流只要求udp方式,从2016开始要求新增支持tcp被动和tcp主动两种方式,udp理论上会丢包的,所以实际使用过程可能会出现画面花屏的情况,而tcp肯定不丢包,起码…...

从WWDC看苹果产品发展的规律
WWDC 是苹果公司一年一度面向全球开发者的盛会,其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具,对过去十年 WWDC 主题演讲内容进行了系统化分析,形成了这份…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容
基于 UniApp + WebSocket实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

Module Federation 和 Native Federation 的比较
前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...
Matlab | matlab常用命令总结
常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...
Device Mapper 机制
Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...
代理篇12|深入理解 Vite中的Proxy接口代理配置
在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...