当前位置: 首页 > news >正文

Pytorch自定义算子反向传播

文章目录

      • 自定义一个线性函数算子
      • 如何实现反向传播

有关 自定义算子的实现前面已经提到,可以参考。本文讲述自定义算子如何前向推理+反向传播进行模型训练。

自定义一个线性函数算子

线性函数 Y = X W T + B Y = XW^T + B Y=XWT+B 定义输入M 个X变量,输出N个Y变量的线性方程组。
X X X 为一个 1 x M 矩阵, W W W为 N x M 矩阵, B B B 为 1xN 矩阵,根据公式,输出 Y Y Y为1xN 矩阵。其中 W 和 B 为算子权重参数,保存在模型中。
在训练时刻,模型输入 X X X , 和监督值 Y Y Y,根据 算子forward()计算的 Y p Y^p Yp ,计算Loss = criterion( Y Y Y, Y p Y^p Yp ),然后根据backward()链式求导反向传播计算梯度值。最后根据梯度更新W 和 B 参数。

class LinearF(torch.autograd.Function):@staticmethoddef symbolic(g, input, weight, bias):return g.op("MYLINEAR", input, weight, bias)@staticmethoddef forward(ctx, input:Tensor, weight: Tensor, bias: Tensor) -> Tensor:output = input @ weight.T + bias[None, ...]ctx.save_for_backward(input, weight)return output@staticmethoddef backward(ctx, grad_output:Tensor)->Tuple[Tensor, Tensor, Tensor]:# grad_output -- [B, N] = d(Loss) / d(Y)input, weight = ctx.saved_tensorsgrad_input = grad_output @ weightgrad_weight = grad_output.T @ inputgrad_bias = grad_output.sum(0)# print("grad_input: ", grad_input)# print("grad_weight: ", grad_weight)# print("grad_bias: ", grad_bias)return grad_input, grad_weight, grad_bias

如何实现反向传播

在这里插入图片描述
前向推理比较简单,就根据公式来既可以。反向传播backward() 怎么写呢?
反向传播有两个输入参数,第一个为ctx,第二个grad_output,grad_output就是对forward() 输出output 的求导,如果是最后的节点,那就是loss对输出的求导,否则就是下一层对输出求导,输出grad_input, grad_weight, grad_bias则分别对应了forward的输入input、weight、bias的梯度。这很容易理解,因为是在做链式求导,LinearFunction是这条链上的某个节点,输入输出的数量和含义刚好相反且对应。
根据公式:
Y = X W T + B Y = XW^T + B Y=XWT+B
Loss = criterion( Y t Y^t_{} Yt, Y Y_{} Y ), 假设我们选择判别函数为L2范数,Loss = ∑ j = 0 N 0.5 ∗ ( Y j t − Y j ) 2 \sum_{j=0}^N0.5 * (Y^t_{j}-Y_{j} )^2 j=0N0.5(YjtYj)2

grad_output(j) = d ( L o s s ) d ( Y j ) \frac{d(Loss) }{d(Y_{j})} d(Yj)d(Loss) = Y j t − Y j Y^t_{j} - Y_{j} YjtYj

其中 Y j t Y^t_{j} Yjt为监督值, Y j Y_{j} Yj为模型输出值。

根据链式求导法则, 对输入 X i X_{i} Xi 的求导为

grad_input[i] = ∑ j = 0 N d ( L o s s ) d ( Y j ) ∗ d ( Y j ) d ( X i ) \sum_{j=0}^N\frac{d(Loss) }{d(Y_{j})}*\frac{d(Y_{j}) }{d(X_{i})} j=0Nd(Yj)d(Loss)d(Xi)d(Yj)= ∑ j = 0 N g r a d _ o u t p u t [ j ] ∗ d ( Y j ) d ( X i ) \sum_{j=0}^N{grad\_output}[j] *\frac{d(Y_{j}) }{d(X_{i})} j=0Ngrad_output[j]d(Xi)d(Yj)

d ( Y j ) d ( X i ) \frac{d(Y_{j}) }{d(X_{i})} d(Xi)d(Yj) 即为 W i j T = W j i W^T_{ij} = W_{ji} WijT=Wji

其中i 对应X维度, j对应输出Y维度。

最后整理成矩阵形式:

g r a d _ i n p u t = g r a d _ o u t p u t ∗ W {grad\_input}={grad\_output} * W grad_input=grad_outputW

同理:
g r a d _ w e i g h t = g r a d _ o u t p u t T ∗ X {grad\_weight}={grad\_output}^T * X grad_weight=grad_outputTX

g r a d _ b i a s = ∑ q = 0 N g r a d _ o u t p u t {grad\_bias}=\sum_{q=0}^N{grad\_output} grad_bias=q=0Ngrad_output

最后根据公式形式得到backward()函数。

反向传播的梯度求解还是不容易的,一不小心可能算错了,所以务必在模型训练以前检查梯度计算的正确性。pytorch提供了torch.autograd.gradcheck方法来检验梯度计算的正确性。

其他参考文献:pytorch自定义算子实现详解及反向传播梯度推导

最后根据自定义算子,搭建模型,训练模型参数W,B。并导出onnx。参考代码如下:

import torch
from torch import Tensor
from typing import Tuple
import numpy as np
class LinearF(torch.autograd.Function):@staticmethoddef symbolic(g, input, weight, bias):return g.op("MYLINEAR", input, weight, bias)@staticmethoddef forward(ctx, input:Tensor, weight: Tensor, bias: Tensor) -> Tensor:output = input @ weight.T + bias[None, ...]ctx.save_for_backward(input, weight)return output@staticmethoddef backward(ctx, grad_output:Tensor)->Tuple[Tensor, Tensor, Tensor]:print("grad_output: ", grad_output)# grad_output -- [B, N] = d(Loss) / d(Y)input, weight = ctx.saved_tensorsgrad_input = grad_output @ weightgrad_weight = grad_output.T @ inputgrad_bias = grad_output.sum(0)return grad_input, grad_weight, grad_bias#对LinearFunction进行封装
class MyLinear(torch.nn.Module):def __init__(self, in_features: int, out_features: int, dtype:torch.dtype) -> None:super().__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype))self.bias = torch.nn.Parameter(torch.empty((out_features,), dtype=dtype))self.reset_parameters()# self.weight = torch.nn.Parameter(torch.Tensor([2.0, 3.0]))# self.bias = torch.nn.Parameter(torch.Tensor([4.0]))#y = 2 * x1 + 3 * x2 + 4def reset_parameters(self) -> None:torch.nn.init.uniform_(self.weight)torch.nn.init.uniform_(self.bias)def forward(self, input: Tensor) -> Tensor:# for name, pa in self.named_parameters():#     print(name, pa)return LinearF.apply(input, self.weight, self.bias)  # 在此处使用if __name__ == "__main__":device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(device.type)model = MyLinear(2, 1, dtype=torch.float64).to(device)# torch.Tensor 默认类型为float32,使用gpu时,输入数据类型与W权重类型一致,否则报错# torch.Tensor([3.0, 2.0].double() 转换为float64#input = torch.Tensor([3.0, 2.0], ).requires_grad_(True).unsqueeze(0).double()#input = input.to(device)#assert torch.autograd.gradcheck(model, input)import torch.optim as optim#定义优化策略和判别函数optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)criterion = torch.nn.MSELoss()for epoch in range(300):print("************** epoch: ", epoch , " ************************************* ")inputx = torch.Tensor(np.random.rand(2)).unsqueeze(0).double().to(device)lable = torch.Tensor(2 * inputx[:, 0] + 3 * inputx[:, 1] + 4).double().to(device)print("outlable", lable)optimizer.zero_grad()  # 梯度清零prob = model(inputx)print("prob", prob)loss = criterion(lable, prob)print("loss: ", loss)loss.backward()  #反向传播optimizer.step() #更新参数# 完成训练model.cpu().eval()input = torch.tensor([[3.0, 2.0]], dtype=torch.float64)output = model(input)torch.onnx.export(model,  # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号(input,),"linear.onnx",  # 储存的文件路径verbose=True,  # 打印详细信息input_names=["x"],  #为输入和输出节点指定名称,方便后面查看或者操作output_names=["y"],opset_version=11,  #这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11dynamic_axes={"image": {0: "batch"},"output": {0: "batch"},},operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

相关文章:

Pytorch自定义算子反向传播

文章目录 自定义一个线性函数算子如何实现反向传播 有关 自定义算子的实现前面已经提到,可以参考。本文讲述自定义算子如何前向推理反向传播进行模型训练。 自定义一个线性函数算子 线性函数 Y X W T B Y XW^T B YXWTB 定义输入M 个X变量,输出N个…...

aws服务(二)机密数据存储

在AWS(Amazon Web Services)中存储机密数据时,安全性和合规性是最重要的考虑因素。AWS 提供了多个服务和工具,帮助用户确保数据的安全性、机密性以及合规性。以下是一些推荐的存储机密数据的AWS服务和最佳实践: 一、A…...

VMware Workstation 17.6.1

概述 目前 VMware Workstation Pro 发布了最新版 v17.6.1: 本月11号官宣:针对所有人免费提供,包括商业、教育和个人用户。 使用说明 软件安装 获取安装包后,双击默认安装即可: 一路单击下一步按钮: 等待…...

高校企业数据挖掘平台推荐

TipDM数据挖掘建模平台是由广东泰迪智能科技股份有限公司自主研发打造的可视化、一站式、高性能的数据挖掘与人工智能建模服务平台,致力于为使用者打通从数据接入、数据预处理、模型开发训练、模型评估比较、模型应用部署到模型任务调度的全链路。平台内置丰富的机器…...

Vue项目开发 formatData 函数有哪些常用的场景?

formatData 不是 JavaScript 中的内建函数,它通常是一个自定义函数,用来格式化数据。不同的开发环境和框架中可能有不同的 formatData 实现方式。如果你指的是某个特定框架或者库中的 formatData,请提供更多的上下文信息。不过,以…...

【AI知识】两类最主流AI应用(文生图、ChatGPT)中的目标函数

之前写过一篇 【AI知识】了解两类最主流AI任务中的目标函数,介绍了AI最常见的两类任务【分类、回归】的基础损失函数【交叉熵、均方差】,以初步了解AI的训练目标。 本篇更进一步,聊一聊流行的“文生图”、“聊天机器人ChatGPT”模型中的目标函…...

【单片机基础】定时器/计数器的工作原理

单片机中的定时器/计数器(Timer/Counter)是用于时间测量和事件计数的重要模块。它们可以用来生成精确的延时、测量外部信号的频率或周期、捕获外部事件的时间戳等。理解定时器/计数器的工作原理对于单片机编程和系统设计非常重要。以下是定时器/计数器的…...

ModuleNotFoundError: No module named ‘distutils.msvccompiler‘ 报错的解决

报错 在conda 环境安装 numpy 时,出现报错 ModuleNotFoundError: No module named distutils.msvccompiler 解决 Python 版本过高导致的,降低版本到 Python 3.8 conda install python3.8即可解决。...

HCIA笔记2--ARP+ICMP+VRP基础

1. ARP ARP: 地址解析协议(address resolve protocol)。 网络数据包在通信的时候一般是使用 I P IP IP地址进行通信。 但是在封装数据链路层的时候是需要目标 m a c mac mac地址的。 而 A R P ARP ARP协议实现的功能就是根据 I P IP IP地址来获得 m a c mac mac地址。 1.1 a…...

SpringBoot与MongoDB深度整合及应用案例

SpringBoot与MongoDB深度整合及应用案例 在当今快速发展的软件开发领域,NoSQL数据库因其灵活性和可扩展性而变得越来越流行。MongoDB,作为一款领先的NoSQL数据库,以其文档导向的存储模型和强大的查询能力脱颖而出。本文将为您提供一个全方位…...

Redis模拟延时队列 实现日程提醒

使用Redis模拟延时队列 实际上通过MQ实现延时队列更加方便,只是在实际业务中种种原因导致最终选择使用redis作为该业务实现的中间件,顺便记录一下。 该业务是用于日程短信提醒,用户添加日程后,就会被放入redis队列中等待被执行发…...

vue项目中富文本编辑器的实现

文章目录 vue前端实现富文本编辑器的功能需要用到第三方库1. 安装包2.全局引入注册3.组件内使用4.图片缩放功能实现①安装包②注册并添加配置项③报错解决 vue前端实现富文本编辑器的功能需要用到第三方库 vue2使用vue-quill-editor,vue3使用vueup/vue-quill&#…...

nginx 配置lua执行shell脚本

1.需要nginx安装lua_nginx_module模块,这一步安装时,遇到一个坑,nginx执行configure时,一直提示./configure: error: unsupported LuaJIT version; ngx_http_lua_module requires LuaJIT 2.x。 网上一堆方法都试了,都…...

Keil+VSCode优化开发体验

目录 一、引言 二、详细步骤 1、编译器准备 2、安装相应插件 2.1 安装C/C插件 2.2 安装Keil相关插件 3、添加keil环境变量 4、加载keil工程文件 5、VSCode中成功添加工程文件后可能出现的问题 5.1 编码不一致问题 6、在VSCode中进行编译工程以及烧录程序 7、效果展示…...

vue2中引入cesium全步骤

1.npm 下载cesium建议指定版本下载,最新版本有兼容性问题 npm install cesium1.95.0 2.在node_models中找到cesium将此文件下的Cesium文件复制出来放在项目的静态资源public中或者static中,获取去github上去下载zip包放在本地也可以 3.在index.html中引…...

工程师 - 智能家居方案介绍

1. 智能家居硬件方案概述 智能家居硬件方案是实现家庭自动化的重要组件,通过集成各种设备来提升生活的便利性、安全性和效率。这些方案通常结合了物联网技术,为用户提供智能化、自动化的生活体验。硬件方案的选择直接影响到智能家居系统的性能、兼容性、…...

中小企业人事管理:SpringBoot框架高级应用

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,中小企业人事管理系统当然也不能排除在外。中小企业人事管理系统是以实际运用为开发背景,运用软件工程原理和…...

嵌入式Linux驱动开发日记

目录 让我们从环境配置开始 目标平台 从Ubuntu开始 从交叉编译器继续 arm-linux-gnueabihf-gcc vscode 没学过ARM汇编 正文开始——速度体验一把 写一个链接脚本 写一个简单的Makefile脚本 使用正点原子的imxdownload下载到自己的SD卡上 更进一步的笔记和说明 从IM…...

迪杰特斯拉算法(Dijkstra‘s)

迪杰斯特拉算法(Dijkstras algorithm)是由荷兰计算机科学家艾兹格迪科斯彻(Edsger W. Dijkstra)在1956年提出的,用于在加权图中找到单个源点到所有其他顶点的最短路径的算法。这个算法广泛应用于网络路由、地图导航等领…...

reids基础

数据结构类型 String setnx //设置key不存在,则添加成功 setex name 10 jack // key 10s失效,自动删除 hash hset hget list 按添加数据排序 lpush //左侧插入 rpush //右侧插入 set 不重复 sadd //添加…...

基于大模型的 UI 自动化系统

基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

centos 7 部署awstats 网站访问检测

一、基础环境准备(两种安装方式都要做) bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats&#xff0…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...

相机从app启动流程

一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...

Unsafe Fileupload篇补充-木马的详细教程与木马分享(中国蚁剑方式)

在之前的皮卡丘靶场第九期Unsafe Fileupload篇中我们学习了木马的原理并且学了一个简单的木马文件 本期内容是为了更好的为大家解释木马(服务器方面的)的原理,连接,以及各种木马及连接工具的分享 文件木马:https://w…...

Java编程之桥接模式

定义 桥接模式(Bridge Pattern)属于结构型设计模式,它的核心意图是将抽象部分与实现部分分离,使它们可以独立地变化。这种模式通过组合关系来替代继承关系,从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...

push [特殊字符] present

push 🆚 present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中,push 和 present 是两种不同的视图控制器切换方式,它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

【JVM】Java虚拟机(二)——垃圾回收

目录 一、如何判断对象可以回收 (一)引用计数法 (二)可达性分析算法 二、垃圾回收算法 (一)标记清除 (二)标记整理 (三)复制 (四&#xff…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...