计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战
前一篇文章,Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战
本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started
PyTorch 计算图和 Autograd
- 微积分之于机器学习
- Computational Graphs 计算图
- Autograd 自动求导
- 一个训练过程及 no_grad 的使用
- 示例代码
- 执行结果
- 生成数据
- 第一轮后
- 第二轮后
- 第十轮后
- 更多计算图的知识
- 更为复杂点的计算图的样子
- 自动求导有关的参数
- Links
微积分之于机器学习
机器学习的主要工作原理,就是万事万物存在规律,而我们使用机器来完成参数评估。参数评估的过程是随机梯度下降,也就是任意选择起点,然后使用微积分技术指导我们调优,找到一组最优参数值。
这就像我们爬山,面对众多的山峰,我们从不同的出发点出发,不断的朝着山顶前进,最终,我们即便起点不同,都可以达到山顶 - 通向山顶的路有多条。另外一方面,我们可能来到了不同的山顶。
在我们爬山的过程中,如何选择下一步呢?这时,就是微积分大显身手的时候了。
在机器学习中,对参数优化的过程,使用了大量微积分的运算,PyTorch 能成为通用性的机器学习框架,就在于不同的机器学习任务底层的数学原理是一致的,而 PyTorch 内置了这些标准化的数学运算,在 PyTorch 中,除了 Tensor 外,还有两个关键的概念:
- 计算图
- 自动求导
Computational Graphs 计算图
神经网络是由很多神经元组成的网络,最简单的神经网络就是只包含一个线性神经元的神经网络,理解这个最简单的神经网络,有助于理解任何复杂的神经网络。
z = x ∗ w + b z = x * w + b z=x∗w+b
注意:这里没有添加激活函数,这个神经元是一个简单的线性神经元。

计算过程:
- 加权输出 z 与理想输出 y 之间,使用交叉熵(CE)计算出损失(loss)
- 然后基于 loss 计算梯度 grad
- 基于梯度更新 w 和 b
这个计算过程,可以用一张图表达,一个图就是由节点以及边组成,边上定义操作符。同时,这个计算过程会在训练中发生多次,因为梯度下降算法是 SGD 迭代运算。
PyTorch 为了让每次运算可以更灵活,比如使用 Dropout 随机丢弃一些神经元,PyTorch 实现了每次运算动态的生成这张图 - 动态计算图1。也就是说,对于每次运算,PyTorch 会生成一个计算图并附着计算状态。
Autograd 自动求导
附着状态,最主要的目的就是实现自动求导。因为每个节点都是一个变量,变量和变量之间通过操作符相互依赖,而操作符和变量构成的函数式,就可以实现求导,根据链式法则,实现计算图中,每个变量的导数的计算。
在上图,只有一个线性神经元的情况下,PyTorch 的自动求导是如何工作的呢?参考下面的代码。
import torch# 定义输入和理想输出
x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output# 定义参数
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)# 定义模型,并进行一次运算
z = torch.matmul(x, w)+b# 定义损失函数,并得到单次的损失
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)# 进行反向传播,并得到梯度
loss.backward()
print(w.grad)
print(b.grad)
如此一来,参数更新将变得非常简单。计算图允许每次迭代传入不同的操作符等,实现训练过程更灵活的配置。计算图保留了运算过程中的 Tensor、操作符、操作符对应的导函数。当 loss.backward() 调用时,顺序的调用自动求导变量的导函数,得到 .grad 梯度值。
一个训练过程及 no_grad 的使用
现在我们看一个例子,通过一个简单的模型,了解训练中,自动求导机制是如何工作的。
示例代码
'''
autograd
'''
import plotly.graph_objects as go
import plotly.express as px
from torch import nn
import numpy as np
import torch
import math# 输入变量 x,理想输出 yt(生成 y 的函数就是要拟合的模型)
X = torch.tensor(np.linspace(-10, 10, 1000))
y = 1.5 * torch.sin(X) + 1.2 * torch.cos(X/4) # 真实的模型
yt = y + np.random.normal(0, 1, 1000)# vis
def plotter(X, y, yhat=None, title=None):with torch.no_grad():fig = go.Figure()fig.add_trace(go.Scatter(x=X, y=y, mode='lines', name='y'))fig.add_trace(go.Scatter(x=X, y=yt, mode='markers', marker=dict(size=4), name='yt'))if yhat is not None: fig.add_trace(go.Scatter(x=X, y=yhat, mode='lines', name='yhat'))fig.update_layout(template='none', title=title)fig.show()plotter(X, y, title='Data Generating Process')# 计算模型的实际输出,这里前提是假设知道变量 X 和函数 sin|cos, 而不知道参数 theta
def fit_model(theta:torch.tensor=torch.rand(3, requires_grad=True)):return theta[0] * X + theta[1] * torch.sin(X) + theta[2] * torch.cos(X/4)# 随机初始化参数,开启自动求导
theta = torch.randn(3, requires_grad=True)# 损失函数和优化器
loss_fn = nn.MSELoss() # MSE loss
optimizer = torch.optim.SGD([theta], lr=0.01) # build optimizer # 迭代训练
epochs = 500
for i in range(epochs):yhat = fit_model(theta) # 计算实际输出loss = loss_fn(y, yhat) # 将实际输出和理想输出传入损失函数,得到损失 lossloss.backward() # 反向传播,完成 .grad 梯度的计算optimizer.step() # 基于梯度完成参数更新 optimizer.zero_grad() # 本轮计算完成,将梯度值归零,否则下次计算损失并调用 backward 导致梯度累计 if i % (epochs/10) == 0: # 验证及输出调试信息 msg = f"loss: {loss.item():>7f} theta: {theta.detach().numpy()}"yhat = fit_model(theta)plotter(X, y, yhat.detach(), title=f"loss: {loss.item():>7f} theta: {theta.detach().numpy().round(3)}")
执行结果
生成数据
创建了一个假数据:
- 分布在象限中的点就是 x,y
- 象限中的曲线,就是符合设想的模型,我们看最终的机器学习的模型,能否拟合这条曲线

第一轮后
初始化后,实际模型和理想模型差距很大。注意,此时 theta 和目标参数差距很大。

第二轮后
经过两次迭代,差距在缩小。

第十轮后
又经过了几轮训练,此时,我们发现图中已经分辨不出来,但是从 theta 的值,我们还可以看到一点差距,这已经证明,机器学习拟合上了目标空间。

更多计算图的知识
更为复杂点的计算图的样子
在训练中,生成的 DAG 类似如下。

自动求导有关的参数
# 做一个计算图
x = torch.rand(1)
b = torch.rand(1, requires_grad=True)
w = torch.rand(1, requires_grad=True)
y = w * x # y 是一个新的 tensor# 检查 y 是否是叶子节点,这里 y 是输出,也就是 root 节点而不是 leaf 节点
print(y.is_leaf)# 反向传播
y.backward(retain_graph=True) # retain_graph=True,保留计算图中的状态,https://discuss.pytorch.org/t/use-of-retain-graph-true/179658
print(w.grad) # 查看梯度
Links
- How Computational Graphs are Constructed in PyTorch
- How Computational Graphs are Executed in PyTorch
- PyTorch’s Dynamic Graphs (Autograd)
- Automatic Differentiation with torch.autograd
- Autograd mechanics
PyTorch 使用 DAG 有向无环图这种格式存储计算图,其中输入的 Tensor 称为叶子节点(leaves),输出的 Tensor 称为根节点(roots)。 ↩︎
相关文章:
计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战
前一篇文章,Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started PyTorch 计算图和 Autograd 微积分之于机器学习Computational Graphs 计算图Autograd…...
基于 Java 开发的 MongoDB 企业级应用全解析
基于Java的MongoDB企业级应用开发实战 目录 背景与历史MongoDB的核心功能与特性企业级业务场景分析MongoDB的优缺点剖析开发环境搭建 5.1 JDK安装与配置5.2 MongoDB安装与集群配置5.3 开发工具选型 Java与MongoDB集成实战 6.1 项目依赖与驱动选择6.2 连接池与客户端配置6.3…...
构建一个测试助手Agent:提升测试效率的实践
在上一篇文章中,我们讨论了如何构建一个运维助手Agent。今天,我想分享另一个实际项目:如何构建一个测试助手Agent。这个项目源于我们一个大型互联网公司的真实需求 - 提升测试效率,保障产品质量。 从测试痛点说起 记得和测试团队讨论时的场景: 小张:每…...
ESXI虚拟机中部署docker会降低服务器性能
在 8 核 16GB 的 ESXi 虚拟机中部署 Docker 的性能影响分析 在 ESXi 虚拟机中运行 Docker 容器时,性能影响主要来自以下几个方面: 虚拟化开销:ESXi 虚拟化层和 Docker 容器化层的叠加。资源竞争:虚拟机与容器之间对 CPU、内存、…...
基于“蘑菇书”的强化学习知识点(五):条件期望
条件期望 摘要一、条件期望的定义二、条件期望的关键性质三、条件期望的直观理解四、条件期望的应用场景五、简单例子离散情况连续情况 摘要 本系列知识点讲解基于蘑菇书EasyRL中的内容进行详细的疑难点分析!具体内容请阅读蘑菇书EasyRL! 对应蘑菇书Eas…...
Linux抢占式内核:技术演进与源码解析
一、引言 Linux内核作为全球广泛使用的开源操作系统核心,其设计和实现一直是计算机科学领域的研究热点。从早期的非抢占式内核到2.6版本引入的抢占式内核,Linux在实时性和响应能力上取得了显著进步。本文将深入探讨Linux抢占式内核的引入背景、技术实现以及与非抢占式内核的…...
接入DeepSeek大模型
接入DeepSeek 下载并安装Ollamachatbox 软件配置大模型 下载并安装Ollama 下载并安装Ollama, 使用参数ollama -v查看是否安装成功。 输入命令ollama list, 可以看到已经存在4个目录了。 输入命令ollama pull deepseek-r1:1.5b, 下载deepse…...
【论文复现】粘菌算法在最优经济排放调度中的发展与应用
目录 1.摘要2.黏菌算法SMA原理3.改进策略4.结果展示5.参考文献6.代码获取 1.摘要 本文提出了一种改进粘菌算法(ISMA),并将其应用于考虑阀点效应的单目标和双目标经济与排放调度(EED)问题。为提升传统粘菌算法…...
SSM开发(十) SSM框架协同工作原理
目录 一、Spring扮演了一个整合者的角色 二、SSM拆解来看 三、SSM框架的核心优势 注: SSM框架(Spring + Spring MVC + MyBatis) 一、Spring扮演了一个整合者的角色 SSM框架中,Spring扮演了一个整合者的角色,它将Spring MVC的Web层和MyBatis的数据持久层连接起来。在SS…...
UE Bridge混合材质工具
打开虚幻内置Bridge 随便点个材质点右下角图标 就能打开材质混合工具 可以用来做顶点绘制...
基于 yolov8_pyqt5 自适应界面设计的火灾检测系统 demo:毕业设计参考
基于 yolov8_pyqt5 自适应界面设计的火灾检测系统 demo:毕业设计参考 【毕业设计参考】基于yolov8-pyqt5自适应界面设计的火灾检测系统demo.zip资源-CSDN文库 【毕业设计参考】基于yolov8-pyqt5自适应界面设计的火灾检测系统demo.zip资源-CSDN文库 一、项目背景 …...
Linux 传输层协议 UDP 和 TCP
UDP 协议 UDP 协议端格式 16 位 UDP 长度, 表示整个数据报(UDP 首部UDP 数据)的最大长度如果校验和出错, 就会直接丢弃 UDP 的特点 UDP 传输的过程类似于寄信 . 无连接: 知道对端的 IP 和端口号就直接进行传输, 不需要建立连接不可靠: 没有确认机制, 没有重传机制; 如果因…...
Android开发EventBus
Android开发EventBus 分享一个EventBus 工具类,封装一下,让你少写些代码 直接上代码: public class BaseEventBusUtils {public static void register(Object subscriber) {EventBus eventBus EventBus.getDefault();if (!eventBus.isReg…...
chrome浏览器chromedriver下载
chromedriver 下载地址 https://googlechromelabs.github.io/chrome-for-testing/ 上面的链接有和当前发布的chrome浏览器版本相近的chromedriver 实际使用感受 chrome浏览器会自动更新,可以去下载最新的chromedriver使用,自动化中使用新的chromedr…...
第一个Qt开发实例(一个Push Button按钮和两个Label)【包括如何在QtCreator中创建新工程、代码详解、编译、环境变量配置、测试程序运行等】
目录 Qt开发环境QtCreator的安装、配置在QtCreator中创建新工程在Forms→mainwindow.ui中拖曳出我们要的图形按钮查看拖曳出按钮后的代码为pushButton这个图形添加回调函数编译工程关闭开发板上QT的GUI(选做)禁止LCD黑屏(选做)设置Qt运行的环境变量运行Qt程序如何让程序在系统启…...
【react+redux】 react使用redux相关内容
首先说一下,文章中所提及的内容都是我自己的个人理解,是我理逻辑的时候,自我说服的方式,如果有问题有补充欢迎在评论区指出。 一、场景描述 为什么在react里面要使用redux,我的理解是因为想要使组件之间的通信更便捷…...
【435. 无重叠区间 中等】
题目: 给定一个区间的集合 intervals ,其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量,使剩余区间互不重叠 。 注意 只在一点上接触的区间是 不重叠的。例如 [1, 2] 和 [2, 3] 是不重叠的。 示例 1: 输入: intervals …...
文献学习笔记:中风醒脑液(FYTF-919)临床试验解读:有效还是无效?
【中风醒脑液(FYTF-919)临床试验解读:有效还是无效?】 在发表于 The Lancet (2024 年 11 月 30 日,第 404 卷)的临床研究《Traditional Chinese medicine FYTF-919 (Zhongfeng Xingnao oral pr…...
4 前端前置技术(中):node.js环境
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 前言...
5.角色基础移动
能帮到你的话,就给个赞吧 😘 文章目录 角色的xyz轴与移动方向拌合输入轴值add movement inputget controller rotationget right vectorget forward vector 发现模型的旋转改变后,xyz轴也会改变,所以需要旋转值来计算xyz轴方向。 …...
vue2语法速通
首先,git clone下来的项目要npm install下载依赖,如果是vue项目,运行通常npm run serve或者npm run dev vue速通一下 使用vite创建项目(较快) npm create vite 配置文件 src/ ├── assets/ # 存放…...
doris:基于导入的批量删除
基于导入的批量删除 删除操作可以视为数据更新的一种特殊形式。在主键模型(Unique Key)表上,Doris 支持通过导入数据时添加删除标记来实现删除操作。 相比 DELETE 语句,使用删除标记在以下场景中具有更好的易用性和性能优势&a…...
【商品库存管理——差分、前缀和】
题目 代码 #include <bits/stdc.h> using namespace std; const int N 3e510; int l[N], r[N], b[N]; int s1[N], s0[N]; int main() {int n, m;cin >> n >> m;for(int i 1; i < m; i){cin >> l[i] >> r[i];b[l[i]], b[r[i]1]--;}int a 0…...
Linux基本指令2
07.man指令(重要): Linux的命令有很多参数,我们不可能全记住,我们可以通过查看联机手册获取帮助。访问Linux手册页的命令是 man 语法: man [选项] 命令 man ls查看ls指令更多的说明。 man man: man指令就…...
运维监控平台 WGCLOUD
WGCLOUD v3.5.7 于 2025 年 2 月 3 日发布1。这是一款开源免费的分布式运维监控平台,server 端基于 springboot 开发,agent 端使用 go 编写1。以下是 v3.5.7 版本的更新内容1: 2. 自定义告警批量添加设置 3. 告警通知渠道设置 4. 告警规则设置…...
GDAL矢量数据集相关接口的资源控制问题
1. 引言 笔者在《使用GDAL读写矢量文件》这篇文章中总结了通过GDAL读写矢量的具体实现。不过这篇文章中并没有谈到涉及到矢量数据集相关接口的资源控制问题。具体来说,GDAL/OGR诞生的年代连C语言本身都不是很完善(c11之前),因此提…...
Android学习19 -- 手搓App
1 前言 之前工作中,很多时候要搞一个简单的app去验证底层功能,Android studio又过于重型,之前用gradle,被版本匹配和下载外网包折腾的堪称噩梦。所以搞app都只有找应用的同事帮忙。一直想知道一些简单的app怎么能手搓一下&#x…...
人工智能导论-第3章-知识点与学习笔记
参考教材3.2节的内容,介绍什么是自然演绎推理;解释“肯定后件”与“否定前件”两类错误的演绎推理是什么意义,给出具体例子加以阐述。参考教材3.3节的内容,介绍什么是文字(literal);介绍什么是子…...
wxWidgets中wxGrid表格使用示例,去掉竖向表头
这里设置表格各种属性如下: // 去掉竖向表头 grid->SetRowLabelSize(0); // 设置表格背景色为黑色 grid->SetDefaultCellBackgroundColour(*wxBLACK); // 设置单元格内容居中,字体为16号,白色 wxFont cellFont(16, wxFONTFAMILY_DEFAULT, wx…...
全面掌握市场信息:xtquant库在证券品种数据获取中的应用
全面掌握市场信息:xtquant库在证券品种数据获取中的应用 开篇点题:技术背景和应用场景 在量化交易领域,快速准确地获取市场基础信息是至关重要的。xtquant库提供了一种便捷的途径来获取各类证券品种的数据,包括股票、指数、基金等…...
