当前位置: 首页 > news >正文

【Triton 教程】低内存 Dropout

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境,以高效编写自定义 DNN 计算内核,并能够在现代 GPU 硬件上以最大吞吐量运行。

更多 Triton 中文文档可访问 →https://triton.hyper.ai/

在本教程中,您将编写一个内存高效的 Dropout 实现,其状态将由单个 int32 seed 组成。这与传统 Dropout 实现不同,传统实现通常由与输入 shape 相同的位掩码张量组成。

在这过程中,您将学习到以下内容:

  • PyTorch 中 原生实现 Dropout 的局限性。

  • Triton 中的并行伪随机数生成。

简介

Dropout 是在 [SRIVASTAVA2014] 中引入的一种技术,用于改善低数据条件下深度神经网络的性能,通常用于正则化。它接受一个向量作为输入,并生成相同 shape 的输出向量。输出中的每个标量都有概率 p 被设为零,否则直接从输入复制。这使得网络在仅有输入的 1−p 标量时也能表现良好。

在评估阶段,为了充分利用网络的能力,将 p 设为 0。但是简单地将 p 设为 0 会增加输出的范数,可能会人为地降低输出的 softmax temperature。为了防止这种情况发生,输出被缩放为 1/(1-p),这使得无论 dropout 概率如何都能保持一致的范数。

Baseline

首先看一下 baseline 的实现。

import tabulate
import torchimport triton
import triton.language as tl@triton.jit
def _dropout(x_ptr,      # 输入指针x_keep_ptr, # pointer to a mask of 0s and 1s 由 01 组成的掩码的指针output_ptr, # pointer to the output 输出指针n_elements, # number of elements in the `x` tensor `x` 张量的元素数量p,          # probability that an element of `x` is changed to zero 元素 `x` 被设置为 0 的概率BLOCK_SIZE: tl.constexpr,
):pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)mask = offsets < n_elements# Load data# 加载数据x = tl.load(x_ptr + offsets, mask=mask)x_keep = tl.load(x_keep_ptr + offsets, mask=mask)# The line below is the crucial part, described in the paragraph above!# 下一行是上段描述的关键部分output = tl.where(x_keep, x / (1 - p), 0.0)# Write-back output# 写回输出tl.store(output_ptr + offsets, output, mask=mask)def dropout(x, x_keep, p):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)return output# Input tensor
# 输入张量
x = torch.randn(size=(10, )).cuda()
# Dropout mask
# Dropout 掩码
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([["input"] + x.tolist(),["keep mask"] + x_keep.tolist(),["output"] + output.tolist(),
]))

Out:
在这里插入图片描述

种子化 Dropout

上述 Dropout 实现效果良好,但管理 Dropout 状态可能会变得复杂,特别是在考虑反向传播和重新计算/检查点场景时。在这里,我们描述一种替代实现,它具有以下优点:

  1. 更小的内存占用。
  2. 较少的数据移动。
  3. 简化了在多次调用内核函数时持久化随机性的管理。

生成 Triton 中的伪随机数很简单!在本教程中,我们将使用 triton.language.rand 函数,该函数基于给定的种子和一组 int32 偏移量生成一个块的均匀分布的 float32 值,范围在 (0, 1) 内。但如果你需要,Triton 也提供其他随机数生成策略。

注意 Triton 的 PRNG 实现基于 Philox 算法(详见 [SALMON2011])。

现在将所有内容整合起来。

@triton.jit
def _seeded_dropout(x_ptr,output_ptr,n_elements,p,seed,BLOCK_SIZE: tl.constexpr,
):# compute memory offsets of elements handled by this instance# 计算由此实例处理的元素的内存偏移量pid = tl.program_id(axis=0)block_start = pid * BLOCK_SIZEoffsets = block_start + tl.arange(0, BLOCK_SIZE)# load data from x# 从 x 读取数据mask = offsets < n_elementsx = tl.load(x_ptr + offsets, mask=mask)# randomly prune it# 随机修剪random = tl.rand(seed, offsets)x_keep = random > p# write-back# 写回output = tl.where(x_keep, x / (1 - p), 0.0)tl.store(output_ptr + offsets, output, mask=mask)def seeded_dropout(x, p, seed):output = torch.empty_like(x)assert x.is_contiguous()n_elements = x.numel()grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)return outputx = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
# 与基线相比 - dropout 掩码从未被实例化!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)print(tabulate.tabulate([["input"] + x.tolist(),["output (seed = 123)"] + output.tolist(),["output (seed = 123)"] + output2.tolist(),["output (seed = 512)"] + output3.tolist(),
]))

Out:

在这里插入图片描述

大功告成!我们现在有了一个 Triton 内核,可以在给定相同种子的情况下应用一致的 dropout 掩码。与传统的 dropout 实现相比,这种方法减少了内存开销并简化了状态管理。

练习

  1. 扩展内核以处理矩阵,并使用一个种子向量 — 每行一个种子。
  2. 添加对 striding 的支持。
  3. (挑战)实现稀疏 Johnson-Lindenstrauss 变换的内核,每次使用种子动态生成投影矩阵。

参考文献

  • [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, “Parallel Random Numbers: As Easy as 1, 2, 3”, 2011

  • [SRIVASTAVA2014] Nitish Srivastava et al., “Dropout: A Simple Way to Prevent Neural Networks from Overfitting”, JMLR 2014

​Download Jupyter notebook: 04-low-memory-dropout.ipynb

Download Python source code: 04-low-memory-dropout.py

Download zipped: 04-low-memory-dropout.zip

相关文章:

【Triton 教程】低内存 Dropout

Triton 是一种用于并行编程的语言和编译器。它旨在提供一个基于 Python 的编程环境&#xff0c;以高效编写自定义 DNN 计算内核&#xff0c;并能够在现代 GPU 硬件上以最大吞吐量运行。 更多 Triton 中文文档可访问 →https://triton.hyper.ai/ 在本教程中&#xff0c;您将编…...

npx创建项目时,error fetch failed.TypeError: fetch failed

npx创建项目时&#xff0c;报以下错误&#xff1a; error fetch failed. TypeError: fetch failedat node:internal/deps/undici/undici:12345:11at process.processTicksAndRejections (node:internal/process/task_queues:95:5)at async getTemplateVersion (C:\Users\ymt30…...

《Kotlin实战》-附录

附录 本部分内容只是简单列举下Kotlin应用以便指引进一步深入学习Kotlin。 附录A&#xff1a;构建Kotlin项目 本节只会记录下gradle的应用&#xff0c;其他需要时请自行搜索查看。 A.1 用Gradle构建Kotlin代码的项目 构建Kotlin项目的标准Gradle脚本如下&#xff1a; bui…...

yelp数据集上识别潜在的热门商家

yelp数据集是研究B2C业态的一个很好的数据集&#xff0c;要识别潜在的热门商家是一个多维度的分析过程&#xff0c;涉及用户行为、商家特征和社区结构等多个因素。从yelp数据集里我们可以挖掘到下面信息有助于识别热门商家 用户评分和评论分析 评分均值: 商家的平均评分是反映其…...

【Linux】进程信号全攻略(一)

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Linux 目录 一&#xff1a;&#x1f525; 信号的概念 二&#xff1a;&#x1f525; 信号产生的方式 &#x1f98b; 使用键盘&#x1f98b; 系统调用函数&#x1f98b; 软件条件&#x1f98b; 进程异…...

linux文件重命名

Linux文件重命名 文件名显示异常问题出在哪里批量改名扩展 文件名显示异常 跑测CTS&#xff0c;linux环境看跑测结果log file显示没问题&#xff0c;倘若windows下看log file名却显示异常&#xff0c;不太方便操作。 问题出在哪里 linux环境下文件名可以显示正常&#xff0…...

如何选择适合的AWS EC2实例类型

在云计算的世界中&#xff0c;Amazon Web Services&#xff08;AWS&#xff09;提供了丰富的服务&#xff0c;其中Elastic Compute Cloud&#xff08;EC2&#xff09;是最受欢迎的服务之一。选择合适的EC2实例类型对于确保应用程序的性能和成本效益至关重要。我们九河云通过本文…...

【Uniapp】Uniapp Android原生插件开发指北

前言 在uniapp开发中当HBuilderX中提供的能力无法满足App功能需求&#xff0c;需要通过使用Andorid/iOS原生开发实现时&#xff0c;或者是第三方公司提供的是Android的库&#xff0c;这时候可使用App离线SDK开发原生插件来扩展原生能力。 插件类型有两种&#xff0c;Module模…...

【随手笔记】FLASH-W25Q16(三)

#include "bsp_w25q16.h"/*内部函数声明区*/ static HAL_StatusTypeDef bsp_w25q_Transmit(uint8_t * T_pData, uint16_t T_Size); static HAL_StatusTypeDef bsp_w25q_Receive(uint8_t * R_pData, uint16_t R_Size);/*内部函数定义区*//* 函数参数&#xff1a;1、T_…...

2024软件测试面试热点问题

&#x1f345; 点击文末小卡片 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 大厂面试热点问题 1、测试人员需要何时参加需求分析&#xff1f; 如果条件循序 原则上来说 是越早介入需求分析越好 因为测试人员对需求理解越深刻 对测试工…...

【JAVA】java 企业微信信息推送

前言 JAVA中 将信息 推送到企业微信 // 企微消息推送messageprivate String getMessage(String name, String problemType, String pushResults, Long orderId,java.util.Date submitTime, java.util.Date payTime) {String message "对接方&#xff1a;<font color\…...

介绍一下数组(c基础)(smart 版)

c初期&#xff0c;记住规则&#xff0c;用规则。 我只是介绍规则。&#xff08;有详细版&#xff0c;这适合smart人看&#xff09; 数组&#xff08;同类型&#xff09; int arr[n] {} ; int 是 元素类型。 int arr[n] {} ; arr为标识符。 {} 集合&#xff0c;元素有次…...

Java项目实战II基于Spring Boot的个人云盘管理系统设计与实现(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 基于Spring Boot的个人云盘管理系统设计…...

探索数据科学与大数据技术专业本科生的广阔就业前景

随着信息技术的不断发展&#xff0c;数据科学与大数据技术已经成为各大行业的关键推动力。在这样一个数据驱动的时代&#xff0c;越来越多的企业依赖数据来驱动决策、优化运营和创造价值。因此&#xff0c;数据科学与大数据技术专业的本科生在就业市场上具有广阔的前景和多样的…...

微服务架构面试内容整理-Zuul

Zuul 是由 Netflix 开发的一个边缘服务(API 网关),用于动态路由、监控、认证、以及对微服务架构中的请求进行过滤。它在微服务架构中扮演着重要的角色,提供了一种集中管理和控制服务访问的方式。以下是 Zuul 的主要特点、工作原理和使用场景: 主要特点 1. 动态路由: Zuu…...

解决Knife4j 接口界面UI中文乱码问题

1、查看乱码情况 2、修改 编码设置 3、删除 target 文件 项目重新启动 被坑死了...

微服务架构面试内容整理-Sleuth

Spring Cloud Sleuth 是一个分布式追踪工具&#xff0c;用于监控微服务系统中请求的传播情况。它通过在微服务之间传递追踪信息&#xff0c;帮助开发者理解系统的行为&#xff0c;快速定位性能瓶颈和问题。以下是 Sleuth 的主要特点、工作原理和使用场景&#xff1a; 主要特点 …...

Go语言的接口示例

Go语言的接口(interface)是一种轻量级的多态性实现方式,是构建高扩展性、高复用性代码的利器。Go语言的接口非常灵活,不要求显式的实现声明,只要一个类型实现了接口规定的方法,它就可以被视为该接口的实现者。在本篇博客中,我们将通过多个实际示例,探讨Go语言接口的使用…...

【Apache ECharts】<农作物病害发生防治面积>

在vs Code里打开&#xff0c; 实现 1. 首先引入 echarts.min.js 资源 2. 在body部分设一个 div&#xff0c;设置 id 为 main 3. 设置 script 3.1 基于准备好的dom&#xff0c;初始化echarts实例 var myChart echarts.init(document.getElementById(main)); 3.2 指定图表的…...

基于vue3实现的聊天机器人前端(附代码)

<template><div class"container"><!-- 页面头部 --><header><h1>跟它说说话吧&#xff01;</h1><p>一个活泼的伙伴&#xff0c;为你提供情感支持&#xff01;</p></header><!-- 聊天容器 --><div c…...

Sammy.js项目实战:从零搭建完整的单页应用架构终极指南

Sammy.js项目实战&#xff1a;从零搭建完整的单页应用架构终极指南 【免费下载链接】sammy Sammy is a tiny javascript framework built on top of jQuery, Its RESTful Evented Javascript. 项目地址: https://gitcode.com/gh_mirrors/sa/sammy Sammy.js是一个轻量级的…...

雪花算法:分布式世界的“身份证号”

嘿&#xff0c;朋友&#xff01;想象一下&#xff0c;你是一家拥有几千台服务器的互联网大厂架构师。现在有个小麻烦&#xff1a;你的订单系统每秒钟要生成几万个订单号。如果让数据库自己搞&#xff08;自增ID&#xff09;&#xff0c;几台数据库凑在一起&#xff0c;肯定会出…...

别再重装系统了!用GParted给Ubuntu 20.04根目录无损扩容(Win11+Ubuntu双系统适用)

双系统用户必备&#xff1a;Ubuntu根目录无损扩容实战指南 1. 当根目录空间告急时 作为一名长期使用Win11Ubuntu双系统的开发者&#xff0c;我深刻理解那种看着根目录空间一点点被蚕食的焦虑。特别是进行深度学习训练或大型项目编译时&#xff0c;几十GB的空间转眼间就被占满。…...

利用快马平台与openclaw切换模型功能,快速构建待办事项应用原型

最近在尝试快速构建一个待办事项应用的原型时&#xff0c;发现InsCode(快马)平台的AI代码生成功能特别适合这种场景。通过平台内置的openclaw切换模型功能&#xff0c;可以快速比较不同AI模型生成的代码风格差异&#xff0c;大大缩短了原型开发周期。下面分享下我的实践过程&am…...

循环冷却水流量示意图设计 建筑水流量示意图绘制教程

一、引言 在建筑给排水、暖通空调及工业循环水系统设计中&#xff0c;循环冷却水流量示意图与建筑水流量示意图是核心技术图纸之一&#xff0c;其作用是直观呈现水流路径、管径规格、流量分配、设备连接关系及压力节点参数&#xff0c;为系统施工、调试、运维及故障排查提供可…...

基于vue的非遗文化传承平台[vue]-计算机毕业设计源码+LW文档

摘要&#xff1a;非物质文化遗产&#xff08;非遗&#xff09;作为民族文化的重要组成部分&#xff0c;承载着人类社会的文明和历史记忆。随着现代社会的快速发展&#xff0c;非遗文化的传承面临着诸多挑战。为了更好地保护和传承非遗文化&#xff0c;本文设计并实现了一个基于…...

MOS管选型实战指南

MOS管(金属氧化物半导体场效应晶体管)是现代电力电子和开关电路的核心元件。选型失误的后果往往是灾难性的——效率低下、发热严重、驱动振荡、甚至炸管冒烟。相比电阻电容,MOS管的选型需要权衡的维度更多:电压、电流、导通电阻、开关速度、驱动电压、热阻、体二极管特性……...

Vue2项目构建优化实战:时间戳防缓存与资源压缩的配置详解

1. 为什么Vue2项目需要构建优化 最近接手了一个老项目的维护工作&#xff0c;发现每次前端更新后总有用户反馈页面显示异常。排查后发现是浏览器缓存惹的祸——用户访问的仍然是旧版本的静态资源。这让我意识到构建优化的重要性&#xff0c;特别是对于需要频繁更新的业务系统。…...

揭秘JVM创世过程之Call Stub进入Java世界的门票

前言 本文旨在记录近期研读Java源码的学习心得与疑难问题。由于个人理解水平有限&#xff0c;文中内容可能存在疏漏&#xff0c;恳请读者不吝指正。 前情回顾 在揭秘JVM创世过程之两种语言首席外交官JavaCalls&#xff0c;一文中将JVM看作Java世界中一个拥有两种语言的领事馆…...

PyTorch 2.8镜像真实效果:物理实验→电磁场/流体力学可视化视频

PyTorch 2.8镜像真实效果&#xff1a;物理实验→电磁场/流体力学可视化视频 1. 开箱即用的专业级物理模拟环境 当你第一次启动这个基于RTX 4090D优化的PyTorch 2.8镜像时&#xff0c;最直接的感受就是"专业工具就该这样"。这个镜像不是普通的深度学习环境&#xff…...