当前位置: 首页 > news >正文

【自制C++深度学习推理框架】Layer的设计思路

Layer的设计思路

Layer的抽象

如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。

基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的Layer。

具体来说,该类包括以下几个成员函数:

  1. 构造函数 Layer(std::string layer_name),用于创建一个Layer对象并设置该层的名称。

  2. virtual ~Layer() = default,虚析构函数,在派生类中可以通过override关键字重新定义。

  3. virtual InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs) ,前向传播函数,将输入tensor作为参数,计算输出tensor。

  4. virtual const std::vector<std::shared_ptr<Tensor<float>>> &weights() const, 返回当前层的权重数组。

  5. virtual const std::vector<std::shared_ptr<Tensor<float>>> &bias() const, 返回当前层的偏置数组。

  6. virtual void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights),设置当前层的权重数组。

  7. virtual void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias),设置当前层的偏置数组。

  8. virtual void set_weights(const std::vector<float> &weights),将权重数据类型转换为shared_ptr后调用上述函数。

  9. virtual void set_bias(const std::vector<float> &bias),将偏置数据类型转换为shared_ptr后调用上述函数。

  10. virtual const std::string &layer_name() const,返回当前层的名称。

而成员变量只有一个,即

  • std::string layer_name_,Layer的名称

为什么定义成虚函数

在神经网络中,不同的层具有不同的结构和运算方式,因此需要不同的函数来实现它们。使用虚函数的方法可以将这些不同的函数封装到一个基类中,并通过多态机制来实现不同类型的层的动态绑定。

具体来说,当使用基类指针或引用调用虚函数时,程序会根据对象的动态类型(即实际指向的派生类类型)来选择相应的函数实现。这就使得不同类型的层可以通过共同的接口进行调用,从而提高了代码的可维护性和可扩展性。

此外,使用虚函数还可以方便地定义抽象类,即只声明虚函数但不提供实现的类。这可以为子类提供一个规范化的接口,要求其必须重写某些接口以满足特定的需求。这种机制可以有效避免在大型工程中出现微小的差错而导致底层实现不符合最终需求的问题。

带权重Layer的实现

我们把Layer基类来表示不带参数的Layer,并且通过继承该Layer基类的方式来定义了一个带参数的层ParamLayer子类,在ParamLayer中定义了成员变量bias和weights。

ParamLayer是具有可调参数的神经网络层实现,包括初始化权重和偏置的函数、重载读写权重和偏置的函数,以及保存权重和偏置的成员变量。

具体来说,该类包括以下几个成员函数和成员变量:

  1. 构造函数 ParamLayer(const std::string &layer_name),用于创建一个ParamLayer对象并设置该层的名称。

  2. void InitWeightParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化权重参数。

  3. void InitBiasParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化偏置参数。

  4. const std::vector<std::shared_ptr<Tensor<float>>> &weights() const override,重载虚函数weights(),返回保存权重参数的成员变量weights_

  5. const std::vector<std::shared_ptr<Tensor<float>>> &bias() const override,重载虚函数bias(),返回保存偏置参数的成员变量bias_

  6. void set_weights(const std::vector<float> &weights) override,重载虚函数set_weights(),将权重数据类型转换为shared_ptr后存储在成员变量weights_中。

  7. void set_bias(const std::vector<float> &bias) override,重载虚函数set_bias(),将偏置数据类型转换为shared_ptr后存储在成员变量bias_中。

  8. void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights) override,重载虚函数set_weights(),将参数复制到成员变量weights_中。

  9. void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias) override,重载虚函数set_bias(),将参数复制到成员变量bias_中。

  10. 成员变量std::vector<std::shared_ptr<Tensor<float>>> weights_,保存ParamLayer的权重参数。

  11. 成员变量std::vector<std::shared_ptr<Tensor<float>>> bias_,保存ParamLayer的偏置参数。

ParamLayer通过继承Layer类实现了一些共同接口,并在此基础上扩展了更多函数和成员,可以方便地实现带有参数的神经网络层。

Layer的注册机制

为了实现注册和创建神经网络层,并在运行时动态地生成不同类型的神经网络层,定义了两个类:LayerRegisterer和LayerRegistererWrapper。

具体来说,LayerRegisterer类提供了三个静态函数和一个静态成员变量:

  1. typedef ParseParameterAttrStatus (*Creator)(const std::shared_ptr<RuntimeOperator> &op, std::shared_ptr<Layer> &layer):定义了一个函数指针类型Creator,用于指向具体神经网络层的函数。

  2. typedef std::map<std::string, Creator> CreateRegistry:定义了一个映射类型CreateRegistry,用于保存层类型和对应创建函数的映射关系。

  3. static void RegisterCreator(const std::string &layer_type, const Creator &creator):将层类型和创建函数的映射关系注册到CreateRegistry中。

  4. static std::shared_ptr<Layer> CreateLayer(const std::shared_ptr<RuntimeOperator> &op):根据输入的op对象创建相应的神经网络层。

  5. static CreateRegistry &Registry():返回当前已经注册的所有层类型和创建函数的映射关系。

RuntimeOperator是计算图的某个计算节点,里面保存了计算节点所需的参数等信息,具体介绍请看3.Graph.md

而LayerRegistererWrapper类则提供了一个构造函数,用于将某一种类型的神经网络层和其创建函数注册到LayerRegisterer中,如下所示。

class LayerRegistererWrapper {public:LayerRegistererWrapper(const std::string &layer_type, const LayerRegisterer::Creator &creator) {LayerRegisterer::RegisterCreator(layer_type, creator);}
};

在LayerRegisterer类中,通过维护一个键值对(<std::string, Creator>CreateRegistry,管理Layer注册表,在注册和查找Layer时都要先检查一下是否注册,如果未注册输出错误信息。

为什么要把成员函数定义为静态的

静态函数与类相关联,而不是与类的对象相关。因此,静态函数可以在没有创建类的实例的情况下调用,从而方便地提供一些辅助函数或管理函数,例如工厂方法、单例等。

LayerRegisterer和LayerRegistererWrapper中定义的所有函数都是静态的,主要原因是这些函数需要全局地维护层类型和创建函数的映射关系,并控制新层类型的注册和创建过程。使用静态函数可以使得这些功能在整个程序中被共享和访问,同时避免了由于对象实例的含糊不清而导致的编码错误。

另外需要注意的是,静态函数可以直接使用静态成员变量,不需要通过对象来访问,这使得这些静态函数可以更容易地协同工作,并兼顾了效率和灵活性。

阅读的代码

  • include
    • layer
      • abstract
        • layer_factory.hpp
        • layer.hpp
        • param_layer.hpp
  • source
    • layer
      • abstract
        • layer.cpp
        • layer_factory.cpp
        • param_layer.cpp

相关文章:

【自制C++深度学习推理框架】Layer的设计思路

Layer的设计思路 Layer的抽象 如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。 基于层的共性&#xff0c;我们定义了一个Layer的基类&#xff0c;提供了一些基本接口&#xff0c;并可以通过继承和多态机制实现不同类型的L…...

Rust每日一练(Leetday0011) 下一排列、有效括号、搜索旋转数组

目录 31. 下一个排列 Next Permutation &#x1f31f;&#x1f31f; 32. 最长有效括号 Longest Valid Parentheses &#x1f31f;&#x1f31f;&#x1f31f; 33. 搜索旋转排序数组 Search-in-rotated-sorted-array &#x1f31f;&#x1f31f; &#x1f31f; 每日一练刷…...

STL --- 五. 函数对象 Function Objects

目录 1、函数对象的定义和作用 2、函数对象的分类和使用 3、std 常用的函数对象 4、函数对象的适配器 5、std 算法和函数对象区别 1、函数对象的定义和作用 STL&#xff08;Standard Template Library&#xff09;中的函数对象&#xff08;Functor&#xff09;是一种重载…...

Java IO 流操作详解

Java IO 流操作详解 一、简介1. 什么是IO流2. IO流的分类3. IO流的作用 二、Java IO流的输入操作1. 文件输入流2. 字节输入流3. 缓冲输入流4. 对象输入流 三、Java IO流的输出操作1. 文件输出流2. 字节输出流3. 缓冲输出流4. 对象输出流 四、Java IO流的常用方法解析1. 字节读写…...

Halcon 形状匹配参数详解

find_shape_model(Image : : ModelID, AngleStart, AngleExtent, MinScore, NumMatches, MaxOverlap, SubPixel, NumLevels, Greediness : Row, Column, Angle, Score) find_shape_model(Image : : //搜索图像 ModelID, //模板句柄 AngleStart, // 搜索时的起始角度 AngleExte…...

C++11强类型枚举

C11引入了强类型枚举&#xff08;enum class&#xff09;&#xff0c;也称为枚举类。 强类型枚举是一种更加类型安全的枚举类型&#xff0c;相对于传统的枚举类型&#xff0c;强类型枚举可以提供更好的安全性和可读性。 强类型枚举的格式如下&#xff1a; enum class 枚举名 …...

pytorch讲解(部分)

友爱的目录 自动求导机制从后向中排除子图自动求导如何编码历史信息Variable上的In-place操作In-place正确性检查 CUDA语义最佳实践使用固定的内存缓冲区使用 nn.DataParallel 替代 multiprocessing 扩展PyTorch扩展 torch.autograd扩展 torch.nn 多进程最佳实践共享CUDA张量最…...

C++ 基本的7种数据类型和4种类型转换(C++复习向p3)

文章目录 基本内置类型存储范围typedef 声明新名字enum 枚举类型类型转换 基本内置类型 boolcharintfloatdoublevoidwchar_t ⇒ short int 存储范围 可以这样 sizeof(int) 来确认 int 占用字节数 char&#xff0c;1字节&#xff0c;-128~127 或 0~255 wchar_t&#xff0c;2…...

Scrum敏捷迭代规划和执行

Sprint Backlog看板 迭代工作的开展是围绕Sprint Backlog展开的&#xff0c;在Leangoo中&#xff0c;我们需要为每个迭代创建一个Sprint Backlog看板。Sprint Backlog&#xff08;迭代&#xff09;看板&#xff0c;用于管理当前Sprint的需求和开发任务&#xff0c;可视化展示每…...

智警杯赛前学习1.1---excel基本操作

修改默认设置 步骤一&#xff1a;打开“Excel选项”窗口&#xff0c;打开“文件”菜单&#xff0c;选择“选项”标签 步骤二&#xff1a;在“Excel选项”窗口中&#xff0c;选择“常规与保存”标签&#xff0c;在“常规与保存”标签中&#xff0c;可以修改录入数据时的默认字体…...

【Android】Handle(一) 主要特点和用途

在Android中&#xff0c;Handler是一种消息处理机制&#xff0c;它允许我们在不同线程之间交换信息并更新UI。具体来说&#xff0c;Handler可以将一个Runnable或Message对象加入到消息队列中&#xff0c;并在合适的时间去执行它们。 以下是Handler的主要特点和用途&#xff1a…...

40亿个QQ号,限制1G内存,如何去重?【已通过代码实现】

前几天发现一个有趣的文章 “40亿个QQ号,限制1G内存,如何去重?”,发现很有意思,就想着用代码实现一下,下面是分析和实现过程 一、审题分析 一个 QQ 号现在最长有 11 位,因为 int 是四字节,数值范围是2的31次方,因此得使用 long 存储,但考虑到实现,使用 int 存储(1…...

Talk预告 | 新加坡国立大学张傲:10%成本定制类 GPT-4 多模态大模型

本期为TechBeat人工智能社区第502期线上Talk&#xff01; 北京时间06月01日(周四)20:00&#xff0c;新加坡国立大学在读博士生 — 张傲的Talk将准时在TechBeat人工智能社区开播&#xff01; 他与大家分享的主题是: “10%成本定制类 GPT-4 多模态大模型 ”&#xff0c;届时将介…...

从C语言到C++_13(string的模拟实现)深浅拷贝+传统/现代写法

前面已经对 string 类进行了简单的介绍和应用&#xff0c;大家只要能够正常使用即可。 在面试中&#xff0c;面试官总喜欢让学生自己 来模拟实现string类&#xff0c; 最主要是实现string类的构造、拷贝构造、赋值运算符重载以及析构函数。 为了更深入学习STL&#xff0c;下面我…...

reduce()方法详解

一、 定义和用法 reduce() 方法将数组缩减为单个值。 reduce() 方法为数组的每个值&#xff08;从左到右&#xff09;执行提供的函数。 函数的返回值存储在累加器中&#xff08;结果/总计&#xff09;。 注释&#xff1a;对没有值的数组元素&#xff0c;不执行 reduce() 方法。…...

C++虚假唤醒

概念&#xff1a; 虚假唤醒是指在使用条件变量时&#xff0c;线程被唤醒但条件并没有满足&#xff0c;导致线程执行错误的情况&#xff0c;这个过程就是虚假唤醒。 虚假唤醒弊端&#xff1a; 虚假唤醒会导致程序的正确性受到影响&#xff0c;因为唤醒的线程并没有满足条件&…...

【AI】dragonGPT - 单机部署、极速便捷

dragonGPT 从数据私有化&#xff0c;到prompt向量库匹配&#xff0c;再到查询&#xff0c;一条龙服务&#xff0c;单机部署&#xff0c;极简操作 pre a.需要下载gpt4all model到本地. ggml Model Download Link 然后将存放model的地址写入.env MODEL_PATH your pathb.…...

Uuiapp使用生命周期,路由跳转传参

Uniapp生命周期&#xff1a; 1. beforeCreate&#xff1a;在实例初始化之后&#xff0c;数据观测和事件配置之前被调用。 2. created&#xff1a;在实例创建完成后被立即调用。 3. beforeMount&#xff1a;在挂载开始之前被调用&#xff1a;相关的 render 函数首次被调用。 …...

定积分的计算(牛顿-莱布尼茨公式)习题

前置知识&#xff1a;定积分的计算&#xff08;牛顿-莱布尼茨公式&#xff09; 习题1 计算 ∫ 0 2 ( x 2 − 2 x 3 ) d x \int_0^2(x^2-2x3)dx ∫02​(x2−2x3)dx 解&#xff1a; \qquad 原式 ( 1 3 x 3 − x 2 3 x ) ∣ 0 2 ( 8 3 − 4 6 ) − 0 14 3 (\dfrac 13x^3-…...

leak 记录今天的一个小题

先看题, add没有大小限制,这里edit可以溢出8字节,也就是可以改后边的size,可以调用4次free没有调用函数只是把指针置0,show可以用一次. void __fastcall __noreturn main(__int64 a1, char **a2, char **a3) {init_0(a1, a2, a3);while ( 1 ){menu();switch ( read_n() ){cas…...

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站&#xff0c;会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后&#xff0c;网站没有变化的情况。 不熟悉siteground主机的新手&#xff0c;遇到这个问题&#xff0c;就很抓狂&#xff0c;明明是哪都没操作错误&#x…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0&#xff1a;开发环境同步测试 cookie 至 localhost&#xff0c;便于本地请求服务携带 cookie 参考地址&#xff1a;https://juejin.cn/post/7139354571712757767 里面有源码下载下来&#xff0c;加在到扩展即可使用FeHelp…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

招商蛇口 | 执笔CID,启幕低密生活新境

作为中国城市生长的力量&#xff0c;招商蛇口以“美好生活承载者”为使命&#xff0c;深耕全球111座城市&#xff0c;以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子&#xff0c;招商蛇口始终与城市发展同频共振&#xff0c;以建筑诠释对土地与生活的…...

如何应对敏捷转型中的团队阻力

应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中&#xff0c;明确沟通敏捷转型目的尤为关键&#xff0c;团队成员只有清晰理解转型背后的原因和利益&#xff0c;才能降低对变化的…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下&#xff0c;推客小程序系统凭借其裂变传播、精准营销等特性&#xff0c;成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径&#xff0c;助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...

Linux部署私有文件管理系统MinIO

最近需要用到一个文件管理服务&#xff0c;但是又不想花钱&#xff0c;所以就想着自己搭建一个&#xff0c;刚好我们用的一个开源框架已经集成了MinIO&#xff0c;所以就选了这个 我这边对文件服务性能要求不是太高&#xff0c;单机版就可以 安装非常简单&#xff0c;几个命令就…...

用鸿蒙HarmonyOS5实现中国象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...