反向传播详解BP
误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。
BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法,由误差传播、参数更新两个环节循环迭代组成。
本文将以最基础的全连接深度前馈网络为例,详细展示Back-propagation的全过程,并以Numpy进行实现。
图1 神经元为层/权重为层
通常我们以神经元来计量“层”,但本文将权重抽象为“层”,个人认为这样更有助于反向传播的理解和代码的编写。如上图所示的网络就被抽象为两个中间层、一个输出层的结构。
简而言之,神经网络的训练过程中,前向传播和反向传播交替进行,如下图所示:前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新,这一点是重点,会在后文详叙。
图2 前向传播&反向传播
前向传播
每层中前向传播的过程如下所示,很简单的矩阵运算。我们将权重作为层,中间层和输出层均可用Layer类来表示,只是对应的激活函数不同。如图2所示,每一层的输入和输出都是
,且前一层的输出是后一层的输入。
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)
def forward(self, input_data): # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return h
- 反向传播
损失对参数梯度的反向传播可以被这样直观解释:由A到传播B,即由
得到
,由导数链式法则
实现。所以神经网络的BP就是通过链式法则求出
对所有参数梯度的过程。
如上图示例,输入
,经过网络的参数
,得到一系列中间结果
。
表示通过权重和偏置的结果,还未经过激活函数,
表示经过激活函数后的结果。灰色框内表示
对各中间计算结果的梯度,这些梯度的反向传播有两类:
由
到
,通过激活函数,如右上角
由
到
,通过权重,如橙线部分
可以看出梯度的传播和前向传播的模式是一致的,只是方向不同。
计算完了灰色框的部分(损失对中间结果
的梯度),损失对参数
的梯度也就显而易见了,以图中红色的
和
为例:
因此,我们可以如图2,将反向传播的表达式和代码如下。
注意代码和公式中
表示element-wise乘积,
表示矩阵乘积。
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)
def forward(self, input_data): # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return hdef backward(input_grad):'''单个样本的反向传播'''a_grad = input_grad * activate’(a) # (1, output_dim)b_grad = a_grad # (1, output_dim)W_grad = (input_data.T) · a_grad # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # (1, input_dim)
输出层的反向传播略有不同,因为在分类任务中输出层若用到softmax激活函数,
到
不是逐个对应的,如下图所示,因此
中的element-wise相乘是失效的,需要用
乘以向量
到向量
的向量梯度(雅可比矩阵)。
但实际上,经过看上去复杂的计算后输出层
会计算出一个非常简洁的结果:
以分类任务为例(交叉熵损失、softmax、训练标签
为one-hot向量其中第
维为1):
)
)
以回归任务为例(二次损失、线性激活、训练标签
为实数向量):
因此输出层反向传播的公式和代码可以写成如下所示:
* 表示element-wise乘积,· 表示矩阵乘积
class Output_layer(Layer):
‘’‘属性和forward方法继承Layer类’‘’
def backward(input_grad):'''输出层backward方法''''''单个样本的反向传播'''a_grad = input_grad # (1, output_dim)b_grad = a_grad # (1, output_dim)W_grad = (input_data.T) · a_grad # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # (1, input_dim)
- Batch 批量计算
除非用随机梯度下降,否则每次用以训练的样本都是整个batch计算的,损失函数
则是整个batch中样本得到损失的均值。
在计算中会以向量化的方式增加运算效率,用batch_size表示批的规模,代码可更改为:
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
def forward(self, input_data): # input_data: (batch_size, input_dim)'''batch_size个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return hdef backward(input_grad): # input_grad: (batch_size, output_dim)'''batch_size个样本的反向传播'''a_grad = input_grad * activate’(a) # (batch_size, output_dim)b_grad = a_grad.mean(axis=0) # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= lr * b_gradself.W -= lr * W_gradreturn a_grad · (self.W).T # output_grad: (batch_size, input_dim)
class Output_layer(Layer):
‘’‘输出层类:属性和forward方法继承Layer类’‘’
def backward(input_grad): # input_grad: (batch_size, output_dim)'''输出层backward方法''''''batch_size个样本的反向传播'''a_grad = input_grad # (batch_size, output_dim)b_grad = a_grad.mean(axis=0) # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # output_grad: (batch_size, input_dim)
这里比较易错的地方是什么时候求均值,对
求均值还是对
求均值:梯度在中间结果
上都不需要求均值,对参数
的梯度时才需要求均值。
- 代码
https://github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
模拟一个三层神经网络的训练
相关文章:
反向传播详解BP
误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。 BP本来只指损失…...
2023.11.16-hive sql高阶函数lateral view,与行转列,列转行
目录 0.lateral view简介 1.行转列 需求1: 需求2: 2.列转行 解题思路: 0.lateral view简介 hive函数 lateral view 主要功能是将原本汇总在一条(行)的数据拆分成多条(行)成虚拟表,再与原表进行笛卡尔积,…...
解决Jetson Xavier NX上Invalid CUDA ‘--device 0‘ requested等问题
解决Jetson Xavier NX上Invalid CUDA --device 0 requested等问题 问题1:AssertionError: Invalid CUDA --device 0 requested, use --device cpu or pass valid CUDA device(s)问题2: “Illegal instruction(cpre dumped)”错误记录python http局域网文…...
git push 报错 The requested URL returned error: 500
今天gitpush时报错The requested URL returned error: 500 看报错应该是本地和gitlab服务器之间通信的问题,登录gitlab网站查看 登录时报错无法通过ldapadmin认证,ldap服务器连接失败。 首先,登录ldap服务器,查看是否是ldap服务…...
应用软件安全编程--17预防基于 DOM 的 XSS
DOM型XSS从效果上来说也属于反射型XSS,由于形成的原因比较特殊所以进行单独划分。在网站页面中有许多页面的元素,当页面到达浏览器时浏览器会为页面创建一个顶级的Document object 文档对象,接着生成各个子文档对象,每个页面元素对应一个文档…...
【FastCAE源码阅读9】鼠标框选网格、节点的实现
一、VTK的框选支持类vtkInteractorStyleRubberBandPick FastCAE的鼠标事件交互类是PropPickerInteractionStyle,它扩展自vtkInteractorStyleRubberBandPick。vtkInteractorStyleRubberBandPick类可以实现鼠标框选物体,默认情况下按下键盘r键开启框选模式…...
【ArcGIS处理】行政区划与流域区划间转化
【ArcGIS处理】行政区划与流域区划间转化 引言数据准备1、行政区划数据2、流域区划数据 ArcGIS详细处理步骤Step1:统计行政区划下子流域面积1、创建批量处理模型2、添加批量裁剪处理3、添加计算面积 Step2:根据子流域面积占比均化得到各行政区固定值 参考…...
Session、Token、Jwt三种登录方案介绍
新开发一个应用首先要考虑的就是登录怎么去做,登录本身就是判断一下输入的用户名和密码与系统存储的是否一致,但因为Http是无状态协议,用户请求其它接口时是怎么判断该用户已经登录了呢?下面聊一个三种实现方案。 一、传统sessio…...
Linux操作系统使用及C高级编程-D5Linux shell命令(进程管理、用户管理)
进程管理 查看进程ps 其中ps -eif可显示父进程 实时查看进程top 按q退出 树状图显示进程pstree 以父进程,子进程以树状形式展示 发送信号kill kill -l:查看都有哪些信号 9:进程终止 kill不指定信号,默认发送的是15信号SIGT…...
【TDSQL-PG数据库简单介绍】
TDSQL-PG数据库简单介绍 TDSQL-PGTDSQL-PG 设计目标 TDSQL-PG 腾讯 TDSQL-PG 分布式关系型数据库是一款面向海量在线实时分布式事务交易和 MPP 实时数据分析 通用型高性能数据库系统。 面对应用业务产生的不定性数据爆炸需求,不管是高并发的交易还是海量的实时数据…...
【文件包含】metinfo 5.0.4 文件包含漏洞复现
1.1漏洞描述 漏洞编号————漏洞类型文件包含漏洞等级⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐漏洞环境windows攻击方式 MetInfo 是一套使用PHP 和MySQL 开发的内容管理系统。MetInfo 5.0.4 版本中的 /metinfo_5.0.4/about/index.php?fmodule文件存在任意文件包含漏洞。攻击者可利用漏洞读取网…...
差分信号的末端并联电容到底有什么作用?
差分信号的末端并联电容到底有什么作用? 在现代电子系统中,差分信号是一种常见的信号形式,它们通过两根互补的信号线传输信号,具有较低的噪声和更高的抗干扰能力。然而,当差分信号线长度较长或者遇到复杂的电路环境时&…...
pandas教程:GroupBy Mechanics 分组机制
文章目录 Chapter 10 Data Aggregation and Group Operations(数据汇总和组操作)10.1 GroupBy Mechanics(分组机制)1 Iterating Over Groups(对组进行迭代)2 Selecting a Column or Subset of Columns (选中…...
通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用
通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用 通过右键点击某个文件夹用Idea打开 首先打开注册表 win R 输入 regedit 然后找到HKEY_CLASSES_ROOT\Directory\shell 然后右键shell 新建一个项名字就叫 Idea 第一步…...
Android 10.0 framework层设置后台运行app进程最大数功能实现
1. 前言 在10.0的定制开发中,在系统中,对于后台运行的app过多的时候,会比较耗内存,导致系统运行有可能会卡顿,所以在系统优化的 过程中,会限制后台app进程运行的数量,来保证系统流畅不影响体验,所以需要分析下系统中关于限制app进程的相关源码来实现 功能 2.framewo…...
如何快速找到华为手机中下载的文档
手机的目录设置比较繁杂,尤其是查找刚刚下载的文件,有时候需要捣鼓半天,如何快速找到这些文件呢?以下提供了几种方法: 方法一: 文件管理-》搜索文档 方法二: 文件管理-》最近 方法三…...
iceoryx(冰羚)-Architecture
Architecture 本文概述了Eclipseiceoryx体系结构,并解释了它的基本原理。 Software layers Eclipse iceoryx所包含的主要包如下所示。 接下来的部分将逐一简要介绍组件及其库。 Components and libraries 下面描述了不同的库及其名称空间。 ### iceoryx hoofs …...
LeetCode2-两数相加
大佬解法 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode(int x) { val x; }* }*/ class Solution {public ListNode addTwoNumbers(ListNode l1, ListNode l2) {ListNode pre new ListNode(0);ListNo…...
css 灰质彩色的边框
border: 4px solid transparent; background-color:#fff; background-clip: padding-box,border-box; background-origin:padding-box, border-box; background-image: linear-gradient(90deg,#F5F6FA,#F5F6FA 42%,#F5F6FA),linear-gradient(151deg,#33e9bf,#c7e58a,#b1e8cc);...
OpenCV实现手势音量控制
前言: Hello大家好,我是Dream。 今天来学习一下如何使用OpenCV实现手势音量控制,欢迎大家一起前来探讨学习~ 一、需要的库及功能介绍 本次实验需要使用OpenCV和mediapipe库进行手势识别,并利用手势距离控制电脑音量。 导入库&am…...
椭圆曲线密码学(ECC)
一、ECC算法概述 椭圆曲线密码学(Elliptic Curve Cryptography)是基于椭圆曲线数学理论的公钥密码系统,由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA,ECC在相同安全强度下密钥更短(256位ECC ≈ 3072位RSA…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...
React第五十七节 Router中RouterProvider使用详解及注意事项
前言 在 React Router v6.4 中,RouterProvider 是一个核心组件,用于提供基于数据路由(data routers)的新型路由方案。 它替代了传统的 <BrowserRouter>,支持更强大的数据加载和操作功能(如 loader 和…...
五年级数学知识边界总结思考-下册
目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...
Caliper 配置文件解析:config.yaml
Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...
在WSL2的Ubuntu镜像中安装Docker
Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...
Java面试专项一-准备篇
一、企业简历筛选规则 一般企业的简历筛选流程:首先由HR先筛选一部分简历后,在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如:Boss直聘(招聘方平台) 直接按照条件进行筛选 例如:…...
招商蛇口 | 执笔CID,启幕低密生活新境
作为中国城市生长的力量,招商蛇口以“美好生活承载者”为使命,深耕全球111座城市,以央企担当匠造时代理想人居。从深圳湾的开拓基因到西安高新CID的战略落子,招商蛇口始终与城市发展同频共振,以建筑诠释对土地与生活的…...
作为测试我们应该关注redis哪些方面
1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...
Unity中的transform.up
2025年6月8日,周日下午 在Unity中,transform.up是Transform组件的一个属性,表示游戏对象在世界空间中的“上”方向(Y轴正方向),且会随对象旋转动态变化。以下是关键点解析: 基本定义 transfor…...
