[内存泄漏][PyTorch](create_graph=True)
PyTorch保存计算图导致内存泄漏
- 1. 内存泄漏定义
- 2. 问题发现背景
- 3. pytorch中关于这个问题的讨论
1. 内存泄漏定义
内存泄漏(Memory Leak)是指程序中已动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果。
2. 问题发现背景
在使用深度学习求解PDE时,由于经常需要计算高阶导数,在pytorch框架下写的代码需要用到torch.autograd.grad(create_graph=True)
或者torch.backward(create_graph=True)
这个参数,然后发现了这个内存泄漏的问题。如果要保存计算图用来计算高阶导数,那么其所占的内存不会被释放,会一直占用。也就是如果设置create_graph=True
,那么其保存的计算图所占的内存只有在程序运行结束时才会释放,这样导致了一个问题,即如果在循环中需要保存计算图,例如每个循环都需要计算一次黑塞矩阵,那么这个内存占用就会越来越多,最终导致out of memory
报错。
3. pytorch中关于这个问题的讨论
官网中关于这个问题的讨论见https://github.com/pytorch/pytorch/issues/7343,这里提出的内存泄漏的例子如下:
import torch
import gc_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
y.backward(torch.ones_like(y), create_graph=True)
del x, y
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
可以看到虽然删除了变量,依然造成了内存泄漏。这里红色的警告就是关于这个内存泄漏的问题。
UserWarning: Using backward() with create_graph=True will create a reference cycle between
the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad
when creating the graph to avoid this. If you have to use this function, make sure to reset
the .grad fields of your parameters to None after use to break the cycle and avoid the leak.
(Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\engine.cpp:1000.)
allow_unreachable=True, accumulate_grad=True)
# Calls into the C++ engine to run the backward pass
看这个UserWarning
,提示我们使用torch.autograd.grad()函数可以避免这个梯度泄漏,然后对代码进行改动:
import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
# y.backward(torch.ones_like(y), create_graph=True)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
结果显示没有梯度泄漏。进一步,我们求一下二阶导数:
import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z, q
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
结果也没有内存泄漏。但是,如果我们不删除结果二阶导数q
,这样是出于如果写在一个函数中,需要将q
作为return
值返回的情况。
import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
可以看到,这还是会导致一部分内存泄漏。那如果我们一定要返回q
,又不想内存泄漏,这里本人想到一直办法,就是将q
转换成numpy数据类型,返回这个numpy数组,就不会导致内存泄漏了。
import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
k = q[0].cpu().numpy()
del x, y, z, q
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
相关文章:

[内存泄漏][PyTorch](create_graph=True)
PyTorch保存计算图导致内存泄漏 1. 内存泄漏定义2. 问题发现背景3. pytorch中关于这个问题的讨论 1. 内存泄漏定义 内存泄漏(Memory Leak)是指程序中已动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致…...

【Git学习二】时光回溯:git reset和git checkout命令详解
😁 作者简介:一名大四的学生,致力学习前端开发技术 ⭐️个人主页:夜宵饽饽的主页 ❔ 系列专栏:Git等软件工具技术的使用 👐学习格言:成功不是终点,失败也并非末日,最重要…...

多维时序 | MATLAB实现PSO-GRU-Attention粒子群优化门控循环单元融合注意力机制的多变量时间序列预测
多维时序 | MATLAB实现PSO-GRU-Attention粒子群优化门控循环单元融合注意力机制的多变量时间序列预测 目录 多维时序 | MATLAB实现PSO-GRU-Attention粒子群优化门控循环单元融合注意力机制的多变量时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 MAT…...
MySQL缓冲池的优化与性能提升
“不积跬步,无以至千里。” MySQL是许多Web应用的核心数据库,而数据库的性能对于应用的稳定运行至关重要。在MySQL中,缓冲池(Buffer Pool)是一个关键的组件,它直接影响着数据库的性能和响应速度。今天这篇文…...

一些RLHF的平替汇总
卷友们好,我是rumor。 众所周知,RLHF十分玄学且令人望而却步。我听过有的小道消息说提升很大,也有小道消息说效果不明显,究其根本还是系统链路太长自由度太高,不像SFT一样可以通过数据配比、prompt、有限的超参数来可控…...
7.docker部署前端vue项目,实现反向代理配置
介绍: 构建镜像:通过docker构建以nginx为基础的镜像,将vue项目生成的dist包拷贝至nginx目录下,.conf文件做反向代理配置;部署服务:docker stack启动部署服务; 通过执行两个脚本既可以实现构建…...

字符串函数详解
一.字母大小写转换函数. 1.1.tolower 结合cppreference.com 有以下结论: 1.头文件为#include <ctype.h> 2.使用规则为 #include <stdio.h> #include <ctype.h> int main() {char ch A;printf("%c\n",tolower(ch));//大写转换为小…...

Mybatis学习笔记-映射文件,标签,插件
目录 概述 mybatis做了什么 原生JDBC存在什么问题 MyBatis组成部分 Mybatis工作原理 mybatis和hibernate区别 使用mybatis(springboot) mybatis核心-sql映射文件 基础标签说明 1.namespace,命名空间 2.select,insert&a…...

【C++】模板初阶 【 深入浅出理解 模板 】
模板初阶 前言:泛型编程一、函数模板(一)函数模板概念(二)函数模板格式(三)函数模板的原理(四)函数模板的实例化(五)模板参数的匹配原则 三、类模…...

无需API开发,伯俊科技实现电商与客服系统的无缝集成
伯俊科技的无代码开发实现系统连接 自1999年成立以来,伯俊科技一直致力于为企业提供全渠道一盘货的服务。凭借其24年的深耕零售行业的经验,伯俊科技推出了一种无需API开发的方法,实现电商系统和客服系统的连接与集成。这种无代码开发的方式不…...

Python | 机器学习之逻辑回归
🌈个人主页:Sarapines Programmer🔥 系列专栏:《人工智能奇遇记》🔖少年有梦不应止于心动,更要付诸行动。 目录结构 1. 机器学习之逻辑回归概念 1.1 机器学习 1.2 逻辑回归 2. 逻辑回归 2.1 实验目的…...

手机,蓝牙开发板,TTL/USB模块,电脑四者之间的通讯
一,意图 通过手机蓝牙连接WeMosD1R32开发板,开发板又通过TTL转USB与电脑连接.手机通过蓝牙控制开发板上的LED灯的开,关,闪等动作,在电脑上打开串口监视工具观察其状态.也可以通过电脑上的串口监视工具来控制开发板上LED灯的动作,而在手机蓝牙监测工具中显示灯的状态. 二,原料…...

Springboot更新用户头像
人们通常(为徒省事)把一个包含了修改后userName的完整userInfo对象传给后端,做完整更新。但仔细想想,这种做法感觉有点二,而且浪费带宽。 于是patch诞生,只传一个userName到指定资源去,表示该请求是一个局部更新&#…...

Express.js 与 Nest.js对比
Express.js 与 Nest.js对比 自从 Node.js 发布以来,Javascript 在后端领域的使用有所增加。由于 Node.js 的使用越来越多,每天都会有新的框架和工具发布。Express 和 Nest 是使用 Node.js 创建后端应用程序的最著名的框架之一,在本文中&…...

总结 CNN 模型:将焦点转移到基于注意力的架构
一、说明 在计算机视觉时代,卷积神经网络(CNN)几十年来一直是主导范式。直到 2021 年 Vision Transformers (ViTs) 出现,这个领域才开始发生变化。现在,是时候采用受 Transformer 架构启发的基于注意力的模型了&#x…...

2023.11.16 hivesql高阶函数之开窗函数
目录 1.开窗函数的定义 2.数据准备 3.开窗函数之排序 需求:用三种排序方法查询学生的语文成绩排名,并降序显示 4.开窗函数分组 需求:按照科目来分类,使用三种排序方式来排序学生的成绩 5.聚合函数与分组配合使用 6.聚合函数同时和分组以及排序关键字配合使用 --需求1&…...
QTableWidget常用信号的功能
2023年11月18日,周六上午 itemPressed(QTableWidgetItem *item):当某个项目被按下时发出信号。itemClicked(QTableWidgetItem *item):当某个项目被单击时发出信号。itemDoubleClicked(QTableWidgetItem *item):当某个项目被双击时…...
Vue理解01
项目建立流程 项目文件夹终端vue ui可视化新建项目(需要一些时间)vscode打开项目npm run serve运行 架构理解: 首先打开的页面默认是index.htmlindex.html默认引用main.jsmain.js引用需要的页面,默认App.vue。Vue示例挂载可以在…...

4、FFmpeg命令行操作8
生成测试文件 找三个不同的视频每个视频截取10秒内容 ffmpeg -i 沙海02.mp4 -ss 00:05:00 -t 10 -codec copy 1.mp4 ffmpeg -i 复仇者联盟3.mp4 -ss 00:05:00 -t 10 -codec copy 2.mp4 ffmpeg -i 红海行动.mp4 -ss 00:05:00 -t 10 -codec copy 3.mp4 如果音视…...

【MySQL】索引与事务
作者主页:paper jie_博客 本文作者:大家好,我是paper jie,感谢你阅读本文,欢迎一建三连哦。 本文录入于《MySQL》专栏,本专栏是针对于大学生,编程小白精心打造的。笔者用重金(时间和精力)打造&a…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...
Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器
第一章 引言:语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域,文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量,支撑着搜索引擎、推荐系统、…...
[Java恶补day16] 238.除自身以外数组的乘积
给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O(n) 时间复杂度…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...
CMake控制VS2022项目文件分组
我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...

使用LangGraph和LangSmith构建多智能体人工智能系统
现在,通过组合几个较小的子智能体来创建一个强大的人工智能智能体正成为一种趋势。但这也带来了一些挑战,比如减少幻觉、管理对话流程、在测试期间留意智能体的工作方式、允许人工介入以及评估其性能。你需要进行大量的反复试验。 在这篇博客〔原作者&a…...

无人机侦测与反制技术的进展与应用
国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机(无人驾驶飞行器,UAV)技术的快速发展,其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统,无人机的“黑飞”&…...

Ubuntu系统复制(U盘-电脑硬盘)
所需环境 电脑自带硬盘:1块 (1T) U盘1:Ubuntu系统引导盘(用于“U盘2”复制到“电脑自带硬盘”) U盘2:Ubuntu系统盘(1T,用于被复制) !!!建议“电脑…...
redis和redission的区别
Redis 和 Redisson 是两个密切相关但又本质不同的技术,它们扮演着完全不同的角色: Redis: 内存数据库/数据结构存储 本质: 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能: 提供丰…...

【51单片机】4. 模块化编程与LCD1602Debug
1. 什么是模块化编程 传统编程会将所有函数放在main.c中,如果使用的模块多,一个文件内会有很多代码,不利于组织和管理 模块化编程则是将各个模块的代码放在不同的.c文件里,在.h文件里提供外部可调用函数声明,其他.c文…...