当前位置: 首页 > news >正文

梯度下降方法

2.5 梯度下降方法介绍

学习目标

  • 掌握梯度下降法的推导过程
  • 知道全梯度下降算法的原理
  • 知道随机梯度下降算法的原理
  • 知道随机平均梯度下降算法的原理
  • 知道小批量梯度下降算法的原理

上一节中给大家介绍了最基本的梯度下降法实现流程,本节我们将进一步介绍梯度下降法的详细过算法推导过程常见的梯度下降算法

1 详解梯度下降算法

1.1梯度下降的相关概念复习

在详细了解梯度下降的算法之前,我们先复习相关的一些概念。

  • 步长(Learning rate):

    • 步长决定了在梯度下降迭代的过程中,每一步沿梯度负方向前进的长度。用前面下山的例子,步长就是在当前这一步所在位置沿着最陡峭最易下山的位置走的那一步的长度。
  • 特征(feature):

    • 指的是样本中输入部分,比如2个单特征的样本(x(0),y(0)),(x(1),y(1))(x^{(0)},y^{(0)}),(x^{(1)},y^{(1)})(x(0),y(0)),(x(1),y(1)),则第一个样本特征为x(0)x^{(0)}x(0),第一个样本输出为y(0)y^{(0)}y(0)
  • 假设函数(hypothesis function):

    • 在监督学习中,为了拟合输入样本,而使用的假设函数,记为hθ(x)h_\theta (x)hθ(x)比如对于单个特征的m个样本(x(i),y(i))(i=1,2,...m)(x^{(i)},y^{(i)})(i=1,2,...m)(x(i),y(i))(i=1,2,...m),可以采用拟合函数如下: hθ(x)=θ0+θ1xh_\theta (x)=\theta _0+\theta _1xhθ(x)=θ0+θ1x
  • 损失函数(loss function):

    • 为了评估模型拟合的好坏,通常用损失函数来度量拟合的程度。损失函数极小化,意味着拟合程度最好,对应的模型参数即为最优参数。
    • 在线性回归中,损失函数通常为样本输出和假设函数的差取平方。比如对于m个样本(xi,yi)(i=1,2,...m)(x_i,y_i)(i=1,2,...m)(xi,yi)(i=1,2,...m),采用线性回归,损失函数为:

在这里插入图片描述

其中xix_ixi表示第i个样本特征,yiy_iyi表示第i个样本对应的输出,hθ(xi)h_\theta (x_i)hθ(xi)为假设函数。

1.2 梯度下降法的推导流程

1) 先决条件: 确认优化模型的假设函数和损失函数。

比如对于线性回归,假设函数表示为 hθ(x1,x2,...,xn)=θ0+θ1x1+...+θnxnh_\theta (x_1,x_2,...,x_n)=\theta _0+\theta _1x_1+...+\theta _nx_nhθ(x1,x2,...,xn)=θ0+θ1x1+...+θnxn, 其中θi(i=0,1,2...n)\theta _i (i = 0,1,2... n)θi(i=0,1,2...n)为模型参数,xi(i=0,1,2...n)x_i (i = 0,1,2... n)xi(i=0,1,2...n)为每个样本的n个特征值。这个表示可以简化,我们增加一个特征x0=1x_0=1x0=1 ,这样

在这里插入图片描述

同样是线性回归,对应于上面的假设函数,损失函数为:

在这里插入图片描述

2) 算法相关参数初始化,

主要是初始化θ0,θ1...,θn\theta _0,\theta _1...,\theta _nθ0,θ1...,θn,算法终止距离ε以及步长α\alphaα 。在没有任何先验知识的时候,我喜欢将所有的θ\thetaθ 初始化为0, 将步长初始化为1。在调优的时候再 优化。

3) 算法过程:

3.1) 确定当前位置的损失函数的梯度,对于θi\theta _iθi,其梯度表达式如下:

在这里插入图片描述

3.2) 用步长乘以损失函数的梯度,得到当前位置下降的距离,即

在这里插入图片描述

对应于前面登山例子中的某一步。

3.3) 确定是否所有的θi\theta _iθi,梯度下降的距离都小于ε,如果小于ε则算法终止,当前所有的θi(i=0,1,...n)\theta _i(i=0,1,...n)θi(i=0,1,...n)即为最终结果。否则进入步骤4.

4)更新所有的θ\thetaθ ,对于θi\theta _iθi,其更新表达式如下。更新完毕后继续转入步骤1.

在这里插入图片描述


下面用线性回归的例子来具体描述梯度下降。假设我们的样本是:

在这里插入图片描述

损失函数如前面先决条件所述:

在这里插入图片描述

则在算法过程步骤1中对于θi\theta _iθi 的偏导数计算如下:

在这里插入图片描述

由于样本中没有x0x_0x0上式中令所有的x0jx_0^jx0j为1.

步骤4中θi\theta _iθi的更新表达式如下:

在这里插入图片描述

从这个例子可以看出当前点的梯度方向是由所有的样本决定的,加1m\frac{1}{m}m1 是为了好理解。由于步长也为常数,他们的乘积也为常数,所以这里α\alphaα¨NBSP;1m\frac{1}{m}m1 可以用一个常数表示。


在下面一节中,咱们会详细讲到梯度下降法的变种,他们主要的区别就是对样本的采用方法不同。这里我们采用的是用所有样本。

2 梯度下降法大家族

首先,我们来看一下,常见的梯度下降算法有:

  • 全梯度下降算法(Full gradient descent),
  • 随机梯度下降算法(Stochastic gradient descent),
  • 小批量梯度下降算法(Mini-batch gradient descent),
  • 随机平均梯度下降算法(Stochastic average gradient descent)

它们都是为了正确地调节权重向量,通过为每个权重计算一个梯度,从而更新权值,使目标函数尽可能最小化。其差别在于样本的使用方式不同。

2.1 全梯度下降算法(FG)

批量梯度下降法,是梯度下降法最常用的形式,具体做法也就是在更新参数时使用所有的样本来进行更新。

计算训练集所有样本误差对其求和再取平均值作为目标函数

权重向量沿其梯度相反的方向移动,从而使当前目标函数减少得最多。

其是在整个训练数据集上计算损失函数关于参数θ\thetaθ 的梯度:

在这里插入图片描述

由于我们有m个样本,这里求梯度的时候就用了所有m个样本的梯度数据。

注意:

  • 因为在执行每次更新时,我们需要在整个数据集上计算所有的梯度,所以批梯度下降法的速度会很慢,同时,批梯度下降法无法处理超出内存容量限制的数据集。

  • 批梯度下降法同样也不能在线更新模型,即在运行的过程中,不能增加新的样本

2.2 随机梯度下降算法(SG)

由于FG每迭代更新一次权重都需要计算所有样本误差,而实际问题中经常有上亿的训练样本,故效率偏低,且容易陷入局部最优解,因此提出了随机梯度下降算法。

其每轮计算的目标函数不再是全体样本误差,而仅是单个样本误差,即每次只代入计算一个样本目标函数的梯度来更新权重,再取下一个样本重复此过程,直到损失函数值停止下降或损失函数值小于某个可以容忍的阈值。

此过程简单,高效,通常可以较好地避免更新迭代收敛到局部最优解。其迭代形式为

在这里插入图片描述

但是由于,SG每次只使用一个样本迭代,若遇上噪声则容易陷入局部最优解。

2.3 小批量梯度下降算法(mini-batch)

小批量梯度下降算法是FG和SG的折中方案,在一定程度上兼顾了以上两种方法的优点。

每次从训练样本集上随机抽取一个小样本集,在抽出来的小样本集上采用FG迭代更新权重。

被抽出的小样本集所含样本点的个数称为batch_size,通常设置为2的幂次方,更有利于GPU加速处理。

特别的,若batch_size=1,则变成了SG;若batch_size=n,则变成了FG.其迭代形式为

在这里插入图片描述

上式中,也就是我们从m个样本中,选择x个样本进行迭代(1<x<m),

2.4 随机平均梯度下降算法(SAG)

在SG方法中,虽然避开了运算成本大的问题,但对于大数据训练而言,SG效果常不尽如人意,因为每一轮梯度更新都完全与上一轮的数据和梯度无关。

随机平均梯度算法克服了这个问题,在内存中为每一个样本都维护一个旧的梯度,随机选择第i个样本来更新此样本的梯度,其他样本的梯度保持不变,然后求得所有梯度的平均值,进而更新了参数。

如此,每一轮更新仅需计算一个样本的梯度,计算成本等同于SG,但收敛速度快得多。

其迭代形式为:

θi=θi−αn(hθ(x0(j),x1(j),...xn(j))−yj)xi(j)\theta _i=\theta _i-\frac{\alpha }{n}(h_\theta (x^{(j)}_0,x^{(j)}_1,...x^{(j)}_n)-y_j)x_i^{(j)}θi=θinα(hθ(x0(j),x1(j),...xn(j))yj)xi(j)

  • 我们知道sgd是当前权重减去步长乘以梯度,得到新的权重。sag中的a,就是平均的意思,具体说,就是在第k步迭代的时候,我考虑的这一步和前面n-1个梯度的平均值,当前权重减去步长乘以最近n个梯度的平均值。
  • n是自己设置的,当n=1的时候,就是普通的sgd。
  • 这个想法非常的简单,在随机中又增加了确定性,类似于mini-batch sgd的作用,但不同的是,sag又没有去计算更多的样本,只是利用了之前计算出来的梯度,所以每次迭代的计算成本远小于mini-batch sgd,和sgd相当。效果而言,sag相对于sgd,收敛速度快了很多。这一点下面的论文中有具体的描述和证明。
  • SAG论文链接:https://arxiv.org/pdf/1309.2388.pdf
  • 拓展阅读:

    • 梯度下降法算法比较和进一步优化

3 小结

  • 全梯度下降算法(FG)【知道】
    • 在进行计算的时候,计算所有样本的误差平均值,作为我的目标函数
  • 随机梯度下降算法(SG)【知道】
    • 每次只选择一个样本进行考核
  • 小批量梯度下降算法(mini-batch)【知道】
    • 选择一部分样本进行考核
  • 随机平均梯度下降算法(SAG)【知道】
    • 会给每个样本都维持一个平均值,后期计算的时候,参考这个平均值

相关文章:

梯度下降方法

2.5 梯度下降方法介绍 学习目标 掌握梯度下降法的推导过程知道全梯度下降算法的原理知道随机梯度下降算法的原理知道随机平均梯度下降算法的原理知道小批量梯度下降算法的原理 上一节中给大家介绍了最基本的梯度下降法实现流程&#xff0c;本节我们将进一步介绍梯度下降法的详细…...

web3与AI结合-Sahara AI 项目介绍

背景介绍 Sahara AI 于 2023 年创立&#xff0c;是一个 "区块链AI" 领域的项目。其项目愿景是&#xff0c;利用区块链和隐私技术将现有的 AI 商业模式去中心化&#xff0c;打造公平、透明、低门槛的 “协作 AI 经济” 体系&#xff0c;旨在重构新的利益分配机制以及…...

Nginx——反向代理(三/五)

目录 1.Nginx 反向代理1.1.Nginx 反向代理概述1.2.Nginx 反向代理的配置语法1.2.1.proxy_pass1.2.2.proxy_set_header1.2.3.proxy_redirect 1.3.Nginx 反向代理实战1.4.Nginx 的安全控制1.4.1.如何使用 SSL 对流量进行加密1.4.2.Nginx 添加 SSL 的支持1.4.3.Nginx 的 SSL 相关指…...

环动科技平均售价波动下滑:大客户依赖明显,应收账款周转率骤降

《港湾商业观察》施子夫 2024年12月18日&#xff0c;浙江环动机器人关节科技股份有限公司&#xff08;以下简称&#xff0c;环动科技&#xff09;的上市审核状态变更为“已问询”&#xff0c;公司在11月25日科创板IPO获上交所受理&#xff0c;独家保荐机构为广发证券。 此次环…...

源网荷储:构建智慧能源生态的关键方案设计

一、技术融合基石 多元能源采集技术&#xff1a;在 “源” 端&#xff0c;除了常见的光伏、风电、火电&#xff0c;生物质能发电、地热能利用技术也应纳入考量。例如在有丰富生物质原料的农村地区&#xff0c;小型生物质发电厂可实现废物利用与供电双赢&#xff1b;地热资源丰…...

进程间通讯

简介&#xff1a; 进程间通讯方式有&#xff1a; 1.内存映射&#xff08;mmap&#xff09;&#xff1a; 使用mmap函数将磁盘空间映射到内存 2.管道 3.信号 4.套接字&#xff08;socket&#xff09; 5.信号机制 通过进程中kill函数&#xff0c;去给另一个函数发送信号&a…...

STM32-笔记33-OLED实验

实验目的 驱动 OLED 屏幕&#xff0c;显示点、线、字符、字符串、汉字、图片等内容。 项目实现-OLED通讯协议 复制项目文件19-串口打印功能 重命名为47-OLED实验 打开项目文件 加载文件 代码书写顺序&#xff1a; oled.c #include "oled.h"//初始化oled的gpio …...

低空管控技术-无人机云监视技术详解!

一、无人机监听技术的原理 无人机监听技术主要依赖于射频&#xff08;RF&#xff09;探测、光学和红外传感器等技术手段。这些技术通过被动监听和监测无人机与飞行员&#xff08;或控制器&#xff09;之间的通信链路传输&#xff0c;以确定无人机的位置&#xff0c;甚至在某些…...

RedisTemplate执行lua脚本及Lua 脚本语言详解

使用RedisTemplate执行lua脚本 在开发中&#xff0c;我们经常需要与Redis数据库进行交互&#xff0c;而Redis是一个基于内存的高性能键值存储数据库&#xff0c;它支持多种数据结构&#xff0c;并提供了丰富的命令接口。在某些情况下&#xff0c;我们可能需要执行一些复杂的逻…...

基于springboot的网上商城购物系统

作者&#xff1a;学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等 文末获取“源码数据库万字文档PPT”&#xff0c;支持远程部署调试、运行安装。 目录 项目包含&#xff1a; 开发说明&#xff1a; 系统功能&#xff1a; 项目截图…...

服务器攻击方式有哪几种?

随着互联网的快速发展&#xff0c;网络攻击事件频发&#xff0c;已泛滥成互联网行业的重病&#xff0c;受到了各个行业的关注与重视&#xff0c;因为它对网络安全乃至国家安全都形成了严重的威胁。面对复杂多样的网络攻击&#xff0c;想要有效防御就必须了解网络攻击的相关内容…...

【Unity3D】AB包加密(AssetBundle加密)

加密前&#xff1a; 加密后&#xff0c;直接无法加载ab&#xff0c;所以无法正常看到ab内容。 using UnityEngine; using UnityEditor; using System.IO; public static class AssetBundleDemoTest {[MenuItem("Tools/打包!")]public static void Build(){//注意:St…...

【FTP 协议】FTP主动模式

一、测试工具 服务器&#xff1a;FileZilla_Server-cn-0_9_60_2.exe 中文版本 客户端&#xff1a;FileZilla_3.66.5_win64 客户端IP: 192.168.9.186 服务端 IP: 192.168.9.161 在客户端请求PORT之前&#xff0c;抓包测试的结果跟被动模式流程相同。 二、客户端主动模式命令…...

十五、Vue 响应接口

文章目录 一、响应式系统基础什么是响应式系统响应式数据的声明与使用二、响应式原理深入Object.defineProperty () 方法的应用(Vue2)Proxy 对象的应用(Vue3)三、响应式接口之 ref 和 reactive(Vue3)ref 函数的使用reactive 函数的使用四、计算属性(computed)作为响应式…...

至强6搭配美光CZ122,证明CXL可以提高生成式AI的性能表现

最近发现了英特尔官网公布的一项最新测试报告&#xff0c;报告显示&#xff0c;将美光的CZ122 CXL内存模块放到英特尔至强6平台上&#xff0c;显著提升了HPC和AI工作负载的内存带宽&#xff0c;特别是在采用基于软件的交错配置&#xff08;interleave configuration&#xff09…...

一文理解ssh,ssl协议以及应用

在使用基于密钥的认证方式的时候&#xff0c;私钥的位置一定要符合远程服务器规定的位置&#xff0c;否则找不到私钥的位置会导致建立ssh连接失败 SSH 全称是 “Secure Shell”&#xff0c;即安全外壳协议。 它是一种网络协议&#xff0c;用于在不安全的网络中安全地进行远程登…...

电子应用设计方案87:智能AI收纳箱系统设计

智能 AI 收纳箱系统设计 一、引言 智能 AI 收纳箱系统旨在为用户提供更高效、便捷和智能的物品收纳与管理解决方案&#xff0c;通过融合人工智能技术和创新设计&#xff0c;提升用户的生活品质和物品整理效率。 二、系统概述 1. 系统目标 - 实现物品的自动分类和整理&#xf…...

BloombergGPT: A Large Language Model for Finance——面向金融领域的大语言模型

这篇文章介绍了BloombergGPT&#xff0c;一个专门为金融领域设计的大语言模型&#xff08;LLM&#xff09;。以下是文章的主要内容总结&#xff1a; 背景与动机&#xff1a; 大语言模型&#xff08;如GPT-3&#xff09;在多个任务上表现出色&#xff0c;但尚未有针对金融领域的…...

LeetCode - #180 Swift 实现连续数字查询

文章目录 摘要描述SQL 解法Swift 题解代码Swift 题解代码分析核心逻辑关键函数 示例测试及结果测试 1测试 2 时间复杂度空间复杂度总结 摘要 本文将解决如何从日志数据中找出连续出现至少三次的数字。通过 SQL 查询语句结合 Swift 数据库操作&#xff0c;我们将完成这一任务。…...

为什么ip属地一会河南一会江苏

在使用互联网的过程中&#xff0c;许多用户可能会遇到这样一个问题&#xff1a;自己的IP属地一会儿显示为河南&#xff0c;一会儿又变成了江苏。这种现象可能会让人感到困惑&#xff0c;甚至产生疑虑&#xff0c;担心自己的网络活动是否受到了某种影响。为了解答这一疑问&#…...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子&#xff0c;用于处理异步操作&#xff08;如数据加载&#xff09;中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误&#xff1a;捕获在 loader 或 action 中发生的异步错误替…...

【JavaEE】-- HTTP

1. HTTP是什么&#xff1f; HTTP&#xff08;全称为"超文本传输协议"&#xff09;是一种应用非常广泛的应用层协议&#xff0c;HTTP是基于TCP协议的一种应用层协议。 应用层协议&#xff1a;是计算机网络协议栈中最高层的协议&#xff0c;它定义了运行在不同主机上…...

css3笔记 (1) 自用

outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size&#xff1a;0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格&#xff…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 在 GPU 上对图像执行 均值漂移滤波&#xff08;Mean Shift Filtering&#xff09;&#xff0c;用于图像分割或平滑处理。 该函数将输入图像中的…...

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面&#xff0c;避免重复抓取&#xff0c;以节省资源和时间。 在分布式环境下&#xff0c;增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路&#xff1a;将增量判…...

Typeerror: cannot read properties of undefined (reading ‘XXX‘)

最近需要在离线机器上运行软件&#xff0c;所以得把软件用docker打包起来&#xff0c;大部分功能都没问题&#xff0c;出了一个奇怪的事情。同样的代码&#xff0c;在本机上用vscode可以运行起来&#xff0c;但是打包之后在docker里出现了问题。使用的是dialog组件&#xff0c;…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA&#xff1a;通过低成本全身远程操作学习双手移动操作 传统模仿学习&#xff08;Imitation Learning&#xff09;缺点&#xff1a;聚焦与桌面操作&#xff0c;缺乏通用任务所需的移动性和灵活性 本论文优点&#xff1a;&#xff08;1&#xff09;在ALOHA…...

MySQL JOIN 表过多的优化思路

当 MySQL 查询涉及大量表 JOIN 时&#xff0c;性能会显著下降。以下是优化思路和简易实现方法&#xff1a; 一、核心优化思路 减少 JOIN 数量 数据冗余&#xff1a;添加必要的冗余字段&#xff08;如订单表直接存储用户名&#xff09;合并表&#xff1a;将频繁关联的小表合并成…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

(一)单例模式

一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...