从零手写线性回归模型:PyTorch 实现深度学习入门教程
系列文章目录
01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量!
02-张量运算真简单!PyTorch 数值计算操作完全指南
03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧
04-揭秘数据处理神器:PyTorch 张量拼接与拆分实用技巧
05-深度学习从索引开始:PyTorch 张量索引与切片最全解析
06-张量形状任意改!PyTorch reshape、transpose 操作超详细教程
07-深入解读 PyTorch 张量运算:6 大核心函数全面解析,代码示例一步到位!
08-自动微分到底有多强?PyTorch 自动求导机制深度解析
09-从零手写线性回归模型:PyTorch 实现深度学习入门教程
10-PyTorch 框架实现线性回归:从数据预处理到模型训练全流程
文章目录
- 系列文章目录
- 前言
- 一、构建数据集
- 1.1 示例代码
- 1.2 示例输出
- 二、构建假设函数
- 2.1 示例代码
- 三、损失函数
- 3.1 示例代码
- 四、优化方法
- 4.1 示例代码
- 五、训练函数
- 5.1 示例代码
- 5.2 绘制结果
- 六、调用训练函数
- 6.1 示例输出
- 七、小结
- 7.1 完整代码
前言
在机器学习的学习过程中,我们接触过 线性回归 模型,并使用过如 Scikit-learn 这样的工具来快速实现。但在本文中,将深入理解线性回归的核心思想,并使用 PyTorch 从零开始手动实现一个线性回归模型。这包括:
- 数据集的构建;
- 假设函数的定义;
- 损失函数的设计;
- 梯度下降优化方法的实现;
- 模型训练和损失变化的可视化。
一、构建数据集
线性回归需要一个简单的线性数据集,我们将通过sklearn.datasets.make_regression 方法生成。
1.1 示例代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import randomdef create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120, # 样本数量,即数据点个数n_features=1, # 每个样本只有一个特征noise=15, # 添加噪声,模拟真实场景coef=True, # 是否返回生成数据的真实系数bias=12.0, # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32) # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 输出标签张量转换为二维return x, y, coef # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y) # 数据集长度indices = list(range(data_len)) # 创建索引列表random.shuffle(indices) # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size # 计算批次数for i in range(num_batches): # 遍历每个批次start = i * batch_size # 当前批次起始索引end = start + batch_size # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]] # 当前批次输入batch_y = y[indices[start:end]] # 当前批次输出yield batch_x, batch_y # 返回当前批次数据
1.2 示例输出
运行 create_dataset 后的数据分布为线性趋势,同时包含噪声点。例如:
x:
tensor([[-1.3282],[ 0.1941],[ 0.8944],...])
y:
tensor([-40.2345, 15.2934, 45.1282, ...])
coef:
tensor([35.0])
二、构建假设函数
线性回归的假设函数可以表示为:
y = w ⋅ x + b y ^ =w⋅x+b y=w⋅x+b
其中,( w ) 是权重,( b ) 是偏置。使用 PyTorch 张量定义这些参数。
2.1 示例代码
# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32) # 权重
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32) # 偏置# 假设函数
def linear_regression(x):return w * x + b
三、损失函数
我们使用均方误差(MSE)作为损失函数,其公式为:
L = 1 n ∑ i = 1 n ( y ^ i − y i ) 2 L = \frac{1}{n} \sum_{i=1}^n (\hat{y}_i - y_i)^2 L=n1i=1∑n(y^i−yi)2
3.1 示例代码
def square_loss(y_pred, y_true):return (y_pred - y_true) ** 2
四、优化方法
为了更新模型参数 ( w ) 和 ( b ),使用 随机梯度下降(SGD) 算法,其更新公式为:
w = w − η ⋅ ∂ L ∂ w , b = b − η ⋅ ∂ L ∂ b w = w - \eta \cdot \frac{\partial L}{\partial w}, \quad b = b - \eta \cdot \frac{\partial L}{\partial b} w=w−η⋅∂w∂L,b=b−η⋅∂b∂L
其中,η 是学习率。
4.1 示例代码
def sgd(lr=0.01, batch_size=16):# 更新权重和偏置w.data = w.data - lr * w.grad.data / batch_sizeb.data = b.data - lr * b.grad.data / batch_size
五、训练函数
将前面定义的所有组件组合在一起,构建训练函数,通过多个 epoch 来优化模型。
5.1 示例代码
# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50 # 训练轮次learning_rate = 0.01 # 学习率batch_size = 16 # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = [] # 用于记录每轮的平均损失for epoch in range(epochs): # 遍历每一轮total_loss = 0.0 # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size): # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y) # 当前批次的平均损失total_loss += loss.item() # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}") # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)
5.2 绘制结果
# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7) # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1) # 连续 x 值y_pred = linear_regression(x_line).detach().numpy() # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32) # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0 # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red') # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green') # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()
六、调用训练函数
在主程序中调用 train 函数,训练模型并观察输出。
if __name__ == "__main__":train()
6.1 示例输出
- 拟合直线:

- 损失变化曲线:

七、小结
本文通过手动实现线性回归模型,完成了以下内容:
- 构建数据集并设计数据加载器;
- 定义线性假设函数;
- 设计均方误差损失函数;
- 实现随机梯度下降优化方法;
- 训练模型并可视化损失变化和拟合直线。
7.1 完整代码
import torch
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt
import random# 设置 Matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号# 数据集生成函数
def create_dataset():"""使用 make_regression 生成线性回归数据集,并转换为 PyTorch 张量。"""x, y, coef = make_regression(n_samples=120, # 样本数量,即数据点个数n_features=1, # 每个样本只有一个特征noise=15, # 添加噪声,模拟真实场景coef=True, # 是否返回生成数据的真实系数bias=12.0, # 偏置值,即 y = coef * x + biasrandom_state=42 # 随机种子,保证结果一致性)# 转换为 PyTorch 张量x = torch.tensor(x, dtype=torch.float32) # 输入特征张量y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1) # 输出标签张量转换为二维return x, y, coef # 返回输入特征张量、输出标签张量和真实系数# 数据加载器,用于批量获取数据
def data_loader(x, y, batch_size):"""数据加载器,按批次随机提取训练数据。参数:x: 输入特征张量y: 输出标签张量batch_size: 批量大小"""data_len = len(y) # 数据集长度indices = list(range(data_len)) # 创建索引列表random.shuffle(indices) # 随机打乱索引,保证数据随机性num_batches = data_len // batch_size # 计算批次数for i in range(num_batches): # 遍历每个批次start = i * batch_size # 当前批次起始索引end = start + batch_size # 当前批次结束索引# 提取当前批次的数据batch_x = x[indices[start:end]] # 当前批次输入batch_y = y[indices[start:end]] # 当前批次输出yield batch_x, batch_y # 返回当前批次数据# 模型参数初始化
w = torch.tensor(0.5, requires_grad=True, dtype=torch.float32) # 权重,初始值为 0.5
b = torch.tensor(0.0, requires_grad=True, dtype=torch.float32) # 偏置,初始值为 0# 线性假设函数
def linear_regression(x):"""线性回归假设函数。参数:x: 输入特征张量返回:模型预测值"""return w * x + b # 线性模型公式# 损失函数(均方误差)
def square_loss(y_pred, y_true):"""均方误差损失函数。参数:y_pred: 模型预测值y_true: 数据真实值返回:每个样本的平方误差"""return ((y_pred - y_true) ** 2).mean() # 返回均方误差# 模型训练函数
def train():"""训练线性回归模型。"""# 加载数据集x, y, coef = create_dataset()# 设置训练参数epochs = 50 # 训练轮次learning_rate = 0.01 # 学习率batch_size = 16 # 每批次的数据大小# 使用 PyTorch 内置优化器optimizer = torch.optim.SGD([w, b], lr=learning_rate)epoch_loss = [] # 用于记录每轮的平均损失for epoch in range(epochs): # 遍历每一轮total_loss = 0.0 # 累计损失for batch_x, batch_y in data_loader(x, y, batch_size): # 遍历每个批次# 使用假设函数计算预测值y_pred = linear_regression(batch_x)# 计算损失loss = square_loss(y_pred, batch_y) # 当前批次的平均损失total_loss += loss.item() # 累加总损失# 梯度清零optimizer.zero_grad()# 反向传播计算梯度loss.backward()# 使用随机梯度下降更新参数optimizer.step()# 记录当前轮次的平均损失epoch_loss.append(total_loss / len(y))print(f"轮次 {epoch + 1}, 平均损失: {epoch_loss[-1]:.4f}") # 打印损失# 可视化训练结果plot_results(x, y, coef, epoch_loss)# 可视化训练结果
def plot_results(x, y, coef, epoch_loss):"""绘制训练结果,包括拟合直线和损失变化曲线。参数:x: 输入特征张量y: 输出标签张量coef: 数据生成时的真实权重epoch_loss: 每轮的平均损失"""# 绘制训练数据点和拟合直线plt.scatter(x.numpy(), y.numpy(), label='数据点', alpha=0.7) # 数据点x_line = torch.linspace(x.min(), x.max(), 100).unsqueeze(1) # 连续 x 值y_pred = linear_regression(x_line).detach().numpy() # 模型预测值coef_tensor = torch.tensor(coef, dtype=torch.float32) # 将 coef 转换为 PyTorch 张量y_true = coef_tensor * x_line + 12.0 # 真实直线(生成数据时的公式)plt.plot(x_line.numpy(), y_pred, label='拟合直线', color='red') # 拟合直线plt.plot(x_line.numpy(), y_true.numpy(), label='真实直线', color='green') # 真实直线plt.legend()plt.grid()plt.title('线性回归拟合')plt.xlabel('特征值 X')plt.ylabel('标签值 Y')plt.show()# 绘制损失变化曲线plt.plot(range(len(epoch_loss)), epoch_loss)plt.title('损失变化曲线')plt.xlabel('轮次')plt.ylabel('损失')plt.grid()plt.show()# 调用训练函数
if __name__ == "__main__":train()相关文章:
从零手写线性回归模型:PyTorch 实现深度学习入门教程
系列文章目录 01-PyTorch新手必看:张量是什么?5 分钟教你快速创建张量! 02-张量运算真简单!PyTorch 数值计算操作完全指南 03-Numpy 还是 PyTorch?张量与 Numpy 的神奇转换技巧 04-揭秘数据处理神器:PyTor…...
【Cesium】自定义材质,添加带有方向的滚动路线
【Cesium】自定义材质,添加带有方向的滚动路线 🍖 前言🎶一、实现过程✨二、代码展示🏀三、运行结果🏆四、知识点提示 🍖 前言 【Cesium】自定义材质,添加带有方向的滚动路线 🎶一、…...
C 语言奇幻之旅 - 第11篇:C 语言动态内存管理
目录 引言1. 内存分配函数1.1 malloc 函数实际开发场景:动态数组 1.2 calloc 函数实际开发场景:初始化数据结构 1.3 realloc 函数实际开发场景:动态调整数据结构大小 2. 内存释放2.1 free 函数 3. 内存泄漏与调试3.1 常见内存问题3.2 内存调试…...
IDEA 撤销 merge 操作(详解)
作为一个开发者,我们都知道Git是一个非常重要的版本控制工具,尤其是在协作开发的过程中。然而,在使用Git的过程中难免会踩一些坑,今天我来给大家分享一个我曾经遇到的问题:在使用IDEA中进行merge操作后如何撤销错误的合…...
swarm天气智能体调用流程
Swarm 框架的调用流程: 入口点 (examples/weather_agent/run.py): run_demo_loop(weather_agent, streamTrue)初始化流程: # swarm/repl/repl.py -> run_demo_loop() client Swarm() # 创建 Swarm 实例消息处理流程: # swarm/core.py class Swarm:def run(…...
LED背光驱动芯片RT9293应用电路
一)简介: RT9293 是一款高频、异步的 Boost 升压型 LED 定电流驱动控制器,其工作原理如下: 1)基本电路结构及原理 RT9293的主要功能为上图的Q1. Boost 电路核心原理:基于电感和电容的特性实现升压功能。当…...
二叉树的二叉链表和三叉链表
在二叉树的数据结构中,通常有两种链表存储方式:二叉链表和三叉链表。这里,我们先澄清一下概念,通常我们讨论的是二叉链表,它用于存储二叉树的节点。而“三叉链表”这个术语在二叉树的上下文中不常见,可能是…...
【学习路线】Python 算法(人工智能)详细知识点学习路径(附学习资源)
学习本路线内容之前,请先学习Python的基础知识 其他路线: Python基础 >> Python进阶 >> Python爬虫 >> Python数据分析(数据科学) >> Python 算法(人工智能) >> Pyth…...
C++直接内存管理new和delete
0、前言 C语言定义了两个运算符来分配和释放动态内存。运算符new分配内存,delete释放new分配的内存。 1、new动态内存的分配 1.1、new动态分配和初始化对象 1)、new内存分配 在自由的空间分配的内存是无名的,new无法为其分配的对象…...
Linux 内核中网络接口的创建与管理
在 Linux 系统中,网络接口(如 eth0、wlan0 等)是计算机与外部网络通信的桥梁。无论是物理网卡还是虚拟网络接口,它们的创建和管理都依赖于 Linux 内核的复杂机制。本文将深入探讨 Linux 内核中网络接口的创建过程、命名规则、路由选择以及内核如何将网络接口映射到实际的硬…...
人工智能 前馈神经网络练习题
为了构建一个有两个输入( X 1 X_1 X1、 X 2 X_2 X2)和一个输出的单层感知器,并进行分类,我们需要计算权值 w 1 w_1 w1和 w 2 w_2 w2的更新过程。以下是详细的步骤和计算过程: 初始化参数 初始权值:…...
Windows搭建RTMP服务器
目录 一、Nginx-RTMP服务器搭建1、下载Nginx2、下载Nginx的RTMP扩展包3、修改配置文件4、启动服务器5、查看服务器状态6、其它ngnix命令 二、OBS推流1 、推流设置2、查看服务器状态 三、VLC拉流四、补充 本文转载自:Windows搭建RTMP服务器OBS推流VLC拉流_浏览器查看…...
Vue重新加载子组件
背景:组件需要重新加载,即重新走一遍组件的生命周期常见解决方案: 使用v-if指令:v-if 可以实现 true (加载)和 false (卸载) async reloadComponent() {this.show false// 加上 nextTick this.$nextTick(function() {this.show…...
【VScode】设置代理,通过代理连接服务器
文章目录 VScode编辑器设置代理1.图形化界面1.1 进入proxy设置界面1.2 配置代理服务器 2.配置文件(推荐)2.1 打开setting.json 文件2.2 配置代理 VScode编辑器设置代理 根据情况安装nmap 1.图形化界面 1.1 进入proxy设置界面 或者使用快捷键ctrl , 。…...
js es6 reduce函数, 通过规格生成sku
const specs [{ name: 颜色, values: [红色, 蓝色, 绿色] },{ name: 尺寸, values: [S, M, L] } ];function generateSKUs(specs) {return specs.reduce((acc, spec) > {const newAcc [];for (const combination of acc) {for (const value of spec.values) {newAcc.push(…...
基于R语言的DICE模型
DICE型是运用最广泛的综合模型之一。DICE和RICE模型虽然代码量不多,但涉及经济学与气候变化,原理较为复杂。 一:DICE模型的原理与推导 1.经济学 2.气候变化问题 3.DICE模型的经济学部分 4.DICE模型的气候相关部分 5.DICE模型的目标函数…...
【C】PAT 1006-1010
1006 换个格式输出整数 让我们用字母 B 来表示“百”、字母 S 表示“十”,用 12...n 来表示不为零的个位数字 n(<10),换个格式来输出任一个不超过 3 位的正整数。例如 234 应该被输出为 BBSSS1234,因为它有 2 个“…...
力扣双指针-算法模版总结
lc-15.三数之和 (时隔13天) 目前可通过,想法上无逻辑问题,一点细节小错误需注意即可 lc-283.移动零(时隔16天) 总结:观察案例直觉就是双指针遇零交换,两次实现都通过了,…...
解释一下:运放的输入偏置电流
输入偏置电流 首先看基础部分:这就是同相比例放大器 按照理论计算,输入VIN=0时,输出VOUT应为0,对吧 仿真与理论差距较大,有200多毫伏的偏差,这就是输入偏置电流IBIAS引起的,接着看它的定义 同向和反向输入电流的平均值,也就是Ib1、Ib2求平均,即(Ib1+Ib2)/2 按照下面…...
Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8
在 Windows 11 上通过 WSL (Windows Subsystem for Linux) 安装 MySQL 8 的步骤如下: ✅ 1. 检查 WSL 的安装 首先确保已经安装并启用了 WSL 2。 🔧 检查 WSL 版本 打开 PowerShell,执行以下命令: wsl --list --verbose确保 W…...
docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...
MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例
一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...
Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)
文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...
3.3.1_1 检错编码(奇偶校验码)
从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...
在rocky linux 9.5上在线安装 docker
前面是指南,后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...
DAY 47
三、通道注意力 3.1 通道注意力的定义 # 新增:通道注意力模块(SE模块) class ChannelAttention(nn.Module):"""通道注意力模块(Squeeze-and-Excitation)"""def __init__(self, in_channels, reduction_rat…...
css的定位(position)详解:相对定位 绝对定位 固定定位
在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...
EtherNet/IP转DeviceNet协议网关详解
一,设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络,本网关连接到EtherNet/IP总线中做为从站使用,连接到DeviceNet总线中做为从站使用。 在自动…...
UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)
UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化…...
QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...
