BatchNormalization
目录
Covariate Shift
Internal Covariate Shift
BatchNormalization
Q1:BN的原理
Q2:BN的作用
Q3:BN的缺陷
Q4:BN的均值、方差的计算维度
Q5:BN在训练和测试时有什么区别
Q6:BN的代码实现
Covariate Shift
机器学习中,一般会假设模型的输入数据分布时稳定的。如果这个假设不成立,即模型的输入数据的分布发生变化,则称为协变量偏移(Covariate Shift).例如:模型的训练数据和测试数据分布不一致;模型在训练过程中输入数据发生变化。
Internal Covariate Shift
在深层网络训练的过程中,由于网络中参数变化而引起内部节点数据分布发生变化这一过程被称作Internal Covariate Shift.深度神经网络涉及到很多层的叠加,而每一层的参数更新会导致高层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断地去重新适应底层的参数更新。
Covariate Shift时模型的输入数据的分布发生变化,Internal Covariate Shift时网络内部的节点的输入数据分布发生变化。
ICS带来的问题
- 高层网络需要不断调整来适应输入数据分布的变化,导致网络学习速度的降低,使得学习的过程变得很不稳定
- 网络前几层参数的更新,很可能使得后面层输入数据变得过大或者过小,从而陷入梯度饱和区域(比如sigmoid的饱和区),减缓网络收敛速度
在各种Normalization提出之前,解决上述问题的方法是使用较小的学习率(避免参数更新太快),精细的参数初始化,训练时间会很长。
BatchNormalization
BN的提出就是为了解决Interval Covariate Shift问题。BN的作用是确保网络的各层,即使参数发生了变化,其输入数据的分布也不会发生太大的变化,将其拉回到均值0,方差1的正态分布,从而避免ICS。
Q1:BN的原理
BN可以看作带参数的标准化,它有两个需要学习的参数γ和β,称为偏移因子和缩放因子。BN包括两个步骤:第一步相当于标准化,对每层的输入统计均值和方差,然后进行去均值方差标准化处理,使得每层的输入都是均值为0,方差为1.第二步是对规范化的数据进行线性变化,使用两个可以学习的参数γ和β。第一步的标准化虽然缓解了ICS问题,使得每层网络输入都变得稳定,但却导致数据的表达能力减弱,使得底层网络学习的信息丢失(如果激活函数使用的是sigmoid或者tanh,0均值的数据大部分落在激活函数的近似线性区域,没有利用上非线性区域,极大削弱了非线性表达能力);因此加入了两个可学习的参数的线性变换来恢复数据的表达能力。
具体的计算公式如下:
Q2:BN的作用
-
加速网路训练。ICS问题会导致深层网络需要不断去适应底层网络参数的变化,因此训练速度会很慢,BN解决了ICS问题,因此可以加速网络训练。
- 防止过拟合。BN由于每次在计算均值方差时是依靠一个batch来计算的,引入了随机性,可以缓解过拟合问题,可以用来代替dropout以及降低L2正则的权重系数。
- 缓解梯度消失和梯度爆炸。BN使得每层的输入不会太大,因此不会梯度爆炸;每层输入绝对值不会太大,就不会落入sigmoid激活函数的饱和区域,从而缓解梯度消失。
- 调参更容易。之前由于ICS问题的存在,一般会采用更小的学习率,为了防止过拟合也会尝试dropout以及L2正则,并且要求很精细的网络初始化。有了BN后,缓解了ICS问题,就可以使用较大的学习率,对初始化的要求也没有那么高了。
Q3:BN的缺陷
- 受到Batch Size影响很大,如果batch size较小,每次训练计算的均值方差不具有代表性且不稳定,会降低模型效果。
- BN比较难用到RNN这种序列模型中。因为BN是batch内计算均值,而句子之间没有很强的语义关系,句子内部有比较强的语义关系,所以句子之间算均值对其再标准化,这种方式效果会不好。
Q4:BN的均值、方差的计算维度
对于全连接层:输入维度是[N,C],在N上计算平均,γ和β的维度是C
对于卷积层:输入维度是[N,C,H,W],在N,H,W上计算平均,γ和β的维度是C。
Q5:BN在训练和测试时有什么区别
训练时,均值、方差分别是该批次内数据相应维度的均值和方差;测试时,均值、方差是基于训练时批次数据均值方差的无偏估计,公司如下:(即在训练时保存所有批次的均值方差,然后计算无偏估计)
在推荐过程中BN采用如下公式:
这里的E[x]就是我上面那个式子里面的μtest,Var[x]就是我上面式子里面的σ^2test。
这个式子和训练时:
是等价的,不过是做了一些变换。在实际运行的时候,按照这种变体可以减少计算量,为啥呢?因为对于隐藏节点来说:
都是固定值,这样这两个值可以事先存起来,在推荐的时候直接用就行了,这样比原始的公式每一步都现成算少了除法的运算过程,如果隐藏节点个数多的话就会节省很多的计算量。
Q6:BN的代码实现
Batch Normalization里面有一个momentum参数,该参数作用于mean和variance的计算上,这里保留了历史batch里面的mean和variance值,即moving_mean和moving_variance,计算的是移动平均,将历史batch里的mean和variance的作用延续到当前batch。一般momentum的值为0.9,0.99等。多个batch后,即多个0.9连乘后,最早的batch的影响后变弱。
指数移动平均:指数移动平均是以指数式递减加权的移动平均。各数值的加权影响力随时间而指数式递减,越近期的数据加权影响力越重,但较旧的数据也给予一定的加权值。计算公式为:.优点:当想要计算均值的时候,不用保留所有时刻的值。随着时间推移,遥远过去的历史的影响会越来越小。
所以这里要注意实际实现的时候,测试阶段的均值和方差是通过在训练阶段指数加权移动平均来统计均值和方差得到的。
代码实现如下:
import torch
from torch import nn def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)#判断是全连接层还是卷积层,2代表全连接层,样本数和特征数;4代表卷积层,批量数,通道数,高宽if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)#1*n*高*宽var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)
关于BN常见的面试题目整理:
1.BN为什么效果?
讲BN的归一化带来的好处以及mini-batch mean 和mini-batch variance引入正则作用这两个方面
2.为什么BN归一化后还要有scale-shift操作?
这个在文中提到了
3.BN改变了数据分布,为什么效果反而会更好?
虽然会改变数据分布,但是数据之间的关联性是不会变的;由于有目标函数在,所以神经网络自己会朝着分布最优的方向去学习
4.BN用在什么地方
一般用在全连接层+BN+激活函数
5.对于什么激活函数,BN效果会明显?
对于sigmoid或者tanh激活函数,BN效果会好一些。
6.BN中在训练和测试时怎么用?
文中讲到了。
7.BN缺点
小样本时,效果不好,均值和方差是有偏的;在RNN中效果通常不好。其实文中也讲到了
[深度学习基础][面经]Batch Normalization - 知乎
Batch Normalization导读 - 知乎
整理学习之Batch Normalization(批标准化)_笨笨犬牙的博客-CSDN博客
相关文章:
BatchNormalization
目录 Covariate Shift Internal Covariate Shift BatchNormalization Q1:BN的原理 Q2:BN的作用 Q3:BN的缺陷 Q4:BN的均值、方差的计算维度 Q5:BN在训练和测试时有什么区别 Q6:BN的代码实现 Covariate Shift 机器学习中&a…...
vue 中安装插件实现 rem 适配
vue 中实现 rem 适配vue 项目实现页面自适应,可以安装插件实现。 postcss-pxtorem 是 PostCSS 的插件,用于将像素单元生成 rem 单位。 autoprefixer 浏览器前缀处理插件。 amfe-flexible 可伸缩布局方案替代了原先的 lib-flexible 选用了当前众多浏览…...

Hadoop学习
1.分布式与集群 hosts文件: 域名映射文件 2.Linux常用命令 ls -a:查看当前目录下所有文件mkdir -p:如果没有对应的父文件夹,会自动创建rm -rf:-f:强制删除 -r:递归删除cp -r:复制文…...

Golang反射源码分析
在go的源码包及一些开源组件中,经常可以看到reflect反射包的使用,本文就与大家一起探讨go反射机制的原理、学习其实现源码 首先,了解一下反射的定义: 反射是指计算机程序能够在运行时,能够描述其自身状态或行为、调整…...

Qt之悬浮球菜单
一、概述 最近想做一个炫酷的悬浮式菜单,考虑到菜单展开和美观,所以考虑学习下Qt的动画系统和状态机内容,打开QtCreator的示例教程浏览了下,大致发现教程中2D Painting程序和Animated Tiles程序有所帮助,如下图所示&a…...

易优cms attribute 栏目属性列表
attribute 栏目属性列表 attribute 栏目属性列表 [基础用法] 标签:attribute 描述:获取栏目的属性列表,或者单独获取某个属性值。 用法: {eyou:attribute typeauto} {$attr.name}:{$attr.value} {/eyou:attri…...

表格中的table-layout属性讲解
表格中的table-layout属性讲解 定义和用法 tableLayout 属性用来显示表格单元格、行、列的算法规则。 table-layout有三个属性值:auto、fixed、inherit。 fixed:固定表格布局 固定表格布局与自动表格布局相比,允许浏览器更快地对表格进行布…...

【MFA】windows环境下,使用Montreal-Forced-Aligner训练并对齐音频
文章目录一、安装MFA1.安装anaconda2.创建并进入虚拟环境3.安装pyTorch二、训练新的声学模型1.确保数据集的格式正确2.训练声音模型-导出模型和对齐文件3.报错处理1.遇到类似: Command ‘[‘createdb’,–host‘ ’, ‘Librispeech’]’ returned non-zero exit sta…...

C语言实验小项目实例源码大全订票信息管理系统贪吃蛇图书商品管理网络通信等
wx供重浩:创享日记 对话框发送:c项目 获取完整源码源文件视频讲解环境资源包文档说明等 包括火车订票系统、学生个人消费管理系统、超级万年历、学生信息管理系统、网络通信编程、商品管理系统、通讯录管理系统、企业员工管理系统、贪吃蛇游戏、图书管理…...
电脑图片损坏是怎么回事
电脑图片损坏是怎么回事?对于经常使用电脑的我们,总是会下载各种各样的图片,用于平时的使用中。但难免会遇到莫名其妙就损坏的图片文件,一旦发生这种情况,要如何才能修复损坏的图片呢?下面小编为大家带来常用的修复方…...

【论文研读】无人机飞行模拟仿真平台设计
无人机飞行模拟仿真平台设计 摘要: 为提高飞行控制算法的研发效率,降低研发成本,基于数字孪生技术设计一个无人机硬件在环飞行模拟仿真平台。从几何、物理和行为3个方面研究无人机数字模型构建方法,将物理实体以数字化方式呈现。设计一种多元融合场景建模法,依据属…...
【算法题】2379. 得到 K 个黑块的最少涂色次数
插: 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 坚持不懈,越努力越幸运,大家一起学习鸭~~~ 题目: 给你一个长度为 n 下标从 0 开始的…...

DJ1-3 计算机网络和因特网
目录 一、物理介质 1. 双绞线 2. 同轴电缆 3. 光纤线缆 4. 无线电磁波 二、端系统上的 Internet 服务 1. 面向连接的服务 TCP(Transmission Control Protocol) 2. 无连接的服务 UDP(User Datagram Protocol) TCP 和 UD…...

Git学习笔记(六)-标签管理
发布一个版本时,我们通常先在版本库中打一个标签(tag),这样,就唯一确定了打标签时刻的版本。将来无论什么时候,取某个标签的版本,就是把那个打标签的时刻的历史版本取出来。所以,标签…...

Semaphore 源码解读
一、Semaphore Semaphore 通过设置一个固定数值的信号量,并发时线程通过 acquire() 获取一个信号量,如果能成功获得则可以继续执行,否则将阻塞等待,当某个线程使用 release() 释放一个信号量时,被阻塞的线程则可以被唤…...

RZ/G2L工业核心板U盘读写速率测试
1. 测试对象HD-G2L-IOT基于HD-G2L-CORE工业级核心板设计,双路千兆网口、双路CAN-bus、2路RS-232、2路RS-485、DSI、LCD、4G/5G、WiFi、CSI摄像头接口等,接口丰富,适用于工业现场应用需求,亦方便用户评估核心板及CPU的性能。HD-G2L…...
《SQL与数据库基础》18. MySQL管理
SQL - MySQL管理MySQL管理系统数据库常用工具mysqlmysqladminmysqlbinlogmysqlshowmysqldumpmysqlimportsource本文以 MySQL 为例 MySQL管理 系统数据库 Mysql数据库安装完成后,自带了以下四个数据库,具体作用如下: 数据库含义mysql存储My…...

达梦关系型数据库
达梦关系型数据库一、DM8 安装1. 安装包下载2. Docker 安装3. Linux 安装4. Windows 安装二、DM 管理工具三、命令行交互工具 DIsql四、DM8 SQL使用1. 创建模式2. 创建表3. 修改表4. 读写数据5. 查看库下所有的表名6. 查看表字段信息GitHub: link. 欢迎star国产自主研发的大型…...
Postgresql | 执行计划
SQL优化主要从三个角度进行: (1)扫描方式; (2)连接方式; (3)连接顺序。 如果解决好这三方面的问题,那么这条SQL的执行效率就基本上是靠谱的。看懂SQL的执行计…...
Vue3之父子组件通过事件通信
前言 组件间传值的章节我们知道父组件给子组件传值的时候,使用v-bind的方式定义一个属性传值,子组件根据这个属性名去接收父组件的值,但是假如子组件想给父组件一些反馈呢?就不能使用这种方式来,而是使用事件的方式&a…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂
蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...

Keil 中设置 STM32 Flash 和 RAM 地址详解
文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面
代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...

Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数
高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...

网站指纹识别
网站指纹识别 网站的最基本组成:服务器(操作系统)、中间件(web容器)、脚本语言、数据厍 为什么要了解这些?举个例子:发现了一个文件读取漏洞,我们需要读/etc/passwd,如…...
iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈
在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...

七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
Qt 事件处理中 return 的深入解析
Qt 事件处理中 return 的深入解析 在 Qt 事件处理中,return 语句的使用是另一个关键概念,它与 event->accept()/event->ignore() 密切相关但作用不同。让我们详细分析一下它们之间的关系和工作原理。 核心区别:不同层级的事件处理 方…...