当前位置: 首页 > news >正文

用你的手机/电脑运行文生图方案

f8b0fc16a231ae0cd7d8fcf298dd8fde.gif

随着ChatGPT和Stable Diffusion的发布,最近一两年,生成式AI已经火爆全球,已然成为移动互联网后一个重要的“风口”。就图片/视频生成领域来说,Stable Diffusion模型发挥着极其重要的作用。由于Stable Diffusion模型参数量是10亿参数的大模型,通常业界都是运行部署在显卡上。

但是随着量化、剪枝等模型压缩技术的进步,以及手机等终端设备的算力、带宽、内存持续增大。使得大模型在终端设备部署也成为的可能。大模型在终端部署可以有效保护用户隐私,而且终端设备日常广泛使用、用户可以随时随地生成想要的内容。

7ec7e6256a7667e08014fbdff9b54cea.png

MNN-Diffusion使用

本文是深度学习推理引擎MNN团队,做的Stable Diffusion端侧部署应用,代码开源,用户可以自行DIY各种好玩的Stable Diffusion应用。

MNN开源地址:

https://github.com/alibaba/MNN/tree/master

欢迎大家试用,使用教程如下:

https://mnn-docs.readthedocs.io/en/latest/transformers/diffusion.html



下面是在个人手机/电脑上生成的图片:

bebe439c12b0e97687df89db324f0bad.png

技术要点

业界加速Stable Diffusion部署通常有两个方向,一是算法层面的优化,包括优化网络结构、减少计算量或者降低推理迭代步数;二是工程部署优化,通过量化/算子高效实现等方式提高硬件计算效率、提高访存效率。MNN作为推理引擎,主要聚焦在工程部署优化上,下面分享下MNN Diffusion GPU在性能/内存方面做了优化工作。

  Self-Attention优化

Transformer结构中Self-Attention是一个基础结构,也是性能耗时的关键。如下结构是一个典型的Attention结构:

30df654d025272d5fc22fc5fd4eb287e.png

一个共有节点,分别经过三个Linear层,得到Query/Key/Value,Query/Key经过形状变换进行BatchMatMul操作,再进行Scale,取Softmax操作;该结果和Value经过形状变换做BatchMatMul;之后把结果进行形状变换,得到最终的输出。可以看到上述总共有19个算子,包括12个形状变化算子,7个计算型算子。



大量的形状变化会带来很多的访存耗时,对于GPU高算力的硬件来说,访存耗时往往容易成为热点。因此,将上述结构,融合成2个算子,第一个是将三个Linear层权重融合在一起,只做一个Linear,这样形成更大的矩阵乘尺寸,更容易打满GPU算力,带来性能收益;第二个算子是将Attention算子融合成一个算子Fused-MultiHead-Attention,融合之后在该新算子内部仅需5个Kernel就可以实现整个Attention功能。消除了大量额外的形状变换算子,降低了访存压力,同时可以更容易基于Attention算子特性做进一步优化工作。

500e2baca3a6db86eb4d5d2fd9550b6f.png

  GroupNorm/SplitGeLU融合

在Stable Diffusion中,有一个通用的结构ResnetBlock,其中包含了BroadCast Binary + GroupNorm + SiLU结构,在onnx模型图结构中包含了如下13个算子:

333736b66abbeff889e57a0fad37192b.png

可以看到GroupNorm采用InstanceNorm+形变算子实现,gamma/beta被单独拆解为mul/add算子,细碎的算子会增加全局内存的访存次数、以及Kernel launch的压力。因此将上述通用结构合并成一个GroupNorm算子,该算子把前面的BroadCast Binary和后续的SiLU激活函数,融合在一起。高效的只需一个Kernel就可以实现上述计算需求。



同样的图融合原理,在Transformer激活函数中,Stable Diffusion Feed-Forward模块中采用GEGLU结构,对应onnx图结构如下。将该8个onnx图算子,融合为通用的SplitGeLU算子。

c5251af434249b000dafca534c294940.png

  conv-winograd算法实现

在Stable Diffusion中有大量3x3卷积,在深度学习中,Winograd算法已经大量应用在加速3x3卷积实现。

Winograd F(m, r)算法,其中m代表一个计算tile的大小,r对应filter的尺寸,d=m+r-1 代表对应input tile大小。

4f6d41a9cea75a5c83330159ce22669b.png

下表是3x3 Winograd不同tile对应计算量的节省比例和中间内存占用的增大比例。

m

r

d

计算量前后比例

input中间内存

weight中间内存

2

3

4

9 : 4 = 2.25x

4x

1.78x

4

3

6

4 : 1 = 4x

2.25x

4x

6

3

8

81 : 16 = 5.06x

1.78x

7.11x

目前,我们使用的是F(2, 3) Winograd,控制内存增大量,同时带来一倍的性能提升效果。

  高性能Gemm/BatchGemm

上述分析可以看出,Attention/卷积3x3,核心计算量在BatchGemm上,Linear层实际上就是Gemm运算。实际上,Stable Diffusion中,核心的计算量或者说耗时的热点,归根溯源,都集中在Gemm/BatchGemm上。如何高效实现矩阵乘法 成为最核心的关键。

矩阵乘在各个维度上的分块策略,可以有效提升数据的复用度和数据cache命中率;合理的分块可以为矩阵乘法带来大幅度的性能提升。

a6c30dcdb2a2f826e6023ee302c4c2ab.png

上图展示了,矩阵乘在各个维度上面的分块变量,包括在并发M/N维度,单次数据访存向量化位宽、每个线程存取矩阵的尺寸、每个工作组存取矩阵的尺寸,以及如果使用local memory缓存的话每个线程/工作组的缓存量。

这些参量都决定了数据访存的效率、并发量的大小、计算访存比的大小。不同的设备有不同的寄存器资源、共享内存资源、访存带宽、计算核心数,这些参量都决定着矩阵乘法的性能效率。



对于特定的矩阵乘的尺寸M/N/K,针对特定设备采取Auto-Tuning的获取最佳的运行参数(OPWM/OPWN/OPTM/OPTN/VEC_M/VEC_N等),Tuning候选集数量是M的N次方(N是参数的个数、M是每个参数候选集个数)。如果暴力循环每个参数候选集,由于候选集数量巨大、并且大尺寸矩阵乘本身单次运行耗时较大,必然会导致要花费大量时间去Tuning完所有候选集。因此,根据经验和实际试跑,选出部分高频参数候选集进行Tuning,在控制好Tuning时间的同时,也可以带来极大的性能收益。

  Gemm Strassen探索

由于矩阵乘法是Stable Diffusion耗时的核心,因此进行了矩阵乘快速算法的研究探索。Strassen算法是利用矩阵拆解,通过引入矩阵加减法,来减少矩阵乘法次数的方式。最简单的方法,将M/N/K维度各对拆1/2的方法,朴素的矩阵拆解如下:

0b94ad48464fc4675377f555b66e0b0e.png

Strassen算法,通过15次子矩阵加减法,来减少一次子矩阵乘法。矩阵拆解如下:

46a06996401968f0e405de7c014f6a98.png

当N足够大时,矩阵加减法耗时会远低于矩阵乘法耗时,带来12.5%的计算量降低。当N较小时,受限于15次 子矩阵加减的 耗时,以及拆解子矩阵乘法算力打不满等损耗原因,将引起负优化。具体某个形状的矩阵乘法适不适合使用Strassen算法?



对于矩阵A形状为[M, K], 矩阵B形状为[N, K],输出矩阵C形状为[M, N]。15次子矩阵加减,数据访存量为:(3*M*K + 3*N*K + 3.5*M*N) * sizeof(DataType) Bytes。1次子矩阵乘法,数据计算量为:1/8 * M*N*K * 2 = 1/4 * M*N*K FLOPS。我们默认矩阵加减是带宽瓶颈,矩阵乘法是算力瓶颈。假设设备的内存带宽为X GB/s,算力是Y GFLOPS。

子矩阵加减耗时:(6*M*K + 6*N*K + 3.5*M*N)*sizeof(DataType) / X (ns)

子矩阵乘节省耗时:(1/4 * M * N * K) / Y (ns)



当节省的耗时大于损耗耗时,即可有性能收益。根据上述公式,计算访存比越低的设备,Strassen算法越容易有收益。对于手机设备来说,1024x1024x1024的子矩阵,通常可以获得约10%的性能收益。

  内存占用优化

在Attention优化中,Q/K做BatchMatMul得到中间数据QK时,张量维度为[Batch, HeadNum, SeqLen, SeqLen]。对于Stable Diffusion来说,会遇到Batch=2,HeadNum=16,SeqLen=4096。对于float16的数据类型,单个张量的存储就需要1GB的内存大小,这对于内存资源紧缺的端侧设备是不可接受的。

876e1c1eb1ae66fb68659378e93eced9.png

因此,将Attention操作进行分块处理,类似Paged Attention的思路,将整个Attention分成SeqNum次执行,这样每次仅需原先1/SeqNum中间内存大小,可以非常有效的控制内存的大小。

性能测评

MNN Stable Diffusion应用,生成512x512图片,在骁龙8Gen3上使用GPU float16精度达到2s/iter (20次迭代,手机上40s可以生成完一幅图),在Apple Mac M3上GPU float32精度达到1.1s/iter (20次迭代,Mac上22s可以生成完一幅图)。MNN CPU/GPU性能均较大幅度快于如下Stable Diffusion开源框架,例如:

  • stable-diffusion.cpp

    https://github.com/leejet/stable-diffusion.cpp/issues/15

  • Android OnnxRuntime Stable Diffusion应用

    https://github.com/ZTMIDGO/Android-Stable-diffusion-ONNX

9e9496859a04ea0ac02b2af87f69bd83.png

后续研究

后续在性能优化和内存优化上面仍然有空间可以挖掘。

性能优化方面:

  • Conv Winograd采用更大的分块,获取更高的计算量降低收益。

  • 矩阵乘尝试Image存储内存访问模式,提高访存效率。

  • Attention进一步采用Flash Attention等思路优化。

内存占用优化方面:

  • 采用低比特权重(int8/int4量化)。

  • 在线转换动态内存可复用,Conv Winograd权重尝试采用在线转换。

  • Attention 采用Flash Attention优化节省中间内存使用。

8c56100ad51dd1b6565cc7b815d90b01.png

参考资料

  • https://blog.csdn.net/xian0710830114/article/details/129194419

  • https://github.com/NVIDIA/TensorRT/tree/release/8.6/demo/Diffusion

  • https://arxiv.org/abs/0707.2347

  • https://courses.cs.cornell.edu/cs6810/2023fa/Matrix.pdf

  • https://github.com/CNugteren/CLBlast/tree/master

  • https://arxiv.org/pdf/1703.06503

  • https://github.com/leejet/stable-diffusion.cpp/

  • https://github.com/ZTMIDGO/Android-Stable-diffusion-ONNX

e2e86c44f7745eb2a3c1103c68e973db.png

团队介绍

我们是大淘宝技术Meta Team,负责面向消费场景的3D/XR基础技术建设和创新应用探索,通过技术和应用创新找到以手机及XR 新设备为载体的消费购物3D/XR新体验。团队在端智能、商品三维重建、3D引擎、XR引擎等方面有深厚的技术积累。团队在OSDI、MLSys、CVPR、ICCV、NeurIPS、TPAMI等顶级学术会议和期刊上发表多篇论文。

¤ 拓展阅读 ¤

3DXR技术 | 终端技术 | 音视频技术

服务端技术 | 技术质量 | 数据算法

相关文章:

用你的手机/电脑运行文生图方案

随着ChatGPT和Stable Diffusion的发布,最近一两年,生成式AI已经火爆全球,已然成为移动互联网后一个重要的“风口”。就图片/视频生成领域来说,Stable Diffusion模型发挥着极其重要的作用。由于Stable Diffusion模型参数量是10亿参…...

L1正则化详解

目录 L1 正则化优缺点:适合使用L1正则化的情况:不适合使用L1正则化的情况:参考 L1 正则化 L1正则化是一种常用的正则化技术,也被称为Lasso正则化(Least Absolute Shrinkage and Selection Operator)。它通…...

C语言在数据库开发中的应用及其代码实践

数据库作为现代软件开发中不可或缺的一部分,其开发和维护工作至关重要。C语言,以其接近硬件的特性和高效率,被广泛应用于数据库系统的核心组件开发中。本文将探讨C语言在数据库开发中的应用,并提供实际的代码示例。 C语言在数据库…...

java maven

参考链接 maven相关配置 maven依赖管理 依赖具有传递性。 maven依赖范围 maven的生命周期 分为三个相互独立的生命周期: 在执行对应生命周期的操作时,需要进行前面的操作。比如,执行打包install的时候,会执行test。...

Java爬虫:获取直播带货数据的实战指南

在当今数字化时代,直播带货已成为电商领域的新热点,通过直播平台展示商品并进行销售,有效促进了产品的曝光和销售量的提升。然而,如何在直播带货过程中进行数据分析和评估效果,成为了摆在商家面前的一个重要问题。本文…...

python 列表、元组、字典易误区

一、删除元素 1、删除列表中的元素 pop del (1)pop(索引) 用于删除指定索引处的元素,并返回被删除的元素的值。默认删除最后一个元素。 eg:list.pop() (2)del 用于删除列表中的指定索引处的元素,或者删除整个列表变量。del操作没有返回值。 eg:del a[1:…...

wireshark或tshark提取tcpdump捕获的数据包(附python脚本自动解析文件后缀)

tcpdump 捕获数据包后,保存的文件通常会被命名为 capture.pcap(或其他你指定的名称),并存储在你运行命令的当前目录中。以下是如何使用 tcpdump 进行流量捕获,并找到和使用捕获文件的详细步骤。 1. 使用 tcpdump 捕获…...

了解EasyNVR及EasyNVS,EasyNVR连接EasyNVS显示授权超时如何解决?什么原因?

我们先来了解NVR批量管理软件/平台EasyNVR,它深耕市场多年,为用户提供多种协议,兼容多种厂商设备,包括但不限于支持海康,大华,宇视,萤石,天地伟业,华为设备。 NVR录像机…...

【AUTOSAR标准文档】服务类型介绍

Introduction to types of services The Basic Software can be subdivided into the following types of services: ① Input/Output (I/O) Standardized access to sensors, actuators and ECU onboard peripherals ② Memory Standardized access to internal/external…...

Axure垂直菜单展开与折叠

亲爱的小伙伴,在您浏览之前,烦请关注一下,在此深表感谢! 课程主题:Axure垂直菜单展开与折叠 主要内容:垂直菜单单击实现展开/折叠,点击各菜单项显示选中效果 应用场景:后台菜单设…...

java简单理解哈希算法

这里需要大家有一些哈希表(散列表的理论基础) 比如冲突怎么处理 key-value是什么意思 有哪些处理冲突的方法 平均查找成功长度和失败长度是什么意思。 详细可以看一下这个数据结构散列表。在java中常用三种结构代表散列: map,set,数组。应在不…...

Python生成随机密码脚本

引言 在数字化时代,密码已成为我们保护个人信息和数据安全的重要手段。然而,手动创建复杂且难以猜测的密码是一项既繁琐又容易出错的任务。幸运的是,Python编程语言为我们提供了一种高效且灵活的方法来自动生成随机密码。本文将详细介绍如何…...

什么是ASC广告?Facebook ASC广告使用技巧

ASC广告全称AdvantageShopping Campaign,即进阶赋能型智能购物广告,许多投放Facebook广告的小伙伴听过这个词,但每用过这个功能,Facebook推出ASC广告已经有两年了,不少实例证明ASC广告在降低转化成本上有一定效果&…...

idea2024启动Java项目报Error running CloudPlApplication. Command line is too long.

idea2024启动Java项目报Error running CloudPlApplication. Command line is too long. 解决方案: 1、打开Edit Configurations 2、点击Modify options设置,勾选Shorten command line 3、在Edit Configurations界面下方新增的Shorten command line选项中…...

xtu oj 不定方程的正整数解

文章目录 回顾思路c 语言代码 回顾 AB III问题 H: 三角数问题 G: 3个数等式 数组下标查询,降低时间复杂度1405 问题 E: 世界杯xtu 数码串xtu oj 神经网络xtu oj 1167 逆序数(大数据)xtu oj 原根 思路 首先直观地理解这个题目的意思&#x…...

python爬虫技术实现酷我付费破解下载

python爬虫技术实现酷我付费破解下载 1.python编程环境 python解释器:pyhton3版本 代码编辑器:Vscode,PyCharm 2.实现爬虫程序过程 2.1浏览器访问网站的过程 在浏览器导航栏中输入域名并回车(在按下回车的那一瞬间浏览器向网站发送了一个http请求)当网站接收到请求后向…...

工具:Git分布式版本控制系统

文章目录 介绍分布式版本控制系统原理git安装和使用git软件分类安装软件注册开源社区githubgit ssh key 配置远程仓库分支管理标签管理 引用 介绍 分布式版本控制系统下的每一台终端都可以充当类似集中式版本控制系统的中央服务器。每台终端都可以保存版本库,并且版…...

python+docxtpl:word文件模版渲染

目录 操作流程 加载模版 模版渲染 文件保存 python-docx库结合 模版渲染说明 变量值的获取 模板代码语句 遍历生成列表 docxtpl使用jinja2作为框架的模板系统,基于python-docx,同样可以使用python-docx库的一些方法,如添加段落,添加图片、列表等。 安装:pip ins…...

018_基于python+django荣誉证书管理系统2024_jytq9489

目录 系统展示 开发背景 代码实现 项目案例 获取源码 博主介绍:CodeMentor毕业设计领航者、全网关注者30W群落,InfoQ特邀专栏作家、技术博客领航者、InfoQ新星培育计划导师、Web开发领域杰出贡献者,博客领航之星、开发者头条/腾讯云/AW…...

Vulkan 开发(三):Vulkan 物理设备

Vulkan 物理设备 图片来自《 Vulkan 应用开发指南》 上一节了解了 Vulkan 实例,一旦有了实例,就可以查找系统里安装的与 Vulkan 兼容的物理设备。 Vulkan 物理设备(PhysicalDevice)一般是指支持 Vulkan 的物理硬件,通…...

Netty无锁化设计之对象池实现

池化技术是比较常见的一种技术,在平时我们已经就接触很多了,比如线程池,数据库连接池等等。当我们要使用一个资源的时候从池中去获取,用完就放回池中以便其他线程可以使用,这样的目的就是为了减少资源开销,…...

工厂生成中关于WiFi的一些问题

一 背景: 主要做高通和MTK,工厂生成中通过使用adb wifi,因为这样生产效率高并且避免了新机器有划痕,但是也经常碰到adb wifi无法连接的问题,那么是什么原因导致呢? 二 案例 测试步骤: 使用adb wifi连接手机测试工厂case adb usb adb tcpip 5555 adb connect DU…...

Java爬虫:获取商品评论数据的高效工具

在电子商务的激烈竞争中,商品评论作为消费者购买决策的重要参考,对于商家来说具有极高的价值。它不仅能够帮助商家了解消费者的需求和反馈,还能作为改进产品和服务的依据。Java爬虫技术,以其稳健性和高效性,成为了获取…...

oracle中的exists 和not exists 用法

exists (sql 返回结果集为真) not exists (sql 不返回结果集为真) exists 与 in 意思相同,语法不同,效率高于in not exists 与 not in 意思相同,语法不同,效率高于in 基本概念: se…...

自定义导出Excel数据注解实践

目录 前言结构组成定义自定义注解定义导出数据的实体定义Excel导出逻辑定义导出服务注解验证总结 前言 在企业级应用中,导入导出 Excel 文件是很常见的需求。通过使用自定义注解不仅可以实现灵活的 Excel 数据导入导出还可以减少手动配置的麻烦,提高代码…...

CSS3 动画相关属性实例大全(一)(@keyframes ,background属性,border 属性)

CSS3 动画相关属性实例大全(一) (keyframes ,background属性,border 属性) 本文目录: 零、时光宝盒 一、CSS3 动画基本概念 (1)、CSS3的动画基本属性 (2&#xff09…...

拦截器或过滤器往本次请求体中添加信息

步骤一:定义新的Request package com.ict.lux.framework.interceptor;import java.util.Collections; import java.util.Enumeration; import java.util.Map; import java.util.TreeMap;import javax.servlet.http.HttpServletRequest; import javax.servlet.http.…...

Docker 安装达梦 DM8 数据库实战指南

Docker 安装达梦 DM8 数据库实战指南 文章目录 Docker 安装达梦 DM8 数据库实战指南一 安装环境二 下载 DM8 安装包三 导入镜像四 启动容器1)docker run 启动2)docker compose 启动3)名词解释 五 连接数据库 本文详细介绍了如何在 CentOS 7.9…...

QtCreator14调试Qt5.15出现 Launching Debugger 错误

1、问题描述 使用QtCreator14调试程序,Launching Debugger 显示红色,无法进入调试模式。 故障现象如下: 使能Debugger Log窗口,显示: 325^error,msg"Error while executing Python code." 不过&#xff…...

day1:基础了解

虚拟机网络设置 桥接模式:客户机使用宿主机的网段 使虚拟机像物理机一样直接连接到外部网络,拥有独立的IP地址,可与其他网络设备通信。 nat模式:客户机使用单独的局域网 通过宿主机的NAT功能,让虚拟机能够访问外部…...