当前位置: 首页 > news >正文

机器学习第5天:多项式回归与学习曲线

文章目录

多项式回归介绍

方法与代码

方法描述

分离多项式

学习曲线的作用

场景

学习曲线介绍

欠拟合曲线

示例

结论

过拟合曲线

示例

​结论 


多项式回归介绍

当数据不是线性时我们该如何处理呢,考虑如下数据

import matplotlib.pyplot as plt
import numpy as npnp.random.seed(42)x = 8 * np.random.rand(100, 1) - 4
y = 2*x**2+3*x+np.random.randn(100, 1)plt.scatter(x, y)
plt.show()


方法与代码

方法描述

先讲思路,以这个二元函数为例

y=3*x^{2}+2*x+c

将多项式化为多个单项的,也就是将x的平方和x两个项分离开,然后单独给线性模型处理,求出参数,最后再组合在一起,很好理解,让我们来看一下代码


分离多项式

我们使用机器学习库的PolynomialFeatures来分离多项式

from sklearn.preprocessing import PolynomialFeaturespoly_features = PolynomialFeatures(degree=2, include_bias=False)
x_poly = poly_features.fit_transform(x)
print(x[0])
print(x_poly[0])

运行结果

可以看到,4, 5行代码将原始x和x平方挑选了出来,这时我们再把这个数据进行线性回归

model = LinearRegression()
model.fit(x_poly, y)
print(model.coef_)

 这段代码使用处理后的x拟合y,再打印模型拟合的参数,可以看到模型的两个参数分别是2.9和2左右,而我们的方程的一次参数和二次参数分别是3和2,可见效果还是很好的

把预测的结果绘制出来

model = LinearRegression()
model.fit(x_poly, y)
pre_y = model.predict(x_poly)# 这里是为了让x升序的排序算法, 可以尝试不加这段代码图会变成什么样
sorted_indices = sorted(range(len(x)), key=lambda k: x[k])
x_sorted = [x[i] for i in sorted_indices]
y_sorted = [pre_y[i] for i in sorted_indices]plt.plot(x_sorted, y_sorted, "r-")
plt.scatter(x, y)
plt.show()


学习曲线的作用

场景

设想一下,当你需要预测房价,你也有多组数据,包括离学校距离,交通状况等,但是问题来了,你只知道这些特征可能与房价有关,但并不知道这些特征与房价之间的方程关系,这时我们进行回归任务时,就可能导致欠拟合或者过拟合,幸运的是,我们可以通过学习曲线来判断


学习曲线介绍

学习曲线图就是以损失函数为纵坐标,数据集大小为横坐标,然后在图上画出训练集和验证集两条曲线的图,训练集就是我们用来训练模型的数据,验证集就是我们用来验证模型性能的数据集,我们往往将数据集分成训练集与验证集

我们先定义一个学习曲线绘制函数

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()

 简单介绍一下,这个函数接收模型参数,x,y参数,然后在for循环中,取不同数据集大小来计算RMSE损失(就是\sqrt{MSE}),然后把曲线绘制出来


欠拟合曲线

我们知道欠拟合就是模拟效果不好的情况,可以想象的到,无论在训练集还是验证集上,他的损失都会比较高

示例

我们将线性模型的学习曲线绘制出来

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegressiondef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()x = np.random.rand(100, 1)
y = 2 * x + np.random.rand(100, 1)model = LinearRegression()
plot_learning_curves(model, x, y)

 

结论

可以看到,在只有一点数据时,模型在训练集上效果很好(因为就是开始这一些数据训练出来的),而在验证集上效果不好,但随着训练集增加(模型学习到的越多),验证集上的误差逐渐减小,训练集上的误差增加(因为是学到了一个趋势,不会完全和训练集一样了)

这个图的特征是两条曲线非常接近,且误差都较大(差不多在0.3) ,这是欠拟合的表现(模型效果不好)


过拟合曲线

过拟合就是完全以数据集来模拟曲线,泛化能力很差

示例

我们来试试将一次函数模拟成三次函数,再来看看学习曲线(毫无疑问过拟合了)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipelinedef plot_learning_curves(model, x, y):x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)train_errors, val_errors = [], []for m in range(1, len(x_train)):model.fit(x_train[:m], y_train[:m])y_train_predict = model.predict(x_train[:m])y_val_predict = model.predict(x_val)train_errors.append(mean_squared_error(y_train[:m], y_train_predict))val_errors.append(mean_squared_error(y_val, y_val_predict))plt.plot(np.sqrt(train_errors), "r-+", linewidth=2, label="train")plt.plot(np.sqrt(val_errors), "b-", linewidth=3, label="val")plt.legend()plt.show()np.random.seed(10)
x = np.random.rand(200, 1)
y = 2 * x + np.random.rand(200, 1)poly_regression = Pipeline([("Poly", PolynomialFeatures(degree=3, include_bias=False)),("Line", LinearRegression())
])plot_learning_curves(poly_regression, x, y)

结论 

这条曲线的特征是训练集的效果比验证集好(两条线之间有一定间距),这往往是过拟合的表现(在训练集上效果好,验证集差,表面泛化能力差) 

相关文章:

机器学习第5天:多项式回归与学习曲线

文章目录 多项式回归介绍 方法与代码 方法描述 分离多项式 学习曲线的作用 场景 学习曲线介绍 欠拟合曲线 示例 结论 过拟合曲线 示例 ​结论 多项式回归介绍 当数据不是线性时我们该如何处理呢,考虑如下数据 import matplotlib.pyplot as plt impo…...

MSYS2介绍及工具安装

0 Preface/Foreword 1 MSYS2 官网:MSYS2...

Swift开发中:非逃逸闭包、逃逸闭包、自动闭包的区别

1. 非逃逸闭包(Non-Escaping Closure) 定义:默认情况下,在 Swift 中闭包是非逃逸的。这意味着闭包在函数结束之前被调用并完成,它不会“逃逸”出函数的范围。内存管理:由于闭包在函数返回前被调用&#xf…...

栈结构应用-进制转换-辗转相除法

// 定义类class Stack{// #items [] 前边加#变为私有 外部不能随意修改 内部使用也要加#items []pop(){return this.items.pop()}push(data){this.items.push(data)}peek(){return this.items[this.items.length-1]}isEmpty(){return this.items.length 0}size(){return th…...

【Azure 架构师学习笔记】-Azure Storage Account(6)- File Layer

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Storage Account】系列。 接上文 【Azure 架构师学习笔记】-Azure Storage Account(5)- Data Lake layers 前言 上一文介绍了存储帐户的概述,还有container的一些配置,在…...

idea 环境搭建及运行java后端源码

1、 idea 历史版本下载及安装 建议下载和我一样的版本,2020.3 https://www.jetbrains.com/idea/download/other.html,idea分为专业版本(Ultimate)和社区版本(Community),前期可以下载专业版本…...

掌握Shell:从新手到编程大师的Linux之旅

1 shell介绍 1.1 shell脚本的意义 1.记录命令执行的过程和执行逻辑,以便以后重复执行 2.脚本可以批量处理主机 3.脚本可以定时处理主机 1.2 脚本的创建 #!/bin/bash # 运行脚本时候执行的环境1.3 自动添加脚本说明信息 /etc/vimrc # vim主配置文件 ~/.vimrc # 该…...

有重复元素的快速排序

当涉及到处理重复元素的快速排序时,可以使用荷兰国旗问题的方法,也就是三路划分。下面是使用 Java 实现的示例代码: import java.util.Arrays;public class QuickSort {public static void quickSort(int[] arr, int low, int high) {if (lo…...

Bert浅谈

优点 首先,bert的创新点在于利用了双向transformer,这就跟openai的gpt有区别,gpt是采用单向的transformer,而作者认为双向transformer更能够融合上下文的信息。这里双向和单向的区别在于,单向只跟当前位置之前的tocke…...

产品运营的场景和运营策略

一、启动屏 1.概念 启动屏,特指 APP 产品启动时即显示的界面,这个界面一般会停留几秒钟时间,在这个时间内 APP 会在后台加载服务框架、启动各种服务 SDK 、获取用户地理位置、判断有无新版本、判断用户账户状态以及其他系统级别的…...

C#异常捕获try catch详细介绍

在C#中,异常处理是通过try、catch、finally和throw语句来实现的,它们提供了一种结构化和可预测的方法来处理运行时错误。 C#异常基本用法 try块 异常处理以try块开始,try块包含可能会引发异常的代码。如果在try块中的代码执行过程中发生了…...

切换阿里云ES方式及故障应急处理方案

一、阿里云es服务相关问题及答解 1.1 ES7.10扩容节点时间 增加节点数量需要节点拉起和数据Rebalance两步,拉起时间7.16及以上的新版本大概10分钟以内,7.16以前大概一小时,数据迁移的时间就看数据量了,一般整体在半小时以内 (需进行相关测试验证) 1.2 ES7.10扩容数据节点…...

CTFhub-RCE-过滤空格

1. 查看当前目录&#xff1a;127.0.0.1|ls 2. 查看 flag_890277429145.php 127.0.0.1|cat flag_890277429145.php 根据题目可以知道空格被过滤掉了 3.空格可以用以下字符代替&#xff1a; < 、>、<>、%20(space)、%09(tab)、$IFS$9、 ${IFS}、$IFS等 $IFS在li…...

无需添加udid,ios企业证书的自助生成方法

我们开发uniapp的app的时候&#xff0c;需要苹果证书去打包。 假如申请的是个人或company类型的苹果开发者账号&#xff0c;必须上架才能安装&#xff0c;异常的麻烦&#xff0c;但是有一些app&#xff0c;比如企业内部使用的app&#xff0c;是不需要上架苹果应用市场的。 假…...

【PTA题目】6-20 使用函数判断完全平方数 分数 10

6-20 使用函数判断完全平方数 分数 10 全屏浏览题目 切换布局 作者 张高燕 单位 浙大城市学院 本题要求实现一个判断整数是否为完全平方数的简单函数。 函数接口定义&#xff1a; int IsSquare( int n ); 其中n是用户传入的参数&#xff0c;在长整型范围内。如果n是完全…...

Nas搭建webdav服务器并同步Zotero科研文献

无需云盘&#xff0c;不限流量实现Zotero跨平台同步&#xff1a;内网穿透私有WebDAV服务器 文章目录 无需云盘&#xff0c;不限流量实现Zotero跨平台同步&#xff1a;内网穿透私有WebDAV服务器一、Zotero安装教程二、群晖NAS WebDAV设置三、Zotero设置四、使用公网地址同步Zote…...

一句话总结敏捷实践中不同方法

敏捷实践是指一组优先考虑灵活性、协作和客户满意度的软件开发和项目管理原则和方法。 不同方法论的敏捷实践&#xff1a; 1、敏捷&#xff1a; Sprints&#xff1a;限时迭代&#xff08;通常 2-4 周&#xff09;&#xff0c;在此期间创建潜在的可交付产品增量。每日站立会议…...

【数据结构】线段树(点修区查)

数据结构-线段树&#xff08;点修区查&#xff09; 前置知识 分治递归二叉树 思路 我们需要维护一个支持单点修改&#xff0c;区间查询的数据结构&#xff0c;并且要求在线&#xff0c;一般使用线段树解决。 线段树是一个二叉树形的数据结构。 线段树的思想很简单&#xff0c…...

Ansys Lumerical | 用于增强现实系统的表面浮雕光栅

在本示例中&#xff0c;我们使用 RCWA 求解器设计了一个斜面浮雕光栅 (SRG)&#xff0c;它将用于将光线耦合到单色增强现实 (AR) 系统的波导中。光栅的几何形状经过优化&#xff0c;可将正常入射光导入-1 光栅阶次。 然后我们将光栅特性导出为 Lumerical Sub-Wavelength Model …...

QT day3作业

1.思维导图 2、 完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar&#xff1a;依赖注入与仓储模式实践 在 C# 的应用开发中&#xff0c;数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护&#xff0c;许多开发者会选择成熟的 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;SqlSugar 就是其中备受…...

css3笔记 (1) 自用

outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size&#xff1a;0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格&#xff…...

scikit-learn机器学习

# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...

【JVM】Java虚拟机(二)——垃圾回收

目录 一、如何判断对象可以回收 &#xff08;一&#xff09;引用计数法 &#xff08;二&#xff09;可达性分析算法 二、垃圾回收算法 &#xff08;一&#xff09;标记清除 &#xff08;二&#xff09;标记整理 &#xff08;三&#xff09;复制 &#xff08;四&#xff…...

C语言中提供的第三方库之哈希表实现

一. 简介 前面一篇文章简单学习了C语言中第三方库&#xff08;uthash库&#xff09;提供对哈希表的操作&#xff0c;文章如下&#xff1a; C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

OD 算法题 B卷【正整数到Excel编号之间的转换】

文章目录 正整数到Excel编号之间的转换 正整数到Excel编号之间的转换 excel的列编号是这样的&#xff1a;a b c … z aa ab ac… az ba bb bc…yz za zb zc …zz aaa aab aac…; 分别代表以下的编号1 2 3 … 26 27 28 29… 52 53 54 55… 676 677 678 679 … 702 703 704 705;…...

深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏

一、引言 在深度学习中&#xff0c;我们训练出的神经网络往往非常庞大&#xff08;比如像 ResNet、YOLOv8、Vision Transformer&#xff09;&#xff0c;虽然精度很高&#xff0c;但“太重”了&#xff0c;运行起来很慢&#xff0c;占用内存大&#xff0c;不适合部署到手机、摄…...