通俗易懂之线性回归时序预测PyTorch实践
线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入理解这一重要概念。
线性回归概述
线性回归是一种用于预测因变量(目标变量)与一个或多个自变量(特征变量)之间关系的统计方法。其目标是在数据点之间找到一条最佳拟合直线,使得预测值与实际值之间的误差最小。
基本形式:
- 简单线性回归:只有一个自变量。
- 多元线性回归:包含多个自变量。
本文将聚焦于简单线性回归,即仅考虑一个自变量的情况。
线性回归的数学原理
模型表达式
简单线性回归的模型表达式为:
y = w x + b y = wx + b y=wx+b
其中:
- y y y 是预测值。
- x x x 是输入特征。
- w w w 是权重(斜率)。
- b b b 是偏置(截距)。
损失函数
为了衡量模型预测值与实际值之间的差异,通常使用均方误差(Mean Squared Error, MSE)作为损失函数:
Loss = 1 2 ∑ i = 1 N ( y i pred − y i ) 2 \text{Loss} = \frac{1}{2} \sum_{i=1}^{N} (y_i^{\text{pred}} - y_i)^2 Loss=21i=1∑N(yipred−yi)2
优化算法
线性回归常用的优化算法是梯度下降(Gradient Descent)。通过计算损失函数关于参数 w w w 和 b b b 的梯度,迭代更新参数以最小化损失。
更新规则如下:
w : = w − η ∂ Loss ∂ w w := w - \eta \frac{\partial \text{Loss}}{\partial w} w:=w−η∂w∂Loss
b : = b − η ∂ Loss ∂ b b := b - \eta \frac{\partial \text{Loss}}{\partial b} b:=b−η∂b∂Loss
其中 η \eta η 是学习率。
应用场景
线性回归在多个领域有广泛应用,包括但不限于:
- 经济学:预测经济指标,如GDP、通货膨胀率等。
- 工程学:估计物理量之间的关系,如材料强度与应力。
- 医疗:预测疾病发展趋势,如体重增长与健康指标。
- 金融:股价预测、风险评估等。
PyTorch实现线性回归
接下来,我们将通过一段PyTorch代码实践线性回归,从数据生成、模型训练到可视化展示,全面演示线性回归的实现过程。代码参考《深度学习框架PyTorch入门与实践》一书的实现,为了感受线性回归的计算过程,代码并未直接调用python中已有的线性回归库。
代码解析
首先,我们导入必要的库并设置随机种子以确保结果可复现。
import torch as t
import matplotlib.pyplot as plt
from IPython import displayt.manual_seed(1000)
数据生成函数
定义一个函数 get_fake_data
来生成假数据,这些数据遵循线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3 并添加了一定的噪声。
def get_fake_data(batch_size=8):x = t.randn(batch_size, 1, dtype=float) * 20 # 随机生成x,范围扩大到[-20, 20]y = x * 2 + (1 + t.randn(batch_size, 1, dtype=float)) * 3 # y = 2x + 3 + 噪声return x, y
调用该函数生成一批数据并进行可视化。
x, y = get_fake_data()plt.figure()
plt.scatter(x, y)
plt.show()
参数初始化
随机初始化权重 w w w 和偏置 b b b,并设置学习率 l r lr lr。
# 随机初始化参数
w = t.rand(1, 1, requires_grad=True, dtype=float)
b = t.zeros(1, 1, requires_grad=True, dtype=float)lr = 0.00001
训练过程
通过1000次迭代,使用梯度下降法优化参数 w w w 和 b b b。
for i in range(1000):x, y = get_fake_data()y_pred = x.mm(w) + b.expand_as(y) # 预测值loss = 0.5 * (y_pred - y) ** 2 # 均方误差loss = loss.sum()loss.backward() # 反向传播计算梯度# 更新参数w.data.sub_(lr * w.grad.data)b.data.sub_(lr * b.grad.data)# 梯度清零w.grad.data.zero_()b.grad.data.zero_()# 每100次迭代可视化一次结果if i % 100 == 0:display.clear_output(wait=True)x_plot = t.arange(0, 20, dtype=float).view(-1, 1)y_plot = x_plot.mm(w) + b.expand_as(x_plot)plt.plot(x_plot.data, y_plot.data, label='Fitting Line')x2, y2 = get_fake_data(batch_size=20)plt.scatter(x2, y2, color='red', label='Data Points')plt.xlim(0, 20)plt.ylim(0, 41)plt.legend()plt.show()plt.pause(0.5)
可视化与训练过程
训练过程中,每隔100次迭代,会清除之前的输出,绘制当前拟合的直线与新生成的数据点。随着训练的进行,拟合线将逐渐接近真实的线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3。
以下是训练过程中的可视化效果示例:
注:实际运行代码时,图像会动态更新,展示拟合过程。
代码关键点解析
-
数据生成:
- 使用
torch.randn
生成标准正态分布的随机数,并通过线性变换获取x
和y
。 - 添加噪声使模型更贴近真实场景。
- 使用
-
参数初始化:
w
随机初始化,b
初始化为零。requires_grad=True
表示在反向传播时需要计算梯度。
-
前向传播:
- 计算预测值
y_pred = x.mm(w) + b.expand_as(y)
。 - 使用矩阵乘法
mm
实现线性变换。
- 计算预测值
-
损失计算:
- 采用均方误差损失函数。
loss.backward()
计算损失函数相对于参数的梯度。
-
参数更新:
- 使用学习率
lr
按梯度方向更新参数。 data.sub_
进行原地更新,避免梯度计算图的干扰。
- 使用学习率
-
梯度清零:
- 每次参数更新后,需要清零梯度
w.grad.data.zero_()
和b.grad.data.zero_()
,以防止梯度累积。
- 每次参数更新后,需要清零梯度
-
可视化:
- 使用
matplotlib
绘制拟合线和数据点。 display.clear_output(wait=True)
清除之前的图像,避免图形堆积。plt.pause(0.5)
控制图像更新速度。
- 使用
总结
本文从线性回归的基本概念出发,详细介绍了其数学原理和应用场景,并通过一段PyTorch代码演示了线性回归模型的实现过程。从数据生成、参数初始化、模型训练到结果可视化,全面展示了线性回归的实际应用。通过这种实例讲解,读者不仅能够理解线性回归的理论基础,还能掌握其在深度学习框架中的具体实现方法。
线性回归作为机器学习的基础模型,虽然简单,但其思想却深刻影响着更加复杂的算法和模型。在实际应用中,理解并掌握线性回归对于进一步学习和开发更加复杂的机器学习模型具有重要意义。
如果这篇文章对你有一点点的帮助,欢迎点赞、关注、收藏、转发、评论哦!
我也会在微信公众号“智识小站”坚持分享更多内容,以期记录成长、普及技术、造福后来者!
相关文章:

通俗易懂之线性回归时序预测PyTorch实践
线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入…...

[离线数仓] 总结二、Hive数仓分层开发
接 [离线数仓] 总结一、数据采集 5.8 数仓开发之ODS层 ODS层的设计要点如下: (1)ODS层的表结构设计依托于从业务系统同步过来的数据结构。 (2)ODS层要保存全部历史数据,故其压缩格式应选择压缩比率,较高的,此处选择gzip。 CompressedStorage - Apache Hive - Apac…...

页面顶部导航栏(Navbar)的功能(Navbar/index.vue)
这段代码是一个 Vue.js 组件,实现了页面顶部导航栏(Navbar)的功能。我将分块分析它的各个部分: 模板 (Template): <!-- spid-admin/src/layout/components/Navbar/index.vue --> <template><div class"navb…...
thinnkphp5.1和 thinkphp6以及nginx,apache 解决跨域问题
ThinkPHP 5.1 使用中间件设置响应头 ThinkPHP 5.1 及以上版本支持中间件,可以通过中间件统一设置跨域响应头。 步骤: 创建一个中间件文件,例如 CorsMiddleware.php: namespace app\middleware;class CorsMiddleware {public fu…...
vue2新增删除
(只是页面实现,不涉及数据库) list组件: <button click"onAdd">新增</button><el-table:header-cell-style"{ textAlign: center }" :cell-style"{ textAlign: center }":data&quo…...

测试ip端口-telnet开启与使用
前言 开发过程中我们总会要去测试ip通不通,或者ip下某个端口是否可以联通,为此我们可以使用telnet 命令来实现。 一、telnet 开启 可能有些人使用telnet报错,不是内部命令,可以如下开启: 1、打开控制面板ÿ…...
Python爬虫基础——XPath表达式
首先说一下这节内容在学习过程中存在的问题吧,在爬取百度网页文字时,出现了问题,就是通过表达式在网页搜索中可以定位,但是通过代码无法定位,请教了一位老师,他说是动态链接,目前这部分内容比较…...

ansible-性能优化
一. 简述: 搞过运维自动化工具的人,肯定会发现很多运维伙伴们经常用saltstack和ansible做比较,单从执行效率上来说,ansible确实比不上saltstack(ansible使用的是ssh,salt使用的是zeromq消息队列[暂没深入了解]),但其实…...
高等数学学习笔记 ☞ 一元函数微分的基础知识
1. 微分的定义 (1)定义:设函数在点的某领域内有定义,取附近的点,对应的函数值分别为和, 令,若可以表示成,则称函数在点是可微的。 【 若函数在点是可微的,则可以表达为】…...

前后端实现防抖节流实现
在前端和 Java 后端中实现防抖(Debounce)和节流(Throttle)主要用于减少频繁请求或事件触发对系统的压力。前端和后端的实现方式有些不同,以下是两种方法的具体实现: 1. 前端实现防抖和节流 在前端中&…...
【笔记】算法记录
1、求一个数的素因子(试除法) // 获取一个数的所有素因子 set<int> getPrimeFactors(int num) {set<int> primeFactors;for (int i 2; i * i < num; i) {while (num % i 0) {primeFactors.insert(i);num / i;}}if (num > 1) {prime…...
【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析
文章目录 一、选择题二、理论题三、实操题 【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析 一、选择题 生成树协议的主要作用是 B. 防止网络环路解释:生成树协议(STP)的主要目的是防止网络中…...
git push -f 指定分支
要将本地代码推送到指定的远程分支,你可以使用以下步骤和命令: 确认远程仓库: 确保你的本地仓库已经与远程仓库关联。你可以使用以下命令查看当前的远程仓库状态: git remote -v查看本地分支: 使用命令查看当前存在的本…...
CTF知识点总结(二)
异或注入:两个条件相同(同真或同假)即为假。 http://120.24.86.145:9004/1ndex.php?id1^(length(union)!0)-- 如上,如果union被过滤,则 length(union)!0 为假,那么返回页面正常。 2|0updatexml() 函数报…...
解决Edge打开PDF总是没有焦点
【问题描述】 使用Edge浏览器作为默认PDF阅读器打开本地PDF文件,Edge窗口总是不获得焦点,而是在任务栏以橙色显示,需要再手动点击一次才能查看文件内容。 本强迫症来治一治这个问题! 【解决方法】 GPT老师指出问题出在Edge的启动…...

69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)
项目介绍 在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的应用,其中包括家乡特色推荐的网络应用,在外国家乡特色推荐系统已经是很普遍的方式,不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…...

计算机视觉目标检测-DETR网络
目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非…...

《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶
目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知,乃至可以使用感知结果直接构建鸟瞰图(bird eye view, BEV),而 L4 则依赖离线地图。 高精地…...

【UE5 C++课程系列笔记】23——多线程基础——AsyncTask
目录 概念 函数说明 注意事项 (1)线程安全问题 (2)依赖特定线程执行的任务限制 (3)任务执行顺序和时间不确定性 使用示例 概念 AsyncTask 允许开发者将一个函数或者一段代码逻辑提交到特定的线程去执…...

基于Python的音乐播放器 毕业设计-附源码73733
摘 要 本项目基于Python开发了一款简单而功能强大的音乐播放器。通过该音乐播放器,用户可以轻松管理自己的音乐库,播放喜爱的音乐,并享受音乐带来的愉悦体验。 首先,我们使用Python语言结合相关库开发了这款音乐播放器。利用Tkin…...

19c补丁后oracle属主变化,导致不能识别磁盘组
补丁后服务器重启,数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后,存在与用户组权限相关的问题。具体表现为,Oracle 实例的运行用户(oracle)和集…...

linux之kylin系统nginx的安装
一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源(HTML/CSS/图片等),响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址,提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

【Oracle APEX开发小技巧12】
有如下需求: 有一个问题反馈页面,要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据,方便管理员及时处理反馈。 我的方法:直接将逻辑写在SQL中,这样可以直接在页面展示 完整代码: SELECTSF.FE…...

【项目实战】通过多模态+LangGraph实现PPT生成助手
PPT自动生成系统 基于LangGraph的PPT自动生成系统,可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析:自动解析Markdown文档结构PPT模板分析:分析PPT模板的布局和风格智能布局决策:匹配内容与合适的PPT布局自动…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

微信小程序云开发平台MySQL的连接方式
注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...

NLP学习路线图(二十三):长短期记忆网络(LSTM)
在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...
大数据学习(132)-HIve数据分析
🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言Ǵ…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题
分区配置 (ptab.json) img 属性介绍: img 属性指定分区存放的 image 名称,指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件,则以 proj_name:binary_name 格式指定文件名, proj_name 为工程 名&…...
scikit-learn机器学习
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...