当前位置: 首页 > news >正文

DLA :pytorch添加算子

pytorch的C++ extension写法

        这部分主要介绍如何在pytorch中添加自定义的算子,需要以下cuda基础。就总体的逻辑来说正向传播需要输入数据,反向传播需要输入数据和上一层的梯度,然后分别实现这两个kernel,将这两个kernerl绑定到pytorch即可。

add

  • 但实际上来说,这可能不是一个很好的教程,因为加法中没有对输入的grad_out进行继续的操作(不用写cuda的操作)。所以实际上只需要正向传播的launch_add2函数。更重要的是作者大佬写了博客介绍。
// https://github.com/godweiyang/NN-CUDA-Example/blob/master/kernel/add2_kernel.cu__global__ void add2_kernel(float* c,const float* a,const float* b,int n) {for (int i = blockIdx.x * blockDim.x + threadIdx.x; \i < n; i += gridDim.x * blockDim.x) {c[i] = a[i] + b[i];}
}void launch_add2(float* c,const float* a,const float* b,int n) {// 创建 [(n + 1023) / 1024 ,1 ,1]的三维向量数据dim3 grid((n + 1023) / 1024);//dim3 为CUDA中三维向量结构体// 创建 [1024 ,1 ,1]的三维向量数据dim3 block(1024);// 函数add2_kernel实现两个n维向量相加// 共有(n + 1023) / 1024*1*1个block , 每个block有1024*1*1个线程add2_kernel<<<grid, block>>>(c, a, b, n);
}

在这里插入图片描述

binary activation function

  • 正向计算为:
x > 1 ? 1 : -1;// 也可以使用sign() 函数(求符号函数)实现
  • 这篇文章作者没有自己写正向传播的算子,使用的是at::sign
// https://github1s.com/jxgu1016/BinActivateFunc_PyTorch/blob/master/src/cuda/BinActivateFunc_cuda.cpp#L17-L22
at::Tensor BinActivateFunc_forward(at::Tensor input) 
{CHECK_INPUT(input);return at::sign(input);
}
  • 这篇文章用的Setuptools将写好的算子和pytorch链接起来,运行时需要安装一下(JIT运行时编译也很香,代码直接运行,就是cmakelist.txt需要各种环境配置很麻烦)。绑定部分见链接。以下是作者实现的反向传播的kernel:
// https://github.com/jxgu1016/BinActivateFunc_PyTorch/blob/master/src/cuda/BinActivateFunc_cuda_kernel.cu
#include <ATen/ATen.h>#include <cuda.h>
#include <cuda_runtime.h>#include <vector>// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)namespace {
template <typename scalar_t>
__global__ void BinActivateFunc_cuda_backward_kernel(const int nthreads,const scalar_t* __restrict__ input_data,scalar_t* __restrict__ gradInput_data) 
{CUDA_KERNEL_LOOP(n, nthreads) {if (*(input_data + n) > 1 || *(input_data + n) < -1) {*(gradInput_data + n) = 0;}}
}
} // namespaceint BinActivateFunc_cuda_backward(at::Tensor input,at::Tensor gradInput) 
{const int nthreads = input.numel();const int CUDA_NUM_THREADS = 1024;const int nblocks = (nthreads + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;AT_DISPATCH_FLOATING_TYPES(input.type(), "BinActivateFunc_cuda_backward", ([&] {BinActivateFunc_cuda_backward_kernel<scalar_t><<<nblocks, CUDA_NUM_THREADS>>>(nthreads,input.data<scalar_t>(),gradInput.data<scalar_t>());}));return 1;
}

swish

// https://github1s.com/thomasbrandon/swish-torch/blob/HEAD/csrc/swish_kernel.cu
#include <torch/types.h>
#include <cuda_runtime.h>
#include "CUDAApplyUtils.cuh"// TORCH_CHECK replaces AT_CHECK in PyTorch 1,2, support 1.1 as well.
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif#ifndef __CUDACC_EXTENDED_LAMBDA__
#error "please compile with --expt-extended-lambda"
#endifnamespace kernel {
#include "swish.h"using at::cuda::CUDA_tensor_apply2;
using at::cuda::CUDA_tensor_apply3;
using at::cuda::TensorArgType;template <typename scalar_t>
void
swish_forward(torch::Tensor &output,const torch::Tensor &input
) {CUDA_tensor_apply2<scalar_t,scalar_t>(output, input,[=] __host__ __device__ (scalar_t &out, const scalar_t &inp) {swish_fwd_func(out, inp);},TensorArgType::ReadWrite, TensorArgType::ReadOnly);
}template <typename scalar_t>
void
swish_backward(torch::Tensor &grad_inp,const torch::Tensor &input,const torch::Tensor &grad_out
) {CUDA_tensor_apply3<scalar_t,scalar_t,scalar_t>(grad_inp, input, grad_out,[=] __host__ __device__ (scalar_t &grad_inp, const scalar_t &inp, const scalar_t &grad_out) {swish_bwd_func(grad_inp, inp, grad_out);},TensorArgType::ReadWrite, TensorArgType::ReadOnly, TensorArgType::ReadOnly);
}} // namespace kernelvoid
swish_forward_cuda(torch::Tensor &output, const torch::Tensor &input
) {auto in_arg  = torch::TensorArg(input,  "input",  0),out_arg = torch::TensorArg(output, "output", 1);torch::checkAllDefined("swish_forward_cuda", {in_arg, out_arg});torch::checkAllSameGPU("swish_forward_cuda", {in_arg, out_arg});AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "swish_forward_cuda", [&] {kernel::swish_forward<scalar_t>(output, input);});
}void
swish_backward_cuda(torch::Tensor &grad_inp, const torch::Tensor &input, const torch::Tensor &grad_out
) {auto gi_arg = torch::TensorArg(grad_inp, "grad_inp", 0),in_arg = torch::TensorArg(input,    "input",    1),go_arg = torch::TensorArg(grad_out, "grad_out", 2);torch::checkAllDefined("swish_backward_cuda", {gi_arg, in_arg, go_arg});torch::checkAllSameGPU("swish_backward_cuda", {gi_arg, in_arg, go_arg});AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_inp.scalar_type(), "swish_backward_cuda", [&] {kernel::swish_backward<scalar_t>(grad_inp, input, grad_out);});
}

cg

  • ScatWave是使用CUDA散射的Torch实现,主要使用lua语言https://github.com/edouardoyallon/scatwave

  • https://github.com/huangtinglin/PyTorch-extension-Convolution

  • This is a tutorial to explore how to customize operations in PyTorch.

  • https://pytorch.org/tutorials/advanced/cpp_extension.html

  • 台湾博主 Pytorch+cpp/cuda extension 教學 tutorial 1 - English CC - B站搬运地址

  • pytorch的C++ extension写法

  • https://github.com/salinaaaaaa/NVIDIA-GPU-Tensor-Core-Accelerator-PyTorch-OpenCV

  • https://github.com/MariyaSha/Inference_withTorchTensorRT

  • 项目介绍了简单的CUDA入门,涉及到CUDA执行模型、线程层次、CUDA内存模型、核函数的编写方式以及PyTorch使用CUDA扩展的两种方式。通过该项目可以基本入门基于PyTorch的CUDA扩展的开发方式。

RWKV CUDA

  • 实例:手写 CUDA 算子,让 Pytorch 提速 20 倍(某特殊算子) https://zhuanlan.zhihu.com/p/476297195
  • https://github.com/BlinkDL/RWKV-CUDA
  • The CUDA version of the RWKV language model

数据加速

  • 用于在 Pytorch 中更快地固定 CPU <-> GPU 传输的库

环境

  • Docker images and github actions for building packages containing PyTorch C++/CUDA extensions.
    一个构建系统,用于生成(相对)轻量级和便携式的 PyPI 轮子,其中包含 PyTorch C++/CUDA 扩展。使用Torch Extension Builder构建的轮子动态链接到用户PyTorch安装中包含的Torch和CUDA库。最终用户计算机上不需要安装 CUDA。

CG

  • 又发现一个部署工具
研究人员很难将机器学习模型交付到生产环境。解决方案的一部分是Docker,但要让它工作非常复杂:Dockerfiles,预/后处理,Flask服务器,CUDA版本。通常情况下,研究人员必须与工程师坐下来部署该死的东西。安德烈亚斯和本创造了Cog。Andreas曾经在Spotify工作,在那里他构建了使用Docker构建和部署ML模型的工具。Ben 曾在 Docker 工作,在那里他创建了 Docker Compose。我们意识到,除了Spotify之外,其他公司也在使用Docker来构建和部署机器学习模型。Uber和其他公司也建立了类似的系统。因此,我们正在制作一个开源版本,以便其他人也可以这样做。如果您有兴趣使用它或想与我们合作,请与我们联系。我们在 Discord 上或给我们发电子邮件 team@replicate.com.

相关文章:

DLA :pytorch添加算子

pytorch的C extension写法 这部分主要介绍如何在pytorch中添加自定义的算子&#xff0c;需要以下cuda基础。就总体的逻辑来说正向传播需要输入数据&#xff0c;反向传播需要输入数据和上一层的梯度&#xff0c;然后分别实现这两个kernel,将这两个kernerl绑定到pytorch即可。 a…...

Java特殊时间格式转化

平常开发过程当中&#xff0c;我们可能会见到有的日期格式是这样的。 1、2022-12-21T12:20:1608:00 2、2022-12-21T12:20:16.0000800 3、2022-12-21T12:20:16.00008:00下面来说一下这种时间格式怎么转换 第一种&#xff1a;2022-12-21T12:20:1608:00 代码如下&#xff1a; p…...

在CSDN学Golang云原生(Kubernetes声明式资源管理Kustomize)

一&#xff0c;生成资源 在 Kubernetes 中&#xff0c;我们可以通过 YAML 或 JSON 文件来定义和创建各种资源对象&#xff0c;例如 Pod、Service、Deployment 等。下面是一个简单的 YAML 文件示例&#xff0c;用于创建一个 Nginx Pod&#xff1a; apiVersion: v1 kind: Pod m…...

后台管理系统中常见的三栏布局总结:使用element ui构建

vue2 使用 el-menu构建的列表布局&#xff1a; 列表可以折叠展开 <template><div class"home"><header><el-button type"primary" click"handleClick">切换</el-button></header><div class"conte…...

SpringCloud学习路线(10)——分布式搜索ElasticSeach基础

一、初识ES &#xff08;一&#xff09;概念&#xff1a; ES是一款开源搜索引擎&#xff0c;结合数据可视化【Kibana】、数据抓取【Logstash、Beats】共同集成为ELK&#xff08;Elastic Stack&#xff09;&#xff0c;ELK被广泛应用于日志数据分析和实时监控等领域&#xff0…...

CSS翻转DIV展示顺序

项目国际化开发中&#xff0c;阿拉伯语是从右往左读的&#xff0c;在做样式兼容时&#xff0c;一些表单代码块也需要 label在右&#xff0c;表单在左。如果整个项目改div的话代价太大了&#xff0c;所以需要做样式翻转。 html <div class"container"><div …...

python 源码中 PyId_stdout 如何定义的

python 源代码中遇到一个变量名 PyId_stdout&#xff0c;搜不到在哪里定义的&#xff0c;如下只能搜到引用的位置&#xff08;python3.8.10&#xff09;&#xff1a; 找了半天发现是用宏来构造的声明语句&#xff1a; // filepath: Include/cpython/object.h typedef struct …...

Mybatis映射关系mybatis核心配置文件

目录 1.Mybatis映射关系 1.1一对一映射之resultType 1.2resultMap处理映射关系 2.mybatis核心配置文件 1. properties&#xff08;属性&#xff09; 2. settings&#xff08;设置&#xff09; 3.typeAliases&#xff08;类型别名&#xff09; 4.environments&#xff0…...

Mybatis中limit用法与分页查询

错误示范 错误示范一&#xff1a; <select id"fileInspectionList" resultType"map">SELECT <include refid"aip_n_static_cols"/>FROM sys_inspection_form WHERE<if test" type admin.toString() ">dept_id …...

libcomposite: Unknown symbol config_group_init (err 0)

加载libcomposite.ko 失败 问题描述 如图&#xff0c;在做USB OTG 设备模式的时候需要用到libcomposite.ko驱动&#xff0c;加载失败了。 原因&解决方法 有一个依赖叫configfs.ko的驱动没有安装。可以从内核代码的fs/configfs/configfs.ko中找到这个驱动。先加载confi…...

Spring Tool Suite 4

参考&#xff1a;Spring tool suite4 安装及配置_springtoolsuite4_猿界零零七的博客-CSDN博客 下载&#xff1a;Spring | Tools 将下载的JAR进行解压两次&#xff0c;直至解压出contents中的sts 双击启动 第一次打开需要指定工作区文件夹 配置Maven的config 安装插件...

带你读论文第三期:微软研究员、北大博士陈琪,荣获NeurIPS杰出论文奖

Datawhale干货 来源&#xff1a;WhalePaper&#xff0c;负责人&#xff1a;芙蕖 WhalePaper简介 由Datawhale团队成员发起&#xff0c;对目前学术论文中比较成熟的 Topic 和开源方案进行分享&#xff0c;通过一起阅读、分享论文学习的方式帮助大家更好地“高效全面自律”学习&…...

农业中的计算机视觉 2023

物体检测应用于检测田间收割机和果园苹果 一、说明 欢迎来到Voxel51的计算机视觉行业聚焦博客系列的第一期。每个月&#xff0c;我们都将重点介绍不同行业&#xff08;从建筑到气候技术&#xff0c;从零售到机器人等&#xff09;如何使用计算机视觉、机器学习和人工智能来推动…...

掌握三个基础平面构成法则 优漫动游

1.图形重复&#xff1a;通过重复使用同一种或类似的图形元素,创造出一种有节奏、有重复感的视觉效果。这种设计手法可以使海报看起来更加统一和协调,增强视觉冲击力。 掌握三个基础平面构成法则 2.字体重复&#xff1a;通过重复使用同一种或类似的字体元素,创造出一种有序…...

叶工好容5-日志与监控

目录 前言 平台维度 docker运行状态 cAdvisor-日志采集者 Heapster-日志收集 metrics-server-出生决定成败 kube-state-metrics-不完美中的完美 应用维度 日志 部署方式 输出方式 工具选择 日志接入 监控 serviceMonitor Annotation Prometheus扩展性 Thanos …...

Dubbo 指定调用固定ip+port dubbo调用指定服务 dubbo调用不随机 dubbo自定义调用服务 dubbo点对点通信 dubbo指定ip

1. 在写分布式im时nami-im: 分布式im, 集群 zookeeper netty kafka nacos rpc主要为gate&#xff08;长连接服务&#xff09; logic &#xff08;业务&#xff09; lsb &#xff08;负载均衡&#xff09;store&#xff08;存储&#xff09; - Gitee.com&#xff0c;需要指定某一…...

BCNet论文精读

Title—标题 Boundary Constraint Network&#xff08;边界约束网络&#xff09; With Cross Layer Feature Integration&#xff08;跨层特征融合&#xff09; for Polyp Segmentation&#xff08;息肉分割&#xff09; 结构分析 标题结构由三部分组成&#xff0c;分别是本文…...

PHP8的注释-PHP8知识详解

欢迎你来到PHP服务网&#xff0c;学习《PHP8知识详解》系列教程&#xff0c;本文学习的是《PHP8的注释》。 什么是注释&#xff1f; 注释是在程序代码中添加的文本&#xff0c;用于解释和说明代码的功能、逻辑或其他相关信息。注释通常不会被编译器或解释器处理&#xff0c;而…...

优化企业集成架构:iPaaS集成平台助力数字化转型

前言 在数字化时代全面来临之际&#xff0c;企业正面临着前所未有的挑战与机遇。技术的迅猛发展与数字化转型正在彻底颠覆各行各业的格局&#xff0c;不断推动着企业迈向新的前程。然而&#xff0c;这一数字化时代亦衍生出一系列复杂而深奥的难题&#xff1a;各异系统之间数据…...

前端存储之sessionStorage和localStorage

sessionStorage sessionStorage是一种用于web浏览器中临时保存数据的客户端存储机制。它允许在同一个浏览器窗口的会话期间&#xff0c;保存和访问临时数据&#xff0c;而这些数据在用户关闭窗口或者标签页会被清除。每个sessionStorage对象都与当前的浏览器会话相关联&#x…...

上海亚商投顾:沪指放量大涨1.84% 证券股掀涨停潮

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 市场情绪 三大指数今日低开高走&#xff0c;沪指午后放量涨近2%&#xff0c;上证50盘中大涨超3%。大金融板块全线爆发&#…...

微服务划分的原则

微服务的划分 微服务的划分要保证的原则 单一职责原则 1、耦合性也称块间联系。指软件系统结构中各模块间相互联系紧密程度的一种度量。模块之间联系越紧密&#xff0c;其耦合性就越强&#xff0c;模块的独立性则越差。模块间耦合高低取决于模块间接口的复杂性、调用的方式及…...

作业 - 3

[ 作业 - 3 ] Industrial Melanism: The Case of the Peppered Moth melanism n. 黑化&#xff1b;黑变病&#xff1b;黑色素沉着症 peppered adj. 用胡椒调味的&#xff1b;加胡椒的&#xff0c;撒胡椒粉的 pepper的过去分词和过去式 moth n. 蛾;飞蛾 Paragraph 2 Over a …...

MTK联发科安卓核心板MT8385(Genio 500)规格参数资料_性能介绍

简介 MT8385安卓核心板 是一个高度集成且功能强大的物联网平台&#xff0c;具有以下主要特性&#xff1a; l 四核 Arm Cortex-A73 处理器 l 四核Arm Cortex-A53处理器 l Arm Mali™-G72 MP3 3D 图形加速器 (GPU)&#xff0c;带有 Vulkan 1.0、OpenGL ES 3.2 和 OpenCL™ 2.x …...

ChatGPT付费创作系统小程序端开发工具提示打开显示无法打开页面解决办法

很多会员在上传小程序前端时经常出现首页无法打开的情况&#xff0c;错误提示无法打开该页面&#xff0c;不支持打开&#xff0c;这种问题其实就是权限问题&#xff0c;页面是通过调用web-view访问&#xff0c;说明业务域名有问题&#xff0c;很多都是合法域名加了&#xff0c;…...

CVPR2023新作:pix2pix3D

Title: 3D-Aware Conditional Image SynthesisAffiliation: Carnegie Mellon University (卡内基梅隆大学)Authors: Kangle Deng, Gengshan Yang, Deva Ramanan, Jun-Yan ZhuKeywords: Image Synthesis, 3D-aware, Neural Radiance Fields, Interactive Editing, Conditional G…...

Django自定义用户错误记录

django.db.migrations.exceptions.InconsistentMigrationHistory: Migration admin.0001_initial is applied before its dependency mysit.0001_initial on database default.执行&#xff1a; 1 setttings.py: 先注释掉 django.contrib.admin 2 注释掉urls.py path(“admin/…...

Abaqus 导出单元刚度矩阵和全局刚度矩阵

Abaqus 导出单元刚度矩阵和全局刚度矩阵 首次创建&#xff1a;2023.7.29 最后更新&#xff1a;2023.7.29 如有什么改进的地方&#xff0c;欢迎大家讨论&#xff01; 详细情况请查阅&#xff1a;Abaqus Analysis User’s Guide 一、Abaqus 导出单元刚度矩阵 1.生成单元刚度矩阵…...

Pytorch(一)

目录 一、基本操作 二、自动求导机制 三、线性回归DEMO 3.1模型的读取与保存 3.2利用GPU训练时 四、常见的Tensor形式 五、Hub模块 一、基本操作 操作代码如下: import torch import numpy as np#创建一个矩阵 x1 torch.empty(5,3)# 随机值 x2 torch.rand(5,3)# 初始化…...

图数据库Neo4j学习三——cypher语法总结

1MATCH 1.1作用 MATCH是Cypher查询语言中用于从图数据库中检索数据的关键字。它的作用是在图中查找满足指定条件的节点和边&#xff0c;并返回这些节点和边的属性信息。 在MATCH语句中&#xff0c;通过节点标签和边类型来限定查找范围&#xff0c;然后通过WHERE语句来筛选符合…...