当前位置: 首页 > news >正文

大模型自定义算子优化方案学习笔记:CUDA算子定义、算子编译、正反向梯度实现

01算子优化的意义

随着大模型应用的普及以及算力紧缺,下一步对于计算性能的追求一定是技术的核心方向。因为目前大模型的计算逻辑是由一个个独立的算子或者说OP正反向求导实现的,底层往往调用的是GPU提供的CUDA的驱动程序。如果不能对于整个计算过程学习并了解,对于性能优化领域无非是隔靴搔痒,今天也是抽一点时间读了下网上的一些文档和CUDA的文档,整理了学习材料。

首先说下为什么要自定义算子,无非是两个原因,

(1)TF、PyTorch等提供的原生算子不满足需求,需要通过底层接口,比如CUDA层更灵活的实现个性化算子

(2)目前提供的算子性能不足,没有很好的利用到GPU的并行计算优势,有优化空间

接着性能优化的问题说,因为GPU与CPU最大的区别是计算单元占据了绝大部分的体积(图中绿色部分),而控制单元较少。

自定义手写算子可以更好地利用绿色的计算单元的并行化优势。大的思路是Grid包含Block,Block包含Thread。于是首先算子需要把计算逻辑拆分成Thread,让程序可以并行化的运行起来,然后有机的管理各个Block的执行节奏,解决好异步和同步问题,就可以让芯片的计算效率最大化。

Grid of Thread Blocks

02实现流程

整个自定义算子优化其实可以分为CUDA算子定义、算子编译、正方向梯度实现几个步骤。

1、CUDA算子定义

需要定义以下几个文件:

(1)function.cpp:python层和CUDA层的衔接,实现手写的C++ CUDA算子可以被python代码调用

(2)function.h:CUDA函数声明文件

(3)function.cu:CUDA函数的逻辑实现

这里比较核心的文件就是.cu文件,构建的时候主要做两个事:一个是建设Kernel函数,因为只有Kernel函数是在GPU端执行,执行完之后要将控制权给到控制函数,这里要控制好异步、同步的问题。二是在kernel函数中需要通过循环函数定义每个Thread以及每个Block的工作,真正让计算并行化

.cpp文件可以通过pybind函数实现,这个函数主要解决的是C++代码和Python绑定的问题,项目地址:GitHub - pybind/pybind11: Seamless operability between C++11 and Python

2.编译和执行

import torch
from torch.utils.cpp_extension import load
cuda_module = load(name="function",extra_include_paths=["include"],sources=["function.cpp", "function.cu"],verbose=True)

接着就是执行过程中的编译,通过load函数底层会执行JIT(Just in time)的动态编译模式调用.cpp和.cu文件,底层运行的是Ninjia编译器,通过调用nvcc编译.so文件

[1/2] nvcc -c function.cu -o function.cuda.o
[2/3] c++ -c function.cpp -o function.o
[3/3] c++ function.o function.cuda.o -shared -o function.so

3.正反向梯度实现

以上两步实现了自定义算子的逻辑,可以通过手写CUDA算子并在python框架层调用,如果要满足建模需求,还需要实现正方向梯度。具体的做法是在建模的函数中实现forward和backward函数,定义导数作为输出。

以上大概就是手写算子优化的简单流程,仅当学习笔记。

参考:

【1】熬了几个通宵,我写了份CUDA新手入门代码 - 知乎

【2】CUDA C++ Programming Guide

相关文章:

大模型自定义算子优化方案学习笔记:CUDA算子定义、算子编译、正反向梯度实现

01算子优化的意义 随着大模型应用的普及以及算力紧缺,下一步对于计算性能的追求一定是技术的核心方向。因为目前大模型的计算逻辑是由一个个独立的算子或者说OP正反向求导实现的,底层往往调用的是GPU提供的CUDA的驱动程序。如果不能对于整个计算过程学习…...

【密码学基础】Diffie-Hellman密钥交换协议

DH介绍 Diffie-Hellman密钥协议算法是一种确保共享密钥安全穿越不安全网络的方法。 这个机制的巧妙在于需要安全通信的双方可以用这个方法确定对称密钥,然后可以用这个密钥进行加密和解密。 但是注意,这个密钥交换协议 只能用于密钥的交换,而…...

最新AI绘画Midjourney绘画提示词Prompt教程

一、Midjourney绘画工具 SparkAi【无需魔法使用】: sparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的…...

AI助力DevOps新时代

根据2023年Gitlab全球DevSecOps报告,62%使用AI和ML的开发人员表示他们正在使用AI来检查代码,而2022年这一比例只有51%。 人工智能在 DevOps 中的作用 虽然今年年初,随着GPT的爆火,AI技术逐渐深入人心,但在很早以前&…...

Spring之容器:IOC(2)

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。各位小伙伴,如果您: 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持,想组团高效学习… 想写博客但无从下手,急需…...

Spring 依赖查找知识点总结

前言 源码在我github的guide-spring仓库中,可以克隆下来 直接执行。 我们本文主要来介绍依赖查找的使用示例 依赖查找 什么是依赖查找 依赖查找并不是 Spring 框架特有的概念,它是一种在软件开发中获取依赖对象的方式。它通常用于获取运行时需要的服…...

html5新增特性

对于这行代码&#xff0c;要写在html页面的最前端&#xff1a; <!DOCTYPE html> 为什么要写在前面&#xff1f; 这是声明&#xff0c;是html5的新特性 对于html4来说&#xff0c;它有三种声明格式&#xff0c;而html5只需要统一声明&#xff0c;用来告诉浏览器文档使用…...

4、APScheduler: 详解Scheduler种类用法、常见错误与解决方法【Python3测试任务管理总结】

调度器(Scheduler)是将其他组件绑在一起的关键。通常在应用程序中只运行一个调度器。应用程序开发者通常不直接处理作业存储(job stores)、执行器(executors)或触发器(triggers)。相反,调度器提供了适当的接口来处理所有这些。通过调度器配置作业存储和执行器,以及添…...

微服务实战系列之ZooKeeper(实践篇)

前言 关于ZooKeeper&#xff0c;博主已完整的通过庖丁解牛式的“解法”&#xff0c;完成了概述。我想掌握了这些基础原理和概念后&#xff0c;工作的问题自然迎刃而解&#xff0c;甚至offer也可能手到擒来&#xff0c;真实一举两得&#xff0c;美极了。 为了更有直观的体验&a…...

C++ 开发中为什么要使用继承

为何继承 实验介绍 继承是 C++ 中的特性之一,使用继承能够有效减轻工作量,使得开发更加高效。 知识点 什么是继承为何继承继承的内容权限关键字什么是继承 生活中继承是指孩子继承父亲的财产等。C++ 使用了这一思想,却又与生活中的继承不一样。 在使用继承时,派生类是…...

2020蓝桥杯c组纸张大小

题目名字 纸张大小 题目链接 题意 给一张纸&#xff0c;通过不断折叠&#xff0c;求最终长宽&#xff0c;给十个数字&#xff0c;输入哪个数字就求哪次折叠的长宽&#xff0c;其实就是&#xff0c;每次折叠后长度的一半变为宽度&#xff0c;原来的宽度变成长度 思路 因为数字…...

【Image】图像处理

计算机视觉 CV Perception 如自动驾驶领域。 只要是从所谓的图像当中去抽取信息的过程&#xff0c;我们都叫做Perception。 视觉检测可以涵盖二维检测&#xff0c;如车辆、人和信号灯的检测。另外&#xff0c;还可以控制三维信息&#xff0c;直接在三维空间中操作数据。 SL…...

JAVA对文档加密

当 Word 文档中包含无法公开的机密信息时&#xff0c;我们可以对其进行加密&#xff0c;使其在没有密码的情况下无法打开。本文将向您介绍如何使用 Spire.Doc for Java 加密 Word 文档和移除 Word 密码保护。 加密 Word 文档删除 Word 密码保护 安装 Spire.Doc for Java 首先…...

EmbedAI:一个可以上传文件训练自己ChatGPT的AI工具,妈妈再也不用担心我的GPT不会回答问题

功能介绍&#xff1a; 个性化定制&#xff1a;提供灵活的训练选项&#xff0c;用户能够通过文件、网站、Notion文档甚至YouTube等多种数据源对ChatGPT进行训练&#xff0c;以满足不同领域和需求的个性化定制。广泛应用场景&#xff1a;ChatGPT支持多种用例&#xff0c;包括智能…...

runCatching异常捕获onSuccess/onFailure返回函数,Kotlin

runCatching异常捕获onSuccess/onFailure返回函数&#xff0c;Kotlin fun test(a: Int, b: Int) {runCatching {a / b}.onSuccess {println("onSuccess: $it")return ok(it)}.onFailure {println("onFailure: $it")return fail(it)} }fun ok(o: Any) {prin…...

IDEA报错处理

问题1 IDEA 新建 Maven 项目没有文件结构 pom 文件为空 将JDK换成1.8后解决。 网络说法&#xff1a;别用 java18&#xff0c;换成 java17 或者 java1.8 都可以&#xff0c;因为 java18 不是 LTS 版本&#xff0c;有着各种各样的问题。。...

使用动画曲线编辑器打造炫酷的3D可视化ACE

前言 在制作3D可视化看板时&#xff0c;除了精细的模型结构外&#xff0c;炫酷的动画效果也是必不可少的。无论是复杂的还是简单的动画效果&#xff0c;要实现100%的自然平滑都是具有挑战性的工作。这涉及到物理引擎的计算和对动画效果的数学建模分析。一般来说&#xff0c;只…...

使用 React 和 ECharts 创建地球模拟扩散和飞线效果

在本博客中&#xff0c;我们将学习如何使用 React 和 ECharts 创建一个酷炫的地球模拟扩散效果。我们将使用 ECharts 作为可视化库&#xff0c;以及 React 来构建我们的应用。地球贴图在文章的结尾。 最终效果 准备工作 首先&#xff0c;确保你已经安装了 React&#xff0c;并…...

http状态码(一)400报错

一 400报错汇总 ① 综述 一、4xx状态码报错说明&#xff1a; 客户端行为导致的报错二、通用的4xxHTTP报错1) 4002) 4013) 4034) 4045) 405 --> 不允许方法&#xff0c;可能跨域或者nginx限制请求方法6) 4087) 4138) 419三、ngin自身定义的4xx报错495、496、497、498、4…...

【深度学习目标检测】五、基于深度学习的安全帽识别(python,目标检测)

深度学习目标检测方法则是利用深度神经网络模型进行目标检测&#xff0c;主要有以下几种&#xff1a; R-CNN系列&#xff1a;包括R-CNN、Fast R-CNN、Faster R-CNN等&#xff0c;通过候选区域法生成候选目标区域&#xff0c;然后使用卷积神经网络提取特征&#xff0c;并通过分类…...

芒果RT-DETR改进实验:深度集成版目标检测 RT-DETR 热力图来了!支持自定义数据集训练出来的模型

💡该教程为改进RT-DETR指南,属于《芒果书》📚系列,包含大量的原创改进方式🚀 💡🚀🚀🚀内含改进源代码 按步骤操作运行改进后的代码即可💡更方便的统计更多实验数据,方便写作 芒果RT-DETR改进实验:深度集成版目标检测 RT-DETR 热力图来了!支持自定义数据集…...

c语言实验八

实验1&#xff1a;在主函数中输入num个字符串&#xff0c;写一个函数&#xff0c;从传入的num个字符串中找出最长的一个字符串&#xff0c;并通过形参指针max传回该串地址&#xff0c;在主函数中输出。&#xff08;注意&#xff1a;用****作为结束输入的标志。&#xff09; #i…...

ArcGIS Pro SDK文件选择对话框

文件保存对话框 // 获取默认数据库var gdbPath Project.Current.DefaultGeodatabasePath;//设置文件的保存路径SaveItemDialog saveLayerFileDialog new SaveItemDialog(){Title "Save Layer File",OverwritePrompt true,//获取或设置当同名文件已存在时是否出现…...

ACT、NAT、NATPT和EASY-IP

目录 一、ACL 1.ACL 2.ACL的两种应用匹配机制 3.ACL的基本类型 4.ACL命令操作 5.ACL实验&#xff1a; 4.ACL的应用原则&#xff1a; 5.匹配原则&#xff1a; 二、NAT 1.NAT的原理及作用&#xff1a; 2.NAT分类 3.NAT配置 三、EASY-ip实验 四、NATPT 五、通配符 …...

HTML实现每天单词积累

注册页面 <!DOCTYPE html> <html> <head><meta charset"UTF-8"><title>注册</title><style>body {font-family: Arial, sans-serif;background-color: #f5f5f5;}form {max-width: 500px;margin: 50px auto;padding: 40px…...

【ECMAScript笔记二】运算符分类,流程控制(顺序结构、分支结构、循环结构)

文章目录 4 运算符4.1 算术运算符4.2 递增和递减运算符4.3 比较运算符4.4 逻辑运算符4.5 赋值运算符4.6 运算优先级 5 流程控制5.1 顺序结构5.2 分支结构5.2.1 if 语句5.2.2 switch 语句 5.3 循环结构5.3.1 for循环5.3.2 while循环5.3.3 do while循环5.3.4 continue和break 5.4…...

ShenYu网关注册中心之Zookeeper注册原理

文章目录 1、客户端注册流程1.1、读取配置1.1.1、用于注册的 ZookeeperClientRegisterRepository1.1.2、用于扫描构建 元数据 和 URI 的 SpringMvcClientEventListener 1.2、扫描注解&#xff0c;注册元数据和URI1.2.1、构建URI并写入Disruptor1.2.2、构建元数据并写入Disrupto…...

高级C#技术(二)

前言 本章为高级C#技术的第二节也是最后一节。前一节在下面这个链接 高级C#技术https://blog.csdn.net/qq_71897293/article/details/134930989?spm1001.2014.3001.5501 匿名类型 匿名类型如其名&#xff0c;匿名的没有指定变量的具体类型。 举个例子&#xff1a; 1 创建…...

【性能测试】基础知识篇-压力模型

常见压力模式 并发模式&#xff08;即虚拟用户模式&#xff09;和RPS模式&#xff08;即Requests Per Second&#xff0c;每秒请求数&#xff0c;吞吐量模式&#xff09;。 本文介绍这两种压力模式的区别&#xff0c;以便根据自身业务场景选择更合适的压力模式。 并发模式 …...

springboot-redis设置定时触发任务详解

最近研究了一下“redis定时触发”&#xff0c;网上查了查资料&#xff0c;这里记录一下。 从Redis 2.8.0开始&#xff0c;Redis加入了发布/订阅模式以及键空间消息提醒&#xff08;keyspace notification&#xff09;功能。键空间消息提醒提供了允许客户端通过订阅指定信道获取…...