【自制C++深度学习推理框架】Layer的设计思路
Layer的设计思路
Layer的抽象
如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。
基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的Layer。
具体来说,该类包括以下几个成员函数:
-
构造函数
Layer(std::string layer_name),用于创建一个Layer对象并设置该层的名称。 -
virtual ~Layer() = default,虚析构函数,在派生类中可以通过override关键字重新定义。 -
virtual InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs),前向传播函数,将输入tensor作为参数,计算输出tensor。 -
virtual const std::vector<std::shared_ptr<Tensor<float>>> &weights() const, 返回当前层的权重数组。 -
virtual const std::vector<std::shared_ptr<Tensor<float>>> &bias() const, 返回当前层的偏置数组。 -
virtual void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights),设置当前层的权重数组。 -
virtual void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias),设置当前层的偏置数组。 -
virtual void set_weights(const std::vector<float> &weights),将权重数据类型转换为shared_ptr后调用上述函数。 -
virtual void set_bias(const std::vector<float> &bias),将偏置数据类型转换为shared_ptr后调用上述函数。 -
virtual const std::string &layer_name() const,返回当前层的名称。
而成员变量只有一个,即
std::string layer_name_,Layer的名称
为什么定义成虚函数
在神经网络中,不同的层具有不同的结构和运算方式,因此需要不同的函数来实现它们。使用虚函数的方法可以将这些不同的函数封装到一个基类中,并通过多态机制来实现不同类型的层的动态绑定。
具体来说,当使用基类指针或引用调用虚函数时,程序会根据对象的动态类型(即实际指向的派生类类型)来选择相应的函数实现。这就使得不同类型的层可以通过共同的接口进行调用,从而提高了代码的可维护性和可扩展性。
此外,使用虚函数还可以方便地定义抽象类,即只声明虚函数但不提供实现的类。这可以为子类提供一个规范化的接口,要求其必须重写某些接口以满足特定的需求。这种机制可以有效避免在大型工程中出现微小的差错而导致底层实现不符合最终需求的问题。
带权重Layer的实现
我们把Layer基类来表示不带参数的Layer,并且通过继承该Layer基类的方式来定义了一个带参数的层ParamLayer子类,在ParamLayer中定义了成员变量bias和weights。
ParamLayer是具有可调参数的神经网络层实现,包括初始化权重和偏置的函数、重载读写权重和偏置的函数,以及保存权重和偏置的成员变量。
具体来说,该类包括以下几个成员函数和成员变量:
-
构造函数
ParamLayer(const std::string &layer_name),用于创建一个ParamLayer对象并设置该层的名称。 -
void InitWeightParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化权重参数。 -
void InitBiasParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化偏置参数。 -
const std::vector<std::shared_ptr<Tensor<float>>> &weights() const override,重载虚函数weights(),返回保存权重参数的成员变量weights_。 -
const std::vector<std::shared_ptr<Tensor<float>>> &bias() const override,重载虚函数bias(),返回保存偏置参数的成员变量bias_。 -
void set_weights(const std::vector<float> &weights) override,重载虚函数set_weights(),将权重数据类型转换为shared_ptr后存储在成员变量weights_中。 -
void set_bias(const std::vector<float> &bias) override,重载虚函数set_bias(),将偏置数据类型转换为shared_ptr后存储在成员变量bias_中。 -
void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights) override,重载虚函数set_weights(),将参数复制到成员变量weights_中。 -
void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias) override,重载虚函数set_bias(),将参数复制到成员变量bias_中。 -
成员变量
std::vector<std::shared_ptr<Tensor<float>>> weights_,保存ParamLayer的权重参数。 -
成员变量
std::vector<std::shared_ptr<Tensor<float>>> bias_,保存ParamLayer的偏置参数。
ParamLayer通过继承Layer类实现了一些共同接口,并在此基础上扩展了更多函数和成员,可以方便地实现带有参数的神经网络层。
Layer的注册机制
为了实现注册和创建神经网络层,并在运行时动态地生成不同类型的神经网络层,定义了两个类:LayerRegisterer和LayerRegistererWrapper。
具体来说,LayerRegisterer类提供了三个静态函数和一个静态成员变量:
-
typedef ParseParameterAttrStatus (*Creator)(const std::shared_ptr<RuntimeOperator> &op, std::shared_ptr<Layer> &layer):定义了一个函数指针类型Creator,用于指向具体神经网络层的函数。 -
typedef std::map<std::string, Creator> CreateRegistry:定义了一个映射类型CreateRegistry,用于保存层类型和对应创建函数的映射关系。 -
static void RegisterCreator(const std::string &layer_type, const Creator &creator):将层类型和创建函数的映射关系注册到CreateRegistry中。 -
static std::shared_ptr<Layer> CreateLayer(const std::shared_ptr<RuntimeOperator> &op):根据输入的op对象创建相应的神经网络层。 -
static CreateRegistry &Registry():返回当前已经注册的所有层类型和创建函数的映射关系。
RuntimeOperator是计算图的某个计算节点,里面保存了计算节点所需的参数等信息,具体介绍请看3.Graph.md。
而LayerRegistererWrapper类则提供了一个构造函数,用于将某一种类型的神经网络层和其创建函数注册到LayerRegisterer中,如下所示。
class LayerRegistererWrapper {public:LayerRegistererWrapper(const std::string &layer_type, const LayerRegisterer::Creator &creator) {LayerRegisterer::RegisterCreator(layer_type, creator);}
};
在LayerRegisterer类中,通过维护一个键值对(<std::string, Creator>)CreateRegistry,管理Layer注册表,在注册和查找Layer时都要先检查一下是否注册,如果未注册输出错误信息。
为什么要把成员函数定义为静态的
静态函数与类相关联,而不是与类的对象相关。因此,静态函数可以在没有创建类的实例的情况下调用,从而方便地提供一些辅助函数或管理函数,例如工厂方法、单例等。
LayerRegisterer和LayerRegistererWrapper中定义的所有函数都是静态的,主要原因是这些函数需要全局地维护层类型和创建函数的映射关系,并控制新层类型的注册和创建过程。使用静态函数可以使得这些功能在整个程序中被共享和访问,同时避免了由于对象实例的含糊不清而导致的编码错误。
另外需要注意的是,静态函数可以直接使用静态成员变量,不需要通过对象来访问,这使得这些静态函数可以更容易地协同工作,并兼顾了效率和灵活性。
阅读的代码
- include
- layer
- abstract
- layer_factory.hpp
- layer.hpp
- param_layer.hpp
- abstract
- layer
- source
- layer
- abstract
- layer.cpp
- layer_factory.cpp
- param_layer.cpp
- abstract
- layer
相关文章:
【自制C++深度学习推理框架】Layer的设计思路
Layer的设计思路 Layer的抽象 如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。 基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的L…...
Rust每日一练(Leetday0011) 下一排列、有效括号、搜索旋转数组
目录 31. 下一个排列 Next Permutation 🌟🌟 32. 最长有效括号 Longest Valid Parentheses 🌟🌟🌟 33. 搜索旋转排序数组 Search-in-rotated-sorted-array 🌟🌟 🌟 每日一练刷…...
STL --- 五. 函数对象 Function Objects
目录 1、函数对象的定义和作用 2、函数对象的分类和使用 3、std 常用的函数对象 4、函数对象的适配器 5、std 算法和函数对象区别 1、函数对象的定义和作用 STL(Standard Template Library)中的函数对象(Functor)是一种重载…...
Java IO 流操作详解
Java IO 流操作详解 一、简介1. 什么是IO流2. IO流的分类3. IO流的作用 二、Java IO流的输入操作1. 文件输入流2. 字节输入流3. 缓冲输入流4. 对象输入流 三、Java IO流的输出操作1. 文件输出流2. 字节输出流3. 缓冲输出流4. 对象输出流 四、Java IO流的常用方法解析1. 字节读写…...
Halcon 形状匹配参数详解
find_shape_model(Image : : ModelID, AngleStart, AngleExtent, MinScore, NumMatches, MaxOverlap, SubPixel, NumLevels, Greediness : Row, Column, Angle, Score) find_shape_model(Image : : //搜索图像 ModelID, //模板句柄 AngleStart, // 搜索时的起始角度 AngleExte…...
C++11强类型枚举
C11引入了强类型枚举(enum class),也称为枚举类。 强类型枚举是一种更加类型安全的枚举类型,相对于传统的枚举类型,强类型枚举可以提供更好的安全性和可读性。 强类型枚举的格式如下: enum class 枚举名 …...
pytorch讲解(部分)
友爱的目录 自动求导机制从后向中排除子图自动求导如何编码历史信息Variable上的In-place操作In-place正确性检查 CUDA语义最佳实践使用固定的内存缓冲区使用 nn.DataParallel 替代 multiprocessing 扩展PyTorch扩展 torch.autograd扩展 torch.nn 多进程最佳实践共享CUDA张量最…...
C++ 基本的7种数据类型和4种类型转换(C++复习向p3)
文章目录 基本内置类型存储范围typedef 声明新名字enum 枚举类型类型转换 基本内置类型 boolcharintfloatdoublevoidwchar_t ⇒ short int 存储范围 可以这样 sizeof(int) 来确认 int 占用字节数 char,1字节,-128~127 或 0~255 wchar_t,2…...
Scrum敏捷迭代规划和执行
Sprint Backlog看板 迭代工作的开展是围绕Sprint Backlog展开的,在Leangoo中,我们需要为每个迭代创建一个Sprint Backlog看板。Sprint Backlog(迭代)看板,用于管理当前Sprint的需求和开发任务,可视化展示每…...
智警杯赛前学习1.1---excel基本操作
修改默认设置 步骤一:打开“Excel选项”窗口,打开“文件”菜单,选择“选项”标签 步骤二:在“Excel选项”窗口中,选择“常规与保存”标签,在“常规与保存”标签中,可以修改录入数据时的默认字体…...
【Android】Handle(一) 主要特点和用途
在Android中,Handler是一种消息处理机制,它允许我们在不同线程之间交换信息并更新UI。具体来说,Handler可以将一个Runnable或Message对象加入到消息队列中,并在合适的时间去执行它们。 以下是Handler的主要特点和用途:…...
40亿个QQ号,限制1G内存,如何去重?【已通过代码实现】
前几天发现一个有趣的文章 “40亿个QQ号,限制1G内存,如何去重?”,发现很有意思,就想着用代码实现一下,下面是分析和实现过程 一、审题分析 一个 QQ 号现在最长有 11 位,因为 int 是四字节,数值范围是2的31次方,因此得使用 long 存储,但考虑到实现,使用 int 存储(1…...
Talk预告 | 新加坡国立大学张傲:10%成本定制类 GPT-4 多模态大模型
本期为TechBeat人工智能社区第502期线上Talk! 北京时间06月01日(周四)20:00,新加坡国立大学在读博士生 — 张傲的Talk将准时在TechBeat人工智能社区开播! 他与大家分享的主题是: “10%成本定制类 GPT-4 多模态大模型 ”,届时将介…...
从C语言到C++_13(string的模拟实现)深浅拷贝+传统/现代写法
前面已经对 string 类进行了简单的介绍和应用,大家只要能够正常使用即可。 在面试中,面试官总喜欢让学生自己 来模拟实现string类, 最主要是实现string类的构造、拷贝构造、赋值运算符重载以及析构函数。 为了更深入学习STL,下面我…...
reduce()方法详解
一、 定义和用法 reduce() 方法将数组缩减为单个值。 reduce() 方法为数组的每个值(从左到右)执行提供的函数。 函数的返回值存储在累加器中(结果/总计)。 注释:对没有值的数组元素,不执行 reduce() 方法。…...
C++虚假唤醒
概念: 虚假唤醒是指在使用条件变量时,线程被唤醒但条件并没有满足,导致线程执行错误的情况,这个过程就是虚假唤醒。 虚假唤醒弊端: 虚假唤醒会导致程序的正确性受到影响,因为唤醒的线程并没有满足条件&…...
【AI】dragonGPT - 单机部署、极速便捷
dragonGPT 从数据私有化,到prompt向量库匹配,再到查询,一条龙服务,单机部署,极简操作 pre a.需要下载gpt4all model到本地. ggml Model Download Link 然后将存放model的地址写入.env MODEL_PATH your pathb.…...
Uuiapp使用生命周期,路由跳转传参
Uniapp生命周期: 1. beforeCreate:在实例初始化之后,数据观测和事件配置之前被调用。 2. created:在实例创建完成后被立即调用。 3. beforeMount:在挂载开始之前被调用:相关的 render 函数首次被调用。 …...
定积分的计算(牛顿-莱布尼茨公式)习题
前置知识:定积分的计算(牛顿-莱布尼茨公式) 习题1 计算 ∫ 0 2 ( x 2 − 2 x 3 ) d x \int_0^2(x^2-2x3)dx ∫02(x2−2x3)dx 解: \qquad 原式 ( 1 3 x 3 − x 2 3 x ) ∣ 0 2 ( 8 3 − 4 6 ) − 0 14 3 (\dfrac 13x^3-…...
leak 记录今天的一个小题
先看题, add没有大小限制,这里edit可以溢出8字节,也就是可以改后边的size,可以调用4次free没有调用函数只是把指针置0,show可以用一次. void __fastcall __noreturn main(__int64 a1, char **a2, char **a3) {init_0(a1, a2, a3);while ( 1 ){menu();switch ( read_n() ){cas…...
使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...
Ubuntu系统下交叉编译openssl
一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...
docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
MFC内存泄露
1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...
多场景 OkHttpClient 管理器 - Android 网络通信解决方案
下面是一个完整的 Android 实现,展示如何创建和管理多个 OkHttpClient 实例,分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...
【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案
随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...
