当前位置: 首页 > news >正文

机器学习第5天:多项式回归与学习曲线

文章目录

多项式回归介绍

方法与代码

方法描述

分离多项式

学习曲线的作用

场景

学习曲线介绍

欠拟合曲线

示例

结论

过拟合曲线

示例

​结论 


多项式回归介绍

当数据不是线性时我们该如何处理呢,考虑如下数据

import matplotlib.pyplot as plt
import numpy as npnp.random.seed(42)x = 8 * np.random.rand(100, 1) - 4
y = 2*x**2+3*x+np.random.randn(100, 1)plt.scatter(x, y)
plt.show()


方法与代码

方法描述

先讲思路,以这个二元函数为例

y=3*x^{2}+2*x+c

将多项式化为多个单项的,也就是将x的平方和x两个项分离开,然后单独给线性模型处理,求出参数,最后再组合在一起,很好理解,让我们来看一下代码


分离多项式

我们使用机器学习库的PolynomialFeatures来分离多项式

from sklearn.preprocessing import PolynomialFeaturespoly_features = PolynomialFeatures(degree=2, include_bias=False)
x_poly = poly_features.fit_transform(x)
print(x[0])
print(x_poly[0])

运行结果

可以看到,4, 5行代码将原始x和x平方挑选了出来,这时我们再把这个数据进行线性回归

model = LinearRegression()
model.fit(x_poly, y)
print(model.coef_)

 这段代码使用处理后的x拟合y,再打印模型拟合的参数,可以看到模型的两个参数分别是2.9和2左右,而我们的方程的一次参数和二次参数分别是3和2,可见效果还是很好的

把预测的结果绘制出来

model = LinearRegression()
model.fit(x_poly, y)
pre_y = model.predict(x_poly)# 这里是为了让x升序的排序算法, 可以尝试不加这段代码图会变成什么样
sorted_indices = sorted(range(len(x)), key=lambda k: x[k])
x_sorted = [x[i] for i in sorted_indices]
y_sorted = [pre_y[i] for i in sorted_indices]plt.plot(x_sorted, y_sorted, "r-")
plt.scatter(x, y)
plt.show()


学习曲线的作用

场景

设想一下,当你需要预测房价,你也有多组数据,包括离学校距离,交通状况等,但是问题来了,你只知道这些特征可能与房价有关,但并不知道这些特征与房价之间的方程关系,这时我们进行回归任务时,就可能导致欠拟合或者过拟合,幸运的是,我们可以通过学习曲线来判断


学习曲线介绍

学习曲线图就是以损失函数为纵坐标,数据集大小为横坐标,然后在图上画出训练集和验证集两条曲线的图,训练集就是我们用来训练模型的数据,验证集就是我们用来验证模型性能的数据集,我们往往将数据集分成训练集与验证集

我们先定义一个学习曲线绘制函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()

 简单介绍一下,这个函数接收模型参数,x,y参数,然后在for循环中,取不同数据集大小来计算RMSE损失(就是\sqrt{MSE}),然后把曲线绘制出来


欠拟合曲线

我们知道欠拟合就是模拟效果不好的情况,可以想象的到,无论在训练集还是验证集上,他的损失都会比较高

示例

我们将线性模型的学习曲线绘制出来

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)model = LinearRegression()
plot_learning_curves(model, x, y)

 

结论

可以看到,在只有一点数据时,模型在训练集上效果很好(因为就是开始这一些数据训练出来的),而在验证集上效果不好,但随着训练集增加(模型学习到的越多),验证集上的误差逐渐减小,训练集上的误差增加(因为是学到了一个趋势,不会完全和训练集一样了)

这个图的特征是两条曲线非常接近,且误差都较大(差不多在0.3) ,这是欠拟合的表现(模型效果不好)


过拟合曲线

过拟合就是完全以数据集来模拟曲线,泛化能力很差

示例

我们来试试将一次函数模拟成三次函数,再来看看学习曲线(毫无疑问过拟合了)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipelinedef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()np.random.seed(10)
x = np.random.rand(200, 1)
y = 2 * x + np.random.rand(200, 1)poly_regression = Pipeline([("Poly", PolynomialFeatures(degree=3, include_bias=False)),("Line", LinearRegression())
])plot_learning_curves(poly_regression, x, y)

结论 

这条曲线的特征是训练集的效果比验证集好(两条线之间有一定间距),这往往是过拟合的表现(在训练集上效果好,验证集差,表面泛化能力差) 

相关文章:

机器学习第5天:多项式回归与学习曲线

文章目录 多项式回归介绍 方法与代码 方法描述 分离多项式 学习曲线的作用 场景 学习曲线介绍 欠拟合曲线 示例 结论 过拟合曲线 示例 ​结论 多项式回归介绍 当数据不是线性时我们该如何处理呢,考虑如下数据 import matplotlib.pyplot as plt impo…...

MSYS2介绍及工具安装

0 Preface/Foreword 1 MSYS2 官网:MSYS2...

Swift开发中:非逃逸闭包、逃逸闭包、自动闭包的区别

1. 非逃逸闭包(Non-Escaping Closure) 定义:默认情况下,在 Swift 中闭包是非逃逸的。这意味着闭包在函数结束之前被调用并完成,它不会“逃逸”出函数的范围。内存管理:由于闭包在函数返回前被调用&#xf…...

栈结构应用-进制转换-辗转相除法

// 定义类class Stack{// #items [] 前边加#变为私有 外部不能随意修改 内部使用也要加#items []pop(){return this.items.pop()}push(data){this.items.push(data)}peek(){return this.items[this.items.length-1]}isEmpty(){return this.items.length 0}size(){return th…...

【Azure 架构师学习笔记】-Azure Storage Account(6)- File Layer

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Storage Account】系列。 接上文 【Azure 架构师学习笔记】-Azure Storage Account(5)- Data Lake layers 前言 上一文介绍了存储帐户的概述,还有container的一些配置,在…...

idea 环境搭建及运行java后端源码

1、 idea 历史版本下载及安装 建议下载和我一样的版本,2020.3 https://www.jetbrains.com/idea/download/other.html,idea分为专业版本(Ultimate)和社区版本(Community),前期可以下载专业版本…...

掌握Shell:从新手到编程大师的Linux之旅

1 shell介绍 1.1 shell脚本的意义 1.记录命令执行的过程和执行逻辑,以便以后重复执行 2.脚本可以批量处理主机 3.脚本可以定时处理主机 1.2 脚本的创建 #!/bin/bash # 运行脚本时候执行的环境1.3 自动添加脚本说明信息 /etc/vimrc # vim主配置文件 ~/.vimrc # 该…...

有重复元素的快速排序

当涉及到处理重复元素的快速排序时,可以使用荷兰国旗问题的方法,也就是三路划分。下面是使用 Java 实现的示例代码: import java.util.Arrays;public class QuickSort {public static void quickSort(int[] arr, int low, int high) {if (lo…...

Bert浅谈

优点 首先,bert的创新点在于利用了双向transformer,这就跟openai的gpt有区别,gpt是采用单向的transformer,而作者认为双向transformer更能够融合上下文的信息。这里双向和单向的区别在于,单向只跟当前位置之前的tocke…...

产品运营的场景和运营策略

一、启动屏 1.概念 启动屏,特指 APP 产品启动时即显示的界面,这个界面一般会停留几秒钟时间,在这个时间内 APP 会在后台加载服务框架、启动各种服务 SDK 、获取用户地理位置、判断有无新版本、判断用户账户状态以及其他系统级别的…...

C#异常捕获try catch详细介绍

在C#中,异常处理是通过try、catch、finally和throw语句来实现的,它们提供了一种结构化和可预测的方法来处理运行时错误。 C#异常基本用法 try块 异常处理以try块开始,try块包含可能会引发异常的代码。如果在try块中的代码执行过程中发生了…...

切换阿里云ES方式及故障应急处理方案

一、阿里云es服务相关问题及答解 1.1 ES7.10扩容节点时间 增加节点数量需要节点拉起和数据Rebalance两步,拉起时间7.16及以上的新版本大概10分钟以内,7.16以前大概一小时,数据迁移的时间就看数据量了,一般整体在半小时以内 (需进行相关测试验证) 1.2 ES7.10扩容数据节点…...

CTFhub-RCE-过滤空格

1. 查看当前目录&#xff1a;127.0.0.1|ls 2. 查看 flag_890277429145.php 127.0.0.1|cat flag_890277429145.php 根据题目可以知道空格被过滤掉了 3.空格可以用以下字符代替&#xff1a; < 、>、<>、%20(space)、%09(tab)、$IFS$9、 ${IFS}、$IFS等 $IFS在li…...

无需添加udid,ios企业证书的自助生成方法

我们开发uniapp的app的时候&#xff0c;需要苹果证书去打包。 假如申请的是个人或company类型的苹果开发者账号&#xff0c;必须上架才能安装&#xff0c;异常的麻烦&#xff0c;但是有一些app&#xff0c;比如企业内部使用的app&#xff0c;是不需要上架苹果应用市场的。 假…...

【PTA题目】6-20 使用函数判断完全平方数 分数 10

6-20 使用函数判断完全平方数 分数 10 全屏浏览题目 切换布局 作者 张高燕 单位 浙大城市学院 本题要求实现一个判断整数是否为完全平方数的简单函数。 函数接口定义&#xff1a; int IsSquare( int n ); 其中n是用户传入的参数&#xff0c;在长整型范围内。如果n是完全…...

Nas搭建webdav服务器并同步Zotero科研文献

无需云盘&#xff0c;不限流量实现Zotero跨平台同步&#xff1a;内网穿透私有WebDAV服务器 文章目录 无需云盘&#xff0c;不限流量实现Zotero跨平台同步&#xff1a;内网穿透私有WebDAV服务器一、Zotero安装教程二、群晖NAS WebDAV设置三、Zotero设置四、使用公网地址同步Zote…...

一句话总结敏捷实践中不同方法

敏捷实践是指一组优先考虑灵活性、协作和客户满意度的软件开发和项目管理原则和方法。 不同方法论的敏捷实践&#xff1a; 1、敏捷&#xff1a; Sprints&#xff1a;限时迭代&#xff08;通常 2-4 周&#xff09;&#xff0c;在此期间创建潜在的可交付产品增量。每日站立会议…...

【数据结构】线段树(点修区查)

数据结构-线段树&#xff08;点修区查&#xff09; 前置知识 分治递归二叉树 思路 我们需要维护一个支持单点修改&#xff0c;区间查询的数据结构&#xff0c;并且要求在线&#xff0c;一般使用线段树解决。 线段树是一个二叉树形的数据结构。 线段树的思想很简单&#xff0c…...

Ansys Lumerical | 用于增强现实系统的表面浮雕光栅

在本示例中&#xff0c;我们使用 RCWA 求解器设计了一个斜面浮雕光栅 (SRG)&#xff0c;它将用于将光线耦合到单色增强现实 (AR) 系统的波导中。光栅的几何形状经过优化&#xff0c;可将正常入射光导入-1 光栅阶次。 然后我们将光栅特性导出为 Lumerical Sub-Wavelength Model …...

QT day3作业

1.思维导图 2、 完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python&#xff5c;GIF 解析与构建&#xff08;5&#xff09;&#xff1a;手搓截屏和帧率控制 一、引言 二、技术实现&#xff1a;手搓截屏模块 2.1 核心原理 2.2 代码解析&#xff1a;ScreenshotData类 2.2.1 截图函数&#xff1a;capture_screen 三、技术实现&…...

R语言AI模型部署方案:精准离线运行详解

R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

Python爬虫(二):爬虫完整流程

爬虫完整流程详解&#xff08;7大核心步骤实战技巧&#xff09; 一、爬虫完整工作流程 以下是爬虫开发的完整流程&#xff0c;我将结合具体技术点和实战经验展开说明&#xff1a; 1. 目标分析与前期准备 网站技术分析&#xff1a; 使用浏览器开发者工具&#xff08;F12&…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全&#xff1a;零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言&#xff1a;云原生安全的范式革命 随着云原生技术的普及&#xff0c;安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测&#xff0c;到2025年&#xff0c;零信任架构将成为超…...

高防服务器能够抵御哪些网络攻击呢?

高防服务器作为一种有着高度防御能力的服务器&#xff0c;可以帮助网站应对分布式拒绝服务攻击&#xff0c;有效识别和清理一些恶意的网络流量&#xff0c;为用户提供安全且稳定的网络环境&#xff0c;那么&#xff0c;高防服务器一般都可以抵御哪些网络攻击呢&#xff1f;下面…...

《C++ 模板》

目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板&#xff0c;就像一个模具&#xff0c;里面可以将不同类型的材料做成一个形状&#xff0c;其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式&#xff1a;templa…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...

人机融合智能 | “人智交互”跨学科新领域

本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...