从零开始学习线性回归:理论、实践与PyTorch实现
文章目录
- 🥦介绍
- 🥦基本知识
- 🥦代码实现
- 🥦完整代码
- 🥦总结
🥦介绍
线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。
🥦基本知识
线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,…,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,…,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。
损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。
梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE ∂bj∂MSE 是损失函数对参数 b j b_j bj的偏导数。
🥦代码实现
如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
- 准备数据
- 设计模型(计算) y ^ i \hat{y}_i y^i
- 构造损失和优化器
- 训练周期(前向,反向 ,更新)
本节还是以刘二大人的视频讲解为例,结尾会设置传送门
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1) # 参数详情下图展示def forward(self, x):y_pred = self.linear(x) # x代表输入样本的张量return y_pred
model = LinearModel()
所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算
好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
-
第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。
-
第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。
-
第三个参数:意思是要不要偏置量。默认true
通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立
criterion = torch.nn.MSELoss(size_average=False) # 使用均方误差损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 使用随机梯度下降优化器
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。
优化器的选择有许多大家可以都试试看看
之后就进行训练了
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() # 归零loss.backward() # 反向optimizer.step() # 更新
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
🥦完整代码
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()
运行结果如下
🥦总结
在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。
挑战与创造都是很痛苦的,但是很充实。
相关文章:

从零开始学习线性回归:理论、实践与PyTorch实现
文章目录 🥦介绍🥦基本知识🥦代码实现🥦完整代码🥦总结 🥦介绍 线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性…...

[LeetCode]链式二叉树相关题目(c语言实现)
文章目录 LeetCode965. 单值二叉树LeetCode100. 相同的树LeetCode101. 对称二叉树LeetCode144. 二叉树的前序遍历LeetCode94. 二叉树的中序遍历LeetCode145. 二叉树的后序遍历LeetCode572. 另一棵树的子树 LeetCode965. 单值二叉树 题目 Oj链接 思路 一棵树的所有值都是一个…...

集成学习
集成学习(Ensemble Learning) - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/27689464集成学习就是组合这里的多个弱监督模型以期得到一个更好更全面的强监督模型,集成学习潜在的思想是即便某一个弱分类器得到了错误的预测,其他的弱分类器…...

算法练习11——买卖股票的最佳时机 II
LeetCode 122 买卖股票的最佳时机 II 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买,然后在 同一天 出售。 返回…...

linux——多线程,线程控制
目录 一.POSIX线程库 二.线程创建 1.创建线程接口 2.查看线程 3.多线程的健壮性问题 4.线程函数参数传递 5.线程id和地址空间 三.线程终止 1.pthread_exit 2.pthread_cancel 四.线程等待 五.线程分离 一.POSIX线程库 站在内核的角度,OS只有轻量级进程…...

Oracle 简介与 Docker Compose部署
最近,我翻阅了在之前公司工作时的笔记,偶然发现了一些有关数据库的记录。当初,我们的项目一开始采用的是 Oracle 数据库,但随着项目需求的变化,我们不得不转向使用 SQL Server。值得一提的是,公司之前采用的…...

mp4音视频分离技术
文章目录 问题描述一、分离MP3二、分离无声音的MP4三、结果 问题描述 MP4视频想拆分成一个MP3音频和一个无声音的MP4文件 一、分离MP3 ffmpeg -i C:\Users\Administrator\Desktop\一个文件夹\我在财神殿里长跪不起_完整版MV.mp4 -vn C:\Users\Administrator\Desktop\一个文件…...

JVM 参数
JVM 参数类型大致分为以下几类: 标准参数(-):保证在所有的 JVM 实现都支持的参数非标准参数(-X):通用的,特定于 HotSpot 虚拟机的参数,这些参数不保证在所有 JVM 实现中…...

黑马点评-07缓存击穿问题(热点key失效)及解决方案,互斥锁和设置逻辑过期时间
缓存击穿问题(热点key失效) 缓存击穿问题也叫热点Key问题,就是一个被高并发访问并且重建缓存业务较复杂的key突然失效了,此时无数的请求访问会在瞬间打到数据库,带来巨大的冲击 一件秒杀中的商品的key突然失效了,由于大家都在疯狂抢购那么这个瞬间就会有无数的请求…...

信息系统项目管理师第四版学习笔记——项目进度管理
项目进度管理过程 项目进度管理过程包括:规划进度管理、定义活动、排列活动顺序、估算活动持续时间、制订进度计划、控制进度。 规划进度管理 规划进度管理是为规划、编制、管理、执行和控制项目进度而制定政策、程序和文档的过程。本过程的主要作用是为如何在…...

指挥棒:C++ 与运算符
文章目录 参考描述算术运算符除法运算取模运算复合赋值运算符自增运算符自减运算符 比较运算符逻辑运算符概念短路为什么需要短路机制? 参考 项目描述微软C 语言文档搜索引擎Bing、GoogleAI 大模型文心一言、通义千问、讯飞星火认知大模型、ChatGPTC Primer Plus &…...

HTTPS建立连接的过程
HTTPS 协议是基于 TCP 协议的,因而要先建立 TCP 的连接。在这个例子中,TCP 的连接是在手机上的 App 和负载均衡器 SLB 之间的。 尽管中间要经过很多的路由器和交换机,但是 TCP 的连接是端到端的。TCP 这一层和更上层的 HTTPS 无法看到中间的包…...

Python接口自动化搭建过程,含request请求封装!
开篇碎碎念 接口测试自动化好处 显而易见的好处就是解放双手😀。 可以在短时间内自动执行大量的测试用例通过参数化和数据驱动的方式进行测试数据的变化,提高测试覆盖范围快速反馈测试执行结果和报告支持持续集成和持续交付的流程 使用Requestspytes…...

Vue3 编译原理
文章目录 一、编译流程1. 解读入口文件 packgages/vue/index.ts2. compile函数的运行流程 二、AST 解析器1. ast 的生成2. 创建ast的根节点3. 解析子节点 parseChildren(关键)4. 解析模版元素 Element模版元素解析-举例分析 一、编译流程 1. 解读入口文…...

spring boot整合Minio
MinIO 安装MinIo # 先创建minio 文件存放的位置 mkdir -p /opt/docker/minio/data# 启动并指定端口 docker run \-p 9000:9000 \-p 5001:5001 \--name minio \-v /opt/docker/minio/data:/data \-e "MINIO_ROOT_USERminioadmin" \-e "MINIO_ROOT_PASSWORDmini…...

Hadoop----Azkaban的使用与一些报错问题的解决
1.因为官方只放出源码,并没有放出其tar包,所以需要我们自己编译,通过查阅资料我们可以使用gradlew对其进行编译,还是比较简单,然后将里面需要用到的服务文件夹进行拷贝,完善其文件夹结构,通常会…...

「新房家装经验」客厅电视高度标准尺寸及客厅电视机买多大尺寸合适?
客厅电视悬挂高度标准尺寸是多少? 客厅电视悬挂高度通常在90~120厘米之间,电视挂墙高度也可以根据个人的喜好和实际情况来调整,但通常不宜过高,以坐在沙发上观看时眼睛能够平视到电视中心点或者中心稍微往下一点的位置为适宜。 客…...

ArduPilot开源飞控之AP_Baro_DroneCAN
ArduPilot开源飞控之AP_Baro_DroneCAN 1. 源由2. back-end抽象类3. 方法实现3.1 probe3.2 update3.3 subscribe_msgs3.4 handle_pressure/handle_temperature3.5 CAN port 4. 参考资料 1. 源由 鉴于ArduPilot开源飞控之AP_Baro中涉及Sensor Driver有以下总线类型: …...

Supervised Contrastive Pre-training for Mammographic Triage Screening Model
方法 品红色箭头表示将生成的孪生编码器分别迁移到单视角学习模块和双视角学习模块...

JVM技术文档--JVM优化思路以及问题定位--JVM可调整参数汇总
阿丹: 一个优秀的程序员,是因为在线上的排查以及遇到的线上、生产事故较多所以定位问题以及解决问题会比普通程序员快很多,所以一个优秀的程序员要逐渐形成自己的方法论,来完善和解决问题。 我们是如何发现问题的呢? …...

Oracle10g数据库迁移方案
试验了很多次Oracle数据库迁移才成功,贴出来给大家参考一下,我看到有的地方写迁移之后还需要重新建立temp表空间,这个还没有研究。另外说一点的是两个数据库的版本一定要一致,之前失败过一次,就是因为两个数据库的版本…...

备忘录模式:对象状态的保存与恢复
欢迎来到设计模式系列的第十八篇文章,本篇将介绍备忘录模式。备忘录模式是一种行为型设计模式,它允许在不破坏封装性的前提下捕获一个对象的内部状态,并在之后恢复该状态。这种模式通常用于需要提供撤销操作的情况。 什么是备忘录模式&#…...

C# InvokeRequired线程安全
C# InvokeRequired线程安全 为了保证新家的线程可能要对主界面的控件元素的属性发生一些改变,此时防止此操作对于主线程的影响,就提出了 InvokeRequired方法,保证主线程的安全,同时新加的线程也可以改变主页面中元素的值。 定义…...

pdf怎么转成jpg图片格式
pdf怎么转成jpg图片格式?对于大家平时在工作或者生活中的图片使用习惯,经常需要将各种格式的文件转换成易于浏览和使用的JPG格式图片以便保存。如今,因为pdf文件具有更强的稳定性和设备兼容性,PDF文件在平时的电脑使用过程中可以说…...

React +ts + babel+webpack
babel babel/preset-typescript 专门处理ts "babel/cli": "^7.17.6", "babel/core": "^7.17.8", "babel/preset-env": "^7.16.11", "babel/preset-react": "^7.16.7", "babel/preset…...

红队专题-REVERSE二进制逆向反编译
红队专题 招募六边形战士队员IDA pro安装python2加入环境变量py2安装pip安装IDA 7.0 proIDAPython: importing "site" failed. 招募六边形战士队员 一起学习 代码审计、安全开发、web攻防、逆向等。。。 私信联系 IDA pro 安装python2 python-2.7.3.msi 加入环…...

Spring技术原理之Bean生命周期原理解析
Spring技术原理之Bean生命周期原理解析 Spring作为Java领域中的优秀框架,其核心功能之一是依赖注入和生命周期管理。其中,Bean的生命周期管理是Spring框架中一个重要的概念。在本篇文章中,我们将深入探讨Spring技术原理中的Bean生命周期原理…...

Unity实现设计模式——模板方法模式
Unity实现设计模式——模板方法模式 模板模式(Template Pattern), 指在一个抽象类公开定义了执行它的方法的模板。它的子类可以按需要重写方法实现,但调用将以抽象类中定义的方式进行。 简单说, 模板方法模式定义一个操作中的算法的骨架&…...

C++实现高性能内存池(二)
文章目录 一、设计内存池二、实现MemoryPool::construct() 实现MemoryPool::deallocate() 实现MemoryPool::~MemoryPool() 实现MemoryPool::allocate() 实现三、与 std::vector 的性能对比一、设计内存池 在上节中,我们在模板链表栈中使用了默认构造器来管理栈操作中的元素内…...

沪深300期权一个点多少钱?
经中国证监会批准,深圳证券交易所于2019年12月23日上市嘉实沪深300ETF期权合约品种。该产品是以沪深300为标的物的嘉实沪深300ETF交易型指数基金为标的衍生的标准化合约,下文介绍沪深300期权一个点多少钱?本文来自:期权酱 一、沪深300期权涨…...