反向传播详解BP
误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。
BP本来只指损失函数对参数的梯度通过网络反向流动的过程,但现在也常被理解成神经网络整个的训练方法,由误差传播、参数更新两个环节循环迭代组成。
本文将以最基础的全连接深度前馈网络为例,详细展示Back-propagation的全过程,并以Numpy进行实现。
图1 神经元为层/权重为层
通常我们以神经元来计量“层”,但本文将权重抽象为“层”,个人认为这样更有助于反向传播的理解和代码的编写。如上图所示的网络就被抽象为两个中间层、一个输出层的结构。
简而言之,神经网络的训练过程中,前向传播和反向传播交替进行,如下图所示:前向传播通过训练数据和权重参数计算输出结果;反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新,这一点是重点,会在后文详叙。
图2 前向传播&反向传播
前向传播
每层中前向传播的过程如下所示,很简单的矩阵运算。我们将权重作为层,中间层和输出层均可用Layer类来表示,只是对应的激活函数不同。如图2所示,每一层的输入和输出都是
,且前一层的输出是后一层的输入。
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)
def forward(self, input_data): # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return h
- 反向传播
损失对参数梯度的反向传播可以被这样直观解释:由A到传播B,即由
得到
,由导数链式法则
实现。所以神经网络的BP就是通过链式法则求出
对所有参数梯度的过程。
如上图示例,输入
,经过网络的参数
,得到一系列中间结果
。
表示通过权重和偏置的结果,还未经过激活函数,
表示经过激活函数后的结果。灰色框内表示
对各中间计算结果的梯度,这些梯度的反向传播有两类:
由
到
,通过激活函数,如右上角
由
到
,通过权重,如橙线部分
可以看出梯度的传播和前向传播的模式是一致的,只是方向不同。
计算完了灰色框的部分(损失对中间结果
的梯度),损失对参数
的梯度也就显而易见了,以图中红色的
和
为例:
因此,我们可以如图2,将反向传播的表达式和代码如下。
注意代码和公式中
表示element-wise乘积,
表示矩阵乘积。
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
self.W # (input_dim, output_dim)
self.b # (1, output_dim)
self.activate(a) = sigmoid(a)/tanh(a)/ReLU(a)/Softmax(a)
def forward(self, input_data): # input_data: (1, input_dim)'''单个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return hdef backward(input_grad):'''单个样本的反向传播'''a_grad = input_grad * activate’(a) # (1, output_dim)b_grad = a_grad # (1, output_dim)W_grad = (input_data.T) · a_grad # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # (1, input_dim)
输出层的反向传播略有不同,因为在分类任务中输出层若用到softmax激活函数,
到
不是逐个对应的,如下图所示,因此
中的element-wise相乘是失效的,需要用
乘以向量
到向量
的向量梯度(雅可比矩阵)。
但实际上,经过看上去复杂的计算后输出层
会计算出一个非常简洁的结果:
以分类任务为例(交叉熵损失、softmax、训练标签
为one-hot向量其中第
维为1):
)
)
以回归任务为例(二次损失、线性激活、训练标签
为实数向量):
因此输出层反向传播的公式和代码可以写成如下所示:
* 表示element-wise乘积,· 表示矩阵乘积
class Output_layer(Layer):
‘’‘属性和forward方法继承Layer类’‘’
def backward(input_grad):'''输出层backward方法''''''单个样本的反向传播'''a_grad = input_grad # (1, output_dim)b_grad = a_grad # (1, output_dim)W_grad = (input_data.T) · a_grad # (input_dim, output_dim)self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # (1, input_dim)
- Batch 批量计算
除非用随机梯度下降,否则每次用以训练的样本都是整个batch计算的,损失函数
则是整个batch中样本得到损失的均值。
在计算中会以向量化的方式增加运算效率,用batch_size表示批的规模,代码可更改为:
* 表示element-wise乘积,· 表示矩阵乘积
class Layer:
‘’‘中间层类’‘’
def forward(self, input_data): # input_data: (batch_size, input_dim)'''batch_size个样本的前向传播'''input_data · self.W + self.b = a # a: (1, output_dim)h = self.activate(a) # h: (1, output_dim)return hdef backward(input_grad): # input_grad: (batch_size, output_dim)'''batch_size个样本的反向传播'''a_grad = input_grad * activate’(a) # (batch_size, output_dim)b_grad = a_grad.mean(axis=0) # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= lr * b_gradself.W -= lr * W_gradreturn a_grad · (self.W).T # output_grad: (batch_size, input_dim)
class Output_layer(Layer):
‘’‘输出层类:属性和forward方法继承Layer类’‘’
def backward(input_grad): # input_grad: (batch_size, output_dim)'''输出层backward方法''''''batch_size个样本的反向传播'''a_grad = input_grad # (batch_size, output_dim)b_grad = a_grad.mean(axis=0) # (1, output_dim)W_grad = (a_grad.reshape(batch_size,1,output_dim) * input_data.reshape(batch_size,input_dim,1)).mean(axis=0)# (input_dim, output_dim) self.b -= learning_rate * b_grad self.W -= learning_rate * W_gradreturn a_grad · (self.W).T # output_grad: (batch_size, input_dim)
这里比较易错的地方是什么时候求均值,对
求均值还是对
求均值:梯度在中间结果
上都不需要求均值,对参数
的梯度时才需要求均值。
- 代码
https://github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
github.com/qcneverrepeat/ML01/blob/master/BP_DNN.ipynb
模拟一个三层神经网络的训练
相关文章:

反向传播详解BP
误差反向传播(Back-propagation, BP)算法的出现是神经网络发展的重大突破,也是现在众多深度学习训练方法的基础。该方法会计算神经网络中损失函数对各参数的梯度,配合优化方法更新参数,降低损失函数。 BP本来只指损失…...

2023.11.16-hive sql高阶函数lateral view,与行转列,列转行
目录 0.lateral view简介 1.行转列 需求1: 需求2: 2.列转行 解题思路: 0.lateral view简介 hive函数 lateral view 主要功能是将原本汇总在一条(行)的数据拆分成多条(行)成虚拟表,再与原表进行笛卡尔积,…...

解决Jetson Xavier NX上Invalid CUDA ‘--device 0‘ requested等问题
解决Jetson Xavier NX上Invalid CUDA --device 0 requested等问题 问题1:AssertionError: Invalid CUDA --device 0 requested, use --device cpu or pass valid CUDA device(s)问题2: “Illegal instruction(cpre dumped)”错误记录python http局域网文…...

git push 报错 The requested URL returned error: 500
今天gitpush时报错The requested URL returned error: 500 看报错应该是本地和gitlab服务器之间通信的问题,登录gitlab网站查看 登录时报错无法通过ldapadmin认证,ldap服务器连接失败。 首先,登录ldap服务器,查看是否是ldap服务…...

应用软件安全编程--17预防基于 DOM 的 XSS
DOM型XSS从效果上来说也属于反射型XSS,由于形成的原因比较特殊所以进行单独划分。在网站页面中有许多页面的元素,当页面到达浏览器时浏览器会为页面创建一个顶级的Document object 文档对象,接着生成各个子文档对象,每个页面元素对应一个文档…...

【FastCAE源码阅读9】鼠标框选网格、节点的实现
一、VTK的框选支持类vtkInteractorStyleRubberBandPick FastCAE的鼠标事件交互类是PropPickerInteractionStyle,它扩展自vtkInteractorStyleRubberBandPick。vtkInteractorStyleRubberBandPick类可以实现鼠标框选物体,默认情况下按下键盘r键开启框选模式…...

【ArcGIS处理】行政区划与流域区划间转化
【ArcGIS处理】行政区划与流域区划间转化 引言数据准备1、行政区划数据2、流域区划数据 ArcGIS详细处理步骤Step1:统计行政区划下子流域面积1、创建批量处理模型2、添加批量裁剪处理3、添加计算面积 Step2:根据子流域面积占比均化得到各行政区固定值 参考…...

Session、Token、Jwt三种登录方案介绍
新开发一个应用首先要考虑的就是登录怎么去做,登录本身就是判断一下输入的用户名和密码与系统存储的是否一致,但因为Http是无状态协议,用户请求其它接口时是怎么判断该用户已经登录了呢?下面聊一个三种实现方案。 一、传统sessio…...

Linux操作系统使用及C高级编程-D5Linux shell命令(进程管理、用户管理)
进程管理 查看进程ps 其中ps -eif可显示父进程 实时查看进程top 按q退出 树状图显示进程pstree 以父进程,子进程以树状形式展示 发送信号kill kill -l:查看都有哪些信号 9:进程终止 kill不指定信号,默认发送的是15信号SIGT…...

【TDSQL-PG数据库简单介绍】
TDSQL-PG数据库简单介绍 TDSQL-PGTDSQL-PG 设计目标 TDSQL-PG 腾讯 TDSQL-PG 分布式关系型数据库是一款面向海量在线实时分布式事务交易和 MPP 实时数据分析 通用型高性能数据库系统。 面对应用业务产生的不定性数据爆炸需求,不管是高并发的交易还是海量的实时数据…...

【文件包含】metinfo 5.0.4 文件包含漏洞复现
1.1漏洞描述 漏洞编号————漏洞类型文件包含漏洞等级⭐⭐⭐⭐⭐⭐⭐⭐⭐⭐漏洞环境windows攻击方式 MetInfo 是一套使用PHP 和MySQL 开发的内容管理系统。MetInfo 5.0.4 版本中的 /metinfo_5.0.4/about/index.php?fmodule文件存在任意文件包含漏洞。攻击者可利用漏洞读取网…...

差分信号的末端并联电容到底有什么作用?
差分信号的末端并联电容到底有什么作用? 在现代电子系统中,差分信号是一种常见的信号形式,它们通过两根互补的信号线传输信号,具有较低的噪声和更高的抗干扰能力。然而,当差分信号线长度较长或者遇到复杂的电路环境时&…...

pandas教程:GroupBy Mechanics 分组机制
文章目录 Chapter 10 Data Aggregation and Group Operations(数据汇总和组操作)10.1 GroupBy Mechanics(分组机制)1 Iterating Over Groups(对组进行迭代)2 Selecting a Column or Subset of Columns (选中…...

通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用
通过右键用WebStorm、Idea打开某个文件夹或者在某一文件夹下右键打开当前文件夹用上述两个应用 通过右键点击某个文件夹用Idea打开 首先打开注册表 win R 输入 regedit 然后找到HKEY_CLASSES_ROOT\Directory\shell 然后右键shell 新建一个项名字就叫 Idea 第一步…...

Android 10.0 framework层设置后台运行app进程最大数功能实现
1. 前言 在10.0的定制开发中,在系统中,对于后台运行的app过多的时候,会比较耗内存,导致系统运行有可能会卡顿,所以在系统优化的 过程中,会限制后台app进程运行的数量,来保证系统流畅不影响体验,所以需要分析下系统中关于限制app进程的相关源码来实现 功能 2.framewo…...

如何快速找到华为手机中下载的文档
手机的目录设置比较繁杂,尤其是查找刚刚下载的文件,有时候需要捣鼓半天,如何快速找到这些文件呢?以下提供了几种方法: 方法一: 文件管理-》搜索文档 方法二: 文件管理-》最近 方法三…...

iceoryx(冰羚)-Architecture
Architecture 本文概述了Eclipseiceoryx体系结构,并解释了它的基本原理。 Software layers Eclipse iceoryx所包含的主要包如下所示。 接下来的部分将逐一简要介绍组件及其库。 Components and libraries 下面描述了不同的库及其名称空间。 ### iceoryx hoofs …...

LeetCode2-两数相加
大佬解法 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode(int x) { val x; }* }*/ class Solution {public ListNode addTwoNumbers(ListNode l1, ListNode l2) {ListNode pre new ListNode(0);ListNo…...

css 灰质彩色的边框
border: 4px solid transparent; background-color:#fff; background-clip: padding-box,border-box; background-origin:padding-box, border-box; background-image: linear-gradient(90deg,#F5F6FA,#F5F6FA 42%,#F5F6FA),linear-gradient(151deg,#33e9bf,#c7e58a,#b1e8cc);...

OpenCV实现手势音量控制
前言: Hello大家好,我是Dream。 今天来学习一下如何使用OpenCV实现手势音量控制,欢迎大家一起前来探讨学习~ 一、需要的库及功能介绍 本次实验需要使用OpenCV和mediapipe库进行手势识别,并利用手势距离控制电脑音量。 导入库&am…...

pytorch 深度学习之余弦相似度
文章目录 用处定理代码F.normalize() 和 F.norm() 的区别 用处 此方法特别重要,经常可以用来修改论文,提出创新点. 定理 余弦相似度是通过计算两个向量之间的夹角余弦值来衡量它们的相似性。给定两个非零向量 x 和 y,它们之间的余弦相似度…...

Postman的常规断言/动态参数断言/全局断言
近期在复习Postman的基础知识,在小破站上跟着百里老师系统复习了一遍,也做了一些笔记,希望可以给大家一点点启发。 断言,包括状态码断言和业务断言,状态码断言有一个,业务断言有多个。 一)常规的…...

ruoyi若依前端请求接口超时,增加响应时长
问题: 前端查询请求超时 解决: 找到request.js的timeout属性由10秒改成了20秒,因为默认是10秒,请求肯定是超出了10秒 祝您万事顺心,没事点个赞呗,关注一下也行啊,有啥要求您评论哈...

贪吃蛇小游戏
一. 准备工作 首先获取贪吃蛇小游戏所需要的头部、身体、食物以及贪吃蛇标题等图片。、 然后,创建贪吃蛇游戏的Java项目命名为snake_game,并在这个项目里创建一个文件夹命名为images,将图片素材导入文件夹。 再在src文件下创建两个包&#…...

cocos----1
1 前言 刚体(Rigidbody)是运动学(Kinematic)中的一个概念,指在运动中和受力作用后,形状和大小不变,而且内部各点的相对位置不变的物体。在 Unity3D 中,刚体组件赋予了游戏对…...

第十九章绘图
Java绘图类 Graphics 类 Grapics 类是所有图形上下文的抽象基类,它允许应用程序在组件以及闭屏图像上进行绘制。Graphics 类封装了Java 支持的基本绘图操作所需的状态信息,主要包括颜色、字体、画笔、文本、图像等。 Graphics 类提供了绘图常用的…...

rpmbuild 包名 version 操作系统信息部分来源 /etc/rpm/macros.dist
/etc/rpm/macros.dist openeuler bclinux src.rpm openssl-1.1.1f-13.oe1.src.rpm 打包名称结果 openeuler openssl-1.1.1f-13.aarch64.rpm bclinux openssl-1.1.1f-13.oe1.bclinux.aarch64.rpm 验证 修改openeuler配置文件macros.dist 重新在openeuler上执行rpmbuild…...

【Linux专题】SFTP 用户配置 ChrootDirectory
【赠送】IT技术视频教程,白拿不谢!思科、华为、红帽、数据库、云计算等等https://xmws-it.blog.csdn.net/article/details/117297837?spm1001.2014.3001.5502 红帽认证 认证课程介绍:红帽RHCE9.0学什么内容,新版有什么变化-CSDN…...

openssl+ DES开发实例(Linux)
文章目录 一、DES介绍二、DES原理三、DES C实现源码 一、DES介绍 DES(Data Encryption Standard)是一种对称密钥加密算法,最初由 IBM 设计,于1977年成为美国国家标准,用于加密非机密但敏感的政府数据。DES 使用相同的…...

结构体几种实用的用法
结构体的初始化 结构体的初始化是指在声明结构体变量时,为其成员变量赋初值。结构体的初始化可以通过以下几种方式实现: 1. 在声明结构体变量的同时进行初始化: struct Student { char name[20]; int age; float score; } student {…...