深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function
文章目录
- 前言
- 一、pytorch使用现有的自动微分机制
- 二、torch.autograd.Function中的ctx解读
- 1、`forward` 方法中的 `ctx`
- 2、`backward` 方法中的 `ctx`
- 3、小结
- 三、pytorch自定义自动微分函数(torch.autograd.Function)
- 1、torch.autograd.Function计算前向与后向传播梯度Demo
- 2、前向传播梯度解读
- 3、后向传播梯度解读
- 4、运行结果
前言
随着深度学习技术的迅速发展,PyTorch 作为一款功能强大且灵活的深度学习框架,受到了广泛的关注和应用。它以其动态计算图、易用性以及强大的社区支持而闻名。在PyTorch中,自动微分(autograd)是其核心特性之一,它使得神经网络训练过程中的梯度计算变得简单高效。对于大多数应用场景而言,开发者无需手动编写反向传播逻辑,因为PyTorch能够自动处理这些细节。
然而,在某些特殊情况下,我们可能需要对特定的操作进行定制化的梯度计算,这时就需要深入了解并利用PyTorch提供的torch.autograd.Function类来实现自定义的前向和后向传播逻辑。通过这种方式,不仅可以实现更复杂的模型结构,还能优化性能或满足特定的研究需求。
本文将从基础出发,首先介绍如何使用PyTorch内置的自动微分机制完成常规的模型训练流程;接着详细解析torch.autograd.Function中的ctx对象及其在前后向传播间的作用;最后,通过一个具体的例子演示如何编写自定义的自动微分函数,并解释其中的关键概念和操作。希望通过这篇文章,读者能够掌握PyTorch自动微分的核心原理,以及如何根据实际需求设计高效的自定义梯度计算逻辑。
一、pytorch使用现有的自动微分机制
编写一个后向传播函数在 PyTorch 中通常是不需要的,因为 PyTorch 自动处理了自动微分(autograd),即通过 loss.backward()
来计算梯度。下面我们将展示如何编写一个简单的自定义后向传播函数,并解释如何在 PyTorch 中利用现有的自动微分机制进行反向传播。
通常情况下,你只需要调用 loss.backward()
即可完成反向传播,一个示列代码如下:
import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入和目标
input_tensor = torch.randn(1, 10, requires_grad=True)
target = torch.tensor([[1.]])# 前向传播
output = model(input_tensor)
loss = criterion(output, target)# 清除之前的梯度
optimizer.zero_grad()# 反向传播
loss.backward()# 更新参数
optimizer.step()
二、torch.autograd.Function中的ctx解读
在PyTorch中,torch.autograd.Function
是用来定义自定义自动求导函数的类。你提供的CustomReLU
类继承了torch.autograd.Function
并实现了自定义的前向传播和反向传播逻辑。这里的ctx
(context)对象是用于存储信息以便在前向传播和反向传播之间共享。
1、forward
方法中的 ctx
在forward
方法中,ctx
被用来保存在前向传播阶段计算的信息,这些信息可能在后续的反向传播过程中需要使用。例如:
@staticmethod
def forward(ctx, input):ctx.save_for_backward(input) # 保存输入以供反向传播使用return input.clamp(min=0)
ctx.save_for_backward(input)
:这里我们保存了输入张量input
。这很重要,因为在反向传播时我们需要知道哪些元素在前向传播中被设为零(即负数),以便正确地将梯度设为零。
2、backward
方法中的 ctx
在backward
方法中,ctx
被用来访问在前向传播阶段保存的信息。例如:
@staticmethod
def backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
-
input, = ctx.saved_tensors
:这里我们从ctx
中获取了在前向传播阶段保存的输入张量。注意,saved_tensors
是一个元组,即使只保存了一个张量,也需要用逗号来解包。 -
接下来,我们基于原始输入创建了
grad_input
,它初始化为grad_output
的副本。然后我们将所有在前向传播中对应的输入小于0的位置的梯度设为0,这是因为ReLU激活函数对于所有负值输入都输出0,所以其梯度也应为0。
3、小结
ctx
的作用是在前向传播和反向传播之间传递必要的信息。通过ctx.save_for_backward()
可以在前向传播中保存任何需要在反向传播中使用的数据,而在反向传播中则可以通过ctx.saved_tensors
来访问这些数据。这对于实现自定义的自动求导函数来说是非常重要的,因为它允许我们在不需要显式管理复杂状态的情况下执行复杂的梯度计算。
三、pytorch自定义自动微分函数(torch.autograd.Function)
1、torch.autograd.Function计算前向与后向传播梯度Demo
然而,如果你确实需要自定义反向传播逻辑或理解其工作原理,可以通过定义自定义的自动微分函数来实现。如果你想自定义某些操作的反向传播逻辑,可以使用 torch.autograd.Function
来创建自定义的自动微分函数。以下是一个简单的例子:
import torchclass CustomReLU(torch.autograd.Function):@staticmethoddef forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input) # 保存输入以供反向传播使用return input.clamp(min=0)@staticmethoddef backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input# 使用自定义 ReLU 函数
custom_relu = CustomReLU.apply# 示例:将自定义 ReLU 应用于输入
input_tensor = torch.randn(5, requires_grad=True)
output = custom_relu(input_tensor)# 创建一个简单的损失并进行反向传播
loss = output.sum()
loss.backward()print("Input tensor:", input_tensor)
print("Gradient of input tensor:", input_tensor.grad)
2、前向传播梯度解读
如果要计算前向传播梯度只有执行代码output = custom_relu(input_tensor)
才能启动(input_tensor是一个维度[ …]),因此当启动了代码,就可以执行我们定义前向传播方法:
@staticmethod
def forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input) # 保存输入以供反向传播使用return input.clamp(min=0)
而ctx.save_for_backward也是保存输入input内容,custom_relu只有一层模型,就是来自input_tensor值,所以ctx保存了input_tensor值,也是一个维度值。
3、后向传播梯度解读
如果要计算后向传播梯度只有执行代码loss.backward()
才能启动,因此当启动了代码,就可以执行我们定义后向传播方法:
@staticmethod
def backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
而ctx.saved_tensors是取前向保存的内容。
4、运行结果
相关文章:

深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function
文章目录 前言一、pytorch使用现有的自动微分机制二、torch.autograd.Function中的ctx解读1、forward 方法中的 ctx2、backward 方法中的 ctx3、小结 三、pytorch自定义自动微分函数(torch.autograd.Function)1、torch.autograd.Function计算前向与后向传…...

《C++ 赋能 K-Means 聚类算法:开启智能数据分类之旅》
在当今数字化浪潮汹涌澎湃的时代,人工智能无疑是引领科技变革的核心驱动力之一。而在人工智能的广袤天地中,数据分类与聚类作为挖掘数据内在价值、揭示数据潜在规律的关键技术手段,正发挥着前所未有的重要作用。K-Means 聚类算法,…...

对 JavaScript 说“不”
JavaScript编程语言历史悠久,但它是在 1995 年大约一周内创建的。 它最初被称为 LiveScript,但后来更名为 JavaScript,以赶上 Java 的潮流,尽管它与 Java 毫无关系。 它很快就变得非常流行,推动了 Web 应用程序革命&…...

spring下的beanutils.copyProperties实现深拷贝
spring下的beanutils.copyProperties方法是深拷贝还是浅拷贝?可以实现深拷贝吗? 答案:浅拷贝。 一、浅拷贝深拷贝的理解 简单说拷贝就是将一个类中的属性拷贝到另一个中,对于BeanUtils.copyProperties来说,你必须保…...

蓝桥杯二分题
P1083 [NOIP2012 提高组] 借教室 题目描述 在大学期间,经常需要租借教室。大到院系举办活动,小到学习小组自习讨论,都需要向学校申请借教室。教室的大小功能不同,借教室人的身份不同,借教室的手续也不一样。 面对海量租…...

3D数字化革新,探索博物馆的正确打开新方式!
3D数字化的发展,让博物馆也焕发新机,比如江苏省的“云上博物”,汇聚江苏全省博物馆展陈资源,采取线上展示和线下体验两种方式进行呈现的数字展览项目。在线上,用户可以通过H5或小程序进入“云上博物”数字展览空间&…...

工业检测基础-工业相机选型及应用场景
以下是一些常见的工业检测相机种类、检测原理、应用场景及选型依据: 2D相机 检测原理:基于二维图像捕获,通过分析图像的明暗、纹理、颜色等信息来检测物体的特征和缺陷.应用场景:广泛应用于平面工件的外观检测,如检测…...

通过 FRP 实现 P2P 通信:控制端与被控制端配置指南
本文介绍了如何通过 FRP 实现 P2P 通信。FRP(Fast Reverse Proxy)是一款高效的内网穿透工具,能够帮助用户突破 NAT 和防火墙的限制,将内网服务暴露到公网。通过 P2P 通信方式,FRP 提供了更加高效、低延迟的网络传输方式…...

即时通信系统项目总览
聊天室服务端项目总体介绍 本项目是一个全栈的即时通信系统, 前端使用QT实现聊天客户端, 后端采⽤微服务框架设计, 由网关子服务统一接收客户端的请求, 再分发到不同的子服务上处理并将结果返回给网关, 网关再将响应转发给客户端 拆分的微服务包含: 网关服务器&…...

QT获取tableview选中的行和列的值
查询数据库数据放入tableview(tableView_database)后 QSqlQueryModel* sql_model new QSqlQueryModel(this);sql_model->setQuery("select * from dxxxb_move_lot_tab");sql_model->setHeaderData(0, Qt::Horizontal, tr("id&quo…...

GDPU 人工智能 期末复习
1、python基础 2、回归、KNN、K-Means、搜索方法思想及算法实现步骤 3、知识表示基本概念 4、状态空间的相关概念、表示方法及应用 5、图搜索策略及应用 6、问题归约概念、与或图搜索、博弈树搜索与剪枝 7、决策树、贝叶斯决策算法及其应用 8、神经网络与深度学习基本概念 一、…...

编程之路,从0开始:补充篇
Hello大家好!很高兴和大家又见面啦!给生活添点passion,开始今天的编程之路! 我的博客:<但凡. 我的专栏:《编程之路》、《题海拾贝》、《数据结构与算法之美》 欢迎点赞,关注! 这篇…...

使用缓存提升Web应用性能:从新手到高手的实践指南
引言 在现代Web开发中,性能优化是确保用户体验和系统稳定性的关键。使用缓存是提升网站性能的有效手段之一,可以显著减少数据库访问和计算开销。根据“网站优化第一定律”,缓存可以提升网站的响应速度,减少延迟,从而改…...

【数字电路与逻辑设计】实验一 序列检测器
文章总览:YuanDaiMa2048博客文章总览 【数字电路与逻辑设计】实验一 序列检测器 一、实验内容二、设计过程(一)作出状态图或状态表(二)状态化简(三)状态编码 三、源代码(一ÿ…...

运动模糊效果
1、运动模糊效果 运动模糊效果,是一种用于 模拟真实世界中快速移动物体产生的模糊现象 的图像处理技术,当一个物体以较高速度移动时,由于人眼或摄像机的曝光时间过长,该物体会在图像中留下模糊的运动轨迹。这种效果游戏、动画、电…...

养老护理员培训考试题库;免费题库;大风车题库
下载链接:大风车题库-文件 大风车题库网站:大风车题库 大风车excel(试题转excel):大风车excel...

Python-配置模块configparser使用指南
configparser 是 Python 标准库中的模块,用于处理配置文件(如 .ini 文件)。它适合管理程序的配置信息,比如数据库连接参数、应用程序设置等。 1. 配置文件的基本结构 配置文件通常是 .ini 格式,由 节(Sec…...

C++的HDF5库将h5图像转为tif格式:szip压缩的图像也可转换
本文介绍基于C 语言的hdf5库与gdal库,将.h5格式的多波段HDF5图像批量转换为.tif格式的方法;其中,本方法支持对szip压缩的HDF5图像(例如高分一号卫星遥感影像)加以转换。 将HDF5图像批量转换为.tif格式,在部…...

【JAVA】Java第十三节:String类(String相关方法,以及StrinBuftrer , StringBulder相关方法)
本文详细介绍了String类以及常用的String相关方法,以及StrinBuftrer , StringBulder相关方法的使用,建议有印象即可,不需要都记住,使用时去查取即可 一、创建一个String类型的变量 我们平时创建String类型的变量一般是第一种形式…...

WordPress安装或访问时出现数据库连接错误的处理方式
一、在安装时出现数据库连接错误 1、如果数据库名称、用户名或密码错误,或者主机设置不正确(如数据库服务器不是在本地localhost,而是在远程服务器,需要正确填写远程服务器的 IP 地址或域名),就会导致连接错…...

JAVA-面向对象基础
文章目录 概要封装多态抽象类接口内部类为什么需要内部类 概要 面向对象是一种编程范式或设计哲学,它将软件系统设计为由多个对象组成,这些对象通过特定的方式相互作用 封装 将数据和操作数据的方法封装在一个类中,并通过访问修饰符控制对…...

[Java]项目入门
这篇简单介绍一些入门的有关项目和行业的知识,并带着实现一个小项目。便于已经编程入门的各位准备进阶到下一个阶段。 先大致地介绍,一个完整的项目(不看客户端、服务端的分类)基本可以划分为三部分: 1.前端。比如你现在看到的CSDN页面就是一…...

opencv Mat To Heif
高效率图像文件格式(英语:High Efficiency Image File Format, HEIF;也称高效图像文件格式)是一个用于单张图像或图像序列的文件格式。它由运动图像专家组(MPEG)开发,并在MPEG-H Part 12&#x…...

二刷代码随想录第24天
93. 复原 IP 地址 确定函数is_ip的实现细节,start不能超过end,没有0开头的非0数字,每个字符都在0-9之间,每段字符小于255在原字符串s上做操作会更简单一些 class Solution { public:vector<string> result;vector<string> rest…...

Java设计模式之状态模式架构高扩展的订单状态管理
🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c=1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,精通Java编程,高并发设计,Springboot和微服务,熟悉Linux,ESXI虚拟化以及云原生Docker和K8s…...

Yagmail邮件发送库:如何用Python实现自动化邮件营销?
目录 一、Yagmail简介 二、安装Yagmail 三、基本使用示例 1. 发送简单文本邮件 2. 发送HTML邮件 3. 发送带有附件的邮件 4. 多收件人处理 5. 自定义邮件头 四、高级功能 1. SMTP配置 2. 邮件模板 3. OAuth2认证 五、自动化邮件营销案例 六、错误处理和调试 七、…...

李宏毅深度学习-Pytorch Tutorial2
什么是张量? 张量(Tensor)是深度学习和机器学习中一个非常基础且重要的概念。在数学上,张量可以被看作是向量和矩阵的泛化。简单来说,张量是一种多维数组,它可以表示标量(0维)、向量…...

SaaS财务软件:赋能企业数字化转型
在数字化浪潮的推动下,企业财务管理正逐步迈向智能化、高效化的新阶段。在这个过程中,SaaS财务软件应运而生,成为许多企业的首选。以易舟云财务软件为例,这款软件不仅集成了众多先进的财务管理功能,而且在用户体验上做…...

FPGA实战篇(按键控制LDE实验)
1.按键简介 按键开关是一种电子开关,属于电子元器件类。我们的开发板上有两种按键开关:第一种是本实验所使用的轻触式按键开关,简称轻触开关。使用时以向开关的操作方向施加压力使内部电路闭合接通,当撤销压力时开关断开ÿ…...

在Ubuntu-22.04 [WSL2]中配置Docker
文章目录 0. 进入Ubuntu-22.041. 更新系统软件包2. 安装Docker相关依赖包3. 添加Docker官方GPG密钥4. 添加Docker软件源5. 安装Docker Engine5.1 更新软件包列表5.2 安装Docker相关软件包 6. 验证Docker安装是否成功6.1 查看Docker版本信息6.2 启动Docker6.3 配置镜像加速器6.4…...