当前位置: 首页 > news >正文

去掉乘法运算的加法移位神经网络架构

[CVPR 2020] AdderNet: Do We Really Need Multiplications in Deep Learning?

代码:https://github.com/huawei-noah/AdderNet/tree/master

核心贡献

  • 用filter与input feature之间的L1-范数距离作为“卷积层”的输出
  • 为了提升模型性能,提出全精度梯度的反向传播方法
  • 根据不同层的梯度级数,提出自适应学习率策略

研究动机

  • 加法远小于乘法的计算开销,L1-距离(加法)对硬件非常友好
  • BNN效率高,但是性能难以保证,同时训练不稳定,收敛慢
  • 几乎没有工作尝试用其他更高效的仅包含加法的相似性度量函数来取代卷积

传统卷积
在这里插入图片描述

其中, S S S是相似度(距离)衡量指标,如果定义为内积,则是传统卷积算法。

AdderNet
用L1-距离作为距离衡量指标:
在这里插入图片描述

从而,计算中不存在任何乘法计算。Adder层的输出都是负的,所以网络中引入batch normalization(BN)层和激活函数层。注意BN层虽然有乘法,但是其开销相比于卷积可以忽略不计。

为什么可以将卷积替换为加法?作者的解释是第一个公式类似于图像匹配领域,在这个领域中 S S S可以被替换为不同的函数,因此在卷积神经网络中把内积换成L1-距离也是很自然的想法。

优化方法
传统卷积的梯度:
在这里插入图片描述

signSGD梯度:
在这里插入图片描述

其中,sgn是符号函数。但是,signSGD几乎没有采取最陡的下降方向,随着维度的增长,下降方向只会变得更糟,所以不适用于大参数量的模型优化。

于是本文提出通过利用全精度梯度,精确地更新filter:
在这里插入图片描述

在形式上就是去掉了signSGD的sgn函数。

为了避免梯度爆炸的问题,提出将梯度裁剪到[-1, 1]范围内:
在这里插入图片描述
在这里插入图片描述

自适应学习率
传统CNN的输出方差:
在这里插入图片描述

AdderNet的输出方差:
在这里插入图片描述

CNN中filter的方差非常小,所以Y的方差很小;而AdderNet中Y的方差则非常大。

计算损失函数对x的梯度:
在这里插入图片描述

这个梯度的级数应该很小,本文对不同层weight梯度的L2-norm值进行了统计:
在这里插入图片描述

发现AdderNet的梯度确实相比于CNN非常小,这会严重减慢filter更新的过程。

一种最直接的思路就是采用更大的学习率,本文发现不同层的梯度值差异很大,所以为了考虑不同层的filter情况,提出了不同层的自适应学习率。

在这里插入图片描述

其中, γ \gamma γ是全局学习率, ∆ L ( F l ) ∆L(F_l) L(Fl)是第 l l l层filter梯度, α l \alpha_l αl是对应层的本地学习率。

在这里插入图片描述

k k k F l F_l Fl中元素的数量, η \eta η是超参数。于是,不同adder层中的filter可以用几乎相同的step进行更新。

训练算法流程
感觉没有什么特别需要注意的地方。
在这里插入图片描述

主要实验结果
在这里插入图片描述

在这里插入图片描述

可以看到,AdderNet在三个CNN模型上都掉点很少,并且省去了所以乘法,也没有BNN中的XNOR操作,只是有了更多的加法,效率应该显著提高。

核心代码
Adder层:

X_col = torch.nn.functional.unfold(X.view(1, -1, h_x, w_x), h_filter, dilation=1, padding=padding, stride=stride).view(n_x, -1, h_out*w_out)
X_col = X_col.permute(1,2,0).contiguous().view(X_col.size(1),-1)
W_col = W.view(n_filters, -1)output = -(W_col.unsqueeze(2)-X_col.unsqueeze(0)).abs().sum(1)

反向传播优化:

grad_W_col = ((X_col.unsqueeze(0)-W_col.unsqueeze(2))*grad_output.unsqueeze(1)).sum(2)
grad_W_col = grad_W_col/grad_W_col.norm(p=2).clamp(min=1e-12)*math.sqrt(W_col.size(1)*W_col.size(0))/5
grad_X_col = (-(X_col.unsqueeze(0)-W_col.unsqueeze(2)).clamp(-1,1)*grad_output.unsqueeze(1)).sum(0)

[NeurIPS 2020] ShiftAddNet: A Hardware-Inspired Deep Network

代码:https://github.com/GATECH-EIC/ShiftAddNet

主要贡献

  • 受到硬件设计的启发,提出bit-shift和add操作,ShiftAddNet具有完全表达能力和超高效率
  • 设计训练推理算法,利用这两个操作的不同的粒度级别,研究ShiftAddNet在训练效率和精度之间的权衡,例如,冻结所有的位移层

研究动机

  • Shift和add比乘法更高效
  • Add层学习的小粒度特征,shift层被认为可以提取大粒度特征提取

ShiftAddNet结构设计
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

反向传播优化
Add层的梯度计算
在这里插入图片描述

Shift层的梯度计算
在这里插入图片描述

冻结shift层
冻结ShiftAddNet中的shift层意味着 s , p s, p s,p在初始化后一样,然后进一步剪枝冻结的shift层以保留必要的大粒度anchor weight。

[NeurIPS 2023] ShiftAddViT: Mixture of Multiplication Primitives Towards Efficient Vision Transformers

代码:https://github.com/GATECH-EIC/ShiftAddViT

核心贡献

  • 用混合互补的乘法原语(shift和add)来重参数化预训练ViT(无需从头训练),得到“乘法降低”网络ShiftAddViT。Attention中所有乘法都被add kernel重参数化,剩下的线性层和MLP被shift kernel重参数化
  • 提出混合专家框架(MoE)维持重参数化后的ViT,其中每个专家都代表一个乘法或它的原语,比如移位。根据给定输入token的重要性,会激活合适的专家,例如,对重要token用乘法,并对不那么重要的token用移位
  • 在MoE中引入延迟感知和负载均衡的损失函数,动态地分配输入token给每个专家,这确保了分配的token数量与专家的处理速度相一致,显著减少了同步时间

研究动机

  • 乘法可以被替换为shift和add
  • 如果重参数化ViT?ShiftAddNet是级联结构,需要双倍的层数/参数Shift和add层的CUDA内核比PyTorch在CUDA上的训练和推理慢得多
  • 如何保持重参数化后ViT的性能?对于ViT,当图像被分割成不重叠token时,我们可以利用输入token之间固有的自适应敏感性。原则上包含目标对象的基本token需要使用更强大的乘法来处理(这个idea和token merging很类似)

总体框架设计

  • 对于attention,将4个linear层和2个矩阵乘转换为shift和add层
  • 对于MLP,直接替换为shift层会大幅降低准确率,因此设计了MoE框架合并乘法原语的混合,如乘法和移位
  • 注意:linear->shift, MatMul->add

在这里插入图片描述

Attention重参数化
考虑二值量化,于是两个矩阵之间的乘累加(MAC)运算将高效的加法运算所取代。
( Q K ) V (QK)V (QK)V改为 Q ( K V ) Q(KV) Q(KV)以实现线性复杂度, Q , K Q, K Q,K进行二值量化,而更敏感的 V V V保持高精度,并插入轻量级的DWConv增强模型局部性。

在这里插入图片描述

可以看到,实际上ShiftAddViT就是把浮点数乘法简化为了2的幂次的移位运算和二值的加法运算。
在这里插入图片描述

其中, s , P s, P s,P都是可以训练的。

敏感性分析
在attention层应用线性注意力、add或shift对ViT准确性影响不大,但是在MLP层应用shift影响很明显!同时,使MLP更高效,对能源效率有很大贡献,因此需要考虑新的MLP重参数化方法。

在这里插入图片描述

MLP重参数化
MLP同样主导ViT的延迟,所以用shift层替换MLP的linear层,但是性能下降明显,所以提出MoE来提升其性能。

MoE框架

  • 假设: 假设重要但敏感的输入token,需要更强大的网络,否则会显著精度下降

  • 乘法原语的混合: 考虑两种专家(乘法和shift)。根据router中gate值 p i = G ( x ) p_i=G(x) pi=G(x),每个输入token表示 x x x将被传递给一位专家,输出定义如下:
    在这里插入图片描述
    在这里插入图片描述
    其中, n , E i n, E_i n,Ei表示专家数和第 i i i个专家。

  • 延迟感知和负载均衡的损失函数: MoE框架的关键是设计一个router函数,以平衡所有专家有更高的准确性和更低的延迟。乘法高性能但慢,shift快但低性能,如何协调每个专家的工作负荷,以减少同步时间?
    在这里插入图片描述
    其中,SCV表示给定分布对专家的平方变异系数(本文没介绍)。通过设计的损失函数,可以满足(1)所有专家都收到gate值的预期加权和;(2)为所有专家分配预期的输入token数。

相关文章:

去掉乘法运算的加法移位神经网络架构

[CVPR 2020] AdderNet: Do We Really Need Multiplications in Deep Learning? 代码:https://github.com/huawei-noah/AdderNet/tree/master 核心贡献 用filter与input feature之间的L1-范数距离作为“卷积层”的输出为了提升模型性能,提出全精度梯度…...

【TB作品】51单片机,具有报时报温功能的电子钟

2.具有报时报温功能的电子钟 一、功能要求: 1.显示室温。 2.具有实时时间显示。 3.具有实时年月日显示和校对功能。 4.具有整点语音播报时间和温度功能。 5.定闹功能,闹钟音乐可选。 6.操作简单、界面友好。 二、设计建议: 1.单片机自选(C51、STM32或其他单片机)。 2.时钟日历芯…...

了解C++工作机制

基于hello.cpp对C的运行进行一个初步认识,并介绍国外C大佬Cherno常用的项目结构和调试Tips C是如何工作的 C工作流程1.实用工程(project)结构(1)Microsoft Visual Studio2022新建项目后,自动生成的原始文件…...

力扣题目学习笔记(OC + Swift) 14. 最长公共前缀

14. 最长公共前缀 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀,返回空字符串 “”。 方法一 竖向扫描法 个人感觉纵向扫描方式比较直观,符合人类理解方式,从前往后遍历所有字符串的每一列,比较相同列上的…...

WinSW设置应用程序开机启动

前言 由于使用windows自动的自启方法,不管是将程序启动服务放到开机自启文件夹中,还是创建任务计划程序,都没有很好的实现程序的开机自启效果,而WinSW很好的解决了这个问题。 下载 WinSW下载地址 注意:不同版本&#…...

Leetcode—96.不同的二叉搜索树【中等】

2023每日刷题&#xff08;六十四&#xff09; Leetcode—96.不同的二叉搜索树 算法思想 实现代码 class Solution { public:int numTrees(int n) {vector<int> G(n 1, 0);G[0] 1;G[1] 1;for(int i 2; i < n; i) {for(int j 1; j < i; j) {G[i] G[j - 1] * …...

正则表达式零宽断言

正则表达式零宽断言 工具类&#xff0c;正则表达式匹配文本内容正则表达式语法例子例子01零宽断言?< 不包含左边值? 不包含右边值例子 常用正则表达式校验数字的表达式校验字符的表达式 工具类&#xff0c;正则表达式匹配文本内容 /*** 正则表达式工具类*/ public class…...

uni-app学习记录

uni-app注意点记录 跳转到 tabBar 页面只能使用 switchTab 跳转路由API的目标页面必须是在pages.json里注册的vue页面。如果想打开web url&#xff0c;在App平台可以使用 plus.runtime.openURL或web-view组件&#xff1b;H5平台使用 window.open&#xff1b;小程序平台使用web…...

API资源对象StorageClass;Ceph存储;搭建Ceph集群;k8s使用ceph

API资源对象StorageClass;Ceph存储;搭建Ceph集群;k8s使用ceph API资源对象StorageClass SC的主要作用在于&#xff0c;自动创建PV&#xff0c;从而实现PVC按需自动绑定PV。 下面我们通过创建一个基于NFS的SC来演示SC的作用。 要想使用NFS的SC&#xff0c;还需要安装一个NFS…...

Databend 开源周报第 124 期

Databend 是一款现代云数仓。专为弹性和高效设计&#xff0c;为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务&#xff1a;https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展&#xff0c;遇到更贴近你心意的 Databend 。 新增对 Delta 和…...

Arduino开发实例-液体流量测量

液体流量测量 文章目录 液体流量测量1、流量传感器介绍2、硬件准备及接线3、代码实现在本文中,将介绍如何流量传感器进行测量液体流量。 流量传感器用于测量液体流速。 市场上有不同类型的流量传感器,在本文中,我们将使用霍尔效应流量传感器。 这些类型的流量传感器是非侵入…...

【idea】解决sprintboot项目创建遇到的问题

目录 一、报错Plugin ‘org.springframework.boot:spring-boot-maven-plugin:‘ not found 二、报错java: 错误: 无效的源发行版&#xff1a;17 三、java: 无法访问org.springframework.web.bind.annotation.CrossOrigin 四、整合mybatis的时候&#xff0c;报java.lang.Ill…...

ADC芯片CS1237在电子秤方案的优势

​随着科技的不断发展&#xff0c;电子秤已经成为我们日常生活中不可或缺的测量工具。为了满足用户对于高精度、高稳定性的需求&#xff0c;芯海ADC芯片CS1237应运而生&#xff0c;为电子秤方案带来了革命性的变革。 一、芯海ADC芯片CS1237介绍 芯海ADC芯片CS1237是一款高性能…...

Leetcode的AC指南 —— 哈希表:202. 快乐数

摘要&#xff1a; Leetcode的AC指南 —— 哈希表&#xff1a;202. 快乐数。题目介绍&#xff1a;编写一个算法来判断一个数 n 是不是快乐数。 文章目录 一、题目二、解析1、哈希表 一、题目 题目介绍&#xff1a;编写一个算法来判断一个数 n 是不是快乐数。 「快乐数」 定义为…...

机器学习 项目结构 数据预测 实验报告

需求&#xff1a; 我经过处理得到了测试值&#xff0c;然后进一步得到预测和真实值的比较&#xff0c;然后再把之前的所有相关的参数、评估指标、预测值、比较结果都存入excel,另外我还打算做测试报告模板&#xff0c;包括敏感性分析等。您建议我这些功能如何封装这些功能&…...

[Verilog] 设计方法和设计流程

主页&#xff1a; 元存储博客 文章目录 1. 设计方法2. 设计流程 3 Vivado软件设计流程总结 1. 设计方法 Verilog 的设计多采用自上而下的设计方法&#xff08;top-down&#xff09;。设计流程是指从一个项目开始从项目需求分析&#xff0c;架构设计&#xff0c;功能验证&#…...

C语言:指向数组的指针和指向数组元素的指针

相关阅读 C语言https://blog.csdn.net/weixin_45791458/category_12423166.html?spm1001.2014.3001.5482 指向数组的指针和指向数组元素的指针常常被混淆&#xff0c;或者笼统地被称为数组指针&#xff0c;但它们之间是有差别的&#xff0c;本文就将对此进行讨论。 下面的代码…...

SQL基础:SQL 介绍和数据库基础

SQL简介 常用的Java等语言是和计算机交流的工具&#xff0c;告诉计算机&#xff0c;让计算机做一些事。 和其类似&#xff0c;SQL是 Structured Query Language 的缩写&#xff0c;即结构化的查询语言&#xff0c;是和数据库交互的工具&#xff0c;即通过既定的一些格式&…...

SpringSecurity入门

前言 Spring Security是一个用于在Java应用程序中提供身份验证和授权功能的强大框架。它构建在Spring框架之上&#xff0c;为开发人员提供了一套灵活且全面的安全性服务&#xff0c;本篇将为大家带来Spring Security的详细介绍及入门 一.安全框架 在学习了解Spring Security之…...

iOS 应用在前台时显示通知

背景&#xff1a; 在iOS应用中&#xff0c;当应用在前台运行时&#xff0c;是不会默认弹出通知的。这是iOS的设计决定&#xff0c;以避免用户在使用应用的过程中被打扰。然而&#xff0c;如果你希望在应用在前台的时候也能收到通知&#xff0c;你可以在你的应用代码中进行一些…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能&#xff0c;我们需要对它的功能特点进行分析&#xff1a; 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具&#xff1a; mysql&#xff1a;关系型数据库&am…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括&#xff1a;采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中&#xff0c;设置任务排序规则尤其重要&#xff0c;因为它让看板视觉上直观地体…...

《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)

CSI-2 协议详细解析 (一&#xff09; 1. CSI-2层定义&#xff08;CSI-2 Layer Definitions&#xff09; 分层结构 &#xff1a;CSI-2协议分为6层&#xff1a; 物理层&#xff08;PHY Layer&#xff09; &#xff1a; 定义电气特性、时钟机制和传输介质&#xff08;导线&#…...

第25节 Node.js 断言测试

Node.js的assert模块主要用于编写程序的单元测试时使用&#xff0c;通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试&#xff0c;通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec&#xff1f; IPsec VPN 5.1 IPsec传输模式&#xff08;Transport Mode&#xff09; 5.2 IPsec隧道模式&#xff08;Tunne…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI&#xff08;https://spring.io/projects/spring-ai&#xff09;作为Spring生态中的AI集成框架&#xff0c;其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似&#xff0c;但特别为多语…...

32位寻址与64位寻址

32位寻址与64位寻址 32位寻址是什么&#xff1f; 32位寻址是指计算机的CPU、内存或总线系统使用32位二进制数来标识和访问内存中的存储单元&#xff08;地址&#xff09;&#xff0c;其核心含义与能力如下&#xff1a; 1. 核心定义 地址位宽&#xff1a;CPU或内存控制器用32位…...

鸿蒙APP测试实战:从HDC命令到专项测试

普通APP的测试与鸿蒙APP的测试有一些共同的特征&#xff0c;但是也有一些区别&#xff0c;其中共同特征是&#xff0c;它们都可以通过cmd的命令提示符工具来进行app的性能测试。 其中区别主要是&#xff0c;对于稳定性测试的命令的区别&#xff0c;性能指标获取方式的命令的区…...