当前位置: 首页 > news >正文

batchnorm和layernorm的理解

batchnorm和layernorm原理和区别

batchnorm

原理

  对于一个特征tensor

x ∈ R b × c × f 1 × f 2 × … x \in \mathbb{R}^{b \times c \times f_1 \times f_2 \times \dots} xRb×c×f1×f2×

  其中, c c c是通道, f f f是通道中各种特征,batchnorm是对所有batch的每个通道特征分别进行归一化,即对于第 i i i个通道,选出他的所有batch的第 i i i个通道的所有特征 x [ : , i , : , : , . . . ] x[:,i,:,:,...] x[:,i,:,:,...]进行归一化:

μ = 1 m ∑ i = 1 m x i σ 2 = 1 m ∑ i = 1 m ( x i − μ ) 2 x ^ i = x i − μ σ 2 + ϵ y i = γ x ^ i + β \mu = \frac{1}{m} \sum_{i=1}^{m} x_i \\ \sigma^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu)^2 \\ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i = \gamma \hat{x}_i + \beta μ=m1i=1mxiσ2=m1i=1m(xiμ)2x^i=σ2+ϵ xiμyi=γx^i+β
  其中, ϵ \epsilon ϵ是较小的数,为了防止分母为0, γ \gamma γ β \beta β是可训练参数,用于仿射变换
  因此,输出向量形状不发生改变,可训练参数取决于通道 c c c的个数

代码

import torch
import torch.nn as nn## BatchNorm1d
batchnorm1d = nn.BatchNorm1d(num_features=10)  # 输入特征数为 10
# 构造输入张量 (N, C)
x_1d = torch.randn(16, 10)  # 模拟训练数据 (batch_size=16, num_features=10)
y = batchnorm1d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C)## BatchNorm2d
batchnorm2d = nn.BatchNorm2d(num_features=16)  # 输入通道数为 16
# 构造输入张量 (N, C, H, W)
x = torch.randn(8, 16, 32, 32)  # 模拟训练数据
y = batchnorm2d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C, H, W)## BatchNorm3d
batchnorm3d = nn.BatchNorm3d(num_features=8)  # 输入通道数为 8
# 构造输入张量 (N, C, D, H, W)
x = torch.randn(4, 8, 16, 32, 32)  # 模拟训练数据
y = batchnorm3d(x)
# 打印结果
print("Output shape :", y.shape)  # (N, C, D, H, W)

作用

  • 加速模型收敛
  • 缓解梯度消失或爆炸问题
  • 降低对初始化的敏感性
  • 一定程度上防止过拟合

通俗理解

    从图片数据 x ∈ R b × c × h × w x \in \mathbb{R}^{b \times c \times h \times w} xRb×c×h×w来理解:对于不同通道的特征,我们经过batchnorm以后,不同通道的特征便失去了可比性或者相对性(通道一:9,8,7,6;通道二:8,7,6,5经过归一化后都一样),但是从人的视角来区分一张图像不是靠不同通道之间的差异性,换句话说RGB稍微变一下变成RRR一样可以区分出图片中物体、属于哪一类。因此我们使用batchnorm相当于告诉模型你应该从同一通道内的数据差异来得到结果,而不是通道间。
    其实你不告诉模型(不用batchnorm),模型在多次训练后也可以发现相同的规律,只是模型训练就慢了,因此我们说batchnorm可以加速收敛

注意点

    模型训练阶段,我们存在batch维度,使用batchnorm时模型的正向计算会受到batch大小影响。那么我们在推理阶段通常只有一个batch,怎么办呢?
    解决方法:在pytorch中,模型训练时,调用 m o d e l . t r a i n ( ) model.train() model.train(),当创建一个 BatchNorm 层时,PyTorch 会自动初始化 r u n n i n g m e a n running_{mean} runningmean r u n n i n g v a r running_{var} runningvar,并将它们存储为模型的缓冲区(buffers)。
    在训练阶段:训练模式下( m o d e l . t r a i n ( ) model.train() model.train()),每次对 mini-batch 进行前向传播时:

r u n n i n g m e a n = m o m e n t u m ⋅ r u n n i n g m e a n + ( 1 − m o m e n t u m ) ⋅ μ b a t c h r u n n i n g v a r = m o m e n t u m ⋅ r u n n i n g v a r + ( 1 − m o m e n t u m ) ⋅ σ b a t c h 2 r u n n i n g v a r = m o m e n t u m ⋅ r u n n i n g v a r + ( 1 − m o m e n t u m ) ⋅ σ b a t c h 2 running_{mean}=momentum⋅running_mean+(1−momentum)⋅μ_{batch} \\ running_{var}=momentum⋅running_var+(1−momentum)⋅σ_{batch}^2 \\ running_{var}=momentum⋅running_var+(1−momentum)⋅σ_{batch}^2 runningmean=momentumrunningmean+(1momentum)μbatchrunningvar=momentumrunningvar+(1momentum)σbatch2runningvar=momentumrunningvar+(1momentum)σbatch2
    在推理阶段:推理模型下( m o d e l . e v a l ( ) model.eval() model.eval()),模型会使用 r u n n i n g m e a n running_{mean} runningmean r u n n i n g v a r running_{var} runningvar进行归一化,从而避免上述问题

layernorm

原理

    对于一个时序特征tensor

x ∈ R b × t × d x \in \mathbb{R}^{b \times t \times d} xRb×t×d

  其中, t t t是时序, d d d是某一时间点的各种特征,layernorm是对某一时间点的特征进行归一化,即对于第 i i i个时间点,分别在各个batch上选出他的所有特征 x [ k , i , : ] x[k,i,:] x[k,i,:]进行归一化:
μ = 1 d ∑ i = 1 d x i σ 2 = 1 d ∑ i = 1 d ( x i − μ ) 2 x ^ i = x i − μ σ 2 + ϵ y i = γ d x ^ i + β d \mu = \frac{1}{d} \sum_{i=1}^d x_i \\ \sigma^2 = \frac{1}{d} \sum_{i=1}^d (x_i - \mu)^2 \\ \hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\ y_i = \gamma_{d} \hat{x}_i + \beta_{d} μ=d1i=1dxiσ2=d1i=1d(xiμ)2x^i=σ2+ϵ xiμyi=γdx^i+βd
  其中, ϵ \epsilon ϵ是较小的数,为了防止分母为0, γ d \gamma_{d} γd β d \beta_{d} βd是可训练参数,用于仿射变换
  因此,输出向量形状不发生改变,可训练参数取决于特征维度$d$的个数

代码

import torch
import torch.nn as nn
# 特征维度 d=16,归一化最后一个维度
layernorm = nn.LayerNorm(normalized_shape=16)# 打印可训练参数
print("Trainable parameters (gamma):", layernorm.weight.shape)  # gamma (scale) (16)
print("Trainable parameters (beta):", layernorm.bias.shape)    # beta (shift) (16)# 构造输入张量 (b, t, d)
x = torch.randn(32, 10, 16) # 前向传播
output = layernorm(x)# 打印结果
print("Input shape:", x.shape)       # [32, 10, 16]
print("Output shape:", output.shape) # [32, 10, 16]

作用

  • 加速模型收敛
  • 缓解梯度消失或爆炸问题
  • 降低对初始化的敏感性
  • 一定程度上防止过拟合

通俗理解

  从时序数据(或者一段话)数据 x ∈ R b × t × d x \in \mathbb{R}^{b \times t \times d} xRb×t×d来理解:对于不同时间点的特征,我们经过layernorm以后,不同时间点的特征便失去了可比性或者相对性(时间点1:9,8,7,6;时间点2:8,7,6,5经过归一化后都一样),。。。目前我也不知道怎么直观理解(后续有新的理解持续更新。。)

注意点

和batchnorm一样,在pytorch中,训练和推理时,一定要指定模型当前状态

总结

BatchNorm适用于CV,LayerNorm适用于NLP

本文是我学习过程中的个人理解,有不对的地方希望大家帮忙指出。希望可以抛砖引玉,欢迎大家在评论区和我交流。

相关文章:

batchnorm和layernorm的理解

batchnorm和layernorm原理和区别 batchnorm 原理 对于一个特征tensor x ∈ R b c f 1 f 2 … x \in \mathbb{R}^{b \times c \times f_1 \times f_2 \times \dots} x∈Rbcf1​f2​… 其中, c c c是通道, f f f是通道中各种特征,batchno…...

在git commit之前让其自动执行一次git pull命令

文章目录 背景原因编写脚本测试效果 背景原因 有时候可以看到项目的git 提交日志里好多 Merge branch ‘master’ of …记录。这些记录是怎么产生的呢? 是因为在本地操作 git add . 、 git commit -m "xxxxx"时,没有提前进行git pull操作&…...

【Rust自学】6.3. 控制流运算符-match

喜欢的话别忘了点赞、收藏加关注哦,对接下来的教程有兴趣的可以关注专栏。谢谢喵!(・ω・) 6.3.1. 什么是match match允许一个值与一系列模式进行匹配,并执行匹配的模式对应的代码。模式可以是字面值、变量名、通配符等…...

大模型应用技术系列(三): 深入理解大模型应用中的Cache:GPTCache

前言 无论在什么技术栈中,缓存都是比较重要的一部分。在大模型技术栈中,缓存存在于技术栈中的不同层次。本文将主要聚焦于技术栈中应用层和底层基座之间中间件层的缓存(个人定位),以开源项目GPTCache(LLM的语义缓存)为例,深入讲解这部分缓存的结构和关键实现。 完整技术…...

『大模型笔记』评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释

评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释 文章目录 一. ELO Rating大模型的elo得分如何理解1. Elo评分的基本原理2. 示例说明3. 大模型中的Elo得分总结3个模型之间如何比较计算,给出示例进行解释1. 基本原理扩展到三方2. 示例计算第一场: A A…...

深度解析:Maven 和 Gradle 的使用比较及常见仓库推荐

Maven 和 Gradle 是 Java 项目中最常用的构建工具。它们各有优势,适用于不同的场景。本文将对两者进行详细的对比,并推荐一些常用的 Maven 和 Gradle 仓库,帮助开发者高效管理依赖。 一、Maven 和 Gradle 的使用比较 1.1 基本介绍 Maven 基…...

SQLite本地数据库的简介和适用场景——集成SpringBoot的图文说明

前言:现在项目普遍使用的数据库都是MySQL,而有些项目实际上使用SQLite既足矣。在一些特定的项目中,要比MySQL更适用。 这一篇文章简单的介绍一下SQLite,对比MySQL的优缺点、以及适用的项目类型和集成SpringBoot。 1. SQLite 简介 …...

管理面板Ajenti的在Windows10下Ubuntu24.04/Ubuntu22.04里的安装

Ajenti是一款基于Web的开源系统管理控制面板,可用于通过Web浏览器,管理远程系统管理性任务,这一点与 Webmin模块 非常相似。 Ajenti是一款功能非常强大的轻型工具,它提供了快速的、反应灵敏的Web界面,可用于管理小型服…...

在Python如何用Type创建类

文章目录 一,如何创建类1:创建一个简单类2:添加属性和方法3:动态继承父类4:结合元类的使用总结 二.在什么情境下适合使用Type创建类1. **运行时动态生成类**2. **避免重复代码**3. **依赖元类或高级元编程**4. **动态扩…...

Android学习19 -- NDK4--共享内存(TODO)

在安卓的NDK(Native Development Kit)中,C共享内存通常用于不同进程间的通信,或者在同一进程中多线程之间共享数据。这种方法相较于其他形式的IPC(进程间通信)来说,具有更高的性能和低延迟。共享…...

《Cocos Creator游戏实战》非固定摇杆实现原理

为什么要使用非固定摇杆 许多同学在开发摇杆功能时,会将摇杆固定在屏幕左下某一位置,不会让其随着大拇指触摸点改变,而且玩家只有按在了摇杆上才能移动人物(触摸监听事件在摇杆精灵上)。然而,不同玩家的大拇指长度不同…...

RabbitMQ工作模式(详解 工作模式:简单队列、工作队列、公平分发以及消息应答和消息持久化)

文章目录 十.RabbitMQ10.1 简单队列实现10.2 Work 模式(工作队列)10.3 公平分发10.4 RabbitMQ 消息应答与消息持久化消息应答概念配置 消息持久化概念配置 十.RabbitMQ 10.1 简单队列实现 简单队列通常指的是一个基本的消息队列,它可以用于…...

【VScode】第三方GPT编程工具-CodeMoss安装教程

一、CodeMoss是什么? CodeMoss是一款集编程、学习和办公于一体的高效工具。它兼容多种主流平台,包括VSCode、IDER、Chrome插件、Web和APP等,支持插件安装,尤其在VSCode和IDER上的表现尤为出色。无论你是编程新手还是资深开发者&a…...

在JavaScript中,let 和 const有什么不同

在JavaScript中,let 和 const 是用于声明变量的关键字,但它们有一些重要的区别 1.重新赋值: let 声明的变量可以重新赋值。const 声明的变量必须在声明时初始化,并且之后不能重新赋值 let a 10; a 20; // 有效,a 的…...

Mysq学习-Mysql查询(4)

5.子查询 子查询指一个查询语句嵌套在另一个查询语句内部的查询,这个特性从MySQL4.1开始引入.在SELECT子句中先计算子查询,子查询结果作为外层另一个查询的过滤条件,查询可以基于一个表或者多个表. 子查询中常用的操作符有ANY(SOME),ALL,IN,EXISTS.子查询可以添加到SELECT,UPD…...

安装torch-geometric库

目录 1.查看 torch 和 CUDA 版本 2.依次下载和 torch 和 CUDA 对应版本的四个依赖库pyg-lib、torch-scatter、torch-sparse、torch-cluster以及torch-spline-conv 3.下载并安装torch-geometric库 1.查看 torch 和 CUDA 版本 查看CUDA版本 nvcc -V 查看pytorch版本 pip s…...

Java数组深入解析:定义、操作、常见问题与高频练习

一、数组的定义 1. 什么是数组 数组是一个容器,用来存储多个相同类型的数据。它属于引用数据类型,可以存储基本数据类型(如int、char)或者引用数据类型(如String、对象)。 2. 数组的定义方式 a. 动态初…...

Docker-构建自己的Web-Linux系统-镜像webtop:ubuntu-kde

介绍 安装自己的linux-server,可以作为学习使用,web方式访问,基于ubuntu构建开源项目 https://github.com/linuxserver/docker-webtop安装 docker run -d -p 1336:3000 -e PASSWORD123456 --name webtop lscr.io/linuxserver/webtop:ubuntu-kde登录 …...

【C语言练习(17)—输出杨辉三角形】

C语言练习(17) 文章目录 C语言练习(17)前言题目题目解析整体代码 前言 杨辉三角形的输出可以分三步,第一步构建一个三角形、第二步根据规律将三角形内容填写、第三步将三角形以等腰的形式输出 题目 请输出一个十行的…...

SpringMVC学习(二)——RESTful API、拦截器、异常处理、数据类型转换

一、RESTful (一)RESTful概述 RESTful是一种软件架构风格,用于设计网络应用程序。REST是“Representational State Transfer”的缩写,中文意思是“表现层状态转移”。它基于客户端-服务器模型和无状态操作,以及使用HTTP请求来处理数据。RES…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

Python如何给视频添加音频和字幕

在Python中,给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加,包括必要的代码示例和详细解释。 环境准备 在开始之前,需要安装以下Python库:…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

Go 并发编程基础:通道(Channel)的使用

在 Go 中,Channel 是 Goroutine 之间通信的核心机制。它提供了一个线程安全的通信方式,用于在多个 Goroutine 之间传递数据,从而实现高效的并发编程。 本章将介绍 Channel 的基本概念、用法、缓冲、关闭机制以及 select 的使用。 一、Channel…...

Java求职者面试指南:计算机基础与源码原理深度解析

Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...

力扣热题100 k个一组反转链表题解

题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...