详解三种常用标准化:Batch Norm、Layer Norm和RMSNorm
在深度学习中,标准化技术是提升模型训练速度、稳定性和性能的重要手段。本文将详细介绍三种常用的标准化方法:Batch Normalization(批量标准化)、Layer Normalization(层标准化)和 RMS Normalization(RMS标准化),并对其原理、实现和应用场景进行深入分析。
一、Batch Normalization
1.1 Batch Normalization的原理
Batch Normalization(BN)通过在每个小批量数据的每个神经元输出上进行标准化来减少内部协变量偏移。具体步骤如下:
-
计算小批量的均值和方差:
对于每个神经元的输出,计算该神经元在当前小批量中的均值和方差。[
\muB = \frac{1}{m} \sum{i=1}^m x_i
][
\sigmaB^2 = \frac{1}{m} \sum{i=1}^m (x_i - \mu_B)^2
] -
标准化:
使用计算得到的均值和方差对数据进行标准化。[
\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
] -
缩放和平移:
引入可学习的参数进行缩放和平移。[
y_i = \gamma \hat{x}_i + \beta
]其中,(\gamma)和(\beta)是可学习的参数。
1.2 Batch Normalization的实现
在PyTorch中,Batch Normalization可以通过 torch.nn.BatchNorm2d
实现。
import torch
import torch.nn as nn# 创建BatchNorm层
batch_norm = nn.BatchNorm2d(num_features=64)# 输入数据
x = torch.randn(16, 64, 32, 32) # (batch_size, num_features, height, width)# 应用BatchNorm
output = batch_norm(x)
1.3 Batch Normalization的优缺点
优点:
- 加速训练:通过减少内部协变量偏移,加快了模型收敛速度。
- 稳定性提高:减小了梯度消失和爆炸的风险。
- 正则化效果:由于引入了噪声,有一定的正则化效果。
缺点:
- 依赖小批量大小:小批量大小过小时,均值和方差估计不准确。
- 训练和推理不一致:训练时使用小批量的均值和方差,推理时使用整个数据集的均值和方差。
二、Layer Normalization
2.1 Layer Normalization的原理
Layer Normalization(LN)通过在每一层的神经元输出上进行标准化,独立于小批量的大小。具体步骤如下:
-
计算每一层的均值和方差:
对于每一层的神经元输出,计算其均值和方差。[
\muL = \frac{1}{H} \sum{i=1}^H x_i
][
\sigmaL^2 = \frac{1}{H} \sum{i=1}^H (x_i - \mu_L)^2
] -
标准化:
使用计算得到的均值和方差对数据进行标准化。[
\hat{x}_i = \frac{x_i - \mu_L}{\sqrt{\sigma_L^2 + \epsilon}}
] -
缩放和平移:
引入可学习的参数进行缩放和平移。[
y_i = \gamma \hat{x}_i + \beta
]其中,(\gamma)和(\beta)是可学习的参数。
2.2 Layer Normalization的实现
在PyTorch中,Layer Normalization可以通过 torch.nn.LayerNorm
实现。
import torch
import torch.nn as nn# 创建LayerNorm层
layer_norm = nn.LayerNorm(normalized_shape=64)# 输入数据
x = torch.randn(16, 64)# 应用LayerNorm
output = layer_norm(x)
2.3 Layer Normalization的优缺点
优点:
- 与小批量大小无关:适用于小批量训练和在线学习。
- 更适合RNN:在循环神经网络中表现更好,因为它独立于时间步长。
缺点:
- 计算开销较大:每一层都需要计算均值和方差,计算开销较大。
- 对CNN效果不明显:在卷积神经网络中效果不如BN明显。
三、RMS Normalization
3.1 RMS Normalization的原理
RMS Normalization(RMSNorm)通过标准化每一层的RMS值,而不是均值和方差。具体步骤如下:
-
计算RMS值:
对于每一层的神经元输出,计算其RMS值。[
\text{RMS}(x) = \sqrt{\frac{1}{H} \sum_{i=1}^H x_i^2}
] -
标准化:
使用计算得到的RMS值对数据进行标准化。[
\hat{x}_i = \frac{x_i}{\text{RMS}(x) + \epsilon}
] -
缩放和平移:
引入可学习的参数进行缩放和平移。[
y_i = \gamma \hat{x}_i + \beta
]其中,(\gamma)和(\beta)是可学习的参数。
3.2 RMS Normalization的实现
在PyTorch中,RMS Normalization没有直接的内置实现,可以通过自定义层来实现。
import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, normalized_shape, epsilon=1e-8):super(RMSNorm, self).__init__()self.epsilon = epsilonself.gamma = nn.Parameter(torch.ones(normalized_shape))self.beta = nn.Parameter(torch.zeros(normalized_shape))def forward(self, x):rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.epsilon)x = x / rmsreturn self.gamma * x + self.beta# 创建RMSNorm层
rms_norm = RMSNorm(normalized_shape=64)# 输入数据
x = torch.randn(16, 64)# 应用RMSNorm
output = rms_norm(x)
3.3 RMS Normalization的优缺点
优点:
- 计算效率高:计算RMS值相对简单,计算开销较小。
- 稳定性好:在某些任务中可以表现出更好的稳定性。
缺点:
- 应用较少:相较于BN和LN,应用场景和研究较少。
- 效果不确定:在某些情况下效果可能不如BN和LN显著。
四、比较与应用场景
4.1 比较
特性 | Batch Norm | Layer Norm | RMSNorm |
---|---|---|---|
标准化维度 | 小批量内各特征维度 | 每层各特征维度 | 每层各特征维度的RMS |
计算开销 | 中等 | 较大 | 较小 |
对小批量大小依赖 | 依赖 | 不依赖 | 不依赖 |
应用场景 | CNN、MLP | RNN、Transformer | 各类神经网络 |
正则化效果 | 有一定正则化效果 | 无显著正则化效果 | 无显著正则化效果 |
4.2 应用场景
-
Batch Normalization:
- 适用于卷积神经网络(CNN)和多层感知机(MLP)。
- 对小批量大小有依赖,不适合小批量和在线学习。
-
Layer Normalization:
- 适用于循环神经网络(RNN)和Transformer。
- 独立于小批量大小,适合小批量和在线学习。
-
RMS Normalization:
- 适用于各种神经网络,尤其在计算效率和稳定性有要求的任务中。
- 相对较新,应用场景和研究较少,但在某些任务中可能表现优异。
五、总结
Batch Normalization
、Layer Normalization和RMS Normalization是深度学习中常用的标准化技术。它们各有优缺点,适用于不同的应用场景。通过理解其原理和实现,您可以根据具体需求选择合适的标准化方法,提升模型的训练速度和性能。
相关文章:
详解三种常用标准化:Batch Norm、Layer Norm和RMSNorm
在深度学习中,标准化技术是提升模型训练速度、稳定性和性能的重要手段。本文将详细介绍三种常用的标准化方法:Batch Normalization(批量标准化)、Layer Normalization(层标准化)和 RMS Normalization&#…...
linux+docker+nacos+mysql部署
一、下载 docker pull mysql:5.7 docker pull nacos/nacos-server:v2.2.2 docker images 二、mysql部署 1、创建目录存储数据信息 mkdir ~/mysql cd ~/mysql 2、运行 MySQL 容器 docker run -id \ -p 3306:3306 \ --name mysql \ -v $PWD/conf:/etc/mysql/conf.d \ -v $PWD/…...
如何实现gitlab和jira连通
将 GitLab 和 Jira 集成起来可以实现开发任务与代码变更的联动,提高团队协作效率。以下是实现两者连通的详细步骤: 1. 确保必要条件 在进行集成之前,确保以下条件满足: 你有 GitLab 和 Jira 的管理员权限。Jira 是 Jira Cloud 或…...
利用ML.NET精准提取人名
在当今信息爆炸的时代,文本处理任务层出不穷,其中人名提取作为基础且重要的工作,广泛应用于信息检索、社交网络分析、客户关系管理等领域。随着人工智能不断进步,ML.NET作为微软推出的开源机器学习框架,为开发者提供了…...
Node.js的解释
1. Node.js 入门教程 1.1 什么是 Node.js? 1.1.1 Node.js 是什么? Node.js 是一个基于 JavaScript 的开源服务器端运行时环境,允许开发者用 JavaScript 编写服务器端代码。与传统的前端 JavaScript 主要运行在浏览器端不同,Nod…...
Macos下交叉编译安卓的paq8px压缩算法
官方没有android的编译方法,自己编写脚本在macos下交叉编译. 下载源码: git clone https://github.com/hxim/paq8px.git 稍旧的ndk并不能编译成功,需要下载最新的ndkr27c, 最后是使用clang来编译。 编译build.sh export ANDROID_NDK/Vol…...

如何在data.table中处理缺失值
📊💻【R语言进阶】轻松搞定缺失值,让数据清洗更高效! 👋 大家好呀!今天我要和大家分享一个超实用的R语言技巧——如何在data.table中处理缺失值,并且提供了一个自定义函数calculate_missing_va…...

从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑
文章目录 从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑一 参考二 安装三 启动准备大模型文件 四 数据集(关键)!4.1 Alapaca格式4.2 sharegpt4.3 在 dataset_info.json 中注册4.4 官方 alpaca_zh_demo 例子 999条数据, 本机微调 5分…...
SQL-leetcode—1164. 指定日期的产品价格
1164. 指定日期的产品价格 产品数据表: Products ---------------------- | Column Name | Type | ---------------------- | product_id | int | | new_price | int | | change_date | date | ---------------------- (product_id, change_date) 是此表的主键(具…...

[Day 15]54.螺旋矩阵(简单易懂 有画图)
今天我们来看这道螺旋矩阵,和昨天发的题很类似。没有技巧,全是循环。小白也能懂~ 力扣54.螺旋矩阵 题目描述: 给你一个 m 行 n 列的矩阵 matrix ,请按照 顺时针螺旋顺序 ,返回矩阵中的所有元素。 示例 1: …...

HTTP 配置与应用(不同网段)
想做一个自己学习的有关的csdn账号,努力奋斗......会更新我计算机网络实验课程的所有内容,还有其他的学习知识^_^,为自己巩固一下所学知识,下次更新校园网设计。 我是一个萌新小白,有误地方请大家指正,谢谢…...

Quartus:开发使用及 Tips 总结
Quartus是Altera(现已被Intel收购)推出的一款针对其FPGA产品的综合性开发环境,用于设计、仿真和调试数字电路。以下是使用Quartus的一些总结和技巧(Tips),帮助更高效地进行FPGA项目开发: 这里写目录标题 使用总结TIPS…...

VSCode下EIDE插件开发STM32
VSCode下STM32开发环境搭建 本STM32教程使用vscode的EIDE插件的开发环境,完全免费,有管理代码文件的界面,不需要其它IDE。 视频教程见本人的 VSCodeEIDE开发STM32 安装EIDE插件 Embedded IDE 嵌入式IDE 这个插件可以帮我们管理代码文件&am…...

Golang并发机制及CSP并发模型
Golang 并发机制及 CSP 并发模型 Golang 是一门为并发而生的语言,其并发机制基于 CSP(Communicating Sequential Processes,通信顺序过程) 模型。CSP 是一种描述并发系统中交互模式的正式语言,强调通过通信来共享内存…...
HTML 文本格式化详解
在网页开发中,文本内容的呈现方式直接影响用户的阅读体验。HTML 提供了多种文本格式化元素,可以帮助我们更好地控制文本的显示效果。本文将详细介绍 HTML 中的文本格式化元素及其使用方法,帮助你轻松实现网页文本的美化。 什么是 HTML 文本格…...

我谈《概率论与数理统计》的知识体系
学习《概率论与数理统计》二十多年后,在廖老师的指导下,才厘清了各章之间的关系。首先,这是两个学科综合的一门课程,这一门课程中还有术语冲突的问题。这一门课程一条线两个分支,脉络很清晰。 概率论与统计学 概率论…...

五、华为 RSTP
RSTP(Rapid Spanning Tree Protocol,快速生成树协议)是 STP 的优化版本,能实现网络拓扑的快速收敛。 一、RSTP 原理 快速收敛机制:RSTP 通过引入边缘端口、P/A(Proposal/Agreement)机制等&…...

基于Java Web的网上房屋租售网站
内容摘要 本毕业设计题目为《基于Java Web的网上房屋租售网站》,是在信息化时代下充分利用互联网对传统房屋租售方式进行创新,在互联网上进行房屋租售突破了传统方式的局限性。对于房屋租售的当事人都提供了极大的便利。本稳针对了实际用户需求…...

Pyside6(PyQT5)中的QTableView与QSqlQueryModel、QSqlTableModel的联合使用
QTableView 是QT的一个强大的表视图部件,可以与模型结合使用以显示和编辑数据。QSqlQueryModel、QSqlTableModel 都是用于与 SQL 数据库交互的模型,将二者与QTableView结合使用可以轻松地展示和编辑数据库的数据。 QSqlQueryModel的简单应用 import sys from PySid…...

git常用命令学习
目录 文章目录 目录第一章 git简介1.Git 与SVN2.Git 工作区、暂存区和版本库 第二章 git常用命令学习1.ssh设置2.设置用户信息3.常用命令设置1.初始化本地仓库init2.克隆clone3.查看状态 git status4.添加add命令5.添加评论6.分支操作1.创建分支2.查看分支3.切换分支4.删除分支…...
Cursor实现用excel数据填充word模版的方法
cursor主页:https://www.cursor.com/ 任务目标:把excel格式的数据里的单元格,按照某一个固定模版填充到word中 文章目录 注意事项逐步生成程序1. 确定格式2. 调试程序 注意事项 直接给一个excel文件和最终呈现的word文件的示例,…...

跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...
三体问题详解
从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包
文章目录 现象:mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时,可能是因为以下几个原因:1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
C#中的CLR属性、依赖属性与附加属性
CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...
区块链技术概述
区块链技术是一种去中心化、分布式账本技术,通过密码学、共识机制和智能合约等核心组件,实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点:数据存储在网络中的多个节点(计算机),而非…...

Python训练营-Day26-函数专题1:函数定义与参数
题目1:计算圆的面积 任务: 编写一个名为 calculate_circle_area 的函数,该函数接收圆的半径 radius 作为参数,并返回圆的面积。圆的面积 π * radius (可以使用 math.pi 作为 π 的值)要求:函数接收一个位置参数 radi…...

数据分析六部曲?
引言 上一章我们说到了数据分析六部曲,何谓六部曲呢? 其实啊,数据分析没那么难,只要掌握了下面这六个步骤,也就是数据分析六部曲,就算你是个啥都不懂的小白,也能慢慢上手做数据分析啦。 第一…...