从零手写线性回归模型:PyTorch 实现深度学习入门教程
系列文章目录
01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量!
02-张量运算真简单!PyTorch 数值计算操作完全指南
03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧
04-揭秘数据处理神器:PyTorch 张量拼接与拆分实用技巧
05-深度学习从索引开始:PyTorch 张量索引与切片最全解析
06-张量形状任意改!PyTorch reshape、transpose 操作超详细教程
07-深入解读 PyTorch 张量运算:6 大核心函数全面解析,代码示例一步到位!
08-自动微分到底有多强?PyTorch 自动求导机制深度解析
09-从零手写线性回归模型:PyTorch 实现深度学习入门教程
10-PyTorch 框架实现线性回归:从数据预处理到模型训练全流程
文章目录
- 系列文章目录
- 前言
- 一、构建数据集
- 1.1 示例代码
- 1.2 示例输出
- 二、构建假设函数
- 2.1 示例代码
- 三、损失函数
- 3.1 示例代码
- 四、优化方法
- 4.1 示例代码
- 五、训练函数
- 5.1 示例代码
- 5.2 绘制结果
- 六、调用训练函数
- 6.1 示例输出
- 七、小结
- 7.1 完整代码
前言
在机器学习的学习过程中,我们接触过 线性回归 模型,并使用过如 Scikit-learn 这样的工具来快速实现。但在本文中,将深入理解线性回归的核心思想,并使用 PyTorch 从零开始手动实现一个线性回归模型。这包括:
- 数据集的构建;
- 假设函数的定义;
- 损失函数的设计;
- 梯度下降优化方法的实现;
- 模型训练和损失变化的可视化。
一、构建数据集
线性回归需要一个简单的线性数据集,我们将通过sklearn.datasets.make_regression
方法生成。
1.1 示例代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import randomdef create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120, # 样本数量,即数据点个数n_features=1, # 每个样本只有一个特征noise=15, # 添加噪声,模拟真实场景coef=True, # 是否返回生成数据的真实系数bias=12.0, # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32) # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 输出标签张量转换为二维return x, y, coef # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y) # 数据集长度indices = list(range(data_len)) # 创建索引列表random.shuffle(indices) # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size # 计算批次数for i in range(num_batches): # 遍历每个批次start = i * batch_size # 当前批次起始索引end = start + batch_size # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]] # 当前批次输入batch_y = y[indices[start:end]] # 当前批次输出yield batch_x, batch_y # 返回当前批次数据
1.2 示例输出
运行 create_dataset
后的数据分布为线性趋势,同时包含噪声点。例如:
x:
tensor([[-1.3282],[ 0.1941],[ 0.8944],...])
y:
tensor([-40.2345, 15.2934, 45.1282, ...])
coef:
tensor([35.0])
二、构建假设函数
线性回归的假设函数可以表示为:
y = w ⋅ x + b y ^ =w⋅x+b y=w⋅x+b
其中,( w ) 是权重,( b ) 是偏置。使用 PyTorch 张量定义这些参数。
2.1 示例代码
# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32) # 权重
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32) # 偏置# 假设函数
def linear_regression(x):return w * x + b
三、损失函数
我们使用均方误差(MSE)作为损失函数,其公式为:
L = 1 n ∑ i = 1 n ( y ^ i − y i ) 2 L = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i)^2 L=n1i=1∑n(y^i−yi)2
3.1 示例代码
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2
四、优化方法
为了更新模型参数 ( w ) 和 ( b ),使用 随机梯度下降(SGD) 算法,其更新公式为:
w = w − η ⋅ ∂ L ∂ w , b = b − η ⋅ ∂ L ∂ b w = w - \eta \cdot \frac{\partial L}{\partial w}, \quad b = b - \eta \cdot \frac{\partial L}{\partial b} w=w−η⋅∂w∂L,b=b−η⋅∂b∂L
其中,η 是学习率。
4.1 示例代码
def sgd(lr=0.01, batch_size=16):# 更新权重和偏置w.data = w.data - lr * w.grad.data / batch_sizeb.data = b.data - lr * b.grad.data / batch_size
五、训练函数
将前面定义的所有组件组合在一起,构建训练函数,通过多个 epoch 来优化模型。
5.1 示例代码
# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50 # 训练轮次learning_rate = 0.01 # 学习率batch_size = 16 # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = [] # 用于记录每轮的平均损失for epoch in range(epochs): # 遍历每一轮total_loss = 0.0 # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size): # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y) # 当前批次的平均损失total_loss += loss.item() # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}") # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)
5.2 绘制结果
# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7) # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1) # 连续 x 值y_pred = linear_regression(x_line).detach().numpy() # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32) # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0 # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red') # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green') # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()
六、调用训练函数
在主程序中调用 train
函数,训练模型并观察输出。
if __name__ == "__main__":train()
6.1 示例输出
- 拟合直线:
- 损失变化曲线:
七、小结
本文通过手动实现线性回归模型,完成了以下内容:
- 构建数据集并设计数据加载器;
- 定义线性假设函数;
- 设计均方误差损失函数;
- 实现随机梯度下降优化方法;
- 训练模型并可视化损失变化和拟合直线。
7.1 完整代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random# 设置 Matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号# 数据集生成函数
def create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120, # 样本数量,即数据点个数n_features=1, # 每个样本只有一个特征noise=15, # 添加噪声,模拟真实场景coef=True, # 是否返回生成数据的真实系数bias=12.0, # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32) # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 输出标签张量转换为二维return x, y, coef # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y) # 数据集长度indices = list(range(data_len)) # 创建索引列表random.shuffle(indices) # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size # 计算批次数for i in range(num_batches): # 遍历每个批次start = i * batch_size # 当前批次起始索引end = start + batch_size # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]] # 当前批次输入batch_y = y[indices[start:end]] # 当前批次输出yield batch_x, batch_y # 返回当前批次数据# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32) # 权重,初始值为 0.5
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32) # 偏置,初始值为 0# 线性假设函数
def linear_regression(x):"""线性回归假设函数。参数:x: 输入特征张量返回:模型预测值"""return w * x + b # 线性模型公式# 损失函数(均方误差)
def square_loss(y_pred, y_true):"""均方误差损失函数。参数:y_pred: 模型预测值y_true: 数据真实值返回:每个样本的平方误差"""return ((y_pred - y_true) ** 2).mean() # 返回均方误差# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50 # 训练轮次learning_rate = 0.01 # 学习率batch_size = 16 # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = [] # 用于记录每轮的平均损失for epoch in range(epochs): # 遍历每一轮total_loss = 0.0 # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size): # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y) # 当前批次的平均损失total_loss += loss.item() # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}") # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7) # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1) # 连续 x 值y_pred = linear_regression(x_line).detach().numpy() # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32) # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0 # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red') # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green') # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()# 调用训练函数
if __name__ == "__main__":train()
相关文章:

从零手写线性回归模型:PyTorch 实现深度学习入门教程
系列文章目录 01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量! 02-张量运算真简单!PyTorch 数值计算操作完全指南 03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧 04-揭秘数据处理神器:PyTor…...

【Cesium】自定义材质,添加带有方向的滚动路线
【Cesium】自定义材质,添加带有方向的滚动路线 🍖 前言🎶一、实现过程✨二、代码展示🏀三、运行结果🏆四、知识点提示 🍖 前言 【Cesium】自定义材质,添加带有方向的滚动路线 🎶一、…...
C 语言奇幻之旅 - 第11篇:C 语言动态内存管理
目录 引言1. 内存分配函数1.1 malloc 函数实际开发场景:动态数组 1.2 calloc 函数实际开发场景:初始化数据结构 1.3 realloc 函数实际开发场景:动态调整数据结构大小 2. 内存释放2.1 free 函数 3. 内存泄漏与调试3.1 常见内存问题3.2 内存调试…...

IDEA 撤销 merge 操作(详解)
作为一个开发者,我们都知道Git是一个非常重要的版本控制工具,尤其是在协作开发的过程中。然而,在使用Git的过程中难免会踩一些坑,今天我来给大家分享一个我曾经遇到的问题:在使用IDEA中进行merge操作后如何撤销错误的合…...
swarm天气智能体调用流程
Swarm 框架的调用流程: 入口点 (examples/weather_agent/run.py): run_demo_loop(weather_agent, streamTrue)初始化流程: # swarm/repl/repl.py -> run_demo_loop() client Swarm() # 创建 Swarm 实例消息处理流程: # swarm/core.py class Swarm:def run(…...

LED背光驱动芯片RT9293应用电路
一)简介: RT9293 是一款高频、异步的 Boost 升压型 LED 定电流驱动控制器,其工作原理如下: 1)基本电路结构及原理 RT9293的主要功能为上图的Q1. Boost 电路核心原理:基于电感和电容的特性实现升压功能。当…...
二叉树的二叉链表和三叉链表
在二叉树的数据结构中,通常有两种链表存储方式:二叉链表和三叉链表。这里,我们先澄清一下概念,通常我们讨论的是二叉链表,它用于存储二叉树的节点。而“三叉链表”这个术语在二叉树的上下文中不常见,可能是…...

【学习路线】Python 算法(人工智能)详细知识点学习路径(附学习资源)
学习本路线内容之前,请先学习Python的基础知识 其他路线: Python基础 >> Python进阶 >> Python爬虫 >> Python数据分析(数据科学) >> Python 算法(人工智能) >> Pyth…...
C++直接内存管理new和delete
0、前言 C语言定义了两个运算符来分配和释放动态内存。运算符new分配内存,delete释放new分配的内存。 1、new动态内存的分配 1.1、new动态分配和初始化对象 1)、new内存分配 在自由的空间分配的内存是无名的,new无法为其分配的对象…...
Linux 内核中网络接口的创建与管理
在 Linux 系统中,网络接口(如 eth0、wlan0 等)是计算机与外部网络通信的桥梁。无论是物理网卡还是虚拟网络接口,它们的创建和管理都依赖于 Linux 内核的复杂机制。本文将深入探讨 Linux 内核中网络接口的创建过程、命名规则、路由选择以及内核如何将网络接口映射到实际的硬…...
人工智能 前馈神经网络练习题
为了构建一个有两个输入( X 1 X_1 X1、 X 2 X_2 X2)和一个输出的单层感知器,并进行分类,我们需要计算权值 w 1 w_1 w1和 w 2 w_2 w2的更新过程。以下是详细的步骤和计算过程: 初始化参数 初始权值:…...

Windows搭建RTMP服务器
目录 一、Nginx-RTMP服务器搭建1、下载Nginx2、下载Nginx的RTMP扩展包3、修改配置文件4、启动服务器5、查看服务器状态6、其它ngnix命令 二、OBS推流1 、推流设置2、查看服务器状态 三、VLC拉流四、补充 本文转载自:Windows搭建RTMP服务器OBS推流VLC拉流_浏览器查看…...
Vue重新加载子组件
背景:组件需要重新加载,即重新走一遍组件的生命周期常见解决方案: 使用v-if指令:v-if 可以实现 true (加载)和 false (卸载) async reloadComponent() {this.show false// 加上 nextTick this.$nextTick(function() {this.show…...

【VScode】设置代理,通过代理连接服务器
文章目录 VScode编辑器设置代理1.图形化界面1.1 进入proxy设置界面1.2 配置代理服务器 2.配置文件(推荐)2.1 打开setting.json 文件2.2 配置代理 VScode编辑器设置代理 根据情况安装nmap 1.图形化界面 1.1 进入proxy设置界面 或者使用快捷键ctrl , 。…...
js es6 reduce函数, 通过规格生成sku
const specs [{ name: 颜色, values: [红色, 蓝色, 绿色] },{ name: 尺寸, values: [S, M, L] } ];function generateSKUs(specs) {return specs.reduce((acc, spec) > {const newAcc [];for (const combination of acc) {for (const value of spec.values) {newAcc.push(…...
基于R语言的DICE模型
DICE型是运用最广泛的综合模型之一。DICE和RICE模型虽然代码量不多,但涉及经济学与气候变化,原理较为复杂。 一:DICE模型的原理与推导 1.经济学 2.气候变化问题 3.DICE模型的经济学部分 4.DICE模型的气候相关部分 5.DICE模型的目标函数…...
【C】PAT 1006-1010
1006 换个格式输出整数 让我们用字母 B 来表示“百”、字母 S 表示“十”,用 12...n 来表示不为零的个位数字 n(<10),换个格式来输出任一个不超过 3 位的正整数。例如 234 应该被输出为 BBSSS1234,因为它有 2 个“…...

力扣双指针-算法模版总结
lc-15.三数之和 (时隔13天) 目前可通过,想法上无逻辑问题,一点细节小错误需注意即可 lc-283.移动零(时隔16天) 总结:观察案例直觉就是双指针遇零交换,两次实现都通过了,…...

解释一下:运放的输入偏置电流
输入偏置电流 首先看基础部分:这就是同相比例放大器 按照理论计算,输入VIN=0时,输出VOUT应为0,对吧 仿真与理论差距较大,有200多毫伏的偏差,这就是输入偏置电流IBIAS引起的,接着看它的定义 同向和反向输入电流的平均值,也就是Ib1、Ib2求平均,即(Ib1+Ib2)/2 按照下面…...
Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8
在 Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8 的步骤如下: ✅ 1. 检查 WSL 的安装 首先确保已经安装并启用了 WSL 2。 🔧 检查 WSL 版本 打开 PowerShell,执行以下命令: wsl --list --verbose确保 W…...

C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。
1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj,再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...
Python Einops库:深度学习中的张量操作革命
Einops(爱因斯坦操作库)就像给张量操作戴上了一副"语义眼镜"——让你用人类能理解的方式告诉计算机如何操作多维数组。这个基于爱因斯坦求和约定的库,用类似自然语言的表达式替代了晦涩的API调用,彻底改变了深度学习工程…...

Rust 开发环境搭建
环境搭建 1、开发工具RustRover 或者vs code 2、Cygwin64 安装 https://cygwin.com/install.html 在工具终端执行: rustup toolchain install stable-x86_64-pc-windows-gnu rustup default stable-x86_64-pc-windows-gnu 2、Hello World fn main() { println…...
区块链技术概述
区块链技术是一种去中心化、分布式账本技术,通过密码学、共识机制和智能合约等核心组件,实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点:数据存储在网络中的多个节点(计算机),而非…...

Matlab实现任意伪彩色图像可视化显示
Matlab实现任意伪彩色图像可视化显示 1、灰度原始图像2、RGB彩色原始图像 在科研研究中,如何展示好看的实验结果图像非常重要!!! 1、灰度原始图像 灰度图像每个像素点只有一个数值,代表该点的亮度(或…...
[QMT量化交易小白入门]-六十二、ETF轮动中简单的评分算法如何获取历史年化收益32.7%
本专栏主要是介绍QMT的基础用法,常见函数,写策略的方法,也会分享一些量化交易的思路,大概会写100篇左右。 QMT的相关资料较少,在使用过程中不断的摸索,遇到了一些问题,记录下来和大家一起沟通,共同进步。 文章目录 相关阅读1. 策略概述2. 趋势评分模块3 代码解析4 木头…...
LeetCode 0386.字典序排数:细心总结条件
【LetMeFly】386.字典序排数:细心总结条件 力扣题目链接:https://leetcode.cn/problems/lexicographical-numbers/ 给你一个整数 n ,按字典序返回范围 [1, n] 内所有整数。 你必须设计一个时间复杂度为 O(n) 且使用 O(1) 额外空间的算法。…...

【Linux】使用1Panel 面板让服务器定时自动执行任务
服务器就是一台24小时开机的主机,相比自己家中不定时开关机的主机更适合完成定时任务,例如下载资源、备份上传,或者登录某个网站执行一些操作,只需要编写 脚本,然后让服务器定时来执行这个脚本就可以。 有很多方法实现…...

MCP和Function Calling
MCP MCP(Model Context Protocol,模型上下文协议) ,2024年11月底,由 Anthropic 推出的一种开放标准,旨在统一大模型与外部数据源和工具之间的通信协议。MCP 的主要目的在于解决当前 AI 模型因数据孤岛限制而…...

docker容器互联
1.docker可以通过网路访问 2.docker允许映射容器内应用的服务端口到本地宿主主机 3.互联机制实现多个容器间通过容器名来快速访问 一 、端口映射实现容器访问 1.从外部访问容器应用 我们先把之前的删掉吧(如果不删的话,容器就提不起来,因…...