【深度学习】线性回归的简洁实现
线性回归的简洁实现
在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。
这些框架可以自动化基于梯度的学习算法中重复性的工作。
目前,我们只会运用:
(1)通过张量来进行数据存储和线性代数;
(2)通过自动微分来计算梯度。
实际上,由于数据迭代器、损失函数、优化器和神经网络层很常用,现代深度学习库也为我们实现了这些组件。
本节将介绍如何(通过使用深度学习框架来简洁地实现线性回归模型)。
生成数据集
我们首先[生成数据集]。
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])#用于创建张量(Tensor)
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
读取数据集
我们可以[调用框架中现有的API来读取数据]。
我们将features和labels作为API的参数传递,并通过数据迭代器指定batch_size。
此外,布尔值is_train表示是否希望数据迭代器对象在每个迭代周期内打乱数据。
def load_array(data_arrays, batch_size, is_train=True): #@save"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)#TensorDataset可以将多个张量组合成一个数据集对象return data.DataLoader(dataset, batch_size, shuffle=is_train)#用于批量加载和处理数据集
batch_size = 10
data_iter = load_array((features, labels), batch_size)
为了验证是否正常工作,让我们读取并打印第一个小批量样本。这里我们使用iter构造Python迭代器,并使用next从迭代器中获取第一项。
next(iter(data_iter))
iter(data_iter)
iter() 是 Python 的内置函数,它的作用是将一个可迭代对象(如列表、元组、DataLoader 等)转换为迭代器对象。迭代器是一种特殊的对象,它实现了 iter() 和 next() 方法,允许我们逐个访问可迭代对象中的元素。
在 PyTorch 里,DataLoader 是一个可迭代对象,它用于批量加载数据。通过iter(DataLoader) 就可以将 DataLoader 转换为迭代器,以便后续使用 next() 函数逐个获取批次数据。next(iter(data_iter))
next() 也是 Python 的内置函数,它用于从迭代器中获取下一个元素。当调用 next(迭代器对象) 时,迭代器会返回其下一个元素,如果没有更多元素,会抛出 StopIteration 异常。

定义模型
当我们在实现线性回归时,我们明确定义了模型参数变量,并编写了计算的代码,这样通过基本的线性代数运算得到输出。
但是,如果模型变得更加复杂,且当我们几乎每天都需要实现模型时,自然会想简化这个过程。
这种情况类似于为自己的博客从零开始编写网页。做一两次是有益的,但如果每个新博客就需要工程师花一个月的时间重新开始编写网页,那并不高效。
对于标准深度学习模型,我们可以[使用框架的预定义好的层]。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。
我们首先定义一个模型变量net,它是一个Sequential类的实例。Sequential类将多个层串联在一起。
当给定输入数据时,Sequential实例将数据传入到第一层,然后将第一层的输出作为第二层的输入,以此类推。
在下面的例子中,我们的模型只包含一个层,因此实际上不需要Sequential。但是由于以后几乎所有的模型都是多层的,在这里使用Sequential会让你熟悉“标准的流水线”。
单层网络架构,这一单层被称为全连接层(fully-connected layer),因为它的每一个输入都通过矩阵-向量乘法得到它的每个输出。
# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))
nn.Sequential是 PyTorch 中的一个容器类,它可以按顺序将多个神经网络层组合在一起,形成一个完整的神经网络模型。使用 nn.Sequential 可以方便地定义一个简单的前馈神经网络,模型会按照添加层的顺序依次对输入数据进行处理。
nn.Linear是 PyTorch 中的全连接层(也称为线性层),它实现了一个线性变换,其公式为: y = x A T + b y = xA^T + b y=xAT+b,其中 x x x 是输入数据, A A A 是权重矩阵, b b b 是偏置向量, y y y 是输出数据。
nn.Linear类的构造函数为nn.Linear(in_features, out_features, bias=True),其中:
in_features:输入特征的数量,即输入数据的维度。out_features:输出特征的数量,即输出数据的维度。bias:是否使用偏置项,默认为True。
(初始化模型参数)
在使用net之前,我们需要初始化模型参数。如在线性回归模型中的权重和偏置,深度学习框架通常有预定义的方法来初始化参数。
在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样,偏置参数将初始化为零。
正如我们在构造nn.Linear时指定输入和输出尺寸一样,现在我们能直接访问参数以设定它们的初始值。
我们通过net[0]选择网络中的第一个图层,然后使用weight.data和bias.data方法访问参数。
我们还可以使用替换方法normal_和fill_来重写参数值。
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
weight全连接层的权重参数,它是一个 torch.Tensor 对象,代表了线性变换中的权重矩阵。
bias:全连接层的偏置参数,它也是一个 torch.Tensor 对象,代表了线性变换中的偏置向量。
data属性返回的是其底层的普通张量(不包含梯度信息)。我们直接操作data可以避免在初始化过程中触发不必要的梯度计算。
normal_(0, 0.01):这是 PyTorch 张量的一个原地操作方法(方法名末尾带 _ 表示原地操作,会直接修改调用该方法的张量),作用是将张量中的元素用均值为 0、标准差为 0.01 的正态分布。也就是说,这行代码把全连接层的权重矩阵的所有元素初始化为从该正态分布中采样得到的值。
fill_(0):这也是一个原地操作方法,它会把偏置张量中的所有元素都填充为 0,也就是将全连接层的偏置向量初始化为零向量。
定义损失函数
[计算均方误差使用的是MSELoss类,也称为平方 L 2 L_2 L2范数]。默认情况下,它返回所有样本损失的平均值。
loss = nn.MSELoss()
向量的平方 L2 范数
对于一个 n n n 维向量 x = [ x 1 , x 2 , ⋯ , x n ] T \mathbf{x} = [x_1, x_2, \cdots, x_n]^T x=[x1,x2,⋯,xn]T,其 L2 范数(也称为欧几里得范数)定义为向量各元素平方和的平方根,数学表达式为:
∥ x ∥ 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2 = \sqrt{\sum_{i=1}^{n} x_i^2} ∥x∥2=i=1∑nxi2
而向量的平方 L2 范数则是 L2 范数的平方,即:
∥ x ∥ 2 2 = ∑ i = 1 n x i 2 \|\mathbf{x}\|_2^2 = \sum_{i=1}^{n} x_i^2 ∥x∥22=i=1∑nxi2矩阵的平方 L2 范数
对于一个 m × n m \times n m×n 的矩阵 A = [ a i j ] \mathbf{A} = [a_{ij}] A=[aij],其 Frobenius 范数(可以看作是矩阵的一种 L2 范数)定义为矩阵所有元素平方和的平方根,表达式为:
∥ A ∥ F = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F = \sqrt{\sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2} ∥A∥F=i=1∑mj=1∑naij2
矩阵的平方 L2 范数(即平方 Frobenius 范数)为:
∥ A ∥ F 2 = ∑ i = 1 m ∑ j = 1 n a i j 2 \|\mathbf{A}\|_F^2 = \sum_{i=1}^{m} \sum_{j=1}^{n} a_{ij}^2 ∥A∥F2=i=1∑mj=1∑naij2
定义优化算法
小批量随机梯度下降算法是一种优化神经网络的标准工具,PyTorch在optim模块中实现了该算法的许多变种。
当我们(实例化一个SGD实例)时,我们要指定优化的参数(可通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置lr值,这里设置为0.03。
trainer = torch.optim.SGD(net.parameters(), lr=0.03)
net.parameters()是 PyTorch 中 nn.Module 类(nn.Sequential 继承自 nn.Module)的一个方法,它会返回一个生成器(Python 中的迭代器对象),这个生成器会逐个产生模型 net 里所有可训练的参数。
训练
通过深度学习框架的高级API来实现我们的模型只需要相对较少的代码。
我们不必单独分配参数、不必定义我们的损失函数,也不必手动实现小批量随机梯度下降。
当我们需要更复杂的模型时,高级API的优势将大大增加。
当我们有了所有的基本组件,[训练过程代码与我们从零开始实现时所做的非常相似]。
回顾一下:在每个迭代周期里,我们将完整遍历一次数据集(train_data),
不停地从中获取一个小批量的输入和相应的标签。
对于每一个小批量,我们会进行以下步骤:
- 通过调用
net(X)生成预测并计算损失l(前向传播)。 - 通过进行反向传播来计算梯度。
- 通过调用优化器来更新模型参数。
为了更好的衡量训练效果,我们计算每个迭代周期后的损失,并打印它来监控训练过程。
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')

下面我们[比较生成数据集的真实参数和通过有限数据训练获得的模型参数]。
要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。正如在从零开始实现中一样,我们估计得到的参数与生成数据的真实参数非常接近。
w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

相关文章:
【深度学习】线性回归的简洁实现
线性回归的简洁实现 在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。 这些框架可以自动化基于梯度的学习算法中重复性的工作。 目前,我们只会运用: (1)通…...
渗透测试技法之口令安全
一、口令安全威胁 口令泄露途径 代码与文件存储不当:在软件开发和系统维护过程中,开发者可能会将口令以明文形式存储在代码文件、配置文件或注释中。例如,在开源代码托管平台 GitHub 上,一些开发者由于疏忽,将包含数据…...
【R语言】数学运算
一、基础运算 R语言中能实现加、减、乘、除、求模、取整、取绝对值、指数、对数等运算。 x <- 2 y <- 10 # 求模 y %% x # 整除 y %/% x # 取绝对值 abs(-x) # 指数运算 y ^x y^1/x #对数运算 log(x) #log()函数默认情况下以 e 为底 双等号“”的作用等同于identical(…...
小游戏源码开发搭建技术栈和服务器配置流程
近些年各种场景小游戏开发搭建版本层出不穷,山东布谷科技拥有多年海内外小游戏源码开发经验,现为从事小游戏源码开发或游戏运营的朋友们详细介绍小游戏开发及服务器配置流程。 一、可以对接到app的小游戏是如何开发的 1、小游戏源码开发的需求分析: 明…...
深度学习|表示学习|卷积神经网络|输出维度公式|15
如是我闻: 在卷积和池化操作中,计算输出维度的公式是关键,它们分别可以帮助我们计算卷积操作和池化操作后的输出大小。下面分别总结公式,并结合解释它们的意义: 1. 卷积操作的输出维度公式 当我们对输入图像进行卷积时…...
cpp智能指针
普通指针的不足 new和new[]的内存需要用delete和deletel]释放。 程序员的主观失误,忘了或漏了释放。 程序员也不确定何时释放。 普通指针的释放 类内的指针,在析构函数中释放。 C内置数据类型,如何释放? new出来的类,本身如…...
【面试题】 Java 三年工作经验(2025)
问题列表 为什么选择 spring boot 框架,它与 Spring 有什么区别?spring mvc 的执行流程是什么?如何实现 spring 的 IOC 过程,会用到什么技术?spring boot 的自动化配置的原理是什么?如何理解 spring boot 中…...
MOS的体二极管能通多大电流
第一个问题:MOS导通之后电流方向可以使任意的,既可以从D到S,也可以从S到D。 第二个问题:MOS里面的体二极管电流可以达到几百安培,这也就解释了MOS选型的时候很少考虑体二极管的最大电流,而是考虑DS之间电流…...
Node.js下载安装及环境配置教程 (详细版)
Node.js:是一个基于 Chrome V8 引擎的 JavaScript 运行时,用于构建可扩展的网络应用程序。Node.js 使用事件驱动、非阻塞 I/O 模型,使其非常适合构建实时应用程序。 Node.js 提供了一种轻量、高效、可扩展的方式来构建网络应用程序࿰…...
嵌入式MCU面试笔记2
目录 串口通信 概论 原理 配置 HAL库代码 1. 初始化函数 2. 数据发送和接收函数 3. 中断和DMA函数 4. 中断服务函数 串口通信 概论 我们知道,通信桥接了两个设备之间的交流。一个经典的例子就是使用串口通信交换上位机和单片机之间的数据。 比较常见的串…...
代码随想录算法【Day34】
Day34 62.不同路径 思路 第一种:深搜 -> 超时 第二种:动态规划 第三种:数论 动态规划代码如下: class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> dp(m, vector<int>(n,…...
《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》重印P126、P131勘误
勘误:打圈的地方有指数二字。 指数滤波器本身是错误的概念,我在书上打了一个叉,排版人员误删了。 滤波器部分从根本上有问题,本来要改,但是时间不够了。 和廖老师讨论多次后,决定大动。指数滤波器的概念…...
vim多文件操作如何同屏开多个文件
[rootxxx ~]# vimdiff aa.txt bb.txt cc.txt #带颜色比较的纵向排列打开的同屏多文件操作 示例: [rootxxx ~]# vimdiff -o aa.txt bb.txt cc.txt #带颜色比较的横向排列打开的同屏多文件操作 示例: [rootxxx ~]# vim -O aa.txt bb.txt c…...
day6手机摄影社区,可以去苹果摄影社区学习拍摄技巧
逛自己手机的社区:即(手机牌子)摄影社区 拍照时防止抖动可以控制自己的呼吸,不要大喘气 拍一张照片后,如何简单的用手机修图? HDR模式就是让高光部分和阴影部分更协调(拍风紧时可以打开&…...
渗透测试之WAF规则触发绕过规则之规则库绕过方式
目录 Waf触发规则的绕过 特殊字符替换空格 实例 特殊字符拼接绕过waf Mysql 内置得方法 注释包含关键字 实例 Waf触发规则的绕过 特殊字符替换空格 用一些特殊字符代替空格,比如在mysql中%0a是换行,可以代替空格 这个方法也可以部分绕过最新版本的…...
C语言【基础篇】之流程控制——掌握三大结构的奥秘
流程控制 🚀前言🦜顺序结构💯 定义💯执行规则 🌟选择结构💯if语句💯switch语句💯case穿透规则 🤔循环结构💯for循环💯while循环💯do -…...
c++小知识点
抽象类包含至少一个纯虚函数,不能实例化对象。派生类必须实现基类的所有纯虚函数才能成为非抽象类,从而可以实例化对象。可以使用抽象类的指针或引用指向派生类对象,实现多态性调用。抽象类虽然不能直接实例化,但可以拥有构造函数…...
团体程序设计天梯赛-练习集——L1-022 奇偶分家
前言 这几道题都偏简单一点,没有什么计算,10分 L1-022 奇偶分家 给定N个正整数,请统计奇数和偶数各有多少个? 输入格式: 输入第一行给出一个正整N(≤1000);第2行给出N个非负整数…...
vue项目中,如何获取某一部分的宽高
vue项目中,如何获取某一部分的宽高 在Vue项目中,如果你想要获取某个DOM元素的宽度和高度,可以使用原生的JavaScript方法或者结合Vue的特性来实现。以下是几种常见的方法: 使用ref属性 你可以给需要测量宽高的元素添加一个ref属…...
LeetCode - #195 Swift 实现打印文件中的第十行
网罗开发 (小红书、快手、视频号同名) 大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等…...
基于服务器使用 apt 安装、配置 Nginx
🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...
2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面
代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...
WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...
相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...
GitHub 趋势日报 (2025年06月08日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...
并发编程 - go版
1.并发编程基础概念 进程和线程 A. 进程是程序在操作系统中的一次执行过程,系统进行资源分配和调度的一个独立单位。B. 线程是进程的一个执行实体,是CPU调度和分派的基本单位,它是比进程更小的能独立运行的基本单位。C.一个进程可以创建和撤销多个线程;同一个进程中…...
学习一下用鸿蒙DevEco Studio HarmonyOS5实现百度地图
在鸿蒙(HarmonyOS5)中集成百度地图,可以通过以下步骤和技术方案实现。结合鸿蒙的分布式能力和百度地图的API,可以构建跨设备的定位、导航和地图展示功能。 1. 鸿蒙环境准备 开发工具:下载安装 De…...
从零开始了解数据采集(二十八)——制造业数字孪生
近年来,我国的工业领域正经历一场前所未有的数字化变革,从“双碳目标”到工业互联网平台的推广,国家政策和市场需求共同推动了制造业的升级。在这场变革中,数字孪生技术成为备受关注的关键工具,它不仅让企业“看见”设…...
