当前位置: 首页 > news >正文

从零开始学习线性回归:理论、实践与PyTorch实现

文章目录

  • 🥦介绍
  • 🥦基本知识
  • 🥦代码实现
  • 🥦完整代码
  • 🥦总结

🥦介绍

线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。

🥦基本知识

线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
y=b0+b1x1+b2x2+…+bpxp
y=b0​+b1​x1​+b2​x2​+…+bp​xp​

其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。

损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
在这里插入图片描述
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。

梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
在这里插入图片描述
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE bjMSE 是损失函数对参数 b j b_j bj的偏导数。

🥦代码实现

如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
在这里插入图片描述

  • 准备数据
  • 设计模型(计算) y ^ i \hat{y}_i y^i
  • 构造损失和优化器
  • 训练周期(前向,反向 ,更新)

本节还是以刘二大人的视频讲解为例,结尾会设置传送门

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1)  # 参数详情下图展示def forward(self, x):y_pred = self.linear(x)   # x代表输入样本的张量return y_pred
model = LinearModel()

所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算

好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
在这里插入图片描述

  • 第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。

  • 第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。

  • 第三个参数:意思是要不要偏置量。默认true

通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立

criterion = torch.nn.MSELoss(size_average=False)  # 使用均方误差损失 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 使用随机梯度下降优化器

在这里插入图片描述
在这里插入图片描述
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。

优化器的选择有许多大家可以都试试看看
在这里插入图片描述

之后就进行训练了

for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad()   # 归零loss.backward()  # 反向optimizer.step()  # 更新
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)

🥦完整代码

x_data = torch.Tensor([[1.0], [2.0], [3.0]]) 
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()

运行结果如下
在这里插入图片描述

在这里插入图片描述

🥦总结

在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

相关文章:

从零开始学习线性回归:理论、实践与PyTorch实现

文章目录 🥦介绍🥦基本知识🥦代码实现🥦完整代码🥦总结 🥦介绍 线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性…...

[LeetCode]链式二叉树相关题目(c语言实现)

文章目录 LeetCode965. 单值二叉树LeetCode100. 相同的树LeetCode101. 对称二叉树LeetCode144. 二叉树的前序遍历LeetCode94. 二叉树的中序遍历LeetCode145. 二叉树的后序遍历LeetCode572. 另一棵树的子树 LeetCode965. 单值二叉树 题目 Oj链接 思路 一棵树的所有值都是一个…...

集成学习

集成学习(Ensemble Learning) - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/27689464集成学习就是组合这里的多个弱监督模型以期得到一个更好更全面的强监督模型,集成学习潜在的思想是即便某一个弱分类器得到了错误的预测,其他的弱分类器…...

算法练习11——买卖股票的最佳时机 II

LeetCode 122 买卖股票的最佳时机 II 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买,然后在 同一天 出售。 返回…...

linux——多线程,线程控制

目录 一.POSIX线程库 二.线程创建 1.创建线程接口 2.查看线程 3.多线程的健壮性问题 4.线程函数参数传递 5.线程id和地址空间 三.线程终止 1.pthread_exit 2.pthread_cancel 四.线程等待 五.线程分离 一.POSIX线程库 站在内核的角度,OS只有轻量级进程…...

Oracle 简介与 Docker Compose部署

最近,我翻阅了在之前公司工作时的笔记,偶然发现了一些有关数据库的记录。当初,我们的项目一开始采用的是 Oracle 数据库,但随着项目需求的变化,我们不得不转向使用 SQL Server。值得一提的是,公司之前采用的…...

mp4音视频分离技术

文章目录 问题描述一、分离MP3二、分离无声音的MP4三、结果 问题描述 MP4视频想拆分成一个MP3音频和一个无声音的MP4文件 一、分离MP3 ffmpeg -i C:\Users\Administrator\Desktop\一个文件夹\我在财神殿里长跪不起_完整版MV.mp4 -vn C:\Users\Administrator\Desktop\一个文件…...

JVM 参数

JVM 参数类型大致分为以下几类: 标准参数(-):保证在所有的 JVM 实现都支持的参数非标准参数(-X):通用的,特定于 HotSpot 虚拟机的参数,这些参数不保证在所有 JVM 实现中…...

黑马点评-07缓存击穿问题(热点key失效)及解决方案,互斥锁和设置逻辑过期时间

缓存击穿问题(热点key失效) 缓存击穿问题也叫热点Key问题,就是一个被高并发访问并且重建缓存业务较复杂的key突然失效了,此时无数的请求访问会在瞬间打到数据库,带来巨大的冲击 一件秒杀中的商品的key突然失效了,由于大家都在疯狂抢购那么这个瞬间就会有无数的请求…...

信息系统项目管理师第四版学习笔记——项目进度管理

项目进度管理过程 项目进度管理过程包括:规划进度管理、定义活动、排列活动顺序、估算活动持续时间、制订进度计划、控制进度。 规划进度管理 规划进度管理是为规划、编制、管理、执行和控制项目进度而制定政策、程序和文档的过程。本过程的主要作用是为如何在…...

指挥棒:C++ 与运算符

文章目录 参考描述算术运算符除法运算取模运算复合赋值运算符自增运算符自减运算符 比较运算符逻辑运算符概念短路为什么需要短路机制? 参考 项目描述微软C 语言文档搜索引擎Bing、GoogleAI 大模型文心一言、通义千问、讯飞星火认知大模型、ChatGPTC Primer Plus &…...

HTTPS建立连接的过程

HTTPS 协议是基于 TCP 协议的,因而要先建立 TCP 的连接。在这个例子中,TCP 的连接是在手机上的 App 和负载均衡器 SLB 之间的。 尽管中间要经过很多的路由器和交换机,但是 TCP 的连接是端到端的。TCP 这一层和更上层的 HTTPS 无法看到中间的包…...

Python接口自动化搭建过程,含request请求封装!

开篇碎碎念 接口测试自动化好处 显而易见的好处就是解放双手😀。 可以在短时间内自动执行大量的测试用例通过参数化和数据驱动的方式进行测试数据的变化,提高测试覆盖范围快速反馈测试执行结果和报告支持持续集成和持续交付的流程 使用Requestspytes…...

Vue3 编译原理

文章目录 一、编译流程1. 解读入口文件 packgages/vue/index.ts2. compile函数的运行流程 二、AST 解析器1. ast 的生成2. 创建ast的根节点3. 解析子节点 parseChildren(关键)4. 解析模版元素 Element模版元素解析-举例分析 一、编译流程 1. 解读入口文…...

spring boot整合Minio

MinIO 安装MinIo # 先创建minio 文件存放的位置 mkdir -p /opt/docker/minio/data# 启动并指定端口 docker run \-p 9000:9000 \-p 5001:5001 \--name minio \-v /opt/docker/minio/data:/data \-e "MINIO_ROOT_USERminioadmin" \-e "MINIO_ROOT_PASSWORDmini…...

Hadoop----Azkaban的使用与一些报错问题的解决

1.因为官方只放出源码,并没有放出其tar包,所以需要我们自己编译,通过查阅资料我们可以使用gradlew对其进行编译,还是比较简单,然后将里面需要用到的服务文件夹进行拷贝,完善其文件夹结构,通常会…...

「新房家装经验」客厅电视高度标准尺寸及客厅电视机买多大尺寸合适?

客厅电视悬挂高度标准尺寸是多少? 客厅电视悬挂高度通常在90~120厘米之间,电视挂墙高度也可以根据个人的喜好和实际情况来调整,但通常不宜过高,以坐在沙发上观看时眼睛能够平视到电视中心点或者中心稍微往下一点的位置为适宜。 客…...

ArduPilot开源飞控之AP_Baro_DroneCAN

ArduPilot开源飞控之AP_Baro_DroneCAN 1. 源由2. back-end抽象类3. 方法实现3.1 probe3.2 update3.3 subscribe_msgs3.4 handle_pressure/handle_temperature3.5 CAN port 4. 参考资料 1. 源由 鉴于ArduPilot开源飞控之AP_Baro中涉及Sensor Driver有以下总线类型: …...

Supervised Contrastive Pre-training for Mammographic Triage Screening Model

方法 品红色箭头表示将生成的孪生编码器分别迁移到单视角学习模块和双视角学习模块...

JVM技术文档--JVM优化思路以及问题定位--JVM可调整参数汇总

阿丹: 一个优秀的程序员,是因为在线上的排查以及遇到的线上、生产事故较多所以定位问题以及解决问题会比普通程序员快很多,所以一个优秀的程序员要逐渐形成自己的方法论,来完善和解决问题。 我们是如何发现问题的呢? …...

从零实现富文本编辑器#5-编辑器选区模型的状态结构表达

先前我们总结了浏览器选区模型的交互策略,并且实现了基本的选区操作,还调研了自绘选区的实现。那么相对的,我们还需要设计编辑器的选区表达,也可以称为模型选区。编辑器中应用变更时的操作范围,就是以模型选区为基准来…...

【网络安全产品大调研系列】2. 体验漏洞扫描

前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

如何将联系人从 iPhone 转移到 Android

从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...

【AI学习】三、AI算法中的向量

在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

CMake控制VS2022项目文件分组

我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

分布式增量爬虫实现方案

之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...

服务器--宝塔命令

一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...

论文笔记——相干体技术在裂缝预测中的应用研究

目录 相关地震知识补充地震数据的认识地震几何属性 相干体算法定义基本原理第一代相干体技术:基于互相关的相干体技术(Correlation)第二代相干体技术:基于相似的相干体技术(Semblance)基于多道相似的相干体…...