pytorch 3 计算图
计算图结构
分析:
- 起始节点 a
- b = 5 - 3a
- c = 2b + 3
- d = 5b + 6
- e = 7c + d^2
- f = 2e
- 最终输出 g = 3f - o(其中 o 是另一个输入)
前向传播
前向传播按照上述顺序计算每个节点的值。
反向传播过程
反向传播的目标是计算损失函数(这里假设为 g)对每个中间变量和输入的偏导数。从右向左进行计算:
- ∂g/∂o = -1
- ∂g/∂f = 3
- ∂f/∂e = 2
- ∂e/∂c = 7
- ∂e/∂d = 2d
- ∂d/∂b = 5
- ∂c/∂b = 2
- ∂b/∂a = -3
链式法则应用
使用链式法则计算出 g 对每个变量的全导数:
- dg/df = ∂g/∂f = 3
- dg/de = (∂g/∂f) * (∂f/∂e) = 3 * 2 = 6
- dg/dc = (dg/de) * (∂e/∂c) = 6 * 7 = 42
- dg/dd = (dg/de) * (∂e/∂d) = 6 * 2d
- dg/db = (dg/dc) * (∂c/∂b) + (dg/dd) * (∂d/∂b)
= 42 * 2 + 6 * 2d * 5
= 84 + 60d - dg/da = (dg/db) * (∂b/∂a)
= (84 + 60d) * (-3)
= -252 - 180d
最终梯度
最终得到 g 对输入 a 和 o 的梯度:
- dg/da = -252 - 180d
- dg/do = -1
代码实现
静态图
import mathclass Node:"""表示计算图中的一个节点。每个节点都可以存储一个值、梯度,并且知道如何计算前向传播和反向传播。"""def __init__(self, value=None):self.value = value # 节点的值self.gradient = 0 # 节点的梯度self.parents = [] # 父节点列表self.forward_fn = lambda: None # 前向传播函数self.backward_fn = lambda: None # 反向传播函数def __add__(self, other):"""加法操作"""return self._create_binary_operation(other, lambda x, y: x + y, lambda: (1, 1))def __mul__(self, other):"""乘法操作"""return self._create_binary_operation(other, lambda x, y: x * y, lambda: (other.value, self.value))def __sub__(self, other):"""减法操作"""return self._create_binary_operation(other, lambda x, y: x - y, lambda: (1, -1))def __pow__(self, power):"""幂运算"""result = Node()result.parents = [self]def forward():result.value = math.pow(self.value, power)def backward():self.gradient += power * math.pow(self.value, power-1) * result.gradientresult.forward_fn = forwardresult.backward_fn = backwardreturn resultdef _create_binary_operation(self, other, forward_op, gradient_op):"""创建二元操作的辅助方法。用于简化加法、乘法和减法的实现。"""result = Node()result.parents = [self, other]def forward():result.value = forward_op(self.value, other.value)def backward():grads = gradient_op()self.gradient += grads[0] * result.gradientother.gradient += grads[1] * result.gradientresult.forward_fn = forwardresult.backward_fn = backwardreturn resultdef topological_sort(node):"""对计算图进行拓扑排序。确保在前向和反向传播中按正确的顺序处理节点。"""visited = set()topo_order = []def dfs(n):if n not in visited:visited.add(n)for parent in n.parents:dfs(parent)topo_order.append(n)dfs(node)return topo_order# 构建计算图
a = Node(2) # 假设a的初始值为2
o = Node(1) # 假设o的初始值为1# 按照给定的数学表达式构建计算图
b = Node(5) - a * Node(3)
c = b * Node(2) + Node(3)
d = b * Node(5) + Node(6)
e = c * Node(7) + d ** 2
f = e * Node(2)
g = f * Node(3) - o# 前向传播
sorted_nodes = topological_sort(g)
for node in sorted_nodes:node.forward_fn()# 反向传播
g.gradient = 1 # 设置输出节点的梯度为1
for node in reversed(sorted_nodes):node.backward_fn()# 打印结果
print(f"g = {g.value}")
print(f"dg/da = {a.gradient}")
print(f"dg/do = {o.gradient}")# 验证手动计算的结果
d_value = 5 * b.value + 6
expected_dg_da = -252 - 180 * d_value
print(f"Expected dg/da = {expected_dg_da}")
print(f"Difference: {abs(a.gradient - expected_dg_da)}")
动态图
import mathclass Node:"""表示计算图中的一个节点。实现了动态计算图的核心功能,包括前向计算和反向传播。"""def __init__(self, value, children=(), op=''):self.value = value # 节点的值self.grad = 0 # 节点的梯度self._backward = lambda: None # 反向传播函数,默认为空操作self._prev = set(children) # 前驱节点集合self._op = op # 操作符,用于调试def __add__(self, other):"""加法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value + other.value, (self, other), '+')def _backward():self.grad += result.gradother.grad += result.gradresult._backward = _backwardreturn resultdef __mul__(self, other):"""乘法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value * other.value, (self, other), '*')def _backward():self.grad += other.value * result.gradother.grad += self.value * result.gradresult._backward = _backwardreturn resultdef __pow__(self, other):"""幂运算"""assert isinstance(other, (int, float)), "only supporting int/float powers for now"result = Node(self.value ** other, (self,), f'**{other}')def _backward():self.grad += (other * self.value**(other-1)) * result.gradresult._backward = _backwardreturn resultdef __neg__(self):"""取反操作"""return self * -1def __sub__(self, other):"""减法操作"""return self + (-other)def __truediv__(self, other):"""除法操作"""return self * other**-1def __radd__(self, other):"""反向加法"""return self + otherdef __rmul__(self, other):"""反向乘法"""return self * otherdef __rtruediv__(self, other):"""反向除法"""return other * self**-1def tanh(self):"""双曲正切函数"""x = self.valuet = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)result = Node(t, (self,), 'tanh')def _backward():self.grad += (1 - t**2) * result.gradresult._backward = _backwardreturn resultdef backward(self):"""执行反向传播,计算梯度。使用拓扑排序确保正确的反向传播顺序。"""topo = []visited = set()def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)self.grad = 1 # 设置输出节点的梯度为1for node in reversed(topo):node._backward() # 对每个节点执行反向传播def main():"""主函数,用于测试自动微分系统。构建一个计算图,执行反向传播,并验证结果。"""# 构建计算图a = Node(2)o = Node(1)b = Node(5) - a * 3c = b * 2 + 3d = b * 5 + 6e = c * 7 + d ** 2f = e * 2g = f * 3 - o# 反向传播g.backward()# 打印结果print(f"g = {g.value}")print(f"dg/da = {a.grad}")print(f"dg/do = {o.grad}")# 验证手动计算的结果d_value = 5 * b.value + 6expected_dg_da = -252 - 180 * d_valueprint(f"Expected dg/da = {expected_dg_da}")print(f"Difference: {abs(a.grad - expected_dg_da)}")if __name__ == "__main__":main()
解释:
Node
类代表计算图中的一个节点,包含值、梯度、父节点以及前向和反向传播函数。- 重载的数学运算符 (
__add__
,__mul__
,__sub__
,__pow__
) 允许直观地构建计算图。 _create_binary_operation
方法用于创建二元操作,简化了加法、乘法和减法的实现。topological_sort
函数对计算图进行拓扑排序,确保正确的计算顺序。
import mathclass Node:"""表示计算图中的一个节点。实现了动态计算图的核心功能,包括前向计算和反向传播。"""def __init__(self, value, children=(), op=''):self.value = value # 节点的值self.grad = 0 # 节点的梯度self._backward = lambda: None # 反向传播函数,默认为空操作self._prev = set(children) # 前驱节点集合self._op = op # 操作符,用于调试def __add__(self, other):"""加法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value + other.value, (self, other), '+')def _backward():self.grad += result.gradother.grad += result.gradresult._backward = _backwardreturn resultdef __mul__(self, other):"""乘法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value * other.value, (self, other), '*')def _backward():self.grad += other.value * result.gradother.grad += self.value * result.gradresult._backward = _backwardreturn resultdef __pow__(self, other):"""幂运算"""assert isinstance(other, (int, float)), "only supporting int/float powers for now"result = Node(self.value ** other, (self,), f'**{other}')def _backward():self.grad += (other * self.value**(other-1)) * result.gradresult._backward = _backwardreturn resultdef __neg__(self):"""取反操作"""return self * -1def __sub__(self, other):"""减法操作"""return self + (-other)def __truediv__(self, other):"""除法操作"""return self * other**-1def __radd__(self, other):"""反向加法"""return self + otherdef __rmul__(self, other):"""反向乘法"""return self * otherdef __rtruediv__(self, other):"""反向除法"""return other * self**-1def tanh(self):"""双曲正切函数"""x = self.valuet = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)result = Node(t, (self,), 'tanh')def _backward():self.grad += (1 - t**2) * result.gradresult._backward = _backwardreturn resultdef backward(self):"""执行反向传播,计算梯度。使用拓扑排序确保正确的反向传播顺序。"""topo = []visited = set()def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)self.grad = 1 # 设置输出节点的梯度为1for node in reversed(topo):node._backward() # 对每个节点执行反向传播def main():"""主函数,用于测试自动微分系统。构建一个计算图,执行反向传播,并验证结果。"""# 构建计算图a = Node(2)o = Node(1)b = Node(5) - a * 3c = b * 2 + 3d = b * 5 + 6e = c * 7 + d ** 2f = e * 2g = f * 3 - o# 反向传播g.backward()# 打印结果print(f"g = {g.value}")print(f"dg/da = {a.grad}")print(f"dg/do = {o.grad}")# 验证手动计算的结果d_value = 5 * b.value + 6expected_dg_da = -252 - 180 * d_valueprint(f"Expected dg/da = {expected_dg_da}")print(f"Difference: {abs(a.grad - expected_dg_da)}")if __name__ == "__main__":main()
解释:
-
Node
类是核心,它代表计算图中的一个节点,并实现了各种数学运算。 -
每个数学运算(如
__add__
,__mul__
等)都创建一个新的Node
,并定义了相应的反向传播函数。 -
backward
方法实现了反向传播算法,使用拓扑排序确保正确的计算顺序。
相关文章:

pytorch 3 计算图
计算图结构 分析: 起始节点 ab 5 - 3ac 2b 3d 5b 6e 7c d^2f 2e最终输出 g 3f - o(其中 o 是另一个输入) 前向传播 前向传播按照上述顺序计算每个节点的值。 反向传播过程 反向传播的目标是计算损失函数(这里假设为…...

一文吃透:暗水印是什么?企业防泄密可以加暗水印吗?
设计部主管:昨天下班的时候我在办公室捡到一张文件,上面可是我们最新产品的设计草稿,严禁打印的,到底是谁干的? 员工:办公室没有监控,似乎很难查到哦。 网络部经理:不用担心&#…...

Ajax-02.Axios
Axios入门 1.引入Axios的js文件 <script src"js/axios-0.18.0.js"></script> Axios 请求方式别名: axios.get(url[,config]) axios.delete(url[,config]) axios.post(url[,data[,config]]) axios.put(url[,data[,config]]) 发送GET/POST请求 axios.get…...
NodeJS的核心配置文件package.json和package.lock.json详解
package.json 文件 package.json 文件是 Node.js 项目的核心配置文件,它包含了项目的基本信息、依赖关系以及一些脚本命令等。以下是 package.json 文件的主要字段说明: name:项目的名称,必须是小写,可以包含字母、数…...
开源数据采集和跟踪系统:助力营销决策的关键工具
开源数据采集和跟踪系统:助力营销决策的关键工具 在现代营销中,数据是最重要的资产之一。了解用户行为、优化广告效果、提升转化率,这一切都离不开精准的数据分析。为了帮助商家更好地掌握这些数据,市场上出现了许多开源的数据采…...

Luminar Neo for Mac/Win:创新AI图像编辑软件的强大功能
Luminar Neo,这款由Skylum公司倾力打造的图像编辑软件,为Mac和Windows用户带来了前所未有的创作体验与编辑便利。作为一款融合了先进AI技术的图像处理工具,Luminar Neo以其独特的功能和高效的操作流程,成为了摄影师、设计师及摄影…...

Mac平台M1PRO芯片MiniCPM-V-2.6网页部署跑通
Mac平台M1PRO芯片MiniCPM-V-2.6网页部署跑通 契机 ⚙ 2.6的小钢炮可以输入视频了,我必须拉到本地跑跑。主要解决2.6版本默认绑定flash_atten问题,pip install flash_attn也无法安装,因为强制依赖cuda。主要解决的就是这个问题,还…...

MyBatis:Maven,Git,TortoiseGit,Gradle
1,Maven Maven是一个非常优秀的项目管理工具,采用一种“约定优于配置(CoC)”的策略来管理项目。使用Maven不仅可以把源代码构建成可发布的项目(包括编译、打包、测试和分发),还可以生成报告、生…...
获取链表中间位置的两种方法方法
方法一: 我们可以计算链表节点的数量,然后遍历链表找到前半部分的尾节点。 方法二: 我们也可以使用快慢指针在一次遍历中找到:慢指针一次走一步,快指针一次走两步,快慢指针同时出发。当快指针移动到链表的末尾时&am…...

第二十天的学习(2024.8.8)Vue拓展
昨天的笔记中,我们进行的项目已经可以在网页上显示查询到数据库中的数据,今天的笔记中将会完成在网页上进行增删改查的操作 1.删除表中数据 现在网页上只能呈现出数据库中的数据,我们首先添加一个删除按钮,使其可以对数据库数据…...

微信小程序教程011:全局配置:Window
文章目录 1、window1.1、`window`-小程序窗口的组成部分1.2、了解 window 节点常用的配置项1.3、设置导航栏的标题1.4、设置导航栏的背景色1.5、设置导航栏的标题颜色1.6、全局开启下拉刷新功能1.7、设置下拉刷新时窗口的背景色1.8、设置下拉刷新时 loading 的样式1.9、设置上拉…...

Tomcat服务器和Web项目的部署
目录 一、概述和作用 二、安装 1.进入官网 2.Download下面选择想要下载的版本 3.点击Which version查看版本所需要的JRE版本 4.返回上一页下载和电脑和操作系统匹配的Tomcat 5. 安装完成后,点击bin目录下的startup.bat(linux系统下就运行startup.sh&…...

PCIe学习笔记(22)
Transaction Ordering Transaction Ordering Rules 表2-40定义了PCI Express Transactions的排序要求。该表中定义的规则统一适用于PCI Express上所有类型的事务,包括内存、I/O、配置和消息。该表中定义的排序规则适用于单个流量类(TC)。不同TC标签的事务之间没有…...
Vue3 依赖注入Provide / Inject
在实际开发中,我们经常需要从父组件向子组件传递数据,一般情况下,我们使用 props。但有时候会遇到深度嵌套的组件,而深层的子组件只需要父组件的部分内容。在这种情况下,如果仍然将 prop 沿着组件链逐级传递下去&#…...

Python | Leetcode Python题解之第332题重新安排行程
题目: 题解: class Solution:def findItinerary(self, tickets: List[List[str]]) -> List[str]:def dfs(curr: str):while vec[curr]:tmp heapq.heappop(vec[curr])dfs(tmp)stack.append(curr)vec collections.defaultdict(list)for depart, arri…...

React状态管理:react-redux和redux-saga(适合由vue转到react的同学)
注意:本文不会把所有知识点都写一遍,并不适合纯新手阅读 首先Redux是一种状态管理方案,本身和react并没有什么联系,redux也可以结合其他框架来用。 react-redux是基于react的一种状态管理实现,他不像vuex那样直接内置在…...

刷题技巧:双指针法的核心思想总结+例题整合+力扣接雨水双指针c++实现
双指针法的核心思想是通过同时操作两个指针来遍历数据结构,通常是数组或链表,以达到优化算法性能的目的。具体来说,双指针法能够减少时间复杂度、空间复杂度,或者简化逻辑结构。以下是双指针法的几个核心思想: ps 下面…...

什么是前端微服务,有何优势
随着互联网技术的发展,传统的单体应用架构已经无法满足复杂业务场景的需求。微服务架构的兴起为后端应用的开发和部署提供了灵活性和可扩展性。与此同时,前端开发也经历了类似的演变,前端微服务作为一种新兴的架构模式应运而生。 一、前端微服…...
小论文写作——02:编故事
一篇论文,可以发水刊,也可以发顶刊顶会,这两者的区别就是一个故事编的好不好。 你的论文ABC,但不能之说有ABC。创新就是看你故事编的怎么样?创新是编出来的。 我们要说:我发现了问题,然后准备…...

GIT企业开发使用介绍
0.认识git git就是一个版本控制器,记录每次的修改以及版本迭代的一个管理系统 至于为什么会有git的出现,主要是为了解决一份代码改了又改,但最后还是要第一版的情况 git 可以控制电脑上所有格式的文档 1.安装git sudo yum install git -y…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试
作者:Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位:中南大学地球科学与信息物理学院论文标题:BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接:https://arxiv.…...

3.3.1_1 检错编码(奇偶校验码)
从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...
服务器硬防的应用场景都有哪些?
服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...

(二)原型模式
原型的功能是将一个已经存在的对象作为源目标,其余对象都是通过这个源目标创建。发挥复制的作用就是原型模式的核心思想。 一、源型模式的定义 原型模式是指第二次创建对象可以通过复制已经存在的原型对象来实现,忽略对象创建过程中的其它细节。 📌 核心特点: 避免重复初…...

微服务商城-商品微服务
数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...

使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...

20个超级好用的 CSS 动画库
分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码,而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库,可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画,可以包含在你的网页或应用项目中。 3.An…...
SQL慢可能是触发了ring buffer
简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...