当前位置: 首页 > news >正文

从 X 入门Pytorch——BN、LN、IN、GN 四种归一化层的代码使用和原理

Pytorch中四种归一化层的原理和代码使用

    • 前言
    • 1 Batch Normalization(2015年提出)
      • Pytorch官网解释
      • 原理
      • Pytorch代码
      • 示例
    • 2 Layer Normalization(2016年提出)
      • Pytorch官网解释
      • 原理
      • Pytorch代码
      • 示例
    • 3 Instance Normalization(2017年提出)
      • Pytorch官方解释
      • 原理
      • Pytorch代码
      • 示例
    • 4 Group Normalization(2018年提出)
      • Pytorch官方解释
      • 原理
      • Pytorch代码
      • 示例

前言

在训练神经网络时,往往需要标准化(Normalization)输入数据,使得网络的训练更加快速和有效,然而SGD等学习算法会在训练中不断改变网络的参数,隐含层的激活值的分布会因此发生变化,而这一种变化就称为内协变量偏移(Internal Covariate Shift,ICS)

为了减轻ICS问题,Nomalization固定激活函数的输入变量的均值和方差(Nomalization后接激活函数),使得网络的训练更快。其计算方式大白话如下:将输入的数据,按照一定维度进行统计,得到数据的均值和方差,然后对数据结合均值和方差进行计算,将其变化为服从均值为0,方差为1的高斯分布,这样有助于加快网络训练速度。同时为了降低由于数据变化导致降低模型的表达能力,Nomalization提供了两个参数,用来对数据进行仿射变换,让这些数据靠近原始数据(最极限情况是这两个参数又将变换后的数据变为原来的数据)。

以下图片清晰的给出了不同归一化层的区别:

在这里插入图片描述

我当时见过另外一种图,较为抽象,一直没太看懂,当时看到这个图,醍醐灌顶,一下就都通了。

1 Batch Normalization(2015年提出)

Pytorch官网解释

BatchNorm2d

原理

针对输入到BN层的数据X,对所有 batch单个通道做归一化,每个通道都分别做一次,公式如下:
y=x−E[x]Var[x]+ϵ∗γ+β\mathrm{y}=\frac{\mathrm{x}-\mathrm{E}[\mathrm{x}]}{\sqrt{\mathrm{Var}[\mathrm{x}]+\epsilon}} * \gamma+\beta y=Var[x]+ϵxE[x]γ+β
其中:

  • E[x]\mathrm{E}[\mathrm{x}]E[x] 是向量x的均值
  • Var[x]\mathrm{Var}[\mathrm{x}]Var[x] 是向量x的方差
  • ϵ\epsilonϵ 为很小的常数,通常为0.00001,防止分母为0
  • γ\gammaγ 为仿射变换参数,模型可学习
  • β\betaβ 为放射变换参数,模型可学习

公式中gama之前的数据就是标准化后的数据,满足均值为0,方差为1的高斯分布,便于加快网络训练速度。
但是标准化有可能会降低模型的表达能力,因为网络中的某些隐藏层很有可能就是血需要输入数据是非标准分布的,因此提供gama和beta进行学习,用于恢复网络的表达能力。

Pytorch代码

torch.nn.BatchNorm2d(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)
  • num_features: 传入的通道数
  • eps:常数ϵ\epsilonϵ, 默认为0.00001
  • monentum:动量参数,用来控制均值和方差的更新
  • affine:仿射变换的开关:默认打开
    • 如果 affine=False,则γ\gammaγ=1, β\betaβ=0,参数不能学习
    • 如果 affine=True, 则γ\gammaγ,β\betaβ可学习
  • track_running_stats: 如果为 True,则统计跟踪 batch 的个数,记录在 num_batches_tracked 中,一般不用管。

示例

import torch
import torch.nn as nninput = torch.randn(20, 6, 10, 10)
m = nn.BatchNorm2d(6) y = m(input)
print(y.shape)out:
torch.Size([20, 6, 10, 10])

本文介绍的四种归一化层都不改变输入数据的维度大小!!

2 Layer Normalization(2016年提出)

Pytorch官网解释

LayerNorm

原理

针对输入到LN层的数据X,对单个Batch中的所有通道数据做归一化,然后每个batch都单独做一次,公式如下:
y=x−E[x]Var[x]+ϵ∗γ+β\mathrm{y}=\frac{\mathrm{x}-\mathrm{E}[\mathrm{x}]}{\sqrt{\mathrm{Var}[\mathrm{x}]+\epsilon}} * \gamma+\beta y=Var[x]+ϵxE[x]γ+β
参照Batch Normalization公式和最上面的图进行理解,就是针对不同维度的数据进行标准化,其他的没变。Transformer中使用的就是LayerNorm。

Pytorch代码

torch.nn.LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True)
  • normalized_shape: 输入数据的维度(除了batch维度),例:数据维度【16, 64, 256, 256】 传入的normalized_shape维度为【64, 256, 256】。
  • eps: 常数,默认值为0.00001
  • elementwise_affine:仿射变换的开关:默认打开
    • 如果 elementwise_affine=False,则γ\gammaγ=1, β\betaβ=0,参数不能学习
    • 如果 elementwise_affine=True, 则γ\gammaγ,β\betaβ可学习

示例

import torch
import torch.nn as nnln = nn.LayerNorm([64, 256, 256])
x = torch.randn(16, 64, 256, 256)
y = ln(x)
print(y.shape)out:
torch.Size([16, 64, 256, 256])

我也是最近才用了这个,因为要事先传入维度大小(C,H,W),所以使用的时候没有BN那么方便,如果想用,需要提前自己计算一下数据到LN这里的维度是多少,然后进行实例化才可以。

3 Instance Normalization(2017年提出)

Pytorch官方解释

InstanceNorm2d

原理

针对输入到IN层的数据X,对单个Batch中的单个通道数据做归一化,然后每个batch每个通道单独做一次,公式如下:
y=x−E[x]Var[x]+ϵ∗γ+β\mathrm{y}=\frac{\mathrm{x}-\mathrm{E}[\mathrm{x}]}{\sqrt{\mathrm{Var}[\mathrm{x}]+\epsilon}} * \gamma+\beta y=Var[x]+ϵxE[x]γ+β
公式还是那个公式,理解还是那个理解,请看明白最上面的那种图,就清楚了。

Pytorch代码

torch.nn.InstanceNorm2d(num_features, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
  • num_features: 是通道数
  • eps:常数ϵ\epsilonϵ, 默认为0.00001
  • monentum:动量参数,用来控制均值和方差的更新
  • affine:仿射变换的开关:默认关闭
    • 如果 affine=False,则γ\gammaγ=1, β\betaβ=0,参数不能学习
    • 如果affine=True, 则γ\gammaγ,β\betaβ可学习

示例

# Without Learnable Parameters
m = nn.InstanceNorm2d(100)
# With Learnable Parameters
m = nn.InstanceNorm2d(100, affine=True)
input = torch.randn(20, 100, 35, 45)
output = m(input)
print(output.shape)out:
torch.Size([20, 100, 35, 45])

这个初始化只需要传入通道数即可,和BN的实例化方法一样,便于使用!

4 Group Normalization(2018年提出)

Pytorch官方解释

GroupNoem

原理

针对输入到GN层的数据X,对单个Batch中的多个通道数据(一组数据)做归一化,然后每个batch每组数据单独做一次,公式如下:
y=x−E[x]Var[x]+ϵ∗γ+β\mathrm{y}=\frac{\mathrm{x}-\mathrm{E}[\mathrm{x}]}{\sqrt{\mathrm{Var}[\mathrm{x}]+\epsilon}} * \gamma+\beta y=Var[x]+ϵxE[x]γ+β
公式还是那个公式,理解还是那个理解,请看明白最上面的那种图,就清楚了。
不是我偷懒,实际情况是,Pytorch官网给出的四个公式也就是一样的!

Pytorch代码

torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True)
  • num_groups: 需要将Batch中的通道分为几组
  • num_channels: 传入的数据通道数
  • eps:常数ϵ\epsilonϵ, 默认为0.00001
  • affine:仿射变换的开关:默认打开
    • 如果 affine=False,则γ\gammaγ=1, β\betaβ=0,参数不能学习
    • 如果affine=True, 则γ\gammaγ,β\betaβ可学习

示例

import torch
import torch.nn as nninput = torch.randn(20, 6, 10, 10)m3 = nn.GroupNorm(3, 6)  # 分成3组
m6 = nn.GroupNorm(6, 6)   # 分成6组,这个就和IN一样了
m1 = nn.GroupNorm(1, 6)  # 分成1组,这个就和LN一样了y1 = m1(input)
y3 = m3(input)
y6 = m6(input)
print(y1.shape)
print(y3.shape)
print(y6.shape)out:
torch.Size([20, 6, 10, 10])
torch.Size([20, 6, 10, 10])
torch.Size([20, 6, 10, 10])

参考:

https://blog.csdn.net/qq_33236581/article/details/124016573
https://zhuanlan.zhihu.com/p/470260895

如果感觉有用,就及时收藏点赞,否则后期又要翻找,跟着我,一步一步学CV

相关文章:

从 X 入门Pytorch——BN、LN、IN、GN 四种归一化层的代码使用和原理

Pytorch中四种归一化层的原理和代码使用前言1 Batch Normalization(2015年提出)Pytorch官网解释原理Pytorch代码示例2 Layer Normalization(2016年提出)Pytorch官网解释原理Pytorch代码示例3 Instance Normalization(2…...

Windows环境下实施域名访问的一些小知识

文章目录 前言一、windows域名访问流程二、网络域名访问配置设置DNS未正确设置DNS的结果三、本地hosts设置本地hosts本地hosts的优先机制本地hosts的内部访问次序示例一示例二总结前言 作为一种常见的操作系统,windows系统具有其特殊的域名访问管理机制。了解其访问机制,将有…...

78.qt QCustomPlot介绍

参考https://www.qcustomplot.com/index.php/tutorials/settingup 下载地址: https://www.qcustomplot.com/index.php/download 1.添加帮助文档 在QtCreator ——>工具——>选项——>帮助——>文档——>添加,选择qcustomplot.qch文件,确定,以后按F1就能跳转到…...

win32api之文件系统管理(七)

什么是文件系统 文件系统是一种用于管理计算机存储设备上文件和目录的机制。文件系统为文件和目录分配磁盘空间,管理文件和目录的存储和检索,以及提供对它们的访问和共享,以下是常见的两种文件系统: NTFSFAT32磁盘分区容量2T32G…...

点云规则格网化,且保存原始的点云索引

点云规则格网化,且保存原始的点云索引 点云深度学习Voxelize规则,参考PTV2:https://github.com/Gofinge/PointTransformerV2 1总执行文件 import numpy as np import torch from pcr.utils.registry import Registry TRANSFORMS Registry…...

入职第一天就被迫离职,找工作多月已读不回,面试拿不到offer我该怎么办?

大多数情况下,测试员的个人技能成长速度,远远大于公司规模或业务的成长速度。所以,跳槽成为了这个行业里最常见的一个词汇。 前言 前几天,我们一个粉丝跟我说,正常入职一家外包,什么都准备好了&#xff0…...

走进Vue【三】vue-router详解

目录🌟前言🌟路由🌟什么是前端路由?🌟前端路由优点缺点🌟vue-router🌟安装🌟路由初体验1.路由组件router-linkrouter-view2.步骤1. 定义路由组件2. 定义路由3. 创建 router 实例4. 挂…...

html+css制作

<!DOCTYPE html> <html><head><meta charset"utf-8"><title>校园官网</title><style type"text/css">*{padding: 0;margin: 0;}#logo{width:30%;float: left;}.nav{width: 100%;height: 100px;background-color…...

Python实现rar、zip和7z文件的压缩和解压

一、7z压缩文件的压缩和解压 1、安装py7zr 我们要先安装py7zr第三方库&#xff1a; pip install py7zr如果python环境有问题&#xff0c;执行上面那一条安装语句老是安装在默认的python环境的话&#xff0c;我们可以执行下面这条语句&#xff0c;将第三方库安装在项目的虚拟…...

从Hive源码解读大数据开发为什么可以脱离SQL、Java、Scala

从Hive源码解读大数据开发为什么可以脱离SQL、Java、Scala 前言 【本文适合有一定计算机基础/半年工作经验的读者食用。立个Flg&#xff0c;愿天下不再有肤浅的SQL Boy】 谈到大数据开发&#xff0c;占据绝大多数人口的就是SQL Boy&#xff0c;不接受反驳&#xff0c;毕竟大…...

RocketMQ 事务消息 原理及使用方法解析

&#x1f34a; Java学习&#xff1a;Java从入门到精通总结 &#x1f34a; 深入浅出RocketMQ设计思想&#xff1a;深入浅出RocketMQ设计思想 &#x1f34a; 绝对不一样的职场干货&#xff1a;大厂最佳实践经验指南 &#x1f4c6; 最近更新&#xff1a;2023年3月24日 &#x…...

为什么 ChatGPT 输出时经常会中断,需要输入“继续” 才可以继续输出?

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;蚂蚁集团高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《EffectiveJava》独家解析》专栏作者。 热门文章推荐…...

PyTorch 之 基于经典网络架构训练图像分类模型

文章目录一、 模块简单介绍1. 数据预处理部分2. 网络模块设置3. 网络模型保存与测试二、数据读取与预处理操作1. 制作数据源2. 读取标签对应的实际名字3. 展示数据三、模型构建与实现1. 加载 models 中提供的模型&#xff0c;并且直接用训练的好权重当做初始化参数2. 参考 pyto…...

Scrapy的callback进入不了回调方法

一、前言 有的时候&#xff0c;Scrapy的callback方法直接被略过了&#xff0c;不去执行其中的回调方法&#xff0c;可能排查好久都排查不出来&#xff0c;我来教大家集中解决方法。 yield Request(urlurl, callbackself.parse_detail, cb_kwargs{item: item})二、解决方法 1…...

第二十一天 数据库开发-MySQL

目录 数据库开发-MySQL 前言 1. MySQL概述 1.1 安装 1.2 数据模型 1.3 SQL介绍 1.4 项目开发流程 2. 数据库设计-DDL 2.1 数据库操作 2.2 图形化工具 2.3 表操作 3. 数据库操作-DML 3.1 增加(insert) 3.2 修改(update) 3.3 删除(delete) 数据库开发-MySQL 前言 …...

蓝桥杯每日一真题—— [蓝桥杯 2021 省 AB2] 完全平方数(数论,质因数分解)

文章目录[蓝桥杯 2021 省 AB2] 完全平方数题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1样例 #2样例输入 #2样例输出 #2提示思路&#xff1a;理论补充&#xff1a;完全平方数的一个性质&#xff1a;完全平方数的质因子的指数一定为偶数最终思路&#xff1a;小插曲&am…...

Linux编辑器-vim

一、vim简述1&#xff09;vi/vim2&#xff09;检查vim是否安装2)如何用vim打开文件3)vim的几种模式命令模式插入模式末行模式可视化模式二、vim的基本操作1)进入vim&#xff08;命令行模式&#xff09;2)[命令行模式]切换至[插入模式]3)[插入模式]切换至[命令行模式]4)[命令行模…...

5G将在五方面彻底改变制造业

想象一下这样一个未来&#xff0c;智能机器人通过在工厂车间重新配置自己&#xff0c;从多条生产线上组装产品。安全无人机处理着从监视入侵者到确认员工停车等繁琐的任务。自动驾驶汽车不仅可以在建筑物之间运输零部件&#xff0c;还可以在全国各地运输。工厂检查可以在千里之…...

http和https的区别?

http和https的区别&#xff1f;HTTPHTTPSHTTP与HTTPS区别HTTPS相比于HTTP协议的优点和缺点HTTP http是超文本传输协议 HTTP协议是基于传输层的TCP协议进行通信&#xff0c;通用无状态的协议。80端口 HTTPS https—安全的超文本传输协议 是以安全为目标的HTTP通道&#xff0c;…...

【Spring Cloud Alibaba】4.创建服务消费者

文章目录简介开始搭建创建项目修改POM文件添加启动类添加配置项添加Controller添加配置文件启动项目测试访问Nacos访问接口查看端点检查简介 接下来我们创建一个服务消费者&#xff0c;本操作先要完成之前的步骤&#xff0c;详情请参照【Spring Cloud Alibaba】Spring Cloud A…...

萨科微宋仕强“华强北山寨手机”研究

萨科微宋仕强“华强北山寨手机”研究&#xff08;十六&#xff09;&#xff0c;手机的灰色产业链。华强北每个手机柜台背后都有灰色供应链支撑。如香港手机比华强北便宜&#xff0c;就通过各种渠道从香港走私过来。沙头角的中英街两边分属于香港和深圳&#xff0c;香港一侧的走…...

ChatGPT高质量输出的隐藏开关:基于IEEE写作标准的11项自动校验清单(附可运行Python验证脚本)

更多请点击&#xff1a; https://kaifayun.com 第一章&#xff1a;ChatGPT高质量输出的底层逻辑与认知前提 ChatGPT生成高质量响应并非依赖“魔法”&#xff0c;而是建立在三个核心支柱之上&#xff1a;大规模语言建模的统计涌现能力、人类反馈强化学习&#xff08;RLHF&#…...

三国杀卡牌DIY终极指南:5分钟打造你的专属武将

三国杀卡牌DIY终极指南&#xff1a;5分钟打造你的专属武将 【免费下载链接】Lyciumaker 在线三国杀卡牌制作器 项目地址: https://gitcode.com/gh_mirrors/ly/Lyciumaker 还在羡慕别人能设计出酷炫的三国杀武将卡牌吗&#xff1f;Lyciumaker这个免费开源的三国杀卡牌制作…...

Harness 中的令牌级流控与字符级计费

Harness 中的令牌级流控与字符级计费:从原理到落地的全指南 关键词:Harness CI/CD, 令牌级流控, 字符级计费, 微服务流量治理, 用量计量, 云原生成本优化, 网关限流 摘要:作为全球领先的智能软件交付平台,Harness 每天要处理来自数千家企业客户的上亿次 API 调用、数百万次…...

终极指南:如何快速构建中文手写识别AI系统(免费数据集)

终极指南&#xff1a;如何快速构建中文手写识别AI系统&#xff08;免费数据集&#xff09; 【免费下载链接】Traditional-Chinese-Handwriting-Dataset Open source traditional chinese handwriting dataset. 项目地址: https://gitcode.com/gh_mirrors/tr/Traditional-Chin…...

【.NET新特性·第2篇】C# 12 全特性回顾:语法糖的盛宴

C# 12 带来了主构造函数、集合表达式、Inline Arrays 等 8 个新特性&#xff0c;让代码更简洁 版本定位 适用版本&#xff1a;.NET 8 | C# 12 前置知识&#xff1a;C# 11 基础语法 背景 C# 11 引入了原始字符串字面量、list patterns 等特性&#xff0c;但开发者们期待更多语法…...

高通8650 AudioReach实战:手把手调试GSL-Passthru-GPR数据流(附动态调试脚本)

高通8650 AudioReach实战&#xff1a;GSL-Passthru-GPR数据流调试全指南 当你在深夜的实验室里盯着示波器上那条毫无波动的音频信号线时&#xff0c;手机突然响起一阵刺耳的电流噪声——这可能是每位音频驱动工程师都经历过的噩梦时刻。高通AudioReach架构作为现代移动音频系统…...

【MATLAB】红外图像增强与目标检测实现

【MATLAB】红外图像增强与目标检测实现 摘要:红外成像技术可全天候、无源感知目标热辐射信息,不受光照、雾霾、黑夜环境限制,广泛应用于安防监控、军事侦察、设备故障巡检、森林防火等领域。但受红外传感器噪声、大气衰减、环境杂波干扰影响,原始红外图像普遍存在对比度低…...

OpenRGB:终结RGB灯光管理混乱的终极免费方案

OpenRGB&#xff1a;终结RGB灯光管理混乱的终极免费方案 【免费下载链接】OpenRGB Open source RGB lighting control that doesnt depend on manufacturer software. Supports Windows, Linux, MacOS. Mirror of https://gitlab.com/CalcProgrammer1/OpenRGB. Releases can be…...

ML模型生产部署:从Jupyter到高可用推理服务的工程化实践

1. 项目概述&#xff1a;当模型走出Jupyter&#xff0c;真正开始呼吸真实世界空气“From Notebook to Production: Running ML in the Real World (Part 4)”——这个标题本身就像一句暗号&#xff0c;专为那些在Jupyter里调通了模型、画出了漂亮ROC曲线、却在部署时被生产环境…...