【Triton 教程】低内存 Dropout
Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。
更多 Triton 中文文档可访问 →https://triton.hyper.ai/
在本教程中,您将编写一个内存高效的 Dropout 实现,其状态将由单个 int32 seed 组成。这与传统 Dropout 实现不同,传统实现通常由与输入 shape 相同的位掩码张量组成。
在这过程中,您将学习到以下内容:
-
PyTorch 中 原生实现 Dropout 的局限性。
-
Triton 中的并行伪随机数生成。
简介
Dropout 是在 [SRIVASTAVA2014] 中引入的一种技术,用于改善低数据条件下深度神经网络的性能,通常用于正则化。它接受一个向量作为输入,并生成相同 shape 的输出向量。输出中的每个标量都有概率 p 被设为零,否则直接从输入复制。这使得网络在仅有输入的 1−p 标量时也能表现良好。
在评估阶段,为了充分利用网络的能力,将 p 设为 0。但是简单地将 p 设为 0 会增加输出的范数,可能会人为地降低输出的 softmax temperature。为了防止这种情况发生,输出被缩放为 1/(1-p),这使得无论 dropout 概率如何都能保持一致的范数。
Baseline
首先看一下 baseline 的实现。
import tabulate
import torchimport triton
import triton.language as tl@triton.jit
def _dropout(x_ptr, # 输入指针x_keep_ptr, # pointer to a mask of 0s and 1s 由 0 和 1 组成的掩码的指针output_ptr, # pointer to the output 输出指针n_elements, # number of elements in the `x` tensor `x` 张量的元素数量p, # probability that an element of `x` is changed to zero 元素 `x` 被设置为 0 的概率BLOCK_SIZE: tl.constexpr,
):pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)mask = offsets < n_elements# Load data# 加载数据x = tl.load(x_ptr + offsets, mask=mask)x_keep = tl.load(x_keep_ptr + offsets, mask=mask)# The line below is the crucial part, described in the paragraph above!# 下一行是上段描述的关键部分output = tl.where(x_keep, x / (1 - p), 0.0)# Write-back output# 写回输出tl.store(output_ptr + offsets, output, mask=mask)def dropout(x, x_keep, p):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)return output# Input tensor
# 输入张量
x = torch.randn(size=(10, )).cuda()
# Dropout mask
# Dropout 掩码
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([["input"] + x.tolist(),["keep mask"] + x_keep.tolist(),["output"] + output.tolist(),
]))
Out:
种子化 Dropout
上述 Dropout 实现效果良好,但管理 Dropout 状态可能会变得复杂,特别是在考虑反向传播和重新计算/检查点场景时。在这里,我们描述一种替代实现,它具有以下优点:
- 更小的内存占用。
- 较少的数据移动。
- 简化了在多次调用内核函数时持久化随机性的管理。
生成 Triton 中的伪随机数很简单!在本教程中,我们将使用 triton.language.rand
函数,该函数基于给定的种子和一组 int32
偏移量生成一个块的均匀分布的 float32
值,范围在 (0, 1) 内。但如果你需要,Triton 也提供其他随机数生成策略。
注意 Triton 的 PRNG 实现基于 Philox 算法(详见 [SALMON2011])。
现在将所有内容整合起来。
@triton.jit
def _seeded_dropout(x_ptr,output_ptr,n_elements,p,seed,BLOCK_SIZE: tl.constexpr,
):# compute memory offsets of elements handled by this instance# 计算由此实例处理的元素的内存偏移量pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)# load data from x# 从 x 读取数据mask = offsets < n_elementsx = tl.load(x_ptr + offsets, mask=mask)# randomly prune it# 随机修剪random = tl.rand(seed, offsets)x_keep = random > p# write-back# 写回output = tl.where(x_keep, x / (1 - p), 0.0)tl.store(output_ptr + offsets, output, mask=mask)def seeded_dropout(x, p, seed):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)return outputx = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
# 与基线相比 - dropout 掩码从未被实例化!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)print(tabulate.tabulate([["input"] + x.tolist(),["output (seed = 123)"] + output.tolist(),["output (seed = 123)"] + output2.tolist(),["output (seed = 512)"] + output3.tolist(),
]))
Out:
大功告成!我们现在有了一个 Triton 内核,可以在给定相同种子的情况下应用一致的 dropout 掩码。与传统的 dropout 实现相比,这种方法减少了内存开销并简化了状态管理。
练习
- 扩展内核以处理矩阵,并使用一个种子向量 — 每行一个种子。
- 添加对 striding 的支持。
- (挑战)实现稀疏 Johnson-Lindenstrauss 变换的内核,每次使用种子动态生成投影矩阵。
参考文献
-
[SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011
-
[SRIVASTAVA2014] Nitish Srivastava et al., “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014
Download Jupyter notebook: 04-low-memory-dropout.ipynb
Download Python source code: 04-low-memory-dropout.py
Download zipped: 04-low-memory-dropout.zip
相关文章:

【Triton 教程】低内存 Dropout
Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。 更多 Triton 中文文档可访问 →https://triton.hyper.ai/ 在本教程中,您将编…...
npx创建项目时,error fetch failed.TypeError: fetch failed
npx创建项目时,报以下错误: error fetch failed. TypeError: fetch failedat node:internal/deps/undici/undici:12345:11at process.processTicksAndRejections (node:internal/process/task_queues:95:5)at async getTemplateVersion (C:\Users\ymt30…...
《Kotlin实战》-附录
附录 本部分内容只是简单列举下Kotlin应用以便指引进一步深入学习Kotlin。 附录A:构建Kotlin项目 本节只会记录下gradle的应用,其他需要时请自行搜索查看。 A.1 用Gradle构建Kotlin代码的项目 构建Kotlin项目的标准Gradle脚本如下: bui…...

yelp数据集上识别潜在的热门商家
yelp数据集是研究B2C业态的一个很好的数据集,要识别潜在的热门商家是一个多维度的分析过程,涉及用户行为、商家特征和社区结构等多个因素。从yelp数据集里我们可以挖掘到下面信息有助于识别热门商家 用户评分和评论分析 评分均值: 商家的平均评分是反映其…...

【Linux】进程信号全攻略(一)
🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 信号的概念 二:🔥 信号产生的方式 🦋 使用键盘🦋 系统调用函数🦋 软件条件🦋 进程异…...

linux文件重命名
Linux文件重命名 文件名显示异常问题出在哪里批量改名扩展 文件名显示异常 跑测CTS,linux环境看跑测结果log file显示没问题,倘若windows下看log file名却显示异常,不太方便操作。 问题出在哪里 linux环境下文件名可以显示正常࿰…...

如何选择适合的AWS EC2实例类型
在云计算的世界中,Amazon Web Services(AWS)提供了丰富的服务,其中Elastic Compute Cloud(EC2)是最受欢迎的服务之一。选择合适的EC2实例类型对于确保应用程序的性能和成本效益至关重要。我们九河云通过本文…...

【Uniapp】Uniapp Android原生插件开发指北
前言 在uniapp开发中当HBuilderX中提供的能力无法满足App功能需求,需要通过使用Andorid/iOS原生开发实现时,或者是第三方公司提供的是Android的库,这时候可使用App离线SDK开发原生插件来扩展原生能力。 插件类型有两种,Module模…...
【随手笔记】FLASH-W25Q16(三)
#include "bsp_w25q16.h"/*内部函数声明区*/ static HAL_StatusTypeDef bsp_w25q_Transmit(uint8_t * T_pData, uint16_t T_Size); static HAL_StatusTypeDef bsp_w25q_Receive(uint8_t * R_pData, uint16_t R_Size);/*内部函数定义区*//* 函数参数:1、T_…...

2024软件测试面试热点问题
🍅 点击文末小卡片 ,免费获取软件测试全套资料,资料在手,涨薪更快 大厂面试热点问题 1、测试人员需要何时参加需求分析? 如果条件循序 原则上来说 是越早介入需求分析越好 因为测试人员对需求理解越深刻 对测试工…...

【JAVA】java 企业微信信息推送
前言 JAVA中 将信息 推送到企业微信 // 企微消息推送messageprivate String getMessage(String name, String problemType, String pushResults, Long orderId,java.util.Date submitTime, java.util.Date payTime) {String message "对接方:<font color\…...

介绍一下数组(c基础)(smart 版)
c初期,记住规则,用规则。 我只是介绍规则。(有详细版,这适合smart人看) 数组(同类型) int arr[n] {} ; int 是 元素类型。 int arr[n] {} ; arr为标识符。 {} 集合,元素有次…...

Java项目实战II基于Spring Boot的个人云盘管理系统设计与实现(开发文档+数据库+源码)
目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 基于Spring Boot的个人云盘管理系统设计…...
探索数据科学与大数据技术专业本科生的广阔就业前景
随着信息技术的不断发展,数据科学与大数据技术已经成为各大行业的关键推动力。在这样一个数据驱动的时代,越来越多的企业依赖数据来驱动决策、优化运营和创造价值。因此,数据科学与大数据技术专业的本科生在就业市场上具有广阔的前景和多样的…...
微服务架构面试内容整理-Zuul
Zuul 是由 Netflix 开发的一个边缘服务(API 网关),用于动态路由、监控、认证、以及对微服务架构中的请求进行过滤。它在微服务架构中扮演着重要的角色,提供了一种集中管理和控制服务访问的方式。以下是 Zuul 的主要特点、工作原理和使用场景: 主要特点 1. 动态路由: Zuu…...

解决Knife4j 接口界面UI中文乱码问题
1、查看乱码情况 2、修改 编码设置 3、删除 target 文件 项目重新启动 被坑死了...
微服务架构面试内容整理-Sleuth
Spring Cloud Sleuth 是一个分布式追踪工具,用于监控微服务系统中请求的传播情况。它通过在微服务之间传递追踪信息,帮助开发者理解系统的行为,快速定位性能瓶颈和问题。以下是 Sleuth 的主要特点、工作原理和使用场景: 主要特点 …...
Go语言的接口示例
Go语言的接口(interface)是一种轻量级的多态性实现方式,是构建高扩展性、高复用性代码的利器。Go语言的接口非常灵活,不要求显式的实现声明,只要一个类型实现了接口规定的方法,它就可以被视为该接口的实现者。在本篇博客中,我们将通过多个实际示例,探讨Go语言接口的使用…...

【Apache ECharts】<农作物病害发生防治面积>
在vs Code里打开, 实现 1. 首先引入 echarts.min.js 资源 2. 在body部分设一个 div,设置 id 为 main 3. 设置 script 3.1 基于准备好的dom,初始化echarts实例 var myChart echarts.init(document.getElementById(main)); 3.2 指定图表的…...

基于vue3实现的聊天机器人前端(附代码)
<template><div class"container"><!-- 页面头部 --><header><h1>跟它说说话吧!</h1><p>一个活泼的伙伴,为你提供情感支持!</p></header><!-- 聊天容器 --><div c…...
生成xcframework
打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

微信小程序之bind和catch
这两个呢,都是绑定事件用的,具体使用有些小区别。 官方文档: 事件冒泡处理不同 bind:绑定的事件会向上冒泡,即触发当前组件的事件后,还会继续触发父组件的相同事件。例如,有一个子视图绑定了b…...
PHP和Node.js哪个更爽?
先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...
uni-app学习笔记二十二---使用vite.config.js全局导入常用依赖
在前面的练习中,每个页面需要使用ref,onShow等生命周期钩子函数时都需要像下面这样导入 import {onMounted, ref} from "vue" 如果不想每个页面都导入,需要使用node.js命令npm安装unplugin-auto-import npm install unplugin-au…...
五年级数学知识边界总结思考-下册
目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)
笔记整理:刘治强,浙江大学硕士生,研究方向为知识图谱表示学习,大语言模型 论文链接:http://arxiv.org/abs/2407.16127 发表会议:ISWC 2024 1. 动机 传统的知识图谱补全(KGC)模型通过…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...
Spring AI与Spring Modulith核心技术解析
Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...

Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”
2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...