当前位置: 首页 > news >正文

深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function

文章目录

  • 前言
  • 一、pytorch使用现有的自动微分机制
  • 二、torch.autograd.Function中的ctx解读
    • 1、`forward` 方法中的 `ctx`
    • 2、`backward` 方法中的 `ctx`
    • 3、小结
  • 三、pytorch自定义自动微分函数(torch.autograd.Function)
    • 1、torch.autograd.Function计算前向与后向传播梯度Demo
    • 2、前向传播梯度解读
    • 3、后向传播梯度解读
    • 4、运行结果


前言

随着深度学习技术的迅速发展,PyTorch 作为一款功能强大且灵活的深度学习框架,受到了广泛的关注和应用。它以其动态计算图、易用性以及强大的社区支持而闻名。在PyTorch中,自动微分(autograd)是其核心特性之一,它使得神经网络训练过程中的梯度计算变得简单高效。对于大多数应用场景而言,开发者无需手动编写反向传播逻辑,因为PyTorch能够自动处理这些细节。

然而,在某些特殊情况下,我们可能需要对特定的操作进行定制化的梯度计算,这时就需要深入了解并利用PyTorch提供的torch.autograd.Function类来实现自定义的前向和后向传播逻辑。通过这种方式,不仅可以实现更复杂的模型结构,还能优化性能或满足特定的研究需求。

本文将从基础出发,首先介绍如何使用PyTorch内置的自动微分机制完成常规的模型训练流程;接着详细解析torch.autograd.Function中的ctx对象及其在前后向传播间的作用;最后,通过一个具体的例子演示如何编写自定义的自动微分函数,并解释其中的关键概念和操作。希望通过这篇文章,读者能够掌握PyTorch自动微分的核心原理,以及如何根据实际需求设计高效的自定义梯度计算逻辑。

一、pytorch使用现有的自动微分机制

编写一个后向传播函数在 PyTorch 中通常是不需要的,因为 PyTorch 自动处理了自动微分(autograd),即通过 loss.backward() 来计算梯度。下面我们将展示如何编写一个简单的自定义后向传播函数,并解释如何在 PyTorch 中利用现有的自动微分机制进行反向传播。

通常情况下,你只需要调用 loss.backward() 即可完成反向传播,一个示列代码如下:

import torch
import torch.nn as nn
import torch.optim as optim# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 初始化模型、损失函数和优化器
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 输入和目标
input_tensor = torch.randn(1, 10, requires_grad=True)
target = torch.tensor([[1.]])# 前向传播
output = model(input_tensor)
loss = criterion(output, target)# 清除之前的梯度
optimizer.zero_grad()# 反向传播
loss.backward()# 更新参数
optimizer.step()

二、torch.autograd.Function中的ctx解读

在PyTorch中,torch.autograd.Function 是用来定义自定义自动求导函数的类。你提供的CustomReLU类继承了torch.autograd.Function并实现了自定义的前向传播和反向传播逻辑。这里的ctx(context)对象是用于存储信息以便在前向传播和反向传播之间共享。

1、forward 方法中的 ctx

forward方法中,ctx被用来保存在前向传播阶段计算的信息,这些信息可能在后续的反向传播过程中需要使用。例如:

@staticmethod
def forward(ctx, input):ctx.save_for_backward(input)  # 保存输入以供反向传播使用return input.clamp(min=0)
  • ctx.save_for_backward(input):这里我们保存了输入张量input。这很重要,因为在反向传播时我们需要知道哪些元素在前向传播中被设为零(即负数),以便正确地将梯度设为零。

2、backward 方法中的 ctx

backward方法中,ctx被用来访问在前向传播阶段保存的信息。例如:

@staticmethod
def backward(ctx, grad_output):input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input
  • input, = ctx.saved_tensors:这里我们从ctx中获取了在前向传播阶段保存的输入张量。注意,saved_tensors是一个元组,即使只保存了一个张量,也需要用逗号来解包。

  • 接下来,我们基于原始输入创建了grad_input,它初始化为grad_output的副本。然后我们将所有在前向传播中对应的输入小于0的位置的梯度设为0,这是因为ReLU激活函数对于所有负值输入都输出0,所以其梯度也应为0。

3、小结

ctx的作用是在前向传播和反向传播之间传递必要的信息。通过ctx.save_for_backward()可以在前向传播中保存任何需要在反向传播中使用的数据,而在反向传播中则可以通过ctx.saved_tensors来访问这些数据。这对于实现自定义的自动求导函数来说是非常重要的,因为它允许我们在不需要显式管理复杂状态的情况下执行复杂的梯度计算。

三、pytorch自定义自动微分函数(torch.autograd.Function)

1、torch.autograd.Function计算前向与后向传播梯度Demo

然而,如果你确实需要自定义反向传播逻辑或理解其工作原理,可以通过定义自定义的自动微分函数来实现。如果你想自定义某些操作的反向传播逻辑,可以使用 torch.autograd.Function 来创建自定义的自动微分函数。以下是一个简单的例子:

import torchclass CustomReLU(torch.autograd.Function):@staticmethoddef forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input)  # 保存输入以供反向传播使用return input.clamp(min=0)@staticmethoddef backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input# 使用自定义 ReLU 函数
custom_relu = CustomReLU.apply# 示例:将自定义 ReLU 应用于输入
input_tensor = torch.randn(5, requires_grad=True)
output = custom_relu(input_tensor)# 创建一个简单的损失并进行反向传播
loss = output.sum()
loss.backward()print("Input tensor:", input_tensor)
print("Gradient of input tensor:", input_tensor.grad)

2、前向传播梯度解读

如果要计算前向传播梯度只有执行代码output = custom_relu(input_tensor)才能启动(input_tensor是一个维度[ …]),因此当启动了代码,就可以执行我们定义前向传播方法:

@staticmethod
def forward(ctx, input):"""在前向传播中,我们接收到一个上下文对象和一个输入张量,并返回一个经过 ReLU 激活的输出张量。"""ctx.save_for_backward(input)  # 保存输入以供反向传播使用return input.clamp(min=0)

而ctx.save_for_backward也是保存输入input内容,custom_relu只有一层模型,就是来自input_tensor值,所以ctx保存了input_tensor值,也是一个维度值。

3、后向传播梯度解读

如果要计算后向传播梯度只有执行代码loss.backward()才能启动,因此当启动了代码,就可以执行我们定义后向传播方法:

@staticmethod
def backward(ctx, grad_output):"""在反向传播中,我们接收到一个上下文对象和一个输出张量的梯度,并返回输入张量的梯度。"""input, = ctx.saved_tensorsgrad_input = grad_output.clone()grad_input[input < 0] = 0return grad_input

而ctx.saved_tensors是取前向保存的内容。

4、运行结果

在这里插入图片描述

相关文章:

深入理解 PyTorch 自动微分机制与自定义 torch.autograd.Function

文章目录 前言一、pytorch使用现有的自动微分机制二、torch.autograd.Function中的ctx解读1、forward 方法中的 ctx2、backward 方法中的 ctx3、小结 三、pytorch自定义自动微分函数&#xff08;torch.autograd.Function&#xff09;1、torch.autograd.Function计算前向与后向传…...

《C++ 赋能 K-Means 聚类算法:开启智能数据分类之旅》

在当今数字化浪潮汹涌澎湃的时代&#xff0c;人工智能无疑是引领科技变革的核心驱动力之一。而在人工智能的广袤天地中&#xff0c;数据分类与聚类作为挖掘数据内在价值、揭示数据潜在规律的关键技术手段&#xff0c;正发挥着前所未有的重要作用。K-Means 聚类算法&#xff0c;…...

对 JavaScript 说“不”

JavaScript编程语言历史悠久&#xff0c;但它是在 1995 年大约一周内创建的。 它最初被称为 LiveScript&#xff0c;但后来更名为 JavaScript&#xff0c;以赶上 Java 的潮流&#xff0c;尽管它与 Java 毫无关系。 它很快就变得非常流行&#xff0c;推动了 Web 应用程序革命&…...

spring下的beanutils.copyProperties实现深拷贝

spring下的beanutils.copyProperties方法是深拷贝还是浅拷贝&#xff1f;可以实现深拷贝吗&#xff1f; 答案&#xff1a;浅拷贝。 一、浅拷贝深拷贝的理解 简单说拷贝就是将一个类中的属性拷贝到另一个中&#xff0c;对于BeanUtils.copyProperties来说&#xff0c;你必须保…...

蓝桥杯二分题

P1083 [NOIP2012 提高组] 借教室 题目描述 在大学期间&#xff0c;经常需要租借教室。大到院系举办活动&#xff0c;小到学习小组自习讨论&#xff0c;都需要向学校申请借教室。教室的大小功能不同&#xff0c;借教室人的身份不同&#xff0c;借教室的手续也不一样。 面对海量租…...

3D数字化革新,探索博物馆的正确打开新方式!

3D数字化的发展&#xff0c;让博物馆也焕发新机&#xff0c;比如江苏省的“云上博物”&#xff0c;汇聚江苏全省博物馆展陈资源&#xff0c;采取线上展示和线下体验两种方式进行呈现的数字展览项目。在线上&#xff0c;用户可以通过H5或小程序进入“云上博物”数字展览空间&…...

工业检测基础-工业相机选型及应用场景

以下是一些常见的工业检测相机种类、检测原理、应用场景及选型依据&#xff1a; 2D相机 检测原理&#xff1a;基于二维图像捕获&#xff0c;通过分析图像的明暗、纹理、颜色等信息来检测物体的特征和缺陷.应用场景&#xff1a;广泛应用于平面工件的外观检测&#xff0c;如检测…...

通过 FRP 实现 P2P 通信:控制端与被控制端配置指南

本文介绍了如何通过 FRP 实现 P2P 通信。FRP&#xff08;Fast Reverse Proxy&#xff09;是一款高效的内网穿透工具&#xff0c;能够帮助用户突破 NAT 和防火墙的限制&#xff0c;将内网服务暴露到公网。通过 P2P 通信方式&#xff0c;FRP 提供了更加高效、低延迟的网络传输方式…...

即时通信系统项目总览

聊天室服务端项目总体介绍 本项目是一个全栈的即时通信系统, 前端使用QT实现聊天客户端, 后端采⽤微服务框架设计, 由网关子服务统一接收客户端的请求, 再分发到不同的子服务上处理并将结果返回给网关, 网关再将响应转发给客户端 拆分的微服务包含&#xff1a; 网关服务器&…...

QT获取tableview选中的行和列的值

查询数据库数据放入tableview&#xff08;tableView_database&#xff09;后 QSqlQueryModel* sql_model new QSqlQueryModel(this);sql_model->setQuery("select * from dxxxb_move_lot_tab");sql_model->setHeaderData(0, Qt::Horizontal, tr("id&quo…...

GDPU 人工智能 期末复习

1、python基础 2、回归、KNN、K-Means、搜索方法思想及算法实现步骤 3、知识表示基本概念 4、状态空间的相关概念、表示方法及应用 5、图搜索策略及应用 6、问题归约概念、与或图搜索、博弈树搜索与剪枝 7、决策树、贝叶斯决策算法及其应用 8、神经网络与深度学习基本概念 一、…...

编程之路,从0开始:补充篇

Hello大家好&#xff01;很高兴和大家又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《题海拾贝》、《数据结构与算法之美》 欢迎点赞&#xff0c;关注&#xff01; 这篇…...

使用缓存提升Web应用性能:从新手到高手的实践指南

引言 在现代Web开发中&#xff0c;性能优化是确保用户体验和系统稳定性的关键。使用缓存是提升网站性能的有效手段之一&#xff0c;可以显著减少数据库访问和计算开销。根据“网站优化第一定律”&#xff0c;缓存可以提升网站的响应速度&#xff0c;减少延迟&#xff0c;从而改…...

【数字电路与逻辑设计】实验一 序列检测器

文章总览&#xff1a;YuanDaiMa2048博客文章总览 【数字电路与逻辑设计】实验一 序列检测器 一、实验内容二、设计过程&#xff08;一&#xff09;作出状态图或状态表&#xff08;二&#xff09;状态化简&#xff08;三&#xff09;状态编码 三、源代码&#xff08;一&#xff…...

运动模糊效果

1、运动模糊效果 运动模糊效果&#xff0c;是一种用于 模拟真实世界中快速移动物体产生的模糊现象 的图像处理技术&#xff0c;当一个物体以较高速度移动时&#xff0c;由于人眼或摄像机的曝光时间过长&#xff0c;该物体会在图像中留下模糊的运动轨迹。这种效果游戏、动画、电…...

养老护理员培训考试题库;免费题库;大风车题库

下载链接&#xff1a;大风车题库-文件 大风车题库网站&#xff1a;大风车题库 大风车excel&#xff08;试题转excel&#xff09;&#xff1a;大风车excel...

Python-配置模块configparser使用指南

configparser 是 Python 标准库中的模块&#xff0c;用于处理配置文件&#xff08;如 .ini 文件&#xff09;。它适合管理程序的配置信息&#xff0c;比如数据库连接参数、应用程序设置等。 1. 配置文件的基本结构 配置文件通常是 .ini 格式&#xff0c;由 节&#xff08;Sec…...

C++的HDF5库将h5图像转为tif格式:szip压缩的图像也可转换

本文介绍基于C 语言的hdf5库与gdal库&#xff0c;将.h5格式的多波段HDF5图像批量转换为.tif格式的方法&#xff1b;其中&#xff0c;本方法支持对szip压缩的HDF5图像&#xff08;例如高分一号卫星遥感影像&#xff09;加以转换。 将HDF5图像批量转换为.tif格式&#xff0c;在部…...

【JAVA】Java第十三节:String类(String相关方法,以及StrinBuftrer , StringBulder相关方法)

本文详细介绍了String类以及常用的String相关方法&#xff0c;以及StrinBuftrer , StringBulder相关方法的使用&#xff0c;建议有印象即可&#xff0c;不需要都记住&#xff0c;使用时去查取即可 一、创建一个String类型的变量 我们平时创建String类型的变量一般是第一种形式…...

WordPress安装或访问时出现数据库连接错误的处理方式

一、在安装时出现数据库连接错误 1、如果数据库名称、用户名或密码错误&#xff0c;或者主机设置不正确&#xff08;如数据库服务器不是在本地localhost&#xff0c;而是在远程服务器&#xff0c;需要正确填写远程服务器的 IP 地址或域名&#xff09;&#xff0c;就会导致连接错…...

【JavaEE】-- HTTP

1. HTTP是什么&#xff1f; HTTP&#xff08;全称为"超文本传输协议"&#xff09;是一种应用非常广泛的应用层协议&#xff0c;HTTP是基于TCP协议的一种应用层协议。 应用层协议&#xff1a;是计算机网络协议栈中最高层的协议&#xff0c;它定义了运行在不同主机上…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!

5月28日&#xff0c;中天合创屋面分布式光伏发电项目顺利并网发电&#xff0c;该项目位于内蒙古自治区鄂尔多斯市乌审旗&#xff0c;项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站&#xff0c;总装机容量为9.96MWp。 项目投运后&#xff0c;每年可节约标煤3670…...

【HTTP三个基础问题】

面试官您好&#xff01;HTTP是超文本传输协议&#xff0c;是互联网上客户端和服务器之间传输超文本数据&#xff08;比如文字、图片、音频、视频等&#xff09;的核心协议&#xff0c;当前互联网应用最广泛的版本是HTTP1.1&#xff0c;它基于经典的C/S模型&#xff0c;也就是客…...

深度学习水论文:mamba+图像增强

&#x1f9c0;当前视觉领域对高效长序列建模需求激增&#xff0c;对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模&#xff0c;以及动态计算优势&#xff0c;在图像质量提升和细节恢复方面有难以替代的作用。 &#x1f9c0;因此短时间内&#xff0c;就有不…...

DingDing机器人群消息推送

文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人&#xff0c;点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置&#xff0c;详见说明文档 成功后&#xff0c;记录Webhook 2 API文档说明 点击设置说明 查看自…...

从面试角度回答Android中ContentProvider启动原理

Android中ContentProvider原理的面试角度解析&#xff0c;分为​​已启动​​和​​未启动​​两种场景&#xff1a; 一、ContentProvider已启动的情况 1. ​​核心流程​​ ​​触发条件​​&#xff1a;当其他组件&#xff08;如Activity、Service&#xff09;通过ContentR…...

华为OD机试-最短木板长度-二分法(A卷,100分)

此题是一个最大化最小值的典型例题&#xff0c; 因为搜索范围是有界的&#xff0c;上界最大木板长度补充的全部木料长度&#xff0c;下界最小木板长度&#xff1b; 即left0,right10^6; 我们可以设置一个候选值x(mid)&#xff0c;将木板的长度全部都补充到x&#xff0c;如果成功…...

ZYNQ学习记录FPGA(二)Verilog语言

一、Verilog简介 1.1 HDL&#xff08;Hardware Description language&#xff09; 在解释HDL之前&#xff0c;先来了解一下数字系统设计的流程&#xff1a;逻辑设计 -> 电路实现 -> 系统验证。 逻辑设计又称前端&#xff0c;在这个过程中就需要用到HDL&#xff0c;正文…...

【Redis】Redis从入门到实战:全面指南

Redis从入门到实战:全面指南 一、Redis简介 Redis(Remote Dictionary Server)是一个开源的、基于内存的键值存储系统,它可以用作数据库、缓存和消息代理。由Salvatore Sanfilippo于2009年开发,因其高性能、丰富的数据结构和广泛的语言支持而广受欢迎。 Redis核心特点:…...