当前位置: 首页 > news >正文

batchnorm和layernorm的理解

batchnorm和layernorm原理和区别

batchnorm

原理

  对于一个特征tensor

x ∈ R b × c × f 1 × f 2 × … x \in \mathbb{R}^{b \times c \times f_1 \times f_2 \times \dots} xRb×c×f1×f2×

  其中, c c c是通道, f f f是通道中各种特征,batchnorm是对所有batch的每个通道特征分别进行归一化,即对于第 i i i个通道,选出他的所有batch的第 i i i个通道的所有特征 x [ : , i , : , : , . . . ] x[:,i,:,:,...] x[:,i,:,:,...]进行归一化:

μ = 1 m ∑ i = 1 m x i σ 2 = 1 m ∑ i = 1 m ( x i − μ ) 2 x ^ i = x i − μ σ 2 + ϵ y i = γ x ^ i + β \mu = \frac{1}{m} \sum_{i=1}^{m} x_i \\ \sigma^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu)^2 \\ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i = \gamma \hat{x}_i + \beta μ=m1i=1mxiσ2=m1i=1m(xiμ)2x^i=σ2+ϵ xiμyi=γx^i+β
  其中, ϵ \epsilon ϵ是较小的数,为了防止分母为0, γ \gamma γ β \beta β是可训练参数,用于仿射变换
  因此,输出向量形状不发生改变,可训练参数取决于通道 c c c的个数

代码

import torch
import torch.nn as nn## BatchNorm1d
batchnorm1d = nn.BatchNorm1d(num_features=10)  # 输入特征数为 10
# 构造输入张量 (N, C)
x_1d = torch.randn(16, 10)  # 模拟训练数据 (batch_size=16, num_features=10)
y = batchnorm1d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C)## BatchNorm2d
batchnorm2d = nn.BatchNorm2d(num_features=16)  # 输入通道数为 16
# 构造输入张量 (N, C, H, W)
x = torch.randn(8, 16, 32, 32)  # 模拟训练数据
y = batchnorm2d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C, H, W)## BatchNorm3d
batchnorm3d = nn.BatchNorm3d(num_features=8)  # 输入通道数为 8
# 构造输入张量 (N, C, D, H, W)
x = torch.randn(4, 8, 16, 32, 32)  # 模拟训练数据
y = batchnorm3d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C, D, H, W)

作用

  • 加速模型收敛
  • 缓解梯度消失或爆炸问题
  • 降低对初始化的敏感性
  • 一定程度上防止过拟合

通俗理解

    从图片数据 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w来理解:对于不同通道的特征,我们经过batchnorm以后,不同通道的特征便失去了可比性或者相对性(通道一:9,8,7,6;通道二:8,7,6,5经过归一化后都一样),但是从人的视角来区分一张图像不是靠不同通道之间的差异性,换句话说RGB稍微变一下变成RRR一样可以区分出图片中物体、属于哪一类。因此我们使用batchnorm相当于告诉模型你应该从同一通道内的数据差异来得到结果,而不是通道间。
    其实你不告诉模型(不用batchnorm),模型在多次训练后也可以发现相同的规律,只是模型训练就慢了,因此我们说batchnorm可以加速收敛

注意点

    模型训练阶段,我们存在batch维度,使用batchnorm时模型的正向计算会受到batch大小影响。那么我们在推理阶段通常只有一个batch,怎么办呢?
    解决方法:在pytorch中,模型训练时,调用 m o d e l . t r a i n ( ) model.train() model.train(),当创建一个 BatchNorm 层时,PyTorch 会自动初始化 r u n n i n g m e a n running_{mean} runningmean r u n n i n g v a r running_{var} runningvar,并将它们存储为模型的缓冲区(buffers)。
    在训练阶段:训练模式下( m o d e l . t r a i n ( ) model.train() model.train()),每次对 mini-batch 进行前向传播时:

r u n n i n g m e a n = m o m e n t u m ⋅ r u n n i n g m e a n + ( 1 − m o m e n t u m ) ⋅ μ b a t c h r u n n i n g v a r = m o m e n t u m ⋅ r u n n i n g v a r + ( 1 − m o m e n t u m ) ⋅ σ b a t c h 2 r u n n i n g v a r = m o m e n t u m ⋅ r u n n i n g v a r + ( 1 − m o m e n t u m ) ⋅ σ b a t c h 2 running_{mean}=momentum⋅running_mean+(1−momentum)⋅μ_{batch} \\ running_{var}=momentum⋅running_var+(1−momentum)⋅σ_{batch}^2 \\ running_{var}=momentum⋅running_var+(1−momentum)⋅σ_{batch}^2 runningmean=momentumrunningmean+(1momentum)μbatchrunningvar=momentumrunningvar+(1momentum)σbatch2runningvar=momentumrunningvar+(1momentum)σbatch2
    在推理阶段:推理模型下( m o d e l . e v a l ( ) model.eval() model.eval()),模型会使用 r u n n i n g m e a n running_{mean} runningmean r u n n i n g v a r running_{var} runningvar进行归一化,从而避免上述问题

layernorm

原理

    对于一个时序特征tensor

x ∈ R b × t × d x \in \mathbb{R}^{b \times t \times d} xRb×t×d

  其中, t t t是时序, d d d是某一时间点的各种特征,layernorm是对某一时间点的特征进行归一化,即对于第 i i i个时间点,分别在各个batch上选出他的所有特征 x [ k , i , : ] x[k,i,:] x[k,i,:]进行归一化:
μ = 1 d ∑ i = 1 d x i σ 2 = 1 d ∑ i = 1 d ( x i − μ ) 2 x ^ i = x i − μ σ 2 + ϵ y i = γ d x ^ i + β d \mu = \frac{1}{d} \sum_{i=1}^d x_i \\ \sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 \\ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i = \gamma_{d} \hat{x}_i + \beta_{d} μ=d1i=1dxiσ2=d1i=1d(xiμ)2x^i=σ2+ϵ xiμyi=γdx^i+βd
  其中, ϵ \epsilon ϵ是较小的数,为了防止分母为0, γ d \gamma_{d} γd β d \beta_{d} βd是可训练参数,用于仿射变换
  因此,输出向量形状不发生改变,可训练参数取决于特征维度$d$的个数

代码

import torch
import torch.nn as nn
# 特征维度 d=16,归一化最后一个维度
layernorm = nn.LayerNorm(normalized_shape=16)# 打印可训练参数
print("Trainable parameters (gamma):", layernorm.weight.shape)  # gamma (scale) (16)
print("Trainable parameters (beta):", layernorm.bias.shape)    # beta (shift) (16)# 构造输入张量 (b, t, d)
x = torch.randn(32, 10, 16) # 前向传播
output = layernorm(x)# 打印结果
print("Input shape:", x.shape)       # [32, 10, 16]
print("Output shape:", output.shape) # [32, 10, 16]

作用

  • 加速模型收敛
  • 缓解梯度消失或爆炸问题
  • 降低对初始化的敏感性
  • 一定程度上防止过拟合

通俗理解

  从时序数据(或者一段话)数据 x ∈ R b × t × d x \in \mathbb{R}^{b \times t \times d} xRb×t×d来理解:对于不同时间点的特征,我们经过layernorm以后,不同时间点的特征便失去了可比性或者相对性(时间点1:9,8,7,6;时间点2:8,7,6,5经过归一化后都一样),。。。目前我也不知道怎么直观理解(后续有新的理解持续更新。。)

注意点

和batchnorm一样,在pytorch中,训练和推理时,一定要指定模型当前状态

总结

BatchNorm适用于CV,LayerNorm适用于NLP

本文是我学习过程中的个人理解,有不对的地方希望大家帮忙指出。希望可以抛砖引玉,欢迎大家在评论区和我交流。

相关文章:

batchnorm和layernorm的理解

batchnorm和layernorm原理和区别 batchnorm 原理 对于一个特征tensor x ∈ R b c f 1 f 2 … x \in \mathbb{R}^{b \times c \times f_1 \times f_2 \times \dots} x∈Rbcf1​f2​… 其中, c c c是通道, f f f是通道中各种特征,batchno…...

在git commit之前让其自动执行一次git pull命令

文章目录 背景原因编写脚本测试效果 背景原因 有时候可以看到项目的git 提交日志里好多 Merge branch ‘master’ of …记录。这些记录是怎么产生的呢? 是因为在本地操作 git add . 、 git commit -m "xxxxx"时,没有提前进行git pull操作&…...

【Rust自学】6.3. 控制流运算符-match

喜欢的话别忘了点赞、收藏加关注哦,对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 6.3.1. 什么是match match允许一个值与一系列模式进行匹配,并执行匹配的模式对应的代码。模式可以是字面值、变量名、通配符等…...

大模型应用技术系列(三): 深入理解大模型应用中的Cache:GPTCache

前言 无论在什么技术栈中,缓存都是比较重要的一部分。在大模型技术栈中,缓存存在于技术栈中的不同层次。本文将主要聚焦于技术栈中应用层和底层基座之间中间件层的缓存(个人定位),以开源项目GPTCache(LLM的语义缓存)为例,深入讲解这部分缓存的结构和关键实现。 完整技术…...

『大模型笔记』评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释

评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释 文章目录 一. ELO Rating大模型的elo得分如何理解1. Elo评分的基本原理2. 示例说明3. 大模型中的Elo得分总结3个模型之间如何比较计算,给出示例进行解释1. 基本原理扩展到三方2. 示例计算第一场: A A…...

深度解析:Maven 和 Gradle 的使用比较及常见仓库推荐

Maven 和 Gradle 是 Java 项目中最常用的构建工具。它们各有优势,适用于不同的场景。本文将对两者进行详细的对比,并推荐一些常用的 Maven 和 Gradle 仓库,帮助开发者高效管理依赖。 一、Maven 和 Gradle 的使用比较 1.1 基本介绍 Maven 基…...

SQLite本地数据库的简介和适用场景——集成SpringBoot的图文说明

前言:现在项目普遍使用的数据库都是MySQL,而有些项目实际上使用SQLite既足矣。在一些特定的项目中,要比MySQL更适用。 这一篇文章简单的介绍一下SQLite,对比MySQL的优缺点、以及适用的项目类型和集成SpringBoot。 1. SQLite 简介 …...

管理面板Ajenti的在Windows10下Ubuntu24.04/Ubuntu22.04里的安装

Ajenti是一款基于Web的开源系统管理控制面板,可用于通过Web浏览器,管理远程系统管理性任务,这一点与 Webmin模块 非常相似。 Ajenti是一款功能非常强大的轻型工具,它提供了快速的、反应灵敏的Web界面,可用于管理小型服…...

在Python如何用Type创建类

文章目录 一,如何创建类1:创建一个简单类2:添加属性和方法3:动态继承父类4:结合元类的使用总结 二.在什么情境下适合使用Type创建类1. **运行时动态生成类**2. **避免重复代码**3. **依赖元类或高级元编程**4. **动态扩…...

Android学习19 -- NDK4--共享内存(TODO)

在安卓的NDK(Native Development Kit)中,C共享内存通常用于不同进程间的通信,或者在同一进程中多线程之间共享数据。这种方法相较于其他形式的IPC(进程间通信)来说,具有更高的性能和低延迟。共享…...

《Cocos Creator游戏实战》非固定摇杆实现原理

为什么要使用非固定摇杆 许多同学在开发摇杆功能时,会将摇杆固定在屏幕左下某一位置,不会让其随着大拇指触摸点改变,而且玩家只有按在了摇杆上才能移动人物(触摸监听事件在摇杆精灵上)。然而,不同玩家的大拇指长度不同…...

RabbitMQ工作模式(详解 工作模式:简单队列、工作队列、公平分发以及消息应答和消息持久化)

文章目录 十.RabbitMQ10.1 简单队列实现10.2 Work 模式(工作队列)10.3 公平分发10.4 RabbitMQ 消息应答与消息持久化消息应答概念配置 消息持久化概念配置 十.RabbitMQ 10.1 简单队列实现 简单队列通常指的是一个基本的消息队列,它可以用于…...

【VScode】第三方GPT编程工具-CodeMoss安装教程

一、CodeMoss是什么? CodeMoss是一款集编程、学习和办公于一体的高效工具。它兼容多种主流平台,包括VSCode、IDER、Chrome插件、Web和APP等,支持插件安装,尤其在VSCode和IDER上的表现尤为出色。无论你是编程新手还是资深开发者&a…...

在JavaScript中,let 和 const有什么不同

在JavaScript中,let 和 const 是用于声明变量的关键字,但它们有一些重要的区别 1.重新赋值: let 声明的变量可以重新赋值。const 声明的变量必须在声明时初始化,并且之后不能重新赋值 let a 10; a 20; // 有效,a 的…...

Mysq学习-Mysql查询(4)

5.子查询 子查询指一个查询语句嵌套在另一个查询语句内部的查询,这个特性从MySQL4.1开始引入.在SELECT子句中先计算子查询,子查询结果作为外层另一个查询的过滤条件,查询可以基于一个表或者多个表. 子查询中常用的操作符有ANY(SOME),ALL,IN,EXISTS.子查询可以添加到SELECT,UPD…...

安装torch-geometric库

目录 1.查看 torch 和 CUDA 版本 2.依次下载和 torch 和 CUDA 对应版本的四个依赖库pyg-lib、torch-scatter、torch-sparse、torch-cluster以及torch-spline-conv 3.下载并安装torch-geometric库 1.查看 torch 和 CUDA 版本 查看CUDA版本 nvcc -V 查看pytorch版本 pip s…...

Java数组深入解析:定义、操作、常见问题与高频练习

一、数组的定义 1. 什么是数组 数组是一个容器,用来存储多个相同类型的数据。它属于引用数据类型,可以存储基本数据类型(如int、char)或者引用数据类型(如String、对象)。 2. 数组的定义方式 a. 动态初…...

Docker-构建自己的Web-Linux系统-镜像webtop:ubuntu-kde

介绍 安装自己的linux-server,可以作为学习使用,web方式访问,基于ubuntu构建开源项目 https://github.com/linuxserver/docker-webtop安装 docker run -d -p 1336:3000 -e PASSWORD123456 --name webtop lscr.io/linuxserver/webtop:ubuntu-kde登录 …...

【C语言练习(17)—输出杨辉三角形】

C语言练习(17) 文章目录 C语言练习(17)前言题目题目解析整体代码 前言 杨辉三角形的输出可以分三步,第一步构建一个三角形、第二步根据规律将三角形内容填写、第三步将三角形以等腰的形式输出 题目 请输出一个十行的…...

SpringMVC学习(二)——RESTful API、拦截器、异常处理、数据类型转换

一、RESTful (一)RESTful概述 RESTful是一种软件架构风格,用于设计网络应用程序。REST是“Representational State Transfer”的缩写,中文意思是“表现层状态转移”。它基于客户端-服务器模型和无状态操作,以及使用HTTP请求来处理数据。RES…...

Qt Widget类解析与代码注释

#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

基于服务器使用 apt 安装、配置 Nginx

🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

反射获取方法和属性

Java反射获取方法 在Java中,反射(Reflection)是一种强大的机制,允许程序在运行时访问和操作类的内部属性和方法。通过反射,可以动态地创建对象、调用方法、改变属性值,这在很多Java框架中如Spring和Hiberna…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南

🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...

JAVA后端开发——多租户

数据隔离是多租户系统中的核心概念,确保一个租户(在这个系统中可能是一个公司或一个独立的客户)的数据对其他租户是不可见的。在 RuoYi 框架(您当前项目所使用的基础框架)中,这通常是通过在数据表中增加一个…...

浪潮交换机配置track检测实现高速公路收费网络主备切换NQA

浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...

推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)

推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...

音视频——I2S 协议详解

I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议,专门用于在数字音频设备之间传输数字音频数据。它由飞利浦(Philips)公司开发,以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...