当前位置: 首页 > news >正文

指数移动平均(EMA)

文章目录

  • 前言
  • EMA的定义
  • 在深度学习中的应用
    • PyTorch代码实现
    • yolov5中模型的EMA实现
  • 参考


前言

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
实际上,_EMA可以看作是Temporal Ensembling,在模型学习过程中融合更多的历史状态,从而达到更好的优化效果。

EMA的定义

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
假设有n个权重数据image.png

  • 普通的平均数:image.png
  • EMA:image.png

其中, vt表示前 t条的平均值 ( v0=0 ),β是加权权重值 (一般设为0.9-0.999)。

Andrew Ng在Course 2 Improving Deep Neural Networks中讲到,EMA可以近似看成过去 1/(1−β) 个时刻 v 值的平均。
普通的过去n时刻的平均是这样的:image.png
类比EMA,可以发现当 image.png 时,两式形式上相等。需要注意的是,两个平均并不是严格相等的,这里只是为了帮助理解。
实际上,EMA计算时,过去 1/(1−β) 个时刻之前的数值平均会decay到 1/e 的加权比例,证明如下。
如果将这里的 vt展开,可以得到:
image.png
其中, image.png,代入可以得到 image.png

在深度学习中的应用

上面讲的是广义的ema定义和计算方法,特别的,在深度学习的优化过程中, image.png是t时刻的模型权重weights, vt是t时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练。
基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒。

PyTorch代码实现

下面是一个简单的指数移动平均(EMA)的PyTorch实现:

import torchclass EMA():def __init__(self, alpha):self.alpha = alpha    # 初始化平滑因子alphaself.average = None   # 初始化平均值为空self.count = 0        # 初始化计数器为0def update(self, x):if self.average is None:  # 如果平均值为空,则将其初始化为与x相同大小的全零张量self.average = torch.zeros_like(x)self.average = self.alpha * x + (1 - self.alpha) * self.average  # 更新平均值self.count += 1   # 更新计数器def get(self):return self.average / (1 - self.alpha ** self.count)   # 根据计数器和平滑因子计算EMA值,并返回平均值除以衰减系数的结果

在这个类中,我们定义了三个方法,分别是__init__、update和get。

  • __init__方法用于初始化平滑因子alpha、平均值average和计数器count
  • update方法用于更新EMA值
  • get方法用于获取最终的EMA值。

使用这个类时,我们可以先实例化一个EMA对象,然后在每个时间步中调用update方法来更新EMA值,最后调用get方法来获取最终的EMA值。
例如:

ema = EMA(alpha=0.5)
for value in data:ema.update(torch.tensor(value))
smoothed_data = ema.get()

在这个例子中,我们使用alpha=0.5来初始化EMA对象,然后遍历数据集data中的每个数据点,调用update方法更新EMA值。最后我们调用get方法来获取平滑后的数据。

yolov5中模型的EMA实现

如下:

class ModelEMA:""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-modelsKeeps a moving average of everything in the model state_dict (parameters and buffers)For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage"""def __init__(self, model, decay=0.9999, tau=2000, updates=0):# Create EMAself.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMAself.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)for p in self.ema.parameters():p.requires_grad_(False)def update(self, model):# Update EMA parametersself.updates += 1d = self.decay(self.updates)msd = de_parallel(model).state_dict()  # model state_dictfor k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:  # true for FP16 and FP32v *= dv += (1 - d) * msd[k].detach()# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):# Update EMA attributescopy_attr(self.ema, model, include, exclude)

参考

https://zhuanlan.zhihu.com/p/68748778


如果有用,请点个三连呗 点赞、关注、收藏
你的鼓励是我最大的动力

相关文章:

指数移动平均(EMA)

文章目录 前言EMA的定义在深度学习中的应用PyTorch代码实现yolov5中模型的EMA实现 参考 前言 在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。实际上,_EMA可以…...

无线表格识别模型LORE转换库:ConvertLOREToONNX

引言 总有小伙伴问到阿里的无线表格识别模型是如何转换为ONNX格式的。这个说来有些惭愧,现有的ONNX模型是很久之前转换的了,转换环境已经丢失,且没有做任何笔记。 今天下定决心再次尝试转换,庆幸的是转换成功了。于是有了转换笔…...

C# 视频转图片

在 C# 中将视频转换为图像可以使用 FFmpeg 库。下面是一个示例代码来完成这个任务: using System; using System.Diagnostics;class Program {static void Main(string[] args){string inputFile "input_video.mp4"; // 输入的视频文件路径string outpu…...

LINUX ADC使用

监测 ADC ,使用CAT 查看&#xff1a; LINUX ADC基本使用 &adc {pinctrl-names "default";pinctrl-0 <&adc6>;pinctrl-1 <&adc7>;pinctrl-2 <&adc8>;pinctrl-3 <&adc9>;pinctrl-4 <&adc10>;pinctrl-5 …...

Ubuntu 基本操作-嵌入式 Linux 入门

在 Ubuntu 基本操作 里面基本就分为两部分&#xff1a; 安装 VMware 运行 Ubuntu熟悉 Ubuntu 的各种操作、命令 如果你对 Ubuntu 比较熟悉的话&#xff0c;安装完 VMware 运行 Ubuntu 之后就可以来学习下一章节了。 1. 安装 VMware 运行 Ubuntu 我们首先来看看怎么去安装 V…...

Pytorch可形变卷积分类模型与可视化

E:. │ archs.py │ dataset.py │ deform_conv_v2.py │ train.py │ utils.py │ visual_net.py │ ├─grad_cam │ 2.png │ 3.png │ ├─image │ ├─1 │ │ 154.png │ │ 2.png │ │ │ ├─2 │ │ 143.png │…...

Mysql 表逻辑分区原理和应用

MySQL的表逻辑分区是一种数据库设计技术&#xff0c;它允许将一个表的数据分布在多个物理分区中&#xff0c;但在逻辑上仍然表现为一个单一的表。这种方式可以提高查询性能、简化数据管理&#xff0c;并有助于高效地进行大数据量的存储和访问。逻辑分区基于特定的规则&#xff…...

架构面试题汇总:网络协议34问(七)

码到三十五 &#xff1a; 个人主页 心中有诗画&#xff0c;指尖舞代码&#xff0c;目光览世界&#xff0c;步履越千山&#xff0c;人间尽值得 ! 网络协议是实现各种设备和应用程序之间顺畅通信的基石。无论是构建分布式系统、开发Web应用&#xff0c;还是进行网络通信&#x…...

lida,一个超级厉害的 Python 库!

目录 前言 什么是 lida 库&#xff1f; lida 库的安装 基本功能 1. 文本分词 2. 词性标注 3. 命名实体识别 高级功能 1. 情感分析 2. 关键词提取 实际应用场景 1. 文本分类 2. 情感分析 3. 实体识别 总结 前言 大家好&#xff0c;今天为大家分享一个超级厉害的 Python …...

K好数 C语言 蓝桥杯算法提升ALGO3 一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字

问题描述 如果一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字&#xff0c;那么我们就说这个数是K好数。求L位K进制数中K好数的数目。例如K 4&#xff0c;L 2的时候&#xff0c;所有K好数为11、13、20、22、30、31、33 共7个。由于这个数目很大&#xff0c;请你输…...

2195. 深海机器人问题(网络流,费用流,上下界可行流,网格图模型)

活动 - AcWing 深海资源考察探险队的潜艇将到达深海的海底进行科学考察。 潜艇内有多个深海机器人。 潜艇到达深海海底后&#xff0c;深海机器人将离开潜艇向预定目标移动。 深海机器人在移动中还必须沿途采集海底生物标本。 沿途生物标本由最先遇到它的深海机器人完成采…...

Vue/cli项目全局css使用

第一步&#xff1a;创建css文件 在合适的位置创建好css文件&#xff0c;文件可以是sass/less/stylus...第二步&#xff1a;响预处理器loader传递选项 //摘自官网&#xff0c;引入样式 // vue.config.js module.exports {css: {loaderOptions: {// 给 sass-loader 传递选项sa…...

【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM

BitNet&#xff1a;用1-bit Transformer训练LLM 《BitNet: Scaling 1-bit Transformers for Large Language Models》 论文地址&#xff1a;https://arxiv.org/pdf/2310.11453.pdf 相关博客 【自然语言处理】【大模型】BitNet&#xff1a;用1-bit Transformer训练LLM 【自然语言…...

安装及管理docker

文章目录 1.Docker介绍2.Docker安装3.免sudo设置4. 使用docker命令5.Images6.运行docker容器7. 管理docker容器8.创建image9.Push Image 1.Docker介绍 Docker 是一个简化在容器中管理应用程序进程的应用程序。容器让你在资源隔离的进程中运行你的应用程序。类似于虚拟机&#…...

【MySQL】表的增删改查——MySQL基本查询、数据库表的创建、表的读取、表的更新、表的删除

文章目录 MySQL表的增删查改1. Create&#xff08;创建&#xff09;1.1 单行插入1.2 多行插入1.3 替换 2. Retrieve&#xff08;读取&#xff09;2.1 select查看2.2 where条件2.3 结果排序2.4 筛选分页结果 3. Update&#xff08;更新&#xff09;3.1 更新单个数据3.2 更新多个…...

C/C++蓝桥杯之日期问题

问题描述&#xff1a;小明正在整理一批文献&#xff0c;这些文献中出现了很多日期&#xff0c;小明知道这些日期都在1960年1月1日至2059年12月31日之间&#xff0c;令小明头疼的是&#xff0c;这些日期采用的格式非常不统一&#xff0c;有采用年/月/日的&#xff0c;有采用月/日…...

【理解指针(二)】

文章目录 一、指针的运算&#xff08;1&#xff09;指针加整数&#xff08;2&#xff09;指针减指针&#xff08;指针关系运算&#xff09; 二、野指针&#xff08;1&#xff09;野指针的成因&#xff08;1.1&#xff09;指针未初始化&#xff08;1.2&#xff09;指针的越界访问…...

使用AI纠正文章

我写了一段关于哲学自学的读书笔记&#xff0c;处于好奇的目的&#xff0c;让AI帮我纠正语法和逻辑。我的原文如下&#xff1a; 泰勒斯第一次提出了水是万物本源的说法&#xff0c;对于泰勒斯为什么提出这样的观点&#xff0c;或者是这样的观点是怎么来的&#xff0c;我们无从所…...

拼多多API批量获取商品详情信息

随着电子商务的蓬勃发展&#xff0c;淘宝作为中国最大的在线购物平台之一&#xff0c;每天需要处理海量的商品上架和交易。为了提高工作效率&#xff0c;自动化上架商品和批量获取商品详情信息成为了许多商家和开发者的迫切需求。本文将详细介绍淘宝的API接口及其相关技术&…...

杨辉三角(C语言)

杨辉三角 一.什么是杨辉三角 一.什么是杨辉三角 每个数等于它上方两数之和。 每行数字左右对称&#xff0c;由1开始逐渐变大。 第n行的数字有n项。 前n行共[(1n)n]/2 个数。 … 当前行的数上一行的数上一行的前一列的数 void yanghuisanjian(int arr[][20], int n) {for (int i…...

Leetcode 3576. Transform Array to All Equal Elements

Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接&#xff1a;3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到&#xf…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?

在建筑行业&#xff0c;项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升&#xff0c;传统的管理模式已经难以满足现代工程的需求。过去&#xff0c;许多企业依赖手工记录、口头沟通和分散的信息管理&#xff0c;导致效率低下、成本失控、风险频发。例如&#…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中&#xff0c;元素的定位通过 position 属性控制&#xff0c;共有 5 种定位模式&#xff1a;static&#xff08;静态定位&#xff09;、relative&#xff08;相对定位&#xff09;、absolute&#xff08;绝对定位&#xff09;、fixed&#xff08;固定定位&#xff09;和…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包&#xff08;Closure&#xff09;&#xff1f;闭包有什么应用场景和潜在问题&#xff1f;2.解释 JavaScript 的作用域链&#xff08;Scope Chain&#xff09; 二、原型与继承3.原型链是什么&#xff1f;如何实现继承&a…...

今日科技热点速览

&#x1f525; 今日科技热点速览 &#x1f3ae; 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售&#xff0c;主打更强图形性能与沉浸式体验&#xff0c;支持多模态交互&#xff0c;受到全球玩家热捧 。 &#x1f916; 人工智能持续突破 DeepSeek-R1&…...

优选算法第十二讲:队列 + 宽搜 优先级队列

优选算法第十二讲&#xff1a;队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表

##鸿蒙核心技术##运动开发##Sensor Service Kit&#xff08;传感器服务&#xff09;# 前言 在运动类应用中&#xff0c;运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据&#xff0c;如配速、距离、卡路里消耗等&#xff0c;用户可以更清晰…...

R语言速释制剂QBD解决方案之三

本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...