使用Numpy手工模拟梯度下降算法
代码
import numpy as np # Compute every step manually# Linear regression
# f = w * x # here : f = 2 * x
X = np.array([1, 2, 3, 4], dtype=np.float32)
Y = np.array([2, 4, 6, 8], dtype=np.float32)w = 0.0# model output
def forward(x):return w * x# loss = MSE
def loss(y, y_pred):return ((y_pred - y)**2).mean()# J = MSE = 1/N * (w*x - y)**2
# dJ/dw = 1/N * 2x(w*x - y)
def gradient(x, y, y_pred):return np.mean(2*x*(y_pred - y))print(f'Prediction before training: f(5) = {forward(5):.3f}')# Training
learning_rate = 0.01
n_iters = 20for epoch in range(n_iters):# predict = forward passy_pred = forward(X)# lossl = loss(Y, y_pred)# calculate gradientsdw = gradient(X, Y, y_pred)# update weightsw -= learning_rate * dwif epoch % 2 == 0:print(f'epoch {epoch+1}: w = {w:.3f}, loss = {l:.8f}')print(f'Prediction after training: f(5) = {forward(5):.3f}')
输出
Prediction before training: f(5) = 0.000
epoch 1: w = 0.300, loss = 30.00000000
epoch 3: w = 0.772, loss = 15.66018677
epoch 5: w = 1.113, loss = 8.17471600
epoch 7: w = 1.359, loss = 4.26725292
epoch 9: w = 1.537, loss = 2.22753215
epoch 11: w = 1.665, loss = 1.16278565
epoch 13: w = 1.758, loss = 0.60698175
epoch 15: w = 1.825, loss = 0.31684822
epoch 17: w = 1.874, loss = 0.16539653
epoch 19: w = 1.909, loss = 0.08633806
Prediction after training: f(5) = 9.612
代码步骤详细解释
让我们通过一步一步地代入具体值来解释为什么在给定的线性回归示例中,权重 w w w 逐渐接近真实值,并且损失函数的值持续减小。这个过程展示了梯度下降法如何通过逐步迭代更新模型的权重 w w w 来最小化损失函数。
初始设置
- 真实函数为 f ( x ) = 2 x f(x) = 2x f(x)=2x,我们的目标是通过学习找到这个关系。
- 初始权重 w = 0.0 w = 0.0 w=0.0。
- 学习率 α = 0.01 \alpha = 0.01 α=0.01。
- 输入 X = [ 1 , 2 , 3 , 4 ] X = [1, 2, 3, 4] X=[1,2,3,4],对应的真实输出 Y = [ 2 , 4 , 6 , 8 ] Y = [2, 4, 6, 8] Y=[2,4,6,8]。
第一次迭代
-
前向传播:使用初始权重 w = 0.0 w = 0.0 w=0.0 进行预测,
y pred = w × X = 0 × X = [ 0 , 0 , 0 , 0 ] y_{\text{pred}} = w \times X = 0 \times X = [0, 0, 0, 0] ypred=w×X=0×X=[0,0,0,0] -
计算损失(MSE):损失函数为 J = MSE = 1 N ∑ ( y pred − Y ) 2 J=\text{MSE} = \frac{1}{N} \sum (y_{\text{pred}} - Y)^2 J=MSE=N1∑(ypred−Y)2,
MSE = 1 4 ( ( 0 − 2 ) 2 + ( 0 − 4 ) 2 + ( 0 − 6 ) 2 + ( 0 − 8 ) 2 ) = 30 \text{MSE} = \frac{1}{4} \left((0-2)^2 + (0-4)^2 + (0-6)^2 + (0-8)^2\right) = 30 MSE=41((0−2)2+(0−4)2+(0−6)2+(0−8)2)=30 -
计算梯度:梯度 d J d w = 1 N ∑ 2 x ( w × x − y ) \frac{dJ}{dw} = \frac{1}{N} \sum 2x (w \times x - y) dwdJ=N1∑2x(w×x−y),
d J d w = 1 4 ∑ 2 X ( 0 × X − Y ) = 1 4 ∑ 2 X ( − Y ) \frac{dJ}{dw} = \frac{1}{4} \sum 2X (0 \times X - Y) = \frac{1}{4} \sum 2X (-Y) dwdJ=41∑2X(0×X−Y)=41∑2X(−Y)
d J d w = 1 4 × 2 × ( ( 1 × − 2 ) + ( 2 × − 4 ) + ( 3 × − 6 ) + ( 4 × − 8 ) ) = − 30 \frac{dJ}{dw} = \frac{1}{4} \times 2 \times ((1 \times -2) + (2 \times -4) + (3 \times -6) + (4 \times -8)) = -30 dwdJ=41×2×((1×−2)+(2×−4)+(3×−6)+(4×−8))=−30 -
更新权重: w = w − α d J d w w = w - \alpha \frac{dJ}{dw} w=w−αdwdJ,
w = 0.0 − 0.01 × ( − 30 ) = 0.3 w = 0.0 - 0.01 \times (-30) = 0.3 w=0.0−0.01×(−30)=0.3
这个过程解释了第一次迭代后为什么 w w w 更新为 0.3 并且损失减少到 30。梯度 d J d w = − 30 \frac{dJ}{dw} = -30 dwdJ=−30 指示了 w w w 需要增加来减少损失。
推导梯度
对于给定的输入 X X X 和输出 Y Y Y,梯度的计算可以展开为:
d J d w = 1 N ∑ 2 x ( w × x − y ) \frac{dJ}{dw} = \frac{1}{N} \sum 2x (w \times x - y) dwdJ=N1∑2x(w×x−y)
代入第一次迭代的值,
d J d w = 1 4 × 2 × [ 1 × ( 0 × 1 − 2 ) + 2 × ( 0 × 2 − 4 ) + 3 × ( 0 × 3 − 6 ) + 4 × ( 0 × 4 − 8 ) ] \frac{dJ}{dw} = \frac{1}{4} \times 2 \times [1 \times (0 \times 1 - 2) + 2 \times (0 \times 2 - 4) + 3 \times (0 \times 3 - 6) + 4 \times (0 \times 4 - 8)] dwdJ=41×2×[1×(0×1−2)+2×(0×2−4)+3×(0×3−6)+4×(0×4−8)]
= 1 4 × 2 × [ − 2 − 8 − 18 − 32 ] = − 30 = \frac{1}{4} \times 2 \times [-2 -8 -18 -32] = -30 =41×2×[−2−8−18−32]=−30
让我们通过代入具体值来详细展示线性回归示例中第二次迭代的推导过程。
第二次迭代的起点
- 初始权重(从第一次迭代更新后): w = 0.3 w = 0.3 w=0.3
- 学习率: α = 0.01 \alpha = 0.01 α=0.01
- 输入: X = [ 1 , 2 , 3 , 4 ] X = [1, 2, 3, 4] X=[1,2,3,4]
- 真实输出: Y = [ 2 , 4 , 6 , 8 ] Y = [2, 4, 6, 8] Y=[2,4,6,8]
前向传播
计算预测值 y pred y_{\text{pred}} ypred:
y pred = w × X = 0.3 × [ 1 , 2 , 3 , 4 ] = [ 0.3 , 0.6 , 0.9 , 1.2 ] y_{\text{pred}} = w \times X = 0.3 \times [1, 2, 3, 4] = [0.3, 0.6, 0.9, 1.2] ypred=w×X=0.3×[1,2,3,4]=[0.3,0.6,0.9,1.2]
损失计算(MSE)
L = 1 N ∑ i = 1 N ( y pred , i − Y i ) 2 L = \frac{1}{N} \sum_{i=1}^{N} (y_{\text{pred}, i} - Y_i)^2 L=N1i=1∑N(ypred,i−Yi)2
L = 1 4 ( ( 0.3 − 2 ) 2 + ( 0.6 − 4 ) 2 + ( 0.9 − 6 ) 2 + ( 1.2 − 8 ) 2 ) L = \frac{1}{4} \left((0.3-2)^2 + (0.6-4)^2 + (0.9-6)^2 + (1.2-8)^2\right) L=41((0.3−2)2+(0.6−4)2+(0.9−6)2+(1.2−8)2)
L = 1 4 ( 2.89 + 11.56 + 26.01 + 46.24 ) L = \frac{1}{4} \left(2.89 + 11.56 + 26.01 + 46.24\right) L=41(2.89+11.56+26.01+46.24)
L = 1 4 × 86.7 = 21.675 L = \frac{1}{4} \times 86.7 = 21.675 L=41×86.7=21.675
梯度计算
d L d w = 1 N ∑ i = 1 N 2 x i ( w x i − Y i ) \frac{dL}{dw} = \frac{1}{N} \sum_{i=1}^{N} 2x_i (w x_i - Y_i) dwdL=N1i=1∑N2xi(wxi−Yi)
d L d w = 1 4 × 2 × [ 1 × ( 0.3 × 1 − 2 ) + 2 × ( 0.3 × 2 − 4 ) + 3 × ( 0.3 × 3 − 6 ) + 4 × ( 0.3 × 4 − 8 ) ] \frac{dL}{dw} = \frac{1}{4} \times 2 \times [1 \times (0.3 \times 1 - 2) + 2 \times (0.3 \times 2 - 4) + 3 \times (0.3 \times 3 - 6) + 4 \times (0.3 \times 4 - 8)] dwdL=41×2×[1×(0.3×1−2)+2×(0.3×2−4)+3×(0.3×3−6)+4×(0.3×4−8)]
d L d w = 1 4 × 2 × [ ( − 1.7 ) + ( − 7.4 ) + ( − 16.1 ) + ( − 28.8 ) ] \frac{dL}{dw} = \frac{1}{4} \times 2 \times [(-1.7) + (-7.4) + (-16.1) + (-28.8)] dwdL=41×2×[(−1.7)+(−7.4)+(−16.1)+(−28.8)]
d L d w = 1 4 × 2 × [ − 53.999 ] = − 27.0 \frac{dL}{dw} = \frac{1}{4} \times 2 \times [-53.999] = -27.0 dwdL=41×2×[−53.999]=−27.0
更新权重
使用梯度下降法更新 w w w:
w = w − α × d L d w w = w - \alpha \times \frac{dL}{dw} w=w−α×dwdL
w = 0.3 − 0.01 × ( − 25.5 ) = 0.3 + 0.255 = 0.555 w = 0.3 - 0.01 \times (-25.5) = 0.3 + 0.255 = 0.555 w=0.3−0.01×(−25.5)=0.3+0.255=0.555
这一系列计算表明,在第二次迭代中,通过计算损失和梯度,并根据这个梯度更新权重,权重 w w w 从 0.3 更新到了 0.555。这一过程逐步将模型从初步猜测调整为更接近真实模型 f ( x ) = 2 x f(x) = 2x f(x)=2x 的参数,损失从 21.675 21.675 21.675 减少,显示了模型准确度的提高。
总结
通过不断重复这个过程(前向传播、损失计算、梯度计算、权重更新), w w w 逐步被调整,以最小化模型的总损失。每次迭代,梯度告诉我们如何调整 w w w 以减少损失,学习率 α \alpha α 控制了这个调整的步长。随着迭代的进行,模型预测 y pred y_{\text{pred}} ypred 会逐渐接近真实值 Y Y Y,损失函数值会持续减小,直至收敛到最小值或达到学习的终止条件。
为什么梯度方向表明了减少损失的方向?
第一轮迭代中,梯度 d J d w = − 30 \frac{dJ}{dw} = -30 dwdJ=−30 指出了权重 w w w 需要增加以减少损失。这是因为在梯度下降法中,我们通过从当前权重中减去梯度乘以学习率(一个小的正数)来更新权重。如果梯度为负(如此例中的 − 30 -30 −30),减去一个负数相当于向正方向(增加)调整权重。
在梯度下降法中,梯度 d J d w \frac{dJ}{dw} dwdJ 描述了损失函数 J J J 关于权重 w w w 的变化率。如果梯度为负,这意味着增加 w w w 可以减少损失 J J J;如果梯度为正,减少 w w w 可以减少损失。
具体来说:
- 梯度为负( d J d w < 0 \frac{dJ}{dw} < 0 dwdJ<0):这意味着增加权重 w w w (向梯度的反方向移动)会导致损失 J J J 减小。因此,为了减少损失,我们需要增加 w w w。
- 权重更新公式( w = w − α d J d w w = w - \alpha \frac{dJ}{dw} w=w−αdwdJ)中,当 d J d w \frac{dJ}{dw} dwdJ 为负时, w w w 的更新实际上会增加 w w w 的值。
在我们的例子中,通过这种方式更新 w w w(从 0.0 0.0 0.0 更新到 0.3 0.3 0.3),正是因为我们沿着减少损失的方向调整了 w w w,使得模型的预测与真实值之间的差异减小了,进而损失函数值减少。这个过程在多次迭代后,逐渐使模型更加准确,最终找到一个能够最小化损失函数的 w w w 值。
相关文章:
使用Numpy手工模拟梯度下降算法
代码 import numpy as np # Compute every step manually# Linear regression # f w * x # here : f 2 * x X np.array([1, 2, 3, 4], dtypenp.float32) Y np.array([2, 4, 6, 8], dtypenp.float32)w 0.0# model output def forward(x):return w * x# loss MSE def loss…...
金融数据采集与风险管理:Open-Spider工具的应用与实践
一、项目介绍 在当今快速发展的金融行业中,新的金融产品和服务层出不穷,为银行业务带来了巨大的机遇和挑战。为了帮助银行员工更好地应对这些挑战,我们曾成功实施了一个创新的项目,该项目采用了先进的爬虫技术,通过ope…...
鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:动态属性设置)
动态设置组件的属性,支持开发者在属性设置时使用if/else语法,且根据需要使用多态样式设置属性。 说明: 从API Version 11开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 attributeModifier attributeMo…...
Vue class和style绑定:动态美化你的组件
🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…...
[C++] Windows中字符串函数的种类
文章目录 C标准库函数VC CRT函数Win32 APILinux C标准库函数 #include || #include <string.h> || #include 都可以使用以下函数: char *strcpy(char *dest, const char *src) //将Src字符串拷贝到Dst字符串地址。没有目标内存大小检查,可能会导致…...
Django工具
一、分页器介绍 1.1、介绍 分页,就是当我们在页面中显示一些信息列表,内容过多,一个页面显示不完,需要分成多个页面进行显示时,使用的技术就是分页技术 在django项目中,一般是使用3种分页的技术: 自定义分页功能,所有的分页功能都是自己实现django的插件 django-pagin…...
vue ui Starting GUI 图形化配置web新项目
前言:在vue框架里面, 以往大家都是习惯用命令行 vue create 、vue init webpack创建新前端项目,而vue ui是一个可视化的图形界面,对于新手来说更加友好了,不但可以创建、管理、还可以更新vue项目,也可以下载…...
Unity InputField宽度自适应内容
在Unity中,InputField在我们输入内容时,只会显示适应初始宽度的最新内容,或者自定义长度内容。 那么,要实现宽度自适应内容就需要另寻他法了。 以下是通过一个控制脚本来实现的一个简单方法。 直接上脚本: using S…...
加快代码审查的 7 个最佳实践
目录 前言 1-保持小的拉取请求 2-使用拉取请求模板 3-实施响应时间 SLA 4-培训初级和中级工程师 5-设置持续集成管道 6-使用拉取请求审查应用程序 7-生成图表以可视化您的代码更改 前言 代码审查可能会很痛苦软件工程师经常抱怨审查过程缓慢,延迟下游任务&…...
C++读写Excel(xlnt库的使用)
一、简介 官网:https://github.com/tfussell/xlnt Cross-platform user-friendly xlsx library for C11 xlnt is a modern C library for manipulating spreadsheets in memory and reading/writing them from/to XLSX files as described in ECMA 376 4th edition…...
【工具】conda常用命令
Conda 是一个流行的包管理器和环境管理器,用于安装、部署和管理软件包及其依赖项。 创建环境: conda create --name myenv 这将创建一个名为 myenv 的新环境。 激活环境: conda activate myenv 这会激活名为 myenv 的环境。在 Windows 上&am…...
Dockerfile编写实践篇
Docker通过一种打包和分发的软件,完成传统容器的封装。这个用来充当容器分发角色的组件被称为镜像。Docker镜像是一个容器中运行程序的所有文件的捆绑快照。当使用Docker分发软件,其实就是分发这些镜像,并在接收的机器上创建容器。镜像在Dock…...
BJFU|计算机网络缩写对照表
之前有过这个题型,但23年没考,所以按需准备 A ACK (ACKnowledgement) 确认 ADSL (Asymmetric Digital Subscriber Line) 非对称数字用户线 API (Applicatin Programming Interface) 应用编程接口 ARP (Address Resolution Protocol) 地址解析协议 ARQ (…...
Grafana dashboards as ConfigMaps
文章目录 1. 简介2. 创建 configmaps3. grafana 界面查看 1. 简介 将 Grafana 仪表板存储为 Kubernetes ConfigMap 相比传统的通过 Grafana 界面导入仪表板有以下一些主要优点: 版本控制: ConfigMap 可以存储在版本控制系统(如Git)中,便于跟踪和管理仪表板的变更历…...
【QA-SYSTEMS】CANTATA-解决Jenkins中build Cantata报错
【更多软件使用问题请点击亿道电子官方网站查询】 1、 文档目标 解决Jenkins中build Cantata测试项目报找不到license server的错误。 2、 问题场景 在Jenkins中build Cantata测试项目,报错“Failed to figure out the license server correctly”。 3、软硬件环…...
个人网站展示(静态)
大学期间做了一个个人博客网站,纯H5编码的网站,利用php搭建了一个留言模块。 有需要源码的同学,可以联系我~ 首页: IT杂记模块 文人墨客模块 劳有所获模块 生活日志模块 关于我 一个推崇全栈开发的前端开发人员 微信: itrzzh …...
C++——内存管理、模板
一、C内存管理 在C语言中我们曾学习过动态内存管理的相关知识,通过malloc、calloc、realloc和free等对堆上的空间进行申请和释放。在C中我们同样会面临类似的需求,因此C对动态开辟内存的方式进行了一些调整,我们可以使用new和delete操作符来对…...
商品上传上货搬家使用1688商品采集api接口
1688.item_get 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥api_nameString是API接口名称(包括在请求地址中)[item_search,item_get,item_search_shop等]cacheString否[yes,no…...
redisson解决redis服务器的主从一致性问题
redisson解决redis的主节点和从节点一致性的问题。从而解决锁被错误获取的情况。 实际开发中我们会搭建多台redis服务器,但这些服务器分主次,主服务器负责处理写的操作(增删改),从服务器负责处理读的操作,…...
Vue-router
router的使用(52) 5个基础步骤: 1.在终端执行yarn add vue-router3.6.5,安装router插件 yarn add vue-router3.6.5 2.在文件的main.js中引入router插件 import VueRouter from vue-router 3.在main.js中安装注册Vue.use(Vue…...
mongodb源码分析session执行handleRequest命令find过程
mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程,并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令,把数据流转换成Message,状态转变流程是:State::Created 》 St…...
基于服务器使用 apt 安装、配置 Nginx
🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...
服务器硬防的应用场景都有哪些?
服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...
TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案
一、TRS收益互换的本质与业务逻辑 (一)概念解析 TRS(Total Return Swap)收益互换是一种金融衍生工具,指交易双方约定在未来一定期限内,基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...
多种风格导航菜单 HTML 实现(附源码)
下面我将为您展示 6 种不同风格的导航菜单实现,每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】
大家好,我是java1234_小锋老师,看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】,分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...
宇树科技,改名了!
提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...
多模态图像修复系统:基于深度学习的图片修复实现
多模态图像修复系统:基于深度学习的图片修复实现 1. 系统概述 本系统使用多模态大模型(Stable Diffusion Inpainting)实现图像修复功能,结合文本描述和图片输入,对指定区域进行内容修复。系统包含完整的数据处理、模型训练、推理部署流程。 import torch import numpy …...
Scrapy-Redis分布式爬虫架构的可扩展性与容错性增强:基于微服务与容器化的解决方案
在大数据时代,海量数据的采集与处理成为企业和研究机构获取信息的关键环节。Scrapy-Redis作为一种经典的分布式爬虫架构,在处理大规模数据抓取任务时展现出强大的能力。然而,随着业务规模的不断扩大和数据抓取需求的日益复杂,传统…...
