从零开始学习线性回归:理论、实践与PyTorch实现
文章目录
- 🥦介绍
- 🥦基本知识
- 🥦代码实现
- 🥦完整代码
- 🥦总结
🥦介绍
线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。
🥦基本知识
线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:

其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,…,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,…,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。
损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:

其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。
梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:

其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE ∂bj∂MSE 是损失函数对参数 b j b_j bj的偏导数。
🥦代码实现
如果你想知道实现线性回归的大体步骤,下图可以充分进行说明

- 准备数据
- 设计模型(计算) y ^ i \hat{y}_i y^i
- 构造损失和优化器
- 训练周期(前向,反向 ,更新)
本节还是以刘二大人的视频讲解为例,结尾会设置传送门
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1) # 参数详情下图展示def forward(self, x):y_pred = self.linear(x) # x代表输入样本的张量return y_pred
model = LinearModel()
所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算
好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。

-
第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。
-
第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。
-
第三个参数:意思是要不要偏置量。默认true
通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立
criterion = torch.nn.MSELoss(size_average=False) # 使用均方误差损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 使用随机梯度下降优化器


model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。
优化器的选择有许多大家可以都试试看看

之后就进行训练了
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() # 归零loss.backward() # 反向optimizer.step() # 更新
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
🥦完整代码
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()
运行结果如下


🥦总结
在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。

挑战与创造都是很痛苦的,但是很充实。
相关文章:
从零开始学习线性回归:理论、实践与PyTorch实现
文章目录 🥦介绍🥦基本知识🥦代码实现🥦完整代码🥦总结 🥦介绍 线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性…...
[LeetCode]链式二叉树相关题目(c语言实现)
文章目录 LeetCode965. 单值二叉树LeetCode100. 相同的树LeetCode101. 对称二叉树LeetCode144. 二叉树的前序遍历LeetCode94. 二叉树的中序遍历LeetCode145. 二叉树的后序遍历LeetCode572. 另一棵树的子树 LeetCode965. 单值二叉树 题目 Oj链接 思路 一棵树的所有值都是一个…...
集成学习
集成学习(Ensemble Learning) - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/27689464集成学习就是组合这里的多个弱监督模型以期得到一个更好更全面的强监督模型,集成学习潜在的思想是即便某一个弱分类器得到了错误的预测,其他的弱分类器…...
算法练习11——买卖股票的最佳时机 II
LeetCode 122 买卖股票的最佳时机 II 给你一个整数数组 prices ,其中 prices[i] 表示某支股票第 i 天的价格。 在每一天,你可以决定是否购买和/或出售股票。你在任何时候 最多 只能持有 一股 股票。你也可以先购买,然后在 同一天 出售。 返回…...
linux——多线程,线程控制
目录 一.POSIX线程库 二.线程创建 1.创建线程接口 2.查看线程 3.多线程的健壮性问题 4.线程函数参数传递 5.线程id和地址空间 三.线程终止 1.pthread_exit 2.pthread_cancel 四.线程等待 五.线程分离 一.POSIX线程库 站在内核的角度,OS只有轻量级进程…...
Oracle 简介与 Docker Compose部署
最近,我翻阅了在之前公司工作时的笔记,偶然发现了一些有关数据库的记录。当初,我们的项目一开始采用的是 Oracle 数据库,但随着项目需求的变化,我们不得不转向使用 SQL Server。值得一提的是,公司之前采用的…...
mp4音视频分离技术
文章目录 问题描述一、分离MP3二、分离无声音的MP4三、结果 问题描述 MP4视频想拆分成一个MP3音频和一个无声音的MP4文件 一、分离MP3 ffmpeg -i C:\Users\Administrator\Desktop\一个文件夹\我在财神殿里长跪不起_完整版MV.mp4 -vn C:\Users\Administrator\Desktop\一个文件…...
JVM 参数
JVM 参数类型大致分为以下几类: 标准参数(-):保证在所有的 JVM 实现都支持的参数非标准参数(-X):通用的,特定于 HotSpot 虚拟机的参数,这些参数不保证在所有 JVM 实现中…...
黑马点评-07缓存击穿问题(热点key失效)及解决方案,互斥锁和设置逻辑过期时间
缓存击穿问题(热点key失效) 缓存击穿问题也叫热点Key问题,就是一个被高并发访问并且重建缓存业务较复杂的key突然失效了,此时无数的请求访问会在瞬间打到数据库,带来巨大的冲击 一件秒杀中的商品的key突然失效了,由于大家都在疯狂抢购那么这个瞬间就会有无数的请求…...
信息系统项目管理师第四版学习笔记——项目进度管理
项目进度管理过程 项目进度管理过程包括:规划进度管理、定义活动、排列活动顺序、估算活动持续时间、制订进度计划、控制进度。 规划进度管理 规划进度管理是为规划、编制、管理、执行和控制项目进度而制定政策、程序和文档的过程。本过程的主要作用是为如何在…...
指挥棒:C++ 与运算符
文章目录 参考描述算术运算符除法运算取模运算复合赋值运算符自增运算符自减运算符 比较运算符逻辑运算符概念短路为什么需要短路机制? 参考 项目描述微软C 语言文档搜索引擎Bing、GoogleAI 大模型文心一言、通义千问、讯飞星火认知大模型、ChatGPTC Primer Plus &…...
HTTPS建立连接的过程
HTTPS 协议是基于 TCP 协议的,因而要先建立 TCP 的连接。在这个例子中,TCP 的连接是在手机上的 App 和负载均衡器 SLB 之间的。 尽管中间要经过很多的路由器和交换机,但是 TCP 的连接是端到端的。TCP 这一层和更上层的 HTTPS 无法看到中间的包…...
Python接口自动化搭建过程,含request请求封装!
开篇碎碎念 接口测试自动化好处 显而易见的好处就是解放双手😀。 可以在短时间内自动执行大量的测试用例通过参数化和数据驱动的方式进行测试数据的变化,提高测试覆盖范围快速反馈测试执行结果和报告支持持续集成和持续交付的流程 使用Requestspytes…...
Vue3 编译原理
文章目录 一、编译流程1. 解读入口文件 packgages/vue/index.ts2. compile函数的运行流程 二、AST 解析器1. ast 的生成2. 创建ast的根节点3. 解析子节点 parseChildren(关键)4. 解析模版元素 Element模版元素解析-举例分析 一、编译流程 1. 解读入口文…...
spring boot整合Minio
MinIO 安装MinIo # 先创建minio 文件存放的位置 mkdir -p /opt/docker/minio/data# 启动并指定端口 docker run \-p 9000:9000 \-p 5001:5001 \--name minio \-v /opt/docker/minio/data:/data \-e "MINIO_ROOT_USERminioadmin" \-e "MINIO_ROOT_PASSWORDmini…...
Hadoop----Azkaban的使用与一些报错问题的解决
1.因为官方只放出源码,并没有放出其tar包,所以需要我们自己编译,通过查阅资料我们可以使用gradlew对其进行编译,还是比较简单,然后将里面需要用到的服务文件夹进行拷贝,完善其文件夹结构,通常会…...
「新房家装经验」客厅电视高度标准尺寸及客厅电视机买多大尺寸合适?
客厅电视悬挂高度标准尺寸是多少? 客厅电视悬挂高度通常在90~120厘米之间,电视挂墙高度也可以根据个人的喜好和实际情况来调整,但通常不宜过高,以坐在沙发上观看时眼睛能够平视到电视中心点或者中心稍微往下一点的位置为适宜。 客…...
ArduPilot开源飞控之AP_Baro_DroneCAN
ArduPilot开源飞控之AP_Baro_DroneCAN 1. 源由2. back-end抽象类3. 方法实现3.1 probe3.2 update3.3 subscribe_msgs3.4 handle_pressure/handle_temperature3.5 CAN port 4. 参考资料 1. 源由 鉴于ArduPilot开源飞控之AP_Baro中涉及Sensor Driver有以下总线类型: …...
Supervised Contrastive Pre-training for Mammographic Triage Screening Model
方法 品红色箭头表示将生成的孪生编码器分别迁移到单视角学习模块和双视角学习模块...
JVM技术文档--JVM优化思路以及问题定位--JVM可调整参数汇总
阿丹: 一个优秀的程序员,是因为在线上的排查以及遇到的线上、生产事故较多所以定位问题以及解决问题会比普通程序员快很多,所以一个优秀的程序员要逐渐形成自己的方法论,来完善和解决问题。 我们是如何发现问题的呢? …...
Linux配置静态ip地址和Oracle VM VirtualBox导入/导出虚拟机Centos7
导入虚拟机选择管理 - 导入虚拟电脑找到自己的虚拟机位置修改内存大小,默认虚拟机电脑位置,MAC地址等导入后点击设置如下图:修改网络-网 -- 卡1,其他基本不需要修改桥接网络选好网卡接入网线;设置好网络以后使用命令重…...
mkcert 命令文档 - 本地 HTTPS 开发证书生成工具详解
1. 命令简介mkcert 是一个用 Go 语言编写的、零配置的本地开发用自签名证书生成工具。它能够自动创建并安装本地证书颁发机构(CA)到系统的信任存储中,并生成受本地信任的开发证书,大幅简化 HTTPS 本地开发环境的搭建过程ÿ…...
抖音批量下载助手:轻松管理您的抖音视频资源库
抖音批量下载助手:轻松管理您的抖音视频资源库 【免费下载链接】douyinhelper 抖音批量下载助手 项目地址: https://gitcode.com/gh_mirrors/do/douyinhelper 还在为手动保存抖音视频而烦恼吗?抖音批量下载助手正是您需要的效率工具!这…...
LeagueAkari:英雄联盟智能辅助工具完全指南
LeagueAkari:英雄联盟智能辅助工具完全指南 【免费下载链接】League-Toolkit 兴趣使然的、简单易用的英雄联盟工具集。支持战绩查询、自动秒选等功能。基于 LCU API。 项目地址: https://gitcode.com/gh_mirrors/le/League-Toolkit LeagueAkari是一款基于英雄…...
Excel VBA图像处理:如何在单元格中显示并调整图片大小
在Excel中处理图片时,VBA(Visual Basic for Applications)是一个强大的工具。今天我们将讨论如何通过VBA代码在Excel的单元格中插入并调整图片大小,以及如何解决一些常见的问题。 背景介绍 假设你有一个Excel工作表,A列从A2开始存放了几个图片文件名,如"test.jpg&…...
LumiPixel开箱即用教程:快速上手这个专为人像设计的AI创作平台
LumiPixel开箱即用教程:快速上手这个专为人像设计的AI创作平台 1. 认识LumiPixel:纯净人像创作平台 LumiPixel: Canvas Quest是一款专注于人像创作的AI视觉平台,它将先进的Z-Image扩散模型与复古像素艺术美学完美结合。这个平台特别适合需要…...
工业视觉检测避坑指南:CogBlobTool阈值设置5大常见错误及解决方案
工业视觉检测避坑指南:CogBlobTool阈值设置5大常见错误及解决方案 在工业视觉检测领域,斑点检测(Blob Analysis)是最基础也最关键的环节之一。作为Cognex VisionPro套件中的核心工具,CogBlobTool凭借其强大的图像分割能…...
内容营销对 SEO 有什么影响
<h3 id"seo">内容营销对 SEO 有什么影响</h3> <h4 id"">引言</h4> <p>在当今数字化时代,搜索引擎优化(SEO)和内容营销被广泛认为是网站流量和业务增长的关键驱动因素。许多企业在网站建设…...
百川2-13B模型实战:Python爬虫数据的智能分析与摘要生成
百川2-13B模型实战:Python爬虫数据的智能分析与摘要生成 每天,互联网上都会产生海量的文本信息,新闻、论坛帖子、社交媒体动态……对于市场分析师、舆情监控人员或者内容运营者来说,如何从这些信息海洋中快速提炼出有价值的内容&…...
2026年全国青少年信息素养大赛算法应用主题赛(C++赛项初赛模拟卷3:文末附答案)
2026年全国青少年信息素养大赛算法应用主题赛(C赛项初赛模拟卷3:文末附答案) 一、单选题 在C中,以下哪个关键字用于定义一个整型变量? A. int B. float C. char D. double 一支商队从长安出发,每天行进80里…...
