当前位置: 首页 > news >正文

【深度学习】线性回归的简洁实现

线性回归的简洁实现

在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。
这些框架可以自动化基于梯度的学习算法中重复性的工作。
目前,我们只会运用:
(1)通过张量来进行数据存储和线性代数;
(2)通过自动微分来计算梯度。
实际上,由于数据迭代器、损失函数、优化器和神经网络层很常用,现代深度学习库也为我们实现了这些组件。

本节将介绍如何(通过使用深度学习框架来简洁地实现线性回归模型)。

生成数据集

我们首先[生成数据集]。

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])#用于创建张量(Tensor)
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

读取数据集

我们可以[调用框架中现有的API来读取数据]。
我们将featureslabels作为API的参数传递,并通过数据迭代器指定batch_size
此外,布尔值is_train表示是否希望数据迭代器对象在每个迭代周期内打乱数据。

def load_array(data_arrays, batch_size, is_train=True):  #@save"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)#TensorDataset可以将多个张量组合成一个数据集对象return data.DataLoader(dataset, batch_size, shuffle=is_train)#用于批量加载和处理数据集
batch_size = 10
data_iter = load_array((features, labels), batch_size)

为了验证是否正常工作,让我们读取并打印第一个小批量样本。这里我们使用iter构造Python迭代器,并使用next从迭代器中获取第一项。

next(iter(data_iter))
iter(data_iter)

iter() 是 Python 的内置函数,它的作用是将一个可迭代对象(如列表、元组、DataLoader 等)转换为迭代器对象。迭代器是一种特殊的对象,它实现了 iter() 和 next() 方法,允许我们逐个访问可迭代对象中的元素。
在 PyTorch 里,DataLoader 是一个可迭代对象,它用于批量加载数据。通过iter(DataLoader) 就可以将 DataLoader 转换为迭代器,以便后续使用 next() 函数逐个获取批次数据。

next(iter(data_iter))

next() 也是 Python 的内置函数,它用于从迭代器中获取下一个元素。当调用 next(迭代器对象) 时,迭代器会返回其下一个元素,如果没有更多元素,会抛出 StopIteration 异常。

在这里插入图片描述

定义模型

当我们在实现线性回归时,我们明确定义了模型参数变量,并编写了计算的代码,这样通过基本的线性代数运算得到输出。
但是,如果模型变得更加复杂,且当我们几乎每天都需要实现模型时,自然会想简化这个过程。
这种情况类似于为自己的博客从零开始编写网页。做一两次是有益的,但如果每个新博客就需要工程师花一个月的时间重新开始编写网页,那并不高效。

对于标准深度学习模型,我们可以[使用框架的预定义好的层]。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。

我们首先定义一个模型变量net,它是一个Sequential类的实例。Sequential类将多个层串联在一起。
当给定输入数据时,Sequential实例将数据传入到第一层,然后将第一层的输出作为第二层的输入,以此类推。
在下面的例子中,我们的模型只包含一个层,因此实际上不需要Sequential。但是由于以后几乎所有的模型都是多层的,在这里使用Sequential会让你熟悉“标准的流水线”。

单层网络架构,这一单层被称为全连接层(fully-connected layer),因为它的每一个输入都通过矩阵-向量乘法得到它的每个输出。

# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))

nn.Sequential 是 PyTorch 中的一个容器类,它可以按顺序将多个神经网络层组合在一起,形成一个完整的神经网络模型。使用 nn.Sequential 可以方便地定义一个简单的前馈神经网络,模型会按照添加层的顺序依次对输入数据进行处理。

nn.Linear 是 PyTorch 中的全连接层(也称为线性层),它实现了一个线性变换,其公式为: y = x A T + b y = xA^T + b y=xAT+b,其中 x x x 是输入数据, A A A 是权重矩阵, b b b 是偏置向量, y y y 是输出数据。
nn.Linear 类的构造函数为 nn.Linear(in_features, out_features, bias=True),其中:

  • in_features:输入特征的数量,即输入数据的维度。
  • out_features:输出特征的数量,即输出数据的维度。
  • bias:是否使用偏置项,默认为 True

(初始化模型参数)

在使用net之前,我们需要初始化模型参数。如在线性回归模型中的权重和偏置,深度学习框架通常有预定义的方法来初始化参数。
在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样,偏置参数将初始化为零。

正如我们在构造nn.Linear时指定输入和输出尺寸一样,现在我们能直接访问参数以设定它们的初始值。
我们通过net[0]选择网络中的第一个图层,然后使用weight.databias.data方法访问参数。
我们还可以使用替换方法normal_fill_来重写参数值。

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

weight全连接层的权重参数,它是一个 torch.Tensor 对象,代表了线性变换中的权重矩阵。

bias:全连接层的偏置参数,它也是一个 torch.Tensor 对象,代表了线性变换中的偏置向量。

data 属性返回的是其底层的普通张量(不包含梯度信息)。我们直接操作 data 可以避免在初始化过程中触发不必要的梯度计算。

normal_(0, 0.01):这是 PyTorch 张量的一个原地操作方法(方法名末尾带 _ 表示原地操作,会直接修改调用该方法的张量),作用是将张量中的元素用均值为 0、标准差为 0.01 的正态分布。也就是说,这行代码把全连接层的权重矩阵的所有元素初始化为从该正态分布中采样得到的值。
fill_(0):这也是一个原地操作方法,它会把偏置张量中的所有元素都填充为 0,也就是将全连接层的偏置向量初始化为零向量。

定义损失函数

[计算均方误差使用的是MSELoss类,也称为平方 L 2 L_2 L2范数]。默认情况下,它返回所有样本损失的平均值。

loss = nn.MSELoss()
向量的平方 L2 范数

对于一个 n n n 维向量 x = [ x 1 , x 2 , ⋯ , x n ] T \mathbf{x} = [x_1, x_2, \cdots, x_n]^T x=[x1,x2,,xn]T,其 L2 范数(也称为欧几里得范数)定义为向量各元素平方和的平方根,数学表达式为:
∥ x ∥ 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2 = \sqrt{\sum_{i=1}^{n} x_i^2} x2=i=1nxi2
而向量的平方 L2 范数则是 L2 范数的平方,即:
∥ x ∥ 2 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2^2 = \sum_{i=1}^{n} x_i^2 x22=i=1nxi2

矩阵的平方 L2 范数

对于一个 m × n m \times n m×n 的矩阵 A = [ a i j ] \mathbf{A} = [a_{ij}] A=[aij],其 Frobenius 范数(可以看作是矩阵的一种 L2 范数)定义为矩阵所有元素平方和的平方根,表达式为:
∥ A ∥ F = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F = \sqrt{\sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2} AF=i=1mj=1naij2
矩阵的平方 L2 范数(即平方 Frobenius 范数)为:
∥ A ∥ F 2 = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F^2 = \sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2 AF2=i=1mj=1naij2

定义优化算法

小批量随机梯度下降算法是一种优化神经网络的标准工具,PyTorch在optim模块中实现了该算法的许多变种。
当我们(实例化一个SGD实例)时,我们要指定优化的参数(可通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置lr值,这里设置为0.03。

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

net.parameters() 是 PyTorch 中 nn.Module 类(nn.Sequential 继承自 nn.Module)的一个方法,它会返回一个生成器(Python 中的迭代器对象),这个生成器会逐个产生模型 net 里所有可训练的参数。

训练

通过深度学习框架的高级API来实现我们的模型只需要相对较少的代码。
我们不必单独分配参数、不必定义我们的损失函数,也不必手动实现小批量随机梯度下降。
当我们需要更复杂的模型时,高级API的优势将大大增加。
当我们有了所有的基本组件,[训练过程代码与我们从零开始实现时所做的非常相似]。

回顾一下:在每个迭代周期里,我们将完整遍历一次数据集(train_data),
不停地从中获取一个小批量的输入和相应的标签。
对于每一个小批量,我们会进行以下步骤:

  • 通过调用net(X)生成预测并计算损失l(前向传播)。
  • 通过进行反向传播来计算梯度。
  • 通过调用优化器来更新模型参数。

为了更好的衡量训练效果,我们计算每个迭代周期后的损失,并打印它来监控训练过程。

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')

在这里插入图片描述
下面我们[比较生成数据集的真实参数和通过有限数据训练获得的模型参数]。
要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。正如在从零开始实现中一样,我们估计得到的参数与生成数据的真实参数非常接近。

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

在这里插入图片描述

相关文章:

【深度学习】线性回归的简洁实现

线性回归的简洁实现 在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。 这些框架可以自动化基于梯度的学习算法中重复性的工作。 目前,我们只会运用: (1)通…...

渗透测试技法之口令安全

一、口令安全威胁 口令泄露途径 代码与文件存储不当:在软件开发和系统维护过程中,开发者可能会将口令以明文形式存储在代码文件、配置文件或注释中。例如,在开源代码托管平台 GitHub 上,一些开发者由于疏忽,将包含数据…...

【R语言】数学运算

一、基础运算 R语言中能实现加、减、乘、除、求模、取整、取绝对值、指数、对数等运算。 x <- 2 y <- 10 # 求模 y %% x # 整除 y %/% x # 取绝对值 abs(-x) # 指数运算 y ^x y^1/x #对数运算 log(x) #log()函数默认情况下以 e 为底 双等号“”的作用等同于identical(…...

小游戏源码开发搭建技术栈和服务器配置流程

近些年各种场景小游戏开发搭建版本层出不穷,山东布谷科技拥有多年海内外小游戏源码开发经验&#xff0c;现为从事小游戏源码开发或游戏运营的朋友们详细介绍小游戏开发及服务器配置流程。 一、可以对接到app的小游戏是如何开发的 1、小游戏源码开发的需求分析&#xff1a; 明…...

深度学习|表示学习|卷积神经网络|输出维度公式|15

如是我闻&#xff1a; 在卷积和池化操作中&#xff0c;计算输出维度的公式是关键&#xff0c;它们分别可以帮助我们计算卷积操作和池化操作后的输出大小。下面分别总结公式&#xff0c;并结合解释它们的意义&#xff1a; 1. 卷积操作的输出维度公式 当我们对输入图像进行卷积时…...

cpp智能指针

普通指针的不足 new和new[]的内存需要用delete和deletel]释放。 程序员的主观失误&#xff0c;忘了或漏了释放。 程序员也不确定何时释放。 普通指针的释放 类内的指针&#xff0c;在析构函数中释放。 C内置数据类型&#xff0c;如何释放? new出来的类&#xff0c;本身如…...

【面试题】 Java 三年工作经验(2025)

问题列表 为什么选择 spring boot 框架&#xff0c;它与 Spring 有什么区别&#xff1f;spring mvc 的执行流程是什么&#xff1f;如何实现 spring 的 IOC 过程&#xff0c;会用到什么技术&#xff1f;spring boot 的自动化配置的原理是什么&#xff1f;如何理解 spring boot 中…...

MOS的体二极管能通多大电流

第一个问题&#xff1a;MOS导通之后电流方向可以使任意的&#xff0c;既可以从D到S&#xff0c;也可以从S到D。 第二个问题&#xff1a;MOS里面的体二极管电流可以达到几百安培&#xff0c;这也就解释了MOS选型的时候很少考虑体二极管的最大电流&#xff0c;而是考虑DS之间电流…...

Node.js下载安装及环境配置教程 (详细版)

Node.js&#xff1a;是一个基于 Chrome V8 引擎的 JavaScript 运行时&#xff0c;用于构建可扩展的网络应用程序。Node.js 使用事件驱动、非阻塞 I/O 模型&#xff0c;使其非常适合构建实时应用程序。 Node.js 提供了一种轻量、高效、可扩展的方式来构建网络应用程序&#xff0…...

嵌入式MCU面试笔记2

目录 串口通信 概论 原理 配置 HAL库代码 1. 初始化函数 2. 数据发送和接收函数 3. 中断和DMA函数 4. 中断服务函数 串口通信 概论 我们知道&#xff0c;通信桥接了两个设备之间的交流。一个经典的例子就是使用串口通信交换上位机和单片机之间的数据。 比较常见的串…...

代码随想录算法【Day34】

Day34 62.不同路径 思路 第一种&#xff1a;深搜 -> 超时 第二种&#xff1a;动态规划 第三种&#xff1a;数论 动态规划代码如下&#xff1a; class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> dp(m, vector<int>(n,…...

《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》重印P126、P131勘误

勘误&#xff1a;打圈的地方有指数二字。 指数滤波器本身是错误的概念&#xff0c;我在书上打了一个叉&#xff0c;排版人员误删了。 滤波器部分从根本上有问题&#xff0c;本来要改&#xff0c;但是时间不够了。 和廖老师讨论多次后&#xff0c;决定大动。指数滤波器的概念…...

vim多文件操作如何同屏开多个文件

[rootxxx ~]# vimdiff aa.txt bb.txt cc.txt #带颜色比较的纵向排列打开的同屏多文件操作 示例&#xff1a; [rootxxx ~]# vimdiff -o aa.txt bb.txt cc.txt #带颜色比较的横向排列打开的同屏多文件操作 示例&#xff1a; [rootxxx ~]# vim -O aa.txt bb.txt c…...

day6手机摄影社区,可以去苹果摄影社区学习拍摄技巧

逛自己手机的社区&#xff1a;即&#xff08;手机牌子&#xff09;摄影社区 拍照时防止抖动可以控制自己的呼吸&#xff0c;不要大喘气 拍一张照片后&#xff0c;如何简单的用手机修图&#xff1f; HDR模式就是让高光部分和阴影部分更协调&#xff08;拍风紧时可以打开&…...

渗透测试之WAF规则触发绕过规则之规则库绕过方式

目录 Waf触发规则的绕过 特殊字符替换空格 实例 特殊字符拼接绕过waf Mysql 内置得方法 注释包含关键字 实例 Waf触发规则的绕过 特殊字符替换空格 用一些特殊字符代替空格&#xff0c;比如在mysql中%0a是换行&#xff0c;可以代替空格 这个方法也可以部分绕过最新版本的…...

C语言【基础篇】之流程控制——掌握三大结构的奥秘

流程控制 &#x1f680;前言&#x1f99c;顺序结构&#x1f4af; 定义&#x1f4af;执行规则 &#x1f31f;选择结构&#x1f4af;if语句&#x1f4af;switch语句&#x1f4af;case穿透规则 &#x1f914;循环结构&#x1f4af;for循环&#x1f4af;while循环&#x1f4af;do -…...

c++小知识点

抽象类包含至少一个纯虚函数&#xff0c;不能实例化对象。派生类必须实现基类的所有纯虚函数才能成为非抽象类&#xff0c;从而可以实例化对象。可以使用抽象类的指针或引用指向派生类对象&#xff0c;实现多态性调用。抽象类虽然不能直接实例化&#xff0c;但可以拥有构造函数…...

团体程序设计天梯赛-练习集——L1-022 奇偶分家

前言 这几道题都偏简单一点&#xff0c;没有什么计算&#xff0c;10分 L1-022 奇偶分家 给定N个正整数&#xff0c;请统计奇数和偶数各有多少个&#xff1f; 输入格式&#xff1a; 输入第一行给出一个正整N&#xff08;≤1000&#xff09;&#xff1b;第2行给出N个非负整数…...

vue项目中,如何获取某一部分的宽高

vue项目中&#xff0c;如何获取某一部分的宽高 在Vue项目中&#xff0c;如果你想要获取某个DOM元素的宽度和高度&#xff0c;可以使用原生的JavaScript方法或者结合Vue的特性来实现。以下是几种常见的方法&#xff1a; 使用ref属性 你可以给需要测量宽高的元素添加一个ref属…...

LeetCode - #195 Swift 实现打印文件中的第十行

网罗开发 &#xff08;小红书、快手、视频号同名&#xff09; 大家好&#xff0c;我是 展菲&#xff0c;目前在上市企业从事人工智能项目研发管理工作&#xff0c;平时热衷于分享各种编程领域的软硬技能知识以及前沿技术&#xff0c;包括iOS、前端、Harmony OS、Java、Python等…...

微信小程序之bind和catch

这两个呢&#xff0c;都是绑定事件用的&#xff0c;具体使用有些小区别。 官方文档&#xff1a; 事件冒泡处理不同 bind&#xff1a;绑定的事件会向上冒泡&#xff0c;即触发当前组件的事件后&#xff0c;还会继续触发父组件的相同事件。例如&#xff0c;有一个子视图绑定了b…...

docker详细操作--未完待续

docker介绍 docker官网: Docker&#xff1a;加速容器应用程序开发 harbor官网&#xff1a;Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台&#xff0c;用于将应用程序及其依赖项&#xff08;如库、运行时环…...

uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖

在前面的练习中&#xff0c;每个页面需要使用ref&#xff0c;onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入&#xff0c;需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中&#xff0c;高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术&#xff0c;实现年省电费15%-60%&#xff0c;且不改动原有装备、安装快捷、…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台&#xff08;Launchpad&#xff09;多出来了&#xff1a;Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显&#xff0c;都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

PL0语法,分析器实现!

简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同&#xff0c;结合所安装的tensorflow的目录结构修改from语句即可。 原语句&#xff1a; from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后&#xff1a; from tensorflow.python.keras.lay…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

【Post-process】【VBA】ETABS VBA FrameObj.GetNameList and write to EXCEL

ETABS API实战:导出框架元素数据到Excel 在结构工程师的日常工作中,经常需要从ETABS模型中提取框架元素信息进行后续分析。手动复制粘贴不仅耗时,还容易出错。今天我们来用简单的VBA代码实现自动化导出。 🎯 我们要实现什么? 一键点击,就能将ETABS中所有框架元素的基…...