从零开始学习线性回归:理论、实践与PyTorch实现
文章目录
- 🥦介绍
- 🥦基本知识
- 🥦代码实现
- 🥦完整代码
- 🥦总结
🥦介绍
线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。
🥦基本知识
线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,…,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,…,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。
损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。
梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE ∂bj∂MSE 是损失函数对参数 b j b_j bj的偏导数。
🥦代码实现
如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
- 准备数据
- 设计模型(计算) y ^ i \hat{y}_i y^i
- 构造损失和优化器
- 训练周期(前向,反向 ,更新)
本节还是以刘二大人的视频讲解为例,结尾会设置传送门
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1) # 参数详情下图展示def forward(self, x):y_pred = self.linear(x) # x代表输入样本的张量return y_pred
model = LinearModel()
所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算
好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
-
第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。
-
第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。
-
第三个参数:意思是要不要偏置量。默认true
通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立
criterion = torch.nn.MSELoss(size_average=False) # 使用均方误差损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 使用随机梯度下降优化器
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。
优化器的选择有许多大家可以都试试看看
之后就进行训练了
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() # 归零loss.backward() # 反向optimizer.step() # 更新
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
🥦完整代码
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()
运行结果如下
🥦总结
在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。
挑战与创造都是很痛苦的,但是很充实。
相关文章:

从零开始学习线性回归:理论、实践与PyTorch实现
文章目录 🥦介绍🥦基本知识🥦代码实现🥦完整代码🥦总结 🥦介绍 线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性…...

[LeetCode]链式二叉树相关题目(c语言实现)
文章目录 LeetCode965. 单值二叉树LeetCode100. 相同的树LeetCode101. 对称二叉树LeetCode144. 二叉树的前序遍历LeetCode94. 二叉树的中序遍历LeetCode145. 二叉树的后序遍历LeetCode572. 另一棵树的子树 LeetCode965. 单值二叉树 题目 Oj链接 思路 一棵树的所有值都是一个…...

集成学习
集成学习(Ensemble Learning) - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/27689464集成学习就是组合这里的多个弱监督模型以期得到一个更好更全面的强监督模型,集成学习潜在的思想是即便某一个弱分类器得到了错误的预测,其他的弱分类器…...

算法练习11——买卖股票的最佳时机 II
LeetCode 122 买卖股票的最佳时机 II 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买,然后在 同一天 出售。 返回…...

linux——多线程,线程控制
目录 一.POSIX线程库 二.线程创建 1.创建线程接口 2.查看线程 3.多线程的健壮性问题 4.线程函数参数传递 5.线程id和地址空间 三.线程终止 1.pthread_exit 2.pthread_cancel 四.线程等待 五.线程分离 一.POSIX线程库 站在内核的角度,OS只有轻量级进程…...

Oracle 简介与 Docker Compose部署
最近,我翻阅了在之前公司工作时的笔记,偶然发现了一些有关数据库的记录。当初,我们的项目一开始采用的是 Oracle 数据库,但随着项目需求的变化,我们不得不转向使用 SQL Server。值得一提的是,公司之前采用的…...

mp4音视频分离技术
文章目录 问题描述一、分离MP3二、分离无声音的MP4三、结果 问题描述 MP4视频想拆分成一个MP3音频和一个无声音的MP4文件 一、分离MP3 ffmpeg -i C:\Users\Administrator\Desktop\一个文件夹\我在财神殿里长跪不起_完整版MV.mp4 -vn C:\Users\Administrator\Desktop\一个文件…...

JVM 参数
JVM 参数类型大致分为以下几类: 标准参数(-):保证在所有的 JVM 实现都支持的参数非标准参数(-X):通用的,特定于 HotSpot 虚拟机的参数,这些参数不保证在所有 JVM 实现中…...

黑马点评-07缓存击穿问题(热点key失效)及解决方案,互斥锁和设置逻辑过期时间
缓存击穿问题(热点key失效) 缓存击穿问题也叫热点Key问题,就是一个被高并发访问并且重建缓存业务较复杂的key突然失效了,此时无数的请求访问会在瞬间打到数据库,带来巨大的冲击 一件秒杀中的商品的key突然失效了,由于大家都在疯狂抢购那么这个瞬间就会有无数的请求…...

信息系统项目管理师第四版学习笔记——项目进度管理
项目进度管理过程 项目进度管理过程包括:规划进度管理、定义活动、排列活动顺序、估算活动持续时间、制订进度计划、控制进度。 规划进度管理 规划进度管理是为规划、编制、管理、执行和控制项目进度而制定政策、程序和文档的过程。本过程的主要作用是为如何在…...
指挥棒:C++ 与运算符
文章目录 参考描述算术运算符除法运算取模运算复合赋值运算符自增运算符自减运算符 比较运算符逻辑运算符概念短路为什么需要短路机制? 参考 项目描述微软C 语言文档搜索引擎Bing、GoogleAI 大模型文心一言、通义千问、讯飞星火认知大模型、ChatGPTC Primer Plus &…...

HTTPS建立连接的过程
HTTPS 协议是基于 TCP 协议的,因而要先建立 TCP 的连接。在这个例子中,TCP 的连接是在手机上的 App 和负载均衡器 SLB 之间的。 尽管中间要经过很多的路由器和交换机,但是 TCP 的连接是端到端的。TCP 这一层和更上层的 HTTPS 无法看到中间的包…...

Python接口自动化搭建过程,含request请求封装!
开篇碎碎念 接口测试自动化好处 显而易见的好处就是解放双手😀。 可以在短时间内自动执行大量的测试用例通过参数化和数据驱动的方式进行测试数据的变化,提高测试覆盖范围快速反馈测试执行结果和报告支持持续集成和持续交付的流程 使用Requestspytes…...

Vue3 编译原理
文章目录 一、编译流程1. 解读入口文件 packgages/vue/index.ts2. compile函数的运行流程 二、AST 解析器1. ast 的生成2. 创建ast的根节点3. 解析子节点 parseChildren(关键)4. 解析模版元素 Element模版元素解析-举例分析 一、编译流程 1. 解读入口文…...
spring boot整合Minio
MinIO 安装MinIo # 先创建minio 文件存放的位置 mkdir -p /opt/docker/minio/data# 启动并指定端口 docker run \-p 9000:9000 \-p 5001:5001 \--name minio \-v /opt/docker/minio/data:/data \-e "MINIO_ROOT_USERminioadmin" \-e "MINIO_ROOT_PASSWORDmini…...

Hadoop----Azkaban的使用与一些报错问题的解决
1.因为官方只放出源码,并没有放出其tar包,所以需要我们自己编译,通过查阅资料我们可以使用gradlew对其进行编译,还是比较简单,然后将里面需要用到的服务文件夹进行拷贝,完善其文件夹结构,通常会…...

「新房家装经验」客厅电视高度标准尺寸及客厅电视机买多大尺寸合适?
客厅电视悬挂高度标准尺寸是多少? 客厅电视悬挂高度通常在90~120厘米之间,电视挂墙高度也可以根据个人的喜好和实际情况来调整,但通常不宜过高,以坐在沙发上观看时眼睛能够平视到电视中心点或者中心稍微往下一点的位置为适宜。 客…...
ArduPilot开源飞控之AP_Baro_DroneCAN
ArduPilot开源飞控之AP_Baro_DroneCAN 1. 源由2. back-end抽象类3. 方法实现3.1 probe3.2 update3.3 subscribe_msgs3.4 handle_pressure/handle_temperature3.5 CAN port 4. 参考资料 1. 源由 鉴于ArduPilot开源飞控之AP_Baro中涉及Sensor Driver有以下总线类型: …...

Supervised Contrastive Pre-training for Mammographic Triage Screening Model
方法 品红色箭头表示将生成的孪生编码器分别迁移到单视角学习模块和双视角学习模块...
JVM技术文档--JVM优化思路以及问题定位--JVM可调整参数汇总
阿丹: 一个优秀的程序员,是因为在线上的排查以及遇到的线上、生产事故较多所以定位问题以及解决问题会比普通程序员快很多,所以一个优秀的程序员要逐渐形成自己的方法论,来完善和解决问题。 我们是如何发现问题的呢? …...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)
说明: 想象一下,你正在用eNSP搭建一个虚拟的网络世界,里面有虚拟的路由器、交换机、电脑(PC)等等。这些设备都在你的电脑里面“运行”,它们之间可以互相通信,就像一个封闭的小王国。 但是&#…...
QMC5883L的驱动
简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...
线程同步:确保多线程程序的安全与高效!
全文目录: 开篇语前序前言第一部分:线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分:synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分ÿ…...

【机器视觉】单目测距——运动结构恢复
ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...

SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...
Neo4j 集群管理:原理、技术与最佳实践深度解析
Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)
笔记整理:刘治强,浙江大学硕士生,研究方向为知识图谱表示学习,大语言模型 论文链接:http://arxiv.org/abs/2407.16127 发表会议:ISWC 2024 1. 动机 传统的知识图谱补全(KGC)模型通过…...
JDK 17 新特性
#JDK 17 新特性 /**************** 文本块 *****************/ python/scala中早就支持,不稀奇 String json “”" { “name”: “Java”, “version”: 17 } “”"; /**************** Switch 语句 -> 表达式 *****************/ 挺好的ÿ…...

如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...