当前位置: 首页 > news >正文

从零手写线性回归模型:PyTorch 实现深度学习入门教程

系列文章目录

01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量!
02-张量运算真简单!PyTorch 数值计算操作完全指南
03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧
04-揭秘数据处理神器:PyTorch 张量拼接与拆分实用技巧
05-深度学习从索引开始:PyTorch 张量索引与切片最全解析
06-张量形状任意改!PyTorch reshape、transpose 操作超详细教程
07-深入解读 PyTorch 张量运算:6 大核心函数全面解析,代码示例一步到位!
08-自动微分到底有多强?PyTorch 自动求导机制深度解析
09-从零手写线性回归模型:PyTorch 实现深度学习入门教程
10-PyTorch 框架实现线性回归:从数据预处理到模型训练全流程


文章目录

  • 系列文章目录
  • 前言
  • 一、构建数据集
    • 1.1 示例代码
    • 1.2 示例输出
  • 二、构建假设函数
    • 2.1 示例代码
  • 三、损失函数
    • 3.1 示例代码
  • 四、优化方法
    • 4.1 示例代码
  • 五、训练函数
    • 5.1 示例代码
    • 5.2 绘制结果
  • 六、调用训练函数
    • 6.1 示例输出
  • 七、小结
    • 7.1 完整代码


前言

在机器学习的学习过程中,我们接触过 线性回归 模型,并使用过如 Scikit-learn 这样的工具来快速实现。但在本文中,将深入理解线性回归的核心思想,并使用 PyTorch 从零开始手动实现一个线性回归模型。这包括:

  1. 数据集的构建;
  2. 假设函数的定义;
  3. 损失函数的设计;
  4. 梯度下降优化方法的实现;
  5. 模型训练和损失变化的可视化。

一、构建数据集

线性回归需要一个简单的线性数据集,我们将通过sklearn.datasets.make_regression 方法生成。

1.1 示例代码

import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import randomdef create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120,  # 样本数量,即数据点个数n_features=1,   # 每个样本只有一个特征noise=15,       # 添加噪声,模拟真实场景coef=True,      # 是否返回生成数据的真实系数bias=12.0,      # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32)  # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)  # 输出标签张量转换为二维return x, y, coef  # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y)  # 数据集长度indices = list(range(data_len))  # 创建索引列表random.shuffle(indices)  # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size  # 计算批次数for i in range(num_batches):  # 遍历每个批次start = i * batch_size  # 当前批次起始索引end = start + batch_size  # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]]  # 当前批次输入batch_y = y[indices[start:end]]  # 当前批次输出yield batch_x, batch_y  # 返回当前批次数据

1.2 示例输出

运行 create_dataset 后的数据分布为线性趋势,同时包含噪声点。例如:

x:
tensor([[-1.3282],[ 0.1941],[ 0.8944],...])
y:
tensor([-40.2345,  15.2934,  45.1282, ...])
coef:
tensor([35.0])

二、构建假设函数

线性回归的假设函数可以表示为:
y ​ = w ⋅ x + b y ^ ​ =w⋅x+b y=wx+b
其中,( w ) 是权重,( b ) 是偏置。使用 PyTorch 张量定义这些参数。

2.1 示例代码

# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32)  # 权重
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32)  # 偏置# 假设函数
def linear_regression(x):return w * x + b

三、损失函数

我们使用均方误差(MSE)作为损失函数,其公式为:

L = 1 n ∑ i = 1 n ( y ^ i − y i ) 2 L = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i)^2 L=n1i=1n(y^iyi)2

3.1 示例代码

def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2

四、优化方法

为了更新模型参数 ( w ) 和 ( b ),使用 随机梯度下降(SGD) 算法,其更新公式为:

w = w − η ⋅ ∂ L ∂ w , b = b − η ⋅ ∂ L ∂ b w = w - \eta \cdot \frac{\partial L}{\partial w}, \quad b = b - \eta \cdot \frac{\partial L}{\partial b} w=wηwL,b=bηbL
其中,η 是学习率。

4.1 示例代码

def sgd(lr=0.01, batch_size=16):# 更新权重和偏置w.data = w.data - lr * w.grad.data / batch_sizeb.data = b.data - lr * b.grad.data / batch_size

五、训练函数

将前面定义的所有组件组合在一起,构建训练函数,通过多个 epoch 来优化模型。

5.1 示例代码

# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50  # 训练轮次learning_rate = 0.01  # 学习率batch_size = 16  # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = []  # 用于记录每轮的平均损失for epoch in range(epochs):  # 遍历每一轮total_loss = 0.0  # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size):  # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y)  # 当前批次的平均损失total_loss += loss.item()  # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}")  # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)

5.2 绘制结果

# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7)  # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1)  # 连续 x 值y_pred = linear_regression(x_line).detach().numpy()  # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32)  # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0  # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red')  # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green')  # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()

六、调用训练函数

在主程序中调用 train 函数,训练模型并观察输出。

if __name__ == "__main__":train()

6.1 示例输出

  • 拟合直线:
    在这里插入图片描述
  • 损失变化曲线:
    在这里插入图片描述

七、小结

本文通过手动实现线性回归模型,完成了以下内容:

  1. 构建数据集并设计数据加载器;
  2. 定义线性假设函数;
  3. 设计均方误差损失函数;
  4. 实现随机梯度下降优化方法;
  5. 训练模型并可视化损失变化和拟合直线。

7.1 完整代码

import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random# 设置 Matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号# 数据集生成函数
def create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120,  # 样本数量,即数据点个数n_features=1,   # 每个样本只有一个特征noise=15,       # 添加噪声,模拟真实场景coef=True,      # 是否返回生成数据的真实系数bias=12.0,      # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32)  # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)  # 输出标签张量转换为二维return x, y, coef  # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y)  # 数据集长度indices = list(range(data_len))  # 创建索引列表random.shuffle(indices)  # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size  # 计算批次数for i in range(num_batches):  # 遍历每个批次start = i * batch_size  # 当前批次起始索引end = start + batch_size  # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]]  # 当前批次输入batch_y = y[indices[start:end]]  # 当前批次输出yield batch_x, batch_y  # 返回当前批次数据# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32)  # 权重,初始值为 0.5
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32)  # 偏置,初始值为 0# 线性假设函数
def linear_regression(x):"""线性回归假设函数。参数:x: 输入特征张量返回:模型预测值"""return w * x + b  # 线性模型公式# 损失函数(均方误差)
def square_loss(y_pred, y_true):"""均方误差损失函数。参数:y_pred: 模型预测值y_true: 数据真实值返回:每个样本的平方误差"""return ((y_pred - y_true) ** 2).mean()  # 返回均方误差# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50  # 训练轮次learning_rate = 0.01  # 学习率batch_size = 16  # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = []  # 用于记录每轮的平均损失for epoch in range(epochs):  # 遍历每一轮total_loss = 0.0  # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size):  # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y)  # 当前批次的平均损失total_loss += loss.item()  # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}")  # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7)  # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1)  # 连续 x 值y_pred = linear_regression(x_line).detach().numpy()  # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32)  # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0  # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red')  # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green')  # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()# 调用训练函数
if __name__ == "__main__":train()

相关文章:

从零手写线性回归模型:PyTorch 实现深度学习入门教程

系列文章目录 01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量! 02-张量运算真简单!PyTorch 数值计算操作完全指南 03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧 04-揭秘数据处理神器:PyTor…...

【Cesium】自定义材质,添加带有方向的滚动路线

【Cesium】自定义材质,添加带有方向的滚动路线 🍖 前言🎶一、实现过程✨二、代码展示🏀三、运行结果🏆四、知识点提示 🍖 前言 【Cesium】自定义材质,添加带有方向的滚动路线 🎶一、…...

C 语言奇幻之旅 - 第11篇:C 语言动态内存管理

目录 引言1. 内存分配函数1.1 malloc 函数实际开发场景:动态数组 1.2 calloc 函数实际开发场景:初始化数据结构 1.3 realloc 函数实际开发场景:动态调整数据结构大小 2. 内存释放2.1 free 函数 3. 内存泄漏与调试3.1 常见内存问题3.2 内存调试…...

IDEA 撤销 merge 操作(详解)

作为一个开发者,我们都知道Git是一个非常重要的版本控制工具,尤其是在协作开发的过程中。然而,在使用Git的过程中难免会踩一些坑,今天我来给大家分享一个我曾经遇到的问题:在使用IDEA中进行merge操作后如何撤销错误的合…...

swarm天气智能体调用流程

Swarm 框架的调用流程: 入口点 (examples/weather_agent/run.py): run_demo_loop(weather_agent, streamTrue)初始化流程: # swarm/repl/repl.py -> run_demo_loop() client Swarm() # 创建 Swarm 实例消息处理流程: # swarm/core.py class Swarm:def run(…...

LED背光驱动芯片RT9293应用电路

一)简介: RT9293 是一款高频、异步的 Boost 升压型 LED 定电流驱动控制器,其工作原理如下: 1)基本电路结构及原理 RT9293的主要功能为上图的Q1. Boost 电路核心原理:基于电感和电容的特性实现升压功能。当…...

二叉树的二叉链表和三叉链表

在二叉树的数据结构中,通常有两种链表存储方式:二叉链表和三叉链表。这里,我们先澄清一下概念,通常我们讨论的是二叉链表,它用于存储二叉树的节点。而“三叉链表”这个术语在二叉树的上下文中不常见,可能是…...

【学习路线】Python 算法(人工智能)详细知识点学习路径(附学习资源)

学习本路线内容之前,请先学习Python的基础知识 其他路线: Python基础 >> Python进阶 >> Python爬虫 >> Python数据分析(数据科学) >> Python 算法(人工智能) >> Pyth…...

C++直接内存管理new和delete

0、前言 C语言定义了两个运算符来分配和释放动态内存。运算符new分配内存,delete释放new分配的内存。 1、new动态内存的分配 1.1、new动态分配和初始化对象 1)、new内存分配 在自由的空间分配的内存是无名的,new无法为其分配的对象…...

Linux 内核中网络接口的创建与管理

在 Linux 系统中,网络接口(如 eth0、wlan0 等)是计算机与外部网络通信的桥梁。无论是物理网卡还是虚拟网络接口,它们的创建和管理都依赖于 Linux 内核的复杂机制。本文将深入探讨 Linux 内核中网络接口的创建过程、命名规则、路由选择以及内核如何将网络接口映射到实际的硬…...

人工智能 前馈神经网络练习题

为了构建一个有两个输入( X 1 X_1 X1​、 X 2 X_2 X2​)和一个输出的单层感知器,并进行分类,我们需要计算权值 w 1 w_1 w1​和 w 2 w_2 w2​的更新过程。以下是详细的步骤和计算过程: 初始化参数 初始权值&#xff1a…...

Windows搭建RTMP服务器

目录 一、Nginx-RTMP服务器搭建1、下载Nginx2、下载Nginx的RTMP扩展包3、修改配置文件4、启动服务器5、查看服务器状态6、其它ngnix命令 二、OBS推流1 、推流设置2、查看服务器状态 三、VLC拉流四、补充 本文转载自:Windows搭建RTMP服务器OBS推流VLC拉流_浏览器查看…...

Vue重新加载子组件

背景:组件需要重新加载,即重新走一遍组件的生命周期常见解决方案: 使用v-if指令:v-if 可以实现 true (加载)和 false (卸载) async reloadComponent() {this.show false// 加上 nextTick this.$nextTick(function() {this.show…...

【VScode】设置代理,通过代理连接服务器

文章目录 VScode编辑器设置代理1.图形化界面1.1 进入proxy设置界面1.2 配置代理服务器 2.配置文件(推荐)2.1 打开setting.json 文件2.2 配置代理 VScode编辑器设置代理 根据情况安装nmap 1.图形化界面 1.1 进入proxy设置界面 或者使用快捷键ctrl , 。…...

js es6 reduce函数, 通过规格生成sku

const specs [{ name: 颜色, values: [红色, 蓝色, 绿色] },{ name: 尺寸, values: [S, M, L] } ];function generateSKUs(specs) {return specs.reduce((acc, spec) > {const newAcc [];for (const combination of acc) {for (const value of spec.values) {newAcc.push(…...

基于R语言的DICE模型

DICE型是运用最广泛的综合模型之一。DICE和RICE模型虽然代码量不多,但涉及经济学与气候变化,原理较为复杂。 一:DICE模型的原理与推导 1.经济学 2.气候变化问题 3.DICE模型的经济学部分 4.DICE模型的气候相关部分 5.DICE模型的目标函数…...

【C】PAT 1006-1010

1006 换个格式输出整数 让我们用字母 B 来表示“百”、字母 S 表示“十”&#xff0c;用 12...n 来表示不为零的个位数字 n&#xff08;<10&#xff09;&#xff0c;换个格式来输出任一个不超过 3 位的正整数。例如 234 应该被输出为 BBSSS1234&#xff0c;因为它有 2 个“…...

力扣双指针-算法模版总结

lc-15.三数之和 &#xff08;时隔13天&#xff09; 目前可通过&#xff0c;想法上无逻辑问题&#xff0c;一点细节小错误需注意即可 lc-283.移动零&#xff08;时隔16天&#xff09; 总结&#xff1a;观察案例直觉就是双指针遇零交换&#xff0c;两次实现都通过了&#xff0c…...

解释一下:运放的输入偏置电流

输入偏置电流 首先看基础部分:这就是同相比例放大器 按照理论计算,输入VIN=0时,输出VOUT应为0,对吧 仿真与理论差距较大,有200多毫伏的偏差,这就是输入偏置电流IBIAS引起的,接着看它的定义 同向和反向输入电流的平均值,也就是Ib1、Ib2求平均,即(Ib1+Ib2)/2 按照下面…...

Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8

在 Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8 的步骤如下&#xff1a; ✅ 1. 检查 WSL 的安装 首先确保已经安装并启用了 WSL 2。 &#x1f527; 检查 WSL 版本 打开 PowerShell&#xff0c;执行以下命令&#xff1a; wsl --list --verbose确保 W…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

JVM垃圾回收机制全解析

Java虚拟机&#xff08;JVM&#xff09;中的垃圾收集器&#xff08;Garbage Collector&#xff0c;简称GC&#xff09;是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象&#xff0c;从而释放内存空间&#xff0c;避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

【AI学习】三、AI算法中的向量

在人工智能&#xff08;AI&#xff09;算法中&#xff0c;向量&#xff08;Vector&#xff09;是一种将现实世界中的数据&#xff08;如图像、文本、音频等&#xff09;转化为计算机可处理的数值型特征表示的工具。它是连接人类认知&#xff08;如语义、视觉特征&#xff09;与…...

听写流程自动化实践,轻量级教育辅助

随着智能教育工具的发展&#xff0c;越来越多的传统学习方式正在被数字化、自动化所优化。听写作为语文、英语等学科中重要的基础训练形式&#xff0c;也迎来了更高效的解决方案。 这是一款轻量但功能强大的听写辅助工具。它是基于本地词库与可选在线语音引擎构建&#xff0c;…...

莫兰迪高级灰总结计划简约商务通用PPT模版

莫兰迪高级灰总结计划简约商务通用PPT模版&#xff0c;莫兰迪调色板清新简约工作汇报PPT模版&#xff0c;莫兰迪时尚风极简设计PPT模版&#xff0c;大学生毕业论文答辩PPT模版&#xff0c;莫兰迪配色总结计划简约商务通用PPT模版&#xff0c;莫兰迪商务汇报PPT模版&#xff0c;…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

DiscuzX3.5发帖json api

参考文章&#xff1a;PHP实现独立Discuz站外发帖(直连操作数据库)_discuz 发帖api-CSDN博客 简单改造了一下&#xff0c;适配我自己的需求 有一个站点存在多个采集站&#xff0c;我想通过主站拿标题&#xff0c;采集站拿内容 使用到的sql如下 CREATE TABLE pre_forum_post_…...

macOS 终端智能代理检测

&#x1f9e0; 终端智能代理检测&#xff1a;自动判断是否需要设置代理访问 GitHub 在开发中&#xff0c;使用 GitHub 是非常常见的需求。但有时候我们会发现某些命令失败、插件无法更新&#xff0c;例如&#xff1a; fatal: unable to access https://github.com/ohmyzsh/oh…...

大数据治理的常见方式

大数据治理的常见方式 大数据治理是确保数据质量、安全性和可用性的系统性方法&#xff0c;以下是几种常见的治理方式&#xff1a; 1. 数据质量管理 核心方法&#xff1a; 数据校验&#xff1a;建立数据校验规则&#xff08;格式、范围、一致性等&#xff09;数据清洗&…...

【java面试】微服务篇

【java面试】微服务篇 一、总体框架二、Springcloud&#xff08;一&#xff09;Springcloud五大组件&#xff08;二&#xff09;服务注册和发现1、Eureka2、Nacos &#xff08;三&#xff09;负载均衡1、Ribbon负载均衡流程2、Ribbon负载均衡策略3、自定义负载均衡策略4、总结 …...