当前位置: 首页 > news >正文

深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function

文章目录

  • 前言
  • 一、pytorch使用现有的自动微分机制
  • 二、torch.autograd.Function中的ctx解读
    • 1、`forward` 方法中的 `ctx`
    • 2、`backward` 方法中的 `ctx`
    • 3、小结
  • 三、pytorch自定义自动微分函数(torch.autograd.Function)
    • 1、torch.autograd.Function计算前向与后向传播梯度Demo
    • 2、前向传播梯度解读
    • 3、后向传播梯度解读
    • 4、运行结果


前言

随着深度学习技术的迅速发展,PyTorch 作为一款功能强大且灵活的深度学习框架,受到了广泛的关注和应用。它以其动态计算图、易用性以及强大的社区支持而闻名。在PyTorch中,自动微分(autograd)是其核心特性之一,它使得神经网络训练过程中的梯度计算变得简单高效。对于大多数应用场景而言,开发者无需手动编写反向传播逻辑,因为PyTorch能够自动处理这些细节。

然而,在某些特殊情况下,我们可能需要对特定的操作进行定制化的梯度计算,这时就需要深入了解并利用PyTorch提供的torch.autograd.Function类来实现自定义的前向和后向传播逻辑。通过这种方式,不仅可以实现更复杂的模型结构,还能优化性能或满足特定的研究需求。

本文将从基础出发,首先介绍如何使用PyTorch内置的自动微分机制完成常规的模型训练流程;接着详细解析torch.autograd.Function中的ctx对象及其在前后向传播间的作用;最后,通过一个具体的例子演示如何编写自定义的自动微分函数,并解释其中的关键概念和操作。希望通过这篇文章,读者能够掌握PyTorch自动微分的核心原理,以及如何根据实际需求设计高效的自定义梯度计算逻辑。

一、pytorch使用现有的自动微分机制

编写一个后向传播函数在 PyTorch 中通常是不需要的,因为 PyTorch 自动处理了自动微分(autograd),即通过 loss.backward() 来计算梯度。下面我们将展示如何编写一个简单的自定义后向传播函数,并解释如何在 PyTorch 中利用现有的自动微分机制进行反向传播。

通常情况下,你只需要调用 loss.backward() 即可完成反向传播,一个示列代码如下:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入和目标
input_tensor = torch.randn(1, 10, requires_grad=True)
target = torch.tensor([[1.]])# 前向传播
output = model(input_tensor)
loss = criterion(output, target)# 清除之前的梯度
optimizer.zero_grad()# 反向传播
loss.backward()# 更新参数
optimizer.step()

二、torch.autograd.Function中的ctx解读

在PyTorch中,torch.autograd.Function 是用来定义自定义自动求导函数的类。你提供的CustomReLU类继承了torch.autograd.Function并实现了自定义的前向传播和反向传播逻辑。这里的ctx(context)对象是用于存储信息以便在前向传播和反向传播之间共享。

1、forward 方法中的 ctx

forward方法中,ctx被用来保存在前向传播阶段计算的信息,这些信息可能在后续的反向传播过程中需要使用。例如:

@staticmethod
def forward(ctx, input):ctx.save_for_backward(input)  # 保存输入以供反向传播使用return input.clamp(min=0)
  • ctx.save_for_backward(input):这里我们保存了输入张量input。这很重要,因为在反向传播时我们需要知道哪些元素在前向传播中被设为零(即负数),以便正确地将梯度设为零。

2、backward 方法中的 ctx

backward方法中,ctx被用来访问在前向传播阶段保存的信息。例如:

@staticmethod
def backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
  • input, = ctx.saved_tensors:这里我们从ctx中获取了在前向传播阶段保存的输入张量。注意,saved_tensors是一个元组,即使只保存了一个张量,也需要用逗号来解包。

  • 接下来,我们基于原始输入创建了grad_input,它初始化为grad_output的副本。然后我们将所有在前向传播中对应的输入小于0的位置的梯度设为0,这是因为ReLU激活函数对于所有负值输入都输出0,所以其梯度也应为0。

3、小结

ctx的作用是在前向传播和反向传播之间传递必要的信息。通过ctx.save_for_backward()可以在前向传播中保存任何需要在反向传播中使用的数据,而在反向传播中则可以通过ctx.saved_tensors来访问这些数据。这对于实现自定义的自动求导函数来说是非常重要的,因为它允许我们在不需要显式管理复杂状态的情况下执行复杂的梯度计算。

三、pytorch自定义自动微分函数(torch.autograd.Function)

1、torch.autograd.Function计算前向与后向传播梯度Demo

然而,如果你确实需要自定义反向传播逻辑或理解其工作原理,可以通过定义自定义的自动微分函数来实现。如果你想自定义某些操作的反向传播逻辑,可以使用 torch.autograd.Function 来创建自定义的自动微分函数。以下是一个简单的例子:

import torchclass CustomReLU(torch.autograd.Function):@staticmethoddef forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input)  # 保存输入以供反向传播使用return input.clamp(min=0)@staticmethoddef backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input# 使用自定义 ReLU 函数
custom_relu = CustomReLU.apply# 示例:将自定义 ReLU 应用于输入
input_tensor = torch.randn(5, requires_grad=True)
output = custom_relu(input_tensor)# 创建一个简单的损失并进行反向传播
loss = output.sum()
loss.backward()print("Input tensor:", input_tensor)
print("Gradient of input tensor:", input_tensor.grad)

2、前向传播梯度解读

如果要计算前向传播梯度只有执行代码output = custom_relu(input_tensor)才能启动(input_tensor是一个维度[ …]),因此当启动了代码,就可以执行我们定义前向传播方法:

@staticmethod
def forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input)  # 保存输入以供反向传播使用return input.clamp(min=0)

而ctx.save_for_backward也是保存输入input内容,custom_relu只有一层模型,就是来自input_tensor值,所以ctx保存了input_tensor值,也是一个维度值。

3、后向传播梯度解读

如果要计算后向传播梯度只有执行代码loss.backward()才能启动,因此当启动了代码,就可以执行我们定义后向传播方法:

@staticmethod
def backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input

而ctx.saved_tensors是取前向保存的内容。

4、运行结果

在这里插入图片描述

相关文章:

深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function

文章目录 前言一、pytorch使用现有的自动微分机制二、torch.autograd.Function中的ctx解读1、forward 方法中的 ctx2、backward 方法中的 ctx3、小结 三、pytorch自定义自动微分函数&#xff08;torch.autograd.Function&#xff09;1、torch.autograd.Function计算前向与后向传…...

《C++ 赋能 K-Means 聚类算法:开启智能数据分类之旅》

在当今数字化浪潮汹涌澎湃的时代&#xff0c;人工智能无疑是引领科技变革的核心驱动力之一。而在人工智能的广袤天地中&#xff0c;数据分类与聚类作为挖掘数据内在价值、揭示数据潜在规律的关键技术手段&#xff0c;正发挥着前所未有的重要作用。K-Means 聚类算法&#xff0c;…...

对 JavaScript 说“不”

JavaScript编程语言历史悠久&#xff0c;但它是在 1995 年大约一周内创建的。 它最初被称为 LiveScript&#xff0c;但后来更名为 JavaScript&#xff0c;以赶上 Java 的潮流&#xff0c;尽管它与 Java 毫无关系。 它很快就变得非常流行&#xff0c;推动了 Web 应用程序革命&…...

spring下的beanutils.copyProperties实现深拷贝

spring下的beanutils.copyProperties方法是深拷贝还是浅拷贝&#xff1f;可以实现深拷贝吗&#xff1f; 答案&#xff1a;浅拷贝。 一、浅拷贝深拷贝的理解 简单说拷贝就是将一个类中的属性拷贝到另一个中&#xff0c;对于BeanUtils.copyProperties来说&#xff0c;你必须保…...

蓝桥杯二分题

P1083 [NOIP2012 提高组] 借教室 题目描述 在大学期间&#xff0c;经常需要租借教室。大到院系举办活动&#xff0c;小到学习小组自习讨论&#xff0c;都需要向学校申请借教室。教室的大小功能不同&#xff0c;借教室人的身份不同&#xff0c;借教室的手续也不一样。 面对海量租…...

3D数字化革新,探索博物馆的正确打开新方式!

3D数字化的发展&#xff0c;让博物馆也焕发新机&#xff0c;比如江苏省的“云上博物”&#xff0c;汇聚江苏全省博物馆展陈资源&#xff0c;采取线上展示和线下体验两种方式进行呈现的数字展览项目。在线上&#xff0c;用户可以通过H5或小程序进入“云上博物”数字展览空间&…...

工业检测基础-工业相机选型及应用场景

以下是一些常见的工业检测相机种类、检测原理、应用场景及选型依据&#xff1a; 2D相机 检测原理&#xff1a;基于二维图像捕获&#xff0c;通过分析图像的明暗、纹理、颜色等信息来检测物体的特征和缺陷.应用场景&#xff1a;广泛应用于平面工件的外观检测&#xff0c;如检测…...

通过 FRP 实现 P2P 通信:控制端与被控制端配置指南

本文介绍了如何通过 FRP 实现 P2P 通信。FRP&#xff08;Fast Reverse Proxy&#xff09;是一款高效的内网穿透工具&#xff0c;能够帮助用户突破 NAT 和防火墙的限制&#xff0c;将内网服务暴露到公网。通过 P2P 通信方式&#xff0c;FRP 提供了更加高效、低延迟的网络传输方式…...

即时通信系统项目总览

聊天室服务端项目总体介绍 本项目是一个全栈的即时通信系统, 前端使用QT实现聊天客户端, 后端采⽤微服务框架设计, 由网关子服务统一接收客户端的请求, 再分发到不同的子服务上处理并将结果返回给网关, 网关再将响应转发给客户端 拆分的微服务包含&#xff1a; 网关服务器&…...

QT获取tableview选中的行和列的值

查询数据库数据放入tableview&#xff08;tableView_database&#xff09;后 QSqlQueryModel* sql_model new QSqlQueryModel(this);sql_model->setQuery("select * from dxxxb_move_lot_tab");sql_model->setHeaderData(0, Qt::Horizontal, tr("id&quo…...

GDPU 人工智能 期末复习

1、python基础 2、回归、KNN、K-Means、搜索方法思想及算法实现步骤 3、知识表示基本概念 4、状态空间的相关概念、表示方法及应用 5、图搜索策略及应用 6、问题归约概念、与或图搜索、博弈树搜索与剪枝 7、决策树、贝叶斯决策算法及其应用 8、神经网络与深度学习基本概念 一、…...

编程之路,从0开始:补充篇

Hello大家好&#xff01;很高兴和大家又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《题海拾贝》、《数据结构与算法之美》 欢迎点赞&#xff0c;关注&#xff01; 这篇…...

使用缓存提升Web应用性能:从新手到高手的实践指南

引言 在现代Web开发中&#xff0c;性能优化是确保用户体验和系统稳定性的关键。使用缓存是提升网站性能的有效手段之一&#xff0c;可以显著减少数据库访问和计算开销。根据“网站优化第一定律”&#xff0c;缓存可以提升网站的响应速度&#xff0c;减少延迟&#xff0c;从而改…...

【数字电路与逻辑设计】实验一 序列检测器

文章总览&#xff1a;YuanDaiMa2048博客文章总览 【数字电路与逻辑设计】实验一 序列检测器 一、实验内容二、设计过程&#xff08;一&#xff09;作出状态图或状态表&#xff08;二&#xff09;状态化简&#xff08;三&#xff09;状态编码 三、源代码&#xff08;一&#xff…...

运动模糊效果

1、运动模糊效果 运动模糊效果&#xff0c;是一种用于 模拟真实世界中快速移动物体产生的模糊现象 的图像处理技术&#xff0c;当一个物体以较高速度移动时&#xff0c;由于人眼或摄像机的曝光时间过长&#xff0c;该物体会在图像中留下模糊的运动轨迹。这种效果游戏、动画、电…...

养老护理员培训考试题库;免费题库;大风车题库

下载链接&#xff1a;大风车题库-文件 大风车题库网站&#xff1a;大风车题库 大风车excel&#xff08;试题转excel&#xff09;&#xff1a;大风车excel...

Python-配置模块configparser使用指南

configparser 是 Python 标准库中的模块&#xff0c;用于处理配置文件&#xff08;如 .ini 文件&#xff09;。它适合管理程序的配置信息&#xff0c;比如数据库连接参数、应用程序设置等。 1. 配置文件的基本结构 配置文件通常是 .ini 格式&#xff0c;由 节&#xff08;Sec…...

C++的HDF5库将h5图像转为tif格式:szip压缩的图像也可转换

本文介绍基于C 语言的hdf5库与gdal库&#xff0c;将.h5格式的多波段HDF5图像批量转换为.tif格式的方法&#xff1b;其中&#xff0c;本方法支持对szip压缩的HDF5图像&#xff08;例如高分一号卫星遥感影像&#xff09;加以转换。 将HDF5图像批量转换为.tif格式&#xff0c;在部…...

【JAVA】Java第十三节:String类(String相关方法,以及StrinBuftrer , StringBulder相关方法)

本文详细介绍了String类以及常用的String相关方法&#xff0c;以及StrinBuftrer , StringBulder相关方法的使用&#xff0c;建议有印象即可&#xff0c;不需要都记住&#xff0c;使用时去查取即可 一、创建一个String类型的变量 我们平时创建String类型的变量一般是第一种形式…...

WordPress安装或访问时出现数据库连接错误的处理方式

一、在安装时出现数据库连接错误 1、如果数据库名称、用户名或密码错误&#xff0c;或者主机设置不正确&#xff08;如数据库服务器不是在本地localhost&#xff0c;而是在远程服务器&#xff0c;需要正确填写远程服务器的 IP 地址或域名&#xff09;&#xff0c;就会导致连接错…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上&#xff0c;你可以使用apt包管理器来安装NFS服务器。打开终端并运行&#xff1a; sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享&#xff0c;例如/shared&#xff1a; sudo mkdir /shared sud…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容

基于 ​UniApp + WebSocket​实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配​微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入&#xff0c;一个是通过INMP441麦克风模块采集音频&#xff0c;一个是通过PCM5102A模块播放音频&#xff0c;那如果我们将两者结合起来&#xff0c;将麦克风采集到的音频通过PCM5102A播放&#xff0c;是不是就可以做一个扩音器了呢…...

【配置 YOLOX 用于按目录分类的图片数据集】

现在的图标点选越来越多&#xff0c;如何一步解决&#xff0c;采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集&#xff08;每个目录代表一个类别&#xff0c;目录下是该类别的所有图片&#xff09;&#xff0c;你需要进行以下配置步骤&#x…...

css3笔记 (1) 自用

outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size&#xff1a;0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格&#xff…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

Python基于历史模拟方法实现投资组合风险管理的VaR与ES模型项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档&#xff09;&#xff0c;如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 在金融市场日益复杂和波动加剧的背景下&#xff0c;风险管理成为金融机构和个人投资者关注的核心议题之一。VaR&…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…...