【自制C++深度学习推理框架】Layer的设计思路
Layer的设计思路
Layer的抽象
如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。
基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的Layer。
具体来说,该类包括以下几个成员函数:
-
构造函数
Layer(std::string layer_name),用于创建一个Layer对象并设置该层的名称。 -
virtual ~Layer() = default,虚析构函数,在派生类中可以通过override关键字重新定义。 -
virtual InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs),前向传播函数,将输入tensor作为参数,计算输出tensor。 -
virtual const std::vector<std::shared_ptr<Tensor<float>>> &weights() const, 返回当前层的权重数组。 -
virtual const std::vector<std::shared_ptr<Tensor<float>>> &bias() const, 返回当前层的偏置数组。 -
virtual void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights),设置当前层的权重数组。 -
virtual void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias),设置当前层的偏置数组。 -
virtual void set_weights(const std::vector<float> &weights),将权重数据类型转换为shared_ptr后调用上述函数。 -
virtual void set_bias(const std::vector<float> &bias),将偏置数据类型转换为shared_ptr后调用上述函数。 -
virtual const std::string &layer_name() const,返回当前层的名称。
而成员变量只有一个,即
std::string layer_name_,Layer的名称
为什么定义成虚函数
在神经网络中,不同的层具有不同的结构和运算方式,因此需要不同的函数来实现它们。使用虚函数的方法可以将这些不同的函数封装到一个基类中,并通过多态机制来实现不同类型的层的动态绑定。
具体来说,当使用基类指针或引用调用虚函数时,程序会根据对象的动态类型(即实际指向的派生类类型)来选择相应的函数实现。这就使得不同类型的层可以通过共同的接口进行调用,从而提高了代码的可维护性和可扩展性。
此外,使用虚函数还可以方便地定义抽象类,即只声明虚函数但不提供实现的类。这可以为子类提供一个规范化的接口,要求其必须重写某些接口以满足特定的需求。这种机制可以有效避免在大型工程中出现微小的差错而导致底层实现不符合最终需求的问题。
带权重Layer的实现
我们把Layer基类来表示不带参数的Layer,并且通过继承该Layer基类的方式来定义了一个带参数的层ParamLayer子类,在ParamLayer中定义了成员变量bias和weights。
ParamLayer是具有可调参数的神经网络层实现,包括初始化权重和偏置的函数、重载读写权重和偏置的函数,以及保存权重和偏置的成员变量。
具体来说,该类包括以下几个成员函数和成员变量:
-
构造函数
ParamLayer(const std::string &layer_name),用于创建一个ParamLayer对象并设置该层的名称。 -
void InitWeightParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化权重参数。 -
void InitBiasParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化偏置参数。 -
const std::vector<std::shared_ptr<Tensor<float>>> &weights() const override,重载虚函数weights(),返回保存权重参数的成员变量weights_。 -
const std::vector<std::shared_ptr<Tensor<float>>> &bias() const override,重载虚函数bias(),返回保存偏置参数的成员变量bias_。 -
void set_weights(const std::vector<float> &weights) override,重载虚函数set_weights(),将权重数据类型转换为shared_ptr后存储在成员变量weights_中。 -
void set_bias(const std::vector<float> &bias) override,重载虚函数set_bias(),将偏置数据类型转换为shared_ptr后存储在成员变量bias_中。 -
void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights) override,重载虚函数set_weights(),将参数复制到成员变量weights_中。 -
void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias) override,重载虚函数set_bias(),将参数复制到成员变量bias_中。 -
成员变量
std::vector<std::shared_ptr<Tensor<float>>> weights_,保存ParamLayer的权重参数。 -
成员变量
std::vector<std::shared_ptr<Tensor<float>>> bias_,保存ParamLayer的偏置参数。
ParamLayer通过继承Layer类实现了一些共同接口,并在此基础上扩展了更多函数和成员,可以方便地实现带有参数的神经网络层。
Layer的注册机制
为了实现注册和创建神经网络层,并在运行时动态地生成不同类型的神经网络层,定义了两个类:LayerRegisterer和LayerRegistererWrapper。
具体来说,LayerRegisterer类提供了三个静态函数和一个静态成员变量:
-
typedef ParseParameterAttrStatus (*Creator)(const std::shared_ptr<RuntimeOperator> &op, std::shared_ptr<Layer> &layer):定义了一个函数指针类型Creator,用于指向具体神经网络层的函数。 -
typedef std::map<std::string, Creator> CreateRegistry:定义了一个映射类型CreateRegistry,用于保存层类型和对应创建函数的映射关系。 -
static void RegisterCreator(const std::string &layer_type, const Creator &creator):将层类型和创建函数的映射关系注册到CreateRegistry中。 -
static std::shared_ptr<Layer> CreateLayer(const std::shared_ptr<RuntimeOperator> &op):根据输入的op对象创建相应的神经网络层。 -
static CreateRegistry &Registry():返回当前已经注册的所有层类型和创建函数的映射关系。
RuntimeOperator是计算图的某个计算节点,里面保存了计算节点所需的参数等信息,具体介绍请看3.Graph.md。
而LayerRegistererWrapper类则提供了一个构造函数,用于将某一种类型的神经网络层和其创建函数注册到LayerRegisterer中,如下所示。
class LayerRegistererWrapper {public:LayerRegistererWrapper(const std::string &layer_type, const LayerRegisterer::Creator &creator) {LayerRegisterer::RegisterCreator(layer_type, creator);}
};
在LayerRegisterer类中,通过维护一个键值对(<std::string, Creator>)CreateRegistry,管理Layer注册表,在注册和查找Layer时都要先检查一下是否注册,如果未注册输出错误信息。
为什么要把成员函数定义为静态的
静态函数与类相关联,而不是与类的对象相关。因此,静态函数可以在没有创建类的实例的情况下调用,从而方便地提供一些辅助函数或管理函数,例如工厂方法、单例等。
LayerRegisterer和LayerRegistererWrapper中定义的所有函数都是静态的,主要原因是这些函数需要全局地维护层类型和创建函数的映射关系,并控制新层类型的注册和创建过程。使用静态函数可以使得这些功能在整个程序中被共享和访问,同时避免了由于对象实例的含糊不清而导致的编码错误。
另外需要注意的是,静态函数可以直接使用静态成员变量,不需要通过对象来访问,这使得这些静态函数可以更容易地协同工作,并兼顾了效率和灵活性。
阅读的代码
- include
- layer
- abstract
- layer_factory.hpp
- layer.hpp
- param_layer.hpp
- abstract
- layer
- source
- layer
- abstract
- layer.cpp
- layer_factory.cpp
- param_layer.cpp
- abstract
- layer
相关文章:
【自制C++深度学习推理框架】Layer的设计思路
Layer的设计思路 Layer的抽象 如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。 基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的L…...
Rust每日一练(Leetday0011) 下一排列、有效括号、搜索旋转数组
目录 31. 下一个排列 Next Permutation 🌟🌟 32. 最长有效括号 Longest Valid Parentheses 🌟🌟🌟 33. 搜索旋转排序数组 Search-in-rotated-sorted-array 🌟🌟 🌟 每日一练刷…...
STL --- 五. 函数对象 Function Objects
目录 1、函数对象的定义和作用 2、函数对象的分类和使用 3、std 常用的函数对象 4、函数对象的适配器 5、std 算法和函数对象区别 1、函数对象的定义和作用 STL(Standard Template Library)中的函数对象(Functor)是一种重载…...
Java IO 流操作详解
Java IO 流操作详解 一、简介1. 什么是IO流2. IO流的分类3. IO流的作用 二、Java IO流的输入操作1. 文件输入流2. 字节输入流3. 缓冲输入流4. 对象输入流 三、Java IO流的输出操作1. 文件输出流2. 字节输出流3. 缓冲输出流4. 对象输出流 四、Java IO流的常用方法解析1. 字节读写…...
Halcon 形状匹配参数详解
find_shape_model(Image : : ModelID, AngleStart, AngleExtent, MinScore, NumMatches, MaxOverlap, SubPixel, NumLevels, Greediness : Row, Column, Angle, Score) find_shape_model(Image : : //搜索图像 ModelID, //模板句柄 AngleStart, // 搜索时的起始角度 AngleExte…...
C++11强类型枚举
C11引入了强类型枚举(enum class),也称为枚举类。 强类型枚举是一种更加类型安全的枚举类型,相对于传统的枚举类型,强类型枚举可以提供更好的安全性和可读性。 强类型枚举的格式如下: enum class 枚举名 …...
pytorch讲解(部分)
友爱的目录 自动求导机制从后向中排除子图自动求导如何编码历史信息Variable上的In-place操作In-place正确性检查 CUDA语义最佳实践使用固定的内存缓冲区使用 nn.DataParallel 替代 multiprocessing 扩展PyTorch扩展 torch.autograd扩展 torch.nn 多进程最佳实践共享CUDA张量最…...
C++ 基本的7种数据类型和4种类型转换(C++复习向p3)
文章目录 基本内置类型存储范围typedef 声明新名字enum 枚举类型类型转换 基本内置类型 boolcharintfloatdoublevoidwchar_t ⇒ short int 存储范围 可以这样 sizeof(int) 来确认 int 占用字节数 char,1字节,-128~127 或 0~255 wchar_t,2…...
Scrum敏捷迭代规划和执行
Sprint Backlog看板 迭代工作的开展是围绕Sprint Backlog展开的,在Leangoo中,我们需要为每个迭代创建一个Sprint Backlog看板。Sprint Backlog(迭代)看板,用于管理当前Sprint的需求和开发任务,可视化展示每…...
智警杯赛前学习1.1---excel基本操作
修改默认设置 步骤一:打开“Excel选项”窗口,打开“文件”菜单,选择“选项”标签 步骤二:在“Excel选项”窗口中,选择“常规与保存”标签,在“常规与保存”标签中,可以修改录入数据时的默认字体…...
【Android】Handle(一) 主要特点和用途
在Android中,Handler是一种消息处理机制,它允许我们在不同线程之间交换信息并更新UI。具体来说,Handler可以将一个Runnable或Message对象加入到消息队列中,并在合适的时间去执行它们。 以下是Handler的主要特点和用途:…...
40亿个QQ号,限制1G内存,如何去重?【已通过代码实现】
前几天发现一个有趣的文章 “40亿个QQ号,限制1G内存,如何去重?”,发现很有意思,就想着用代码实现一下,下面是分析和实现过程 一、审题分析 一个 QQ 号现在最长有 11 位,因为 int 是四字节,数值范围是2的31次方,因此得使用 long 存储,但考虑到实现,使用 int 存储(1…...
Talk预告 | 新加坡国立大学张傲:10%成本定制类 GPT-4 多模态大模型
本期为TechBeat人工智能社区第502期线上Talk! 北京时间06月01日(周四)20:00,新加坡国立大学在读博士生 — 张傲的Talk将准时在TechBeat人工智能社区开播! 他与大家分享的主题是: “10%成本定制类 GPT-4 多模态大模型 ”,届时将介…...
从C语言到C++_13(string的模拟实现)深浅拷贝+传统/现代写法
前面已经对 string 类进行了简单的介绍和应用,大家只要能够正常使用即可。 在面试中,面试官总喜欢让学生自己 来模拟实现string类, 最主要是实现string类的构造、拷贝构造、赋值运算符重载以及析构函数。 为了更深入学习STL,下面我…...
reduce()方法详解
一、 定义和用法 reduce() 方法将数组缩减为单个值。 reduce() 方法为数组的每个值(从左到右)执行提供的函数。 函数的返回值存储在累加器中(结果/总计)。 注释:对没有值的数组元素,不执行 reduce() 方法。…...
C++虚假唤醒
概念: 虚假唤醒是指在使用条件变量时,线程被唤醒但条件并没有满足,导致线程执行错误的情况,这个过程就是虚假唤醒。 虚假唤醒弊端: 虚假唤醒会导致程序的正确性受到影响,因为唤醒的线程并没有满足条件&…...
【AI】dragonGPT - 单机部署、极速便捷
dragonGPT 从数据私有化,到prompt向量库匹配,再到查询,一条龙服务,单机部署,极简操作 pre a.需要下载gpt4all model到本地. ggml Model Download Link 然后将存放model的地址写入.env MODEL_PATH your pathb.…...
Uuiapp使用生命周期,路由跳转传参
Uniapp生命周期: 1. beforeCreate:在实例初始化之后,数据观测和事件配置之前被调用。 2. created:在实例创建完成后被立即调用。 3. beforeMount:在挂载开始之前被调用:相关的 render 函数首次被调用。 …...
定积分的计算(牛顿-莱布尼茨公式)习题
前置知识:定积分的计算(牛顿-莱布尼茨公式) 习题1 计算 ∫ 0 2 ( x 2 − 2 x 3 ) d x \int_0^2(x^2-2x3)dx ∫02(x2−2x3)dx 解: \qquad 原式 ( 1 3 x 3 − x 2 3 x ) ∣ 0 2 ( 8 3 − 4 6 ) − 0 14 3 (\dfrac 13x^3-…...
leak 记录今天的一个小题
先看题, add没有大小限制,这里edit可以溢出8字节,也就是可以改后边的size,可以调用4次free没有调用函数只是把指针置0,show可以用一次. void __fastcall __noreturn main(__int64 a1, char **a2, char **a3) {init_0(a1, a2, a3);while ( 1 ){menu();switch ( read_n() ){cas…...
利用ngx_stream_return_module构建简易 TCP/UDP 响应网关
一、模块概述 ngx_stream_return_module 提供了一个极简的指令: return <value>;在收到客户端连接后,立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量(如 $time_iso8601、$remote_addr 等)&a…...
java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别
UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...
Python实现prophet 理论及参数优化
文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候,写过一篇简单实现,后期随着对该模型的深入研究,本次记录涉及到prophet 的公式以及参数调优,从公式可以更直观…...
CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云
目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...
回溯算法学习
一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...
【分享】推荐一些办公小工具
1、PDF 在线转换 https://smallpdf.com/cn/pdf-tools 推荐理由:大部分的转换软件需要收费,要么功能不齐全,而开会员又用不了几次浪费钱,借用别人的又不安全。 这个网站它不需要登录或下载安装。而且提供的免费功能就能满足日常…...
使用LangGraph和LangSmith构建多智能体人工智能系统
现在,通过组合几个较小的子智能体来创建一个强大的人工智能智能体正成为一种趋势。但这也带来了一些挑战,比如减少幻觉、管理对话流程、在测试期间留意智能体的工作方式、允许人工介入以及评估其性能。你需要进行大量的反复试验。 在这篇博客〔原作者&a…...
CSS | transition 和 transform的用处和区别
省流总结: transform用于变换/变形,transition是动画控制器 transform 用来对元素进行变形,常见的操作如下,它是立即生效的样式变形属性。 旋转 rotate(角度deg)、平移 translateX(像素px)、缩放 scale(倍数)、倾斜 skewX(角度…...
快刀集(1): 一刀斩断视频片头广告
一刀流:用一个简单脚本,秒杀视频片头广告,还你清爽观影体验。 1. 引子 作为一个爱生活、爱学习、爱收藏高清资源的老码农,平时写代码之余看看电影、补补片,是再正常不过的事。 电影嘛,要沉浸,…...
解读《网络安全法》最新修订,把握网络安全新趋势
《网络安全法》自2017年施行以来,在维护网络空间安全方面发挥了重要作用。但随着网络环境的日益复杂,网络攻击、数据泄露等事件频发,现行法律已难以完全适应新的风险挑战。 2025年3月28日,国家网信办会同相关部门起草了《网络安全…...
