通俗易懂之线性回归时序预测PyTorch实践
线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入理解这一重要概念。
线性回归概述
线性回归是一种用于预测因变量(目标变量)与一个或多个自变量(特征变量)之间关系的统计方法。其目标是在数据点之间找到一条最佳拟合直线,使得预测值与实际值之间的误差最小。
基本形式:
- 简单线性回归:只有一个自变量。
- 多元线性回归:包含多个自变量。
本文将聚焦于简单线性回归,即仅考虑一个自变量的情况。
线性回归的数学原理
模型表达式
简单线性回归的模型表达式为:
y = w x + b y = wx + b y=wx+b
其中:
- y y y 是预测值。
- x x x 是输入特征。
- w w w 是权重(斜率)。
- b b b 是偏置(截距)。
损失函数
为了衡量模型预测值与实际值之间的差异,通常使用均方误差(Mean Squared Error, MSE)作为损失函数:
Loss = 1 2 ∑ i = 1 N ( y i pred − y i ) 2 \text{Loss} = \frac{1}{2} \sum_{i=1}^{N} (y_i^{\text{pred}} - y_i)^2 Loss=21i=1∑N(yipred−yi)2
优化算法
线性回归常用的优化算法是梯度下降(Gradient Descent)。通过计算损失函数关于参数 w w w 和 b b b 的梯度,迭代更新参数以最小化损失。
更新规则如下:
w : = w − η ∂ Loss ∂ w w := w - \eta \frac{\partial \text{Loss}}{\partial w} w:=w−η∂w∂Loss
b : = b − η ∂ Loss ∂ b b := b - \eta \frac{\partial \text{Loss}}{\partial b} b:=b−η∂b∂Loss
其中 η \eta η 是学习率。
应用场景
线性回归在多个领域有广泛应用,包括但不限于:
- 经济学:预测经济指标,如GDP、通货膨胀率等。
- 工程学:估计物理量之间的关系,如材料强度与应力。
- 医疗:预测疾病发展趋势,如体重增长与健康指标。
- 金融:股价预测、风险评估等。
PyTorch实现线性回归
接下来,我们将通过一段PyTorch代码实践线性回归,从数据生成、模型训练到可视化展示,全面演示线性回归的实现过程。代码参考《深度学习框架PyTorch入门与实践》一书的实现,为了感受线性回归的计算过程,代码并未直接调用python中已有的线性回归库。
代码解析
首先,我们导入必要的库并设置随机种子以确保结果可复现。
import torch as t
import matplotlib.pyplot as plt
from IPython import displayt.manual_seed(1000)
数据生成函数
定义一个函数 get_fake_data 来生成假数据,这些数据遵循线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3 并添加了一定的噪声。
def get_fake_data(batch_size=8):x = t.randn(batch_size, 1, dtype=float) * 20 # 随机生成x,范围扩大到[-20, 20]y = x * 2 + (1 + t.randn(batch_size, 1, dtype=float)) * 3 # y = 2x + 3 + 噪声return x, y
调用该函数生成一批数据并进行可视化。
x, y = get_fake_data()plt.figure()
plt.scatter(x, y)
plt.show()
参数初始化
随机初始化权重 w w w 和偏置 b b b,并设置学习率 l r lr lr。
# 随机初始化参数
w = t.rand(1, 1, requires_grad=True, dtype=float)
b = t.zeros(1, 1, requires_grad=True, dtype=float)lr = 0.00001
训练过程
通过1000次迭代,使用梯度下降法优化参数 w w w 和 b b b。
for i in range(1000):x, y = get_fake_data()y_pred = x.mm(w) + b.expand_as(y) # 预测值loss = 0.5 * (y_pred - y) ** 2 # 均方误差loss = loss.sum()loss.backward() # 反向传播计算梯度# 更新参数w.data.sub_(lr * w.grad.data)b.data.sub_(lr * b.grad.data)# 梯度清零w.grad.data.zero_()b.grad.data.zero_()# 每100次迭代可视化一次结果if i % 100 == 0:display.clear_output(wait=True)x_plot = t.arange(0, 20, dtype=float).view(-1, 1)y_plot = x_plot.mm(w) + b.expand_as(x_plot)plt.plot(x_plot.data, y_plot.data, label='Fitting Line')x2, y2 = get_fake_data(batch_size=20)plt.scatter(x2, y2, color='red', label='Data Points')plt.xlim(0, 20)plt.ylim(0, 41)plt.legend()plt.show()plt.pause(0.5)
可视化与训练过程
训练过程中,每隔100次迭代,会清除之前的输出,绘制当前拟合的直线与新生成的数据点。随着训练的进行,拟合线将逐渐接近真实的线性关系 y = 2 x + 3 y = 2x + 3 y=2x+3。
以下是训练过程中的可视化效果示例:

注:实际运行代码时,图像会动态更新,展示拟合过程。
代码关键点解析
-
数据生成:
- 使用
torch.randn生成标准正态分布的随机数,并通过线性变换获取x和y。 - 添加噪声使模型更贴近真实场景。
- 使用
-
参数初始化:
w随机初始化,b初始化为零。requires_grad=True表示在反向传播时需要计算梯度。
-
前向传播:
- 计算预测值
y_pred = x.mm(w) + b.expand_as(y)。 - 使用矩阵乘法
mm实现线性变换。
- 计算预测值
-
损失计算:
- 采用均方误差损失函数。
loss.backward()计算损失函数相对于参数的梯度。
-
参数更新:
- 使用学习率
lr按梯度方向更新参数。 data.sub_进行原地更新,避免梯度计算图的干扰。
- 使用学习率
-
梯度清零:
- 每次参数更新后,需要清零梯度
w.grad.data.zero_()和b.grad.data.zero_(),以防止梯度累积。
- 每次参数更新后,需要清零梯度
-
可视化:
- 使用
matplotlib绘制拟合线和数据点。 display.clear_output(wait=True)清除之前的图像,避免图形堆积。plt.pause(0.5)控制图像更新速度。
- 使用
总结
本文从线性回归的基本概念出发,详细介绍了其数学原理和应用场景,并通过一段PyTorch代码演示了线性回归模型的实现过程。从数据生成、参数初始化、模型训练到结果可视化,全面展示了线性回归的实际应用。通过这种实例讲解,读者不仅能够理解线性回归的理论基础,还能掌握其在深度学习框架中的具体实现方法。
线性回归作为机器学习的基础模型,虽然简单,但其思想却深刻影响着更加复杂的算法和模型。在实际应用中,理解并掌握线性回归对于进一步学习和开发更加复杂的机器学习模型具有重要意义。
如果这篇文章对你有一点点的帮助,欢迎点赞、关注、收藏、转发、评论哦!
我也会在微信公众号“智识小站”坚持分享更多内容,以期记录成长、普及技术、造福后来者!


相关文章:
通俗易懂之线性回归时序预测PyTorch实践
线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入…...
[离线数仓] 总结二、Hive数仓分层开发
接 [离线数仓] 总结一、数据采集 5.8 数仓开发之ODS层 ODS层的设计要点如下: (1)ODS层的表结构设计依托于从业务系统同步过来的数据结构。 (2)ODS层要保存全部历史数据,故其压缩格式应选择压缩比率,较高的,此处选择gzip。 CompressedStorage - Apache Hive - Apac…...
页面顶部导航栏(Navbar)的功能(Navbar/index.vue)
这段代码是一个 Vue.js 组件,实现了页面顶部导航栏(Navbar)的功能。我将分块分析它的各个部分: 模板 (Template): <!-- spid-admin/src/layout/components/Navbar/index.vue --> <template><div class"navb…...
thinnkphp5.1和 thinkphp6以及nginx,apache 解决跨域问题
ThinkPHP 5.1 使用中间件设置响应头 ThinkPHP 5.1 及以上版本支持中间件,可以通过中间件统一设置跨域响应头。 步骤: 创建一个中间件文件,例如 CorsMiddleware.php: namespace app\middleware;class CorsMiddleware {public fu…...
vue2新增删除
(只是页面实现,不涉及数据库) list组件: <button click"onAdd">新增</button><el-table:header-cell-style"{ textAlign: center }" :cell-style"{ textAlign: center }":data&quo…...
测试ip端口-telnet开启与使用
前言 开发过程中我们总会要去测试ip通不通,或者ip下某个端口是否可以联通,为此我们可以使用telnet 命令来实现。 一、telnet 开启 可能有些人使用telnet报错,不是内部命令,可以如下开启: 1、打开控制面板ÿ…...
Python爬虫基础——XPath表达式
首先说一下这节内容在学习过程中存在的问题吧,在爬取百度网页文字时,出现了问题,就是通过表达式在网页搜索中可以定位,但是通过代码无法定位,请教了一位老师,他说是动态链接,目前这部分内容比较…...
ansible-性能优化
一. 简述: 搞过运维自动化工具的人,肯定会发现很多运维伙伴们经常用saltstack和ansible做比较,单从执行效率上来说,ansible确实比不上saltstack(ansible使用的是ssh,salt使用的是zeromq消息队列[暂没深入了解]),但其实…...
高等数学学习笔记 ☞ 一元函数微分的基础知识
1. 微分的定义 (1)定义:设函数在点的某领域内有定义,取附近的点,对应的函数值分别为和, 令,若可以表示成,则称函数在点是可微的。 【 若函数在点是可微的,则可以表达为】…...
前后端实现防抖节流实现
在前端和 Java 后端中实现防抖(Debounce)和节流(Throttle)主要用于减少频繁请求或事件触发对系统的压力。前端和后端的实现方式有些不同,以下是两种方法的具体实现: 1. 前端实现防抖和节流 在前端中&…...
【笔记】算法记录
1、求一个数的素因子(试除法) // 获取一个数的所有素因子 set<int> getPrimeFactors(int num) {set<int> primeFactors;for (int i 2; i * i < num; i) {while (num % i 0) {primeFactors.insert(i);num / i;}}if (num > 1) {prime…...
【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析
文章目录 一、选择题二、理论题三、实操题 【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析 一、选择题 生成树协议的主要作用是 B. 防止网络环路解释:生成树协议(STP)的主要目的是防止网络中…...
git push -f 指定分支
要将本地代码推送到指定的远程分支,你可以使用以下步骤和命令: 确认远程仓库: 确保你的本地仓库已经与远程仓库关联。你可以使用以下命令查看当前的远程仓库状态: git remote -v查看本地分支: 使用命令查看当前存在的本…...
CTF知识点总结(二)
异或注入:两个条件相同(同真或同假)即为假。 http://120.24.86.145:9004/1ndex.php?id1^(length(union)!0)-- 如上,如果union被过滤,则 length(union)!0 为假,那么返回页面正常。 2|0updatexml() 函数报…...
解决Edge打开PDF总是没有焦点
【问题描述】 使用Edge浏览器作为默认PDF阅读器打开本地PDF文件,Edge窗口总是不获得焦点,而是在任务栏以橙色显示,需要再手动点击一次才能查看文件内容。 本强迫症来治一治这个问题! 【解决方法】 GPT老师指出问题出在Edge的启动…...
69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)
项目介绍 在Internet高速发展的今天,我们生活的各个领域都涉及到计算机的应用,其中包括家乡特色推荐的网络应用,在外国家乡特色推荐系统已经是很普遍的方式,不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…...
计算机视觉目标检测-DETR网络
目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR(DEtection TRansformer)是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题,摒弃了锚框设计和非…...
《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶
目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知,乃至可以使用感知结果直接构建鸟瞰图(bird eye view, BEV),而 L4 则依赖离线地图。 高精地…...
【UE5 C++课程系列笔记】23——多线程基础——AsyncTask
目录 概念 函数说明 注意事项 (1)线程安全问题 (2)依赖特定线程执行的任务限制 (3)任务执行顺序和时间不确定性 使用示例 概念 AsyncTask 允许开发者将一个函数或者一段代码逻辑提交到特定的线程去执…...
基于Python的音乐播放器 毕业设计-附源码73733
摘 要 本项目基于Python开发了一款简单而功能强大的音乐播放器。通过该音乐播放器,用户可以轻松管理自己的音乐库,播放喜爱的音乐,并享受音乐带来的愉悦体验。 首先,我们使用Python语言结合相关库开发了这款音乐播放器。利用Tkin…...
Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)
参考官方文档:https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java(供 Kotlin 使用) 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...
在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?
uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件,用于在原生应用中加载 HTML 页面: 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...
Reasoning over Uncertain Text by Generative Large Language Models
https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829 1. 概述 文本中的不确定性在许多语境中传达,从日常对话到特定领域的文档(例如医学文档)(Heritage 2013;Landmark、Gulbrandsen 和 Svenevei…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...
C++:多态机制详解
目录 一. 多态的概念 1.静态多态(编译时多态) 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1).协变 2).析构函数的重写 5.override 和 final关键字 1&#…...
人机融合智能 | “人智交互”跨学科新领域
本文系统地提出基于“以人为中心AI(HCAI)”理念的人-人工智能交互(人智交互)这一跨学科新领域及框架,定义人智交互领域的理念、基本理论和关键问题、方法、开发流程和参与团队等,阐述提出人智交互新领域的意义。然后,提出人智交互研究的三种新范式取向以及它们的意义。最后,总结…...
C++.OpenGL (20/64)混合(Blending)
混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...
破解路内监管盲区:免布线低位视频桩重塑停车管理新标准
城市路内停车管理常因行道树遮挡、高位设备盲区等问题,导致车牌识别率低、逃费率高,传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法,正成为破局关键。该设备安装于车位侧方0.5-0.7米高度,直接规避树枝遮…...
MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释
以Module Federation 插件详为例,Webpack.config.js它可能的配置和含义如下: 前言 Module Federation 的Webpack.config.js核心配置包括: name filename(定义应用标识) remotes(引用远程模块࿰…...
Unity中的transform.up
2025年6月8日,周日下午 在Unity中,transform.up是Transform组件的一个属性,表示游戏对象在世界空间中的“上”方向(Y轴正方向),且会随对象旋转动态变化。以下是关键点解析: 基本定义 transfor…...
