当前位置: 首页 > news >正文

Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

flyfish

目录

  • Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)
    • 先看LayerNorm和BatchNorm
    • 举个例子计算 LayerNorm
    • RMSNorm 的整个计算过程
      • 实际代码实现
      • 结果

先看LayerNorm和BatchNorm

展示计算的方向
在这里插入图片描述

  • axis=0 代表第一个轴,逐列处理数据。
  • axis=1 代表第二个轴,逐行处理数据。在二维数组中,axis=-1 等同于 axis=1。
  • axis=-1 代表最后一个轴。在二维数组中,axis=-1 等同于 axis=1,即最后一个轴。

在二维的情况 下,BatchNorm是按列算,LayerNorm按行算

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nnclass CustomLayerNorm:def __init__(self, eps=1e-5):self.eps = epsdef __call__(self, x):mean = np.mean(x, axis=-1, keepdims=True)std = np.std(x, axis=-1, keepdims=True)normalized = (x - mean) / (std + self.eps)return normalizedclass CustomBatchNorm:def __init__(self, eps=1e-5):self.eps = epsdef __call__(self, x):mean = np.mean(x, axis=0)std = np.std(x, axis=0)normalized = (x - mean) / (std + self.eps)return normalized# Original Data
data = np.array([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0],[7.0, 8.0, 9.0]])# Apply Custom LayerNorm
custom_layer_norm = CustomLayerNorm()
custom_layer_norm_data = custom_layer_norm(data)# Apply Custom BatchNorm
custom_batch_norm = CustomBatchNorm()
custom_batch_norm_data = custom_batch_norm(data)# Apply PyTorch LayerNorm
data_tensor = torch.tensor(data, dtype=torch.float32)
layer_norm = nn.LayerNorm(data_tensor.size()[1:])
pytorch_layer_norm_data = layer_norm(data_tensor).detach().numpy()# Compare Custom and PyTorch LayerNorm
print("Original Data:\n", data)
print("Custom LayerNorm Data:\n", custom_layer_norm_data)
print("PyTorch LayerNorm Data:\n", pytorch_layer_norm_data)
Original Data:[[1. 2. 3.][4. 5. 6.][7. 8. 9.]]
Custom LayerNorm Data:[[-1.22472987  0.          1.22472987][-1.22472987  0.          1.22472987][-1.22472987  0.          1.22472987]]
PyTorch LayerNorm Data:[[-1.2247356  0.         1.2247356][-1.2247356  0.         1.2247356][-1.2247356  0.         1.2247356]]

举个例子计算 LayerNorm

具体步骤如下:

  1. 计算每行的均值
  • 对每一行,计算其均值。
  • 第1行: mean = (1 + 2 + 3) / 3 = 2
  • 第2行: mean = (4 + 5 + 6) / 3 = 5
  • 第3行: mean = (7 + 8 + 9) / 3 = 8
  1. 计算每行的标准差
  • 对每一行,计算其标准差。
  • 第1行: s t d = s q r t ( ( ( 1 − 2 ) 2 + ( 2 − 2 ) 2 + ( 3 − 2 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((1-2)^2 + (2-2)^2 + (3-2)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((12)2+(22)2+(32)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  • 第2行: s t d = s q r t ( ( ( 4 − 5 ) 2 + ( 5 − 5 ) 2 + ( 6 − 5 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((4-5)^2 + (5-5)^2 + (6-5)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((45)2+(55)2+(65)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  • 第3行: s t d = s q r t ( ( ( 7 − 8 ) 2 + ( 8 − 8 ) 2 + ( 9 − 8 ) 2 ) / 3 ) = s q r t ( ( 1 + 0 + 1 ) / 3 ) = s q r t ( 2 / 3 ) ≈ 0.8165 std = sqrt(((7-8)^2 + (8-8)^2 + (9-8)^2) / 3) = sqrt((1 + 0 + 1) / 3) = sqrt(2 / 3) ≈ 0.8165 std=sqrt(((78)2+(88)2+(98)2)/3)=sqrt((1+0+1)/3)=sqrt(2/3)0.8165
  1. 标准化每一行
  • 对每一行,使用均值和标准差进行标准化。公式为: ( x − m e a n ) / ( s t d + e p s ) (x - mean) / (std + eps) (xmean)/(std+eps)。其中 eps 是一个小常数,防止除零,通常取值为 1e-5。
  • 计算结果如下:

标准化公式: n o r m a l i z e d = ( x − m e a n ) / ( s t d + e p s ) normalized = (x - mean) / (std + eps) normalized=(xmean)/(std+eps)

第1行: 
[(1-2)/(0.8165+1e-5), (2-2)/(0.8165+1e-5), (3-2)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]第2行: 
[(4-5)/(0.8165+1e-5), (5-5)/(0.8165+1e-5), (6-5)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]第3行: 
[(7-8)/(0.8165+1e-5), (8-8)/(0.8165+1e-5), (9-8)/(0.8165+1e-5)]
= [-1.2247, 0, 1.2247]

最终标准化结果矩阵为:

[[-1.2247, 0, 1.2247][-1.2247, 0, 1.2247][-1.2247, 0, 1.2247]]

RMSNorm 的整个计算过程

Meta Llama 3 使用了RMSNorm
假设我们有以下 2D 输入张量 X X X(为了简单起见,我们假设这个张量有 2 行 3 列):
[ 1 2 3 4 5 6 ] \begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} [142536]
RMSNorm 的计算过程如下:

  1. 计算每行的均方根 (RMS)
    首先,对于每一行,我们计算该行元素的平方和的均值,然后取其平方根。
    对于第 1 行:
    RMS row1 = 1 2 + 2 2 + 3 2 3 = 1 + 4 + 9 3 = 4.67 ≈ 2.16 \text{RMS}_{\text{row1}} = \sqrt{\frac{1^2 + 2^2 + 3^2}{3}} = \sqrt{\frac{1 + 4 + 9}{3}} = \sqrt{4.67} \approx 2.16 RMSrow1=312+22+32 =31+4+9 =4.67 2.16
    对于第 2 行:
    RMS row2 = 4 2 + 5 2 + 6 2 3 = 16 + 25 + 36 3 = 25.67 ≈ 5.07 \text{RMS}_{\text{row2}} = \sqrt{\frac{4^2 + 5^2 + 6^2}{3}} = \sqrt{\frac{16 + 25 + 36}{3}} = \sqrt{25.67} \approx 5.07 RMSrow2=342+52+62 =316+25+36 =25.67 5.07
  2. 使用均方根对输入进行归一化
    将每行的元素除以该行的 RMS 值。这里的 epsilon 用于防止除以零的问题,我们假设 ϵ = 1 e − 6 \epsilon = 1e-6 ϵ=1e6
    对于第 1 行: Normed row1 = [ 1 2.16 + ϵ 2 2.16 + ϵ 3 2.16 + ϵ ] ≈ [ 0.462 0.925 1.387 ] \text{Normed}_{\text{row1}} = \begin{bmatrix} \frac{1}{2.16 + \epsilon} & \frac{2}{2.16 + \epsilon} & \frac{3}{2.16 + \epsilon} \end{bmatrix} \approx \begin{bmatrix} 0.462 & 0.925 & 1.387 \end{bmatrix} Normedrow1=[2.16+ϵ12.16+ϵ22.16+ϵ3][0.4620.9251.387]
    对于第 2 行: Normed row2 = [ 4 5.07 + ϵ 5 5.07 + ϵ 6 5.07 + ϵ ] ≈ [ 0.789 0.986 1.183 ] \text{Normed}_{\text{row2}} = \begin{bmatrix} \frac{4}{5.07 + \epsilon} & \frac{5}{5.07 + \epsilon} & \frac{6}{5.07 + \epsilon} \end{bmatrix} \approx \begin{bmatrix} 0.789 & 0.986 & 1.183 \end{bmatrix} Normedrow2=[5.07+ϵ45.07+ϵ55.07+ϵ6][0.7890.9861.183]
  3. 应用可学习的缩放参数
    假设权重参数 weight \text{weight} weight 为一个向量 [ 1 , 1 , 1 ] [1, 1, 1] [1,1,1],表示每个元素的缩放因子。对于第 1 行: Output row1 = [ 0.462 ⋅ 1 0.925 ⋅ 1 1.387 ⋅ 1 ] = [ 0.462 0.925 1.387 ] \text{Output}_{\text{row1}} = \begin{bmatrix} 0.462 \cdot 1 & 0.925 \cdot 1 & 1.387 \cdot 1 \end{bmatrix} = \begin{bmatrix} 0.462 & 0.925 & 1.387 \end{bmatrix} Outputrow1=[0.46210.92511.3871]=[0.4620.9251.387]对于第 2 行: Output row2 = [ 0.789 ⋅ 1 0.986 ⋅ 1 1.183 ⋅ 1 ] = [ 0.789 0.986 1.183 ] \text{Output}_{\text{row2}} = \begin{bmatrix} 0.789 \cdot 1 & 0.986 \cdot 1 & 1.183 \cdot 1 \end{bmatrix} = \begin{bmatrix} 0.789 & 0.986 & 1.183 \end{bmatrix} Outputrow2=[0.78910.98611.1831]=[0.7890.9861.183]

实际代码实现

以下是使用 PyTorch 实现上述步骤的代码示例:

import torch
import torch.nn as nnclass RMSNorm(nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):output = self._norm(x.float()).type_as(x)return output * self.weight# 示例数据
data = torch.tensor([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])# 实例化 RMSNorm 层
rms_norm = RMSNorm(dim=data.size(-1))# 计算归一化后的输出
normalized_data = rms_norm(data)print("Original Data:\n", data)
print("RMSNorm Normalized Data:\n", normalized_data)

结果

运行上述代码后,我们将得到归一化后的数据:

 tensor([[1., 2., 3.],[4., 5., 6.]])
RMSNorm Normalized Data:tensor([[0.4629, 0.9258, 1.3887],[0.7895, 0.9869, 1.1843]], grad_fn=<MulBackward0>)

相关文章:

Meta Llama 3 RMSNorm(Root Mean Square Layer Normalization)

Meta Llama 3 RMSNorm&#xff08;Root Mean Square Layer Normalization&#xff09; flyfish 目录 Meta Llama 3 RMSNorm&#xff08;Root Mean Square Layer Normalization&#xff09;先看LayerNorm和BatchNorm举个例子计算 LayerNormRMSNorm 的整个计算过程实际代码实现结…...

MySQL-6、单表访问方法

前言 前面介绍了MySQL表空间相关的内容。包括区、段、碎片区&#xff0c;还有一些不同的页类型的作用。 &#xff08;如果没有看前面五篇文章&#xff0c;不建议看此篇文章&#xff09; 传送门&#xff1a; MySQL-1、InnoDB行格式 MySQL-2、InnoDB数据页 MySQL-3、索引 M…...

C语言实现三角波生成

C语言实现三角波生成 #include <stdio.h>#define SAMPLE_RATE 10000 // 采样率10kHz=10000Hz 对应100us=0.1ms #define UP_TIME 12.5 //上升时间12.5ms #...

WPF国际化的最佳实践

WPF国际化的最佳实践 1.创建项目资源文件 如果你的项目没有Properties文件夹和Resources.resx文件&#xff0c;可以通过右键项目-资源-常规-添加创建或打开程序集资源 2.添加国际化字符串 打开Resources.resx文件&#xff0c;添加需要翻译的文本字符&#xff0c;并将访问修…...

ctfshow web

【nl】难了 <?php show_source(__FILE__); error_reporting(0); if(strlen($_GET[1])<4){echo shell_exec($_GET[1]); } else{echo "hack!!!"; } ?> //by Firebasky //by Firebasky ?1>nl //先写个文件 ?1*>b //这样子会把所有文件名写在b里…...

【力扣】矩阵中的最长递增路径

一、题目描述 二、解题思路 1、先求出以矩阵中的每个单元格为起点的最长递增路径 题目中说&#xff0c;对于每个单元格&#xff0c;你可以往上&#xff0c;下&#xff0c;左&#xff0c;右四个方向移动。那么以一个单元格为起点的最长递增路径就是&#xff1a;从该单元格往上…...

语音深度鉴伪识别项目实战:基于深度学习的语音深度鉴伪识别算法模型(二)音频数据预处理及去噪算法+Python源码应用

前言 深度学习技术在当今技术市场上面尚有余力和开发空间的&#xff0c;主流落地领域主要有&#xff1a;视觉&#xff0c;听觉&#xff0c;AIGC这三大板块。 目前视觉板块的框架和主流技术在我上一篇基于Yolov7-LPRNet的动态车牌目标识别算法模型已有较为详细的解说。与AIGC相…...

网络原理——http/https ---http(1)

T04BF &#x1f44b;专栏: 算法|JAVA|MySQL|C语言 &#x1faf5; 今天你敲代码了吗 网络原理 HTTP/HTTPS HTTP,全称为"超文本传输协议" HTTP 诞⽣与1991年. ⽬前已经发展为最主流使⽤的⼀种应⽤层协议. 实际上,HTTP最新已经发展到 3.0 但是当前行业中主要使用的HT…...

Docker安装、使用,容器化部署springboot项目

目录 一、使用官方安装脚本自动安装 二、Docker离线安装 1. 下载安装包 2. 解压 3.创建docker.service文件 4. 启动docker 三、docker常用命令 1. docker常用命令 2. docker镜像命令 3. docker镜像下载 4.docker镜像push到仓库 5. docker操作容器 6.docker …...

USB主机模式——Android

理论 摘自&#xff1a;USB 主机和配件概览 | Connectivity | Android Developers (google.cn) Android 通过 USB 配件和 USB 主机两种模式支持各种 USB 外围设备和 Android USB 配件&#xff08;实现 Android 配件协议的硬件&#xff09;。 在 USB 主机模式下&#xff0…...

240520Scala笔记

240520Scala笔记 第 7 章 集合 7.1 集合1 数组Array 集合(Test01_ImmutableArray): package chapter07 ​ object Test01_ImmutableArray {def main(args: Array[String]): Unit {// 1. 创建数组val arr: Array[Int] new Array[Int](5)// 另一种创建方式val arr2 Array(…...

【React】封装一个好用方便的消息框(Hooks Bootstrap 实践)

引言 以 Bootstrap 为例&#xff0c;使用模态框编写一个简单的消息框&#xff1a; import { useState } from "react"; import { Modal } from "react-bootstrap"; import Button from "react-bootstrap/Button"; import bootstrap/dist/css/b…...

tomcat10部署踩坑记录-公网IP和服务器系统IP搞混

1. 服务器基本条件 使用的阿里云服务器&#xff0c;镜像系统是Ubuntu16.04java version “17.0.11” 2024-04-16 LTS装的是tomcat10.1.24阿里云服务器安全组放行了&#xff1a;8080端口 服务器防火墙关闭&#xff1a; 监听情况和下图一样&#xff1a; tomcat正常启动&#xff…...

探索Sass:Web开发的强大工具

在现代Web开发中,CSS(层叠样式表)作为前端样式设计的核心技术,已经发展得非常成熟。然而,随着Web应用的复杂性不断增加,传统的CSS书写方式逐渐暴露出一些不足之处,如代码冗长、难以维护、缺乏编程功能等。为了解决这些问题,Sass(Syntactically Awesome Stylesheets)应…...

vue组件之间的通信方式有哪些

在开发过程中&#xff0c;数据传输是一个核心的知识点&#xff0c;掌握了数据传输&#xff0c;相当于掌握了80%的内容。 Vue.js 提供了多种组件间的通信方式&#xff0c;这些方式适应不同的场景和需求。下面是4种常见的通信方式&#xff1a; 1. Props & Events (父子组件通…...

111、二叉树的最小深度

给定一个二叉树&#xff0c;找出其最小深度。最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 题解&#xff1a;找出最小深度也就是找出根节点相对所有叶子结点的最小高度&#xff0c;在这也表明了根节点的高度是变化的&#xff0c;相对不同的叶子结点有不同的高度。…...

SpringBoot3依赖管理,自动配置

文章目录 1. 项目新建2. 相关pom依赖3. 依赖管理机制导入 starter 所有相关依赖都会导入进来为什么版本号都不用写&#xff1f;如何自定义版本号第三方的jar包 4. 自动配置机制5. 核心注解 1. 项目新建 直接建Maven项目通过官方提供的Spring Initializr项目创建 2. 相关pom依…...

音视频开发17 FFmpeg 音频解码- 将 aac 解码成 pcm

这一节&#xff0c;接 音视频开发12 FFmpeg 解复用详情分析&#xff0c;前面我们已经对一个 MP4文件&#xff0c;或者 FLV文件&#xff0c;或者TS文件进行了 解复用&#xff0c;解出来的 视频是H264,音频是AAC&#xff0c;那么接下来就要对H264和AAC进行处理&#xff0c;这一节…...

vue2中封装图片上传获取方法类(针对后端返回的数据不是图片链接,只是图片编号)

在Vue 2中实现商品列表中带有图片编号&#xff0c;并将返回的图片插入到商品列表中&#xff0c;可以通过以下步骤完成&#xff1a; 在Vue组件的data函数中定义商品列表和图片URL数组。 创建一个方法来获取每个商品的图片URL。 使用v-for指令在模板中遍历商品列表&#xff0c;并…...

【C++面向对象编程】(二)this指针和静态成员

文章目录 this指针和静态成员this指针静态成员 this指针和静态成员 this指针 C中类的成员变量和成员函数的存储方式有所不同&#xff1a; 成员变量&#xff1a;对象的成员变量直接作为对象的一部分存储在内存中。成员函数&#xff1a;成员函数&#xff08;非静态成员函数&am…...

React Native 开发环境搭建(全平台详解)

React Native 开发环境搭建&#xff08;全平台详解&#xff09; 在开始使用 React Native 开发移动应用之前&#xff0c;正确设置开发环境是至关重要的一步。本文将为你提供一份全面的指南&#xff0c;涵盖 macOS 和 Windows 平台的配置步骤&#xff0c;如何在 Android 和 iOS…...

五年级数学知识边界总结思考-下册

目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解&#xff1a;由来、作用与意义**一、知识点核心内容****二、知识点的由来&#xff1a;从生活实践到数学抽象****三、知识的作用&#xff1a;解决实际问题的工具****四、学习的意义&#xff1a;培养核心素养…...

使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装

以下是基于 vant-ui&#xff08;适配 Vue2 版本 &#xff09;实现截图中照片上传预览、删除功能&#xff0c;并封装成可复用组件的完整代码&#xff0c;包含样式和逻辑实现&#xff0c;可直接在 Vue2 项目中使用&#xff1a; 1. 封装的图片上传组件 ImageUploader.vue <te…...

【Go】3、Go语言进阶与依赖管理

前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课&#xff0c;做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程&#xff0c;它的核心机制是 Goroutine 协程、Channel 通道&#xff0c;并基于CSP&#xff08;Communicating Sequential Processes&#xff0…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

浅谈不同二分算法的查找情况

二分算法原理比较简单&#xff0c;但是实际的算法模板却有很多&#xff0c;这一切都源于二分查找问题中的复杂情况和二分算法的边界处理&#xff0c;以下是博主对一些二分算法查找的情况分析。 需要说明的是&#xff0c;以下二分算法都是基于有序序列为升序有序的情况&#xf…...

【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分

一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计&#xff0c;提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合&#xff1a;各模块职责清晰&#xff0c;便于独立开发…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索&#xff08;基于物理空间 广播范围&#xff09;2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...