当前位置: 首页 > news >正文

Python 中的机器学习简介:多项式回归

一、说明

        多项式回归可以识别自变量和因变量之间的非线性关系。本文是关于回归、梯度下降和 MSE 系列文章的第三篇。前面的文章介绍了简单线性回归、回归的正态方程和多元线性回归。

二、多项式回归

        多项式回归用于最适合曲线拟合的复杂数据。它可以被视为多元线性回归的子集。

        请注意,X₀ 是偏差的一列;这允许在第一篇文章中讨论的广义公式。使用上述等式,每个“自变量”都可以被视为 X₁ 的指数版本。

        这允许从多元线性回归使用相同的模型,因为只需要识别每个变量的系数。可以创建一个简单的三阶多项式模型作为示例。其等式如下:

        模型、梯度下降和 MSE 的广义函数可用于前面的文章:

# line of best fit
def model(w, X):"""Inputs:w: array of weights | (num features, 1)X: array of inputs  | (n samples, num features)Output:returns the output of X@w | (n samples, 1)"""return torch.matmul(X, w)
# mean squared error (MSE)
def MSE(Yhat, Y):"""Inputs:Yhat: array of predictions | (n samples, 1)Y: array of expected outputs | (n samples, 1)Output:returns the loss of the model, which is a scalar"""return torch.mean((Yhat-Y)**2) # mean((error)^2)
# optimizer
def gradient_descent(w):"""Inputs:w: array of weights | (num features, 1)Global Variables / Constants:X: array of inputs  | (n samples, num features)Y: array of expected outputs | (n samples, 1)lr: learning rate to scale the gradientOutput:returns the updated weights""" n = X.shape[0]return w - (lr * 2/n) * (torch.matmul(-Y.T, X) + torch.matmul(torch.matmul(w.T, X.T), X)).reshape(w.shape)

三、创建数据

        现在,所需要的只是一些用于训练模型的数据。可以使用“蓝图”功能,并且可以添加随机性。这遵循与前面文章相同的方法。蓝图如下所示:

        可以创建大小为 (800, 4) 的训练集和大小为 (200, 4) 的测试集。请注意,除偏差外,每个特征都是第一个特征的指数版本。

import torchtorch.manual_seed(5)
torch.set_printoptions(precision=2)# features
X0 = torch.ones((1000,1))
X1 = (100*(torch.rand(1000) - 0.5)).reshape(-1,1) # generates 1000 random numbers from -50 to 50
X2, X3 = X1**2, X1**3
X = torch.hstack((X0,X1,X2,X3))# normal distribution with a mean of 0 and std of 8
normal = torch.distributions.Normal(loc=0, scale=8)# targets
Y = (3*X[:,3] + 2*X[:,2] + 1*X[:,1] + 5 + normal.sample(torch.ones(1000).shape)).reshape(-1,1)# train, test
Xtrain, Xtest = X[:800], X[800:]
Ytrain, Ytest = Y[:800], Y[800:]

        定义初始权重后,可以使用最佳拟合线绘制数据。

torch.manual_seed(5)
w = torch.rand(size=(4, 1))
w
tensor([[0.83],[0.13],[0.91],[0.82]])
import matplotlib.pyplot as pltdef plot_lbf():"""Output:prints the line of best fit in comparison to the train and test data"""# plot the train and test setsplt.scatter(Xtrain[:,1],Ytrain,label="train")plt.scatter(Xtest[:,1],Ytest,label="test")# plot the line of best fitX1_plot = torch.arange(-50, 50.1,.1).reshape(-1,1) X2_plot, X3_plot = X1_plot**2, X1_plot**3X0_plot = torch.ones(X1_plot.shape)X_plot = torch.hstack((X0_plot,X1_plot,X2_plot,X3_plot))plt.plot(X1_plot.flatten(), model(w, X_plot).flatten(), color="red", zorder=4)plt.xlim(-50, 50)plt.xlabel("$X$")plt.ylabel("$Y$")plt.legend()plt.show()plot_lbf()
图片来源:作者

四、训练模型

        为了部分最小化成本函数,可以使用 5e-11 和 500,000 epoch 的学习率与梯度下降一起使用。

lr = 5e-11
epochs = 500000# update the weights 1000 times
for i in range(0, epochs):# update the weightsw = gradient_descent(w)# print the new values every 10 iterationsif (i+1) % 100000 == 0:print("epoch:", i+1)print("weights:", w)print("Train MSE:", MSE(model(w,Xtrain), Ytrain))print("Test MSE:", MSE(model(w,Xtest), Ytest))print("="*10)plot_lbf()
epoch: 100000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.87)
Test MSE: tensor(162.55)
==========
epoch: 200000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.52)
Test MSE: tensor(162.22)
==========
epoch: 300000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.19)
Test MSE: tensor(161.89)
==========
epoch: 400000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(162.85)
Test MSE: tensor(161.57)
==========
epoch: 500000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(162.51)
Test MSE: tensor(161.24)
==========
图片来源:作者

        即使有 500,000 个 epoch 和极小的学习率,该模型也无法识别前两个权重。虽然当前的解决方案非常准确,MSE为161.24,但可能需要数百万个epoch才能完全最小化它。这是多项式回归梯度下降的局限性之一。

五、正态方程

        作为替代方案,可以使用第二篇文章中的正态方程直接计算优化权重:

def NormalEquation(X, Y):"""Inputs:X: array of input values | (n samples, num features)Y: array of expected outputs | (n samples, 1)Output:returns the optimized weights | (num features, 1)"""return torch.inverse(X.T @ X) @ X.T @ Yw = NormalEquation(Xtrain, Ytrain)
w
tensor([[4.57],[0.98],[2.00],[3.00]])

        正态方程能够立即识别每个权重的正确值,并且每组的MSE比梯度下降时低约100点:

MSE(model(w,Xtrain), Ytrain), MSE(model(w,Xtest), Ytest)
(tensor(60.64), tensor(63.84))

六、结论

        通过实现简单线性、多重线性和多项式回归,接下来的两篇文章将介绍套索和岭回归。这些类型的回归在机器学习中引入了两个重要概念:过拟合和正则化。

 参考文章:

亨特·菲利普斯

相关文章:

Python 中的机器学习简介:多项式回归

一、说明 多项式回归可以识别自变量和因变量之间的非线性关系。本文是关于回归、梯度下降和 MSE 系列文章的第三篇。前面的文章介绍了简单线性回归、回归的正态方程和多元线性回归。 二、多项式回归 多项式回归用于最适合曲线拟合的复杂数据。它可以被视为多元线性回归的子集。…...

docker 容器中执行命令出现错误: 13: Permission denied

错误 13: Permission denied [rootVM-32-11-tencentos ~]# docker exec -it kibana1 /bin/bash kibana76c20c215dcb:~$ apt-get install vi E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied) E: Unable to acquire the dpkg frontend…...

JavaWeb学习|JavaBean;MVC三层架构;Filter;Listener

1.JavaBean 实体类 JavaBean有特定的写法: 必须要有一个无参构造 属性必须私有化。 必须有对应的get/set方法 用来和数据库的字段做映射 ORM; ORM:对象关系映射 表--->类 字段-->属性 行记录---->对象 2.<jsp&#xff1a;useBean 标签 3. MVC三层架构 4. Filter …...

arx 外部参照文件(XREF)的添加、删除、卸载和重载_objectarx

添加参照 CString strFileName;int nIndex = strFilePath.ReverseFind(\\);if (nIndex != -1){strFileName = strFilePath.Right(strFilePath....

【博客699】docker daemon预置iptables剖析

docker daemon预置iptables剖析 没有安装docker的机器&#xff1a;iptables为空&#xff0c;且每个链路的默认policy均为ACCEPT [root~]# iptables-save[root ~]# iptables -t raw -nvL Chain PREROUTING (policy ACCEPT 0 packets, 0 bytes)pkts bytes target prot opt …...

Golang 中的交叉编译详解

Golang 中的交叉编译 在 Golang 中&#xff0c;交叉编译指的是在同一台机器上生成针对不同操作系统或硬件架构的二进制文件。这在开发跨平台应用或构建特定平台的发布版本时非常有用。 交叉编译 Golang 程序的基本步骤如下&#xff1a; 指定目标操作系统和工具链并设置对应的…...

Python中的诡异事:不可见字符!

文章目录 前言1. 起因2. 调查3. 高能4. 释惑 前言 今天分享一件很诡异的事情&#xff0c;我写代码的时候遇到了不可见的字符&#xff01;&#xff01;&#xff01; 1. 起因 今天在使用pipreqs导出项目中所依赖的库时突然报错了&#xff1a; pipreqs . --encodingutf-8 --forc…...

【uniapp】uniapp使用微信开发者工具制作骨架屏:

文章目录 一、效果&#xff1a;二、过程&#xff1a; 一、效果&#xff1a; 二、过程&#xff1a; 【1】微信开发者工具打开项目&#xff0c;生成骨架屏&#xff0c;将wxml改造为vue页面组件&#xff0c;并放入样式 【2】页面使用骨架屏组件 【3】改造骨架屏&#xff08;去除…...

【UE4 RTS】06-Camera Edge Scroll

前言 本篇实现的效果是当玩家将鼠标移至屏幕边缘时&#xff0c;视野会相应的上下左右移动 效果 步骤 1. 打开玩家控制器“RTS_PlayerController_BP”&#xff0c;在类默认值中设置如下选项 新建一个宏&#xff0c;命名为“EdgeSroll”&#xff0c; 添加两个输入和三个输出&a…...

无涯教程-Perl - length函数

描述 此函数返回EXPR值的长度(以字符为单位),如果未指定,则返回$_。如果要确定相应的大小,请在数组或哈希上使用标量context。 语法 以下是此函数的简单语法- length EXPRlength返回值 此函数返回字符串的大小。 例 以下是显示其基本用法的示例代码- #!/usr/bin/perl$o…...

怎样在 CentOS 里下载 RPM 包及其所有依赖包

前几天我尝试去创建一个仅包含我们经常在 CentOS 7 下使用的软件的本地仓库。当然,我们可以使用 curl 或者 wget 下载任何软件包,然而这些命令并不能下载要求的依赖软件包。你必须去花一些时间而且手动的去寻找和下载被安装的软件所依赖的软件包。然而,我们并不是必须这样。…...

在Ubuntu上使用NFS挂载

假设要把192.16.2.101服务器上的 /home/sharedata 挂载到192.16.2.102服务器上的 /home/receive_data 一、服务端 1、安装NFS服务端 sudo apt-get install nfs-kernel-server 2、修改NFS挂载配置文件 sudo vim /etc/exports 在文件中输入 /home/sharedata 192.16.2.102(…...

复现海康威视综合安防管理平台artemis接口Spring boot heapdump内存泄露漏洞

目录 一、漏洞描述 二、影响版本 三、资产测绘 四、漏洞复现 一、漏洞描述 HIKVISION iSecure Center综合安防管理平台是一套“集成化”、“智能化”的平台,通过接入视频监控、一卡通...

哈希unordered系列介绍(上)

一.Unordered_map,Unordered_set介绍 在之前我们已经介绍过set,map,multiset等等关联式容器&#xff0c;它们的底层是红黑树进行模拟实现的&#xff0c;在查询时效率可达到 l o g 2 N log_2 N log2​N&#xff0c;即最差情况下需要比较红黑树的高度次&#xff0c;当树中的节点…...

MySQL随心记第二篇

一、正则表达式篇&#xff1a; regular expression--> regexp 元字符: . : 单个的任意字符&#xff08;默认不包含换行&#xff09; \d:数字: 0-9 补集:\D \w:ascil:数字&#xff0c;大写字母&#xff0c;小写字母&#xff0c;以及下划线 unicode: 数字&#xff0c;大…...

0001nginx简介、相关模型与原理

文章目录 一. 什么是Nginx二. ngnix的一些模型1、nginx的进程模型2、worker的抢占&#xff08;锁&#xff09;机制模型3. nginx事件处理模型 三. nginx加载静态资源的过程 一. 什么是Nginx Nginx是一个高性能HTTP反向代理服务器&#xff0c;以下是nginx的相关能力 反向代理&am…...

elasticsearch简单入门语法

基本操作 创建不同的分词器 ik_smart&#xff1a; 极简分词 &#xff1b; ik_max_word: 最细力再度分词 基本的rest命令 methodurl地址描述PUTlocalhost:9200/索引名称/类型名称/文档id创建文档&#xff08;指定文档id&#xff09;POSTlocalhost:9200/索引名称/类型名称创建文…...

Python自动化测试用例:如何优雅的完成Json格式数据断言

目录 前言 直接使用 优化 封装 小结 进阶 总结 资料获取方法 前言 记录Json断言在工作中的应用进阶。 直接使用 很早以前写过一篇博客&#xff0c;记录当时获取一个多级json中指定key的数据&#xff1a; #! /usr/bin/python # coding:utf-8 """ aut…...

阿里云对象存储服务OSS

1、引依赖 <dependency><groupId>com.aliyun.oss</groupId><artifactId>aliyun-sdk-oss</artifactId><version>3.15.1</version> </dependency> <dependency><groupId>javax.xml.bind</groupId><artifa…...

第三节:在WORD为应用主窗口下关闭EXCEL的操作(1)

【分享成果&#xff0c;随喜正能量】夏日里的遗憾&#xff0c;一定都会被秋风温柔化解。吃素不难&#xff0c;难于不肯捨贪口腹之心。若不贪口腹&#xff0c;有何吃素之不便乎。虽吃华素&#xff0c;不吃素日&#xff0c;亦须少吃。以一切物类&#xff0c;皆是贪生怕死&#xf…...

(LeetCode 每日一题) 3442. 奇偶频次间的最大差值 I (哈希、字符串)

题目&#xff1a;3442. 奇偶频次间的最大差值 I 思路 &#xff1a;哈希&#xff0c;时间复杂度0(n)。 用哈希表来记录每个字符串中字符的分布情况&#xff0c;哈希表这里用数组即可实现。 C版本&#xff1a; class Solution { public:int maxDifference(string s) {int a[26]…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

LeetCode - 394. 字符串解码

题目 394. 字符串解码 - 力扣&#xff08;LeetCode&#xff09; 思路 使用两个栈&#xff1a;一个存储重复次数&#xff0c;一个存储字符串 遍历输入字符串&#xff1a; 数字处理&#xff1a;遇到数字时&#xff0c;累积计算重复次数左括号处理&#xff1a;保存当前状态&a…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums&#xff0c;返回 数组 answer &#xff0c;其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法&#xff0c;且在 O(n) 时间复杂度…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...

无人机侦测与反制技术的进展与应用

国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机&#xff08;无人驾驶飞行器&#xff0c;UAV&#xff09;技术的快速发展&#xff0c;其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统&#xff0c;无人机的“黑飞”&…...

push [特殊字符] present

push &#x1f19a; present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中&#xff0c;push 和 present 是两种不同的视图控制器切换方式&#xff0c;它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

CSS | transition 和 transform的用处和区别

省流总结&#xff1a; transform用于变换/变形&#xff0c;transition是动画控制器 transform 用来对元素进行变形&#xff0c;常见的操作如下&#xff0c;它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...