线性回归矩阵求解和梯度求解
正规方程求解线性回归
首先正规方程如下:
Θ = ( X T X ) − 1 X T y \begin{equation} \Theta = (X^T X)^{-1} X^T y \end{equation} Θ=(XTX)−1XTy
接下来通过线性代数的角度理解这个问题。
二维空间
在二维空间上,有两个向量 a a a和 b b b,若 b b b投影到 a a a要怎么做,很简单,做垂线, 那么投影后的向量记为 p p p,那么 b b b和 p p p之间的error记为 e = b − p e=b-p e=b−p。同时 p p p在 a a a上,所以 p p p一定是 a a a的 x x x(标量)倍,记为 p = x a p=xa p=xa。因为 e e e垂直 a a a,所以 a T ( b − x a ) = 0 a^T(b-xa)=0 aT(b−xa)=0 ,即 x a T a = a T b xa^Ta=a^Tb xaTa=aTb,得到
x = a T b a T a x=\frac{a^Tb}{a^Ta} x=aTaaTb
那么
p = x a = a a T b a T a p=xa=a\frac{a^Tb}{a^Ta} p=xa=aaTaaTb
根据上面的公式,如果 a a a翻倍了,那么投影不变,如果 b b b翻倍了,投影也翻倍。投影是由一个矩阵 P P P完成的, p = P b p=Pb p=Pb,那么投影矩阵 P P P:
P = a a T a T a P=\frac{aa^T}{a^Ta} P=aTaaaT
用任何向量乘这个投影矩阵,你总会变换到它的列空间中。同时显然有: P T = P P^T=P PT=P , P 2 = P P^2=P P2=P,即投影两次的结果还是和第一次一样。
高维空间
为什么要做投影呢?
因为, A x = b Ax=b Ax=b可能无解,比如一堆等式,比未知数还多,就可能造成无解。那么该怎么办,只能求解最接近的哪个可能解,哪个才是最接近的呢?问题是 A x Ax Ax总是在 A A A的列空间中,而 b b b不一定在。所以要怎么微调 b b b将它变为列空间中最接近它的那一个,那么就将问题换作求解,有解的 A x ^ = p A\hat{x}=p Ax^=p。所以得找最好的那个投影 p p p,以最好的接近 b b b,这就是为什么要引入投影的原因了。
那么我们来看高维空间,这里以三维空间举例,自然可以推广到n维空间。
现在有一个不在平面上的 b b b向量,想要将 b b b投影在平面上,平面可以由两个基向量 a 1 a_1 a1和 a 2 a_2 a2表示。同样的 b b b投影到平面上的误差记为 e = b − p e=b-p e=b−p,这个 e e e是垂直平面的。 p = x 1 ^ a 1 + x 2 ^ a 2 = A x ^ p=\hat{x_1}a_1+\hat{x_2}a_2=A\hat{x} p=x1^a1+x2^a2=Ax^,我们想要解出 x ^ \hat{x} x^。因为 e e e是垂直平面,所以有 b − A x ^ b-A\hat{x} b−Ax^垂直平面,即有 a 1 T ( b − A x ^ ) = 0 a_1^T(b-A\hat{x})=0 a1T(b−Ax^)=0, a 2 T ( b − A x ^ ) = 0 a_2^T(b-A\hat{x})=0 a2T(b−Ax^)=0,表示为矩阵乘法便有
A T ( b − A x ^ ) = A e = 0 A^T(b-A\hat{x})=Ae=0 AT(b−Ax^)=Ae=0
这个形式与二维空间的很像吧。对于 A e = 0 Ae=0 Ae=0,可知 e e e位于 A T A^T AT的零空间,也就是说 e e e垂直于于 A A A的列空间。由上面式子可得
A T A x ^ = A T b A^TA\hat{x}=A^Tb ATAx^=ATb
继而
x ^ = ( A T A ) − 1 A T b \hat{x}=(A^TA)^{-1}A^Tb x^=(ATA)−1ATb
这不就是我们的正规方程吗。到这里我们的正规方程便推导出来了,但为了内容完整,我们下面收个尾。
p = A x ^ = A ( A T A ) − 1 A T b P = A ( A T A ) − 1 A T P T = P P 2 = P p=A\hat{x}=A(A^TA)^{-1}A^Tb \\ P=A(A^TA)^{-1}A^T\\ P^T=P\\ P^2=P p=Ax^=A(ATA)−1ATbP=A(ATA)−1ATPT=PP2=P
这些结论还是和二维空间上的一样, P T = P P^T=P PT=P , P 2 = P P^2=P P2=P,即投影两次的结果还是和第一次一样。
最小二乘法
正规方程的一个常见应用例子是最小二乘法。从线性代数的角度来看,正规方程是通过最小二乘法求解线性回归问题的一种方法。以下是正规方程的概述:
1. 模型表示
在线性回归中,我们假设目标变量 y y y 与特征矩阵 X X X 之间存在线性关系:
y ^ = X θ \hat{y} = X \theta y^=Xθ
其中:
- y ^ \hat{y} y^ 是预测值(一个 m m m 维列向量)。
- X X X 是特征矩阵( m × n m \times n m×n),每行代表一个样本,每列代表一个特征。
- θ \theta θ 是模型参数(权重向量)。
2. 目标函数
我们的目标是最小化预测值与实际值之间的误差,通常使用残差平方和:
J ( θ ) = ∥ y − X θ ∥ 2 J(\theta) = \|y - X\theta\|^2 J(θ)=∥y−Xθ∥2
3. 求解过程
为了找到使得 J ( θ ) J(\theta) J(θ) 最小的 θ \theta θ,我们可以通过对 J ( θ ) J(\theta) J(θ) 关于 θ \theta θ 的导数求解,设导数为零:
∇ J ( θ ) = − 2 X T ( y − X θ ) = 0 \nabla J(\theta) = -2X^T(y - X\theta) = 0 ∇J(θ)=−2XT(y−Xθ)=0
展开后得到:
X T X θ = X T y X^T X \theta = X^T y XTXθ=XTy
4. 正规方程
这个方程称为正规方程,其形式为:
X T X θ = X T y X^T X \theta = X^T y XTXθ=XTy
5. 解的唯一性
- 若 X T X X^T X XTX 是可逆的(即列向量线性无关),则可以通过求逆得到参数的解:
θ = ( X T X ) − 1 X T y \theta = (X^T X)^{-1} X^T y θ=(XTX)−1XTy
- 如果 X T X X^T X XTX 不可逆(即存在多重共线性),则正规方程可能没有唯一解。
6. 几何解释
从几何的角度,正规方程可以被视为在特征空间中寻找一个超平面,使得目标变量 y y y 的投影与预测值 X θ X \theta Xθ 之间的误差最小化。
总结
正规方程通过线性代数的方法为线性回归提供了解的表达式,使得我们可以有效地计算参数。其核心思想是通过最小化残差平方和,寻找最佳拟合的线性模型。
梯度下降求解线性回归
import numpy as np
def linear_regression_gradient_descent(X: np.ndarray, y: np.ndarray, alpha: float, iterations: int) -> np.ndarray:m, n = X.shapetheta = np.zeros((n, 1))for _ in range(iterations):predictions = X @ thetaerrors = predictions - y.reshape(-1, 1)updates = X.T @ errors / mtheta -= alpha * updatesreturn np.round(theta.flatten(), 4)
其他都好理解,下面主要讲梯度updates的推导
1. 定义损失函数
线性回归的损失函数通常是均方误差(Mean Squared Error, MSE):
MSE = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 \text{MSE} = \frac{1}{2m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)})^2 MSE=2m1i=1∑m(hθ(x(i))−y(i))2
这里, h θ ( x ( i ) ) = X ( i ) ⋅ θ h_\theta(x^{(i)}) = X^{(i)} \cdot \theta hθ(x(i))=X(i)⋅θ 是模型的预测值, y ( i ) y^{(i)} y(i) 是实际值。
2. 对损失函数求导
为了最小化损失函数,我们需要对参数 θ \theta θ 求导:
∂ MSE ∂ θ = ∂ ∂ θ ( 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 ) \frac{\partial \text{MSE}}{\partial \theta} = \frac{\partial}{\partial \theta} \left( \frac{1}{2m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)})^2 \right) ∂θ∂MSE=∂θ∂(2m1i=1∑m(hθ(x(i))−y(i))2)
应用链式法则,首先求导内部的平方项:
∂ ∂ θ ( h θ ( x ( i ) ) − y ( i ) ) 2 = 2 ( h θ ( x ( i ) ) − y ( i ) ) ⋅ ∂ h θ ( x ( i ) ) ∂ θ \frac{\partial}{\partial \theta} (h_\theta(x^{(i)}) - y^{(i)})^2 = 2(h_\theta(x^{(i)}) - y^{(i)}) \cdot \frac{\partial h_\theta(x^{(i)})}{\partial \theta} ∂θ∂(hθ(x(i))−y(i))2=2(hθ(x(i))−y(i))⋅∂θ∂hθ(x(i))
而且 h θ ( x ( i ) ) = X ( i ) ⋅ θ h_\theta(x^{(i)}) = X^{(i)} \cdot \theta hθ(x(i))=X(i)⋅θ,所以:
∂ h θ ( x ( i ) ) ∂ θ = X ( i ) \frac{\partial h_\theta(x^{(i)})}{\partial \theta} = X^{(i)} ∂θ∂hθ(x(i))=X(i)
将这个结果代入:
∂ MSE ∂ θ = 1 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) X ( i ) \frac{\partial \text{MSE}}{\partial \theta} = \frac{1}{m} \sum_{i=1}^{m} (h_\theta(x^{(i)}) - y^{(i)}) X^{(i)} ∂θ∂MSE=m1i=1∑m(hθ(x(i))−y(i))X(i)
3. 用向量表示
将上述和式转换为向量形式。定义误差向量:
errors = predictions − y \text{errors} = \text{predictions} - y errors=predictions−y
其中 predictions = X ⋅ θ \text{predictions} = X \cdot \theta predictions=X⋅θ。这样,梯度可以表示为:
gradient = 1 m ( X T ⋅ errors ) \text{gradient} = \frac{1}{m} (X^T \cdot \text{errors}) gradient=m1(XT⋅errors)
4. 结论
因此,梯度的计算公式来源于损失函数的求导过程,通过向量化的方式将每个样本的误差与特征相乘,得出对每个参数的影响。这是梯度下降法中更新参数的基础。
相关文章:
线性回归矩阵求解和梯度求解
正规方程求解线性回归 首先正规方程如下: Θ ( X T X ) − 1 X T y \begin{equation} \Theta (X^T X)^{-1} X^T y \end{equation} Θ(XTX)−1XTy 接下来通过线性代数的角度理解这个问题。 二维空间 在二维空间上,有两个向量 a a a和 b b b&…...
M3U8不知道如何转MP4?包能学会的4种格式转换教学!
在流媒体视频大量生产的今天,M3U8作为一种基于HTTP Live Streaming(HLS)协议的播放列表格式,广泛应用于网络视频直播和点播中。它包含了媒体播放列表的信息,指向了视频文件被分割成的多个TS(Transport Stre…...
C++第4课——swap、switch-case-for循环(含视频讲解)
文章目录 1、课程代码2、课程视频 1、课程代码 #include<iostream> using namespace std; int main(){/* //第一个任务:学会swap int a,b,c;//从小到大排序输出 升序 cin>>a>>b>>c;//5 4 3if(a>b)swap(a,b);//4 5 3 swap()函数是用于交…...
大数据新视界 -- 大数据大厂之大数据重塑影视娱乐产业的未来(4 - 4)
💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...
在Java中,需要每120分钟刷新一次的`assetoken`,并且你想使用Redis作为缓存来存储和管理这个令牌
学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把手教你开发炫酷的vbs脚本制作(完善中……) 4、牛逼哄哄的 IDEA编程利器技巧(编写中……) 5、面经吐血整理的 面试技…...
linux网络编程7——协程设计原理与汇编实现
文章目录 协程设计原理与汇编实现1. 协程概念2. 协程的实现2.1 setjmp2.2 ucontext2.3 汇编实现2.4 优缺点2.5 实现协程原语2.5.1 create()2.5.2 yield()2.5.3 resume()2.5.4 exit()2.5.5 switch()2.5.6 sleep() 2.6 协程调度器 3. 利用hook使用协程版本的库函数学习参考 协程设…...
Ubuntu22.04版本左右,扩充用户可使用内存
1 取得root权限后,输入命令 lsblk 查看所有磁盘和分区,找到想要替换用户可使用文件夹内存的磁盘和分区。若没有进行分区,并转为所需要的分区数据类型,先进行分区与格式化,过程自行查阅。 扩充替换过程,例如…...
基于ArcMap中Python 批量处理栅格数据(以按掩膜提取为例)
注:图片来源于公众号,公众号也是我自己的。 ArcMap中的python编辑器是很多本科生使用ArcMap时容易忽略的一个工具,本人最近正在读一本书《ArcGIS Python 编程基础与应用》,在此和大家分享、交流一些相关的知识。 这篇文章主要分享…...
【flink】之集成mybatis对mysql进行读写
背景: 在现代大数据应用中,数据的高效处理和存储是核心需求之一。Flink作为一款强大的流处理框架,能够处理大规模的实时数据流,提供丰富的数据处理功能,如窗口操作、连接操作、聚合操作等。而MyBatis则是一款优秀的持…...
Java设计模式—观察者模式详解
引言 模式角色 UML图 示例代码 应用场景 优点 缺点 结论 引言 观察者模式(Observer Pattern)是一种行为设计模式,它定义了对象之间的一对多依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都会得到通知…...
【Cri-Dockerd】安装cri-dockerd
cri-dockerd的作用: 在k8s1.24之前。k8s会通过dockershim来调用docker进行容器运行时containerd,并且会自动安装dockershim,但是从1.24版本之前k8s为了降低容器运行时的调用的复杂度和效率,直接调用containerd了,并且…...
GCC及GDB的使用
参考视频及博客 https://www.bilibili.com/video/BV1EK411g7Li/?spm_id_from333.999.0.0&vd_sourceb3723521e243814388688d813c9d475f https://www.bilibili.com/video/BV1ei4y1V758/?buvidXU932919AEC08339E30CE57D39A2BABF6A44F&from_spmidsearch.search-result.0…...
大数据新视界 -- 大数据大厂之大数据重塑影视娱乐产业的未来(4 - 3)
💖💖💖亲爱的朋友们,热烈欢迎你们来到 青云交的博客!能与你们在此邂逅,我满心欢喜,深感无比荣幸。在这个瞬息万变的时代,我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…...
数据结构——基础知识补充
1.队列 1.普通队列 queue.Queue 是 Python 标准库 queue 模块中的一个类,适用于多线程环境。它实现了线程安全的 FIFO(先进先出)队列。 2.双端队列 双端队列(Deque,Double-Ended Queue)是一种具有队列和…...
只有.git文件夹时如何恢复项目
有时候误删文件但由于.git是隐藏文件夹而幸存,或者项目太大,单单甩给你一个.git文件夹让你自己恢复整个项目,该怎么办呢? 不用担心,只要进行以下步骤,即可把原项目重新搭建起来: 创建一个文件…...
anchor、anchor box、bounding box之间关系
最近学YOLO接触到这些概念,一下子有点蒙,简单总结一下。 anchor和anchor box Anchor:表示一组预定义的尺寸比例,用来代表常见物体的宽高比。可以把它看成是一个模板或规格,定义了物体框的“形状”和“比例”ÿ…...
代码随想录算法训练营第三十天 | 452.用最少数量的箭引爆气球 435.无重叠区间 763.划分字母区间
LeetCode 452.用最少数量的箭引爆气球: 文章链接 题目链接:452.用最少数量的箭引爆气球 思路: 气球的区间有重叠部分,只要弓箭从重叠部分射出来,那么就能减少所使用的弓箭数 **局部最优:**只要有重叠部分…...
海亮科技亮相第84届中国教装展 尽显生于校园 长于校园教育基因
10月25日,第84届中国教育装备展示会(以下简称“教装展”)在昆明滇池国际会展中心开幕。作为国内教育装备领域规模最大、影响最广的专业展会,本届教装展以“数字赋能教育,创新引领未来”为主题,为教育领域新…...
C语言数据结构学习:栈
C语言 数据结构学习 汇总入口: C语言数据结构学习:[汇总] 1. 栈 栈,实际上是一种特殊的线性表。这里使用的是链表栈,链表栈的博客:C语言数据结构学习:单链表 2. 栈的特点 只能在一端进行存取操作&#x…...
如何快速分析音频中的各种频率成分
从视频中提取音频 from moviepy.editor import VideoFileClip# Load the video file and extract audio video_path "/mnt/data/WeChat_20241026235630.mp4" video_clip VideoFileClip(video_path)# Extract audio and save as a temporary file for further anal…...
Python|GIF 解析与构建(5):手搓截屏和帧率控制
目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...
Linux应用开发之网络套接字编程(实例篇)
服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...
多场景 OkHttpClient 管理器 - Android 网络通信解决方案
下面是一个完整的 Android 实现,展示如何创建和管理多个 OkHttpClient 实例,分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...
Swagger和OpenApi的前世今生
Swagger与OpenAPI的关系演进是API标准化进程中的重要篇章,二者共同塑造了现代RESTful API的开发范式。 本期就扒一扒其技术演进的关键节点与核心逻辑: 🔄 一、起源与初创期:Swagger的诞生(2010-2014) 核心…...
CSS设置元素的宽度根据其内容自动调整
width: fit-content 是 CSS 中的一个属性值,用于设置元素的宽度根据其内容自动调整,确保宽度刚好容纳内容而不会超出。 效果对比 默认情况(width: auto): 块级元素(如 <div>)会占满父容器…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
Go 语言并发编程基础:无缓冲与有缓冲通道
在上一章节中,我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道,它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好࿰…...
RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill
视觉语言模型(Vision-Language Models, VLMs),为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展,机器人仍难以胜任复杂的长时程任务(如家具装配),主要受限于人…...
Webpack性能优化:构建速度与体积优化策略
一、构建速度优化 1、升级Webpack和Node.js 优化效果:Webpack 4比Webpack 3构建时间降低60%-98%。原因: V8引擎优化(for of替代forEach、Map/Set替代Object)。默认使用更快的md4哈希算法。AST直接从Loa…...
