基于TorchViz详解计算图(附代码)
文章目录
- 0. 前言
- 1. 计算图是什么?
- 2. TorchViz的安装
- 3. 计算图详解
0. 前言
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。
本文的主旨是基于TorchVis模块详细说明计算图以及叶子节点等相关概念。
创作本文的目的主要有两个:
- 计算图这个概念在深度学习中经常被提及,但是对于新手(甚至部分老手)而言,可能很少人能明白计算图究竟是个什么东西,用来干嘛的;
- CSND上关于计算图的介绍文章不少,但基本都是引用TorchViz生成计算图后就完事了,缺乏对计算图的理解。
1. 计算图是什么?
答:计算图是用于表示计算过程的图,例如下面这个:

这个图可以理解为最简单的单层神经元网络,其中: x x x为训练输入数据, w w w和 b b b是要优化的参数, y y y为训练输出数据, l o s s loss loss为损失值。
PyTorch官方对计算图(Computational Graph)的介绍是:一个有向开环图(DAG),这个有向开环图记录了①所有的输入数据(张量),②这些数据(张量)的计算过程,③通过这些计算过程生成的新数据(张量)。
在计算图中,“叶子”代表了输入数据(张量),“根”代表了输出数据(张量)。追溯从“根”到“叶子”的过程,通过链式法则可以计算出(损失值对神经元网络模型参数的)偏导。
PyTorch官网原文链接:https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html?highlight=grad_fn
2. TorchViz的安装
TorchViz是一个用于可视化 PyTorch计算图的工具库,后面的说明都是使用TorchViz生成的计算图来讲解,所以先介绍下TorchViz的安装。
其方法非常简单。。。使用Conda直接安装TorchViz:首先进入到Pycharm/settings/Python Interpreter,然后点“+”。

再搜torchviz,点“Install Package”

因为TorchViz中引用了GraphViz库中的方法,所以也得把GraphViz安装上。

其实不太想水这一章的内容,但是我实在不理解为什么大家都千篇一律喜欢用pip?
3. 计算图详解
首先我们先生成一个最简单的 h = w × x + b h = w×x + b h=w×x+b的计算图,代码如下:
import torch
from torchviz import make_dotx = torch.tensor([1],dtype=torch.float32,requires_grad=True)
w = torch.tensor([4],dtype=torch.float32,requires_grad=True)
b = torch.tensor([0.5],dtype=torch.float32,requires_grad=True)h = w*x + b
graph_forward = make_dot(h)
graph_forward.render(filename='C:\\Users\\Lenovo\\Desktop\\DL\\calc_graph\\graph_forward', view=False, format='pdf')
这里的路径filename一定要写道最终文件的名字,而不是最终文件夹!!!也就是说calc_graph最后一层文件夹,生成的文件是graph_forward.pdf
生成的计算图如下:

其中,蓝底色的3个(1)即是第1章中说明的计算图中的“叶子”,绿底色的(1)是“根”。
这里的“叶子”即为我们经常听说的叶子节点(leaf node)。PyTorch为了节省内存,只会记录叶子节点的相关操作,计算梯度时也只对叶子节点进行计算。
回到 h = w × x + b h = w×x + b h=w×x+b的计算图,如果它代表的是某个深度学习网络模型中的某个隐层的计算过程,那显然我们不用知道对 x x x的偏导,这样我们就可以把它从计算图中剥离出来,把计算资源都给到对参数 w w w和 b b b的计算。把 x x x从计算图中剥离出来的方式也很简单,只要指定requires_grad为False就可以了。
x = torch.tensor([1],dtype=torch.float32,requires_grad=False)
这里再说明另外一个方法——.detach()。有人也会介绍.detach()的作用也是把张量从计算图中剥离出来,甚至有人不明所以会说.detach()和requires_grad=False作用等效。
这里最大的区别就是requires_grad=False会把这个张量直接从计算图中砍掉,这点在下面的计算图中也可以看出来。
而.detach()的作用更类似于“复制”,张量在.detach()操作后在原来的计算图中仍然存在,只是把这个节点的数据复制出来用作别的计算而不会影响原来的计算图。
这样新的计算图就成了这样:

其中左上角的蓝框代表权重 w w w( x x x已经被砍掉),(1)代表1维向量且只有1个元素,右边蓝框代表偏差 b b b,下边绿框代表“根” h h h,箭头方向代表正向传播方向。
反向传播是从“根”通过链式法则回溯到“叶子”的过程,这里从“根”往上回溯,经历了如下操作过程(灰色框):
- AddBackward0:加法过程,代表 h = w × x + b h = w×x + b h=w×x+b中的“+”;
- MulBackward0:乘法过程,代表 h = w × x + b h = w×x + b h=w×x+b中的“×”;
- AccumulateGrad:梯度积累,在Pytorch中,权重梯度的计算是累加的,这是为了提升训练效率,在每个batch中梯度都进行累加,不同batch间进行梯度清零,这也就是为什么训练的时候要用.zero_grad()的原因;
在 <操作>Backward<层数>中,常见的<操作>有以下几种:
Add代表加法;
Sub代表减法;
Mul代表乘法;
Mm代表矩阵乘法;
Div代表除法;
T代表矩阵转置;
Pow代表乘方;
Squeeze, Unsqueeze, Relu, Sigmoid就代表原本的含义;
<层数>为从"根"到"叶子"的操作层数,本示例中只有1层,所以Backward后面都为0。这里需要注意<层数>是从"根"到"叶子"从下往上数的,所以离"根"越近<层数>越小。
这样我们就把计算图说明白了,无论多复杂的模型,原理都是一样的,只不过是输入输出,操作的复杂度不同而已。
最后需要说明的一点是:计算图在PyTorch中是动态的,在每次调用.backward()之后都会生成一个新的计算图,这样就可以允许在每次学习迭代中调整计算图。
相关文章:
基于TorchViz详解计算图(附代码)
文章目录 0. 前言1. 计算图是什么?2. TorchViz的安装3. 计算图详解 0. 前言 按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,…...
解决GitHub的速度很慢的几种方式
1. GitHub 镜像访问 这里提供两个最常用的镜像地址: https://hub.njuu.cf/search https://www.gitclone.com/gogs/search/clonesearch 也就是说上面的镜像就是一个克隆版的 GitHub,你可以访问上面的镜像网站,网站的内容跟 GitHub 是完整同步…...
设计模式再探——策略模式
目录 一、背景介绍二、思路&方案三、过程1.策略模式简介2.策略模式的类图3.策略模式代码4.策略模式还可以优化的地方5.策略模式的例子改造(配置文件反射) 四、总结五、升华 一、背景介绍 最近在做产品的过程中,对于主题讨论回复内容,按照追评次数排…...
基于Googlenet深度学习网络的人员行为动作识别matlab仿真
目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 1. 原理 1.1 深度学习与卷积神经网络(CNN) 1.2 GoogLeNet 2. 实现过程 2.1 数据预处理 2.2 构建网络模型 2.3 数据输入与训练 2.4 模型评估与调优 3. 应用领域…...
存储过程的学习
1,前言 这是实习期间学习的,我可能是在学校没好好听课,(或者就是学校比较垃,没教这部分,在公司经理让我下去自己学习,太难了,因为是公司代码很多部分都是很多表的操作&#…...
zookeeperAPI操作与写数据原理
要执行API操作需要在idea中创建maven项目 (改成自己的阿里仓库)导入特定依赖 添加日志文件 上边操作做成后就可以进行一些API的实现了 目录 导入maven依赖: 创建日志文件: 创建API客户端: (1)…...
防火墙对双通道协议的处理
防火墙是一种网络安全设备或软件,用于控制网络流量并保护计算机网络免受未经授权的访问、恶意攻击和网络威胁。它作为网络的第一道防线,用于监视、过滤和管理进出网络的数据包。 防火墙可以基于预设的安全策略对网络流量进行评估和筛选。它通过比较数据…...
vscode搭建c语言环境问题
c语言环境搭建参考文章:【C语言初级阶段学习1】使用vscode运行C语言,vscode配置环境超详细过程(包括安装vscode和MinGW-W64安装及后续配置使用的详细过程,vscode用户代码片段的使用)[考研专用]_QAQshift的博客-CSDN博客 问题如下:…...
全网最全的接口自动化测试教程
为什么要做接口自动化 相对于UI自动化而言,接口自动化具有更大的价值。 为了优化转化路径或者提升用户体验,APP/web界面的按钮控件和布局几乎每个版本都会发生一次变化,导致自动化的代码频繁变更,没有起到减少工作量的效果。 而…...
数据结构----结构--线性结构--链式存储--链表
数据结构----结构–线性结构–链式存储–链表 1.链表的特点 空间可以不连续,长度不固定,相对于数组灵活自由 搜索: 时间复杂度O(n) 增删: 头增头删时间复杂度O(1) 其他时间复杂度为O(n) 扩展:单向循环链表的特性 从任意节…...
【5G 核心网】5G 多PDU会话锚点技术介绍
博主未授权任何人或组织机构转载博主任何原创文章,感谢各位对原创的支持! 博主链接 本人就职于国际知名终端厂商,负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作,目前牵头6G算力网络技术标准研究。 博客…...
K8s环境下监控告警平台搭建及配置
Promethues是可以单机搭建的,参考prometheus入门[1] 本文是就PromethuesGrafana在K8s环境下的搭建及配置 Prometheus度量指标监控平台简介 启动minikube minikube start 安装helm 使用Helm Chart 安装 Prometheus Operator: helm install prometheus-operator stabl…...
微信小程序在使用vant组件库时构建npm报错
在跟着vant官方进行使用步骤一步步操作时,由于要构建NPM,但NPM包在App配置文件的外部 所以在做下图这一步时: 接着再进行npm构建时会报错 message:发生错误 Error: F:\前端学习\前端框架\小程序\project\demo\miniprogram解决方法 …...
Django实现音乐网站 ⑽
使用Python Django框架制作一个音乐网站, 本篇主要是后台对歌曲类型、歌单功能原有功能进行部分功能实现和显示优化。 目录 歌曲类型功能优化 新增编辑 优化输入项标题显示 父类型显示改为下拉菜单 列表显示 父类型显示名称 过滤器增加父类型 歌单表功能优化…...
SpringMVC的架构有什么优势?——异常处理与文件上传(五)
前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 「推荐专栏」: ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄,vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄ÿ…...
【java面向对象中static关键字】
提纲 static修饰成员变量static修饰成员变量的应用场景static修饰成员方法static修饰成员方法的应用场景static的注意事项static的应用知识:代码块static的应用知识:单例设计模式 static静态的意思,可以修饰成员变量,成员方法&a…...
系统学习Linux-Redis集群
目录 一、Redis主从复制 概念 作用 缺点 流程 二、Reids哨兵模式(sentinel) 概念 作用 缺点 结构 搭建 三、redis集群 概述 原理 架构细节 选举过程 实验环境模拟 一、Redis主从复制 概念 是指将一台Redis服务器的数据,复制…...
【每日随笔】帝王心术 ② ( 如何培养下一代 | 重点培养孩子某一项特长 | 价值观培养 | 独立思考 | 人性和谋略教育 | 资源传承 | 人生指引 )
文章目录 一、重点培养孩子某一项特长二、价值观培养三、独立思考四、人性和谋略教育五、资源传承六、人生指引 一、重点培养孩子某一项特长 很多人 作为 父母 , 教育孩子 , 没有出息的占大多数 ; 父母 教育 孩子 , 给孩子培训 , 一般都给报个兴趣班 , 如果兴趣班的种类超过两…...
Git简介
Git是一个开源的分布式版本控制系统,用于敏捷高效地处理任何或大或小的项目。 Git是Linus Torvalds为了帮助管理Linux内核开发而开发的一个开放源代码的版本控制软件。 Git与常用的版本控制工具CVS、Subversion等不同,它采用了分布式版本库的方式&#x…...
STM32入门学习之定时器输入捕获
1.定时器的输入捕获可以用来测量脉冲宽度或者测量频率。输入捕获的原理图如下: 假设定时器是向上计数。在图中,t1~t2之间的便是我们要测量的高电平的时间(脉冲宽度)。首先,设置定时器为上升沿捕获,如此一来,在t1时刻可…...
RestClient
什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级ÿ…...
基于ASP.NET+ SQL Server实现(Web)医院信息管理系统
医院信息管理系统 1. 课程设计内容 在 visual studio 2017 平台上,开发一个“医院信息管理系统”Web 程序。 2. 课程设计目的 综合运用 c#.net 知识,在 vs 2017 平台上,进行 ASP.NET 应用程序和简易网站的开发;初步熟悉开发一…...
线程同步:确保多线程程序的安全与高效!
全文目录: 开篇语前序前言第一部分:线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分:synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分ÿ…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
Java - Mysql数据类型对应
Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...
RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文全面剖析RNN核心原理,深入讲解梯度消失/爆炸问题,并通过LSTM/GRU结构实现解决方案,提供时间序列预测和文本生成…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...
Fabric V2.5 通用溯源系统——增加图片上传与下载功能
fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...
【Redis】笔记|第8节|大厂高并发缓存架构实战与优化
缓存架构 代码结构 代码详情 功能点: 多级缓存,先查本地缓存,再查Redis,最后才查数据库热点数据重建逻辑使用分布式锁,二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...
