动手学深度学习-3.2 线性回归的从0开始
以下是代码的逐段解析及其实际作用:
1. 环境设置与库导入
%matplotlib inline
import random
import torch
from d2l import torch as d2l
- 作用:
%matplotlib inline:在 Jupyter Notebook 中内嵌显示 matplotlib 图形。random:生成随机索引用于数据打乱。torch:PyTorch 深度学习框架。d2l:《动手学深度学习》提供的工具函数库(如绘图工具)。
2. 生成合成数据
假设真实权重向量为 w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrue∈Rn,偏置为 b true b_{\text{true}} btrue,噪声为高斯分布 ϵ ∼ N ( 0 , σ 2 ) \epsilon \sim \mathcal{N}(0, \sigma^2) ϵ∼N(0,σ2),则合成数据生成公式为:
y = X w true + b true + ϵ \mathbf{y} = \mathbf{X} \mathbf{w}_{\text{true}} + b_{\text{true}} + \epsilon y=Xwtrue+btrue+ϵ
其中:
- X ∈ R m × n \mathbf{X} \in \mathbb{R}^{m \times n} X∈Rm×n:输入特征矩阵( m m m 个样本, n n n 个特征)。
- w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrue∈Rn:真实权重向量。
- ϵ ∈ R m \epsilon \in \mathbb{R}^m ϵ∈Rm:噪声向量。
def synthetic_data(w, b, num_examples): #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w))) # 生成标准正态分布的输入特征 num_examples行,len(w)列y = torch.matmul(X, w) + b # 计算线性输出 y = Xw + by += torch.normal(0, 0.01, y.shape) # 添加高斯噪声return X, y.reshape((-1, 1)) # y行数不定(值为-1,列数为1)true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
生成的函数是一个二维线性回归模型,其数学表达式为:
y = w 1 x 1 + w 2 x 2 + b + ϵ y = w_1 x_1 + w_2 x_2 + b + \epsilon y=w1x1+w2x2+b+ϵ
其中:
- 权重: w = [ w 1 , w 2 ] = [ 2 , − 3.4 ] \mathbf{w} = [w_1, w_2] = [2, -3.4] w=[w1,w2]=[2,−3.4],由
true_w定义。 - 偏置: b = 4.2 b = 4.2 b=4.2,由
true_b定义。 - 噪声: ϵ ∼ N ( 0 , 0.0 1 2 ) \epsilon \sim \mathcal{N}(0, 0.01^2) ϵ∼N(0,0.012),即均值为 0、标准差为 0.01 的高斯噪声。
展开为标量形式:
y i = 2 ⋅ x i 1 − 3.4 ⋅ x i 2 + 4.2 + ϵ i ( i = 1 , 2 , … , 1000 ) y_i = 2 \cdot x_{i1} - 3.4 \cdot x_{i2} + 4.2 + \epsilon_i \quad (i = 1, 2, \dots, 1000) yi=2⋅xi1−3.4⋅xi2+4.2+ϵi(i=1,2,…,1000)
3. 数据可视化
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
- 绘制第二个特征(
features[:,1] => n行第1列)与标签labels的散点图。
4. 定义数据迭代器
def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices) # 打乱索引顺序for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices] # 生成小批量数据
- 作用:
- 将数据集按
batch_size划分为小批量,并随机打乱顺序。 - 使用生成器 (
yield) 逐批返回数据,避免一次性加载全部数据到内存。
- 将数据集按
5. 初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
- 初始化w和b的值:
w:从均值为 0、标准差为 0.01 的正态分布中初始化权重,启用梯度追踪。b:初始化为 0 的偏置,启用梯度追踪。- 参数需梯度追踪以支持反向传播。
6. 定义模型、损失函数和优化器
def linreg(X, w, b): #@save"""线性回归模型"""return torch.matmul(X, w) + bdef squared_loss(y_hat, y): #@save"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2 # 除以2便于梯度计算def sgd(params, lr, batch_size): #@save"""小批量随机梯度下降"""with torch.no_grad(): # 禁用梯度计算for param in params:param -= lr * param.grad / batch_size # 参数更新param.grad.zero_() # 梯度清零
-
linreg:模型预测值 y ^ \hat{\mathbf{y}} y^ 的矩阵形式为:
y ^ = X w + b \hat{\mathbf{y}} = \mathbf{X} \mathbf{w} + b y^=Xw+b
其中:- w ∈ R n \mathbf{w} \in \mathbb{R}^n w∈Rn:待学习的权重向量。
- b ∈ R b \in \mathbb{R} b∈R:待学习的偏置。
-
squared_loss:损失函数的矩阵形式为:
L = 1 2 ∥ y ^ − y ∥ 2 L = \frac{1}{2} \| \hat{\mathbf{y}} - \mathbf{y} \|^2 L=21∥y^−y∥2
为
L ( w , b ) = 1 2 m ∥ X w + b − y ∥ 2 L(\mathbf{w}, b) = \frac{1}{2m} \| \mathbf{X} \mathbf{w} + b - \mathbf{y} \|^2 L(w,b)=2m1∥Xw+b−y∥2
展开后:
L ( w , b ) = 1 2 m ( X w + b 1 − y ) ⊤ ( X w + b 1 − y ) L(\mathbf{w}, b) = \frac{1}{2m} (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y})^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) L(w,b)=2m1(Xw+b1−y)⊤(Xw+b1−y) -
sgd:小批量随机梯度下降优化器,-
对权重 w \mathbf{w} w 的梯度
∇ w L = 1 m X ⊤ ( X w + b 1 − y ) \nabla_{\mathbf{w}} L = \frac{1}{m} \mathbf{X}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) ∇wL=m1X⊤(Xw+b1−y) -
对偏置 b b b 的梯度
∇ b L = 1 m 1 ⊤ ( X w + b 1 − y ) , 1 为单位列向量 \nabla_{b} L = \frac{1}{m} \mathbf{1}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}),\mathbf{1} 为单位列向量 ∇bL=m11⊤(Xw+b1−y),1为单位列向量 -
使用学习率 η \eta η,参数更新公式为:
w ← w − η ∇ w L b ← b − η ∇ b L \mathbf{w} \leftarrow \mathbf{w} - \eta \nabla_{\mathbf{w}} L\\ b \leftarrow b - \eta \nabla_{b} L w←w−η∇wLb←b−η∇bL
-
7. 训练循环
lr = 0.03
num_epochs = 3
batch_size = 10 # 需补充定义(原代码未显式定义)for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y) # 计算小批量损失l.sum().backward() # 反向传播计算梯度sgd([w, b], lr, batch_size) # 更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
-
作用:
- 外层循环:遍历训练轮次 (
num_epochs)。 - 内层循环:按小批量遍历数据,计算损失并反向传播。
l.sum().backward():将小批量损失求和后反向传播,计算梯度。sgd:根据梯度更新参数,梯度需除以batch_size以保持学习率一致性。- 每个 epoch 结束后,计算并打印整体训练损失。
- mean()函数计算平均值
- 外层循环:遍历训练轮次 (
-
梯度下降
l.sum().backward() # 反向传播计算梯度sgd([w, b], lr, batch_size) # 更新参数
- 小批量梯度计算公式:
∇ w L batch = 1 batch_size X batch ⊤ ( X batch w + b − y batch ) \nabla_{\mathbf{w}} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{X}_{\text{batch}}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) ∇wLbatch=batch_size1Xbatch⊤(Xbatchw+b−ybatch)
∇ b L batch = 1 batch_size 1 ⊤ ( X batch w + b − y batch ) \nabla_{b} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{1}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) ∇bLbatch=batch_size11⊤(Xbatchw+b−ybatch)
相关文章:
动手学深度学习-3.2 线性回归的从0开始
以下是代码的逐段解析及其实际作用: 1. 环境设置与库导入 %matplotlib inline import random import torch from d2l import torch as d2l作用: %matplotlib inline:在 Jupyter Notebook 中内嵌显示 matplotlib 图形。random:生成…...
“深度强化学习揭秘:掌握DQN与PPO算法的精髓“
深度Q网络(Deep Q-Network,简称DQN)是一种结合了Q学习和深度神经网络的强化学习算法。它使用神经网络来近似Q值函数,从而实现对复杂状态空间中的动作选择。DQN的核心思想是通过贝尔曼方程(Bellman Equation)…...
如何让DeepSeek恢复联网功能?解决(由于技术原因,联网搜索暂不可用)
DeekSeek提示:(由于技术原因,联网搜索暂不可用) 众所周知,因为海外黑客的ddos攻击、僵尸网络攻击,deepseek的联网功能一直处于宕机阶段,但是很多问题不联网出来的结果都还是2023年的,…...
Unity-编译构建Android的问题记录
文章目录 报错:AAPT2 aapt2-4.1.2-6503028-osx Daemon #0 Failed to shutdown within timeout报错信息解读:原因分析最终处理方法 报错:AAPT2 aapt2-4.1.2-6503028-osx Daemon #0 Failed to shutdown within timeout 报错信息解读࿱…...
python的ruff简单使用
Ruff 是一个用 Rust 编写的高性能 Python 静态分析工具和代码格式化工具。它旨在提供快速的代码检查和格式化功能,同时支持丰富的配置选项和与现有工具的兼容性。ruff是用rust实现的python Linter&Formatter。 安装: conda install -c conda-forge…...
Docker 部署 GLPI(IT 资产管理软件系统)
GLPI 简介 GLPI open source tool to manage Helpdesk and IT assets GLPI stands for Gestionnaire Libre de Parc Informatique(法语 资讯设备自由软件 的缩写) is a Free Asset and IT Management Software package, that provides ITIL Service De…...
【漫话机器学习系列】077.范数惩罚是如何起作用的(How Norm Penalties Work)
范数惩罚的作用与原理 范数惩罚(Norm Penalty) 是一种常用于机器学习模型中的正则化技术,它的主要目的是控制模型复杂度,防止过拟合。通过对模型的参数进行惩罚(即在损失函数中加入惩罚项),使得…...
【C++ STL】vector容器详解:从入门到精通
【C STL】vector容器详解:从入门到精通 摘要:本文深入讲解C STL中vector容器的使用方法,涵盖常用函数、代码示例及注意事项,助你快速掌握动态数组的核心操作! 一、vector概述 vector是C标准模板库(STL&am…...
LLMs之OpenAI o系列:OpenAI o3-mini的简介、安装和使用方法、案例应用之详细攻略
LLMs之OpenAI o系列:OpenAI o3-mini的简介、安装和使用方法、案例应用之详细攻略 目录 相关文章 LLMs之o3:《Deliberative Alignment: Reasoning Enables Safer Language Models》翻译与解读 LLMs之OpenAI o系列:OpenAI o3-mini的简介、安…...
Notepad++消除生成bak文件
设置(T) ⇒ 首选项... ⇒ 备份 ⇒ 勾选 "禁用" 勾选禁用 就不会再生成bak文件了 notepad怎么修改字符集编码格式为gbk 如图所示...
后台管理系统通用页面抽离=>高阶组件+配置文件+hooks
目录结构 配置文件和通用页面组件 content.config.ts const contentConfig {pageName: "role",header: {title: "角色列表",btnText: "新建角色"},propsList: [{ type: "selection", label: "选择", width: "80px&q…...
Spring Boot项目如何使用MyBatis实现分页查询
写在前面:大家好!我是晴空๓。如果博客中有不足或者的错误的地方欢迎在评论区或者私信我指正,感谢大家的不吝赐教。我的唯一博客更新地址是:https://ac-fun.blog.csdn.net/。非常感谢大家的支持。一起加油,冲鸭&#x…...
[Java]多态
1. 多态的基本概念 1.1 定义: 多态是指同一操作作用于不同的对象时,能够表现出不同的行为。多态通常通过以下两种方式实现: 方法重载(Overloading)方法重写(Overriding) 1.2 示例࿱…...
用Impala对存储在HDFS中的大规模数据集进行快速、实时的交互式SQL查询的具体步骤和关键代码
AWS EMR(Elastic MapReduce)中应用Impala的典型案例,主要体现在大型企业和数据密集型组织如何利用Impala对存储在Hadoop分布式文件系统(HDFS)中的大规模数据集进行快速、实时的交互式SQL查询。以下是一个具体的案例说明…...
Intellij 插件开发-快速开始
目录 一、开发环境搭建以及创建action1. 安装 Plugin DevKit 插件2. 新建idea插件项目3. 创建 Action4. 向新的 Action 表单注册 Action5. Enabling Internal Mode 二、插件实战开发[不推荐]UI Designer 基础JBPanel类(JPanel面板)需求:插件设…...
GIt使用笔记大全
Git 使用笔记大全 1. 安装 Git 在终端或命令提示符中,输入以下命令检查是否已安装 Git: git --version如果未安装,可以从 Git 官方网站 下载并安装适合你操作系统的版本。 2. 配置 Git 首次使用 Git 时,需要配置用户名和邮箱…...
语言月赛 202412【题目名没活了】题解(AC)
》》》点我查看「视频」详解》》》 [语言月赛 202412] 题目名没活了 题目描述 在 XCPC 竞赛里,会有若干道题目,一支队伍可以对每道题目提交若干次。我们称一支队伍对一道题目的一次提交是有效的,当且仅当: 在本次提交以前&…...
MySQL锁类型(详解)
锁的分类图,如下: 锁操作类型划分 读锁 : 也称为共享锁 、英文用S表示。针对同一份数据,多个事务的读操作可以同时进行而不会互相影响,相互不阻塞的。 写锁 : 也称为排他锁 、英文用X表示。当前写操作没有完成前,它会…...
面经--C语言——static,volatile,malloc,使用异或进行数据交换
文章目录 static静态变量和全局变量的区别volatile主要作用 malloc1. 内存分配器的作用2. 内存分配过程(1) 查找空闲内存块(2) 扩展堆空间(3) 元数据 3. 内存释放过程(1) 标记为可用(2) 合并相邻空闲块(3) 延迟释放 4. 内存管理策略(1) 分配缓存(Allocation Caching…...
stm32小白成长为高手的学习步骤和方法
我们假定大家已经对STM32的书籍或者文档有一定的理解。如不理解,请立即阅读STM32的文档,以获取最基本的知识点。STM32单片机自学教程 这篇博文也是一篇不错的入门教程,初学者可以看看,讲的真心不错。 英文好的同学…...
OSCP - Proving Grounds - Roquefort
主要知识点 githook 注入Linux path覆盖 具体步骤 依旧是nmap扫描开始,3000端口不是很熟悉,先看一下 Nmap scan report for 192.168.54.67 Host is up (0.00083s latency). Not shown: 65530 filtered tcp ports (no-response) PORT STATE SERV…...
集合通讯概览
(1)通信的算法 是根据通讯的链路组成的 (2)因为通信链路 跟硬件强相关,所以每个CCL的库都不一样 芯片与芯片、不同U之间是怎么通信的!!!!!! 很重要…...
【贪心算法篇】:“贪心”之旅--算法练习题中的智慧与策略(二)
✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:贪心算法篇–CSDN博客 文章目录 前言例题1.买卖股票的最佳时机2.买卖股票的最佳时机23.k次取…...
oracle: 表分区>>范围分区,列表分区,散列分区/哈希分区,间隔分区,参考分区,组合分区,子分区/复合分区/组合分区
分区表 是将一个逻辑上的大表按照特定的规则划分为多个物理上的子表,这些子表称为分区。 分区可以基于不同的维度,如时间、数值范围、字符串值等,将数据分散存储在不同的分区 中,以提高数据管理的效率和查询性能,同时…...
基于SpringBoot 前端接收中文显示解决方案
一. 问题 返回给前端的的中文值会变成“???” 二. 解决方案 1. 在application.yml修改字符编码 (无效) 在网上看到说修改servlet字符集编码,尝试了不行 server:port: 8083servlet:encoding:charset: UTF-8enabled: trueforce: true2. …...
java练习(5)
ps:题目来自力扣 给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的,并且每个节点只能存储 一位 数字。 请你将两个数相加,并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外,这…...
python算法和数据结构刷题[3]:哈希表、滑动窗口、双指针、回溯算法、贪心算法
回溯算法 「所有可能的结果」,而不是「结果的个数」,一般情况下,我们就知道需要暴力搜索所有的可行解了,可以用「回溯法」。 回溯算法关键在于:不合适就退回上一步。在回溯算法中,递归用于深入到所有可能的分支&…...
大数据数仓实战项目(离线数仓+实时数仓)1
目录 1.课程目标 2.电商行业与电商系统介绍 3.数仓项目整体技术架构介绍 4.数仓项目架构-kylin补充 5.数仓具体技术介绍与项目环境介绍 6.kettle的介绍与安装 7.kettle入门案例 8.kettle输入组件之JSON输入与表输入 9.kettle输入组件之生成记录组件 10.kettle输出组件…...
【开源免费】基于Vue和SpringBoot的公寓报修管理系统(附论文)
本文项目编号 T 186 ,文末自助获取源码 \color{red}{T186,文末自助获取源码} T186,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...
使用QMUI实现用户协议对话框
使用QMUI实现用户协议对话框 懒加载用于初始化 TermServiceDialogController 对象。 懒加载 lazy var 的作用 lazy var dialogController: TermServiceDialogController {let r TermServiceDialogController()r.primaryButton.addTarget(self, action: #selector(primaryC…...
