【深度学习】线性回归的简洁实现
线性回归的简洁实现
在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。
这些框架可以自动化基于梯度的学习算法中重复性的工作。
目前,我们只会运用:
(1)通过张量来进行数据存储和线性代数;
(2)通过自动微分来计算梯度。
实际上,由于数据迭代器、损失函数、优化器和神经网络层很常用,现代深度学习库也为我们实现了这些组件。
本节将介绍如何(通过使用深度学习框架来简洁地实现线性回归模型)。
生成数据集
我们首先[生成数据集]。
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])#用于创建张量(Tensor)
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
读取数据集
我们可以[调用框架中现有的API来读取数据]。
我们将features
和labels
作为API的参数传递,并通过数据迭代器指定batch_size
。
此外,布尔值is_train
表示是否希望数据迭代器对象在每个迭代周期内打乱数据。
def load_array(data_arrays, batch_size, is_train=True): #@save"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)#TensorDataset可以将多个张量组合成一个数据集对象return data.DataLoader(dataset, batch_size, shuffle=is_train)#用于批量加载和处理数据集
batch_size = 10
data_iter = load_array((features, labels), batch_size)
为了验证是否正常工作,让我们读取并打印第一个小批量样本。这里我们使用iter
构造Python迭代器,并使用next
从迭代器中获取第一项。
next(iter(data_iter))
iter(data_iter)
iter() 是 Python 的内置函数,它的作用是将一个可迭代对象(如列表、元组、DataLoader 等)转换为迭代器对象。迭代器是一种特殊的对象,它实现了 iter() 和 next() 方法,允许我们逐个访问可迭代对象中的元素。
在 PyTorch 里,DataLoader 是一个可迭代对象,它用于批量加载数据。通过iter(DataLoader) 就可以将 DataLoader 转换为迭代器,以便后续使用 next() 函数逐个获取批次数据。next(iter(data_iter))
next() 也是 Python 的内置函数,它用于从迭代器中获取下一个元素。当调用 next(迭代器对象) 时,迭代器会返回其下一个元素,如果没有更多元素,会抛出 StopIteration 异常。
定义模型
当我们在实现线性回归时,我们明确定义了模型参数变量,并编写了计算的代码,这样通过基本的线性代数运算得到输出。
但是,如果模型变得更加复杂,且当我们几乎每天都需要实现模型时,自然会想简化这个过程。
这种情况类似于为自己的博客从零开始编写网页。做一两次是有益的,但如果每个新博客就需要工程师花一个月的时间重新开始编写网页,那并不高效。
对于标准深度学习模型,我们可以[使用框架的预定义好的层]。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。
我们首先定义一个模型变量net
,它是一个Sequential
类的实例。Sequential
类将多个层串联在一起。
当给定输入数据时,Sequential
实例将数据传入到第一层,然后将第一层的输出作为第二层的输入,以此类推。
在下面的例子中,我们的模型只包含一个层,因此实际上不需要Sequential
。但是由于以后几乎所有的模型都是多层的,在这里使用Sequential
会让你熟悉“标准的流水线”。
单层网络架构,这一单层被称为全连接层(fully-connected layer),因为它的每一个输入都通过矩阵-向量乘法得到它的每个输出。
# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))
nn.Sequential
是 PyTorch 中的一个容器类,它可以按顺序将多个神经网络层组合在一起,形成一个完整的神经网络模型。使用 nn.Sequential 可以方便地定义一个简单的前馈神经网络,模型会按照添加层的顺序依次对输入数据进行处理。
nn.Linear
是 PyTorch 中的全连接层(也称为线性层),它实现了一个线性变换,其公式为: y = x A T + b y = xA^T + b y=xAT+b,其中 x x x 是输入数据, A A A 是权重矩阵, b b b 是偏置向量, y y y 是输出数据。
nn.Linear
类的构造函数为nn.Linear(in_features, out_features, bias=True)
,其中:
in_features
:输入特征的数量,即输入数据的维度。out_features
:输出特征的数量,即输出数据的维度。bias
:是否使用偏置项,默认为True
。
(初始化模型参数)
在使用net
之前,我们需要初始化模型参数。如在线性回归模型中的权重和偏置,深度学习框架通常有预定义的方法来初始化参数。
在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样,偏置参数将初始化为零。
正如我们在构造nn.Linear
时指定输入和输出尺寸一样,现在我们能直接访问参数以设定它们的初始值。
我们通过net[0]
选择网络中的第一个图层,然后使用weight.data
和bias.data
方法访问参数。
我们还可以使用替换方法normal_
和fill_
来重写参数值。
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
weight
全连接层的权重参数,它是一个 torch.Tensor 对象,代表了线性变换中的权重矩阵。
bias
:全连接层的偏置参数,它也是一个 torch.Tensor 对象,代表了线性变换中的偏置向量。
data
属性返回的是其底层的普通张量(不包含梯度信息)。我们直接操作data
可以避免在初始化过程中触发不必要的梯度计算。
normal_(0, 0.01)
:这是 PyTorch 张量的一个原地操作方法(方法名末尾带 _ 表示原地操作,会直接修改调用该方法的张量),作用是将张量中的元素用均值为 0、标准差为 0.01 的正态分布。也就是说,这行代码把全连接层的权重矩阵的所有元素初始化为从该正态分布中采样得到的值。
fill_(0)
:这也是一个原地操作方法,它会把偏置张量中的所有元素都填充为 0,也就是将全连接层的偏置向量初始化为零向量。
定义损失函数
[计算均方误差使用的是MSELoss
类,也称为平方 L 2 L_2 L2范数]。默认情况下,它返回所有样本损失的平均值。
loss = nn.MSELoss()
向量的平方 L2 范数
对于一个 n n n 维向量 x = [ x 1 , x 2 , ⋯ , x n ] T \mathbf{x} = [x_1, x_2, \cdots, x_n]^T x=[x1,x2,⋯,xn]T,其 L2 范数(也称为欧几里得范数)定义为向量各元素平方和的平方根,数学表达式为:
∥ x ∥ 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2 = \sqrt{\sum_{i=1}^{n} x_i^2} ∥x∥2=i=1∑nxi2
而向量的平方 L2 范数则是 L2 范数的平方,即:
∥ x ∥ 2 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2^2 = \sum_{i=1}^{n} x_i^2 ∥x∥22=i=1∑nxi2矩阵的平方 L2 范数
对于一个 m × n m \times n m×n 的矩阵 A = [ a i j ] \mathbf{A} = [a_{ij}] A=[aij],其 Frobenius 范数(可以看作是矩阵的一种 L2 范数)定义为矩阵所有元素平方和的平方根,表达式为:
∥ A ∥ F = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F = \sqrt{\sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2} ∥A∥F=i=1∑mj=1∑naij2
矩阵的平方 L2 范数(即平方 Frobenius 范数)为:
∥ A ∥ F 2 = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F^2 = \sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2 ∥A∥F2=i=1∑mj=1∑naij2
定义优化算法
小批量随机梯度下降算法是一种优化神经网络的标准工具,PyTorch在optim
模块中实现了该算法的许多变种。
当我们(实例化一个SGD
实例)时,我们要指定优化的参数(可通过net.parameters()
从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置lr
值,这里设置为0.03。
trainer = torch.optim.SGD(net.parameters(), lr=0.03)
net.parameters()
是 PyTorch 中 nn.Module 类(nn.Sequential 继承自 nn.Module)的一个方法,它会返回一个生成器(Python 中的迭代器对象),这个生成器会逐个产生模型 net 里所有可训练的参数。
训练
通过深度学习框架的高级API来实现我们的模型只需要相对较少的代码。
我们不必单独分配参数、不必定义我们的损失函数,也不必手动实现小批量随机梯度下降。
当我们需要更复杂的模型时,高级API的优势将大大增加。
当我们有了所有的基本组件,[训练过程代码与我们从零开始实现时所做的非常相似]。
回顾一下:在每个迭代周期里,我们将完整遍历一次数据集(train_data
),
不停地从中获取一个小批量的输入和相应的标签。
对于每一个小批量,我们会进行以下步骤:
- 通过调用
net(X)
生成预测并计算损失l
(前向传播)。 - 通过进行反向传播来计算梯度。
- 通过调用优化器来更新模型参数。
为了更好的衡量训练效果,我们计算每个迭代周期后的损失,并打印它来监控训练过程。
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')
下面我们[比较生成数据集的真实参数和通过有限数据训练获得的模型参数]。
要访问参数,我们首先从net
访问所需的层,然后读取该层的权重和偏置。正如在从零开始实现中一样,我们估计得到的参数与生成数据的真实参数非常接近。
w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)
相关文章:

【深度学习】线性回归的简洁实现
线性回归的简洁实现 在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。 这些框架可以自动化基于梯度的学习算法中重复性的工作。 目前,我们只会运用: (1)通…...
渗透测试技法之口令安全
一、口令安全威胁 口令泄露途径 代码与文件存储不当:在软件开发和系统维护过程中,开发者可能会将口令以明文形式存储在代码文件、配置文件或注释中。例如,在开源代码托管平台 GitHub 上,一些开发者由于疏忽,将包含数据…...

【R语言】数学运算
一、基础运算 R语言中能实现加、减、乘、除、求模、取整、取绝对值、指数、对数等运算。 x <- 2 y <- 10 # 求模 y %% x # 整除 y %/% x # 取绝对值 abs(-x) # 指数运算 y ^x y^1/x #对数运算 log(x) #log()函数默认情况下以 e 为底 双等号“”的作用等同于identical(…...

小游戏源码开发搭建技术栈和服务器配置流程
近些年各种场景小游戏开发搭建版本层出不穷,山东布谷科技拥有多年海内外小游戏源码开发经验,现为从事小游戏源码开发或游戏运营的朋友们详细介绍小游戏开发及服务器配置流程。 一、可以对接到app的小游戏是如何开发的 1、小游戏源码开发的需求分析: 明…...
深度学习|表示学习|卷积神经网络|输出维度公式|15
如是我闻: 在卷积和池化操作中,计算输出维度的公式是关键,它们分别可以帮助我们计算卷积操作和池化操作后的输出大小。下面分别总结公式,并结合解释它们的意义: 1. 卷积操作的输出维度公式 当我们对输入图像进行卷积时…...
cpp智能指针
普通指针的不足 new和new[]的内存需要用delete和deletel]释放。 程序员的主观失误,忘了或漏了释放。 程序员也不确定何时释放。 普通指针的释放 类内的指针,在析构函数中释放。 C内置数据类型,如何释放? new出来的类,本身如…...
【面试题】 Java 三年工作经验(2025)
问题列表 为什么选择 spring boot 框架,它与 Spring 有什么区别?spring mvc 的执行流程是什么?如何实现 spring 的 IOC 过程,会用到什么技术?spring boot 的自动化配置的原理是什么?如何理解 spring boot 中…...

MOS的体二极管能通多大电流
第一个问题:MOS导通之后电流方向可以使任意的,既可以从D到S,也可以从S到D。 第二个问题:MOS里面的体二极管电流可以达到几百安培,这也就解释了MOS选型的时候很少考虑体二极管的最大电流,而是考虑DS之间电流…...

Node.js下载安装及环境配置教程 (详细版)
Node.js:是一个基于 Chrome V8 引擎的 JavaScript 运行时,用于构建可扩展的网络应用程序。Node.js 使用事件驱动、非阻塞 I/O 模型,使其非常适合构建实时应用程序。 Node.js 提供了一种轻量、高效、可扩展的方式来构建网络应用程序࿰…...
嵌入式MCU面试笔记2
目录 串口通信 概论 原理 配置 HAL库代码 1. 初始化函数 2. 数据发送和接收函数 3. 中断和DMA函数 4. 中断服务函数 串口通信 概论 我们知道,通信桥接了两个设备之间的交流。一个经典的例子就是使用串口通信交换上位机和单片机之间的数据。 比较常见的串…...

代码随想录算法【Day34】
Day34 62.不同路径 思路 第一种:深搜 -> 超时 第二种:动态规划 第三种:数论 动态规划代码如下: class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> dp(m, vector<int>(n,…...

《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》重印P126、P131勘误
勘误:打圈的地方有指数二字。 指数滤波器本身是错误的概念,我在书上打了一个叉,排版人员误删了。 滤波器部分从根本上有问题,本来要改,但是时间不够了。 和廖老师讨论多次后,决定大动。指数滤波器的概念…...

vim多文件操作如何同屏开多个文件
[rootxxx ~]# vimdiff aa.txt bb.txt cc.txt #带颜色比较的纵向排列打开的同屏多文件操作 示例: [rootxxx ~]# vimdiff -o aa.txt bb.txt cc.txt #带颜色比较的横向排列打开的同屏多文件操作 示例: [rootxxx ~]# vim -O aa.txt bb.txt c…...

day6手机摄影社区,可以去苹果摄影社区学习拍摄技巧
逛自己手机的社区:即(手机牌子)摄影社区 拍照时防止抖动可以控制自己的呼吸,不要大喘气 拍一张照片后,如何简单的用手机修图? HDR模式就是让高光部分和阴影部分更协调(拍风紧时可以打开&…...

渗透测试之WAF规则触发绕过规则之规则库绕过方式
目录 Waf触发规则的绕过 特殊字符替换空格 实例 特殊字符拼接绕过waf Mysql 内置得方法 注释包含关键字 实例 Waf触发规则的绕过 特殊字符替换空格 用一些特殊字符代替空格,比如在mysql中%0a是换行,可以代替空格 这个方法也可以部分绕过最新版本的…...

C语言【基础篇】之流程控制——掌握三大结构的奥秘
流程控制 🚀前言🦜顺序结构💯 定义💯执行规则 🌟选择结构💯if语句💯switch语句💯case穿透规则 🤔循环结构💯for循环💯while循环💯do -…...
c++小知识点
抽象类包含至少一个纯虚函数,不能实例化对象。派生类必须实现基类的所有纯虚函数才能成为非抽象类,从而可以实例化对象。可以使用抽象类的指针或引用指向派生类对象,实现多态性调用。抽象类虽然不能直接实例化,但可以拥有构造函数…...
团体程序设计天梯赛-练习集——L1-022 奇偶分家
前言 这几道题都偏简单一点,没有什么计算,10分 L1-022 奇偶分家 给定N个正整数,请统计奇数和偶数各有多少个? 输入格式: 输入第一行给出一个正整N(≤1000);第2行给出N个非负整数…...
vue项目中,如何获取某一部分的宽高
vue项目中,如何获取某一部分的宽高 在Vue项目中,如果你想要获取某个DOM元素的宽度和高度,可以使用原生的JavaScript方法或者结合Vue的特性来实现。以下是几种常见的方法: 使用ref属性 你可以给需要测量宽高的元素添加一个ref属…...

LeetCode - #195 Swift 实现打印文件中的第十行
网罗开发 (小红书、快手、视频号同名) 大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...

P3 QT项目----记事本(3.8)
3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...
Java 加密常用的各种算法及其选择
在数字化时代,数据安全至关重要,Java 作为广泛应用的编程语言,提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景,有助于开发者在不同的业务需求中做出正确的选择。 一、对称加密算法…...

04-初识css
一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

uniapp微信小程序视频实时流+pc端预览方案
方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度WebSocket图片帧定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐RTMP推流TRTC/即构SDK推流❌ 付费方案 (部分有免费额度&#x…...

OPENCV形态学基础之二腐蚀
一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

springboot整合VUE之在线教育管理系统简介
可以学习到的技能 学会常用技术栈的使用 独立开发项目 学会前端的开发流程 学会后端的开发流程 学会数据库的设计 学会前后端接口调用方式 学会多模块之间的关联 学会数据的处理 适用人群 在校学生,小白用户,想学习知识的 有点基础,想要通过项…...

STM32---外部32.768K晶振(LSE)无法起振问题
晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...
Linux系统部署KES
1、安装准备 1.版本说明V008R006C009B0014 V008:是version产品的大版本。 R006:是release产品特性版本。 C009:是通用版 B0014:是build开发过程中的构建版本2.硬件要求 #安全版和企业版 内存:1GB 以上 硬盘…...

关于easyexcel动态下拉选问题处理
前些日子突然碰到一个问题,说是客户的导入文件模版想支持部分导入内容的下拉选,于是我就找了easyexcel官网寻找解决方案,并没有找到合适的方案,没办法只能自己动手并分享出来,针对Java生成Excel下拉菜单时因选项过多导…...