当前位置: 首页 > news >正文

反向传播详解BP

误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。

BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法,由误差传播、参数更新两个环节循环迭代组成。

本文将以最基础的全连接深度前馈网络为例,详细展示Back-propagation的全过程,并以Numpy进行实现。

图1 神经元为层/权重为层
通常我们以神经元来计量“层”,但本文将权重抽象为“层”,个人认为这样更有助于反向传播的理解和代码的编写。如上图所示的网络就被抽象为两个中间层、一个输出层的结构。

简而言之,神经网络的训练过程中,前向传播和反向传播交替进行,如下图所示:前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新,这一点是重点,会在后文详叙。

图2 前向传播&反向传播
前向传播
每层中前向传播的过程如下所示,很简单的矩阵运算。我们将权重作为层,中间层和输出层均可用Layer类来表示,只是对应的激活函数不同。如图2所示,每一层的输入和输出都是
,且前一层的输出是后一层的输入。

* 表示element-wise乘积,· 表示矩阵乘积

class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)

def forward(self, input_data):       # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a  # a: (1, output_dim)h = self.activate(a)              # h: (1, output_dim)return h
  1. 反向传播

损失对参数梯度的反向传播可以被这样直观解释:由A到传播B,即由
得到
,由导数链式法则
实现。所以神经网络的BP就是通过链式法则求出
对所有参数梯度的过程。

如上图示例,输入
,经过网络的参数
,得到一系列中间结果

表示通过权重和偏置的结果,还未经过激活函数,
表示经过激活函数后的结果。灰色框内表示
对各中间计算结果的梯度,这些梯度的反向传播有两类:



,通过激活函数,如右上角



,通过权重,如橙线部分

可以看出梯度的传播和前向传播的模式是一致的,只是方向不同。

计算完了灰色框的部分(损失对中间结果
的梯度),损失对参数
的梯度也就显而易见了,以图中红色的

为例:

因此,我们可以如图2,将反向传播的表达式和代码如下。

注意代码和公式中
表示element-wise乘积,
表示矩阵乘积。

* 表示element-wise乘积,· 表示矩阵乘积

class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)

def forward(self, input_data):       # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a  # a: (1, output_dim)h = self.activate(a)              # h: (1, output_dim)return hdef backward(input_grad):'''单个样本的反向传播'''a_grad = input_grad * activate’(a)  # (1, output_dim)b_grad = a_grad                     # (1, output_dim)W_grad = (input_data.T) · a_grad    # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T          # (1, input_dim)

输出层的反向传播略有不同,因为在分类任务中输出层若用到softmax激活函数,

不是逐个对应的,如下图所示,因此
中的element-wise相乘是失效的,需要用
乘以向量
到向量
的向量梯度(雅可比矩阵)。

但实际上,经过看上去复杂的计算后输出层
会计算出一个非常简洁的结果:

以分类任务为例(交叉熵损失、softmax、训练标签
为one-hot向量其中第
维为1):


以回归任务为例(二次损失、线性激活、训练标签
为实数向量):

因此输出层反向传播的公式和代码可以写成如下所示:

* 表示element-wise乘积,· 表示矩阵乘积

class Output_layer(Layer):
‘’‘属性和forward方法继承Layer类’‘’

def backward(input_grad):'''输出层backward方法''''''单个样本的反向传播'''a_grad = input_grad                 # (1, output_dim)b_grad = a_grad                     # (1, output_dim)W_grad = (input_data.T) · a_grad    # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T          # (1, input_dim)
  1. Batch 批量计算

除非用随机梯度下降,否则每次用以训练的样本都是整个batch计算的,损失函数
则是整个batch中样本得到损失的均值。

在计算中会以向量化的方式增加运算效率,用batch_size表示批的规模,代码可更改为:

* 表示element-wise乘积,· 表示矩阵乘积

class Layer:
‘’‘中间层类’‘’

def forward(self, input_data):       # input_data: (batch_size, input_dim)'''batch_size个样本的前向传播'''input_data · self.W + self.b = a  # a: (1, output_dim)h = self.activate(a)              # h: (1, output_dim)return hdef backward(input_grad):             # input_grad: (batch_size, output_dim)'''batch_size个样本的反向传播'''a_grad = input_grad * activate’(a) # (batch_size, output_dim)b_grad = a_grad.mean(axis=0)       # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= lr * b_gradself.W -= lr * W_gradreturn a_grad · (self.W).T         # output_grad: (batch_size, input_dim)

class Output_layer(Layer):
‘’‘输出层类:属性和forward方法继承Layer类’‘’

def backward(input_grad):             # input_grad: (batch_size, output_dim)'''输出层backward方法''''''batch_size个样本的反向传播'''a_grad = input_grad                # (batch_size, output_dim)b_grad = a_grad.mean(axis=0)       # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T          # output_grad: (batch_size, input_dim)

这里比较易错的地方是什么时候求均值,对
求均值还是对
求均值:梯度在中间结果
上都不需要求均值,对参数
的梯度时才需要求均值。

  1. 代码

https://github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
​github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb

模拟一个三层神经网络的训练

相关文章:

反向传播详解BP

误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。 BP本来只指损失…...

2023.11.16-hive sql高阶函数lateral view,与行转列,列转行

目录 0.lateral view简介 1.行转列 需求1: 需求2: 2.列转行 解题思路: 0.lateral view简介 hive函数 lateral view 主要功能是将原本汇总在一条(行)的数据拆分成多条(行)成虚拟表,再与原表进行笛卡尔积&#xff0c…...

解决Jetson Xavier NX上Invalid CUDA ‘--device 0‘ requested等问题

解决Jetson Xavier NX上Invalid CUDA --device 0 requested等问题 问题1:AssertionError: Invalid CUDA --device 0 requested, use --device cpu or pass valid CUDA device(s)问题2: “Illegal instruction(cpre dumped)”错误记录python http局域网文…...

git push 报错 The requested URL returned error: 500

今天gitpush时报错The requested URL returned error: 500 看报错应该是本地和gitlab服务器之间通信的问题,登录gitlab网站查看 登录时报错无法通过ldapadmin认证,ldap服务器连接失败。 首先,登录ldap服务器,查看是否是ldap服务…...

应用软件安全编程--17预防基于 DOM 的 XSS

DOM型XSS从效果上来说也属于反射型XSS,由于形成的原因比较特殊所以进行单独划分。在网站页面中有许多页面的元素,当页面到达浏览器时浏览器会为页面创建一个顶级的Document object 文档对象,接着生成各个子文档对象,每个页面元素对应一个文档…...

【FastCAE源码阅读9】鼠标框选网格、节点的实现

一、VTK的框选支持类vtkInteractorStyleRubberBandPick FastCAE的鼠标事件交互类是PropPickerInteractionStyle,它扩展自vtkInteractorStyleRubberBandPick。vtkInteractorStyleRubberBandPick类可以实现鼠标框选物体,默认情况下按下键盘r键开启框选模式…...

【ArcGIS处理】行政区划与流域区划间转化

【ArcGIS处理】行政区划与流域区划间转化 引言数据准备1、行政区划数据2、流域区划数据 ArcGIS详细处理步骤Step1:统计行政区划下子流域面积1、创建批量处理模型2、添加批量裁剪处理3、添加计算面积 Step2:根据子流域面积占比均化得到各行政区固定值 参考…...

Session、Token、Jwt三种登录方案介绍

新开发一个应用首先要考虑的就是登录怎么去做,登录本身就是判断一下输入的用户名和密码与系统存储的是否一致,但因为Http是无状态协议,用户请求其它接口时是怎么判断该用户已经登录了呢?下面聊一个三种实现方案。 一、传统sessio…...

Linux操作系统使用及C高级编程-D5Linux shell命令(进程管理、用户管理)

进程管理 查看进程ps 其中ps -eif可显示父进程 实时查看进程top 按q退出 树状图显示进程pstree 以父进程,子进程以树状形式展示 发送信号kill kill -l:查看都有哪些信号 9:进程终止 kill不指定信号,默认发送的是15信号SIGT…...

【TDSQL-PG数据库简单介绍】

TDSQL-PG数据库简单介绍 TDSQL-PGTDSQL-PG 设计目标 TDSQL-PG 腾讯 TDSQL-PG 分布式关系型数据库是一款面向海量在线实时分布式事务交易和 MPP 实时数据分析 通用型高性能数据库系统。 面对应用业务产生的不定性数据爆炸需求,不管是高并发的交易还是海量的实时数据…...

【文件包含】metinfo 5.0.4 文件包含漏洞复现

1.1漏洞描述 漏洞编号————漏洞类型文件包含漏洞等级⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐漏洞环境windows攻击方式 MetInfo 是一套使用PHP 和MySQL 开发的内容管理系统。MetInfo 5.0.4 版本中的 /metinfo_5.0.4/about/index.php?fmodule文件存在任意文件包含漏洞。攻击者可利用漏洞读取网…...

差分信号的末端并联电容到底有什么作用?

差分信号的末端并联电容到底有什么作用? 在现代电子系统中,差分信号是一种常见的信号形式,它们通过两根互补的信号线传输信号,具有较低的噪声和更高的抗干扰能力。然而,当差分信号线长度较长或者遇到复杂的电路环境时&…...

pandas教程:GroupBy Mechanics 分组机制

文章目录 Chapter 10 Data Aggregation and Group Operations(数据汇总和组操作)10.1 GroupBy Mechanics(分组机制)1 Iterating Over Groups(对组进行迭代)2 Selecting a Column or Subset of Columns (选中…...

通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用

通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用 通过右键点击某个文件夹用Idea打开 首先打开注册表 win R 输入 regedit 然后找到HKEY_CLASSES_ROOT\Directory\shell 然后右键shell 新建一个项名字就叫 Idea 第一步&#xf…...

Android 10.0 framework层设置后台运行app进程最大数功能实现

1. 前言 在10.0的定制开发中,在系统中,对于后台运行的app过多的时候,会比较耗内存,导致系统运行有可能会卡顿,所以在系统优化的 过程中,会限制后台app进程运行的数量,来保证系统流畅不影响体验,所以需要分析下系统中关于限制app进程的相关源码来实现 功能 2.framewo…...

如何快速找到华为手机中下载的文档

手机的目录设置比较繁杂,尤其是查找刚刚下载的文件,有时候需要捣鼓半天,如何快速找到这些文件呢?以下提供了几种方法: 方法一: 文件管理-》搜索文档 方法二: 文件管理-》最近 方法三&#xf…...

iceoryx(冰羚)-Architecture

Architecture 本文概述了Eclipseiceoryx体系结构,并解释了它的基本原理。 Software layers Eclipse iceoryx所包含的主要包如下所示。 接下来的部分将逐一简要介绍组件及其库。 Components and libraries 下面描述了不同的库及其名称空间。 ### iceoryx hoofs …...

LeetCode2-两数相加

大佬解法 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode(int x) { val x; }* }*/ class Solution {public ListNode addTwoNumbers(ListNode l1, ListNode l2) {ListNode pre new ListNode(0);ListNo…...

css 灰质彩色的边框

border: 4px solid transparent; background-color:#fff; background-clip: padding-box,border-box; background-origin:padding-box, border-box; background-image: linear-gradient(90deg,#F5F6FA,#F5F6FA 42%,#F5F6FA),linear-gradient(151deg,#33e9bf,#c7e58a,#b1e8cc);...

OpenCV实现手势音量控制

前言: Hello大家好,我是Dream。 今天来学习一下如何使用OpenCV实现手势音量控制,欢迎大家一起前来探讨学习~ 一、需要的库及功能介绍 本次实验需要使用OpenCV和mediapipe库进行手势识别,并利用手势距离控制电脑音量。 导入库&am…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

K8S认证|CKS题库+答案| 11. AppArmor

目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作&#xff1a; 1&#xff09;、切换集群 2&#xff09;、切换节点 3&#xff09;、切换到 apparmor 的目录 4&#xff09;、执行 apparmor 策略模块 5&#xff09;、修改 pod 文件 6&#xff09;、…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

如何在看板中有效管理突发紧急任务

在看板中有效管理突发紧急任务需要&#xff1a;设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP&#xff08;Work-in-Progress&#xff09;弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中&#xff0c;设立专门的紧急任务通道尤为重要&#xff0c;这能…...

Module Federation 和 Native Federation 的比较

前言 Module Federation 是 Webpack 5 引入的微前端架构方案&#xff0c;允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...

以光量子为例,详解量子获取方式

光量子技术获取量子比特可在室温下进行。该方式有望通过与名为硅光子学&#xff08;silicon photonics&#xff09;的光波导&#xff08;optical waveguide&#xff09;芯片制造技术和光纤等光通信技术相结合来实现量子计算机。量子力学中&#xff0c;光既是波又是粒子。光子本…...