当前位置: 首页 > news >正文

昇思25天学习打卡营第6天|函数式自动微分

函数式自动微分

相关前置知识复习

深度学习的重点之一是神经网络。而神经网络很重要的一环是反向传播算法,这个算法用于调整神经网络的权重。

反向传播算法

这里举例说明反向传播在做什么。
假设你是一个学生,一次考试过后,你收到了一份老师打分后的试卷。
前向传播就是考试的过程,你通过自己学习的知识来解答试卷上的每个问题。这就是用神经网络的权重(已有的知识)推理(解答)输入数据(问题),得到预测结果(最终的答卷)。
计算误差是你评估自己答案和标准答案的差距。这里在神经网络中是使用损失函数来衡量预测结果和实际结果之间的差距。
反向传播是分析错误的问题,分析具体是哪个知识点没掌握清楚,再去针对性的学习对应的知识,加强弱项。这里在申请网络中就是使用误差来反向调整神经网络中的权重。

反向传播的步骤是:

  1. 前向传播:输入数据通过神经网络计算出预测结果
  2. 计算误差:通过神经网络的预测结果和实际结果,计算损失函数
  3. 反向传播:将误差从输出层向输入层反向传播,计算每一层的误差。
  4. 调整权重:根据计算出的误差,使用优化算法(如梯度下降)调整每一层的权重,以减少误差。

链式法则

链式法则(Chain Rule)是微积分的重要概念,用于计算复合函数的导数。这是计算梯度的基础。简单来说链式法则的基本思想是将复杂的过程分解成多个简单的部分,再将各部分的结果组合起来得到总结果。

假设我们有两个函数
f f f g g g,并且它们是复合的,即 y = f ( g ( x ) ) y=f(g(x)) y=f(g(x))。根据链式法则,复合函数
y y y x x x 的导数可以表示为: d y d x = d y d g ⋅ d g d x \frac{d_y}{d_x}= \frac{d_y}{d_g} \cdot \frac{d_g}{d_x} dxdy=dgdydxdg

为什么要在神经网络这里强调链式法则呢?
举个小例子。想象神经网络的结构,第一层传给第二层,第二层传给第三层,每一层的输出都是后一层的输入,整个就是一条链。通过最后一层的输出,计算损失函数对最终输出的导数,再根据链式法则,可以逐层往前推导梯度,将误差从输出层传递回输入层,计算出每个参数的梯度,再进行参数更新,完成反向传递。

自动微分

上面阐述了,神经网络的误差需要通过链式法则传递,从而计算出每个参数的梯度。自动微分利用链式法则自动计算反向传播的梯度,也就是将梯度计算过程变得自动化了。

构建模型

以下我们构建一个简单的线性模型来介绍如何进行函数式自动微分。

普通的线性模型: y = w × b + b y = w \times b+ b y=w×b+b
输入x,输出y,通过调整w和b参数来优化预测。

# 环境配置
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
x = ops.ones(5, mindspore.float32)  # 形状为5的全1输入
y = ops.zeros(3, mindspore.float32)  # 形状为3的全0输出
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # 线性函数的权重
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # 线性函数的偏差 
def function(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss

解读一下这段代码:
z是线性函数的组合,ops.matmul(x, w) 是输入x和权重w的矩阵乘法,然后再将偏置向量b加入到结果中。
这里使用binary_cross_entropy_with_logits作为损失函数。这个函数计算了预测值和目标值之间的二值交叉熵损失。

通过 f u n c t i o n ( x , y , w , b ) function(x,y,w,b) function(x,y,w,b)可以计算到loss值。

模型参数优化

为了优化模型参数,需要计算对两个参数的导数。

  • ∂ loss ⁡ ∂ w \frac{\partial \operatorname{loss}}{\partial w} wloss
  • ∂ loss ⁡ ∂ b \frac{\partial \operatorname{loss}}{\partial b} bloss

这里使用mindspore.grad函数,来获得function的微分函数用以计算梯度。
grad函数包括两个参数:

  • fn:待求导的函数。
  • grad_position:需要求导的参数的索引。

function的入参是x,y,w,b, 我们需要对w,b求导。w,b在入参中的索引是2,3。因此grad_postion为(2,3),可得function的微分函数

grad_fn = mindspore.grad(function, (2, 3))

停止梯度

一般来讲,只会求loss对参数的导数,因此只需要输出loss就可以。但如果要求的话,也可以输出多个loss以外的参数。在我们的例子里,也就是在function的函数中增加除了loss之外的输出值。

def function_with_logits(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, zgrad_fn = mindspore.grad(function_with_logits, (2, 3))
grads = grad_fn(x, y, w, b)

这里增加了对z的输出。在后续grad_fn调用时,z也会参与到梯度计算,对w和b的梯度结果造成影响。

stop gradient用以阻止某些张量的梯度计算。通俗的说,当对张量 z 应用 Stop Gradient 操作后,在反向传播过程中,其梯度会被置零,从而不会影响之前的计算。具体操作如下:

def function_stop_gradient(x, y, w, b):z = ops.matmul(x, w) + bloss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))return loss, ops.stop_gradient(z)

这样的函数定义下,函数依然输出z,但是对z进行了阻断,不会影响到后续梯度计算是对其他参数(w,b)更新。

Auxiliary data

Auxiliary data意思是辅助数据,是函数第一个输出以外的其他输出。一般loss是第一个输出,其他都是Auxiliary data。

has_aux指具有辅助数据,可以在grad函数中设置,就能自动将辅助数据添加stop gradient操作。

grad_fn = mindspore.grad(function_with_logits, (2, 3), has_aux=True)

神经网络实现自动微分

上面的操作是自己手动搭建了一个模型。在之前的章节中我们使用nn.Cell构建了神经网络,现在来看看如何在Cell模型中如何实现函数式自动微分。

初始化模型

# 模型初始化,定义参数w和b,并构建模型
class Network(nn.Cell):def __init__(self):super().__init__()self.w = wself.b = bdef construct(self, x):z = ops.matmul(x, self.w) + self.breturn z
# 实例化模型
model = Network()
# 定义损失函数
loss_fn = nn.BCEWithLogitsLoss()
# 定义前向计算函数
def forward_fn(x, y):z = model(x)loss = loss_fn(z, y)return loss
# 上面的步骤之前的文章已经介绍过了。

这里我们使用value_and_grad接口获得微分函数。w、b两个参数已经是网络属性的一部分了,因此不需要再次进行指定了,而是使用model.trainable_params()获取可求导的参数。

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())

总结

本章使用了两种方式实现了自动微分,也就是梯度计算。一种是手动构建模型,一种使用nn.Cell搭建的神经网络。此外,本节也复习了一些深度学习的基础知识

打卡凭证

在这里插入图片描述

相关文章:

昇思25天学习打卡营第6天|函数式自动微分

函数式自动微分 相关前置知识复习 深度学习的重点之一是神经网络。而神经网络很重要的一环是反向传播算法,这个算法用于调整神经网络的权重。 反向传播算法 这里举例说明反向传播在做什么。 假设你是一个学生,一次考试过后,你收到了一份老…...

作业7.2

用结构体数组以及函数完成: 录入你要增加的几个学生,之后输出所有的学生信息 删除你要删除的第几个学生,并打印所有的学生信息 修改你要修改的第几个学生,并打印所有的学生信息 查找你要查找的第几个学生,并打印该的学生信息 1 /*…...

PCL 点云聚类(基于体素连通性)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 这里的思路很简单,我们通过将点云转换为体素,基于体素的连通性实现对点云的聚类(有点类似于欧式聚类),不过这种方式进行的聚类有些粗糙,但聚类速度相对会快很多,具体的实现效果可以详细阅读代码。 二、实现代…...

python自动化运维--DNS处理模块dnspython

1.dnspython介绍 dnspython是Pyhton实现的一个DNS工具包,他几乎支持所有的记录类型,可以用于查询、传输并动态更新ZONE信息,同事支持TSIG(事物签名)验证消息和EDNS0(扩展DNS)。在系统管理方面&a…...

成人职场商务英语学习柯桥外语学校|邮件中的“备注”用英语怎么说?

在英语中,"备注"通常可以翻译为"Notes" 或 "Remarks"。 这两个词在邮件中都很常用。例如: 1. Notes Notes: 是最通用和最常见的表达,可以用在各种情况下,例如: 提供有关电子邮件内容的附加信息 列…...

AndroidStudio报错macMissing essential plugin

电脑重启后打开studio: Missing essential plugin: org.jetbrains.android Please reinstall Android Studio from scratch. 无法使用 对应Mac下disabled_plugins.txt位于如下目录: /Users/ACB/Library/Application Support/Google/AndroidStudio4.2 …...

doris集群物理部署保姆级教程

doris物理安装 1、安装要求 Linux 操作系统版本需求​ 查看CentOs版本(>7.1) cat /etc/redhat-release 1)设置系统最大打开文件句柄数​ vi /etc/security/limits.conf soft nofile 65536hard nofile 65536 echo ‘’’ soft nofile 655360hard nofile 655…...

探囊取物之多形式登录页面(基于BootStrap4)

基于BootStrap4的登录页面,支持手机验证码登录、账号密码登录、二维码登录、其它统一登录 低配置云服务器,首次加载速度较慢,请耐心等候;演练页面可点击查看源码 预览页面:http://www.daelui.com/#/tigerlair/saas/pr…...

【ONLYOFFICE】| 桌面编辑器从0-1使用初体验

目录 一. 🦁 写在前面二. 🦁 在线使用感受2.1 创建 ONLYOFFICE 账号2.2 编辑pdf文档2.3 pdf直接创建表格 三. 🦁 写在最后 一. 🦁 写在前面 所谓桌面编辑器就是一种用于编辑文本、图像、视频等多种自媒体的软件工具,具…...

20、PHP字符串的排列(含源码)

题目: PHP字符串的排列? 描述: 输入一个字符串,按字典序打印出该字符串中字符的所有排列。 例如输入字符串abc,则打印出由字符a,b,c所能排列出来的所有字符串abc,acb,bac,bca,cab和cba。 输入描述: 输入一个字符串,长度不超过9(可…...

Linux 标准IO的fopen和fclose

getchar(),putchar() ‐‐‐‐ 一个字符 gets(buf),puts(buf) ‐‐‐‐ 一串字符 scanf(),printf() ‐‐‐‐ 一个字符,一串字符都可以 fopen函数的形式 FILE * fopen(constchar *path , cost char *mode) /* * description : 打开一个文件 * param ‐ path…...

一个计算密集小程序在不同CPU下的表现

本文比较了几款CPU对同一测试程序的比较结果,用的是Oracle公有云OCI上的计算实例,均分配的1 OCPU,内存用的默认值,不过内存对此测试程序运行结果不重要。 本文只列结果,不做任何评价。下表中,最后一列为测…...

圈子系统搭建教程,以及圈子系统的功能特点,圈子系统,允许二开,免费源码,APP小程序H5

圈子是一款社区与群组的交友工具。你可以在软件内创造一个兴趣的群组从而达到按圈子来交友的效果用户可以根据自己的兴趣爱好。 1. 创建圈子 轻松创建专属圈子,支持付费型社群。 2. 加入圈子 加入不同圈子,设置不同名片,保护隐私。 3. 定…...

递归算法练习

112. 路径总和 package Tree;import java.util.HashMap; import java.util.Map;class TreeNode {int val;TreeNode left;TreeNode right;public TreeNode(int val) {this.val val;} }/*** 求 树的路径和* <p>* 递归 递减* <p>* 询问是否存在从*当前节点 root 到叶…...

WebDriver 类的常用属性和方法

目录 &#x1f38d;简介 &#x1f38a;WebDriver 核心概念 &#x1f389;WebDriver 常用属性 &#x1f381;WebDriver 常用方法 &#x1f437;示例代码 &#x1f3aa;注意事项 &#x1f390;结语 &#x1f9e3;参考资料 &#x1f38d;简介 Selenium WebDriver 是一个用…...

基于x86+FPGA+AI轴承缺陷视觉检测系统,摇枕弹簧智能检测系统

一、承缺陷视觉检测系统 应用场景 轴类零件自动检测设备&#xff0c;集光、机、软件、硬件&#xff0c;智能图像处理等先进技术于一体&#xff0c;利用轮廓特征匹配&#xff0c;目标与定位&#xff0c;区域选取&#xff0c;边缘提取&#xff0c;模糊运算等算法实现人工智能高…...

短剧小程序系统cps分销开发搭建

短剧小程序系统CPS分销开发搭建是一个相对复杂但具有广阔商业前景的过程。以下是关于短剧小程序系统CPS分销开发搭建的详细步骤和要点&#xff1a; 需求分析与市场调研&#xff1a; 深入了解市场需求、用户画像和竞品分析&#xff0c;明确产品定位和功能需求。研究目标用户的消…...

代理IP的10大误区:区分事实与虚构

在当今的数字时代&#xff0c;代理已成为在线环境不可或缺的一部分。它们的用途广泛&#xff0c;从增强在线隐私到绕过地理限制。然而&#xff0c;尽管代理无处不在&#xff0c;但仍存在许多围绕代理的误解。在本博客中&#xff0c;我们将探讨和消除一些最常见的代理误解&#…...

数组-长度最小的子数组

M长度最小的子数组&#xff08;leetcode209&#xff09; /*** param {number} target* param {number[]} nums* return {number}*/ var minSubArrayLen function(target, nums) {const n nums.length;let ans n 1;let sum 0; // 子数组元素和let left 0; // 子数组…...

深度学习之交叉验证

交叉验证&#xff08;Cross-Validation&#xff09;是一种用于评估和验证机器学习模型性能的技术&#xff0c;尤其是在数据量有限的情况下。它通过将数据集分成多个子集&#xff0c;反复训练和测试模型&#xff0c;以更稳定和可靠地估计模型的泛化能力。常见的交叉验证方法有以…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中&#xff0c;各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过&#xff0c;在涉及到多个子类派生于基类进行多态模拟的场景下&#xff0c;…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

剑指offer20_链表中环的入口节点

链表中环的入口节点 给定一个链表&#xff0c;若其中包含环&#xff0c;则输出环的入口节点。 若其中不包含环&#xff0c;则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)

可以使用Sqliteviz这个网站免费编写sql语句&#xff0c;它能够让用户直接在浏览器内练习SQL的语法&#xff0c;不需要安装任何软件。 链接如下&#xff1a; sqliteviz 注意&#xff1a; 在转写SQL语法时&#xff0c;关键字之间有一个特定的顺序&#xff0c;这个顺序会影响到…...

12.找到字符串中所有字母异位词

&#x1f9e0; 题目解析 题目描述&#xff1a; 给定两个字符串 s 和 p&#xff0c;找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义&#xff1a; 若两个字符串包含的字符种类和出现次数完全相同&#xff0c;顺序无所谓&#xff0c;则互为…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用

在工业制造领域&#xff0c;无损检测&#xff08;NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统&#xff0c;以非接触式光学麦克风技术为核心&#xff0c;打破传统检测瓶颈&#xff0c;为半导体、航空航天、汽车制造等行业提供了高灵敏…...

libfmt: 现代C++的格式化工具库介绍与酷炫功能

libfmt: 现代C的格式化工具库介绍与酷炫功能 libfmt 是一个开源的C格式化库&#xff0c;提供了高效、安全的文本格式化功能&#xff0c;是C20中引入的std::format的基础实现。它比传统的printf和iostream更安全、更灵活、性能更好。 基本介绍 主要特点 类型安全&#xff1a…...