【Triton 教程】低内存 Dropout
Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
更多 Triton 中文文档可访问 →https://triton.hyper.ai/
在本教程中,您将编写一个内存高效的 Dropout 实现,其状态将由单个 int32 seed 组成。这与传统 Dropout 实现不同,传统实现通常由与输入 shape 相同的位掩码张量组成。
在这过程中,您将学习到以下内容:
-
PyTorch 中 原生实现 Dropout 的局限性。
-
Triton 中的并行伪随机数生成。
简介
Dropout 是在 [SRIVASTAVA2014] 中引入的一种技术,用于改善低数据条件下深度神经网络的性能,通常用于正则化。它接受一个向量作为输入,并生成相同 shape 的输出向量。输出中的每个标量都有概率 p 被设为零,否则直接从输入复制。这使得网络在仅有输入的 1−p 标量时也能表现良好。
在评估阶段,为了充分利用网络的能力,将 p 设为 0。但是简单地将 p 设为 0 会增加输出的范数,可能会人为地降低输出的 softmax temperature。为了防止这种情况发生,输出被缩放为 1/(1-p),这使得无论 dropout 概率如何都能保持一致的范数。
Baseline
首先看一下 baseline 的实现。
import tabulate
import torchimport triton
import triton.language as tl@triton.jit
def _dropout(x_ptr, # 输入指针x_keep_ptr, # pointer to a mask of 0s and 1s 由 0 和 1 组成的掩码的指针output_ptr, # pointer to the output 输出指针n_elements, # number of elements in the `x` tensor `x` 张量的元素数量p, # probability that an element of `x` is changed to zero 元素 `x` 被设置为 0 的概率BLOCK_SIZE: tl.constexpr,
):pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)mask = offsets < n_elements# Load data# 加载数据x = tl.load(x_ptr + offsets, mask=mask)x_keep = tl.load(x_keep_ptr + offsets, mask=mask)# The line below is the crucial part, described in the paragraph above!# 下一行是上段描述的关键部分output = tl.where(x_keep, x / (1 - p), 0.0)# Write-back output# 写回输出tl.store(output_ptr + offsets, output, mask=mask)def dropout(x, x_keep, p):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)return output# Input tensor
# 输入张量
x = torch.randn(size=(10, )).cuda()
# Dropout mask
# Dropout 掩码
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([["input"] + x.tolist(),["keep mask"] + x_keep.tolist(),["output"] + output.tolist(),
]))
Out:

种子化 Dropout
上述 Dropout 实现效果良好,但管理 Dropout 状态可能会变得复杂,特别是在考虑反向传播和重新计算/检查点场景时。在这里,我们描述一种替代实现,它具有以下优点:
- 更小的内存占用。
- 较少的数据移动。
- 简化了在多次调用内核函数时持久化随机性的管理。
生成 Triton 中的伪随机数很简单!在本教程中,我们将使用 triton.language.rand 函数,该函数基于给定的种子和一组 int32 偏移量生成一个块的均匀分布的 float32 值,范围在 (0, 1) 内。但如果你需要,Triton 也提供其他随机数生成策略。
注意 Triton 的 PRNG 实现基于 Philox 算法(详见 [SALMON2011])。
现在将所有内容整合起来。
@triton.jit
def _seeded_dropout(x_ptr,output_ptr,n_elements,p,seed,BLOCK_SIZE: tl.constexpr,
):# compute memory offsets of elements handled by this instance# 计算由此实例处理的元素的内存偏移量pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)# load data from x# 从 x 读取数据mask = offsets < n_elementsx = tl.load(x_ptr + offsets, mask=mask)# randomly prune it# 随机修剪random = tl.rand(seed, offsets)x_keep = random > p# write-back# 写回output = tl.where(x_keep, x / (1 - p), 0.0)tl.store(output_ptr + offsets, output, mask=mask)def seeded_dropout(x, p, seed):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)return outputx = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
# 与基线相比 - dropout 掩码从未被实例化!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)print(tabulate.tabulate([["input"] + x.tolist(),["output (seed = 123)"] + output.tolist(),["output (seed = 123)"] + output2.tolist(),["output (seed = 512)"] + output3.tolist(),
]))
Out:

大功告成!我们现在有了一个 Triton 内核,可以在给定相同种子的情况下应用一致的 dropout 掩码。与传统的 dropout 实现相比,这种方法减少了内存开销并简化了状态管理。
练习
- 扩展内核以处理矩阵,并使用一个种子向量 — 每行一个种子。
- 添加对 striding 的支持。
- (挑战)实现稀疏 Johnson-Lindenstrauss 变换的内核,每次使用种子动态生成投影矩阵。
参考文献
-
[SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011
-
[SRIVASTAVA2014] Nitish Srivastava et al., “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014
Download Jupyter notebook: 04-low-memory-dropout.ipynb
Download Python source code: 04-low-memory-dropout.py
Download zipped: 04-low-memory-dropout.zip
相关文章:
【Triton 教程】低内存 Dropout
Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。 更多 Triton 中文文档可访问 →https://triton.hyper.ai/ 在本教程中,您将编…...
npx创建项目时,error fetch failed.TypeError: fetch failed
npx创建项目时,报以下错误: error fetch failed. TypeError: fetch failedat node:internal/deps/undici/undici:12345:11at process.processTicksAndRejections (node:internal/process/task_queues:95:5)at async getTemplateVersion (C:\Users\ymt30…...
《Kotlin实战》-附录
附录 本部分内容只是简单列举下Kotlin应用以便指引进一步深入学习Kotlin。 附录A:构建Kotlin项目 本节只会记录下gradle的应用,其他需要时请自行搜索查看。 A.1 用Gradle构建Kotlin代码的项目 构建Kotlin项目的标准Gradle脚本如下: bui…...
yelp数据集上识别潜在的热门商家
yelp数据集是研究B2C业态的一个很好的数据集,要识别潜在的热门商家是一个多维度的分析过程,涉及用户行为、商家特征和社区结构等多个因素。从yelp数据集里我们可以挖掘到下面信息有助于识别热门商家 用户评分和评论分析 评分均值: 商家的平均评分是反映其…...
【Linux】进程信号全攻略(一)
🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 信号的概念 二:🔥 信号产生的方式 🦋 使用键盘🦋 系统调用函数🦋 软件条件🦋 进程异…...
linux文件重命名
Linux文件重命名 文件名显示异常问题出在哪里批量改名扩展 文件名显示异常 跑测CTS,linux环境看跑测结果log file显示没问题,倘若windows下看log file名却显示异常,不太方便操作。 问题出在哪里 linux环境下文件名可以显示正常࿰…...
如何选择适合的AWS EC2实例类型
在云计算的世界中,Amazon Web Services(AWS)提供了丰富的服务,其中Elastic Compute Cloud(EC2)是最受欢迎的服务之一。选择合适的EC2实例类型对于确保应用程序的性能和成本效益至关重要。我们九河云通过本文…...
【Uniapp】Uniapp Android原生插件开发指北
前言 在uniapp开发中当HBuilderX中提供的能力无法满足App功能需求,需要通过使用Andorid/iOS原生开发实现时,或者是第三方公司提供的是Android的库,这时候可使用App离线SDK开发原生插件来扩展原生能力。 插件类型有两种,Module模…...
【随手笔记】FLASH-W25Q16(三)
#include "bsp_w25q16.h"/*内部函数声明区*/ static HAL_StatusTypeDef bsp_w25q_Transmit(uint8_t * T_pData, uint16_t T_Size); static HAL_StatusTypeDef bsp_w25q_Receive(uint8_t * R_pData, uint16_t R_Size);/*内部函数定义区*//* 函数参数:1、T_…...
2024软件测试面试热点问题
🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 大厂面试热点问题 1、测试人员需要何时参加需求分析? 如果条件循序 原则上来说 是越早介入需求分析越好 因为测试人员对需求理解越深刻 对测试工…...
【JAVA】java 企业微信信息推送
前言 JAVA中 将信息 推送到企业微信 // 企微消息推送messageprivate String getMessage(String name, String problemType, String pushResults, Long orderId,java.util.Date submitTime, java.util.Date payTime) {String message "对接方:<font color\…...
介绍一下数组(c基础)(smart 版)
c初期,记住规则,用规则。 我只是介绍规则。(有详细版,这适合smart人看) 数组(同类型) int arr[n] {} ; int 是 元素类型。 int arr[n] {} ; arr为标识符。 {} 集合,元素有次…...
Java项目实战II基于Spring Boot的个人云盘管理系统设计与实现(开发文档+数据库+源码)
目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 基于Spring Boot的个人云盘管理系统设计…...
探索数据科学与大数据技术专业本科生的广阔就业前景
随着信息技术的不断发展,数据科学与大数据技术已经成为各大行业的关键推动力。在这样一个数据驱动的时代,越来越多的企业依赖数据来驱动决策、优化运营和创造价值。因此,数据科学与大数据技术专业的本科生在就业市场上具有广阔的前景和多样的…...
微服务架构面试内容整理-Zuul
Zuul 是由 Netflix 开发的一个边缘服务(API 网关),用于动态路由、监控、认证、以及对微服务架构中的请求进行过滤。它在微服务架构中扮演着重要的角色,提供了一种集中管理和控制服务访问的方式。以下是 Zuul 的主要特点、工作原理和使用场景: 主要特点 1. 动态路由: Zuu…...
解决Knife4j 接口界面UI中文乱码问题
1、查看乱码情况 2、修改 编码设置 3、删除 target 文件 项目重新启动 被坑死了...
微服务架构面试内容整理-Sleuth
Spring Cloud Sleuth 是一个分布式追踪工具,用于监控微服务系统中请求的传播情况。它通过在微服务之间传递追踪信息,帮助开发者理解系统的行为,快速定位性能瓶颈和问题。以下是 Sleuth 的主要特点、工作原理和使用场景: 主要特点 …...
Go语言的接口示例
Go语言的接口(interface)是一种轻量级的多态性实现方式,是构建高扩展性、高复用性代码的利器。Go语言的接口非常灵活,不要求显式的实现声明,只要一个类型实现了接口规定的方法,它就可以被视为该接口的实现者。在本篇博客中,我们将通过多个实际示例,探讨Go语言接口的使用…...
【Apache ECharts】<农作物病害发生防治面积>
在vs Code里打开, 实现 1. 首先引入 echarts.min.js 资源 2. 在body部分设一个 div,设置 id 为 main 3. 设置 script 3.1 基于准备好的dom,初始化echarts实例 var myChart echarts.init(document.getElementById(main)); 3.2 指定图表的…...
基于vue3实现的聊天机器人前端(附代码)
<template><div class"container"><!-- 页面头部 --><header><h1>跟它说说话吧!</h1><p>一个活泼的伙伴,为你提供情感支持!</p></header><!-- 聊天容器 --><div c…...
以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:
一、属性动画概述NETX 作用:实现组件通用属性的渐变过渡效果,提升用户体验。支持属性:width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项: 布局类属性(如宽高)变化时&#…...
Oracle查询表空间大小
1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...
C++ 基础特性深度解析
目录 引言 一、命名空间(namespace) C 中的命名空间 与 C 语言的对比 二、缺省参数 C 中的缺省参数 与 C 语言的对比 三、引用(reference) C 中的引用 与 C 语言的对比 四、inline(内联函数…...
现代密码学 | 椭圆曲线密码学—附py代码
Elliptic Curve Cryptography 椭圆曲线密码学(ECC)是一种基于有限域上椭圆曲线数学特性的公钥加密技术。其核心原理涉及椭圆曲线的代数性质、离散对数问题以及有限域上的运算。 椭圆曲线密码学是多种数字签名算法的基础,例如椭圆曲线数字签…...
浅谈不同二分算法的查找情况
二分算法原理比较简单,但是实际的算法模板却有很多,这一切都源于二分查找问题中的复杂情况和二分算法的边界处理,以下是博主对一些二分算法查找的情况分析。 需要说明的是,以下二分算法都是基于有序序列为升序有序的情况…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...
解析两阶段提交与三阶段提交的核心差异及MySQL实现方案
引言 在分布式系统的事务处理中,如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议(2PC)通过准备阶段与提交阶段的协调机制,以同步决策模式确保事务原子性。其改进版本三阶段提交协议(3PC…...
