当前位置: 首页 > news >正文

指数移动平均(EMA)

文章目录

  • 前言
  • EMA的定义
  • 在深度学习中的应用
    • PyTorch代码实现
    • yolov5中模型的EMA实现
  • 参考


前言

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
实际上,_EMA可以看作是Temporal Ensembling,在模型学习过程中融合更多的历史状态,从而达到更好的优化效果。

EMA的定义

指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
假设有n个权重数据image.png

  • 普通的平均数:image.png
  • EMA:image.png

其中, vt表示前 t条的平均值 ( v0=0 ),β是加权权重值 (一般设为0.9-0.999)。

Andrew Ng在Course 2 Improving Deep Neural Networks中讲到,EMA可以近似看成过去 1/(1−β) 个时刻 v 值的平均。
普通的过去n时刻的平均是这样的:image.png
类比EMA,可以发现当 image.png 时,两式形式上相等。需要注意的是,两个平均并不是严格相等的,这里只是为了帮助理解。
实际上,EMA计算时,过去 1/(1−β) 个时刻之前的数值平均会decay到 1/e 的加权比例,证明如下。
如果将这里的 vt展开,可以得到:
image.png
其中, image.png,代入可以得到 image.png

在深度学习中的应用

上面讲的是广义的ema定义和计算方法,特别的,在深度学习的优化过程中, image.png是t时刻的模型权重weights, vt是t时刻的影子权重(shadow weights)。在梯度下降的过程中,会一直维护着这个影子权重,但是这个影子权重并不会参与训练。
基本的假设是,模型权重在最后的n步内,会在实际的最优点处抖动,所以我们取最后n步的平均,能使得模型更加的鲁棒。

PyTorch代码实现

下面是一个简单的指数移动平均(EMA)的PyTorch实现:

import torchclass EMA():def __init__(self, alpha):self.alpha = alpha    # 初始化平滑因子alphaself.average = None   # 初始化平均值为空self.count = 0        # 初始化计数器为0def update(self, x):if self.average is None:  # 如果平均值为空,则将其初始化为与x相同大小的全零张量self.average = torch.zeros_like(x)self.average = self.alpha * x + (1 - self.alpha) * self.average  # 更新平均值self.count += 1   # 更新计数器def get(self):return self.average / (1 - self.alpha ** self.count)   # 根据计数器和平滑因子计算EMA值,并返回平均值除以衰减系数的结果

在这个类中,我们定义了三个方法,分别是__init__、update和get。

  • __init__方法用于初始化平滑因子alpha、平均值average和计数器count
  • update方法用于更新EMA值
  • get方法用于获取最终的EMA值。

使用这个类时,我们可以先实例化一个EMA对象,然后在每个时间步中调用update方法来更新EMA值,最后调用get方法来获取最终的EMA值。
例如:

ema = EMA(alpha=0.5)
for value in data:ema.update(torch.tensor(value))
smoothed_data = ema.get()

在这个例子中,我们使用alpha=0.5来初始化EMA对象,然后遍历数据集data中的每个数据点,调用update方法更新EMA值。最后我们调用get方法来获取平滑后的数据。

yolov5中模型的EMA实现

如下:

class ModelEMA:""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-modelsKeeps a moving average of everything in the model state_dict (parameters and buffers)For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage"""def __init__(self, model, decay=0.9999, tau=2000, updates=0):# Create EMAself.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMAself.updates = updates  # number of EMA updatesself.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)for p in self.ema.parameters():p.requires_grad_(False)def update(self, model):# Update EMA parametersself.updates += 1d = self.decay(self.updates)msd = de_parallel(model).state_dict()  # model state_dictfor k, v in self.ema.state_dict().items():if v.dtype.is_floating_point:  # true for FP16 and FP32v *= dv += (1 - d) * msd[k].detach()# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):# Update EMA attributescopy_attr(self.ema, model, include, exclude)

参考

https://zhuanlan.zhihu.com/p/68748778


如果有用,请点个三连呗 点赞、关注、收藏
你的鼓励是我最大的动力

相关文章:

指数移动平均(EMA)

文章目录 前言EMA的定义在深度学习中的应用PyTorch代码实现yolov5中模型的EMA实现 参考 前言 在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。实际上,_EMA可以…...

无线表格识别模型LORE转换库:ConvertLOREToONNX

引言 总有小伙伴问到阿里的无线表格识别模型是如何转换为ONNX格式的。这个说来有些惭愧,现有的ONNX模型是很久之前转换的了,转换环境已经丢失,且没有做任何笔记。 今天下定决心再次尝试转换,庆幸的是转换成功了。于是有了转换笔…...

C# 视频转图片

在 C# 中将视频转换为图像可以使用 FFmpeg 库。下面是一个示例代码来完成这个任务: using System; using System.Diagnostics;class Program {static void Main(string[] args){string inputFile "input_video.mp4"; // 输入的视频文件路径string outpu…...

LINUX ADC使用

监测 ADC ,使用CAT 查看&#xff1a; LINUX ADC基本使用 &adc {pinctrl-names "default";pinctrl-0 <&adc6>;pinctrl-1 <&adc7>;pinctrl-2 <&adc8>;pinctrl-3 <&adc9>;pinctrl-4 <&adc10>;pinctrl-5 …...

Ubuntu 基本操作-嵌入式 Linux 入门

在 Ubuntu 基本操作 里面基本就分为两部分&#xff1a; 安装 VMware 运行 Ubuntu熟悉 Ubuntu 的各种操作、命令 如果你对 Ubuntu 比较熟悉的话&#xff0c;安装完 VMware 运行 Ubuntu 之后就可以来学习下一章节了。 1. 安装 VMware 运行 Ubuntu 我们首先来看看怎么去安装 V…...

Pytorch可形变卷积分类模型与可视化

E:. │ archs.py │ dataset.py │ deform_conv_v2.py │ train.py │ utils.py │ visual_net.py │ ├─grad_cam │ 2.png │ 3.png │ ├─image │ ├─1 │ │ 154.png │ │ 2.png │ │ │ ├─2 │ │ 143.png │…...

Mysql 表逻辑分区原理和应用

MySQL的表逻辑分区是一种数据库设计技术&#xff0c;它允许将一个表的数据分布在多个物理分区中&#xff0c;但在逻辑上仍然表现为一个单一的表。这种方式可以提高查询性能、简化数据管理&#xff0c;并有助于高效地进行大数据量的存储和访问。逻辑分区基于特定的规则&#xff…...

架构面试题汇总:网络协议34问(七)

码到三十五 &#xff1a; 个人主页 心中有诗画&#xff0c;指尖舞代码&#xff0c;目光览世界&#xff0c;步履越千山&#xff0c;人间尽值得 ! 网络协议是实现各种设备和应用程序之间顺畅通信的基石。无论是构建分布式系统、开发Web应用&#xff0c;还是进行网络通信&#x…...

lida,一个超级厉害的 Python 库!

目录 前言 什么是 lida 库&#xff1f; lida 库的安装 基本功能 1. 文本分词 2. 词性标注 3. 命名实体识别 高级功能 1. 情感分析 2. 关键词提取 实际应用场景 1. 文本分类 2. 情感分析 3. 实体识别 总结 前言 大家好&#xff0c;今天为大家分享一个超级厉害的 Python …...

K好数 C语言 蓝桥杯算法提升ALGO3 一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字

问题描述 如果一个自然数N的K进制表示中任意的相邻的两位都不是相邻的数字&#xff0c;那么我们就说这个数是K好数。求L位K进制数中K好数的数目。例如K 4&#xff0c;L 2的时候&#xff0c;所有K好数为11、13、20、22、30、31、33 共7个。由于这个数目很大&#xff0c;请你输…...

2195. 深海机器人问题(网络流,费用流,上下界可行流,网格图模型)

活动 - AcWing 深海资源考察探险队的潜艇将到达深海的海底进行科学考察。 潜艇内有多个深海机器人。 潜艇到达深海海底后&#xff0c;深海机器人将离开潜艇向预定目标移动。 深海机器人在移动中还必须沿途采集海底生物标本。 沿途生物标本由最先遇到它的深海机器人完成采…...

Vue/cli项目全局css使用

第一步&#xff1a;创建css文件 在合适的位置创建好css文件&#xff0c;文件可以是sass/less/stylus...第二步&#xff1a;响预处理器loader传递选项 //摘自官网&#xff0c;引入样式 // vue.config.js module.exports {css: {loaderOptions: {// 给 sass-loader 传递选项sa…...

【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM

BitNet&#xff1a;用1-bit Transformer训练LLM 《BitNet: Scaling 1-bit Transformers for Large Language Models》 论文地址&#xff1a;https://arxiv.org/pdf/2310.11453.pdf 相关博客 【自然语言处理】【大模型】BitNet&#xff1a;用1-bit Transformer训练LLM 【自然语言…...

安装及管理docker

文章目录 1.Docker介绍2.Docker安装3.免sudo设置4. 使用docker命令5.Images6.运行docker容器7. 管理docker容器8.创建image9.Push Image 1.Docker介绍 Docker 是一个简化在容器中管理应用程序进程的应用程序。容器让你在资源隔离的进程中运行你的应用程序。类似于虚拟机&#…...

【MySQL】表的增删改查——MySQL基本查询、数据库表的创建、表的读取、表的更新、表的删除

文章目录 MySQL表的增删查改1. Create&#xff08;创建&#xff09;1.1 单行插入1.2 多行插入1.3 替换 2. Retrieve&#xff08;读取&#xff09;2.1 select查看2.2 where条件2.3 结果排序2.4 筛选分页结果 3. Update&#xff08;更新&#xff09;3.1 更新单个数据3.2 更新多个…...

C/C++蓝桥杯之日期问题

问题描述&#xff1a;小明正在整理一批文献&#xff0c;这些文献中出现了很多日期&#xff0c;小明知道这些日期都在1960年1月1日至2059年12月31日之间&#xff0c;令小明头疼的是&#xff0c;这些日期采用的格式非常不统一&#xff0c;有采用年/月/日的&#xff0c;有采用月/日…...

【理解指针(二)】

文章目录 一、指针的运算&#xff08;1&#xff09;指针加整数&#xff08;2&#xff09;指针减指针&#xff08;指针关系运算&#xff09; 二、野指针&#xff08;1&#xff09;野指针的成因&#xff08;1.1&#xff09;指针未初始化&#xff08;1.2&#xff09;指针的越界访问…...

使用AI纠正文章

我写了一段关于哲学自学的读书笔记&#xff0c;处于好奇的目的&#xff0c;让AI帮我纠正语法和逻辑。我的原文如下&#xff1a; 泰勒斯第一次提出了水是万物本源的说法&#xff0c;对于泰勒斯为什么提出这样的观点&#xff0c;或者是这样的观点是怎么来的&#xff0c;我们无从所…...

拼多多API批量获取商品详情信息

随着电子商务的蓬勃发展&#xff0c;淘宝作为中国最大的在线购物平台之一&#xff0c;每天需要处理海量的商品上架和交易。为了提高工作效率&#xff0c;自动化上架商品和批量获取商品详情信息成为了许多商家和开发者的迫切需求。本文将详细介绍淘宝的API接口及其相关技术&…...

杨辉三角(C语言)

杨辉三角 一.什么是杨辉三角 一.什么是杨辉三角 每个数等于它上方两数之和。 每行数字左右对称&#xff0c;由1开始逐渐变大。 第n行的数字有n项。 前n行共[(1n)n]/2 个数。 … 当前行的数上一行的数上一行的前一列的数 void yanghuisanjian(int arr[][20], int n) {for (int i…...

23-Oracle 23 ai 区块链表(Blockchain Table)

小伙伴有没有在金融强合规的领域中遇见&#xff0c;必须要保持数据不可变&#xff0c;管理员都无法修改和留痕的要求。比如医疗的电子病历中&#xff0c;影像检查检验结果不可篡改行的&#xff0c;药品追溯过程中数据只可插入无法删除的特性需求&#xff1b;登录日志、修改日志…...

Frozen-Flask :将 Flask 应用“冻结”为静态文件

Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是&#xff1a;将一个 Flask Web 应用生成成纯静态 HTML 文件&#xff0c;从而可以部署到静态网站托管服务上&#xff0c;如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

均衡后的SNRSINR

本文主要摘自参考文献中的前两篇&#xff0c;相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程&#xff0c;其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt​ 根发送天线&#xff0c; n r n_r nr​ 根接收天线的 MIMO 系…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

Linux C语言网络编程详细入门教程:如何一步步实现TCP服务端与客户端通信

文章目录 Linux C语言网络编程详细入门教程&#xff1a;如何一步步实现TCP服务端与客户端通信前言一、网络通信基础概念二、服务端与客户端的完整流程图解三、每一步的详细讲解和代码示例1. 创建Socket&#xff08;服务端和客户端都要&#xff09;2. 绑定本地地址和端口&#x…...

BLEU评分:机器翻译质量评估的黄金标准

BLEU评分&#xff1a;机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域&#xff0c;衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标&#xff0c;自2002年由IBM的Kishore Papineni等人提出以来&#xff0c;…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障

关键领域软件测试的"安全密码"&#xff1a;Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力&#xff0c;从金融交易到交通管控&#xff0c;这些关乎国计民生的关键领域…...