当前位置: 首页 > news >正文

实现一个简单的线性回归和多项式回归(2)

对于多项式回归,可以同样使用前面线性回归中定义的LinearRegression算子、训练函数train、均方误差函数mean_squared_error,生成数据集create_toy_data,这里就不多做赘述咯~

        拟合的函数为

def sin(x):y = torch.sin(2 * math.pi * x)return y

1.数据集的建立

func = sin
interval = (0,1)
train_num = 15
test_num = 10
noise = 0.5
X_train, y_train = create_toy_data(func=func, interval=interval, sample_num=train_num, noise = noise)
X_test, y_test = create_toy_data(func=func, interval=interval, sample_num=test_num, noise = noise)X_underlying = torch.linspace(interval[0], interval[1], 100)
y_underlying = sin(X_underlying)# 绘制图像
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")
#plt.scatter(X_test, y_test, facecolor="none", edgecolor="r", s=50, label="test data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.legend(fontsize='x-large')
plt.savefig('ml-vis2.pdf')
plt.show()

        生成结果: 

2.模型构建

# 多项式转换
def polynomial_basis_function(x, degree=2):"""输入:- x: tensor, 输入的数据,shape=[N,1]- degree: int, 多项式的阶数example Input: [[2], [3], [4]], degree=2example Output: [[2^1, 2^2], [3^1, 3^2], [4^1, 4^2]]注意:本案例中,在degree>=1时不生成全为1的一列数据;degree为0时生成形状与输入相同,全1的Tensor输出:- x_result: tensor"""if degree == 0:return torch.ones(x.size(), dtype=torch.float32)x_tmp = xx_result = x_tmpfor i in range(2, degree + 1):x_tmp = torch.multiply(x_tmp, x)  # 逐元素相乘x_result = torch.concat((x_result, x_tmp), dim=-1)return x_result

        我的理解多项式回归的模型的建立更像是将原来的一个x变成x x^2 x^3 ... x^degree这个过程

3.模型训练

plt.rcParams['figure.figsize'] = (12.0, 8.0)for i, degree in enumerate([0, 1, 3, 8]):  # []中为多项式的阶数model = Linear(degree)X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), degree)X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), degree)model = optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))  # 拟合得到参数y_underlying_pred = model(X_underlying_transformed).squeeze()print(model.params)# 绘制图像plt.subplot(2, 2, i + 1)plt.scatter(X_train, y_train, facecolor="none", edgecolor='#e4007f', s=50, label="train data")plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")plt.plot(X_underlying, y_underlying_pred, c='#f19ec2', label="predicted function")plt.ylim(-2, 1.5)plt.annotate("M={}".format(degree), xy=(0.95, -1.4))# plt.legend(bbox_to_anchor=(1.05, 0.64), loc=2, borderaxespad=0.)
plt.legend(loc='lower left', fontsize='x-large')
plt.savefig('ml-vis3.pdf')
plt.show()

         训练结果:

观察可视化结果,红色的曲线表示不同阶多项式分布拟合数据的结果:

* 当 $M=0$$M=1$时,拟合曲线较简单,模型欠拟合;

* 当 $M=8$ 时,拟合曲线较复杂,模型过拟合;

* 当 $M=3$ 时,模型拟合最为合理。

4.模型评估

        下面通过均方误差来衡量训练误差、测试误差以及在没有噪音的加入下sin函数值与多项式回归值之间的误差,更加真实地反映拟合结果。多项式分布阶数从0到8进行遍历。

# 训练误差和测试误差
training_errors = []
test_errors = []
distribution_errors = []# 遍历多项式阶数
for i in range(9):model = Linear(i)X_train_transformed = polynomial_basis_function(X_train.reshape([-1, 1]), i)X_test_transformed = polynomial_basis_function(X_test.reshape([-1, 1]), i)X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1, 1]), i)optimizer_lsm(model, X_train_transformed, y_train.reshape([-1, 1]))y_train_pred = model(X_train_transformed).squeeze()y_test_pred = model(X_test_transformed).squeeze()y_underlying_pred = model(X_underlying_transformed).squeeze()train_mse = mean_squared_error(y_true=y_train, y_pred=y_train_pred).item()training_errors.append(train_mse)test_mse = mean_squared_error(y_true=y_test, y_pred=y_test_pred).item()test_errors.append(test_mse)# distribution_mse = mean_squared_error(y_true=y_underlying, y_pred=y_underlying_pred).item()# distribution_errors.append(distribution_mse)print("train errors: \n", training_errors)
print("test errors: \n", test_errors)
# print ("distribution errors: \n", distribution_errors)# 绘制图片
plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.plot(training_errors, '-.', mfc="none", mec='#e4007f', ms=10, c='#e4007f', label="Training")
plt.plot(test_errors, '--', mfc="none", mec='#f19ec2', ms=10, c='#f19ec2', label="Test")
# plt.plot(distribution_errors, '-', mfc="none", mec="#3D3D3F", ms=10, c="#3D3D3F", label="Distribution")
plt.legend(fontsize='x-large')
plt.xlabel("degree")
plt.ylabel("MSE")
plt.savefig('ml-mse-error.pdf')
plt.show()

        可视化结果如下:

由可视化结果可得:

  1. 当阶数较低的时候,模型的表示能力有限,训练误差和测试误差都很高,代表模型欠拟合;
  2. 当阶数较高的时候,模型表示能力强,但将训练数据中的噪声也作为特征进行学习,一般情况下训练误差继续降低而测试误差显著升高,代表模型过拟合。

        对于模型过拟合的情况,可以引入正则化方法,通过向误差函数中添加一个惩罚项来避免系数倾向于较大的取值。

degree = 8 # 多项式阶数
reg_lambda = 0.0001 # 正则化系数X_train_transformed = polynomial_basis_function(X_train.reshape([-1,1]), degree)
X_test_transformed = polynomial_basis_function(X_test.reshape([-1,1]), degree)
X_underlying_transformed = polynomial_basis_function(X_underlying.reshape([-1,1]), degree)model = Linear(degree) optimizer_lsm(model,X_train_transformed,y_train.reshape([-1,1]))y_test_pred=model(X_test_transformed).squeeze()
y_underlying_pred=model(X_underlying_transformed).squeeze()model_reg = Linear(degree) optimizer_lsm(model_reg,X_train_transformed,y_train.reshape([-1,1]),reg_lambda=reg_lambda)y_test_pred_reg=model_reg(X_test_transformed).squeeze()
y_underlying_pred_reg=model_reg(X_underlying_transformed).squeeze()mse = mean_squared_error(y_true = y_test, y_pred = y_test_pred).item()
print("mse:",mse)
mes_reg = mean_squared_error(y_true = y_test, y_pred = y_test_pred_reg).item()
print("mse_with_l2_reg:",mes_reg)# 绘制图像
plt.scatter(X_train, y_train, facecolor="none", edgecolor="#e4007f", s=50, label="train data")
plt.plot(X_underlying, y_underlying, c='#000000', label=r"$\sin(2\pi x)$")
plt.plot(X_underlying, y_underlying_pred, c='#e4007f', linestyle="--", label="$deg. = 8$")
plt.plot(X_underlying, y_underlying_pred_reg, c='#f19ec2', linestyle="-.", label="$deg. = 8, \ell_2 reg$")
plt.ylim(-1.5, 1.5)
plt.annotate("lambda={}".format(reg_lambda), xy=(0.82, -1.4))
plt.legend(fontsize='large')
plt.savefig('ml-vis4.pdf')
plt.show()

可视化结果为:

         多项式回归的难点就是是否真正理解了线性回归中的算子,均方误差等函数的目的和用法,其他的就是简单的函数调用问题啦~不懂的话还是建议多看看线性函数,先把线性函数的看明白最好~

相关文章:

实现一个简单的线性回归和多项式回归(2)

对于多项式回归,可以同样使用前面线性回归中定义的LinearRegression算子、训练函数train、均方误差函数mean_squared_error,生成数据集create_toy_data,这里就不多做赘述咯~ 拟合的函数为 def sin(x):y torch.sin(2 * math.pi * x)return y1.数据集的建…...

云原生Kubernetes:K8S集群kubectl命令汇总

目录 一、理论 1.概念 2. kubectl 帮助方法 3.kubectl 子命令使用分类 4.使用kubectl 命令的必要环境 5.kubectl 详细命令 一、理论 1.概念 kubectl是一个命令行工具,通过跟 K8S 集群的 API Server 通信,来执行集群的管理工作。 kubectl命令是操…...

Java使用模板导出word、pdf

使用deepoove根据模板导出word文档&#xff0c;包括文本、表格、图表、图片&#xff0c;使用WordConvertPdf可将word文档转换为pdf导出 模板样例&#xff1a; 导出结果&#xff1a; 一、引入相关依赖 <!-- 工具类--><dependency><groupId>cn.hutool&…...

速通Redis基础(二):掌握Redis的哈希类型和命令

目录 Redis 哈希类型简介 Redis 哈希命令 HSET HGET HEXISTS HDEL HKEYS HVALS HGETALL HMGET HLEN HSETNX ​编辑 HINCRBY HINCRBYFLOAT Redis的哈希类型命令小结 Redis 是一种高性能的键值存储数据库&#xff0c;支持多种数据类型&#xff0c;其中之…...

WebDAV之π-Disk派盘 + 书藏家

书藏家是一款书籍收藏的软件,对于喜欢阅读书籍的用户来说非常友好,记录你所阅读的书籍内容,对你所阅读的书籍内容进行全方位的管理,并且支持多种录入的方式,不管是实体书籍还是网络书籍都能够进行更为有效的管理;内置WebDAV 模块,更加便利的整理自己的文件资源;书藏家的…...

香港Web3.0生态现状

目前香港Web3.0生态正在快速发展。香港政府和金融机构正在积极推动Web3.0生态的建设&#xff0c;以推动数字经济和智慧城市的发展。香港政府已经发布了有关虚拟资产发展的政策宣言&#xff0c;鼓励和监管并重&#xff0c;加大力度推动虚拟资产产业向前发展。同时&#xff0c;香…...

LLMs之BELLE:源码解读(sft_train.py文件)

LLMs之BELLE:源码解读(sft_train.py文件) 目录 源码解读(sft_train.py文件) # 1、解析命令行参数,包括模型参数、数据参数和训练参数。...

【UE5 Cesium】17-Cesium for Unreal 建立飞行跟踪器(2)

目录 效果 步骤 一、飞机沿航线飞行 二、通过切换相机实现在不同角度观察飞机飞行 效果 步骤 一、飞机沿航线飞行 先去模型网站下载一个波音737飞机模型 然后将下载好的模型导入到UE项目中&#xff0c;导入时需要勾选“合并网格体”&#xff08;导入前最好在建模软件中将…...

【ElasticSearch】基于 Java 客户端 RestClient 实现对 ElasticSearch 索引库、文档的增删改查操作,以及文档的批量导入

文章目录 前言一、对 Java RestClient 的认识1.1 什么是 RestClient1.2 RestClient 核心类&#xff1a;RestHighLevelClient 二、使用 Java RestClient 操作索引库2.1 根据数据库表编写创建 ES 索引的 DSL 语句2.2 初始化 Java RestClient2.2.1 在 Spring Boot 项目中引入 Rest…...

【Node.js】stream 流模块

流是一种抽象的数据结构。从键盘输入到应用程序就是标准输入流&#xff08;stdin&#xff09;。应用程序把字符一个一个输出到显示器上叫做&#xff1a;标准输出流&#xff08;stdout&#xff09;。 流的特点是数据是有序的&#xff0c;而且必须依次读取&#xff0c;或者依次写…...

【LeetCode】——链式二叉树经典OJ题详解

主页点击直达&#xff1a;个人主页 我的小仓库&#xff1a;代码仓库 C语言偷着笑&#xff1a;C语言专栏 数据结构挨打小记&#xff1a;初阶数据结构专栏 Linux被操作记&#xff1a;Linux专栏 LeetCode刷题掉发记&#xff1a;LeetCode刷题 算法头疼记&#xff1a;算法专栏…...

代码注释对于程序员重要吗?

程序员对代码注释可以说是又爱又恨又双标……你是怎么看待程序员不写注释这一事件的呢&#xff1f; 代码注释的重要性 代码注释是指在程序代码中添加的解释性说明&#xff0c;用于描述代码的功能、目的、使用方法等。代码注释对于程序的重要性主要体现在以下几个方面&#x…...

OpenHamony开发笔记一:在HarmonyOS虚拟机上运行openharmony工程

在HarmonyOS的虚拟机上要运行openharmony的工程时需要修改的地方有 1.修改build-profile.json5&#xff0c;将runtimeOS改为HarmonyOS "targets": [{"name": "default","runtimeOS": "HarmonyOS"}, 2.修改工程引用的SDK&a…...

C++程序员入门需要怎么学?(InsCode AI 创作助手)

文章目录 &#xff08;一&#xff09;学习C概念&#xff08;二&#xff09;C主要应用场景和相关产品&#xff08;三&#xff09;学习C流程1. 学习C语法和基本示例&#xff1a;2. 深入学习面向对象编程&#xff08;OOP&#xff09;&#xff1a;3. 使用C标准库&#xff1a;4. 解决…...

Intel 高性能库之IPP信号处理简介及下载(版本5.1,含32位和64位及注册)

IPP是什么 IPP:Intel Integrated Performance Primitives 英特尔集成性能基元(英特尔IPP)是一款多核就绪的扩展函数库,其中包含众多针对多媒体、数据处理和通信应用高度优化的软件函数。它包括: 视频编码:用于 DV25/50/100、MPEG-2、MPEG-4、H.263 和 MPEG-4 Part 10 …...

【C++】运算符重载案例 - 字符串类 ② ( 重载 等号 = 运算符 | 重载 数组下标 [] 操作符 | 完整代码示例 )

文章目录 一、重载 等号 运算符1、等号 运算符 与 拷贝构造函数2、重载 等号 运算符 - 右操作数为 String 对象3、不同的右操作数对应的 重载运算符函数 二、重载 下标 [] 运算符三、完整代码示例1、String.h 类头文件2、String.cpp 类实现3、Test.cpp 测试类4、执行结果 一…...

Vue脚手架开发流程

一、项目运行时会先执行 public / index.html 文件 <!DOCTYPE html> <html lang""><head><meta charset"utf-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport&quo…...

从零开始学习线性回归:理论、实践与PyTorch实现

文章目录 &#x1f966;介绍&#x1f966;基本知识&#x1f966;代码实现&#x1f966;完整代码&#x1f966;总结 &#x1f966;介绍 线性回归是统计学和机器学习中最简单而强大的算法之一&#xff0c;用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性…...

[LeetCode]链式二叉树相关题目(c语言实现)

文章目录 LeetCode965. 单值二叉树LeetCode100. 相同的树LeetCode101. 对称二叉树LeetCode144. 二叉树的前序遍历LeetCode94. 二叉树的中序遍历LeetCode145. 二叉树的后序遍历LeetCode572. 另一棵树的子树 LeetCode965. 单值二叉树 题目 Oj链接 思路 一棵树的所有值都是一个…...

集成学习

集成学习&#xff08;Ensemble Learning) - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/27689464集成学习就是组合这里的多个弱监督模型以期得到一个更好更全面的强监督模型&#xff0c;集成学习潜在的思想是即便某一个弱分类器得到了错误的预测&#xff0c;其他的弱分类器…...

Cursor实现用excel数据填充word模版的方法

cursor主页&#xff1a;https://www.cursor.com/ 任务目标&#xff1a;把excel格式的数据里的单元格&#xff0c;按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例&#xff0c;…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK&#xff0c;开始写第二篇的内容了。这篇博客主要能写一下&#xff1a; 如何给一些三方库按照xmake方式进行封装&#xff0c;供调用如何按…...

黑马Mybatis

Mybatis 表现层&#xff1a;页面展示 业务层&#xff1a;逻辑处理 持久层&#xff1a;持久数据化保存 在这里插入图片描述 Mybatis快速入门 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6501c2109c4442118ceb6014725e48e4.png //logback.xml <?xml ver…...

R语言AI模型部署方案:精准离线运行详解

R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

java 实现excel文件转pdf | 无水印 | 无限制

文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

Nuxt.js 中的路由配置详解

Nuxt.js 通过其内置的路由系统简化了应用的路由配置&#xff0c;使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

Android15默认授权浮窗权限

我们经常有那种需求&#xff0c;客户需要定制的apk集成在ROM中&#xff0c;并且默认授予其【显示在其他应用的上层】权限&#xff0c;也就是我们常说的浮窗权限&#xff0c;那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...