当前位置: 首页 > news >正文

Exponential Moving Average (EMA) in Stable Diffusion

1.Moving Average in Stable Diffusion (SMA&EMA)

1.Moving average
2.移动平均值
3.How We Trained Stable Diffusion for Less than $50k (Part 3)

Moving Average
在统计学中,移动平均是通过创建整个数据集中不同选择的一系列平均值来分析数据点的计算。


给定一数字序列和固定子集大小,移动平均值的第一个元素是通过对数字序列的初始固定子集求平均值而获得的。然后通过“前移”的方式修改子集;也就是说,排除系列的第一个数字并包括子集中的下一个值。

移动平均的理解,来自移动平均值

1.1 Simple Moving Average(SMA,an unweighted MA)


1.2 Exponential Moving Average (EMA,a weighted MA)

In the context of Stable Diffusion, the Exponential Moving Average (EMA) is a technique used during the training of machine learning models, particularly neural networks, to stabilize and improve the model’s performance.

The Exponential Moving Average is a method of averaging that gives more weight to recent data points, making it more responsive to recent changes compared to a simple moving average, which treats all data points equally.

1.2.1 EMA in Stable Diffusion

In the context of Stable Diffusion, EMA is applied to the model parameters during training to create a smoothed version of the model. This is particularly useful in machine learning because the training process can be noisy, with the model parameters oscillating as they converge towards an optimal solution. By maintaining an EMA of the model parameters, the training process can benefit from the following:

  1. Smoothing: EMA smooths out the parameter updates, reducing the impact of noise and making the training process more stable.
  2. Better Generalization: The EMA version of the model often generalizes better on unseen data compared to the model with the raw parameters. This is because EMA tends to favor parameter values that are more consistent over time.
  3. Preventing Overfitting: By averaging the parameters over time, EMA can help mitigate overfitting, especially in cases where the model might otherwise converge too quickly to a suboptimal solution.

笔者个人理解
代价函数(loss function)是关于参数(weight&bias)的函数,也就是说一个loss值对应一组参数值,loss值表现为震荡,也就是说模型参数也在变化。在训练SD时的MSE Loss在梯度下降过程中是上下震荡的,对应的模型参数也在震荡,可以用EMA取得这些模型参数震荡值的中间值,这个模型参数的中间值也就能更好的代表所有时刻模型参数的平均水平,让模型获得了更好的泛化能力

Stable Diffusion 2 uses Exponential Moving Averaging (EMA), which maintains an exponential moving average of the weights. At every time step, the EMA model is updated by taking 0.9999 times the current EMA model plus 0.0001 times the new weights after the latest forward and backward pass. By default, the EMA algorithm is applied after every gradient update for the entire training period. However, this can be slow due to the memory operations required to read and write all the weights at every step.
每个时间步都对所有参数进行EMA代价较大,因为要在每个时刻读写模型的全部参数
EMA t = 0.0001 ⋅ x t + 0.9999 ⋅ EMA t − 1 \text{EMA}_t=0.0001\cdot x_t+0.9999\cdot \text{EMA}_{t-1} EMAt=0.0001xt+0.9999EMAt1
为了使得计算EMA代价减小,我们仅仅采取在最后时间段进行EMA计算
To avoid this costly procedure, we start with a key observation: since the old weights are decayed by a factor of 0.9999 at every batch, the early iterations of training only contribute minimally to the final average. This means we only need to take the exponential moving average of the final few steps. Concretely, we train for 1,400,000 batches and only apply EMA for the final 50,000 steps, which is about 3.5% of the training period. The weights from the first 1,350,000 iterations decay away by (0.9999)^50000, so their aggregate contribution would have a weight of less than 1% in the final model. Using this technique, we can avoid adding overhead for 96.5% of training and still achieve a nearly equivalent EMA model.

1.2.2 Implementation in Stable Diffusion

During the training of a diffusion model, the EMA of the model’s weights is updated alongside the regular updates. Here’s a typical process:

  1. Initialize EMA Weights: At the start of training, initialize the EMA weights to be the same as the model’s initial weights.
  2. Update During Training: After each batch update, update the EMA weights using the formula mentioned above. This requires storing a separate set of weights for the EMA.
  3. Use for Inference: At the end of the training, use the EMA weights for inference instead of the raw model weights. This is because the EMA weights represent a more stable and potentially better-performing version of the model.

1.2.3 Practical Considerations

  1. Choosing α \alpha α:The smoothing factor α \alpha α is a hyperparameter that needs to be chosen carefully. A common practice is to set α \alpha α based on the number of iterations or epochs, such as α = 2 N + 1 \alpha=\frac{2}{N+1} α=N+12 where N N N is the number of iterations
  2. Performance Overhead: Maintaining EMA weights requires additional memory and computational overhead, but the benefits in terms of model stability and performance often outweigh these costs.

module.py

class EMA:
# Initializes the EMA object with a smoothing factor (beta) and a step counter (step).def __init__(self, beta):super().__init__()self.beta = beta  # Smoothing factor for the exponential moving averageself.step = 0  # Step counter to keep track of the number of updates
# Updates the moving average of the parameters of the EMA model (ma_model) based on the current model (current_model)def update_model_average(self, ma_model, current_model):# Update the moving average (EMA) of model parametersfor current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):old_weight, up_weight = ma_params.data, current_params.data# Update the moving average of the parametersma_params.data = self.update_average(old_weight, up_weight)
# Computes the exponentially weighted average of the old and new parameters.def update_average(self, old, new):# Compute the updated averageif old is None:return newreturn old * self.beta + (1 - self.beta) * new
# Either resets the EMA model parameters to match the current model parameters 
# if the step count is less than step_start_ema, 
# or updates the EMA model parameters based on the current model parameters. 
# It increments the step counter after each call.def step_ema(self, ema_model, model, step_start_ema=2000):# Update EMA model parameters or reset them based on the step countif self.step < step_start_ema:self.reset_parameters(ema_model, model)else:self.update_model_average(ema_model, model)self.step += 1  # Increment the step counter
# Copies the current model's parameters to the EMA model to initialize the EMA model parametersdef reset_parameters(self, ema_model, model):# Initialize EMA model parameters to be the same as the current model's parametersema_model.load_state_dict(model.state_dict())

train.py

def train(args):device = args.device  # Get the device to run the training onmodel = UNET().to(device)   # Initialize the model and move it to the devicemodel.train()optimizer = optim.AdamW(model.parameters(), lr=args.lr)  # set up the optimizer with AdamWmse = nn.MSELoss()  # Mean Squared Error loss functionlogger = SummaryWriter(os.path.join("runs", args.run_name))len_train = len(train_loader)
# EMA:Exponential Moving Averageema = EMA(0.995)  # Exponential Moving Average with decay rate 0.995
# At the start of training, initialize the EMA weights to be the same as the model’s initial weights.ema_model = copy.deepcopy(model).eval().requires_grad_(False)  # Create a copy of the model for EMA, set to eval mode and no gradientsprint('Start into the loop !')for epoch in range(args.epochs):logging.info(f"Starting epoch {epoch}:")  # log the start of the epochprogress_bar = tqdm(train_loader)  # progress bar for the dataloaderoptimizer.zero_grad()  # Explicitly zero the gradient buffersaccumulation_steps = 4# Load all data into a batchfor batch_idx, (images, captions) in enumerate(progress_bar):images = images.to(device)  # move images to the device# The dataloaer will add a batch size dimension to the tensor, but I've already added batch size to the VAE# and CLIP input, so we're going to remove a batch size and just keep the batch size of the dataloaderimages = torch.squeeze(images, dim=1)captions = captions.to(device)  # move caption to the devicetext_embeddings = torch.squeeze(captions, dim=1) # squeeze batch_sizetimesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)  # Sample random timestepsnoisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)  # Add noise to the imagestime_embeddings = timesteps_to_time_emb(timesteps)# x_t (batch_size, channel, Height/8, Width/8) (bs,4,256/8,256/8)# caption (batch_size, seq_len, dim) (bs, 77, 768)# t (batch_size, channel) (batch_size, 1280)# (bs,320,H/8,W/8)with torch.no_grad():last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)# (bs,4,H/8,W/8)final_output = diffusion.final.to(device)predicted_noise = final_output(last_decoder_noise).to(device)loss = mse(noises, predicted_noise)  # Compute the lossloss.backward()  # Backpropagate the lossif (batch_idx + 1) % accumulation_steps == 0:  # Wait for several backward passesoptimizer.step()  # Now we can do an optimizer stepoptimizer.zero_grad()  # Reset gradients to zero
# EMA:Exponential Moving Averageema.step_ema(ema_model, model)progress_bar.set_postfix(MSE=loss.item())  # Update the progress bar with the loss# log the loss to TensorBoardlogger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)# Save the model checkpointos.makedirs(os.path.join("models", args.run_name), exist_ok=True)torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))torch.save(optimizer.state_dict(),os.path.join("models", args.run_name, f"optim.pt"))  # Save the optimizer state

相关文章:

Exponential Moving Average (EMA) in Stable Diffusion

1.Moving Average in Stable Diffusion (SMA&EMA) 1.Moving average 2.移动平均值 3.How We Trained Stable Diffusion for Less than $50k (Part 3) Moving Average 在统计学中&#xff0c;移动平均是通过创建整个数据集中不同选择的一系列平均值来分析数据点的计算。 …...

017、Vue动态tag标签

文章目录 1、先看效果2、代码 1、先看效果 2、代码 <template><div class "tags"><el-tag size"medium"closable v-for"item,index in tags":key"item.path":effect"item.title$route.name?dark:plain"cl…...

RocketMQ 架构概览

Apache RocketMQ 是一个分布式消息中间件和流计算平台&#xff0c;提供低延迟、高性能和可靠的队列服务&#xff0c;并且支持大规模的分布式系统。在详细介绍 RocketMQ 的整体架构之前&#xff0c;先了解其设计目标和核心特性是很重要的。RocketMQ 主要用于处理大规模的消息&am…...

优化医疗数据管理:Kettle ETL 数据采集方案详解

在现代医疗保健领域&#xff0c;数据的准确性、完整性和及时性对于提高医疗服务质量和患者护理至关重要。为了有效管理和利用医疗数据&#xff0c;Kettle ETL&#xff08;Extract, Transform, Load&#xff09;数据采集方案成为了许多医疗机构的首选工具之一。本文将深入探讨Ke…...

spring-from表单

在spring boot当中,from表单怎样开发(name=value) 先列出接口所需信息(抓包得到请求信息),将这些必要信息以注解的方式表达出来 步骤: 梳理前置条件(请求地址,请求header,请求方法,请求数据,响应结果)编辑一个普通类,在类上标记注解@Controller: 标记在类上,让类…...

【.NET】asp.net core 程序重启容器后redis无法连接,连接超时

环境是容器化部署asp.net core 程序当有大量请求打到容器如果此时重启容器会出现&#xff0c;redis无法连接情况。 使用 csredis 库报错&#xff1a; Status unavailable, waiting for recovery. Connect to server timeout 使用StackExchange.Redis 报错&#xff1a; Time…...

【vue前端项目实战案例】Vue3仿今日头条App

本文将开发一款仿“今日头条”的新闻App。该案例是基于 Vue3.0 Vue Router webpack TypeScript 等技术栈实现的一款新闻资讯类App&#xff0c;适合有一定Vue框架使用经验的开发者进行学习。 项目源码在文章末尾 1 项目概述 该项目是一款“今日头条”的新闻资讯App&#xf…...

常见的文心一言的指令

文心一言&#xff0c;作为百度研发的预训练语言模型“ERNIE 3.0”的一项功能&#xff0c;能够与人对话互动&#xff0c;回答问题&#xff0c;协助创作&#xff0c;高效便捷地帮助人们获取信息、知识和灵感。以下是一些常见的文心一言指令类型及其具体示例&#xff1a; 1. 查询…...

数字货币交易接口实现(含源代码)

数字货币交易接口实现&#xff08;含源代码&#xff09; 使用币安交易接口步骤1&#xff1a;注册API密钥步骤2&#xff1a;安装所需库步骤3&#xff1a;使用API进行交易获取市场数据查看账户信息执行交易错误处理安全提示 使用OKX交易接口步骤1&#xff1a;注册API密钥步骤2&am…...

c++函数以及函数分文件编写

1.函数 1.1格式 返回值类型 函数名 &#xff08;参数列表&#xff09;//返回值类型指的是return过去的类型 { 函数体语句 return 表达式 } 1.2常见的函数样式 1.无参返回 2.有参返回 3.无参有返 4.有参有返 #include<iostream> using namespace std; int add(int nu…...

【JVM基础06】——组成-直接内存详解

目录 1- 引言&#xff1a;直接内存概述1-1 直接内存是什么&#xff1f;直接内存的定义(What)1-2 为什么用直接内存&#xff1f;Java程序对直接内存的使用 (Why) 2- ⭐核心&#xff1a;详解直接内存(How)2-1 文件拷贝案例介绍对比常规 IO(BIO) 和 NIO常规 IO 的操作流程NIO 的操…...

学术研讨 | 区块链与隐私计算领域专用硬件研讨会顺利召开

学术研讨 近日&#xff0c;国家区块链技术创新中心主办&#xff0c;长安链开源社区支持的“区块链与隐私计算领域专用硬件研讨会”顺利召开&#xff0c;会议围绕基于区块链与隐私计算的生成式AI上链、硬件加速、软硬协同等主题展开讨论&#xff0c;来自复旦大学、清华大学、北京…...

AngularJS API 深入解析

AngularJS API 深入解析 AngularJS,作为一个强大且灵活的JavaScript框架,自从其诞生以来,就一直是前端开发者构建复杂Web应用的首选工具。本文将深入探讨AngularJS的API,帮助读者理解其核心功能和工作原理。 AngularJS简介 AngularJS由Google开发,并于2010年发布。它是…...

过某开源滑动验证码

过某开源滑动验证码 今天早上我有一点空闲时间&#xff0c;想着回顾一下前几天在某查询网站遇到的滑动验证码&#xff0c;以免时间久了忘记了。那个网站可能使用的是较早版本的开源滑块验证码系统tianai-captcha&#xff0c;但我不确定是否正确。 整体思路&#xff1a; 获取…...

一文解决 | Linux(Ubuntn)系统安装 | 硬盘挂载 | 用户创建 | 生信分析配置

原文链接&#xff1a;一文解决 | Linux&#xff08;Ubuntn&#xff09;系统安装 | 硬盘挂载 | 用户创建 | 生信分析配置 本期教程 获得本期教程文本文档&#xff0c;在后台回复&#xff1a;20240724。请大家看清楚回复关键词&#xff0c;每天都有很多人回复错误关键词&#xf…...

Matlab M_map工具箱绘制Interrupted Mollweide Projection

GMT自带了许多的地图投影&#xff0c;但是对于Interrupted Mollweide投影效果却不好。 作为平替的m_map工具箱中带有的投影类型可完美解决这一问题。 Interrupted Mollweide Projection长这样 全球陆地 全球海洋 使用Matlab工具箱m_map展示全球海平面变化的空间分布 addpath(…...

Python 变量与基本数据类型

重点内容 1 掌握变量及厂里在数据输入、输出及计算中的应用&#xff1b; 2 熟练使用datetime模块来处理日期和时间问题&#xff1b; 3 熟练掌握abs()、round()、pow()、sum()、min()、max()等的应用&#xff1b; 4 利用变量、字符等知识模拟开发中一些场景的输入与输出&…...

Pytorch深度学习实践(5)逻辑回归

逻辑回归 逻辑回归主要是解决分类问题 回归任务&#xff1a;结果是一个连续的实数分类任务&#xff1a;结果是一个离散的值 分类任务不能直接使用回归去预测&#xff0c;比如在手写识别中&#xff08;识别手写 0 − − 9 0 -- 9 0−−9&#xff09;&#xff0c;因为各个类别…...

认识漏洞-GitLab 远程命令执行漏洞、致远OA-ajax.do未授权任意文件上传漏洞

为方便您的阅读&#xff0c;可点击下方蓝色字体&#xff0c;进行跳转↓↓↓ 01 [GitLab 远程命令执行漏洞复现(CVE-2021-22205)](https://mp.weixin.qq.com/s/4QT-vxKpBn4ppNM9ipt-nQ)02 [致远OA-ajax.do未授权任意文件上传Getshell](https://mp.weixin.qq.com/s/TH2A5J5TXU36Y…...

vue实现电子签名、图片合成、及预览功能

业务功能&#xff1a;电子签名、图片合成、及预览功能 业务背景&#xff1a;需求说想要实现一个电子签名&#xff0c;然后需要提供一个预览的功能&#xff0c;可以查看签完名之后的完整效果。 需求探讨&#xff1a;后端大佬跟我说&#xff0c;文档我返回给你一个PDF的oss链接…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能&#xff0c;本节首先介绍如何通过 Docker 快速体验 TDengine&#xff0c;然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker&#xff0c;请使用 安装包的方式快…...

ES6从入门到精通:前言

ES6简介 ES6&#xff08;ECMAScript 2015&#xff09;是JavaScript语言的重大更新&#xff0c;引入了许多新特性&#xff0c;包括语法糖、新数据类型、模块化支持等&#xff0c;显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var&#xf…...

边缘计算医疗风险自查APP开发方案

核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具

第2章 虚拟机性能监控&#xff0c;故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令&#xff1a;jps [options] [hostid] 功能&#xff1a;本地虚拟机进程显示进程ID&#xff08;与ps相同&#xff09;&#xff0c;可同时显示主类&#x…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...

LLMs 系列实操科普(1)

写在前面&#xff1a; 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容&#xff0c;原视频时长 ~130 分钟&#xff0c;以实操演示主流的一些 LLMs 的使用&#xff0c;由于涉及到实操&#xff0c;实际上并不适合以文字整理&#xff0c;但还是决定尽量整理一份笔…...

【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)

前言&#xff1a; 双亲委派机制对于面试这块来说非常重要&#xff0c;在实际开发中也是经常遇见需要打破双亲委派的需求&#xff0c;今天我们一起来探索一下什么是双亲委派机制&#xff0c;在此之前我们先介绍一下类的加载器。 目录 ​编辑 前言&#xff1a; 类加载器 1. …...