当前位置: 首页 > article >正文

机器学习实战(3):线性回归——预测连续变量

第3集:线性回归——预测连续变量

在机器学习的世界中,线性回归是最基础、最直观的算法之一。它用于解决回归问题,即预测连续变量(如房价、销售额等)。尽管简单,但线性回归却是许多复杂模型的基石。今天我们将深入探讨线性回归的基本原理,并通过实践部分使用 Boston 房价数据集 构建一个线性回归模型。


在这里插入图片描述

线性回归的基本原理

什么是线性回归?

线性回归是一种监督学习算法,其目标是找到一条直线(或超平面),使得这条直线能够最好地拟合数据点。对于单变量线性回归,公式如下:
y = w 0 + w 1 x y = w_0 + w_1x y=w0+w1x
其中:
y y y 是目标变量(预测值)。
x x x 是输入特征。
w 0 w_0 w0 是截距(偏置项)。
w 1 w_1 w1 是权重(斜率)。

图1:线性回归示意图
(图片描述:波士顿房价预测)
在这里插入图片描述

在线性回归中,我们的任务是找到最佳的 w 0 w_0 w0 w 1 w_1 w1,使得预测值与真实值之间的误差最小化。


损失函数与梯度下降法

损失函数

为了衡量模型的好坏,我们定义了一个损失函数(Loss Function)。最常用的损失函数是 均方误差(Mean Squared Error, MSE)
M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=n1i=1n(yiy^i)2
其中:
n n n 是样本数量。
y i y_i yi 是第 i 个样本的真实值。
y ^ i \hat{y}_i y^i 是第 i 个样本的预测值。

MSE 的目标是让所有样本的预测误差平方和最小。

梯度下降法

梯度下降是一种优化算法,用于最小化损失函数。其核心思想是沿着损失函数的负梯度方向更新参数 w w w,直到达到最优解。更新公式为:
w : = w − α ⋅ ∂ J ( w ) ∂ w w := w - \alpha \cdot \frac{\partial J(w)}{\partial w} w:=wαwJ(w)
其中:
α \alpha α 是学习率(控制步长)。
J ( w ) J(w) J(w) 是损失函数。

图2:梯度下降过程
(图片描述:三维曲面表示损失函数,小球从高处滚向最低点,代表参数逐步优化的过程。)
在这里插入图片描述


多元线性回归模型

当输入特征不止一个时,我们使用 多元线性回归。公式扩展为:
y = w 0 + w 1 x 1 + w 2 x 2 + . . . + w p x p y = w_0 + w_1x_1 + w_2x_2 + ... + w_px_p y=w0+w1x1+w2x2+...+wpxp
这可以写成矩阵形式:
y = X ⋅ w \mathbf{y} = \mathbf{X} \cdot \mathbf{w} y=Xw
其中:
y \mathbf{y} y 是目标变量向量。
X \mathbf{X} X 是特征矩阵。
w \mathbf{w} w 是权重向量。


如何评估回归模型性能

评估回归模型的性能通常使用以下指标:

1. 均方误差(MSE)

MSE 衡量预测值与真实值之间的平均误差平方。越小越好。

2. 决定系数(R²)

R² 表示模型对数据变异性的解释能力,取值范围为 [0, 1]。越接近 1,说明模型拟合效果越好。

示例代码(Python实现):
from sklearn.metrics import mean_squared_error, r2_score# 计算 MSE 和 R²
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print(f"MSE: {mse:.2f}, R²: {r2:.2f}")

实践部分:使用 Boston 房价数据集构建线性回归模型

数据集简介

Boston 房价数据集包含波士顿地区房屋价格及其相关特征,共有 506 条记录和 13 个特征。目标是预测房屋的中位数价格(单位:千美元)。

完整代码

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score# 加载数据
boston = load_boston()
data = pd.DataFrame(boston.data, columns=boston.feature_names)
data['PRICE'] = boston.target# 分割数据集
X = data.drop('PRICE', axis=1)
y = data['PRICE']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 构建线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)# 预测
y_pred = model.predict(X_test)# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print(f"模型评估结果:")
print(f"MSE: {mse:.2f}, R²: {r2:.2f}")

运行结果:

模型评估结果:
MSE: 24.29, R²: 0.67

可视化展示

在波士顿房价预测任务中,通过可视化展示线性回归的预测直线和散点图可以帮助我们直观地理解模型的拟合效果。以下增加可视化模块的完整实现。


完整代码(包含可视化模块)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score# 加载数据
boston = load_boston()
data = pd.DataFrame(boston.data, columns=boston.feature_names)
data['PRICE'] = boston.target# 分割数据集
X = data[['RM']]  # 使用房间数(RM)作为单一特征进行可视化
y = data['PRICE']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 构建线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)# 预测
y_pred = model.predict(X_test)# 评估模型性能
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)print(f"模型评估结果:")
print(f"MSE: {mse:.2f}, R²: {r2:.2f}")# 可视化模块:绘制预测直线与散点图
plt.figure(figsize=(10, 6))# 将 X_test 和 y_pred 转换为 NumPy 数组,并确保是一维的
X_test_array = X_test.values.flatten()  # 转换为一维数组
y_pred_array = y_pred.flatten()         # 转换为一维数组# 绘制真实值的散点图
plt.scatter(X_test_array, y_test, color='blue', label='True Values', alpha=0.6)# 绘制预测直线
plt.plot(X_test_array, y_pred_array, color='red', linewidth=2, label='Predicted Line')# 添加标题和标签
plt.title('Linear Regression: Predicted vs True Values', fontsize=16)
plt.xlabel('Average Number of Rooms (RM)', fontsize=12)
plt.ylabel('House Price (in $1000s)', fontsize=12)
plt.legend()# 显示图表
plt.show()

代码解析
1. 选择单一特征进行可视化

为了简化可视化过程,我们选择了 RM(每栋住宅的平均房间数)作为唯一特征。这样可以将问题从多元线性回归降维到一元线性回归,便于绘制二维散点图和预测直线。

2. 绘制散点图
  • 使用 plt.scatter 绘制测试集中真实房价与房间数的关系。
  • 设置颜色为蓝色,透明度为 0.6,以便更好地观察数据分布。
3. 绘制预测直线
  • 使用 plt.plot 绘制线性回归模型的预测直线。
  • 预测值由 model.predict(X_test) 计算得出。
  • 设置颜色为红色,线宽为 2,突出显示预测直线。
4. 添加标题、标签和图例
  • 图表标题说明了可视化内容。
  • 添加 x 轴和 y 轴标签,分别表示房间数和房价。
  • 使用 plt.legend() 添加图例,区分真实值和预测值。

可视化结果

图1:线性回归预测直线与散点图
(图片描述:二维平面上展示了测试集的真实房价(蓝色散点)和线性回归模型的预测直线(红色)。大部分散点分布在直线附近,表明模型具有一定的拟合能力。)
在这里插入图片描述


通过增加可视化模块,我们可以直观地看到线性回归模型如何拟合数据。这种可视化方法不仅有助于理解模型的表现,还能帮助发现潜在的问题(如欠拟合或过拟合)。


总结

本文介绍了线性回归的核心概念,包括基本原理、损失函数、梯度下降法以及模型评估方法。通过实践部分,我们成功使用 Boston 房价数据集构建了一个线性回归模型,并对其性能通过数据分析和可视化进行了评估。是一篇非常具有实战价值的文章。

尽管线性回归简单易懂,但它仍然是许多实际问题的首选工具。希望这篇文章能帮助你更好地理解这一经典算法!


下集预告:第4集:逻辑回归——分类问题的基础

参考资料

  • Scikit-learn 文档: https://scikit-learn.org/stable/documentation.html
  • Boston 房价数据集: https://www.cs.toronto.edu/~delve/data/boston/bostonDetail.html

相关文章:

机器学习实战(3):线性回归——预测连续变量

第3集:线性回归——预测连续变量 在机器学习的世界中,线性回归是最基础、最直观的算法之一。它用于解决回归问题,即预测连续变量(如房价、销售额等)。尽管简单,但线性回归却是许多复杂模型的基石。今天我们…...

【AI-34】机器学习常用七大算法

以下是对这七大常用算法的浅显易懂解释: 1. k 邻近算法(k - Nearest Neighbors,KNN) 想象你在一个满是水果的大广场上,现在有个不认识的水果,想知道它是什么。k 邻近算法就是去看离这个水果最近的 k 个已…...

【漫话机器学习系列】093.代价函数和损失函数(Cost and Loss Functions)

代价函数和损失函数(Cost and Loss Functions)详解 1. 引言 在机器学习和深度学习领域,代价函数(Cost Function)和损失函数(Loss Function)是核心概念,它们决定了模型的优化方向。…...

ThreadLocal为什么会内存溢出

每个线程(Thread 对象)内部维护一个 ThreadLocalMap,用于存储该线程的所有 ThreadLocal 变量的键值对: ThreadLocalMap虽然是ThreadLocal的静态内部类,但是Thread 对象的属性,当线程存活时ThreadLocalMap不会被回收。 Key:ThreadLocal 实例的 弱引用(WeakReference)。…...

LabVIEW 天然气水合物电声联合探测

天然气水合物被认为是潜在的清洁能源,其储量丰富,预计将在未来能源格局中扮演重要角色。由于其独特的物理化学特性,天然气水合物的探测面临诸多挑战,涉及温度、压力、电学信号、声学信号等多个参数。传统的人工操作方式不仅效率低…...

【记忆化搜索】最长递增子序列

文章目录 300. 最长递增子序列解题思路:递归 -> 记忆化搜索 300. 最长递增子序列 300. 最长递增子序列 ​ 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 ​ 子序列 是由数组派生而来的序列,删除(或不删除&am…...

Tomcat的升级

一、为什么Tomcat需要升级 在生产环境中,我们都会指定对应的Tomcat版本进行安排配置,但是由于Tomcat厂商对于小版本的更新迭代会将一些Bug修复,这个时候在生产中出现问题/预防出现问题,可以通过小版本的升级解决前提:…...

4-制作UI

创建模块文件夹 Unity编辑器->Tools->YIUI自动化工具,在新增模块名称那里输入模块名字并点击创建。便可看到在GameRes/YIUI文件夹下有新建的文件夹与内容了。里面包含图集、预制体、Sprites。如果进行预制体的修改,则需要双击进入再修改&#xff0…...

零基础学习人工智能

零基础学习人工智能是一个既充满挑战又极具潜力的过程。以下是一份详细的学习指南,旨在帮助零基础的学习者有效地踏入人工智能领域。 一、理解基本概念 在学习人工智能之前,首先要对其基本概念有一个清晰的认识。人工智能(AI)是…...

vue3+element-plus中的el-table表头和el-table-column内容全部一行显示完整(hook函数)

hook函数封装 export const useTableColumnWidth _this > {const { refTable } _thisconst columnWidthObj ref()const getTableColumnWidth cb > {nextTick(() > {columnWidthObj.value {}// 获取行rowsconst tableEle refTable?.refBaseTable?.$elif (!tab…...

Word写论文常用操作的参考文章

1.插入多个引用文献:word中交叉引用多篇参考文献格式[1-2]操作以及显示错误问题 更改左域名,输入 \#"[0" 更改右域名,输入 \#"0]" 2.插入题注:word 中添加图片题注、目录、内部链接 3.插入公式编号&#x…...

深度学习在蛋白质-蛋白质相互作用(PPI)领域的研究进展(2022-2025)

一、蛋白质-蛋白质相互作用(PPI)的定义与生物学意义 蛋白质-蛋白质相互作用(Protein-Protein Interaction, PPI)是指两个或多个蛋白质通过物理结合形成复合物,进而调控细胞信号传导、代谢、免疫应答等生命活动的过程。…...

C++基础知识(三)之结构体、共同体、枚举、引用、函数重载

九、结构体、共同体和枚举 1、结构体的基本概念 结构体是用户自定义的类型,可以将多种数据的表示合并到一起,描述一个完整的对象。 使用结构体有两个步骤:1)定义结构体描述(类型);2&#xff…...

【java】方法的值传递

在 Java 中,方法的值传递 是指将实参的值传递给方法的形参。Java 中只有 值传递,没有引用传递。具体来说: 对于 基本数据类型,传递的是值的副本。 对于 引用数据类型,传递的是引用的副本(即地址的副本&…...

DeepSeek 助力 Vue 开发:打造丝滑的开关切换(Switch)

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…...

使用Python爬虫实时监控行业新闻案例

目录 背景环境准备请求网页数据解析网页数据定时任务综合代码使用代理IP提升稳定性运行截图与完整代码总结 在互联网时代,新闻的实时性和时效性变得尤为重要。很多行业、技术、商业等领域的新闻都可以为公司或者个人发展提供有价值的信息。如果你有一项需求是要实时…...

Centos搭建python环境

在 CentOS 上配置 Python 环境可以通过以下步骤完成: 1. 检查系统自带 Python 版本 CentOS 7/8 可能已经自带了 Python: python3 --version 如果没有,或者版本过低,可以手动安装。 2. 安装 Python(推荐&#xff0…...

语言大模型基础概念 一(先了解听说过的名词都是什么)

SFT(监督微调)和RLHF(基于人类反馈的强化学习)的区别 STF(Supervised Fine-Tuning)和RLHF(Reinforcement Learning from Human Feedback)是两种不同的模型训练方法,分别…...

DeepSeek-R1 蒸馏 Qwen 和 Llama 架构 企业级RAG知识库

“DeepSeek-R1的输出,蒸馏了6个小模型”意思是利用DeepSeek-R1这个大模型的输出结果,通过知识蒸馏技术训练出6个参数规模较小的模型,以下是具体解释: - **知识蒸馏技术原理**:知识蒸馏是一种模型压缩技术,核…...

ubuntu服务器 如何配置安全加固措施

下面提供一个更详细、一步步的服务器安全加固指南,适合新手操作。我们将从 Fail2Ban、SSH(密钥认证及端口更改)、Nginx 速率限制和日志轮转四个方面进行优化,同时补充一些额外的安全建议。 新的服务器,通常我们会创建一…...

DeepSeek v3 技术报告阅读笔记

注 本文参考 DeepSeek-v3 / v2 / v1 Technical Report 及相关参考模型论文本文不包括基础的知识点讲解,为笔记/大纲性质而非教程,建议阅读技术报告原文交流可发送至邮箱 henryhua0721foxmail.com 架构核心 核心: MLA 高效推理DeepSeekMOE 更…...

Spring 事务及管理方式

Spring 事务管理是 Spring 框架的核心功能之一,它为开发者提供了一种方便、灵活且强大的方式来管理数据库事务。 1、事务的基本概念 事务是一组不可分割的操作序列,这些操作要么全部成功执行,要么全部失败回滚,以确保数据的一致…...

GESP2024年9月认证C++七级( 第三部分编程题(1)小杨寻宝)

参考程序&#xff1a; #include <bits/stdc.h> using namespace std; const int N 1e510; vector<int> g[N]; // 图的邻接表 int col[N], dep[N], has[N];// 深度优先遍历&#xff0c;计算每个节点的深度 void dfs(int x, int fa) {dep[x] dep[fa] 1; // 计算…...

Pandas数据填充(fill)中的那些坑:避免机器学习中的数据泄露

1. 问题背景 在处理时间序列数据时,经常会遇到缺失值需要填充。Pandas提供了ffill(forward fill)和bfill(backward fill)两种填充方式,但使用不当可能会导致数据泄露,特别是在进行机器学习预测时。 2. 填充方式解析 2.1 基本概念 ffill(forward fill): 用前面的值填充后面的…...

ubuntu 安装vnc之后,本地黑屏,vnc正常

ubuntu 安装vnc之后,本地黑屏,vnc正常 在Ubuntu系统中安装VNC服务器&#xff08;如TightVNC或RealVNC&#xff09;后&#xff0c;如果遇到连接时本地屏幕变黑的情况&#xff0c;可能是由于几种不同的配置或兼容性问题。以下是一些解决步骤&#xff0c;可以帮助你解决这个问题&…...

解锁电商数据宝藏:淘宝商品详情API实战指南

在电商蓬勃发展的今天&#xff0c;数据已成为驱动业务增长的核心引擎。对于商家、开发者以及数据分析师而言&#xff0c;获取精准、实时的商品数据至关重要。而淘宝&#xff0c;作为国内最大的电商平台&#xff0c;其海量商品数据更是蕴含着巨大的价值。 本文将带你深入探索淘…...

webshell通信流量分析

环境安装 Apatche2 php sudo apt install apache2 -y sudo apt install php libapache2-mod-php php-mysql -y echo "<?php phpinfo(); ?>" | sudo tee /var/www/html/info.php sudo ufw allow Apache Full 如果成功访问info.php&#xff0c;则环境安…...

在 rtthread中,rt_list_entry (rt_container_of) 已知结构体成员的地址,反推出结构体的首地址

rt_list_entry (rt_container_of)宏定义&#xff1a; /*** rt_container_of - return the start address of struct type, while ptr is the* member of struct type.*/ #define rt_container_of(ptr, type, member) \((type *)((char *)(ptr) - (unsigned long)(&((type *…...

趣味魔法项目 LinuxPDF —— 在 PDF 中启动一个 Linux 操作系统

最近&#xff0c;一位开源爱好者开发了一个LinuxPDF 项目&#xff08;ading2210/linuxpdf: Linux running inside a PDF file via a RISC-V emulator&#xff09;&#xff0c;它的核心功能是在一个 PDF 文件中启动并运行 Linux 操作系统。它通过巧妙地使用 PDF 文件格式中的 Ja…...

DeepSeek教unity------MessagePack-03

数据契约兼容性 你可以使用 [DataContract] 注解代替 [MessagePackObject]。如果类型用 DataContract 进行注解&#xff0c;可以使用 [DataMember] 注解代替 [Key]&#xff0c;并使用 [IgnoreDataMember] 代替 [IgnoreMember]。 然后&#xff0c;[DataMember(Order int)] 的…...