Open-Sora代码详细解读(2):时空3D VAE
Diffusion Models视频生成
前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。
目录
3D VAE原理
代码剖析
2D VAE
时间VAE
因果3D卷积
3D VAE原理
之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。
Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。
代码剖析
2D VAE
来自华为pixart sdxl vae:
vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)
时间VAE
vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32, # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups, # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups, # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z
因果3D卷积
class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None, # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x
相关文章:

Open-Sora代码详细解读(2):时空3D VAE
Diffusion Models视频生成 前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深…...

基于微信平台的旅游出行必备商城小程序+ssm(lw+演示+源码+运行)
摘 要 随着社会的发展,社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景,运用软件工程原理和开发方法,它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…...
AI绘画:科技赋能艺术的崭新时代
💯AI绘画:走进艺术创新的新时代 人工智能在改变世界的过程中,AI绘画工具逐渐成为创新的典范。 本文将为您揭示AI绘画背后的技术秘密、潜在的应用场景,并为您推荐几款出色的AI绘画工具,助您领略这一技术带来的艺术新体…...

性能诊断的方法(四):自下而上的资源诊断方法和发散的异常信息诊断方法
关于性能诊断的方法,我们可以按照“问题现象—直接原因—问题根源”这样一个思路去归纳。我们先从问题的现象去入手,包括时间的分析、资源的分析和异常信息的分析。接下来再去分析产生问题现象的直接原因是什么,这里我们归纳了自上而下的资源…...

GDPU Vue前端框架开发 计数器
计数器算不到你双向绑定的进度。 重要的更新公告 !!!GDPU的小伙伴,感谢大家的支持,希望到此一游的帅哥美女能有所帮助。本学期的前端框架及移动应用,采用专栏订阅量达到50才开始周更了哦( •̀ .̫ •́ )✧…...
最大流笔记
概念 求两点间的路径中可在同一时间内通过的最大量 EK算法 通过bfs找通路,找到后回溯; 每确定一条边时,同时建立一天反方向的边以用来进行反悔操作(毕竟一次性找到正确方案的概率太低了) code #include<bits/st…...

el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能
el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能 1、功能实现图示 2、实现思路 当属性check-strictly为true时,父子节点不互相关联,如果需要全部选中或选择某一节点下的全部节点就必须手动选择每个节点,十分麻…...
模板与泛型编程笔记(一)入门篇
1. 推荐书籍 《C新经典 模板与泛型编程》难得的很容易看得懂的好书,作者讲技术不跳跃,娓娓道来,只要花点时间就能看懂。 2. 笔记 2.1 模板基础 模板为什么要用尖括号?因为便于编译器解析,可以将模板和普通函数声明…...
浅谈WebApi
一、基本介绍 Web API(Web应用程序编程接口)是一种用于构建应用程序的接口,它允许软件应用程序通过HTTP请求与Web服务器进行交互。Web API通常用于构建客户端-服务器应用程序,其中客户端可以是Web浏览器、移动应用程序、桌面应用程…...
9月14日,每日信息差
第一、宝马集团宣布对设计部门进行重组,并将于 2024 年 10 月 1 日成立一个跨品牌设计团队,由范・霍伊顿克领导。该团队将引入极星汽车设计主管马克西米利安・米索尼,负责宝马中高档和豪华车型以及宝马 Alpina 的设计工作。 第二、小鹏汇天飞…...

无人机控制与三维AI感知处理平台正式上线!
低空经济被誉为推动我国经济高质量发展的全新增长引擎,是一种以民用有人驾驶和无人驾驶航空器的各类低空飞行活动为牵引,辐射带动相关领域融合发展的综合性经济形态,2024年全国两会首次被纳入政府工作报告。 大势智慧积极响应国家低空经济政…...
9.11-kubeadm方式安装k8s
一、安装环境 编号主机名称ip地址1k8s-master192.168.2.662k8s-node01192.168.2.773k8s-node02192.168.2.88 二、前期准备 1.设置免密登录 [rootk8s-master ~]# ssh-keygen [rootk8s-master ~]# ssh-copy-id root192.168.2.77 [rootk8s-master ~]# ssh-copy-id root192.168…...

限流,流量整形算法
写在前面 源码 。 本文看下流量整形相关算法。 目前流量整形算法主要有三种,计数器,漏桶,令牌桶。分别看下咯! 1:计数器 1.1:描述 单位时间内只允许指定数量的请求,如果是时间区间内超过指…...
【C++知识扫盲】------C++ 中的引用入门
在 C 中,引用(reference) 是一个非常重要的概念,它提供了一种别名机制,让我们可以给已经存在的变量起一个新的名字,并且能够通过这个别名直接操作原始变量。本文将详细介绍引用的定义、使用场景及其与指针的…...

【机器学习】6 ——最大熵模型
机器学习6——最大熵模型 目录 机器学习6——最大熵模型最大熵(maximum entropy)模型模型模型学习(估计参数)模型评价应用 最大熵(maximum entropy)模型 选择熵最大的概率模型 熵是衡量不确定性的…...

小程序——生命周期
文章目录 运行机制更新机制生命周期介绍应用级别生命周期页面级别生命周期组件生命周期生命周期两个细节补充说明总结 运行机制 用一张图简要概述一下小程序的运行机制 冷启动与热启动: 小程序启动可以分为两种情况,一种是冷启动,一种是热…...

基于微信小程序的宠物之家的设计与实现
作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSpringBootVueMySQL的宠物之家/宠物综合…...

自定义EPICS在LabVIEW中的测试
继续上一篇:LabVIEW中EPICS客户端/服务端的测试 变量定义 You can use CaLabSoftIOC.vi to create new EPICS variables and start them. CA Lab - LabVIEW (Realtime) EPICS INPUT: PV set Cluster-array of names, data types and field definitions to crea…...
基于深度学习的农作物病害检测
基于深度学习的农作物病害检测利用卷积神经网络(CNN)、生成对抗网络(GAN)、Transformer等深度学习技术,自动识别和分类农作物的病害,帮助农业工作者提高作物管理效率、减少损失。 1. 农作物病害检测的挑战…...
【C#】命名规范
文章目录 C# 命名规范使用Pascal case使用Camel case方法、属性、类命名见名知义LINQ查询变量使用有意义的名称如何声明成员变量和字段正确格式化和缩进代码如何撰写备注 通用C#编码最佳实践如何将值与空字符串进行比较使用异常处理使用&&和||可获得更好的性能单一职责…...

Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

无法与IP建立连接,未能下载VSCode服务器
如题,在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈,发现是VSCode版本自动更新惹的祸!!! 在VSCode的帮助->关于这里发现前几天VSCode自动更新了,我的版本号变成了1.100.3 才导致了远程连接出…...

关于nvm与node.js
1 安装nvm 安装过程中手动修改 nvm的安装路径, 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解,但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后,通常在该文件中会出现以下配置&…...

【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...
【AI学习】三、AI算法中的向量
在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践
6月5日,2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席,并作《智能体在安全领域的应用实践》主题演讲,分享了在智能体在安全领域的突破性实践。他指出,百度通过将安全能力…...