深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function
文章目录
- 前言
- 一、pytorch使用现有的自动微分机制
- 二、torch.autograd.Function中的ctx解读
- 1、`forward` 方法中的 `ctx`
- 2、`backward` 方法中的 `ctx`
- 3、小结
- 三、pytorch自定义自动微分函数(torch.autograd.Function)
- 1、torch.autograd.Function计算前向与后向传播梯度Demo
- 2、前向传播梯度解读
- 3、后向传播梯度解读
- 4、运行结果
前言
随着深度学习技术的迅速发展,PyTorch 作为一款功能强大且灵活的深度学习框架,受到了广泛的关注和应用。它以其动态计算图、易用性以及强大的社区支持而闻名。在PyTorch中,自动微分(autograd)是其核心特性之一,它使得神经网络训练过程中的梯度计算变得简单高效。对于大多数应用场景而言,开发者无需手动编写反向传播逻辑,因为PyTorch能够自动处理这些细节。
然而,在某些特殊情况下,我们可能需要对特定的操作进行定制化的梯度计算,这时就需要深入了解并利用PyTorch提供的torch.autograd.Function类来实现自定义的前向和后向传播逻辑。通过这种方式,不仅可以实现更复杂的模型结构,还能优化性能或满足特定的研究需求。
本文将从基础出发,首先介绍如何使用PyTorch内置的自动微分机制完成常规的模型训练流程;接着详细解析torch.autograd.Function中的ctx对象及其在前后向传播间的作用;最后,通过一个具体的例子演示如何编写自定义的自动微分函数,并解释其中的关键概念和操作。希望通过这篇文章,读者能够掌握PyTorch自动微分的核心原理,以及如何根据实际需求设计高效的自定义梯度计算逻辑。
一、pytorch使用现有的自动微分机制
编写一个后向传播函数在 PyTorch 中通常是不需要的,因为 PyTorch 自动处理了自动微分(autograd),即通过 loss.backward() 来计算梯度。下面我们将展示如何编写一个简单的自定义后向传播函数,并解释如何在 PyTorch 中利用现有的自动微分机制进行反向传播。
通常情况下,你只需要调用 loss.backward() 即可完成反向传播,一个示列代码如下:
import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入和目标
input_tensor = torch.randn(1, 10, requires_grad=True)
target = torch.tensor([[1.]])# 前向传播
output = model(input_tensor)
loss = criterion(output, target)# 清除之前的梯度
optimizer.zero_grad()# 反向传播
loss.backward()# 更新参数
optimizer.step()
二、torch.autograd.Function中的ctx解读
在PyTorch中,torch.autograd.Function 是用来定义自定义自动求导函数的类。你提供的CustomReLU类继承了torch.autograd.Function并实现了自定义的前向传播和反向传播逻辑。这里的ctx(context)对象是用于存储信息以便在前向传播和反向传播之间共享。
1、forward 方法中的 ctx
在forward方法中,ctx被用来保存在前向传播阶段计算的信息,这些信息可能在后续的反向传播过程中需要使用。例如:
@staticmethod
def forward(ctx, input):ctx.save_for_backward(input) # 保存输入以供反向传播使用return input.clamp(min=0)
ctx.save_for_backward(input):这里我们保存了输入张量input。这很重要,因为在反向传播时我们需要知道哪些元素在前向传播中被设为零(即负数),以便正确地将梯度设为零。
2、backward 方法中的 ctx
在backward方法中,ctx被用来访问在前向传播阶段保存的信息。例如:
@staticmethod
def backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
-
input, = ctx.saved_tensors:这里我们从ctx中获取了在前向传播阶段保存的输入张量。注意,saved_tensors是一个元组,即使只保存了一个张量,也需要用逗号来解包。 -
接下来,我们基于原始输入创建了
grad_input,它初始化为grad_output的副本。然后我们将所有在前向传播中对应的输入小于0的位置的梯度设为0,这是因为ReLU激活函数对于所有负值输入都输出0,所以其梯度也应为0。
3、小结
ctx的作用是在前向传播和反向传播之间传递必要的信息。通过ctx.save_for_backward()可以在前向传播中保存任何需要在反向传播中使用的数据,而在反向传播中则可以通过ctx.saved_tensors来访问这些数据。这对于实现自定义的自动求导函数来说是非常重要的,因为它允许我们在不需要显式管理复杂状态的情况下执行复杂的梯度计算。
三、pytorch自定义自动微分函数(torch.autograd.Function)
1、torch.autograd.Function计算前向与后向传播梯度Demo
然而,如果你确实需要自定义反向传播逻辑或理解其工作原理,可以通过定义自定义的自动微分函数来实现。如果你想自定义某些操作的反向传播逻辑,可以使用 torch.autograd.Function 来创建自定义的自动微分函数。以下是一个简单的例子:
import torchclass CustomReLU(torch.autograd.Function):@staticmethoddef forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input) # 保存输入以供反向传播使用return input.clamp(min=0)@staticmethoddef backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input# 使用自定义 ReLU 函数
custom_relu = CustomReLU.apply# 示例:将自定义 ReLU 应用于输入
input_tensor = torch.randn(5, requires_grad=True)
output = custom_relu(input_tensor)# 创建一个简单的损失并进行反向传播
loss = output.sum()
loss.backward()print("Input tensor:", input_tensor)
print("Gradient of input tensor:", input_tensor.grad)
2、前向传播梯度解读
如果要计算前向传播梯度只有执行代码output = custom_relu(input_tensor)才能启动(input_tensor是一个维度[ …]),因此当启动了代码,就可以执行我们定义前向传播方法:
@staticmethod
def forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input) # 保存输入以供反向传播使用return input.clamp(min=0)
而ctx.save_for_backward也是保存输入input内容,custom_relu只有一层模型,就是来自input_tensor值,所以ctx保存了input_tensor值,也是一个维度值。
3、后向传播梯度解读
如果要计算后向传播梯度只有执行代码loss.backward()才能启动,因此当启动了代码,就可以执行我们定义后向传播方法:
@staticmethod
def backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
而ctx.saved_tensors是取前向保存的内容。
4、运行结果

相关文章:
深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function
文章目录 前言一、pytorch使用现有的自动微分机制二、torch.autograd.Function中的ctx解读1、forward 方法中的 ctx2、backward 方法中的 ctx3、小结 三、pytorch自定义自动微分函数(torch.autograd.Function)1、torch.autograd.Function计算前向与后向传…...
《C++ 赋能 K-Means 聚类算法:开启智能数据分类之旅》
在当今数字化浪潮汹涌澎湃的时代,人工智能无疑是引领科技变革的核心驱动力之一。而在人工智能的广袤天地中,数据分类与聚类作为挖掘数据内在价值、揭示数据潜在规律的关键技术手段,正发挥着前所未有的重要作用。K-Means 聚类算法,…...
对 JavaScript 说“不”
JavaScript编程语言历史悠久,但它是在 1995 年大约一周内创建的。 它最初被称为 LiveScript,但后来更名为 JavaScript,以赶上 Java 的潮流,尽管它与 Java 毫无关系。 它很快就变得非常流行,推动了 Web 应用程序革命&…...
spring下的beanutils.copyProperties实现深拷贝
spring下的beanutils.copyProperties方法是深拷贝还是浅拷贝?可以实现深拷贝吗? 答案:浅拷贝。 一、浅拷贝深拷贝的理解 简单说拷贝就是将一个类中的属性拷贝到另一个中,对于BeanUtils.copyProperties来说,你必须保…...
蓝桥杯二分题
P1083 [NOIP2012 提高组] 借教室 题目描述 在大学期间,经常需要租借教室。大到院系举办活动,小到学习小组自习讨论,都需要向学校申请借教室。教室的大小功能不同,借教室人的身份不同,借教室的手续也不一样。 面对海量租…...
3D数字化革新,探索博物馆的正确打开新方式!
3D数字化的发展,让博物馆也焕发新机,比如江苏省的“云上博物”,汇聚江苏全省博物馆展陈资源,采取线上展示和线下体验两种方式进行呈现的数字展览项目。在线上,用户可以通过H5或小程序进入“云上博物”数字展览空间&…...
工业检测基础-工业相机选型及应用场景
以下是一些常见的工业检测相机种类、检测原理、应用场景及选型依据: 2D相机 检测原理:基于二维图像捕获,通过分析图像的明暗、纹理、颜色等信息来检测物体的特征和缺陷.应用场景:广泛应用于平面工件的外观检测,如检测…...
通过 FRP 实现 P2P 通信:控制端与被控制端配置指南
本文介绍了如何通过 FRP 实现 P2P 通信。FRP(Fast Reverse Proxy)是一款高效的内网穿透工具,能够帮助用户突破 NAT 和防火墙的限制,将内网服务暴露到公网。通过 P2P 通信方式,FRP 提供了更加高效、低延迟的网络传输方式…...
即时通信系统项目总览
聊天室服务端项目总体介绍 本项目是一个全栈的即时通信系统, 前端使用QT实现聊天客户端, 后端采⽤微服务框架设计, 由网关子服务统一接收客户端的请求, 再分发到不同的子服务上处理并将结果返回给网关, 网关再将响应转发给客户端 拆分的微服务包含: 网关服务器&…...
QT获取tableview选中的行和列的值
查询数据库数据放入tableview(tableView_database)后 QSqlQueryModel* sql_model new QSqlQueryModel(this);sql_model->setQuery("select * from dxxxb_move_lot_tab");sql_model->setHeaderData(0, Qt::Horizontal, tr("id&quo…...
GDPU 人工智能 期末复习
1、python基础 2、回归、KNN、K-Means、搜索方法思想及算法实现步骤 3、知识表示基本概念 4、状态空间的相关概念、表示方法及应用 5、图搜索策略及应用 6、问题归约概念、与或图搜索、博弈树搜索与剪枝 7、决策树、贝叶斯决策算法及其应用 8、神经网络与深度学习基本概念 一、…...
编程之路,从0开始:补充篇
Hello大家好!很高兴和大家又见面啦!给生活添点passion,开始今天的编程之路! 我的博客:<但凡. 我的专栏:《编程之路》、《题海拾贝》、《数据结构与算法之美》 欢迎点赞,关注! 这篇…...
使用缓存提升Web应用性能:从新手到高手的实践指南
引言 在现代Web开发中,性能优化是确保用户体验和系统稳定性的关键。使用缓存是提升网站性能的有效手段之一,可以显著减少数据库访问和计算开销。根据“网站优化第一定律”,缓存可以提升网站的响应速度,减少延迟,从而改…...
【数字电路与逻辑设计】实验一 序列检测器
文章总览:YuanDaiMa2048博客文章总览 【数字电路与逻辑设计】实验一 序列检测器 一、实验内容二、设计过程(一)作出状态图或状态表(二)状态化简(三)状态编码 三、源代码(一ÿ…...
运动模糊效果
1、运动模糊效果 运动模糊效果,是一种用于 模拟真实世界中快速移动物体产生的模糊现象 的图像处理技术,当一个物体以较高速度移动时,由于人眼或摄像机的曝光时间过长,该物体会在图像中留下模糊的运动轨迹。这种效果游戏、动画、电…...
养老护理员培训考试题库;免费题库;大风车题库
下载链接:大风车题库-文件 大风车题库网站:大风车题库 大风车excel(试题转excel):大风车excel...
Python-配置模块configparser使用指南
configparser 是 Python 标准库中的模块,用于处理配置文件(如 .ini 文件)。它适合管理程序的配置信息,比如数据库连接参数、应用程序设置等。 1. 配置文件的基本结构 配置文件通常是 .ini 格式,由 节(Sec…...
C++的HDF5库将h5图像转为tif格式:szip压缩的图像也可转换
本文介绍基于C 语言的hdf5库与gdal库,将.h5格式的多波段HDF5图像批量转换为.tif格式的方法;其中,本方法支持对szip压缩的HDF5图像(例如高分一号卫星遥感影像)加以转换。 将HDF5图像批量转换为.tif格式,在部…...
【JAVA】Java第十三节:String类(String相关方法,以及StrinBuftrer , StringBulder相关方法)
本文详细介绍了String类以及常用的String相关方法,以及StrinBuftrer , StringBulder相关方法的使用,建议有印象即可,不需要都记住,使用时去查取即可 一、创建一个String类型的变量 我们平时创建String类型的变量一般是第一种形式…...
WordPress安装或访问时出现数据库连接错误的处理方式
一、在安装时出现数据库连接错误 1、如果数据库名称、用户名或密码错误,或者主机设置不正确(如数据库服务器不是在本地localhost,而是在远程服务器,需要正确填写远程服务器的 IP 地址或域名),就会导致连接错…...
椭圆曲线密码学(ECC)
一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...
WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)
一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...
12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
有限自动机到正规文法转换器v1.0
1 项目简介 这是一个功能强大的有限自动机(Finite Automaton, FA)到正规文法(Regular Grammar)转换器,它配备了一个直观且完整的图形用户界面,使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...
Linux nano命令的基本使用
参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...
高考志愿填报管理系统---开发介绍
高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...
热门Chrome扩展程序存在明文传输风险,用户隐私安全受威胁
赛门铁克威胁猎手团队最新报告披露,数款拥有数百万活跃用户的Chrome扩展程序正在通过未加密的HTTP连接静默泄露用户敏感数据,严重威胁用户隐私安全。 知名扩展程序存在明文传输风险 尽管宣称提供安全浏览、数据分析或便捷界面等功能,但SEMR…...
02.运算符
目录 什么是运算符 算术运算符 1.基本四则运算符 2.增量运算符 3.自增/自减运算符 关系运算符 逻辑运算符 &&:逻辑与 ||:逻辑或 !:逻辑非 短路求值 位运算符 按位与&: 按位或 | 按位取反~ …...
[特殊字符] 手撸 Redis 互斥锁那些坑
📖 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作,想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁,也顺便跟 Redisson 的 RLock 机制对比了下,记录一波,别踩我踩过…...
二叉树-144.二叉树的前序遍历-力扣(LeetCode)
一、题目解析 对于递归方法的前序遍历十分简单,但对于一位合格的程序猿而言,需要掌握将递归转化为非递归的能力,毕竟递归调用的时候会调用大量的栈帧,存在栈溢出风险。 二、算法原理 递归调用本质是系统建立栈帧,而非…...
