当前位置: 首页 > news >正文

使用Numpy手工模拟梯度下降算法

代码

import numpy as np # Compute every step manually# Linear regression
# f = w * x # here : f = 2 * x
X = np.array([1, 2, 3, 4], dtype=np.float32)
Y = np.array([2, 4, 6, 8], dtype=np.float32)w = 0.0# model output
def forward(x):return w * x# loss = MSE
def loss(y, y_pred):return ((y_pred - y)**2).mean()# J = MSE = 1/N * (w*x - y)**2
# dJ/dw = 1/N * 2x(w*x - y)
def gradient(x, y, y_pred):return np.mean(2*x*(y_pred - y))print(f'Prediction before training: f(5) = {forward(5):.3f}')# Training
learning_rate = 0.01
n_iters = 20for epoch in range(n_iters):# predict = forward passy_pred = forward(X)# lossl = loss(Y, y_pred)# calculate gradientsdw = gradient(X, Y, y_pred)# update weightsw -= learning_rate * dwif epoch % 2 == 0:print(f'epoch {epoch+1}: w = {w:.3f}, loss = {l:.8f}')print(f'Prediction after training: f(5) = {forward(5):.3f}')

输出

Prediction before training: f(5) = 0.000
epoch 1: w = 0.300, loss = 30.00000000
epoch 3: w = 0.772, loss = 15.66018677
epoch 5: w = 1.113, loss = 8.17471600
epoch 7: w = 1.359, loss = 4.26725292
epoch 9: w = 1.537, loss = 2.22753215
epoch 11: w = 1.665, loss = 1.16278565
epoch 13: w = 1.758, loss = 0.60698175
epoch 15: w = 1.825, loss = 0.31684822
epoch 17: w = 1.874, loss = 0.16539653
epoch 19: w = 1.909, loss = 0.08633806
Prediction after training: f(5) = 9.612

代码步骤详细解释

让我们通过一步一步地代入具体值来解释为什么在给定的线性回归示例中,权重 w w w 逐渐接近真实值,并且损失函数的值持续减小。这个过程展示了梯度下降法如何通过逐步迭代更新模型的权重 w w w 来最小化损失函数。

初始设置

  • 真实函数为 f ( x ) = 2 x f(x) = 2x f(x)=2x,我们的目标是通过学习找到这个关系。
  • 初始权重 w = 0.0 w = 0.0 w=0.0
  • 学习率 α = 0.01 \alpha = 0.01 α=0.01
  • 输入 X = [ 1 , 2 , 3 , 4 ] X = [1, 2, 3, 4] X=[1,2,3,4],对应的真实输出 Y = [ 2 , 4 , 6 , 8 ] Y = [2, 4, 6, 8] Y=[2,4,6,8]

第一次迭代

  1. 前向传播:使用初始权重 w = 0.0 w = 0.0 w=0.0 进行预测,
    y pred = w × X = 0 × X = [ 0 , 0 , 0 , 0 ] y_{\text{pred}} = w \times X = 0 \times X = [0, 0, 0, 0] ypred=w×X=0×X=[0,0,0,0]

  2. 计算损失(MSE):损失函数为 J = MSE = 1 N ∑ ( y pred − Y ) 2 J=\text{MSE} = \frac{1}{N} \sum (y_{\text{pred}} - Y)^2 J=MSE=N1(ypredY)2
    MSE = 1 4 ( ( 0 − 2 ) 2 + ( 0 − 4 ) 2 + ( 0 − 6 ) 2 + ( 0 − 8 ) 2 ) = 30 \text{MSE} = \frac{1}{4} \left((0-2)^2 + (0-4)^2 + (0-6)^2 + (0-8)^2\right) = 30 MSE=41((02)2+(04)2+(06)2+(08)2)=30

  3. 计算梯度:梯度 d J d w = 1 N ∑ 2 x ( w × x − y ) \frac{dJ}{dw} = \frac{1}{N} \sum 2x (w \times x - y) dwdJ=N12x(w×xy)
    d J d w = 1 4 ∑ 2 X ( 0 × X − Y ) = 1 4 ∑ 2 X ( − Y ) \frac{dJ}{dw} = \frac{1}{4} \sum 2X (0 \times X - Y) = \frac{1}{4} \sum 2X (-Y) dwdJ=412X(0×XY)=412X(Y)
    d J d w = 1 4 × 2 × ( ( 1 × − 2 ) + ( 2 × − 4 ) + ( 3 × − 6 ) + ( 4 × − 8 ) ) = − 30 \frac{dJ}{dw} = \frac{1}{4} \times 2 \times ((1 \times -2) + (2 \times -4) + (3 \times -6) + (4 \times -8)) = -30 dwdJ=41×2×((1×2)+(2×4)+(3×6)+(4×8))=30

  4. 更新权重 w = w − α d J d w w = w - \alpha \frac{dJ}{dw} w=wαdwdJ
    w = 0.0 − 0.01 × ( − 30 ) = 0.3 w = 0.0 - 0.01 \times (-30) = 0.3 w=0.00.01×(30)=0.3

这个过程解释了第一次迭代后为什么 w w w 更新为 0.3 并且损失减少到 30。梯度 d J d w = − 30 \frac{dJ}{dw} = -30 dwdJ=30 指示了 w w w 需要增加来减少损失

推导梯度

对于给定的输入 X X X 和输出 Y Y Y,梯度的计算可以展开为:
d J d w = 1 N ∑ 2 x ( w × x − y ) \frac{dJ}{dw} = \frac{1}{N} \sum 2x (w \times x - y) dwdJ=N12x(w×xy)

代入第一次迭代的值,
d J d w = 1 4 × 2 × [ 1 × ( 0 × 1 − 2 ) + 2 × ( 0 × 2 − 4 ) + 3 × ( 0 × 3 − 6 ) + 4 × ( 0 × 4 − 8 ) ] \frac{dJ}{dw} = \frac{1}{4} \times 2 \times [1 \times (0 \times 1 - 2) + 2 \times (0 \times 2 - 4) + 3 \times (0 \times 3 - 6) + 4 \times (0 \times 4 - 8)] dwdJ=41×2×[1×(0×12)+2×(0×24)+3×(0×36)+4×(0×48)]
= 1 4 × 2 × [ − 2 − 8 − 18 − 32 ] = − 30 = \frac{1}{4} \times 2 \times [-2 -8 -18 -32] = -30 =41×2×[281832]=30

让我们通过代入具体值来详细展示线性回归示例中第二次迭代的推导过程。

第二次迭代的起点

  • 初始权重(从第一次迭代更新后): w = 0.3 w = 0.3 w=0.3
  • 学习率: α = 0.01 \alpha = 0.01 α=0.01
  • 输入: X = [ 1 , 2 , 3 , 4 ] X = [1, 2, 3, 4] X=[1,2,3,4]
  • 真实输出: Y = [ 2 , 4 , 6 , 8 ] Y = [2, 4, 6, 8] Y=[2,4,6,8]

前向传播

计算预测值 y pred y_{\text{pred}} ypred
y pred = w × X = 0.3 × [ 1 , 2 , 3 , 4 ] = [ 0.3 , 0.6 , 0.9 , 1.2 ] y_{\text{pred}} = w \times X = 0.3 \times [1, 2, 3, 4] = [0.3, 0.6, 0.9, 1.2] ypred=w×X=0.3×[1,2,3,4]=[0.3,0.6,0.9,1.2]

损失计算(MSE)

L = 1 N ∑ i = 1 N ( y pred , i − Y i ) 2 L = \frac{1}{N} \sum_{i=1}^{N} (y_{\text{pred}, i} - Y_i)^2 L=N1i=1N(ypred,iYi)2
L = 1 4 ( ( 0.3 − 2 ) 2 + ( 0.6 − 4 ) 2 + ( 0.9 − 6 ) 2 + ( 1.2 − 8 ) 2 ) L = \frac{1}{4} \left((0.3-2)^2 + (0.6-4)^2 + (0.9-6)^2 + (1.2-8)^2\right) L=41((0.32)2+(0.64)2+(0.96)2+(1.28)2)
L = 1 4 ( 2.89 + 11.56 + 26.01 + 46.24 ) L = \frac{1}{4} \left(2.89 + 11.56 + 26.01 + 46.24\right) L=41(2.89+11.56+26.01+46.24)
L = 1 4 × 86.7 = 21.675 L = \frac{1}{4} \times 86.7 = 21.675 L=41×86.7=21.675

梯度计算

d L d w = 1 N ∑ i = 1 N 2 x i ( w x i − Y i ) \frac{dL}{dw} = \frac{1}{N} \sum_{i=1}^{N} 2x_i (w x_i - Y_i) dwdL=N1i=1N2xi(wxiYi)
d L d w = 1 4 × 2 × [ 1 × ( 0.3 × 1 − 2 ) + 2 × ( 0.3 × 2 − 4 ) + 3 × ( 0.3 × 3 − 6 ) + 4 × ( 0.3 × 4 − 8 ) ] \frac{dL}{dw} = \frac{1}{4} \times 2 \times [1 \times (0.3 \times 1 - 2) + 2 \times (0.3 \times 2 - 4) + 3 \times (0.3 \times 3 - 6) + 4 \times (0.3 \times 4 - 8)] dwdL=41×2×[1×(0.3×12)+2×(0.3×24)+3×(0.3×36)+4×(0.3×48)]
d L d w = 1 4 × 2 × [ ( − 1.7 ) + ( − 7.4 ) + ( − 16.1 ) + ( − 28.8 ) ] \frac{dL}{dw} = \frac{1}{4} \times 2 \times [(-1.7) + (-7.4) + (-16.1) + (-28.8)] dwdL=41×2×[(1.7)+(7.4)+(16.1)+(28.8)]
d L d w = 1 4 × 2 × [ − 53.999 ] = − 27.0 \frac{dL}{dw} = \frac{1}{4} \times 2 \times [-53.999] = -27.0 dwdL=41×2×[53.999]=27.0

更新权重

使用梯度下降法更新 w w w
w = w − α × d L d w w = w - \alpha \times \frac{dL}{dw} w=wα×dwdL
w = 0.3 − 0.01 × ( − 25.5 ) = 0.3 + 0.255 = 0.555 w = 0.3 - 0.01 \times (-25.5) = 0.3 + 0.255 = 0.555 w=0.30.01×(25.5)=0.3+0.255=0.555

这一系列计算表明,在第二次迭代中,通过计算损失和梯度,并根据这个梯度更新权重,权重 w w w 从 0.3 更新到了 0.555。这一过程逐步将模型从初步猜测调整为更接近真实模型 f ( x ) = 2 x f(x) = 2x f(x)=2x 的参数,损失从 21.675 21.675 21.675 减少,显示了模型准确度的提高。

总结

通过不断重复这个过程(前向传播、损失计算、梯度计算、权重更新), w w w 逐步被调整,以最小化模型的总损失。每次迭代,梯度告诉我们如何调整 w w w 以减少损失,学习率 α \alpha α 控制了这个调整的步长。随着迭代的进行,模型预测 y pred y_{\text{pred}} ypred 会逐渐接近真实值 Y Y Y,损失函数值会持续减小,直至收敛到最小值或达到学习的终止条件。

为什么梯度方向表明了减少损失的方向?

第一轮迭代中,梯度 d J d w = − 30 \frac{dJ}{dw} = -30 dwdJ=30 指出了权重 w w w 需要增加以减少损失。这是因为在梯度下降法中,我们通过从当前权重中减去梯度乘以学习率(一个小的正数)来更新权重。如果梯度为负(如此例中的 − 30 -30 30),减去一个负数相当于向正方向(增加)调整权重。

在梯度下降法中,梯度 d J d w \frac{dJ}{dw} dwdJ 描述了损失函数 J J J 关于权重 w w w 的变化率。如果梯度为负,这意味着增加 w w w 可以减少损失 J J J;如果梯度为正,减少 w w w 可以减少损失。

具体来说:

  • 梯度为负( d J d w < 0 \frac{dJ}{dw} < 0 dwdJ<0:这意味着增加权重 w w w (向梯度的反方向移动)会导致损失 J J J 减小。因此,为了减少损失,我们需要增加 w w w
  • 权重更新公式 w = w − α d J d w w = w - \alpha \frac{dJ}{dw} w=wαdwdJ)中,当 d J d w \frac{dJ}{dw} dwdJ 为负时, w w w 的更新实际上会增加 w w w 的值。

在我们的例子中,通过这种方式更新 w w w(从 0.0 0.0 0.0 更新到 0.3 0.3 0.3),正是因为我们沿着减少损失的方向调整了 w w w使得模型的预测与真实值之间的差异减小了,进而损失函数值减少。这个过程在多次迭代后,逐渐使模型更加准确,最终找到一个能够最小化损失函数的 w w w 值。

相关文章:

使用Numpy手工模拟梯度下降算法

代码 import numpy as np # Compute every step manually# Linear regression # f w * x # here : f 2 * x X np.array([1, 2, 3, 4], dtypenp.float32) Y np.array([2, 4, 6, 8], dtypenp.float32)w 0.0# model output def forward(x):return w * x# loss MSE def loss…...

金融数据采集与风险管理:Open-Spider工具的应用与实践

一、项目介绍 在当今快速发展的金融行业中&#xff0c;新的金融产品和服务层出不穷&#xff0c;为银行业务带来了巨大的机遇和挑战。为了帮助银行员工更好地应对这些挑战&#xff0c;我们曾成功实施了一个创新的项目&#xff0c;该项目采用了先进的爬虫技术&#xff0c;通过ope…...

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:动态属性设置)

动态设置组件的属性&#xff0c;支持开发者在属性设置时使用if/else语法&#xff0c;且根据需要使用多态样式设置属性。 说明&#xff1a; 从API Version 11开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 attributeModifier attributeMo…...

Vue class和style绑定:动态美化你的组件

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…...

[C++] Windows中字符串函数的种类

文章目录 C标准库函数VC CRT函数Win32 APILinux C标准库函数 #include || #include <string.h> || #include 都可以使用以下函数&#xff1a; char *strcpy(char *dest, const char *src) //将Src字符串拷贝到Dst字符串地址。没有目标内存大小检查&#xff0c;可能会导致…...

Django工具

一、分页器介绍 1.1、介绍 分页,就是当我们在页面中显示一些信息列表,内容过多,一个页面显示不完,需要分成多个页面进行显示时,使用的技术就是分页技术 在django项目中,一般是使用3种分页的技术: 自定义分页功能,所有的分页功能都是自己实现django的插件 django-pagin…...

vue ui Starting GUI 图形化配置web新项目

前言&#xff1a;在vue框架里面&#xff0c; 以往大家都是习惯用命令行 vue create 、vue init webpack创建新前端项目&#xff0c;而vue ui是一个可视化的图形界面&#xff0c;对于新手来说更加友好了&#xff0c;不但可以创建、管理、还可以更新vue项目&#xff0c;也可以下载…...

Unity InputField宽度自适应内容

在Unity中&#xff0c;InputField在我们输入内容时&#xff0c;只会显示适应初始宽度的最新内容&#xff0c;或者自定义长度内容。 那么&#xff0c;要实现宽度自适应内容就需要另寻他法了。 以下是通过一个控制脚本来实现的一个简单方法。 直接上脚本&#xff1a; using S…...

加快代码审查的 7 个最佳实践

目录 前言 1-保持小的拉取请求 2-使用拉取请求模板 3-实施响应时间 SLA 4-培训初级和中级工程师 5-设置持续集成管道 6-使用拉取请求审查应用程序 7-生成图表以可视化您的代码更改 前言 代码审查可能会很痛苦软件工程师经常抱怨审查过程缓慢&#xff0c;延迟下游任务&…...

C++读写Excel(xlnt库的使用)

一、简介 官网&#xff1a;https://github.com/tfussell/xlnt Cross-platform user-friendly xlsx library for C11 xlnt is a modern C library for manipulating spreadsheets in memory and reading/writing them from/to XLSX files as described in ECMA 376 4th edition…...

【工具】conda常用命令

Conda 是一个流行的包管理器和环境管理器&#xff0c;用于安装、部署和管理软件包及其依赖项。 创建环境&#xff1a; conda create --name myenv 这将创建一个名为 myenv 的新环境。 激活环境&#xff1a; conda activate myenv 这会激活名为 myenv 的环境。在 Windows 上&am…...

Dockerfile编写实践篇

Docker通过一种打包和分发的软件&#xff0c;完成传统容器的封装。这个用来充当容器分发角色的组件被称为镜像。Docker镜像是一个容器中运行程序的所有文件的捆绑快照。当使用Docker分发软件&#xff0c;其实就是分发这些镜像&#xff0c;并在接收的机器上创建容器。镜像在Dock…...

BJFU|计算机网络缩写对照表

之前有过这个题型&#xff0c;但23年没考&#xff0c;所以按需准备 A ACK (ACKnowledgement) 确认 ADSL (Asymmetric Digital Subscriber Line) 非对称数字用户线 API (Applicatin Programming Interface) 应用编程接口 ARP (Address Resolution Protocol) 地址解析协议 ARQ (…...

Grafana dashboards as ConfigMaps

文章目录 1. 简介2. 创建 configmaps3. grafana 界面查看 1. 简介 将 Grafana 仪表板存储为 Kubernetes ConfigMap 相比传统的通过 Grafana 界面导入仪表板有以下一些主要优点: 版本控制&#xff1a; ConfigMap 可以存储在版本控制系统(如Git)中,便于跟踪和管理仪表板的变更历…...

【QA-SYSTEMS】CANTATA-解决Jenkins中build Cantata报错

【更多软件使用问题请点击亿道电子官方网站查询】 1、 文档目标 解决Jenkins中build Cantata测试项目报找不到license server的错误。 2、 问题场景 在Jenkins中build Cantata测试项目&#xff0c;报错“Failed to figure out the license server correctly”。 3、软硬件环…...

个人网站展示(静态)

大学期间做了一个个人博客网站&#xff0c;纯H5编码的网站&#xff0c;利用php搭建了一个留言模块。 有需要源码的同学&#xff0c;可以联系我~ 首页&#xff1a; IT杂记模块 文人墨客模块 劳有所获模块 生活日志模块 关于我 一个推崇全栈开发的前端开发人员 微信: itrzzh …...

C++——内存管理、模板

一、C内存管理 在C语言中我们曾学习过动态内存管理的相关知识&#xff0c;通过malloc、calloc、realloc和free等对堆上的空间进行申请和释放。在C中我们同样会面临类似的需求&#xff0c;因此C对动态开辟内存的方式进行了一些调整&#xff0c;我们可以使用new和delete操作符来对…...

商品上传上货搬家使用1688商品采集api接口

1688.item_get 公共参数 名称类型必须描述keyString是调用key&#xff08;必须以GET方式拼接在URL中&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&#xff09;[item_search,item_get,item_search_shop等]cacheString否[yes,no…...

redisson解决redis服务器的主从一致性问题

redisson解决redis的主节点和从节点一致性的问题。从而解决锁被错误获取的情况。 实际开发中我们会搭建多台redis服务器&#xff0c;但这些服务器分主次&#xff0c;主服务器负责处理写的操作&#xff08;增删改&#xff09;&#xff0c;从服务器负责处理读的操作&#xff0c;…...

Vue-router

router的使用&#xff08;52&#xff09; 5个基础步骤&#xff1a; 1.在终端执行yarn add vue-router3.6.5&#xff0c;安装router插件 yarn add vue-router3.6.5 2.在文件的main.js中引入router插件 import VueRouter from vue-router 3.在main.js中安装注册Vue.use(Vue…...

微信小程序之bind和catch

这两个呢&#xff0c;都是绑定事件用的&#xff0c;具体使用有些小区别。 官方文档&#xff1a; 事件冒泡处理不同 bind&#xff1a;绑定的事件会向上冒泡&#xff0c;即触发当前组件的事件后&#xff0c;还会继续触发父组件的相同事件。例如&#xff0c;有一个子视图绑定了b…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法&#xff0c;用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理&#xff0c;能够自动确定一个阈值&#xff0c;将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

MySQL JOIN 表过多的优化思路

当 MySQL 查询涉及大量表 JOIN 时&#xff0c;性能会显著下降。以下是优化思路和简易实现方法&#xff1a; 一、核心优化思路 减少 JOIN 数量 数据冗余&#xff1a;添加必要的冗余字段&#xff08;如订单表直接存储用户名&#xff09;合并表&#xff1a;将频繁关联的小表合并成…...

计算机基础知识解析:从应用到架构的全面拆解

目录 前言 1、 计算机的应用领域&#xff1a;无处不在的数字助手 2、 计算机的进化史&#xff1a;从算盘到量子计算 3、计算机的分类&#xff1a;不止 “台式机和笔记本” 4、计算机的组件&#xff1a;硬件与软件的协同 4.1 硬件&#xff1a;五大核心部件 4.2 软件&#…...

【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)

LeetCode 3309. 连接二进制表示可形成的最大数值&#xff08;中等&#xff09; 题目描述解题思路Java代码 题目描述 题目链接&#xff1a;LeetCode 3309. 连接二进制表示可形成的最大数值&#xff08;中等&#xff09; 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...

tomcat指定使用的jdk版本

说明 有时候需要对tomcat配置指定的jdk版本号&#xff0c;此时&#xff0c;我们可以通过以下方式进行配置 设置方式 找到tomcat的bin目录中的setclasspath.bat。如果是linux系统则是setclasspath.sh set JAVA_HOMEC:\Program Files\Java\jdk8 set JRE_HOMEC:\Program Files…...

人工智能 - 在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型

在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型。这些平台各有侧重&#xff0c;适用场景差异显著。下面我将从核心功能定位、典型应用场景、真实体验痛点、选型决策关键点进行拆解&#xff0c;并提供具体场景下的推荐方案。 一、核心功能定位速览 平台核心定位技术栈亮…...

Windows电脑能装鸿蒙吗_Windows电脑体验鸿蒙电脑操作系统教程

鸿蒙电脑版操作系统来了&#xff0c;很多小伙伴想体验鸿蒙电脑版操作系统&#xff0c;可惜&#xff0c;鸿蒙系统并不支持你正在使用的传统的电脑来安装。不过可以通过可以使用华为官方提供的虚拟机&#xff0c;来体验大家心心念念的鸿蒙系统啦&#xff01;注意&#xff1a;虚拟…...