当前位置: 首页 > news >正文

Pytorch自定义算子反向传播

文章目录

      • 自定义一个线性函数算子
      • 如何实现反向传播

有关 自定义算子的实现前面已经提到,可以参考。本文讲述自定义算子如何前向推理+反向传播进行模型训练。

自定义一个线性函数算子

线性函数 Y = X W T + B Y = XW^T + B Y=XWT+B 定义输入M 个X变量,输出N个Y变量的线性方程组。
X X X 为一个 1 x M 矩阵, W W W为 N x M 矩阵, B B B 为 1xN 矩阵,根据公式,输出 Y Y Y为1xN 矩阵。其中 W 和 B 为算子权重参数,保存在模型中。
在训练时刻,模型输入 X X X , 和监督值 Y Y Y,根据 算子forward()计算的 Y p Y^p Yp ,计算Loss = criterion( Y Y Y, Y p Y^p Yp ),然后根据backward()链式求导反向传播计算梯度值。最后根据梯度更新W 和 B 参数。

class LinearF(torch.autograd.Function):@staticmethoddef symbolic(g, input, weight, bias):return g.op("MYLINEAR", input, weight, bias)@staticmethoddef forward(ctx, input:Tensor, weight: Tensor, bias: Tensor) -> Tensor:output = input @ weight.T + bias[None, ...]ctx.save_for_backward(input, weight)return output@staticmethoddef backward(ctx, grad_output:Tensor)->Tuple[Tensor, Tensor, Tensor]:# grad_output -- [B, N] = d(Loss) / d(Y)input, weight = ctx.saved_tensorsgrad_input = grad_output @ weightgrad_weight = grad_output.T @ inputgrad_bias = grad_output.sum(0)# print("grad_input: ", grad_input)# print("grad_weight: ", grad_weight)# print("grad_bias: ", grad_bias)return grad_input, grad_weight, grad_bias

如何实现反向传播

在这里插入图片描述
前向推理比较简单,就根据公式来既可以。反向传播backward() 怎么写呢?
反向传播有两个输入参数,第一个为ctx,第二个grad_output,grad_output就是对forward() 输出output 的求导,如果是最后的节点,那就是loss对输出的求导,否则就是下一层对输出求导,输出grad_input, grad_weight, grad_bias则分别对应了forward的输入input、weight、bias的梯度。这很容易理解,因为是在做链式求导,LinearFunction是这条链上的某个节点,输入输出的数量和含义刚好相反且对应。
根据公式:
Y = X W T + B Y = XW^T + B Y=XWT+B
Loss = criterion( Y t Y^t_{} Yt, Y Y_{} Y ), 假设我们选择判别函数为L2范数,Loss = ∑ j = 0 N 0.5 ∗ ( Y j t − Y j ) 2 \sum_{j=0}^N0.5 * (Y^t_{j}-Y_{j} )^2 j=0N0.5(YjtYj)2

grad_output(j) = d ( L o s s ) d ( Y j ) \frac{d(Loss) }{d(Y_{j})} d(Yj)d(Loss) = Y j t − Y j Y^t_{j} - Y_{j} YjtYj

其中 Y j t Y^t_{j} Yjt为监督值, Y j Y_{j} Yj为模型输出值。

根据链式求导法则, 对输入 X i X_{i} Xi 的求导为

grad_input[i] = ∑ j = 0 N d ( L o s s ) d ( Y j ) ∗ d ( Y j ) d ( X i ) \sum_{j=0}^N\frac{d(Loss) }{d(Y_{j})}*\frac{d(Y_{j}) }{d(X_{i})} j=0Nd(Yj)d(Loss)d(Xi)d(Yj)= ∑ j = 0 N g r a d _ o u t p u t [ j ] ∗ d ( Y j ) d ( X i ) \sum_{j=0}^N{grad\_output}[j] *\frac{d(Y_{j}) }{d(X_{i})} j=0Ngrad_output[j]d(Xi)d(Yj)

d ( Y j ) d ( X i ) \frac{d(Y_{j}) }{d(X_{i})} d(Xi)d(Yj) 即为 W i j T = W j i W^T_{ij} = W_{ji} WijT=Wji

其中i 对应X维度, j对应输出Y维度。

最后整理成矩阵形式:

g r a d _ i n p u t = g r a d _ o u t p u t ∗ W {grad\_input}={grad\_output} * W grad_input=grad_outputW

同理:
g r a d _ w e i g h t = g r a d _ o u t p u t T ∗ X {grad\_weight}={grad\_output}^T * X grad_weight=grad_outputTX

g r a d _ b i a s = ∑ q = 0 N g r a d _ o u t p u t {grad\_bias}=\sum_{q=0}^N{grad\_output} grad_bias=q=0Ngrad_output

最后根据公式形式得到backward()函数。

反向传播的梯度求解还是不容易的,一不小心可能算错了,所以务必在模型训练以前检查梯度计算的正确性。pytorch提供了torch.autograd.gradcheck方法来检验梯度计算的正确性。

其他参考文献:pytorch自定义算子实现详解及反向传播梯度推导

最后根据自定义算子,搭建模型,训练模型参数W,B。并导出onnx。参考代码如下:

import torch
from torch import Tensor
from typing import Tuple
import numpy as np
class LinearF(torch.autograd.Function):@staticmethoddef symbolic(g, input, weight, bias):return g.op("MYLINEAR", input, weight, bias)@staticmethoddef forward(ctx, input:Tensor, weight: Tensor, bias: Tensor) -> Tensor:output = input @ weight.T + bias[None, ...]ctx.save_for_backward(input, weight)return output@staticmethoddef backward(ctx, grad_output:Tensor)->Tuple[Tensor, Tensor, Tensor]:print("grad_output: ", grad_output)# grad_output -- [B, N] = d(Loss) / d(Y)input, weight = ctx.saved_tensorsgrad_input = grad_output @ weightgrad_weight = grad_output.T @ inputgrad_bias = grad_output.sum(0)return grad_input, grad_weight, grad_bias#对LinearFunction进行封装
class MyLinear(torch.nn.Module):def __init__(self, in_features: int, out_features: int, dtype:torch.dtype) -> None:super().__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype))self.bias = torch.nn.Parameter(torch.empty((out_features,), dtype=dtype))self.reset_parameters()# self.weight = torch.nn.Parameter(torch.Tensor([2.0, 3.0]))# self.bias = torch.nn.Parameter(torch.Tensor([4.0]))#y = 2 * x1 + 3 * x2 + 4def reset_parameters(self) -> None:torch.nn.init.uniform_(self.weight)torch.nn.init.uniform_(self.bias)def forward(self, input: Tensor) -> Tensor:# for name, pa in self.named_parameters():#     print(name, pa)return LinearF.apply(input, self.weight, self.bias)  # 在此处使用if __name__ == "__main__":device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(device.type)model = MyLinear(2, 1, dtype=torch.float64).to(device)# torch.Tensor 默认类型为float32,使用gpu时,输入数据类型与W权重类型一致,否则报错# torch.Tensor([3.0, 2.0].double() 转换为float64#input = torch.Tensor([3.0, 2.0], ).requires_grad_(True).unsqueeze(0).double()#input = input.to(device)#assert torch.autograd.gradcheck(model, input)import torch.optim as optim#定义优化策略和判别函数optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)criterion = torch.nn.MSELoss()for epoch in range(300):print("************** epoch: ", epoch , " ************************************* ")inputx = torch.Tensor(np.random.rand(2)).unsqueeze(0).double().to(device)lable = torch.Tensor(2 * inputx[:, 0] + 3 * inputx[:, 1] + 4).double().to(device)print("outlable", lable)optimizer.zero_grad()  # 梯度清零prob = model(inputx)print("prob", prob)loss = criterion(lable, prob)print("loss: ", loss)loss.backward()  #反向传播optimizer.step() #更新参数# 完成训练model.cpu().eval()input = torch.tensor([[3.0, 2.0]], dtype=torch.float64)output = model(input)torch.onnx.export(model,  # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号(input,),"linear.onnx",  # 储存的文件路径verbose=True,  # 打印详细信息input_names=["x"],  #为输入和输出节点指定名称,方便后面查看或者操作output_names=["y"],opset_version=11,  #这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11dynamic_axes={"image": {0: "batch"},"output": {0: "batch"},},operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

相关文章:

Pytorch自定义算子反向传播

文章目录 自定义一个线性函数算子如何实现反向传播 有关 自定义算子的实现前面已经提到,可以参考。本文讲述自定义算子如何前向推理反向传播进行模型训练。 自定义一个线性函数算子 线性函数 Y X W T B Y XW^T B YXWTB 定义输入M 个X变量,输出N个…...

aws服务(二)机密数据存储

在AWS(Amazon Web Services)中存储机密数据时,安全性和合规性是最重要的考虑因素。AWS 提供了多个服务和工具,帮助用户确保数据的安全性、机密性以及合规性。以下是一些推荐的存储机密数据的AWS服务和最佳实践: 一、A…...

VMware Workstation 17.6.1

概述 目前 VMware Workstation Pro 发布了最新版 v17.6.1: 本月11号官宣:针对所有人免费提供,包括商业、教育和个人用户。 使用说明 软件安装 获取安装包后,双击默认安装即可: 一路单击下一步按钮: 等待…...

高校企业数据挖掘平台推荐

TipDM数据挖掘建模平台是由广东泰迪智能科技股份有限公司自主研发打造的可视化、一站式、高性能的数据挖掘与人工智能建模服务平台,致力于为使用者打通从数据接入、数据预处理、模型开发训练、模型评估比较、模型应用部署到模型任务调度的全链路。平台内置丰富的机器…...

Vue项目开发 formatData 函数有哪些常用的场景?

formatData 不是 JavaScript 中的内建函数,它通常是一个自定义函数,用来格式化数据。不同的开发环境和框架中可能有不同的 formatData 实现方式。如果你指的是某个特定框架或者库中的 formatData,请提供更多的上下文信息。不过,以…...

【AI知识】两类最主流AI应用(文生图、ChatGPT)中的目标函数

之前写过一篇 【AI知识】了解两类最主流AI任务中的目标函数,介绍了AI最常见的两类任务【分类、回归】的基础损失函数【交叉熵、均方差】,以初步了解AI的训练目标。 本篇更进一步,聊一聊流行的“文生图”、“聊天机器人ChatGPT”模型中的目标函…...

【单片机基础】定时器/计数器的工作原理

单片机中的定时器/计数器(Timer/Counter)是用于时间测量和事件计数的重要模块。它们可以用来生成精确的延时、测量外部信号的频率或周期、捕获外部事件的时间戳等。理解定时器/计数器的工作原理对于单片机编程和系统设计非常重要。以下是定时器/计数器的…...

ModuleNotFoundError: No module named ‘distutils.msvccompiler‘ 报错的解决

报错 在conda 环境安装 numpy 时,出现报错 ModuleNotFoundError: No module named distutils.msvccompiler 解决 Python 版本过高导致的,降低版本到 Python 3.8 conda install python3.8即可解决。...

HCIA笔记2--ARP+ICMP+VRP基础

1. ARP ARP: 地址解析协议(address resolve protocol)。 网络数据包在通信的时候一般是使用 I P IP IP地址进行通信。 但是在封装数据链路层的时候是需要目标 m a c mac mac地址的。 而 A R P ARP ARP协议实现的功能就是根据 I P IP IP地址来获得 m a c mac mac地址。 1.1 a…...

SpringBoot与MongoDB深度整合及应用案例

SpringBoot与MongoDB深度整合及应用案例 在当今快速发展的软件开发领域,NoSQL数据库因其灵活性和可扩展性而变得越来越流行。MongoDB,作为一款领先的NoSQL数据库,以其文档导向的存储模型和强大的查询能力脱颖而出。本文将为您提供一个全方位…...

Redis模拟延时队列 实现日程提醒

使用Redis模拟延时队列 实际上通过MQ实现延时队列更加方便,只是在实际业务中种种原因导致最终选择使用redis作为该业务实现的中间件,顺便记录一下。 该业务是用于日程短信提醒,用户添加日程后,就会被放入redis队列中等待被执行发…...

vue项目中富文本编辑器的实现

文章目录 vue前端实现富文本编辑器的功能需要用到第三方库1. 安装包2.全局引入注册3.组件内使用4.图片缩放功能实现①安装包②注册并添加配置项③报错解决 vue前端实现富文本编辑器的功能需要用到第三方库 vue2使用vue-quill-editor,vue3使用vueup/vue-quill&#…...

nginx 配置lua执行shell脚本

1.需要nginx安装lua_nginx_module模块,这一步安装时,遇到一个坑,nginx执行configure时,一直提示./configure: error: unsupported LuaJIT version; ngx_http_lua_module requires LuaJIT 2.x。 网上一堆方法都试了,都…...

Keil+VSCode优化开发体验

目录 一、引言 二、详细步骤 1、编译器准备 2、安装相应插件 2.1 安装C/C插件 2.2 安装Keil相关插件 3、添加keil环境变量 4、加载keil工程文件 5、VSCode中成功添加工程文件后可能出现的问题 5.1 编码不一致问题 6、在VSCode中进行编译工程以及烧录程序 7、效果展示…...

vue2中引入cesium全步骤

1.npm 下载cesium建议指定版本下载,最新版本有兼容性问题 npm install cesium1.95.0 2.在node_models中找到cesium将此文件下的Cesium文件复制出来放在项目的静态资源public中或者static中,获取去github上去下载zip包放在本地也可以 3.在index.html中引…...

工程师 - 智能家居方案介绍

1. 智能家居硬件方案概述 智能家居硬件方案是实现家庭自动化的重要组件,通过集成各种设备来提升生活的便利性、安全性和效率。这些方案通常结合了物联网技术,为用户提供智能化、自动化的生活体验。硬件方案的选择直接影响到智能家居系统的性能、兼容性、…...

中小企业人事管理:SpringBoot框架高级应用

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,中小企业人事管理系统当然也不能排除在外。中小企业人事管理系统是以实际运用为开发背景,运用软件工程原理和…...

嵌入式Linux驱动开发日记

目录 让我们从环境配置开始 目标平台 从Ubuntu开始 从交叉编译器继续 arm-linux-gnueabihf-gcc vscode 没学过ARM汇编 正文开始——速度体验一把 写一个链接脚本 写一个简单的Makefile脚本 使用正点原子的imxdownload下载到自己的SD卡上 更进一步的笔记和说明 从IM…...

迪杰特斯拉算法(Dijkstra‘s)

迪杰斯特拉算法(Dijkstras algorithm)是由荷兰计算机科学家艾兹格迪科斯彻(Edsger W. Dijkstra)在1956年提出的,用于在加权图中找到单个源点到所有其他顶点的最短路径的算法。这个算法广泛应用于网络路由、地图导航等领…...

reids基础

数据结构类型 String setnx //设置key不存在,则添加成功 setex name 10 jack // key 10s失效,自动删除 hash hset hget list 按添加数据排序 lpush //左侧插入 rpush //右侧插入 set 不重复 sadd //添加…...

java_网络服务相关_gateway_nacos_feign区别联系

1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】

微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...

Module Federation 和 Native Federation 的比较

前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...

k8s业务程序联调工具-KtConnect

概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

Java多线程实现之Thread类深度解析

Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

Web 架构之 CDN 加速原理与落地实践

文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 &#xf…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件,这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下,实现高效测试与快速迭代?这一命题正考验着…...

【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论

路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...