当前位置: 首页 > news >正文

batchnorm和layernorm的理解

batchnorm和layernorm原理和区别

batchnorm

原理

  对于一个特征tensor

x ∈ R b × c × f 1 × f 2 × … x \in \mathbb{R}^{b \times c \times f_1 \times f_2 \times \dots} xRb×c×f1×f2×

  其中, c c c是通道, f f f是通道中各种特征,batchnorm是对所有batch的每个通道特征分别进行归一化,即对于第 i i i个通道,选出他的所有batch的第 i i i个通道的所有特征 x [ : , i , : , : , . . . ] x[:,i,:,:,...] x[:,i,:,:,...]进行归一化:

μ = 1 m ∑ i = 1 m x i σ 2 = 1 m ∑ i = 1 m ( x i − μ ) 2 x ^ i = x i − μ σ 2 + ϵ y i = γ x ^ i + β \mu = \frac{1}{m} \sum_{i=1}^{m} x_i \\ \sigma^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu)^2 \\ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i = \gamma \hat{x}_i + \beta μ=m1i=1mxiσ2=m1i=1m(xiμ)2x^i=σ2+ϵ xiμyi=γx^i+β
  其中, ϵ \epsilon ϵ是较小的数,为了防止分母为0, γ \gamma γ β \beta β是可训练参数,用于仿射变换
  因此,输出向量形状不发生改变,可训练参数取决于通道 c c c的个数

代码

import torch
import torch.nn as nn## BatchNorm1d
batchnorm1d = nn.BatchNorm1d(num_features=10)  # 输入特征数为 10
# 构造输入张量 (N, C)
x_1d = torch.randn(16, 10)  # 模拟训练数据 (batch_size=16, num_features=10)
y = batchnorm1d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C)## BatchNorm2d
batchnorm2d = nn.BatchNorm2d(num_features=16)  # 输入通道数为 16
# 构造输入张量 (N, C, H, W)
x = torch.randn(8, 16, 32, 32)  # 模拟训练数据
y = batchnorm2d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C, H, W)## BatchNorm3d
batchnorm3d = nn.BatchNorm3d(num_features=8)  # 输入通道数为 8
# 构造输入张量 (N, C, D, H, W)
x = torch.randn(4, 8, 16, 32, 32)  # 模拟训练数据
y = batchnorm3d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C, D, H, W)

作用

  • 加速模型收敛
  • 缓解梯度消失或爆炸问题
  • 降低对初始化的敏感性
  • 一定程度上防止过拟合

通俗理解

    从图片数据 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w来理解:对于不同通道的特征,我们经过batchnorm以后,不同通道的特征便失去了可比性或者相对性(通道一:9,8,7,6;通道二:8,7,6,5经过归一化后都一样),但是从人的视角来区分一张图像不是靠不同通道之间的差异性,换句话说RGB稍微变一下变成RRR一样可以区分出图片中物体、属于哪一类。因此我们使用batchnorm相当于告诉模型你应该从同一通道内的数据差异来得到结果,而不是通道间。
    其实你不告诉模型(不用batchnorm),模型在多次训练后也可以发现相同的规律,只是模型训练就慢了,因此我们说batchnorm可以加速收敛

注意点

    模型训练阶段,我们存在batch维度,使用batchnorm时模型的正向计算会受到batch大小影响。那么我们在推理阶段通常只有一个batch,怎么办呢?
    解决方法:在pytorch中,模型训练时,调用 m o d e l . t r a i n ( ) model.train() model.train(),当创建一个 BatchNorm 层时,PyTorch 会自动初始化 r u n n i n g m e a n running_{mean} runningmean r u n n i n g v a r running_{var} runningvar,并将它们存储为模型的缓冲区(buffers)。
    在训练阶段:训练模式下( m o d e l . t r a i n ( ) model.train() model.train()),每次对 mini-batch 进行前向传播时:

r u n n i n g m e a n = m o m e n t u m ⋅ r u n n i n g m e a n + ( 1 − m o m e n t u m ) ⋅ μ b a t c h r u n n i n g v a r = m o m e n t u m ⋅ r u n n i n g v a r + ( 1 − m o m e n t u m ) ⋅ σ b a t c h 2 r u n n i n g v a r = m o m e n t u m ⋅ r u n n i n g v a r + ( 1 − m o m e n t u m ) ⋅ σ b a t c h 2 running_{mean}=momentum⋅running_mean+(1−momentum)⋅μ_{batch} \\ running_{var}=momentum⋅running_var+(1−momentum)⋅σ_{batch}^2 \\ running_{var}=momentum⋅running_var+(1−momentum)⋅σ_{batch}^2 runningmean=momentumrunningmean+(1momentum)μbatchrunningvar=momentumrunningvar+(1momentum)σbatch2runningvar=momentumrunningvar+(1momentum)σbatch2
    在推理阶段:推理模型下( m o d e l . e v a l ( ) model.eval() model.eval()),模型会使用 r u n n i n g m e a n running_{mean} runningmean r u n n i n g v a r running_{var} runningvar进行归一化,从而避免上述问题

layernorm

原理

    对于一个时序特征tensor

x ∈ R b × t × d x \in \mathbb{R}^{b \times t \times d} xRb×t×d

  其中, t t t是时序, d d d是某一时间点的各种特征,layernorm是对某一时间点的特征进行归一化,即对于第 i i i个时间点,分别在各个batch上选出他的所有特征 x [ k , i , : ] x[k,i,:] x[k,i,:]进行归一化:
μ = 1 d ∑ i = 1 d x i σ 2 = 1 d ∑ i = 1 d ( x i − μ ) 2 x ^ i = x i − μ σ 2 + ϵ y i = γ d x ^ i + β d \mu = \frac{1}{d} \sum_{i=1}^d x_i \\ \sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 \\ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i = \gamma_{d} \hat{x}_i + \beta_{d} μ=d1i=1dxiσ2=d1i=1d(xiμ)2x^i=σ2+ϵ xiμyi=γdx^i+βd
  其中, ϵ \epsilon ϵ是较小的数,为了防止分母为0, γ d \gamma_{d} γd β d \beta_{d} βd是可训练参数,用于仿射变换
  因此,输出向量形状不发生改变,可训练参数取决于特征维度$d$的个数

代码

import torch
import torch.nn as nn
# 特征维度 d=16,归一化最后一个维度
layernorm = nn.LayerNorm(normalized_shape=16)# 打印可训练参数
print("Trainable parameters (gamma):", layernorm.weight.shape)  # gamma (scale) (16)
print("Trainable parameters (beta):", layernorm.bias.shape)    # beta (shift) (16)# 构造输入张量 (b, t, d)
x = torch.randn(32, 10, 16) # 前向传播
output = layernorm(x)# 打印结果
print("Input shape:", x.shape)       # [32, 10, 16]
print("Output shape:", output.shape) # [32, 10, 16]

作用

  • 加速模型收敛
  • 缓解梯度消失或爆炸问题
  • 降低对初始化的敏感性
  • 一定程度上防止过拟合

通俗理解

  从时序数据(或者一段话)数据 x ∈ R b × t × d x \in \mathbb{R}^{b \times t \times d} xRb×t×d来理解:对于不同时间点的特征,我们经过layernorm以后,不同时间点的特征便失去了可比性或者相对性(时间点1:9,8,7,6;时间点2:8,7,6,5经过归一化后都一样),。。。目前我也不知道怎么直观理解(后续有新的理解持续更新。。)

注意点

和batchnorm一样,在pytorch中,训练和推理时,一定要指定模型当前状态

总结

BatchNorm适用于CV,LayerNorm适用于NLP

本文是我学习过程中的个人理解,有不对的地方希望大家帮忙指出。希望可以抛砖引玉,欢迎大家在评论区和我交流。

相关文章:

batchnorm和layernorm的理解

batchnorm和layernorm原理和区别 batchnorm 原理 对于一个特征tensor x ∈ R b c f 1 f 2 … x \in \mathbb{R}^{b \times c \times f_1 \times f_2 \times \dots} x∈Rbcf1​f2​… 其中, c c c是通道, f f f是通道中各种特征,batchno…...

在git commit之前让其自动执行一次git pull命令

文章目录 背景原因编写脚本测试效果 背景原因 有时候可以看到项目的git 提交日志里好多 Merge branch ‘master’ of …记录。这些记录是怎么产生的呢? 是因为在本地操作 git add . 、 git commit -m "xxxxx"时,没有提前进行git pull操作&…...

【Rust自学】6.3. 控制流运算符-match

喜欢的话别忘了点赞、收藏加关注哦,对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 6.3.1. 什么是match match允许一个值与一系列模式进行匹配,并执行匹配的模式对应的代码。模式可以是字面值、变量名、通配符等…...

大模型应用技术系列(三): 深入理解大模型应用中的Cache:GPTCache

前言 无论在什么技术栈中,缓存都是比较重要的一部分。在大模型技术栈中,缓存存在于技术栈中的不同层次。本文将主要聚焦于技术栈中应用层和底层基座之间中间件层的缓存(个人定位),以开源项目GPTCache(LLM的语义缓存)为例,深入讲解这部分缓存的结构和关键实现。 完整技术…...

『大模型笔记』评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释

评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释 文章目录 一. ELO Rating大模型的elo得分如何理解1. Elo评分的基本原理2. 示例说明3. 大模型中的Elo得分总结3个模型之间如何比较计算,给出示例进行解释1. 基本原理扩展到三方2. 示例计算第一场: A A…...

深度解析:Maven 和 Gradle 的使用比较及常见仓库推荐

Maven 和 Gradle 是 Java 项目中最常用的构建工具。它们各有优势,适用于不同的场景。本文将对两者进行详细的对比,并推荐一些常用的 Maven 和 Gradle 仓库,帮助开发者高效管理依赖。 一、Maven 和 Gradle 的使用比较 1.1 基本介绍 Maven 基…...

SQLite本地数据库的简介和适用场景——集成SpringBoot的图文说明

前言:现在项目普遍使用的数据库都是MySQL,而有些项目实际上使用SQLite既足矣。在一些特定的项目中,要比MySQL更适用。 这一篇文章简单的介绍一下SQLite,对比MySQL的优缺点、以及适用的项目类型和集成SpringBoot。 1. SQLite 简介 …...

管理面板Ajenti的在Windows10下Ubuntu24.04/Ubuntu22.04里的安装

Ajenti是一款基于Web的开源系统管理控制面板,可用于通过Web浏览器,管理远程系统管理性任务,这一点与 Webmin模块 非常相似。 Ajenti是一款功能非常强大的轻型工具,它提供了快速的、反应灵敏的Web界面,可用于管理小型服…...

在Python如何用Type创建类

文章目录 一,如何创建类1:创建一个简单类2:添加属性和方法3:动态继承父类4:结合元类的使用总结 二.在什么情境下适合使用Type创建类1. **运行时动态生成类**2. **避免重复代码**3. **依赖元类或高级元编程**4. **动态扩…...

Android学习19 -- NDK4--共享内存(TODO)

在安卓的NDK(Native Development Kit)中,C共享内存通常用于不同进程间的通信,或者在同一进程中多线程之间共享数据。这种方法相较于其他形式的IPC(进程间通信)来说,具有更高的性能和低延迟。共享…...

《Cocos Creator游戏实战》非固定摇杆实现原理

为什么要使用非固定摇杆 许多同学在开发摇杆功能时,会将摇杆固定在屏幕左下某一位置,不会让其随着大拇指触摸点改变,而且玩家只有按在了摇杆上才能移动人物(触摸监听事件在摇杆精灵上)。然而,不同玩家的大拇指长度不同…...

RabbitMQ工作模式(详解 工作模式:简单队列、工作队列、公平分发以及消息应答和消息持久化)

文章目录 十.RabbitMQ10.1 简单队列实现10.2 Work 模式(工作队列)10.3 公平分发10.4 RabbitMQ 消息应答与消息持久化消息应答概念配置 消息持久化概念配置 十.RabbitMQ 10.1 简单队列实现 简单队列通常指的是一个基本的消息队列,它可以用于…...

【VScode】第三方GPT编程工具-CodeMoss安装教程

一、CodeMoss是什么? CodeMoss是一款集编程、学习和办公于一体的高效工具。它兼容多种主流平台,包括VSCode、IDER、Chrome插件、Web和APP等,支持插件安装,尤其在VSCode和IDER上的表现尤为出色。无论你是编程新手还是资深开发者&a…...

在JavaScript中,let 和 const有什么不同

在JavaScript中,let 和 const 是用于声明变量的关键字,但它们有一些重要的区别 1.重新赋值: let 声明的变量可以重新赋值。const 声明的变量必须在声明时初始化,并且之后不能重新赋值 let a 10; a 20; // 有效,a 的…...

Mysq学习-Mysql查询(4)

5.子查询 子查询指一个查询语句嵌套在另一个查询语句内部的查询,这个特性从MySQL4.1开始引入.在SELECT子句中先计算子查询,子查询结果作为外层另一个查询的过滤条件,查询可以基于一个表或者多个表. 子查询中常用的操作符有ANY(SOME),ALL,IN,EXISTS.子查询可以添加到SELECT,UPD…...

安装torch-geometric库

目录 1.查看 torch 和 CUDA 版本 2.依次下载和 torch 和 CUDA 对应版本的四个依赖库pyg-lib、torch-scatter、torch-sparse、torch-cluster以及torch-spline-conv 3.下载并安装torch-geometric库 1.查看 torch 和 CUDA 版本 查看CUDA版本 nvcc -V 查看pytorch版本 pip s…...

Java数组深入解析:定义、操作、常见问题与高频练习

一、数组的定义 1. 什么是数组 数组是一个容器,用来存储多个相同类型的数据。它属于引用数据类型,可以存储基本数据类型(如int、char)或者引用数据类型(如String、对象)。 2. 数组的定义方式 a. 动态初…...

Docker-构建自己的Web-Linux系统-镜像webtop:ubuntu-kde

介绍 安装自己的linux-server,可以作为学习使用,web方式访问,基于ubuntu构建开源项目 https://github.com/linuxserver/docker-webtop安装 docker run -d -p 1336:3000 -e PASSWORD123456 --name webtop lscr.io/linuxserver/webtop:ubuntu-kde登录 …...

【C语言练习(17)—输出杨辉三角形】

C语言练习(17) 文章目录 C语言练习(17)前言题目题目解析整体代码 前言 杨辉三角形的输出可以分三步,第一步构建一个三角形、第二步根据规律将三角形内容填写、第三步将三角形以等腰的形式输出 题目 请输出一个十行的…...

SpringMVC学习(二)——RESTful API、拦截器、异常处理、数据类型转换

一、RESTful (一)RESTful概述 RESTful是一种软件架构风格,用于设计网络应用程序。REST是“Representational State Transfer”的缩写,中文意思是“表现层状态转移”。它基于客户端-服务器模型和无状态操作,以及使用HTTP请求来处理数据。RES…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言: 通过AI视觉技术,为船厂提供全面的安全监控解决方案,涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面,能够实现对应负责人反馈机制,并最终实现数据的统计报表。提升船厂…...

Admin.Net中的消息通信SignalR解释

定义集线器接口 IOnlineUserHub public interface IOnlineUserHub {/// 在线用户列表Task OnlineUserList(OnlineUserList context);/// 强制下线Task ForceOffline(object context);/// 发布站内消息Task PublicNotice(SysNotice context);/// 接收消息Task ReceiveMessage(…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​&#xff1a; 下载安装DevEco Studio 4.0&#xff08;支持HarmonyOS 5&#xff09;配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​&#xff1a; ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错

出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上&#xff0c;所以报错&#xff0c;到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本&#xff0c;cu、torch、cp 的版本一定要对…...

【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)

1.获取 authorizationCode&#xff1a; 2.利用 authorizationCode 获取 accessToken&#xff1a;文档中心 3.获取手机&#xff1a;文档中心 4.获取昵称头像&#xff1a;文档中心 首先创建 request 若要获取手机号&#xff0c;scope必填 phone&#xff0c;permissions 必填 …...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...