当前位置: 首页 > article >正文

吴恩达机器学习笔记复盘(六)梯度下降算法

简介

梯度下降(Gradient Descent)是一种常用的优化算法,广泛应用于机器学习、深度学习等领域,在这里是用于求J(w,b)局部最小值。

我自己觉得这样说有点过于抽象。换个直观点的说法就是,一个人站在了一座小土包上,这个人要去找周围的最低点,求这个局部最低点的数学过程,就是这个梯度下降算法。

基本原理

梯度下降的核心思想是基于函数的梯度信息来寻找函数的最小值。对于一个多元函数J(\theta),其中 \theta = (\theta_1, \theta_2, \cdots, \theta_n)是函数的参数向量,梯度 \nabla J(\theta)是一个向量,它的每个元素是函数J 对相应参数 \theta_i的偏导数 \frac{\partial J}{\partial \theta_i}

梯度的方向是函数在当前点上升最快的方向,那么负梯度方向就是函数下降最快的方向。算法通过不断地沿着负梯度方向更新参数,来逐步减小目标函数的值,直到达到一个局部最小值或全局最小值。

算法步骤

初始化参数

随机选择一个初始参数向量\theta^{(0)},它可以是一个随机的数值向量,也可以根据具体问题的先验知识进行初始化。

计算梯度

对于给定的参数\theta^{(t)}(t表示当前的迭代次数),计算目标函数J(\theta)在该点的梯度 \nabla J(\theta^{(t)})。这需要对目标函数进行求导,根据函数的具体形式使用相应的求导规则来计算每个参数的偏导数。

更新参数

根据计算得到的梯度,按照以下公式更新参数:\theta^{(t + 1)}=\theta^{(t)}-\alpha\nabla J(\theta^{(t)}),其中 \alpha 是学习率,它控制着每次更新的步长大小。学习率是一个重要的超参数,需要根据具体问题进行调整。

检查收敛条件

判断是否满足收敛条件,常见的收敛条件有:达到预设的最大迭代次数、目标函数的变化量小于某个阈值、参数的变化量小于某个阈值等。如果满足收敛条件,则停止迭代,输出当前的参数 \theta^{(t + 1)} 作为最优解;否则,返回步骤2继续迭代。

学习率的选择

学习率 \alpha决定了梯度下降算法的收敛速度和最终结果。如果学习率过大,可能会导致算法跳过最优解,甚至无法收敛;如果学习率过小,算法可能会收敛得非常缓慢,需要大量的迭代才能达到满意的结果。

为了选择合适的学习率,可以采用一些策略,如固定学习率、动态调整学习率(如随着迭代次数增加逐渐减小学习率)、使用自适应学习率算法(如Adagrad、Adadelta、RMSProp、Adam等,这些算法可以根据参数的更新情况自动调整学习率)。

梯度下降的变体

批量梯度下降(Batch Gradient Descent,BGD)

在每次更新参数时,使用整个训练数据集来计算梯度。优点是能够找到全局最优解的可能性较大,缺点是当训练数据集很大时,计算梯度的成本很高,导致训练速度慢。

随机梯度下降(Stochastic Gradient Descent,SGD)

每次更新参数时,随机选择一个训练样本,使用该样本的梯度来更新参数。优点是训练速度快,能够处理大规模数据集,缺点是由于每次只使用一个样本,梯度估计可能存在较大的噪声,导致收敛过程可能会有波动,不一定能准确地收敛到全局最优解。

小批量梯度下降(Mini - Batch Gradient Descent,MBGD)

结合了批量梯度下降和随机梯度下降的优点,每次更新参数时,使用一小部分训练样本(称为一个小批量)来计算梯度。小批量的大小通常在几十到几百之间。这种方法既能够利用小批量数据的统计信息来稳定梯度估计,又能够在一定程度上提高训练速度,是实际应用中最常用的梯度下降变体之一。

应用场景

梯度下降在机器学习和深度学习中有广泛的应用,例如在线性回归、逻辑回归、神经网络等模型的训练中,用于最小化损失函数,以找到最优的模型参数。通过不断地调整模型的参数,使得模型的预测结果与真实标签之间的差异最小化,从而提高模型的性能和泛化能力。在这里就是应用在J(w,b)函数上。

简单的代码示例

import numpy as np
import matplotlib.pyplot as pltdef gradient_descent(x, y, learning_rate, num_iterations):# 初始化参数m = 0  # 斜率b = 0  # 截距n = len(x)for iteration in range(num_iterations):# 计算预测值y_pred = m * x + b# 计算梯度dm = (-2 / n) * np.sum(x * (y - y_pred))db = (-2 / n) * np.sum(y - y_pred)# 更新参数m = m - learning_rate * dmb = b - learning_rate * dbreturn m, b# 生成一些示例数据
np.random.seed(0)
x = np.array([1, 2, 3, 4, 5])
y = np.array([5, 7, 9, 11, 13])# 设置超参数
learning_rate = 0.01
num_iterations = 1000# 运行梯度下降算法
m, b = gradient_descent(x, y, learning_rate, num_iterations)# 输出结果
print(f"斜率 m: {m}")
print(f"截距 b: {b}")# 绘制原始数据和拟合直线
plt.scatter(x, y, label='原始数据')
plt.plot(x, m * x + b, color='red', label='拟合直线')
plt.xlabel('x')
plt.ylabel('y')
plt.title('梯度下降线性回归')
plt.legend()
plt.show()

代码解释

gradient_descent` 函数

该函数实现了梯度下降算法的核心逻辑。它接受输入特征 `x`、目标值 `y`、学习率 `learning_rate` 和迭代次数 `num_iterations` 作为参数。在函数内部,首先初始化斜率 `m` 和截距 `b` 为 0,然后进行指定次数的迭代。在每次迭代中,计算预测值 `y_pred`,接着计算斜率和截距的梯度 `dm` 和 `db`,最后根据梯度更新斜率和截距。 (m对应w,b对应b)

示例数据生成

使用 `numpy` 生成了一些简单的示例数据 `x` 和 `y`,模拟线性关系。

设置超参数

设置学习率 `learning_rate` 为 0.01,迭代次数 `num_iterations` 为 1000。

运行梯度下降算法

调用 `gradient_descent` 函数,得到最优的斜率和截距。

输出结果和绘图

打印出最优的斜率和截距,并使用 `matplotlib` 绘制原始数据点和拟合直线,直观展示梯度下降算法的效果。

相关文章:

吴恩达机器学习笔记复盘(六)梯度下降算法

简介 梯度下降(Gradient Descent)是一种常用的优化算法,广泛应用于机器学习、深度学习等领域,在这里是用于求J(w,b)局部最小值。 我自己觉得这样说有点过于抽象。换个直观点的说法就是,一个人…...

【机器学习chp14 — 3】生成式模型—生成对抗网络GAN(超详细分析,易于理解,推导严谨,一文就够了)

目录 三、生成对抗网络 ( Generative Adversarial Networks,GAN ) 1、GAN的基本思想 (1)生成器与判别器的基本结构与演变 (2)“对抗”机制及名词由来 2、GAN训练的基本算法 (1)网络初始化与…...

机器人打磨控制技术

工具姿态调整运动 法线方向对齐运动:机器人实时调整工具姿态,使打磨工具的轴线与工件曲面的法线方向一致。例如,在球面打磨时,工具需始终垂直于球面切线。角度补偿运动:针对倾斜或不规则曲面,通过调整机器人…...

K8S学习之基础四十:K8S配置altermanager发送告警到钉钉群

配置altermanager发送告警到钉钉群 ​ 创建钉钉群,设置机器人助手(必须是管理员才能设置),获取webhook webhook: https://oapi.dingtalk.com/robot/send?access_token25bed933a52d69f192347b5be4b2193bc0b257a6d9ae68d81619e3ae3d93f7c6…...

Spring Boot + Spring Integration整合MQTT打造双向通信客户端

1. 概述 本文分两个章节讲解MQTT相关的知识,第一部份主要讲解MQTT的原理和相关配置,第二个章节主要讲和Spring boot的integration相结合代码的具体实现,如果想快速实现功能,可直接跳过第一章节查看第二章讲。 1.1 MQTT搭建 为了…...

Sampling – Model Context Protocol Specification

网页链接 https://spec.modelcontextprotocol.io/specification/draft/client/sampling/ 主要内容概述 该网页详细介绍了Model Context Protocol (MCP) 中的“Sampling”功能。Sampling允许服务器通过客户端请求语言模型(LLM)生成文本、音频或图像内容…...

Java 填充 PDF 模版

制作 PDF 模版 安装 OnlyOffice 从 OnlyOffice 官网下载 OnlyOffice Desktop,安装过程很简单,一路下一步即可。用 OnlyOffice 制作 PDF 模版(表单) 使用 OnlyOffice 表单设计器,制作表单,如下图 注意命名…...

前端项目中应该如何选择正确的图片格式

在前端项目中选择正确的图片格式是优化页面性能、提升用户体验的关键步骤之一。以下是常见图片格式的特点、适用场景及选择建议,帮助你在不同场景下做出最优决策: 一、常见图片格式对比 格式特点适用场景不适用场景JPEG- 有损压缩,文件小- 不…...

Vulnhub-dedecms织梦通关攻略

姿势一、通过文件管理器上传WebShell 第一步:进入后台,找到文件管理器上传木马文件 第二步:使用蚁剑进行连接 #文件地址 http://localhost/dedecms/shell.php 姿势二、修改模板⽂件拿WebShell 第一步:修改模板文件,删除…...

数据集获取

sklearn数据集 sklearn有四部分数据。其中sklearn的数据集有两部分真实的数据,一部分嵌入到了sklearn库中,即安装好sklearn后就自带了一部分数据,这些数据的规模比较小称为small toy datasets ,还有一部分数据是需要在网上下载的,sklearn提供了下载的api接口,这些数据规…...

实验12深度学习

实验12深度学习 一、实验目的 (1)理解并熟悉深度神经网络的工作原理; (2)熟悉常用的深度神经网络模型及其应用环境; (3)掌握Anaconda的安装和设置方法,进一步熟悉Jupyte…...

2024年消费者权益数据分析

📅 2024年315消费者权益数据分析 数据见:https://mp.weixin.qq.com/s/eV5GoionxhGpw7PunhOVnQ 一、引言 在数字化时代,消费者维权数据对于市场监管、商家诚信和行业发展具有重要价值。本文基于 2024年315平台线上投诉数据,采用数…...

零知识证明:区块链隐私保护的变革力量

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,…...

rag-给一篇几百页的pdf,如何从中找到关键信息并汇总出关系图

小思考 对pdf肯定要做模糊chunk,能用模型切分就用模型切分,不能用模型就用规则,规则要尽可能保存连续文本,特殊数据格式(图、表格)必须完整保存,必须能被捕捉到。这些独立的表格or图数据&#…...

Rust语言学习

Rust语言学习 通用编程概念所有权所有权引用和借用slice struct(结构体)定义并实例化一个结构体使用结构体方法语法 枚举 enums定义枚举match控制流运算符if let 简单控制流 使用包、Crate和模块管理不断增长的项目(模块系统)包和crate定义模块来控制作用…...

wordPress WooCommerce 本地文件包含漏洞复现(CVE-2025-1661)(附脚本)

免责申明: 本文所描述的漏洞及其复现步骤仅供网络安全研究与教育目的使用。任何人不得将本文提供的信息用于非法目的或未经授权的系统测试。作者不对任何由于使用本文信息而导致的直接或间接损害承担责任。如涉及侵权,请及时与我们联系,我们将尽快处理并删除相关内容。 0x0…...

【CSS文字渐变动画】

CSS文字渐变动画 HTML代码CSS代码效果图 HTML代码 <div class"title"><h1>今天是春分</h1><p>正是春天到来的日子&#xff0c;花都开了&#xff0c;小鸟也飞回来了&#xff0c;大山也绿了起来&#xff0c;空气也有点嫩嫩的气息了</p>…...

2021-06-15 C逆序存入数组的元素

缘由编程&#xff0c;逆序存入数组的元素_编程语言-CSDN问答 #define N 7 main() { static int a[N]{12,9,16,5,7,2,l},k,s; for(k0;k<N;k) Printf("%4d",a[k]);for (k0;k<N/2; k) {sa[k]; a[k]a[N-1-k]; a[N-1-k]s; } for (k0;k<N;k) Printf("%4…...

Qt 控件概述 QLabel

目录 QLabel显示类控件 label如何做到与窗口同步变化 边框 Frame QLabel显示类控件 ​​ ​​ textFormat &#xff1a;设置文件格式 ​ Pixmap &#xff1a;标签图片 label如何做到与窗口同步变化 Qt中对应用户的操作 &#xff1a; 事件和信号 拖拽窗口大小就会触发…...

k8s服务中userspace,iptables,和ipvs的比较

在 Kubernetes 中&#xff0c;kube-proxy 是负责实现服务负载均衡的组件。它支持三种代理模式&#xff1a;userspace、iptables 和 ipvs。这三种模式在性能、功能和复杂性上有所不同。以下是它们的详细比较&#xff1a; 1. Userspace 模式 Userspace 是 Kubernetes 最早支持的…...

Vue 渲染 LaTeX 公式 Markdown 库

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…...

KMP-子串匹配算法-关键点理解

1.理解next[]数组的使用与来历 2.求解next[]数组 一、kmp算法的原理 首先观察暴力解法&#xff1a;假设主串为&#xff1a;abdxxabc&#xff0c;模式串为abxxabd。 暴力解法&#xff0c;就是对主串每个字符作为第一个字符&#xff0c;开始和模式串比较。 比如&#xff1a;从…...

网络原理之网络层、数据链路层

1. 网络层 1.1 IP协议 1.1.1 基本概念 主机: 配有IP地址,但是不进⾏路由控制的设备路由器: 即配有IP地址,⼜能进⾏路由控制节点: 主机和路由器的统称 1.1.2 协议头格式 说明&#xff1a; 4位版本号(version): 指定IP协议的版本,对于IPv4来说,就是4,对于IPv6来说,就是6 4位头…...

ssh 多重验证的好处:降低密钥长度,动态密码

ssh 多重验证的好处&#xff1a; 多重验证&#xff1a;可能要比单纯提高密钥长度&#xff0c;或密码的长度更好&#xff0c;可以获得更好的保证服务器安全的效果。降低密钥长度&#xff1a;可以提高 CPU运行时的有效速度&#xff0c;特别是在传输大文件、或传输低比特率视频时…...

版本控制器Git ,Gitee如何连接Linux Gitee和Github区别

&#x1f4d6; 示例场景 假设你和朋友在开发一个「在线笔记网站」&#xff0c;代码需要频繁修改和协作&#xff1a; 只用本地文件管理 每次修改后手动复制文件&#xff0c;命名为 v1.html、v2.html 问题&#xff1a;无法追踪具体改动内容&#xff1b;多人修改易冲突&#xff1…...

网站测速:提升用户体验的关键

在互联网飞速发展的今天&#xff0c;网站已成为企业展示形象、提供服务以及用户获取信息的重要平台。而网站的速度&#xff0c;如同高速公路的路况&#xff0c;直接影响着用户的访问体验和满意度。因此&#xff0c;网站测速成为了网站运营和维护中不可或缺的关键环节。 网站速…...

【动态规划篇】91. 解码方法

91. 解码方法 题目链接&#xff1a; 91. 解码方法 题目叙述&#xff1a; 一条包含字母 A-Z 的消息通过以下映射进行了 编码 &#xff1a; “1” -> ‘A’ “2” -> ‘B’ … “25” -> ‘Y’ “26” -> ‘Z’ 然而&#xff0c;在解码已编码的消息时&#xff0c;你…...

Python高级——类的知识

一、知识梳理&#xff1a; 二、货币场景搭建&#xff1a; 1&#xff09;代码展示&#xff1a; class RMB:count 0def __init__(self,yuan0,jiao0,fen0):self.__yuan yuanself.__jiao jiaoself.__fen fenRMB.count 1def __add__(self, other):temp RMB()temp.__yuan se…...

Python 编程题 第十一节:选择排序、插入排序、删除字符、目标移动、尾部的0

选择排序 假定第一个为最小的为已排序序列&#xff0c;与后面的比较&#xff0c;找到未排序序列中最小的后&#xff0c;交换位置&#xff0c;获得最小元素&#xff0c;依次往后 lst[1,14,25,31,21,13,6,8,14,9,7] def selection_sort(lst):for i in range(len(lst)):min_inde…...

resnet与densenet的比较

一、 ResNet&#xff08;残差网络&#xff09;和 DenseNet&#xff08;密集连接网络&#xff09; ResNet&#xff08;残差网络&#xff09;和 DenseNet&#xff08;密集连接网络&#xff09;都是深度学习中非常经典的卷积神经网络架构&#xff0c;它们在图像分类、目标检测等诸…...