从零手写线性回归模型:PyTorch 实现深度学习入门教程
系列文章目录
01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量!
02-张量运算真简单!PyTorch 数值计算操作完全指南
03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧
04-揭秘数据处理神器:PyTorch 张量拼接与拆分实用技巧
05-深度学习从索引开始:PyTorch 张量索引与切片最全解析
06-张量形状任意改!PyTorch reshape、transpose 操作超详细教程
07-深入解读 PyTorch 张量运算:6 大核心函数全面解析,代码示例一步到位!
08-自动微分到底有多强?PyTorch 自动求导机制深度解析
09-从零手写线性回归模型:PyTorch 实现深度学习入门教程
10-PyTorch 框架实现线性回归:从数据预处理到模型训练全流程
文章目录
- 系列文章目录
- 前言
- 一、构建数据集
- 1.1 示例代码
- 1.2 示例输出
- 二、构建假设函数
- 2.1 示例代码
- 三、损失函数
- 3.1 示例代码
- 四、优化方法
- 4.1 示例代码
- 五、训练函数
- 5.1 示例代码
- 5.2 绘制结果
- 六、调用训练函数
- 6.1 示例输出
- 七、小结
- 7.1 完整代码
前言
在机器学习的学习过程中,我们接触过 线性回归 模型,并使用过如 Scikit-learn 这样的工具来快速实现。但在本文中,将深入理解线性回归的核心思想,并使用 PyTorch 从零开始手动实现一个线性回归模型。这包括:
- 数据集的构建;
- 假设函数的定义;
- 损失函数的设计;
- 梯度下降优化方法的实现;
- 模型训练和损失变化的可视化。
一、构建数据集
线性回归需要一个简单的线性数据集,我们将通过sklearn.datasets.make_regression
方法生成。
1.1 示例代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import randomdef create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120, # 样本数量,即数据点个数n_features=1, # 每个样本只有一个特征noise=15, # 添加噪声,模拟真实场景coef=True, # 是否返回生成数据的真实系数bias=12.0, # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32) # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 输出标签张量转换为二维return x, y, coef # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y) # 数据集长度indices = list(range(data_len)) # 创建索引列表random.shuffle(indices) # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size # 计算批次数for i in range(num_batches): # 遍历每个批次start = i * batch_size # 当前批次起始索引end = start + batch_size # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]] # 当前批次输入batch_y = y[indices[start:end]] # 当前批次输出yield batch_x, batch_y # 返回当前批次数据
1.2 示例输出
运行 create_dataset
后的数据分布为线性趋势,同时包含噪声点。例如:
x:
tensor([[-1.3282],[ 0.1941],[ 0.8944],...])
y:
tensor([-40.2345, 15.2934, 45.1282, ...])
coef:
tensor([35.0])
二、构建假设函数
线性回归的假设函数可以表示为:
y = w ⋅ x + b y ^ =w⋅x+b y=w⋅x+b
其中,( w ) 是权重,( b ) 是偏置。使用 PyTorch 张量定义这些参数。
2.1 示例代码
# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32) # 权重
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32) # 偏置# 假设函数
def linear_regression(x):return w * x + b
三、损失函数
我们使用均方误差(MSE)作为损失函数,其公式为:
L = 1 n ∑ i = 1 n ( y ^ i − y i ) 2 L = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i)^2 L=n1i=1∑n(y^i−yi)2
3.1 示例代码
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2
四、优化方法
为了更新模型参数 ( w ) 和 ( b ),使用 随机梯度下降(SGD) 算法,其更新公式为:
w = w − η ⋅ ∂ L ∂ w , b = b − η ⋅ ∂ L ∂ b w = w - \eta \cdot \frac{\partial L}{\partial w}, \quad b = b - \eta \cdot \frac{\partial L}{\partial b} w=w−η⋅∂w∂L,b=b−η⋅∂b∂L
其中,η 是学习率。
4.1 示例代码
def sgd(lr=0.01, batch_size=16):# 更新权重和偏置w.data = w.data - lr * w.grad.data / batch_sizeb.data = b.data - lr * b.grad.data / batch_size
五、训练函数
将前面定义的所有组件组合在一起,构建训练函数,通过多个 epoch 来优化模型。
5.1 示例代码
# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50 # 训练轮次learning_rate = 0.01 # 学习率batch_size = 16 # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = [] # 用于记录每轮的平均损失for epoch in range(epochs): # 遍历每一轮total_loss = 0.0 # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size): # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y) # 当前批次的平均损失total_loss += loss.item() # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}") # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)
5.2 绘制结果
# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7) # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1) # 连续 x 值y_pred = linear_regression(x_line).detach().numpy() # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32) # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0 # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red') # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green') # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()
六、调用训练函数
在主程序中调用 train
函数,训练模型并观察输出。
if __name__ == "__main__":train()
6.1 示例输出
- 拟合直线:
- 损失变化曲线:
七、小结
本文通过手动实现线性回归模型,完成了以下内容:
- 构建数据集并设计数据加载器;
- 定义线性假设函数;
- 设计均方误差损失函数;
- 实现随机梯度下降优化方法;
- 训练模型并可视化损失变化和拟合直线。
7.1 完整代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random# 设置 Matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号# 数据集生成函数
def create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120, # 样本数量,即数据点个数n_features=1, # 每个样本只有一个特征noise=15, # 添加噪声,模拟真实场景coef=True, # 是否返回生成数据的真实系数bias=12.0, # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32) # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 输出标签张量转换为二维return x, y, coef # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y) # 数据集长度indices = list(range(data_len)) # 创建索引列表random.shuffle(indices) # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size # 计算批次数for i in range(num_batches): # 遍历每个批次start = i * batch_size # 当前批次起始索引end = start + batch_size # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]] # 当前批次输入batch_y = y[indices[start:end]] # 当前批次输出yield batch_x, batch_y # 返回当前批次数据# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32) # 权重,初始值为 0.5
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32) # 偏置,初始值为 0# 线性假设函数
def linear_regression(x):"""线性回归假设函数。参数:x: 输入特征张量返回:模型预测值"""return w * x + b # 线性模型公式# 损失函数(均方误差)
def square_loss(y_pred, y_true):"""均方误差损失函数。参数:y_pred: 模型预测值y_true: 数据真实值返回:每个样本的平方误差"""return ((y_pred - y_true) ** 2).mean() # 返回均方误差# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50 # 训练轮次learning_rate = 0.01 # 学习率batch_size = 16 # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = [] # 用于记录每轮的平均损失for epoch in range(epochs): # 遍历每一轮total_loss = 0.0 # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size): # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y) # 当前批次的平均损失total_loss += loss.item() # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}") # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7) # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1) # 连续 x 值y_pred = linear_regression(x_line).detach().numpy() # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32) # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0 # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red') # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green') # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()# 调用训练函数
if __name__ == "__main__":train()
相关文章:

从零手写线性回归模型:PyTorch 实现深度学习入门教程
系列文章目录 01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量! 02-张量运算真简单!PyTorch 数值计算操作完全指南 03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧 04-揭秘数据处理神器:PyTor…...

【Cesium】自定义材质,添加带有方向的滚动路线
【Cesium】自定义材质,添加带有方向的滚动路线 🍖 前言🎶一、实现过程✨二、代码展示🏀三、运行结果🏆四、知识点提示 🍖 前言 【Cesium】自定义材质,添加带有方向的滚动路线 🎶一、…...
C 语言奇幻之旅 - 第11篇:C 语言动态内存管理
目录 引言1. 内存分配函数1.1 malloc 函数实际开发场景:动态数组 1.2 calloc 函数实际开发场景:初始化数据结构 1.3 realloc 函数实际开发场景:动态调整数据结构大小 2. 内存释放2.1 free 函数 3. 内存泄漏与调试3.1 常见内存问题3.2 内存调试…...

IDEA 撤销 merge 操作(详解)
作为一个开发者,我们都知道Git是一个非常重要的版本控制工具,尤其是在协作开发的过程中。然而,在使用Git的过程中难免会踩一些坑,今天我来给大家分享一个我曾经遇到的问题:在使用IDEA中进行merge操作后如何撤销错误的合…...
swarm天气智能体调用流程
Swarm 框架的调用流程: 入口点 (examples/weather_agent/run.py): run_demo_loop(weather_agent, streamTrue)初始化流程: # swarm/repl/repl.py -> run_demo_loop() client Swarm() # 创建 Swarm 实例消息处理流程: # swarm/core.py class Swarm:def run(…...

LED背光驱动芯片RT9293应用电路
一)简介: RT9293 是一款高频、异步的 Boost 升压型 LED 定电流驱动控制器,其工作原理如下: 1)基本电路结构及原理 RT9293的主要功能为上图的Q1. Boost 电路核心原理:基于电感和电容的特性实现升压功能。当…...
二叉树的二叉链表和三叉链表
在二叉树的数据结构中,通常有两种链表存储方式:二叉链表和三叉链表。这里,我们先澄清一下概念,通常我们讨论的是二叉链表,它用于存储二叉树的节点。而“三叉链表”这个术语在二叉树的上下文中不常见,可能是…...

【学习路线】Python 算法(人工智能)详细知识点学习路径(附学习资源)
学习本路线内容之前,请先学习Python的基础知识 其他路线: Python基础 >> Python进阶 >> Python爬虫 >> Python数据分析(数据科学) >> Python 算法(人工智能) >> Pyth…...
C++直接内存管理new和delete
0、前言 C语言定义了两个运算符来分配和释放动态内存。运算符new分配内存,delete释放new分配的内存。 1、new动态内存的分配 1.1、new动态分配和初始化对象 1)、new内存分配 在自由的空间分配的内存是无名的,new无法为其分配的对象…...
Linux 内核中网络接口的创建与管理
在 Linux 系统中,网络接口(如 eth0、wlan0 等)是计算机与外部网络通信的桥梁。无论是物理网卡还是虚拟网络接口,它们的创建和管理都依赖于 Linux 内核的复杂机制。本文将深入探讨 Linux 内核中网络接口的创建过程、命名规则、路由选择以及内核如何将网络接口映射到实际的硬…...
人工智能 前馈神经网络练习题
为了构建一个有两个输入( X 1 X_1 X1、 X 2 X_2 X2)和一个输出的单层感知器,并进行分类,我们需要计算权值 w 1 w_1 w1和 w 2 w_2 w2的更新过程。以下是详细的步骤和计算过程: 初始化参数 初始权值:…...

Windows搭建RTMP服务器
目录 一、Nginx-RTMP服务器搭建1、下载Nginx2、下载Nginx的RTMP扩展包3、修改配置文件4、启动服务器5、查看服务器状态6、其它ngnix命令 二、OBS推流1 、推流设置2、查看服务器状态 三、VLC拉流四、补充 本文转载自:Windows搭建RTMP服务器OBS推流VLC拉流_浏览器查看…...
Vue重新加载子组件
背景:组件需要重新加载,即重新走一遍组件的生命周期常见解决方案: 使用v-if指令:v-if 可以实现 true (加载)和 false (卸载) async reloadComponent() {this.show false// 加上 nextTick this.$nextTick(function() {this.show…...

【VScode】设置代理,通过代理连接服务器
文章目录 VScode编辑器设置代理1.图形化界面1.1 进入proxy设置界面1.2 配置代理服务器 2.配置文件(推荐)2.1 打开setting.json 文件2.2 配置代理 VScode编辑器设置代理 根据情况安装nmap 1.图形化界面 1.1 进入proxy设置界面 或者使用快捷键ctrl , 。…...
js es6 reduce函数, 通过规格生成sku
const specs [{ name: 颜色, values: [红色, 蓝色, 绿色] },{ name: 尺寸, values: [S, M, L] } ];function generateSKUs(specs) {return specs.reduce((acc, spec) > {const newAcc [];for (const combination of acc) {for (const value of spec.values) {newAcc.push(…...
基于R语言的DICE模型
DICE型是运用最广泛的综合模型之一。DICE和RICE模型虽然代码量不多,但涉及经济学与气候变化,原理较为复杂。 一:DICE模型的原理与推导 1.经济学 2.气候变化问题 3.DICE模型的经济学部分 4.DICE模型的气候相关部分 5.DICE模型的目标函数…...
【C】PAT 1006-1010
1006 换个格式输出整数 让我们用字母 B 来表示“百”、字母 S 表示“十”,用 12...n 来表示不为零的个位数字 n(<10),换个格式来输出任一个不超过 3 位的正整数。例如 234 应该被输出为 BBSSS1234,因为它有 2 个“…...

力扣双指针-算法模版总结
lc-15.三数之和 (时隔13天) 目前可通过,想法上无逻辑问题,一点细节小错误需注意即可 lc-283.移动零(时隔16天) 总结:观察案例直觉就是双指针遇零交换,两次实现都通过了,…...

解释一下:运放的输入偏置电流
输入偏置电流 首先看基础部分:这就是同相比例放大器 按照理论计算,输入VIN=0时,输出VOUT应为0,对吧 仿真与理论差距较大,有200多毫伏的偏差,这就是输入偏置电流IBIAS引起的,接着看它的定义 同向和反向输入电流的平均值,也就是Ib1、Ib2求平均,即(Ib1+Ib2)/2 按照下面…...
Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8
在 Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8 的步骤如下: ✅ 1. 检查 WSL 的安装 首先确保已经安装并启用了 WSL 2。 🔧 检查 WSL 版本 打开 PowerShell,执行以下命令: wsl --list --verbose确保 W…...

wordpress后台更新后 前端没变化的解决方法
使用siteground主机的wordpress网站,会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后,网站没有变化的情况。 不熟悉siteground主机的新手,遇到这个问题,就很抓狂,明明是哪都没操作错误&#x…...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...
【网络】每天掌握一个Linux命令 - iftop
在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...
Spring Boot 实现流式响应(兼容 2.7.x)
在实际开发中,我们可能会遇到一些流式数据处理的场景,比如接收来自上游接口的 Server-Sent Events(SSE) 或 流式 JSON 内容,并将其原样中转给前端页面或客户端。这种情况下,传统的 RestTemplate 缓存机制会…...

React第五十七节 Router中RouterProvider使用详解及注意事项
前言 在 React Router v6.4 中,RouterProvider 是一个核心组件,用于提供基于数据路由(data routers)的新型路由方案。 它替代了传统的 <BrowserRouter>,支持更强大的数据加载和操作功能(如 loader 和…...

【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
vue3 字体颜色设置的多种方式
在Vue 3中设置字体颜色可以通过多种方式实现,这取决于你是想在组件内部直接设置,还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法: 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

ElasticSearch搜索引擎之倒排索引及其底层算法
文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...