【现代深度学习技术】深度学习计算 | 延后初始化自定义层

【作者主页】Francek Chen
【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重要的技术特征是具有自动提取特征的能力。神经网络算法、算力和数据是开展深度学习的三要素。深度学习在计算机视觉、自然语言处理、多模态数据分析、科学探索等领域都取得了很多成果。本专栏介绍基于PyTorch的深度学习算法实现。
【GitCode】专栏资源保存在我的GitCode仓库:https://gitcode.com/Morse_Chen/PyTorch_deep_learning。
文章目录
- 一、延后初始化
- 实例化网络
- 二、自定义层
- (一)不带参数的层
- (二)带参数的层
- 小结
一、延后初始化
到目前为止,我们忽略了建立网络时需要做的以下这些事情:
- 我们定义了网络架构,但没有指定输入维度。
- 我们添加层时没有指定前一层的输出维度。
- 我们在初始化参数时,甚至没有足够的信息来确定模型应该包含多少参数。
有些读者可能会对我们的代码能运行感到惊讶。 毕竟,深度学习框架无法判断网络的输入维度是什么。 这里的诀窍是框架的延后初始化(defers initialization), 即直到数据第一次通过模型传递时,框架才会动态地推断出每个层的大小。
在以后,当使用卷积神经网络时, 由于输入维度(即图像的分辨率)将影响每个后续层的维数, 有了该技术将更加方便。 现在我们在编写代码时无须知道维度是什么就可以设置参数, 这种能力可以大大简化定义和修改模型的任务。 接下来,我们将更深入地研究初始化机制。
实例化网络
首先,让我们实例化一个多层感知机。此时,因为输入维数是未知的,所以网络不可能知道输入层权重的维数。 因此,框架尚未初始化任何参数,我们通过尝试访问以下参数进行确认。
接下来让我们将数据通过网络,最终使框架初始化参数。
一旦我们知道输入维数是20,框架可以通过代入值20来识别第一层权重矩阵的形状。 识别出第一层的形状后,框架处理第二层,依此类推,直到所有形状都已知为止。 注意,在这种情况下,只有第一层需要延迟初始化,但是框架仍是按顺序初始化的。 等到知道了所有的参数形状,框架就可以初始化参数。
二、自定义层
深度学习成功背后的一个因素是神经网络的灵活性:我们可以用创造性的方式组合不同的层,从而设计出适用于各种任务的架构。例如,研究人员发明了专门用于处理图像、文本、序列数据和执行动态规划的层。有时我们会遇到或要自己发明一个现在在深度学习框架中还不存在的层。在这些情况下,必须构建自定义层。本节将展示如何构建自定义层。
(一)不带参数的层
首先,我们构造一个没有任何参数的自定义层。回忆一下在【现代深度学习技术】深度学习计算 | 层和块 对块的介绍,这应该看起来很眼熟。下面的CenteredLayer类要从其输入中减去均值。要构建它,我们只需继承基础层类并实现前向传播功能。
import torch
import torch.nn.functional as F
from torch import nnclass CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X - X.mean()
让我们向该层提供一些数据,验证它是否能按预期工作。
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

现在,我们可以将层作为组件合并到更复杂的模型中。
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
作为额外的健全性检查,我们可以在向该网络发送随机数据后,检查均值是否为0。由于我们处理的是浮点数,因为存储精度的原因,我们仍然可能会看到一个非常小的非零数。
Y = net(torch.rand(4, 8))
Y.mean()

(二)带参数的层
以上我们知道了如何定义简单的层,下面我们继续定义具有参数的层,这些参数可以通过训练进行调整。我们可以使用内置函数来创建参数,这些函数提供一些基本的管理功能。比如管理访问、初始化、共享、保存和加载模型参数。这样做的好处之一是:我们不需要为每个自定义层编写自定义的序列化程序。
现在,让我们实现自定义版本的全连接层。回想一下,该层需要两个参数,一个用于表示权重,另一个用于表示偏置项。在此实现中,我们使用修正线性单元作为激活函数。该层需要输入参数:in_units和units,分别表示输入数和输出数。
class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units))self.bias = nn.Parameter(torch.randn(units,))def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.datareturn F.relu(linear)
接下来,我们实例化MyLinear类并访问其模型参数。
linear = MyLinear(5, 3)
linear.weight

我们可以使用自定义层直接执行前向传播计算。
linear(torch.rand(2, 5))

我们还可以使用自定义层构建模型,就像使用内置的全连接层一样使用自定义层。
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

小结
- 延后初始化使框架能够自动推断参数形状,使修改模型架构变得容易,避免了一些常见的错误。
- 我们可以通过模型传递数据,使框架最终初始化参数。
- 我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。
- 在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。
- 层可以有局部参数,这些参数可以通过内置函数创建。
相关文章:
【现代深度学习技术】深度学习计算 | 延后初始化自定义层
【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重…...
LeetCode 3105. Longest Strictly Increasing or Strictly Decreasing Subarray
🔗 https://leetcode.com/problems/longest-strictly-increasing-or-strictly-decreasing-subarray 题目 给一个数组,返回其最长严格升序或者降序的子数组长度 思路 模拟 代码 class Solution { public:int longestMonotonicSubarray(vector<in…...
Java导出Excel简单工具类
一、maven配置 <!--jxl--><dependency><groupId>net.sourceforge.jexcelapi</groupId><artifactId>jxl</artifactId><version>2.6.12</version></dependency>二、工具类方法 package util2;import jxl.Workbook; impor…...
蓝桥与力扣刷题(141 环形链表)
题目:给你一个链表的头节点 head ,判断链表中是否有环。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 pos 来表示链表尾连接到链表中的…...
【小鱼闪闪】做一个物联网控制小灯的制作流程简要介绍(图文)
1、注册物联网云平台,这里选用巴法云 2.、新建主题 “ledtest” 3、 使用Arduino或Mixly软件编写单片机程序(需要引用巴法云库文件),程序中订阅“ledtest”主题,用于接收单片机发送来的数据。此处会将连接的温度传感器…...
图论常见算法
图论常见算法 算法prim算法Dijkstra算法 用途最小生成树(MST):最短路径:拓扑排序:关键路径: 算法用途适用条件时间复杂度Kruskal最小生成树无向图(稀疏图)O(E log E)Prim最小生成树无…...
实战技巧:如何快速提高网站收录的权威性?
本文转自:百万收录网 原文链接:https://www.baiwanshoulu.com/68.html 快速提高网站收录的权威性是一个系统性的工作,涉及内容质量、网站结构、外部链接、用户体验等多个方面。以下是一些实战技巧,可以帮助你快速提升网站收录的权…...
BUU16 [ACTF2020 新生赛]BackupFile1
扫到index.php.bak 实在扫不出来可以试试一些常有的文件,比如flag.php(flag.php.bak),index.php(index.php.bak) <?php include_once "flag.php";if(isset($_GET[key])) {$key $_GET[key…...
js --- 获取随机数
介绍 使用js获取随机数 代码 Math.random()...
运维之MySQL锁机制(MySQL Lock Mechanism for Operation and Maintenance)
运维之MySQL锁机制 锁是一种常见的并发事务的控制方式。MySQL数据库中的锁机制主要用于控制对数据的并发访问,防止多个用户或进程同时对同一数据进行读写操作,从而避免数据不一致和丢失更新等问题。锁机制确保数据的一致性,保证在多个事务操作…...
用Python实现SVM分类器:从数据到决策边界可视化,以鸢尾花数据集为例
前言 在机器学习的世界里,支持向量机(Support Vector Machine,简称SVM)是一种非常强大的分类算法。它通过寻找最优的决策边界,将不同类别的数据分开。本文将通过一个简单的Python代码示例,展示如何使用SVM…...
pytorch使用SVM实现文本分类
人工智能例子汇总:AI常见的算法和例子-CSDN博客 完整代码: import torch import torch.nn as nn import torch.optim as optim import jieba import numpy as np from sklearn.model_selection import train_test_split from sklearn.feature_extract…...
一文速览DeepSeek-R1的本地部署——可联网、可实现本地知识库问答:包括671B满血版和各个蒸馏版的部署
前言 自从deepseek R1发布之后「详见《一文速览DeepSeek R1:如何通过纯RL训练大模型的推理能力以比肩甚至超越OpenAI o1(含Kimi K1.5的解读)》」,deepseek便爆火 爆火以后便应了“人红是非多”那句话,不但遭受各种大规模攻击,即便…...
Kubernetes学习之包管理工具(Helm)
一、基础知识 1.如果我们需要开发微服务架构的应用,组成应用的服务可能很多,使用原始的组织和管理方式就会非常臃肿和繁琐以及较难管理,此时我们需要一个更高层次的工具将这些配置组织起来。 2.helm架构: chart:一个应用的信息集合…...
2024美团春招硬件开发笔试真题及答案解析
目录 一、选择题 1、在 Linux,有一个名为 file 的文件,内容如下所示: 2、在 Linux 中,关于虚拟内存相关的说法正确的是() 3、AT89S52单片机中,在外部中断响应的期间,中断请求标志位查询占用了()。 4、下列关于8051单片机的结构与功能,说法不正确的是()? 5、…...
MyBatis-Plus速成指南:通用枚举 多数据源
通用枚举: 概述: 表中有些字段值是固定的,例如性别(男或女),此时我们可以使用 MyBatis-Plus 的通用枚举来实现 数据库表添加字段: 创建通用枚举类型: Getter public enum SexEnum {MALE(1, "男"…...
Android项目中使用Eclipse导出jar文件
2014年3月24日 天气晴朗 关于打包Android组件肯定是有用到的,比如开发了一个模块,为了更好的复用,我们可能会将它打包成jar文件方便其他项目引用。这个很好理解,也很简单。网上有一堆关于用Eclipse将Android项目打包成jar文件的&…...
网络安全学习 day4
防火墙的安全策略 规则--策略 条件 --- 检查报文的依据,防火墙将报文中携带的信息与条件逐一进行对比, 以此来判断报文是否是 匹配的 。不同的匹配条件之间属于 “ 与 ” 关系;相同的匹配条件中不同的参数信息之间的关系为 “ 或 ” 关系。…...
【SSM】Spring + SpringMVC + Mybatis
SSM课程,以下为该课程的笔记 bean:IOC容器创建的对象 P12 bean的生命周期 在bean中定义init()和destroy()方法,然后在xml中配置方法名,让bean对象能找到对应的生命周期方法。 或通过实现接口的方式定义声明周期方法。 P13 sett…...
智慧园区综合管理系统如何实现多个维度的高效管理与安全风险控制
内容概要 在当前快速发展的城市环境中,智慧园区综合管理系统正在成为各类园区管理的重要工具,无论是工业园、产业园、物流园,还是写字楼与公寓,都在积极寻求如何提升管理效率和保障安全。通过快鲸智慧园区管理系统,用…...
【协议详解】卫星通信5G IoT NTN SIB33-NB 信令详解
一、SIB33信令概述 在5G非地面网络(NTN)中,卫星的高速移动性和广域覆盖特性使得地面设备(UE)需要频繁切换卫星以维持连接。SIB32提供了UE预测当前服务的卫星覆盖信息,SystemInformationBlockType33&#x…...
《LLM大语言模型深度探索与实践:构建智能应用的新范式,融合代理与数据库的高级整合》
文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...
Debian 10 中 Linux 4.19 内核在 x86_64 架构上对中断嵌套的支持情况
一、中断嵌套的定义与原理 中断嵌套是指在一个中断处理程序(ISR)正在执行的过程中,另一个更高优先级的中断请求到来,系统暂停当前中断处理程序,转而处理新的高优先级中断。处理完高优先级中断后,系统返回到原来的中断处理程序继续执行。这种机制允许系统更高效地响应紧急…...
【Envi遥感图像处理】010:归一化植被指数NDVI计算方法
文章目录 一、NDVI简介二、NDVI计算方法1. NDVI工具2. 波段运算三、注意事项1. 计算结果为一片黑2. 计算结果超出范围一、NDVI简介 归一化植被指数,是反映农作物长势和营养信息的重要参数之一,应用于遥感影像。NDVI是通过植被在近红外波段(NIR)和红光波段(R)的反射率差异…...
优选算法合集————双指针(专题二)
好久都没给大家带来算法专题啦,今天给大家带来滑动窗口专题的训练 题目一:长度最小的子数组 题目描述: 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其和 ≥ target 的长度最小的 连续子数组 [numsl, numsl1, …...
基于微信小程序的私家车位共享系统设计与实现(LW+源码+讲解)
专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...
糖化之前,为什么要进行麦芽粉碎?
糖化的目的是将麦芽中的淀粉转化为可发酵性的糖分,而糖化之前,进行麦芽粉碎是确保糖化效果的关键步骤。本文天泰将阐述麦芽粉碎的重要性及其对酿造过程的影响。 一、麦芽粉碎的目的 增加酶的作用面积:麦芽中的淀粉和蛋白质等物质需要通过酶…...
PAT甲级1052、Linked LIst Sorting
题目 A linked list consists of a series of structures, which are not necessarily adjacent in memory. We assume that each structure contains an integer key and a Next pointer to the next structure. Now given a linked list, you are supposed to sort the stru…...
半导体器件与物理篇6 MESFET
金属-半导体接触 MESFET与MOSFET的相同点:它们的电压电流特性相似。都有源漏栅三极,强反型,漏极加正向电压,也会经历线性区、夹断点、饱和区三个阶段。 MESFET与MOSFET的不同点:在器件的栅电极部分,MESFE…...
BES2700源码解析之系统初始化
一 概述 bes2700凭借着超高的性能,超低的功耗,在可穿戴领域有着广泛的应用。笔者使用该芯片做了一些产品解决方案,发现该芯片的性能十分强大。这里做个系列的源码解析。 二 源码解析 1.GPIO和led灯的初始化: tgt_hardware_setup(…...
