18、指数移动平均——EMA
简介
在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。
例子
假设有”温度-天数“的数据
θt\theta_tθt:在第 t 天的温度
vtv_tvt:在第 t 天的移动平均数
β\betaβ: 权重参数
v0=0v1=0.9v0+0.1θ1v2=0.9v1+0.1θ2⋯vt=0.9vt−1+0.1θt\begin{aligned} v_0 &= 0 \\ v_1 &=0.9v_0 + 0.1 \theta_1 \\ v_2 &=0.9v_1 + 0.1 \theta_2\\ &\cdots\\ v_t &=0.9v_{t-1} + 0.1\theta_t\\ \end{aligned} v0v1v2vt=0=0.9v0+0.1θ1=0.9v1+0.1θ2⋯=0.9vt−1+0.1θt
红线即是蓝色数据点的指数移动平均
VtV_tVt 和 β\betaβ 的关系
vtv_tvt 大概表示前 11−β\frac{1}{1-\beta}1−β1 天的平均数据(以第 t 天做参考)
β=0.9\beta = 0.9β=0.9 | 11−β≈10\frac{1}{1-\beta}\approx101−β1≈10 | vtv_tvt 大概表示前10天的平均数据 | 红线 |
---|---|---|---|
β=0.98\beta = 0.98β=0.98 | 11−β≈50\frac{1}{1-\beta}\approx501−β1≈50 | vtv_tvt 大概表示前50天的平均数据 | 绿线 |
β=0.5\beta = 0.5β=0.5 | 11−β≈2\frac{1}{1-\beta}\approx21−β1≈2 | vtv_tvt 大概表示前2天的平均数据 | 黄线 |
那么β\betaβ 越大,表示考虑的时间阔度越大 |
理解vtv_tvt
vt=β⋅vt−1+(1−β)⋅θt\begin{aligned} v_t = \beta \cdot v_{t-1} + (1-\beta) \cdot \theta_t \end{aligned} vt=β⋅vt−1+(1−β)⋅θt
当 β=0.9\beta = 0.9β=0.9,从v100v_{100}v100往回写
v100=0.9v99+0.1θ100v99=0.9v98+0.1θ99⋯v1=0.9v0+0.1θ1v0=0\begin{aligned} v_{100} &= 0.9v_{99} + 0.1\theta_{100}\\ v_{99} &= 0.9v_{98} + 0.1\theta_{99}\\ &\cdots \\ v_1 &=0.9v_0 + 0.1 \theta_1 \\ v_0 &= 0 \\ \end{aligned} v100v99v1v0=0.9v99+0.1θ100=0.9v98+0.1θ99⋯=0.9v0+0.1θ1=0
迭代该过程可知:
- $v_{100} 是 θ100θ99θ98⋯\theta_{100}\ \theta_{99}\ \theta_{98}\ \cdotsθ100 θ99 θ98 ⋯ 的加权求和
- θ\thetaθ 前的系数相加为 1 或逼近 1
当某项系数小于峰值系数 (𝟏−𝜷)的 1e\frac{1}{e}e1 时,可以忽略它的影响
(0.9)10≃0.34≃1e(0.9)^10 \simeq 0.34 \simeq \frac{1}{e}(0.9)10≃0.34≃e1 所以当β=0.9时,相当于前10天的加权平均。
(0.98)50≃0.36≃1e(0.98)^50 \simeq 0.36 \simeq \frac{1}{e}(0.98)50≃0.36≃e1 所以当β=0.98时,相当于前50天的加权平均。
(0.5)2≃0.25≃1e(0.5)^2 \simeq 0.25 \simeq \frac{1}{e}(0.5)2≃0.25≃e1 所以当β=0.5时,相当于前2天的加权平均。
带入深度学习模型
vt=β⋅vt−1+(1−β)⋅θtv_t = \beta \cdot v_{t-1} + (1-\beta) \cdot \theta_tvt=β⋅vt−1+(1−β)⋅θt
θt\theta_tθt:在第 t 次更新得到的所有参数权重
vtv_tvt:第 t 次更新的所有参数移动平均数
β\betaβ:权重参数
EMA内在
对于更新 t 次时普通的参数权重 θt\theta_tθt (gtg_tgt 是第 t 次传播得到的梯度):
θt=θt−1−gt−1=θt−2−gt−1−gt−2=⋯=θ1−∑i=1t−1gi\begin{aligned} \theta_t &= \theta_{t-1} - g_{t-1}\\ &=\theta_{t-2} - g_{t-1} - g_{t-2}\\ & = \cdots\\ &= \theta_1 - \sum^{t-1}_{i=1} g_i\\ \end{aligned} θt=θt−1−gt−1=θt−2−gt−1−gt−2=⋯=θ1−i=1∑t−1gi
对于更新 t 次时使用EMA的参数权重 vtv_tvt:
θt=θ1−∑i=1t−1givt=θ1−∑i=1t−1(1−βt−i)gi\begin{aligned} \theta_t &= \theta_1 - \sum^{t-1}_{i=1}g_i\\ v_t &= \theta_1 - \sum^{t-1}_{i=1}(1-\beta^{t-i})g_i \end{aligned} θtvt=θ1−i=1∑t−1gi=θ1−i=1∑t−1(1−βt−i)gi
推理如下:将 θn\theta_nθn 带入 vnv_nvn 表达式,并且令 v0=θ1v_0 = \theta_1v0=θ1:
vn=βtv0+(1−β)(θt+βθt−1+β2θt−2+⋯+βn−1θ1)=βtv0+(1−β)(θ1−∑i=1t−1gi+β(θ1−∑i=1t−2gi)+⋯+βt−2(θ1−∑i=11gi)+βt−1θ1)=βtv0+(1−βt)θ1−∑i=1n−1(1−βt−i)gi=θ1−∑i=1t−1(1−βt−i)gi\begin{aligned} v_n &= \beta^t v_0 +(1-\beta)(\theta_t + \beta\theta_{t-1}+\beta^2\theta_{t-2}+\cdots+\beta^{n-1}\theta_1) \\ &=\beta^tv_0 + (1-\beta)(\theta_1 - \sum^{t-1}_{i=1}g_i+\beta(\theta_1 - \sum^{t-2}_{i=1}g_i)+\cdots+\beta^{t-2}(\theta_1 - \sum^{1}_{i=1}g_i)+\beta^{t-1}\theta_1)\\ &=\beta^tv_0 + (1-\beta^t)\theta_1 - \sum^{n-1}_{i=1}(1-\beta^{t-i})g_i \\ &=\theta_1 - \sum^{t-1}_{i=1}(1-\beta^{t-i})g_i \end{aligned} vn=βtv0+(1−β)(θt+βθt−1+β2θt−2+⋯+βn−1θ1)=βtv0+(1−β)(θ1−i=1∑t−1gi+β(θ1−i=1∑t−2gi)+⋯+βt−2(θ1−i=1∑1gi)+βt−1θ1)=βtv0+(1−βt)θ1−i=1∑n−1(1−βt−i)gi=θ1−i=1∑t−1(1−βt−i)gi
普通的参数权重相当于一直累积更新整个训练过程的梯度,使用EMA的参数权重相当于使用训练过程梯度的加权平均(刚开始的梯度权值很小)。由于刚开始训练不稳定,得到的梯度给更小的权值更为合理,所以EMA会有效。
代码实现
class EMA(nn.Module):def __init__(self, model, decay=0.9999, device=None):super(EMA, self).__init__()# make a copy of the model for accumulating moving average of weightsself.module = deepcopy(model)self.module.eval()self.decay = decay# perform ema on different device from model if setself.device = deviceif self.device is not None:self.module.to(device=device)def _update(self, model, update_fn):with torch.no_grad():for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values):if self.device is not None:model_v = model_v.to(device=self.device)ema_v.copy_(update_fn(ema_v, model_v))def update(self, model):self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)def set(self, model):self._update(model, update_fn=lambda e, m: m)
class LitEma(nn.Module):def __init__(self, model, decay=0.9999, use_num_upates=True):super().__init__()if decay < 0.0 or decay > 1.0:raise ValueError('Decay must be between 0 and 1')self.m_name2s_name = {}self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upateselse torch.tensor(-1, dtype=torch.int))for name, p in model.named_parameters():if p.requires_grad:# remove as '.'-character is not allowed in bufferss_name = name.replace('.', '')self.m_name2s_name.update({name: s_name})self.register_buffer(s_name, p.clone().detach().data)self.collected_params = []def forward(self, model):decay = self.decayif self.num_updates >= 0:self.num_updates += 1decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))one_minus_decay = 1.0 - decaywith torch.no_grad():m_param = dict(model.named_parameters())shadow_params = dict(self.named_buffers())for key in m_param:if m_param[key].requires_grad:sname = self.m_name2s_name[key]shadow_params[sname] = shadow_params[sname].type_as(m_param[key])shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))else:assert not key in self.m_name2s_namedef copy_to(self, model):m_param = dict(model.named_parameters())shadow_params = dict(self.named_buffers())for key in m_param:if m_param[key].requires_grad:m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)else:assert not key in self.m_name2s_namedef store(self, parameters):"""Save the current parameters for restoring later.Args:parameters: Iterable of `torch.nn.Parameter`; the parameters to betemporarily stored."""self.collected_params = [param.clone() for param in parameters]def restore(self, parameters):"""Restore the parameters stored with the `store` method.Useful to validate the model with EMA parameters without affecting theoriginal optimization process. Store the parameters before the`copy_to` method. After validation (or model saving), use this torestore the former parameters.Args:parameters: Iterable of `torch.nn.Parameter`; the parameters to beupdated with the stored parameters."""for c_param, param in zip(self.collected_params, parameters):param.data.copy_(c_param.data)
相关文章:

18、指数移动平均——EMA
简介 在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。 指数移动平均(Exponential Moving Average)也叫权重移动平均(Weighted Moving…...

用Go快速搭建IM即时通讯系统
WebSocket的目标是在一个单独的持久连接上提供全双工、双向通信。在Javascript创建了Web Socket之后,会有一个HTTP请求发送到浏览器以发起连接。在取得服务器响应后,建立的连接会将HTTP升级从HTTP协议交换为WebSocket协议。由于WebSocket使用自定义的协议…...
2023年江苏省职业院校技能大赛中职网络安全赛项试卷-学生组-任务书
2023年江苏省职业院校技能大赛中职网络安全赛项试卷-学生组-任务书 2023年江苏省职业院校技能大赛中职网络安全赛项试卷-学生组-任务书第一阶段 (300分) [手敲的任务书 点个赞吧]任务一:主机发现与信息收集 (50分)任务二: 应急响应 (60分)任务三:数字取证与分析(80分)任务四:…...

如何使用码匠连接 MariaDB
MariaDB 是一个免费的、开源的关系型数据库管理系统,由 MariaDB 的创始人 Michael Widenius 于 2010 年创建。它基于 MariaDB,但在对数据存储的处理中加入了一些自己的特性。MariaDB 相对于 MariaDB 而言,具有更好的性能和更好的兼容性&#…...

JavaEE简单示例——Bean的实例化
简单介绍: 在我们之前使用某个对象,那么就要创建这个类的对象,创建对象的过程就叫做实例化。对于Spring来说,实例化Bean的方式有三种,分别是构造方法实例化,静态方法实例化,实例工厂实例化。我…...
1229. 日期问题
目录 题目链接 一些话 流程 套路 ac代码 题目链接 1229. 日期问题 - AcWing题库 一些话 切入点 // 小明知道这些日期都在1960年1月1日至2059年12月31日。 // 这些日期采用的格式非常不统一,有采用年/月/日的,有采用月/日/年的,还有采用…...

Java 中的浅拷贝和深拷贝
无论是浅拷贝还是深拷贝,都可以通过 Object 类的 clone() 方法来完成: /*** 拷贝** author qiaohaojie* date 2023/3/5 15:58*/ public class CloneTest {public static void main(String[] args) throws Exception {Person person1 new Person(23, &…...

【java】 java开发中 常遇到的各种难点 思路方案
文章目录逻辑删除如何建立唯一索引唯一索引失效问题加密字段模糊查询问题maven依赖冲突问题(jar包版本冲突问题)sql in条件查询时 将结果按照传入顺序排序作为一个开发人员 总会遇到各种难题 本文列举博主 遇见/想到 的例子 ,也希望同学们可以…...

ViewBinding 和 DataBinding的使用
1.ViewBinding:视图绑定 通过视图绑定功能,您可以更轻松地编写可与视图交互的代码。在模块中启用视图绑定之后,系统会为该模块中的每个 XML 布局文件生成一个绑定类。绑定类的实例包含对在相应布局中具有 ID 的所有视图的直接引用。在大多数情况下&…...
HTML+CSS入门
CSS概述 CSS指层叠样式表 (Cascading Style Sheets),用来定义HTML网页中的内容用什么样式来显示。 HTML: 指定网页显示的内容 CSS: 指定内容显示的样式CSS入门案例 <html><head><meta charset"UTF-8"><title>入门案例</tit…...

【Vue】vue2导出页面内容为pdf文件,自定义选中页面内容导出为pdf文件,打印选中页面内容,预览打印内容
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言一、安装html2canvas和jspdf二、导出pdf使用步骤1.在utils文件夹下创建htmlToPdf.js2.在main.js中引入3.在页面中使用三、打印预览1. 引入print-js2.页面中impor…...

保姆级使用PyTorch训练与评估自己的Replknet网络教程
文章目录前言0. 环境搭建&快速开始1. 数据集制作1.1 标签文件制作1.2 数据集划分1.3 数据集信息文件制作2. 修改参数文件3. 训练4. 评估5. 其他教程前言 项目地址:https://github.com/Fafa-DL/Awesome-Backbones 操作教程:https://www.bilibili.co…...

1/4车、1/2车、整车悬架PID控制仿真合集
目录 前言 1. 1/4悬架系统 1.1数学模型 1.2仿真分析 2. 1/2悬架系统 2.1数学模型 2.2仿真模型 2.3仿真分析 3. 整车悬架系统 3.1数学模型 3.2仿真分析 参考文献 前言 前面几篇文章介绍了LQR、SkyHook、H2/H∞控制,接下来会继续介绍滑模、反步法、MPC、…...

媒体邀约的形式和步骤
传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 做媒体服务很多年,今天就与大家分享下媒体邀约都有哪些形式: 1,电话邀约:通过电话与媒体记者进行沟通,邀请其参加活动或接受采…...

Unity合批处理
一.静态合批标记为Batching Static的物体(标记后物体运行不能移动、旋转、缩放)在使用相同材质球的条件下在项目打包的时候unity会自动将这些物体合并到一个大Mesh*缺点打包后体积增大运行时内存占用增大二.动态批处理不超过300个顶点不超过900个属性不包…...

Android 进阶——Binder IPC之Native 服务的启动及代理对象的获取详解(六)
文章大纲引言一、Binder线程池的启动1、ProcessState#startThreadPool函数来启动线程池2、IPCThreadState#joinThreadPool 将当前线程进入到线程池中去等待和处理IPC请求二、Service 代理对象的获取1、获取Service Manager 代理对象BpServiceManager2、调用BpServiceManager#ge…...

企业官网怎么做?
企业官网是企业展示形象和吸引潜在客户的重要渠道之一,因此如何打造一款优秀的企业官网显得尤为重要。本文将从策划、设计、开发和上线等方面,为您介绍企业官网的制作步骤。 一、策划 1.明确目标 企业官网的制作需要明确目标,即确定官网的主…...

FPGA和IC设计怎么选?哪个发展更好?
很多人纠结FPGA和IC设计怎么选,其实往小了说,要看你选择的具体是哪个方向岗位。往大了说,将来你要是走更远,要成为大佬,那基本各个方向的都要有涉及的。 不同方向就有不同的发展,目前在薪资上IC设计要比FP…...

宁盾目录成功对接Coremail邮箱,为其提供LDAP统一认证和双因子认证
近日,宁盾与 Coremail 完成兼容适配,在 LDAP 目录用户同步、统一身份认证及双因子认证等模块成功对接。借此机会,双方将加深在产品、解决方案等多个领域的合作,携手共建信创合作生态,打造信创 LDAP 身份目录服务新样本…...
Go: struct 结构体类型和指针【学习笔记记录】
struct 结构体类型和指针struct 结构体类型1. 定义结构体2. 访问结构体成员3. 结构体的使用及匿名字段指针1. 指针变量的声明及使用2. 指针数组的定义及使用3. 函数传参修改值struct 结构体类型 Go 语言中数组可以存储同一类型的数据,但在结构体中我们可以为不同项…...
ES6从入门到精通:前言
ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var…...

《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...

Redis数据倾斜问题解决
Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

AirSim/Cosys-AirSim 游戏开发(四)外部固定位置监控相机
这个博客介绍了如何通过 settings.json 文件添加一个无人机外的 固定位置监控相机,因为在使用过程中发现 Airsim 对外部监控相机的描述模糊,而 Cosys-Airsim 在官方文档中没有提供外部监控相机设置,最后在源码示例中找到了,所以感…...
纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join
纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...

基于IDIG-GAN的小样本电机轴承故障诊断
目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) 梯度归一化(Gradient Normalization) (2) 判别器梯度间隙正则化(Discriminator Gradient Gap Regularization) (3) 自注意力机制(Self-Attention) 3. 完整损失函数 二…...
C#学习第29天:表达式树(Expression Trees)
目录 什么是表达式树? 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持: 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...
GitHub 趋势日报 (2025年06月06日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...
Caliper 配置文件解析:fisco-bcos.json
config.yaml 文件 config.yaml 是 Caliper 的主配置文件,通常包含以下内容: test:name: fisco-bcos-test # 测试名称description: Performance test of FISCO-BCOS # 测试描述workers:type: local # 工作进程类型number: 5 # 工作进程数量monitor:type: - docker- pro…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准
城市路内停车管理常因行道树遮挡、高位设备盲区等问题,导致车牌识别率低、逃费率高,传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法,正成为破局关键。该设备安装于车位侧方0.5-0.7米高度,直接规避树枝遮…...