当前位置: 首页 > news >正文

指数移动平均(EMA)

文章目录

  • 前言
  • EMA的定义
  • 在深度学习中的应用
    • PyTorch代码实现
    • yolov5中模型的EMA实现
  • 参考


前言

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
实际上,_EMA可以看作是Temporal Ensembling,在模型学习过程中融合更多的历史状态,从而达到更好的优化效果。

EMA的定义

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
假设有n个权重数据image.png

  • 普通的平均数:image.png
  • EMA:image.png

其中, vt表示前 t条的平均值 ( v0=0 ),β是加权权重值 (一般设为0.9-0.999)。

Andrew Ng在Course 2 Improving Deep Neural Networks中讲到,EMA可以近似看成过去 1/(1−β) 个时刻 v 值的平均。
普通的过去n时刻的平均是这样的:image.png
类比EMA,可以发现当 image.png 时,两式形式上相等。需要注意的是,两个平均并不是严格相等的,这里只是为了帮助理解。
实际上,EMA计算时,过去 1/(1−β) 个时刻之前的数值平均会decay到 1/e 的加权比例,证明如下。
如果将这里的 vt展开,可以得到:
image.png
其中, image.png,代入可以得到 image.png

在深度学习中的应用

上面讲的是广义的ema定义和计算方法,特别的,在深度学习的优化过程中, image.png是t时刻的模型权重weights, vt是t时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练。
基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒。

PyTorch代码实现

下面是一个简单的指数移动平均(EMA)的PyTorch实现:

import torchclass EMA():def __init__(self, alpha):self.alpha = alpha    # 初始化平滑因子alphaself.average = None   # 初始化平均值为空self.count = 0        # 初始化计数器为0def update(self, x):if self.average is None:  # 如果平均值为空,则将其初始化为与x相同大小的全零张量self.average = torch.zeros_like(x)self.average = self.alpha * x + (1 - self.alpha) * self.average  # 更新平均值self.count += 1   # 更新计数器def get(self):return self.average / (1 - self.alpha ** self.count)   # 根据计数器和平滑因子计算EMA值,并返回平均值除以衰减系数的结果

在这个类中,我们定义了三个方法,分别是__init__、update和get。

  • __init__方法用于初始化平滑因子alpha、平均值average和计数器count
  • update方法用于更新EMA值
  • get方法用于获取最终的EMA值。

使用这个类时,我们可以先实例化一个EMA对象,然后在每个时间步中调用update方法来更新EMA值,最后调用get方法来获取最终的EMA值。
例如:

ema = EMA(alpha=0.5)
for value in data:ema.update(torch.tensor(value))
smoothed_data = ema.get()

在这个例子中,我们使用alpha=0.5来初始化EMA对象,然后遍历数据集data中的每个数据点,调用update方法更新EMA值。最后我们调用get方法来获取平滑后的数据。

yolov5中模型的EMA实现

如下:

class ModelEMA:""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-modelsKeeps a moving average of everything in the model state_dict (parameters and buffers)For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage"""def __init__(self, model, decay=0.9999, tau=2000, updates=0):# Create EMAself.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMAself.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)for p in self.ema.parameters():p.requires_grad_(False)def update(self, model):# Update EMA parametersself.updates += 1d = self.decay(self.updates)msd = de_parallel(model).state_dict()  # model state_dictfor k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:  # true for FP16 and FP32v *= dv += (1 - d) * msd[k].detach()# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):# Update EMA attributescopy_attr(self.ema, model, include, exclude)

参考

https://zhuanlan.zhihu.com/p/68748778


如果有用,请点个三连呗 点赞、关注、收藏
你的鼓励是我最大的动力

相关文章:

指数移动平均(EMA)

文章目录 前言EMA的定义在深度学习中的应用PyTorch代码实现yolov5中模型的EMA实现 参考 前言 在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。实际上,_EMA可以…...

无线表格识别模型LORE转换库:ConvertLOREToONNX

引言 总有小伙伴问到阿里的无线表格识别模型是如何转换为ONNX格式的。这个说来有些惭愧,现有的ONNX模型是很久之前转换的了,转换环境已经丢失,且没有做任何笔记。 今天下定决心再次尝试转换,庆幸的是转换成功了。于是有了转换笔…...

C# 视频转图片

在 C# 中将视频转换为图像可以使用 FFmpeg 库。下面是一个示例代码来完成这个任务: using System; using System.Diagnostics;class Program {static void Main(string[] args){string inputFile "input_video.mp4"; // 输入的视频文件路径string outpu…...

LINUX ADC使用

监测 ADC ,使用CAT 查看&#xff1a; LINUX ADC基本使用 &adc {pinctrl-names "default";pinctrl-0 <&adc6>;pinctrl-1 <&adc7>;pinctrl-2 <&adc8>;pinctrl-3 <&adc9>;pinctrl-4 <&adc10>;pinctrl-5 …...

Ubuntu 基本操作-嵌入式 Linux 入门

在 Ubuntu 基本操作 里面基本就分为两部分&#xff1a; 安装 VMware 运行 Ubuntu熟悉 Ubuntu 的各种操作、命令 如果你对 Ubuntu 比较熟悉的话&#xff0c;安装完 VMware 运行 Ubuntu 之后就可以来学习下一章节了。 1. 安装 VMware 运行 Ubuntu 我们首先来看看怎么去安装 V…...

Pytorch可形变卷积分类模型与可视化

E:. │ archs.py │ dataset.py │ deform_conv_v2.py │ train.py │ utils.py │ visual_net.py │ ├─grad_cam │ 2.png │ 3.png │ ├─image │ ├─1 │ │ 154.png │ │ 2.png │ │ │ ├─2 │ │ 143.png │…...

Mysql 表逻辑分区原理和应用

MySQL的表逻辑分区是一种数据库设计技术&#xff0c;它允许将一个表的数据分布在多个物理分区中&#xff0c;但在逻辑上仍然表现为一个单一的表。这种方式可以提高查询性能、简化数据管理&#xff0c;并有助于高效地进行大数据量的存储和访问。逻辑分区基于特定的规则&#xff…...

架构面试题汇总:网络协议34问(七)

码到三十五 &#xff1a; 个人主页 心中有诗画&#xff0c;指尖舞代码&#xff0c;目光览世界&#xff0c;步履越千山&#xff0c;人间尽值得 ! 网络协议是实现各种设备和应用程序之间顺畅通信的基石。无论是构建分布式系统、开发Web应用&#xff0c;还是进行网络通信&#x…...

lida,一个超级厉害的 Python 库!

目录 前言 什么是 lida 库&#xff1f; lida 库的安装 基本功能 1. 文本分词 2. 词性标注 3. 命名实体识别 高级功能 1. 情感分析 2. 关键词提取 实际应用场景 1. 文本分类 2. 情感分析 3. 实体识别 总结 前言 大家好&#xff0c;今天为大家分享一个超级厉害的 Python …...

K好数 C语言 蓝桥杯算法提升ALGO3 一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字

问题描述 如果一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字&#xff0c;那么我们就说这个数是K好数。求L位K进制数中K好数的数目。例如K 4&#xff0c;L 2的时候&#xff0c;所有K好数为11、13、20、22、30、31、33 共7个。由于这个数目很大&#xff0c;请你输…...

2195. 深海机器人问题(网络流,费用流,上下界可行流,网格图模型)

活动 - AcWing 深海资源考察探险队的潜艇将到达深海的海底进行科学考察。 潜艇内有多个深海机器人。 潜艇到达深海海底后&#xff0c;深海机器人将离开潜艇向预定目标移动。 深海机器人在移动中还必须沿途采集海底生物标本。 沿途生物标本由最先遇到它的深海机器人完成采…...

Vue/cli项目全局css使用

第一步&#xff1a;创建css文件 在合适的位置创建好css文件&#xff0c;文件可以是sass/less/stylus...第二步&#xff1a;响预处理器loader传递选项 //摘自官网&#xff0c;引入样式 // vue.config.js module.exports {css: {loaderOptions: {// 给 sass-loader 传递选项sa…...

【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM

BitNet&#xff1a;用1-bit Transformer训练LLM 《BitNet: Scaling 1-bit Transformers for Large Language Models》 论文地址&#xff1a;https://arxiv.org/pdf/2310.11453.pdf 相关博客 【自然语言处理】【大模型】BitNet&#xff1a;用1-bit Transformer训练LLM 【自然语言…...

安装及管理docker

文章目录 1.Docker介绍2.Docker安装3.免sudo设置4. 使用docker命令5.Images6.运行docker容器7. 管理docker容器8.创建image9.Push Image 1.Docker介绍 Docker 是一个简化在容器中管理应用程序进程的应用程序。容器让你在资源隔离的进程中运行你的应用程序。类似于虚拟机&#…...

【MySQL】表的增删改查——MySQL基本查询、数据库表的创建、表的读取、表的更新、表的删除

文章目录 MySQL表的增删查改1. Create&#xff08;创建&#xff09;1.1 单行插入1.2 多行插入1.3 替换 2. Retrieve&#xff08;读取&#xff09;2.1 select查看2.2 where条件2.3 结果排序2.4 筛选分页结果 3. Update&#xff08;更新&#xff09;3.1 更新单个数据3.2 更新多个…...

C/C++蓝桥杯之日期问题

问题描述&#xff1a;小明正在整理一批文献&#xff0c;这些文献中出现了很多日期&#xff0c;小明知道这些日期都在1960年1月1日至2059年12月31日之间&#xff0c;令小明头疼的是&#xff0c;这些日期采用的格式非常不统一&#xff0c;有采用年/月/日的&#xff0c;有采用月/日…...

【理解指针(二)】

文章目录 一、指针的运算&#xff08;1&#xff09;指针加整数&#xff08;2&#xff09;指针减指针&#xff08;指针关系运算&#xff09; 二、野指针&#xff08;1&#xff09;野指针的成因&#xff08;1.1&#xff09;指针未初始化&#xff08;1.2&#xff09;指针的越界访问…...

使用AI纠正文章

我写了一段关于哲学自学的读书笔记&#xff0c;处于好奇的目的&#xff0c;让AI帮我纠正语法和逻辑。我的原文如下&#xff1a; 泰勒斯第一次提出了水是万物本源的说法&#xff0c;对于泰勒斯为什么提出这样的观点&#xff0c;或者是这样的观点是怎么来的&#xff0c;我们无从所…...

拼多多API批量获取商品详情信息

随着电子商务的蓬勃发展&#xff0c;淘宝作为中国最大的在线购物平台之一&#xff0c;每天需要处理海量的商品上架和交易。为了提高工作效率&#xff0c;自动化上架商品和批量获取商品详情信息成为了许多商家和开发者的迫切需求。本文将详细介绍淘宝的API接口及其相关技术&…...

杨辉三角(C语言)

杨辉三角 一.什么是杨辉三角 一.什么是杨辉三角 每个数等于它上方两数之和。 每行数字左右对称&#xff0c;由1开始逐渐变大。 第n行的数字有n项。 前n行共[(1n)n]/2 个数。 … 当前行的数上一行的数上一行的前一列的数 void yanghuisanjian(int arr[][20], int n) {for (int i…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇&#xff0c;在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下&#xff1a; 【Note】&#xff1a;如果你已经完成安装等操作&#xff0c;可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作&#xff0c;重…...

线程同步:确保多线程程序的安全与高效!

全文目录&#xff1a; 开篇语前序前言第一部分&#xff1a;线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分&#xff1a;synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分&#xff…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

如何为服务器生成TLS证书

TLS&#xff08;Transport Layer Security&#xff09;证书是确保网络通信安全的重要手段&#xff0c;它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书&#xff0c;可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中&#xff0c;元素的定位通过 position 属性控制&#xff0c;共有 5 种定位模式&#xff1a;static&#xff08;静态定位&#xff09;、relative&#xff08;相对定位&#xff09;、absolute&#xff08;绝对定位&#xff09;、fixed&#xff08;固定定位&#xff09;和…...

云原生玩法三问:构建自定义开发环境

云原生玩法三问&#xff1a;构建自定义开发环境 引言 临时运维一个古董项目&#xff0c;无文档&#xff0c;无环境&#xff0c;无交接人&#xff0c;俗称三无。 运行设备的环境老&#xff0c;本地环境版本高&#xff0c;ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

第7篇:中间件全链路监控与 SQL 性能分析实践

7.1 章节导读 在构建数据库中间件的过程中&#xff0c;可观测性 和 性能分析 是保障系统稳定性与可维护性的核心能力。 特别是在复杂分布式场景中&#xff0c;必须做到&#xff1a; &#x1f50d; 追踪每一条 SQL 的生命周期&#xff08;从入口到数据库执行&#xff09;&#…...

安卓基础(Java 和 Gradle 版本)

1. 设置项目的 JDK 版本 方法1&#xff1a;通过 Project Structure File → Project Structure... (或按 CtrlAltShiftS) 左侧选择 SDK Location 在 Gradle Settings 部分&#xff0c;设置 Gradle JDK 方法2&#xff1a;通过 Settings File → Settings... (或 CtrlAltS)…...

flow_controllers

关键点&#xff1a; 流控制器类型&#xff1a; 同步&#xff08;Sync&#xff09;&#xff1a;发布操作会阻塞&#xff0c;直到数据被确认发送。异步&#xff08;Async&#xff09;&#xff1a;发布操作非阻塞&#xff0c;数据发送由后台线程处理。纯同步&#xff08;PureSync…...

GB/T 43887-2024 核级柔性石墨板材检测

核级柔性石墨板材是指以可膨胀石墨为原料、未经改性和增强、用于核工业的核级柔性石墨板材。 GB/T 43887-2024核级柔性石墨板材检测检测指标&#xff1a; 测试项目 测试标准 外观 GB/T 43887 尺寸偏差 GB/T 43887 化学成分 GB/T 43887 密度偏差 GB/T 43887 拉伸强度…...