当前位置: 首页 > news >正文

【Triton 教程】低内存 Dropout

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →https://triton.hyper.ai/

在本教程中,您将编写一个内存高效的 Dropout 实现,其状态将由单个 int32 seed 组成。这与传统 Dropout 实现不同,传统实现通常由与输入 shape 相同的位掩码张量组成。

在这过程中,您将学习到以下内容:

  • PyTorch 中 原生实现 Dropout 的局限性。

  • Triton 中的并行伪随机数生成。

简介

Dropout 是在 [SRIVASTAVA2014] 中引入的一种技术,用于改善低数据条件下深度神经网络的性能,通常用于正则化。它接受一个向量作为输入,并生成相同 shape 的输出向量。输出中的每个标量都有概率 p 被设为零,否则直接从输入复制。这使得网络在仅有输入的 1−p 标量时也能表现良好。

在评估阶段,为了充分利用网络的能力,将 p 设为 0。但是简单地将 p 设为 0 会增加输出的范数,可能会人为地降低输出的 softmax temperature。为了防止这种情况发生,输出被缩放为 1/(1-p),这使得无论 dropout 概率如何都能保持一致的范数。

Baseline

首先看一下 baseline 的实现。

import tabulate
import torchimport triton
import triton.language as tl@triton.jit
def _dropout(x_ptr,      # 输入指针x_keep_ptr, # pointer to a mask of 0s and 1s 由 01 组成的掩码的指针output_ptr, # pointer to the output 输出指针n_elements, # number of elements in the `x` tensor `x` 张量的元素数量p,          # probability that an element of `x` is changed to zero 元素 `x` 被设置为 0 的概率BLOCK_SIZE: tl.constexpr,
):pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)mask = offsets < n_elements# Load data# 加载数据x = tl.load(x_ptr + offsets, mask=mask)x_keep = tl.load(x_keep_ptr + offsets, mask=mask)# The line below is the crucial part, described in the paragraph above!# 下一行是上段描述的关键部分output = tl.where(x_keep, x / (1 - p), 0.0)# Write-back output# 写回输出tl.store(output_ptr + offsets, output, mask=mask)def dropout(x, x_keep, p):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)return output# Input tensor
# 输入张量
x = torch.randn(size=(10, )).cuda()
# Dropout mask
# Dropout 掩码
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([["input"] + x.tolist(),["keep mask"] + x_keep.tolist(),["output"] + output.tolist(),
]))

Out:
在这里插入图片描述

种子化 Dropout

上述 Dropout 实现效果良好,但管理 Dropout 状态可能会变得复杂,特别是在考虑反向传播和重新计算/检查点场景时。在这里,我们描述一种替代实现,它具有以下优点:

  1. 更小的内存占用。
  2. 较少的数据移动。
  3. 简化了在多次调用内核函数时持久化随机性的管理。

生成 Triton 中的伪随机数很简单!在本教程中,我们将使用 triton.language.rand 函数,该函数基于给定的种子和一组 int32 偏移量生成一个块的均匀分布的 float32 值,范围在 (0, 1) 内。但如果你需要,Triton 也提供其他随机数生成策略。

注意 Triton 的 PRNG 实现基于 Philox 算法(详见 [SALMON2011])。

现在将所有内容整合起来。

@triton.jit
def _seeded_dropout(x_ptr,output_ptr,n_elements,p,seed,BLOCK_SIZE: tl.constexpr,
):# compute memory offsets of elements handled by this instance# 计算由此实例处理的元素的内存偏移量pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)# load data from x# 从 x 读取数据mask = offsets < n_elementsx = tl.load(x_ptr + offsets, mask=mask)# randomly prune it# 随机修剪random = tl.rand(seed, offsets)x_keep = random > p# write-back# 写回output = tl.where(x_keep, x / (1 - p), 0.0)tl.store(output_ptr + offsets, output, mask=mask)def seeded_dropout(x, p, seed):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)return outputx = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
# 与基线相比 - dropout 掩码从未被实例化!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)print(tabulate.tabulate([["input"] + x.tolist(),["output (seed = 123)"] + output.tolist(),["output (seed = 123)"] + output2.tolist(),["output (seed = 512)"] + output3.tolist(),
]))

Out:

在这里插入图片描述

大功告成!我们现在有了一个 Triton 内核,可以在给定相同种子的情况下应用一致的 dropout 掩码。与传统的 dropout 实现相比,这种方法减少了内存开销并简化了状态管理。

练习

  1. 扩展内核以处理矩阵,并使用一个种子向量 — 每行一个种子。
  2. 添加对 striding 的支持。
  3. (挑战)实现稀疏 Johnson-Lindenstrauss 变换的内核,每次使用种子动态生成投影矩阵。

参考文献

  • [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011

  • [SRIVASTAVA2014] Nitish Srivastava et al., “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014

​Download Jupyter notebook: 04-low-memory-dropout.ipynb

Download Python source code: 04-low-memory-dropout.py

Download zipped: 04-low-memory-dropout.zip

相关文章:

【Triton 教程】低内存 Dropout

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境&#xff0c;以高效编写自定义 DNN 计算内核&#xff0c;并能够在现代 GPU 硬件上以最大吞吐量运行。 更多 Triton 中文文档可访问 →https://triton.hyper.ai/ 在本教程中&#xff0c;您将编…...

npx创建项目时,error fetch failed.TypeError: fetch failed

npx创建项目时&#xff0c;报以下错误&#xff1a; error fetch failed. TypeError: fetch failedat node:internal/deps/undici/undici:12345:11at process.processTicksAndRejections (node:internal/process/task_queues:95:5)at async getTemplateVersion (C:\Users\ymt30…...

《Kotlin实战》-附录

附录 本部分内容只是简单列举下Kotlin应用以便指引进一步深入学习Kotlin。 附录A&#xff1a;构建Kotlin项目 本节只会记录下gradle的应用&#xff0c;其他需要时请自行搜索查看。 A.1 用Gradle构建Kotlin代码的项目 构建Kotlin项目的标准Gradle脚本如下&#xff1a; bui…...

yelp数据集上识别潜在的热门商家

yelp数据集是研究B2C业态的一个很好的数据集&#xff0c;要识别潜在的热门商家是一个多维度的分析过程&#xff0c;涉及用户行为、商家特征和社区结构等多个因素。从yelp数据集里我们可以挖掘到下面信息有助于识别热门商家 用户评分和评论分析 评分均值: 商家的平均评分是反映其…...

【Linux】进程信号全攻略(一)

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Linux 目录 一&#xff1a;&#x1f525; 信号的概念 二&#xff1a;&#x1f525; 信号产生的方式 &#x1f98b; 使用键盘&#x1f98b; 系统调用函数&#x1f98b; 软件条件&#x1f98b; 进程异…...

linux文件重命名

Linux文件重命名 文件名显示异常问题出在哪里批量改名扩展 文件名显示异常 跑测CTS&#xff0c;linux环境看跑测结果log file显示没问题&#xff0c;倘若windows下看log file名却显示异常&#xff0c;不太方便操作。 问题出在哪里 linux环境下文件名可以显示正常&#xff0…...

如何选择适合的AWS EC2实例类型

在云计算的世界中&#xff0c;Amazon Web Services&#xff08;AWS&#xff09;提供了丰富的服务&#xff0c;其中Elastic Compute Cloud&#xff08;EC2&#xff09;是最受欢迎的服务之一。选择合适的EC2实例类型对于确保应用程序的性能和成本效益至关重要。我们九河云通过本文…...

【Uniapp】Uniapp Android原生插件开发指北

前言 在uniapp开发中当HBuilderX中提供的能力无法满足App功能需求&#xff0c;需要通过使用Andorid/iOS原生开发实现时&#xff0c;或者是第三方公司提供的是Android的库&#xff0c;这时候可使用App离线SDK开发原生插件来扩展原生能力。 插件类型有两种&#xff0c;Module模…...

【随手笔记】FLASH-W25Q16(三)

#include "bsp_w25q16.h"/*内部函数声明区*/ static HAL_StatusTypeDef bsp_w25q_Transmit(uint8_t * T_pData, uint16_t T_Size); static HAL_StatusTypeDef bsp_w25q_Receive(uint8_t * R_pData, uint16_t R_Size);/*内部函数定义区*//* 函数参数&#xff1a;1、T_…...

2024软件测试面试热点问题

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 大厂面试热点问题 1、测试人员需要何时参加需求分析&#xff1f; 如果条件循序 原则上来说 是越早介入需求分析越好 因为测试人员对需求理解越深刻 对测试工…...

【JAVA】java 企业微信信息推送

前言 JAVA中 将信息 推送到企业微信 // 企微消息推送messageprivate String getMessage(String name, String problemType, String pushResults, Long orderId,java.util.Date submitTime, java.util.Date payTime) {String message "对接方&#xff1a;<font color\…...

介绍一下数组(c基础)(smart 版)

c初期&#xff0c;记住规则&#xff0c;用规则。 我只是介绍规则。&#xff08;有详细版&#xff0c;这适合smart人看&#xff09; 数组&#xff08;同类型&#xff09; int arr[n] {} ; int 是 元素类型。 int arr[n] {} ; arr为标识符。 {} 集合&#xff0c;元素有次…...

Java项目实战II基于Spring Boot的个人云盘管理系统设计与实现(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 基于Spring Boot的个人云盘管理系统设计…...

探索数据科学与大数据技术专业本科生的广阔就业前景

随着信息技术的不断发展&#xff0c;数据科学与大数据技术已经成为各大行业的关键推动力。在这样一个数据驱动的时代&#xff0c;越来越多的企业依赖数据来驱动决策、优化运营和创造价值。因此&#xff0c;数据科学与大数据技术专业的本科生在就业市场上具有广阔的前景和多样的…...

微服务架构面试内容整理-Zuul

Zuul 是由 Netflix 开发的一个边缘服务(API 网关),用于动态路由、监控、认证、以及对微服务架构中的请求进行过滤。它在微服务架构中扮演着重要的角色,提供了一种集中管理和控制服务访问的方式。以下是 Zuul 的主要特点、工作原理和使用场景: 主要特点 1. 动态路由: Zuu…...

解决Knife4j 接口界面UI中文乱码问题

1、查看乱码情况 2、修改 编码设置 3、删除 target 文件 项目重新启动 被坑死了...

微服务架构面试内容整理-Sleuth

Spring Cloud Sleuth 是一个分布式追踪工具&#xff0c;用于监控微服务系统中请求的传播情况。它通过在微服务之间传递追踪信息&#xff0c;帮助开发者理解系统的行为&#xff0c;快速定位性能瓶颈和问题。以下是 Sleuth 的主要特点、工作原理和使用场景&#xff1a; 主要特点 …...

Go语言的接口示例

Go语言的接口(interface)是一种轻量级的多态性实现方式,是构建高扩展性、高复用性代码的利器。Go语言的接口非常灵活,不要求显式的实现声明,只要一个类型实现了接口规定的方法,它就可以被视为该接口的实现者。在本篇博客中,我们将通过多个实际示例,探讨Go语言接口的使用…...

【Apache ECharts】<农作物病害发生防治面积>

在vs Code里打开&#xff0c; 实现 1. 首先引入 echarts.min.js 资源 2. 在body部分设一个 div&#xff0c;设置 id 为 main 3. 设置 script 3.1 基于准备好的dom&#xff0c;初始化echarts实例 var myChart echarts.init(document.getElementById(main)); 3.2 指定图表的…...

基于vue3实现的聊天机器人前端(附代码)

<template><div class"container"><!-- 页面头部 --><header><h1>跟它说说话吧&#xff01;</h1><p>一个活泼的伙伴&#xff0c;为你提供情感支持&#xff01;</p></header><!-- 聊天容器 --><div c…...

DICOM标准:深入详解DICOM医学影像中的传输语法

引言 DICOM&#xff08;数字成像和通信医学&#xff09;标准在医学影像数据交换中扮演着至关重要的角色。其中&#xff0c;*传输语法&#xff08;Transfer Syntax&#xff09;是DICOM标准中定义数据编码和传输方式的核心部分。理解传输语法对于确保不同设备和系统之间的互操作性…...

sql server 文件备份恢复

数据库介绍文件组 PRIMARY 文件 lys D:\Program Files\Microsoft SQL Server\MSSQL13.MSSQLSERVER\MSSQL\DATA\lys.mdf lys_02 D:\Program Files\Microsoft SQL Server\MSSQL13.MSSQLSERVER\MSSQL\DATA\lys_02.ndf文件组 sec 有2个表&#xff08;sec_1,sec_2&#xff09; 文件 …...

Gradle命令编译Android Studio工程项目并签名

文章目录 gradlew命令gradlew编译debug apkgradlew编译release apkapksigner签名apkgradlew注意事项 gradlew命令 gradlew 是一个脚本文件&#xff0c;它允许你在没有全局安装 Gradle 的情况下运行 Gradle 构建。这个脚本在多平台上可用&#xff0c;对于 Windows 系统来说是 g…...

lua入门教程:垃圾回收

Lua的垃圾回收机制是一种自动内存管理方式&#xff0c;用于回收不再被程序访问的对象&#xff0c;从而避免内存泄漏。以下是一个关于Lua垃圾回收机制的详细教程&#xff1a; 一、Lua垃圾回收机制概述 Lua使用自动内存管理&#xff0c;这意味着程序员不需要手动释放内存。Lua的…...

基于前后端分离架构,SaaS云平台与私有云部署的智慧校园源码,java电子班牌源码

基于前后端分离架构&#xff0c;SaaS云平台与私有云部署的智慧校园源码&#xff0c;java电子班牌源码&#xff0c;自主研发&#xff0c;自主版权&#xff0c;上百个学校应用案例&#xff0c;支持二次开发。 在信息技术飞速发展的今天&#xff0c;教育领域也迎来了一场革命性的变…...

知识总结五

一、C深浅拷贝 浅拷贝&#xff1a;只复制对象的成员变量的值&#xff0c;如果成员变量包含指针&#xff0c;则只复制指针值&#xff0c;不复制指针所指向的数据。深拷贝&#xff1a;复制对象的成员变量的值&#xff0c;并且如果成员变量包含指针&#xff0c;则还复制指针所指向…...

一、初识C语言(1)

1.C语言识别的是二进制语言 C语言是一门计算机语言&#xff0c;计算机是硬件&#xff0c;硬件分通电&#xff08;1&#xff09;和 未通电&#xff08;0&#xff09;两种情况&#xff0c;所以C语言识别的都是0 / 1信号&#xff0c;也就是二进制语言。 2.C语言文件类型以及基本框…...

petty 状态管理库文档

自研 Petty 状态管理库产生背景 petty 是一款适用于 vue2.5以下版本&#xff08;目前已兼容vue2.5x 以上版本&#xff09;的状态管理库&#xff0c;能够在 vue 2这种配置项的代码中&#xff0c;去实现类似于 vue3 里的 pinia、React 里的hook的调用形式&#xff0c;用函数式的…...

SpringMVC学习记录(三)之响应数据

SpringMVC学习记录&#xff08;三&#xff09;之响应数据 一、页面跳转控制1、快速返回模板视图2、转发和重定向 二、返回JSON数据1、前置准备2、ResponseBody 三、返回静态资源1、静态资源概念2、访问静态资源 /*** TODO: 一个controller的方法是控制层的一个处理器,我们称为h…...

ENSP GVRP动态学习VLAN

手工配置的VLAN称为静态VLAN&#xff0c;通过GVRP协议创建的VLAN称为动态VLAN。 GVRP有三种注册模式&#xff0c;不同的模式对静态VLAN和动态VLAN的处理方式也不同。 GVRP的三种注册模式分别定义如下&#xff1a; Normal模式&#xff1a;允许动态VLAN在端口上进行注册…...