反向传播详解BP
误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。
BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法,由误差传播、参数更新两个环节循环迭代组成。
本文将以最基础的全连接深度前馈网络为例,详细展示Back-propagation的全过程,并以Numpy进行实现。
图1 神经元为层/权重为层
通常我们以神经元来计量“层”,但本文将权重抽象为“层”,个人认为这样更有助于反向传播的理解和代码的编写。如上图所示的网络就被抽象为两个中间层、一个输出层的结构。
简而言之,神经网络的训练过程中,前向传播和反向传播交替进行,如下图所示:前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新,这一点是重点,会在后文详叙。
图2 前向传播&反向传播
前向传播
每层中前向传播的过程如下所示,很简单的矩阵运算。我们将权重作为层,中间层和输出层均可用Layer类来表示,只是对应的激活函数不同。如图2所示,每一层的输入和输出都是
,且前一层的输出是后一层的输入。
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)
def forward(self, input_data): # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return h
- 反向传播
损失对参数梯度的反向传播可以被这样直观解释:由A到传播B,即由
得到
,由导数链式法则
实现。所以神经网络的BP就是通过链式法则求出
对所有参数梯度的过程。
如上图示例,输入
,经过网络的参数
,得到一系列中间结果
。
表示通过权重和偏置的结果,还未经过激活函数,
表示经过激活函数后的结果。灰色框内表示
对各中间计算结果的梯度,这些梯度的反向传播有两类:
由
到
,通过激活函数,如右上角
由
到
,通过权重,如橙线部分
可以看出梯度的传播和前向传播的模式是一致的,只是方向不同。
计算完了灰色框的部分(损失对中间结果
的梯度),损失对参数
的梯度也就显而易见了,以图中红色的
和
为例:
因此,我们可以如图2,将反向传播的表达式和代码如下。
注意代码和公式中
表示element-wise乘积,
表示矩阵乘积。
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)
def forward(self, input_data): # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return hdef backward(input_grad):'''单个样本的反向传播'''a_grad = input_grad * activate’(a) # (1, output_dim)b_grad = a_grad # (1, output_dim)W_grad = (input_data.T) · a_grad # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # (1, input_dim)
输出层的反向传播略有不同,因为在分类任务中输出层若用到softmax激活函数,
到
不是逐个对应的,如下图所示,因此
中的element-wise相乘是失效的,需要用
乘以向量
到向量
的向量梯度(雅可比矩阵)。
但实际上,经过看上去复杂的计算后输出层
会计算出一个非常简洁的结果:
以分类任务为例(交叉熵损失、softmax、训练标签
为one-hot向量其中第
维为1):
)
)
以回归任务为例(二次损失、线性激活、训练标签
为实数向量):
因此输出层反向传播的公式和代码可以写成如下所示:
* 表示element-wise乘积,· 表示矩阵乘积
class Output_layer(Layer):
‘’‘属性和forward方法继承Layer类’‘’
def backward(input_grad):'''输出层backward方法''''''单个样本的反向传播'''a_grad = input_grad # (1, output_dim)b_grad = a_grad # (1, output_dim)W_grad = (input_data.T) · a_grad # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # (1, input_dim)
- Batch 批量计算
除非用随机梯度下降,否则每次用以训练的样本都是整个batch计算的,损失函数
则是整个batch中样本得到损失的均值。
在计算中会以向量化的方式增加运算效率,用batch_size表示批的规模,代码可更改为:
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
def forward(self, input_data): # input_data: (batch_size, input_dim)'''batch_size个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return hdef backward(input_grad): # input_grad: (batch_size, output_dim)'''batch_size个样本的反向传播'''a_grad = input_grad * activate’(a) # (batch_size, output_dim)b_grad = a_grad.mean(axis=0) # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= lr * b_gradself.W -= lr * W_gradreturn a_grad · (self.W).T # output_grad: (batch_size, input_dim)
class Output_layer(Layer):
‘’‘输出层类:属性和forward方法继承Layer类’‘’
def backward(input_grad): # input_grad: (batch_size, output_dim)'''输出层backward方法''''''batch_size个样本的反向传播'''a_grad = input_grad # (batch_size, output_dim)b_grad = a_grad.mean(axis=0) # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # output_grad: (batch_size, input_dim)
这里比较易错的地方是什么时候求均值,对
求均值还是对
求均值:梯度在中间结果
上都不需要求均值,对参数
的梯度时才需要求均值。
- 代码
https://github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
模拟一个三层神经网络的训练
相关文章:
反向传播详解BP
误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。 BP本来只指损失…...

2023.11.16-hive sql高阶函数lateral view,与行转列,列转行
目录 0.lateral view简介 1.行转列 需求1: 需求2: 2.列转行 解题思路: 0.lateral view简介 hive函数 lateral view 主要功能是将原本汇总在一条(行)的数据拆分成多条(行)成虚拟表,再与原表进行笛卡尔积,…...
解决Jetson Xavier NX上Invalid CUDA ‘--device 0‘ requested等问题
解决Jetson Xavier NX上Invalid CUDA --device 0 requested等问题 问题1:AssertionError: Invalid CUDA --device 0 requested, use --device cpu or pass valid CUDA device(s)问题2: “Illegal instruction(cpre dumped)”错误记录python http局域网文…...

git push 报错 The requested URL returned error: 500
今天gitpush时报错The requested URL returned error: 500 看报错应该是本地和gitlab服务器之间通信的问题,登录gitlab网站查看 登录时报错无法通过ldapadmin认证,ldap服务器连接失败。 首先,登录ldap服务器,查看是否是ldap服务…...
应用软件安全编程--17预防基于 DOM 的 XSS
DOM型XSS从效果上来说也属于反射型XSS,由于形成的原因比较特殊所以进行单独划分。在网站页面中有许多页面的元素,当页面到达浏览器时浏览器会为页面创建一个顶级的Document object 文档对象,接着生成各个子文档对象,每个页面元素对应一个文档…...

【FastCAE源码阅读9】鼠标框选网格、节点的实现
一、VTK的框选支持类vtkInteractorStyleRubberBandPick FastCAE的鼠标事件交互类是PropPickerInteractionStyle,它扩展自vtkInteractorStyleRubberBandPick。vtkInteractorStyleRubberBandPick类可以实现鼠标框选物体,默认情况下按下键盘r键开启框选模式…...

【ArcGIS处理】行政区划与流域区划间转化
【ArcGIS处理】行政区划与流域区划间转化 引言数据准备1、行政区划数据2、流域区划数据 ArcGIS详细处理步骤Step1:统计行政区划下子流域面积1、创建批量处理模型2、添加批量裁剪处理3、添加计算面积 Step2:根据子流域面积占比均化得到各行政区固定值 参考…...

Session、Token、Jwt三种登录方案介绍
新开发一个应用首先要考虑的就是登录怎么去做,登录本身就是判断一下输入的用户名和密码与系统存储的是否一致,但因为Http是无状态协议,用户请求其它接口时是怎么判断该用户已经登录了呢?下面聊一个三种实现方案。 一、传统sessio…...

Linux操作系统使用及C高级编程-D5Linux shell命令(进程管理、用户管理)
进程管理 查看进程ps 其中ps -eif可显示父进程 实时查看进程top 按q退出 树状图显示进程pstree 以父进程,子进程以树状形式展示 发送信号kill kill -l:查看都有哪些信号 9:进程终止 kill不指定信号,默认发送的是15信号SIGT…...
【TDSQL-PG数据库简单介绍】
TDSQL-PG数据库简单介绍 TDSQL-PGTDSQL-PG 设计目标 TDSQL-PG 腾讯 TDSQL-PG 分布式关系型数据库是一款面向海量在线实时分布式事务交易和 MPP 实时数据分析 通用型高性能数据库系统。 面对应用业务产生的不定性数据爆炸需求,不管是高并发的交易还是海量的实时数据…...

【文件包含】metinfo 5.0.4 文件包含漏洞复现
1.1漏洞描述 漏洞编号————漏洞类型文件包含漏洞等级⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐漏洞环境windows攻击方式 MetInfo 是一套使用PHP 和MySQL 开发的内容管理系统。MetInfo 5.0.4 版本中的 /metinfo_5.0.4/about/index.php?fmodule文件存在任意文件包含漏洞。攻击者可利用漏洞读取网…...

差分信号的末端并联电容到底有什么作用?
差分信号的末端并联电容到底有什么作用? 在现代电子系统中,差分信号是一种常见的信号形式,它们通过两根互补的信号线传输信号,具有较低的噪声和更高的抗干扰能力。然而,当差分信号线长度较长或者遇到复杂的电路环境时&…...
pandas教程:GroupBy Mechanics 分组机制
文章目录 Chapter 10 Data Aggregation and Group Operations(数据汇总和组操作)10.1 GroupBy Mechanics(分组机制)1 Iterating Over Groups(对组进行迭代)2 Selecting a Column or Subset of Columns (选中…...

通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用
通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用 通过右键点击某个文件夹用Idea打开 首先打开注册表 win R 输入 regedit 然后找到HKEY_CLASSES_ROOT\Directory\shell 然后右键shell 新建一个项名字就叫 Idea 第一步…...

Android 10.0 framework层设置后台运行app进程最大数功能实现
1. 前言 在10.0的定制开发中,在系统中,对于后台运行的app过多的时候,会比较耗内存,导致系统运行有可能会卡顿,所以在系统优化的 过程中,会限制后台app进程运行的数量,来保证系统流畅不影响体验,所以需要分析下系统中关于限制app进程的相关源码来实现 功能 2.framewo…...

如何快速找到华为手机中下载的文档
手机的目录设置比较繁杂,尤其是查找刚刚下载的文件,有时候需要捣鼓半天,如何快速找到这些文件呢?以下提供了几种方法: 方法一: 文件管理-》搜索文档 方法二: 文件管理-》最近 方法三…...

iceoryx(冰羚)-Architecture
Architecture 本文概述了Eclipseiceoryx体系结构,并解释了它的基本原理。 Software layers Eclipse iceoryx所包含的主要包如下所示。 接下来的部分将逐一简要介绍组件及其库。 Components and libraries 下面描述了不同的库及其名称空间。 ### iceoryx hoofs …...

LeetCode2-两数相加
大佬解法 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode(int x) { val x; }* }*/ class Solution {public ListNode addTwoNumbers(ListNode l1, ListNode l2) {ListNode pre new ListNode(0);ListNo…...
css 灰质彩色的边框
border: 4px solid transparent; background-color:#fff; background-clip: padding-box,border-box; background-origin:padding-box, border-box; background-image: linear-gradient(90deg,#F5F6FA,#F5F6FA 42%,#F5F6FA),linear-gradient(151deg,#33e9bf,#c7e58a,#b1e8cc);...

OpenCV实现手势音量控制
前言: Hello大家好,我是Dream。 今天来学习一下如何使用OpenCV实现手势音量控制,欢迎大家一起前来探讨学习~ 一、需要的库及功能介绍 本次实验需要使用OpenCV和mediapipe库进行手势识别,并利用手势距离控制电脑音量。 导入库&am…...
<6>-MySQL表的增删查改
目录 一,create(创建表) 二,retrieve(查询表) 1,select列 2,where条件 三,update(更新表) 四,delete(删除表…...

.Net框架,除了EF还有很多很多......
文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

23-Oracle 23 ai 区块链表(Blockchain Table)
小伙伴有没有在金融强合规的领域中遇见,必须要保持数据不可变,管理员都无法修改和留痕的要求。比如医疗的电子病历中,影像检查检验结果不可篡改行的,药品追溯过程中数据只可插入无法删除的特性需求;登录日志、修改日志…...
Pinocchio 库详解及其在足式机器人上的应用
Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库,专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性,并提供了一个通用的框架&…...
C++.OpenGL (14/64)多光源(Multiple Lights)
多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...
webpack面试题
面试题:webpack介绍和简单使用 一、webpack(模块化打包工具)1. webpack是把项目当作一个整体,通过给定的一个主文件,webpack将从这个主文件开始找到你项目当中的所有依赖文件,使用loaders来处理它们&#x…...
FOPLP vs CoWoS
以下是 FOPLP(Fan-out panel-level packaging 扇出型面板级封装)与 CoWoS(Chip on Wafer on Substrate)两种先进封装技术的详细对比分析,涵盖技术原理、性能、成本、应用场景及市场趋势等维度: 一、技术原…...
13.10 LangGraph多轮对话系统实战:Ollama私有部署+情感识别优化全解析
LangGraph多轮对话系统实战:Ollama私有部署+情感识别优化全解析 LanguageMentor 对话式训练系统架构与实现 关键词:多轮对话系统设计、场景化提示工程、情感识别优化、LangGraph 状态管理、Ollama 私有化部署 1. 对话训练系统技术架构 采用四层架构实现高扩展性的对话训练…...

MLP实战二:MLP 实现图像数字多分类
任务 实战(二):MLP 实现图像多分类 基于 mnist 数据集,建立 mlp 模型,实现 0-9 数字的十分类 task: 1、实现 mnist 数据载入,可视化图形数字; 2、完成数据预处理:图像数据维度转换与…...

生信服务器 | 做生信为什么推荐使用Linux服务器?
原文链接:生信服务器 | 做生信为什么推荐使用Linux服务器? 一、 做生信为什么推荐使用服务器? 大家好,我是小杜。在做生信分析的同学,或是将接触学习生信分析的同学,<font style"color:rgb(53, 1…...