详解三种常用标准化:Batch Norm、Layer Norm和RMSNorm
在深度学习中,标准化技术是提升模型训练速度、稳定性和性能的重要手段。本文将详细介绍三种常用的标准化方法:Batch Normalization(批量标准化)、Layer Normalization(层标准化)和 RMS Normalization(RMS标准化),并对其原理、实现和应用场景进行深入分析。
一、Batch Normalization
1.1 Batch Normalization的原理
Batch Normalization(BN)通过在每个小批量数据的每个神经元输出上进行标准化来减少内部协变量偏移。具体步骤如下:
-
计算小批量的均值和方差:
对于每个神经元的输出,计算该神经元在当前小批量中的均值和方差。[
\muB = \frac{1}{m} \sum{i=1}^m x_i
][
\sigmaB^2 = \frac{1}{m} \sum{i=1}^m (x_i - \mu_B)^2
] -
标准化:
使用计算得到的均值和方差对数据进行标准化。[
\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
] -
缩放和平移:
引入可学习的参数进行缩放和平移。[
y_i = \gamma \hat{x}_i + \beta
]其中,(\gamma)和(\beta)是可学习的参数。
1.2 Batch Normalization的实现
在PyTorch中,Batch Normalization可以通过 torch.nn.BatchNorm2d实现。
import torch
import torch.nn as nn# 创建BatchNorm层
batch_norm = nn.BatchNorm2d(num_features=64)# 输入数据
x = torch.randn(16, 64, 32, 32) # (batch_size, num_features, height, width)# 应用BatchNorm
output = batch_norm(x)
1.3 Batch Normalization的优缺点
优点:
- 加速训练:通过减少内部协变量偏移,加快了模型收敛速度。
- 稳定性提高:减小了梯度消失和爆炸的风险。
- 正则化效果:由于引入了噪声,有一定的正则化效果。
缺点:
- 依赖小批量大小:小批量大小过小时,均值和方差估计不准确。
- 训练和推理不一致:训练时使用小批量的均值和方差,推理时使用整个数据集的均值和方差。
二、Layer Normalization
2.1 Layer Normalization的原理
Layer Normalization(LN)通过在每一层的神经元输出上进行标准化,独立于小批量的大小。具体步骤如下:
-
计算每一层的均值和方差:
对于每一层的神经元输出,计算其均值和方差。[
\muL = \frac{1}{H} \sum{i=1}^H x_i
][
\sigmaL^2 = \frac{1}{H} \sum{i=1}^H (x_i - \mu_L)^2
] -
标准化:
使用计算得到的均值和方差对数据进行标准化。[
\hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}}
] -
缩放和平移:
引入可学习的参数进行缩放和平移。[
y_i = \gamma \hat{x}_i + \beta
]其中,(\gamma)和(\beta)是可学习的参数。
2.2 Layer Normalization的实现
在PyTorch中,Layer Normalization可以通过 torch.nn.LayerNorm实现。
import torch
import torch.nn as nn# 创建LayerNorm层
layer_norm = nn.LayerNorm(normalized_shape=64)# 输入数据
x = torch.randn(16, 64)# 应用LayerNorm
output = layer_norm(x)
2.3 Layer Normalization的优缺点
优点:
- 与小批量大小无关:适用于小批量训练和在线学习。
- 更适合RNN:在循环神经网络中表现更好,因为它独立于时间步长。
缺点:
- 计算开销较大:每一层都需要计算均值和方差,计算开销较大。
- 对CNN效果不明显:在卷积神经网络中效果不如BN明显。
三、RMS Normalization
3.1 RMS Normalization的原理
RMS Normalization(RMSNorm)通过标准化每一层的RMS值,而不是均值和方差。具体步骤如下:
-
计算RMS值:
对于每一层的神经元输出,计算其RMS值。[
\text{RMS}(x) = \sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2}
] -
标准化:
使用计算得到的RMS值对数据进行标准化。[
\hat{x}_i = \frac{x_i}{\text{RMS}(x) + \epsilon}
] -
缩放和平移:
引入可学习的参数进行缩放和平移。[
y_i = \gamma \hat{x}_i + \beta
]其中,(\gamma)和(\beta)是可学习的参数。
3.2 RMS Normalization的实现
在PyTorch中,RMS Normalization没有直接的内置实现,可以通过自定义层来实现。
import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, normalized_shape, epsilon=1e-8):super(RMSNorm, self).__init__()self.epsilon = epsilonself.gamma = nn.Parameter(torch.ones(normalized_shape))self.beta = nn.Parameter(torch.zeros(normalized_shape))def forward(self, x):rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.epsilon)x = x / rmsreturn self.gamma * x + self.beta# 创建RMSNorm层
rms_norm = RMSNorm(normalized_shape=64)# 输入数据
x = torch.randn(16, 64)# 应用RMSNorm
output = rms_norm(x)
3.3 RMS Normalization的优缺点
优点:
- 计算效率高:计算RMS值相对简单,计算开销较小。
- 稳定性好:在某些任务中可以表现出更好的稳定性。
缺点:
- 应用较少:相较于BN和LN,应用场景和研究较少。
- 效果不确定:在某些情况下效果可能不如BN和LN显著。
四、比较与应用场景
4.1 比较
| 特性 | Batch Norm | Layer Norm | RMSNorm |
|---|---|---|---|
| 标准化维度 | 小批量内各特征维度 | 每层各特征维度 | 每层各特征维度的RMS |
| 计算开销 | 中等 | 较大 | 较小 |
| 对小批量大小依赖 | 依赖 | 不依赖 | 不依赖 |
| 应用场景 | CNN、MLP | RNN、Transformer | 各类神经网络 |
| 正则化效果 | 有一定正则化效果 | 无显著正则化效果 | 无显著正则化效果 |
4.2 应用场景
-
Batch Normalization:
- 适用于卷积神经网络(CNN)和多层感知机(MLP)。
- 对小批量大小有依赖,不适合小批量和在线学习。
-
Layer Normalization:
- 适用于循环神经网络(RNN)和Transformer。
- 独立于小批量大小,适合小批量和在线学习。
-
RMS Normalization:
- 适用于各种神经网络,尤其在计算效率和稳定性有要求的任务中。
- 相对较新,应用场景和研究较少,但在某些任务中可能表现优异。
五、总结
Batch Normalization
、Layer Normalization和RMS Normalization是深度学习中常用的标准化技术。它们各有优缺点,适用于不同的应用场景。通过理解其原理和实现,您可以根据具体需求选择合适的标准化方法,提升模型的训练速度和性能。
相关文章:
详解三种常用标准化:Batch Norm、Layer Norm和RMSNorm
在深度学习中,标准化技术是提升模型训练速度、稳定性和性能的重要手段。本文将详细介绍三种常用的标准化方法:Batch Normalization(批量标准化)、Layer Normalization(层标准化)和 RMS Normalization&#…...
linux+docker+nacos+mysql部署
一、下载 docker pull mysql:5.7 docker pull nacos/nacos-server:v2.2.2 docker images 二、mysql部署 1、创建目录存储数据信息 mkdir ~/mysql cd ~/mysql 2、运行 MySQL 容器 docker run -id \ -p 3306:3306 \ --name mysql \ -v $PWD/conf:/etc/mysql/conf.d \ -v $PWD/…...
如何实现gitlab和jira连通
将 GitLab 和 Jira 集成起来可以实现开发任务与代码变更的联动,提高团队协作效率。以下是实现两者连通的详细步骤: 1. 确保必要条件 在进行集成之前,确保以下条件满足: 你有 GitLab 和 Jira 的管理员权限。Jira 是 Jira Cloud 或…...
利用ML.NET精准提取人名
在当今信息爆炸的时代,文本处理任务层出不穷,其中人名提取作为基础且重要的工作,广泛应用于信息检索、社交网络分析、客户关系管理等领域。随着人工智能不断进步,ML.NET作为微软推出的开源机器学习框架,为开发者提供了…...
Node.js的解释
1. Node.js 入门教程 1.1 什么是 Node.js? 1.1.1 Node.js 是什么? Node.js 是一个基于 JavaScript 的开源服务器端运行时环境,允许开发者用 JavaScript 编写服务器端代码。与传统的前端 JavaScript 主要运行在浏览器端不同,Nod…...
Macos下交叉编译安卓的paq8px压缩算法
官方没有android的编译方法,自己编写脚本在macos下交叉编译. 下载源码: git clone https://github.com/hxim/paq8px.git 稍旧的ndk并不能编译成功,需要下载最新的ndkr27c, 最后是使用clang来编译。 编译build.sh export ANDROID_NDK/Vol…...
如何在data.table中处理缺失值
📊💻【R语言进阶】轻松搞定缺失值,让数据清洗更高效! 👋 大家好呀!今天我要和大家分享一个超实用的R语言技巧——如何在data.table中处理缺失值,并且提供了一个自定义函数calculate_missing_va…...
从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑
文章目录 从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑一 参考二 安装三 启动准备大模型文件 四 数据集(关键)!4.1 Alapaca格式4.2 sharegpt4.3 在 dataset_info.json 中注册4.4 官方 alpaca_zh_demo 例子 999条数据, 本机微调 5分…...
SQL-leetcode—1164. 指定日期的产品价格
1164. 指定日期的产品价格 产品数据表: Products ---------------------- | Column Name | Type | ---------------------- | product_id | int | | new_price | int | | change_date | date | ---------------------- (product_id, change_date) 是此表的主键(具…...
[Day 15]54.螺旋矩阵(简单易懂 有画图)
今天我们来看这道螺旋矩阵,和昨天发的题很类似。没有技巧,全是循环。小白也能懂~ 力扣54.螺旋矩阵 题目描述: 给你一个 m 行 n 列的矩阵 matrix ,请按照 顺时针螺旋顺序 ,返回矩阵中的所有元素。 示例 1: …...
HTTP 配置与应用(不同网段)
想做一个自己学习的有关的csdn账号,努力奋斗......会更新我计算机网络实验课程的所有内容,还有其他的学习知识^_^,为自己巩固一下所学知识,下次更新校园网设计。 我是一个萌新小白,有误地方请大家指正,谢谢…...
Quartus:开发使用及 Tips 总结
Quartus是Altera(现已被Intel收购)推出的一款针对其FPGA产品的综合性开发环境,用于设计、仿真和调试数字电路。以下是使用Quartus的一些总结和技巧(Tips),帮助更高效地进行FPGA项目开发: 这里写目录标题 使用总结TIPS…...
VSCode下EIDE插件开发STM32
VSCode下STM32开发环境搭建 本STM32教程使用vscode的EIDE插件的开发环境,完全免费,有管理代码文件的界面,不需要其它IDE。 视频教程见本人的 VSCodeEIDE开发STM32 安装EIDE插件 Embedded IDE 嵌入式IDE 这个插件可以帮我们管理代码文件&am…...
Golang并发机制及CSP并发模型
Golang 并发机制及 CSP 并发模型 Golang 是一门为并发而生的语言,其并发机制基于 CSP(Communicating Sequential Processes,通信顺序过程) 模型。CSP 是一种描述并发系统中交互模式的正式语言,强调通过通信来共享内存…...
HTML 文本格式化详解
在网页开发中,文本内容的呈现方式直接影响用户的阅读体验。HTML 提供了多种文本格式化元素,可以帮助我们更好地控制文本的显示效果。本文将详细介绍 HTML 中的文本格式化元素及其使用方法,帮助你轻松实现网页文本的美化。 什么是 HTML 文本格…...
我谈《概率论与数理统计》的知识体系
学习《概率论与数理统计》二十多年后,在廖老师的指导下,才厘清了各章之间的关系。首先,这是两个学科综合的一门课程,这一门课程中还有术语冲突的问题。这一门课程一条线两个分支,脉络很清晰。 概率论与统计学 概率论…...
五、华为 RSTP
RSTP(Rapid Spanning Tree Protocol,快速生成树协议)是 STP 的优化版本,能实现网络拓扑的快速收敛。 一、RSTP 原理 快速收敛机制:RSTP 通过引入边缘端口、P/A(Proposal/Agreement)机制等&…...
基于Java Web的网上房屋租售网站
内容摘要 本毕业设计题目为《基于Java Web的网上房屋租售网站》,是在信息化时代下充分利用互联网对传统房屋租售方式进行创新,在互联网上进行房屋租售突破了传统方式的局限性。对于房屋租售的当事人都提供了极大的便利。本稳针对了实际用户需求…...
Pyside6(PyQT5)中的QTableView与QSqlQueryModel、QSqlTableModel的联合使用
QTableView 是QT的一个强大的表视图部件,可以与模型结合使用以显示和编辑数据。QSqlQueryModel、QSqlTableModel 都是用于与 SQL 数据库交互的模型,将二者与QTableView结合使用可以轻松地展示和编辑数据库的数据。 QSqlQueryModel的简单应用 import sys from PySid…...
git常用命令学习
目录 文章目录 目录第一章 git简介1.Git 与SVN2.Git 工作区、暂存区和版本库 第二章 git常用命令学习1.ssh设置2.设置用户信息3.常用命令设置1.初始化本地仓库init2.克隆clone3.查看状态 git status4.添加add命令5.添加评论6.分支操作1.创建分支2.查看分支3.切换分支4.删除分支…...
【Oracle APEX开发小技巧12】
有如下需求: 有一个问题反馈页面,要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据,方便管理员及时处理反馈。 我的方法:直接将逻辑写在SQL中,这样可以直接在页面展示 完整代码: SELECTSF.FE…...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...
12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
关于 WASM:1. WASM 基础原理
一、WASM 简介 1.1 WebAssembly 是什么? WebAssembly(WASM) 是一种能在现代浏览器中高效运行的二进制指令格式,它不是传统的编程语言,而是一种 低级字节码格式,可由高级语言(如 C、C、Rust&am…...
dify打造数据可视化图表
一、概述 在日常工作和学习中,我们经常需要和数据打交道。无论是分析报告、项目展示,还是简单的数据洞察,一个清晰直观的图表,往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server,由蚂蚁集团 AntV 团队…...
深度学习水论文:mamba+图像增强
🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...
RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)
RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发,后来由Pivotal Software Inc.(现为VMware子公司)接管。RabbitMQ 是一个开源的消息代理和队列服务器,用 Erlang 语言编写。广泛应用于各种分布…...
第八部分:阶段项目 6:构建 React 前端应用
现在,是时候将你学到的 React 基础知识付诸实践,构建一个简单的前端应用来模拟与后端 API 的交互了。在这个阶段,你可以先使用模拟数据,或者如果你的后端 API(阶段项目 5)已经搭建好,可以直接连…...
加密通信 + 行为分析:运营商行业安全防御体系重构
在数字经济蓬勃发展的时代,运营商作为信息通信网络的核心枢纽,承载着海量用户数据与关键业务传输,其安全防御体系的可靠性直接关乎国家安全、社会稳定与企业发展。随着网络攻击手段的不断升级,传统安全防护体系逐渐暴露出局限性&a…...
对象回调初步研究
_OBJECT_TYPE结构分析 在介绍什么是对象回调前,首先要熟悉下结构 以我们上篇线程回调介绍过的导出的PsProcessType 结构为例,用_OBJECT_TYPE这个结构来解析它,0x80处就是今天要介绍的回调链表,但是先不着急,先把目光…...
