当前位置: 首页 > news >正文

【自制C++深度学习推理框架】Layer的设计思路

Layer的设计思路

Layer的抽象

如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。

基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的Layer。

具体来说,该类包括以下几个成员函数:

  1. 构造函数 Layer(std::string layer_name),用于创建一个Layer对象并设置该层的名称。

  2. virtual ~Layer() = default,虚析构函数,在派生类中可以通过override关键字重新定义。

  3. virtual InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs) ,前向传播函数,将输入tensor作为参数,计算输出tensor。

  4. virtual const std::vector<std::shared_ptr<Tensor<float>>> &weights() const, 返回当前层的权重数组。

  5. virtual const std::vector<std::shared_ptr<Tensor<float>>> &bias() const, 返回当前层的偏置数组。

  6. virtual void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights),设置当前层的权重数组。

  7. virtual void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias),设置当前层的偏置数组。

  8. virtual void set_weights(const std::vector<float> &weights),将权重数据类型转换为shared_ptr后调用上述函数。

  9. virtual void set_bias(const std::vector<float> &bias),将偏置数据类型转换为shared_ptr后调用上述函数。

  10. virtual const std::string &layer_name() const,返回当前层的名称。

而成员变量只有一个,即

  • std::string layer_name_,Layer的名称

为什么定义成虚函数

在神经网络中,不同的层具有不同的结构和运算方式,因此需要不同的函数来实现它们。使用虚函数的方法可以将这些不同的函数封装到一个基类中,并通过多态机制来实现不同类型的层的动态绑定。

具体来说,当使用基类指针或引用调用虚函数时,程序会根据对象的动态类型(即实际指向的派生类类型)来选择相应的函数实现。这就使得不同类型的层可以通过共同的接口进行调用,从而提高了代码的可维护性和可扩展性。

此外,使用虚函数还可以方便地定义抽象类,即只声明虚函数但不提供实现的类。这可以为子类提供一个规范化的接口,要求其必须重写某些接口以满足特定的需求。这种机制可以有效避免在大型工程中出现微小的差错而导致底层实现不符合最终需求的问题。

带权重Layer的实现

我们把Layer基类来表示不带参数的Layer,并且通过继承该Layer基类的方式来定义了一个带参数的层ParamLayer子类,在ParamLayer中定义了成员变量bias和weights。

ParamLayer是具有可调参数的神经网络层实现,包括初始化权重和偏置的函数、重载读写权重和偏置的函数,以及保存权重和偏置的成员变量。

具体来说,该类包括以下几个成员函数和成员变量:

  1. 构造函数 ParamLayer(const std::string &layer_name),用于创建一个ParamLayer对象并设置该层的名称。

  2. void InitWeightParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化权重参数。

  3. void InitBiasParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width),用于初始化偏置参数。

  4. const std::vector<std::shared_ptr<Tensor<float>>> &weights() const override,重载虚函数weights(),返回保存权重参数的成员变量weights_

  5. const std::vector<std::shared_ptr<Tensor<float>>> &bias() const override,重载虚函数bias(),返回保存偏置参数的成员变量bias_

  6. void set_weights(const std::vector<float> &weights) override,重载虚函数set_weights(),将权重数据类型转换为shared_ptr后存储在成员变量weights_中。

  7. void set_bias(const std::vector<float> &bias) override,重载虚函数set_bias(),将偏置数据类型转换为shared_ptr后存储在成员变量bias_中。

  8. void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights) override,重载虚函数set_weights(),将参数复制到成员变量weights_中。

  9. void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias) override,重载虚函数set_bias(),将参数复制到成员变量bias_中。

  10. 成员变量std::vector<std::shared_ptr<Tensor<float>>> weights_,保存ParamLayer的权重参数。

  11. 成员变量std::vector<std::shared_ptr<Tensor<float>>> bias_,保存ParamLayer的偏置参数。

ParamLayer通过继承Layer类实现了一些共同接口,并在此基础上扩展了更多函数和成员,可以方便地实现带有参数的神经网络层。

Layer的注册机制

为了实现注册和创建神经网络层,并在运行时动态地生成不同类型的神经网络层,定义了两个类:LayerRegisterer和LayerRegistererWrapper。

具体来说,LayerRegisterer类提供了三个静态函数和一个静态成员变量:

  1. typedef ParseParameterAttrStatus (*Creator)(const std::shared_ptr<RuntimeOperator> &op, std::shared_ptr<Layer> &layer):定义了一个函数指针类型Creator,用于指向具体神经网络层的函数。

  2. typedef std::map<std::string, Creator> CreateRegistry:定义了一个映射类型CreateRegistry,用于保存层类型和对应创建函数的映射关系。

  3. static void RegisterCreator(const std::string &layer_type, const Creator &creator):将层类型和创建函数的映射关系注册到CreateRegistry中。

  4. static std::shared_ptr<Layer> CreateLayer(const std::shared_ptr<RuntimeOperator> &op):根据输入的op对象创建相应的神经网络层。

  5. static CreateRegistry &Registry():返回当前已经注册的所有层类型和创建函数的映射关系。

RuntimeOperator是计算图的某个计算节点,里面保存了计算节点所需的参数等信息,具体介绍请看3.Graph.md

而LayerRegistererWrapper类则提供了一个构造函数,用于将某一种类型的神经网络层和其创建函数注册到LayerRegisterer中,如下所示。

class LayerRegistererWrapper {public:LayerRegistererWrapper(const std::string &layer_type, const LayerRegisterer::Creator &creator) {LayerRegisterer::RegisterCreator(layer_type, creator);}
};

在LayerRegisterer类中,通过维护一个键值对(<std::string, Creator>CreateRegistry,管理Layer注册表,在注册和查找Layer时都要先检查一下是否注册,如果未注册输出错误信息。

为什么要把成员函数定义为静态的

静态函数与类相关联,而不是与类的对象相关。因此,静态函数可以在没有创建类的实例的情况下调用,从而方便地提供一些辅助函数或管理函数,例如工厂方法、单例等。

LayerRegisterer和LayerRegistererWrapper中定义的所有函数都是静态的,主要原因是这些函数需要全局地维护层类型和创建函数的映射关系,并控制新层类型的注册和创建过程。使用静态函数可以使得这些功能在整个程序中被共享和访问,同时避免了由于对象实例的含糊不清而导致的编码错误。

另外需要注意的是,静态函数可以直接使用静态成员变量,不需要通过对象来访问,这使得这些静态函数可以更容易地协同工作,并兼顾了效率和灵活性。

阅读的代码

  • include
    • layer
      • abstract
        • layer_factory.hpp
        • layer.hpp
        • param_layer.hpp
  • source
    • layer
      • abstract
        • layer.cpp
        • layer_factory.cpp
        • param_layer.cpp

相关文章:

【自制C++深度学习推理框架】Layer的设计思路

Layer的设计思路 Layer的抽象 如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。 基于层的共性&#xff0c;我们定义了一个Layer的基类&#xff0c;提供了一些基本接口&#xff0c;并可以通过继承和多态机制实现不同类型的L…...

Rust每日一练(Leetday0011) 下一排列、有效括号、搜索旋转数组

目录 31. 下一个排列 Next Permutation &#x1f31f;&#x1f31f; 32. 最长有效括号 Longest Valid Parentheses &#x1f31f;&#x1f31f;&#x1f31f; 33. 搜索旋转排序数组 Search-in-rotated-sorted-array &#x1f31f;&#x1f31f; &#x1f31f; 每日一练刷…...

STL --- 五. 函数对象 Function Objects

目录 1、函数对象的定义和作用 2、函数对象的分类和使用 3、std 常用的函数对象 4、函数对象的适配器 5、std 算法和函数对象区别 1、函数对象的定义和作用 STL&#xff08;Standard Template Library&#xff09;中的函数对象&#xff08;Functor&#xff09;是一种重载…...

Java IO 流操作详解

Java IO 流操作详解 一、简介1. 什么是IO流2. IO流的分类3. IO流的作用 二、Java IO流的输入操作1. 文件输入流2. 字节输入流3. 缓冲输入流4. 对象输入流 三、Java IO流的输出操作1. 文件输出流2. 字节输出流3. 缓冲输出流4. 对象输出流 四、Java IO流的常用方法解析1. 字节读写…...

Halcon 形状匹配参数详解

find_shape_model(Image : : ModelID, AngleStart, AngleExtent, MinScore, NumMatches, MaxOverlap, SubPixel, NumLevels, Greediness : Row, Column, Angle, Score) find_shape_model(Image : : //搜索图像 ModelID, //模板句柄 AngleStart, // 搜索时的起始角度 AngleExte…...

C++11强类型枚举

C11引入了强类型枚举&#xff08;enum class&#xff09;&#xff0c;也称为枚举类。 强类型枚举是一种更加类型安全的枚举类型&#xff0c;相对于传统的枚举类型&#xff0c;强类型枚举可以提供更好的安全性和可读性。 强类型枚举的格式如下&#xff1a; enum class 枚举名 …...

pytorch讲解(部分)

友爱的目录 自动求导机制从后向中排除子图自动求导如何编码历史信息Variable上的In-place操作In-place正确性检查 CUDA语义最佳实践使用固定的内存缓冲区使用 nn.DataParallel 替代 multiprocessing 扩展PyTorch扩展 torch.autograd扩展 torch.nn 多进程最佳实践共享CUDA张量最…...

C++ 基本的7种数据类型和4种类型转换(C++复习向p3)

文章目录 基本内置类型存储范围typedef 声明新名字enum 枚举类型类型转换 基本内置类型 boolcharintfloatdoublevoidwchar_t ⇒ short int 存储范围 可以这样 sizeof(int) 来确认 int 占用字节数 char&#xff0c;1字节&#xff0c;-128~127 或 0~255 wchar_t&#xff0c;2…...

Scrum敏捷迭代规划和执行

Sprint Backlog看板 迭代工作的开展是围绕Sprint Backlog展开的&#xff0c;在Leangoo中&#xff0c;我们需要为每个迭代创建一个Sprint Backlog看板。Sprint Backlog&#xff08;迭代&#xff09;看板&#xff0c;用于管理当前Sprint的需求和开发任务&#xff0c;可视化展示每…...

智警杯赛前学习1.1---excel基本操作

修改默认设置 步骤一&#xff1a;打开“Excel选项”窗口&#xff0c;打开“文件”菜单&#xff0c;选择“选项”标签 步骤二&#xff1a;在“Excel选项”窗口中&#xff0c;选择“常规与保存”标签&#xff0c;在“常规与保存”标签中&#xff0c;可以修改录入数据时的默认字体…...

【Android】Handle(一) 主要特点和用途

在Android中&#xff0c;Handler是一种消息处理机制&#xff0c;它允许我们在不同线程之间交换信息并更新UI。具体来说&#xff0c;Handler可以将一个Runnable或Message对象加入到消息队列中&#xff0c;并在合适的时间去执行它们。 以下是Handler的主要特点和用途&#xff1a…...

40亿个QQ号,限制1G内存,如何去重?【已通过代码实现】

前几天发现一个有趣的文章 “40亿个QQ号,限制1G内存,如何去重?”,发现很有意思,就想着用代码实现一下,下面是分析和实现过程 一、审题分析 一个 QQ 号现在最长有 11 位,因为 int 是四字节,数值范围是2的31次方,因此得使用 long 存储,但考虑到实现,使用 int 存储(1…...

Talk预告 | 新加坡国立大学张傲:10%成本定制类 GPT-4 多模态大模型

本期为TechBeat人工智能社区第502期线上Talk&#xff01; 北京时间06月01日(周四)20:00&#xff0c;新加坡国立大学在读博士生 — 张傲的Talk将准时在TechBeat人工智能社区开播&#xff01; 他与大家分享的主题是: “10%成本定制类 GPT-4 多模态大模型 ”&#xff0c;届时将介…...

从C语言到C++_13(string的模拟实现)深浅拷贝+传统/现代写法

前面已经对 string 类进行了简单的介绍和应用&#xff0c;大家只要能够正常使用即可。 在面试中&#xff0c;面试官总喜欢让学生自己 来模拟实现string类&#xff0c; 最主要是实现string类的构造、拷贝构造、赋值运算符重载以及析构函数。 为了更深入学习STL&#xff0c;下面我…...

reduce()方法详解

一、 定义和用法 reduce() 方法将数组缩减为单个值。 reduce() 方法为数组的每个值&#xff08;从左到右&#xff09;执行提供的函数。 函数的返回值存储在累加器中&#xff08;结果/总计&#xff09;。 注释&#xff1a;对没有值的数组元素&#xff0c;不执行 reduce() 方法。…...

C++虚假唤醒

概念&#xff1a; 虚假唤醒是指在使用条件变量时&#xff0c;线程被唤醒但条件并没有满足&#xff0c;导致线程执行错误的情况&#xff0c;这个过程就是虚假唤醒。 虚假唤醒弊端&#xff1a; 虚假唤醒会导致程序的正确性受到影响&#xff0c;因为唤醒的线程并没有满足条件&…...

【AI】dragonGPT - 单机部署、极速便捷

dragonGPT 从数据私有化&#xff0c;到prompt向量库匹配&#xff0c;再到查询&#xff0c;一条龙服务&#xff0c;单机部署&#xff0c;极简操作 pre a.需要下载gpt4all model到本地. ggml Model Download Link 然后将存放model的地址写入.env MODEL_PATH your pathb.…...

Uuiapp使用生命周期,路由跳转传参

Uniapp生命周期&#xff1a; 1. beforeCreate&#xff1a;在实例初始化之后&#xff0c;数据观测和事件配置之前被调用。 2. created&#xff1a;在实例创建完成后被立即调用。 3. beforeMount&#xff1a;在挂载开始之前被调用&#xff1a;相关的 render 函数首次被调用。 …...

定积分的计算(牛顿-莱布尼茨公式)习题

前置知识&#xff1a;定积分的计算&#xff08;牛顿-莱布尼茨公式&#xff09; 习题1 计算 ∫ 0 2 ( x 2 − 2 x 3 ) d x \int_0^2(x^2-2x3)dx ∫02​(x2−2x3)dx 解&#xff1a; \qquad 原式 ( 1 3 x 3 − x 2 3 x ) ∣ 0 2 ( 8 3 − 4 6 ) − 0 14 3 (\dfrac 13x^3-…...

leak 记录今天的一个小题

先看题, add没有大小限制,这里edit可以溢出8字节,也就是可以改后边的size,可以调用4次free没有调用函数只是把指针置0,show可以用一次. void __fastcall __noreturn main(__int64 a1, char **a2, char **a3) {init_0(a1, a2, a3);while ( 1 ){menu();switch ( read_n() ){cas…...

软考A计划-试题模拟含答案解析-卷二

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…...

【C++】pthread

一、pthread简介 pthread是C98接口且只支持Linux&#xff0c;使用时需要包含头文件#include <pthread.h>&#xff0c;编译时需要链接pthread库&#xff0c;其中p是POSIX的缩写&#xff0c;而POSIX是Portable Operating System Interface的缩写&#xff0c;是IEEE为要在各…...

2023年前端面试题汇总-浏览器原理

1. 浏览器安全 1.1. 什么是 XSS 攻击&#xff1f; 1.1. 1. 概念 XSS 攻击指的是跨站脚本攻击&#xff0c;是一种代码注入攻击。攻击者通过在网站注入恶意脚本&#xff0c;使之在用户的浏览器上运行&#xff0c;从而盗取用户的信息如 cookie 等。 XSS 的本质是因为网站没有对…...

react介绍,react语法,react高级特性,react编程技巧

React是一个用于构建用户界面的JavaScript库。它由Facebook开发&#xff0c;于2013年首次发布。React的主要目标是提高应用程序的性能和可维护性。React采用了一种称为“组件”的模式&#xff0c;使开发人员可以将应用程序拆分为小而独立的部分&#xff0c;从而更容易编写和维护…...

Locust接口性能测试

谈到性能测试工具&#xff0c;我们首先想到的是LoadRunner或JMeter。LoadRunner是非常有名的商业性能测试工具&#xff0c;功能非常强大。但现在一般不推荐使用该工具来进行性能测试&#xff0c;主要是使用也较为复杂&#xff0c;而且该工具体积比较大&#xff0c;需要付费且价…...

Python类的特殊方法(通过故事来学习)

在一座森林里&#xff0c;住着三只动物&#xff1a;狼、兔和熊。这三只动物都有不同的特点和能力&#xff0c;但是它们所有的行为都可以被抽象成一个“动物”类。现在&#xff0c;让我们来看看Python中的类和特殊方法如何帮助我们实现这个故事。 首先&#xff0c;我们可以定义…...

Vue.js 中的父子组件通信方式

Vue.js 中的父子组件通信方式 在 Vue.js 中&#xff0c;组件是构建应用程序的基本单元。当我们在应用程序中使用组件时&#xff0c;组件之间的通信是非常重要的。在 Vue.js 中&#xff0c;父子组件通信是最常见的组件通信方式之一。在本文中&#xff0c;我们将讨论 Vue.js 中的…...

Python之并发编程二多进程理论

一、什么是进程 进程&#xff1a;正在进行的一个过程或者说一个任务。而负责执行任务则是cpu。 二、进程与程序的区别 程序仅仅只是一堆代码而已&#xff0c;而进程指的是程序的运行过程。 三、并发与并行 无论是并行还是并发&#xff0c;在用户看来都是’同时’运行的&am…...

纯干货:数据库连接耗时慢原因排查

背景 最近公司的社区相关的服务需要优化&#xff0c;由于对业务不熟悉&#xff0c;只能借助监控从一些慢接口开始尝试探索慢的原因。由于社区相关的功能务是公司小程序流量入口&#xff0c;所以相应的服务访问量还是比较高的。针对这类高访问的项目&#xff0c;任何不留神的地…...

【OneNet】| stm32+esp8266-01s—— OneNet初体验 | 平台注册及设备创建 | demo使用

系列文章目录 失败了也挺可爱&#xff0c;成功了就超帅。 文章目录 前言1. OneNet平台注册2. 创建多协议接入设备3. 硬件连接4. 下载并运行Demo4.1 Demo下载4.2 运行Demo本小节结束 前言 最近准备耍下 Onenet平台 。下载了官方demo 遇到几个问题 1、创建接入设备 因为平台网页…...