Open-Sora代码详细解读(2):时空3D VAE
Diffusion Models视频生成
前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。
目录
3D VAE原理
代码剖析
2D VAE
时间VAE
因果3D卷积
3D VAE原理
之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。
Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。
代码剖析
2D VAE
来自华为pixart sdxl vae:
vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)
时间VAE
vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32, # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups, # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups, # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z
因果3D卷积
class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None, # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x

相关文章:
Open-Sora代码详细解读(2):时空3D VAE
Diffusion Models视频生成 前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深…...
基于微信平台的旅游出行必备商城小程序+ssm(lw+演示+源码+运行)
摘 要 随着社会的发展,社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景,运用软件工程原理和开发方法,它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…...
AI绘画:科技赋能艺术的崭新时代
💯AI绘画:走进艺术创新的新时代 人工智能在改变世界的过程中,AI绘画工具逐渐成为创新的典范。 本文将为您揭示AI绘画背后的技术秘密、潜在的应用场景,并为您推荐几款出色的AI绘画工具,助您领略这一技术带来的艺术新体…...
性能诊断的方法(四):自下而上的资源诊断方法和发散的异常信息诊断方法
关于性能诊断的方法,我们可以按照“问题现象—直接原因—问题根源”这样一个思路去归纳。我们先从问题的现象去入手,包括时间的分析、资源的分析和异常信息的分析。接下来再去分析产生问题现象的直接原因是什么,这里我们归纳了自上而下的资源…...
GDPU Vue前端框架开发 计数器
计数器算不到你双向绑定的进度。 重要的更新公告 !!!GDPU的小伙伴,感谢大家的支持,希望到此一游的帅哥美女能有所帮助。本学期的前端框架及移动应用,采用专栏订阅量达到50才开始周更了哦( •̀ .̫ •́ )✧…...
最大流笔记
概念 求两点间的路径中可在同一时间内通过的最大量 EK算法 通过bfs找通路,找到后回溯; 每确定一条边时,同时建立一天反方向的边以用来进行反悔操作(毕竟一次性找到正确方案的概率太低了) code #include<bits/st…...
el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能
el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能 1、功能实现图示 2、实现思路 当属性check-strictly为true时,父子节点不互相关联,如果需要全部选中或选择某一节点下的全部节点就必须手动选择每个节点,十分麻…...
模板与泛型编程笔记(一)入门篇
1. 推荐书籍 《C新经典 模板与泛型编程》难得的很容易看得懂的好书,作者讲技术不跳跃,娓娓道来,只要花点时间就能看懂。 2. 笔记 2.1 模板基础 模板为什么要用尖括号?因为便于编译器解析,可以将模板和普通函数声明…...
浅谈WebApi
一、基本介绍 Web API(Web应用程序编程接口)是一种用于构建应用程序的接口,它允许软件应用程序通过HTTP请求与Web服务器进行交互。Web API通常用于构建客户端-服务器应用程序,其中客户端可以是Web浏览器、移动应用程序、桌面应用程…...
9月14日,每日信息差
第一、宝马集团宣布对设计部门进行重组,并将于 2024 年 10 月 1 日成立一个跨品牌设计团队,由范・霍伊顿克领导。该团队将引入极星汽车设计主管马克西米利安・米索尼,负责宝马中高档和豪华车型以及宝马 Alpina 的设计工作。 第二、小鹏汇天飞…...
无人机控制与三维AI感知处理平台正式上线!
低空经济被誉为推动我国经济高质量发展的全新增长引擎,是一种以民用有人驾驶和无人驾驶航空器的各类低空飞行活动为牵引,辐射带动相关领域融合发展的综合性经济形态,2024年全国两会首次被纳入政府工作报告。 大势智慧积极响应国家低空经济政…...
9.11-kubeadm方式安装k8s
一、安装环境 编号主机名称ip地址1k8s-master192.168.2.662k8s-node01192.168.2.773k8s-node02192.168.2.88 二、前期准备 1.设置免密登录 [rootk8s-master ~]# ssh-keygen [rootk8s-master ~]# ssh-copy-id root192.168.2.77 [rootk8s-master ~]# ssh-copy-id root192.168…...
限流,流量整形算法
写在前面 源码 。 本文看下流量整形相关算法。 目前流量整形算法主要有三种,计数器,漏桶,令牌桶。分别看下咯! 1:计数器 1.1:描述 单位时间内只允许指定数量的请求,如果是时间区间内超过指…...
【C++知识扫盲】------C++ 中的引用入门
在 C 中,引用(reference) 是一个非常重要的概念,它提供了一种别名机制,让我们可以给已经存在的变量起一个新的名字,并且能够通过这个别名直接操作原始变量。本文将详细介绍引用的定义、使用场景及其与指针的…...
【机器学习】6 ——最大熵模型
机器学习6——最大熵模型 目录 机器学习6——最大熵模型最大熵(maximum entropy)模型模型模型学习(估计参数)模型评价应用 最大熵(maximum entropy)模型 选择熵最大的概率模型 熵是衡量不确定性的…...
小程序——生命周期
文章目录 运行机制更新机制生命周期介绍应用级别生命周期页面级别生命周期组件生命周期生命周期两个细节补充说明总结 运行机制 用一张图简要概述一下小程序的运行机制 冷启动与热启动: 小程序启动可以分为两种情况,一种是冷启动,一种是热…...
基于微信小程序的宠物之家的设计与实现
作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSpringBootVueMySQL的宠物之家/宠物综合…...
自定义EPICS在LabVIEW中的测试
继续上一篇:LabVIEW中EPICS客户端/服务端的测试 变量定义 You can use CaLabSoftIOC.vi to create new EPICS variables and start them. CA Lab - LabVIEW (Realtime) EPICS INPUT: PV set Cluster-array of names, data types and field definitions to crea…...
基于深度学习的农作物病害检测
基于深度学习的农作物病害检测利用卷积神经网络(CNN)、生成对抗网络(GAN)、Transformer等深度学习技术,自动识别和分类农作物的病害,帮助农业工作者提高作物管理效率、减少损失。 1. 农作物病害检测的挑战…...
【C#】命名规范
文章目录 C# 命名规范使用Pascal case使用Camel case方法、属性、类命名见名知义LINQ查询变量使用有意义的名称如何声明成员变量和字段正确格式化和缩进代码如何撰写备注 通用C#编码最佳实践如何将值与空字符串进行比较使用异常处理使用&&和||可获得更好的性能单一职责…...
2025年能源电力系统与流体力学国际会议 (EPSFD 2025)
2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...
反射获取方法和属性
Java反射获取方法 在Java中,反射(Reflection)是一种强大的机制,允许程序在运行时访问和操作类的内部属性和方法。通过反射,可以动态地创建对象、调用方法、改变属性值,这在很多Java框架中如Spring和Hiberna…...
12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
R语言速释制剂QBD解决方案之三
本文是《Quality by Design for ANDAs: An Example for Immediate-Release Dosage Forms》第一个处方的R语言解决方案。 第一个处方研究评估原料药粒径分布、MCC/Lactose比例、崩解剂用量对制剂CQAs的影响。 第二处方研究用于理解颗粒外加硬脂酸镁和滑石粉对片剂质量和可生产…...
【JVM】Java虚拟机(二)——垃圾回收
目录 一、如何判断对象可以回收 (一)引用计数法 (二)可达性分析算法 二、垃圾回收算法 (一)标记清除 (二)标记整理 (三)复制 (四ÿ…...
uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...
ubuntu22.04有线网络无法连接,图标也没了
今天突然无法有线网络无法连接任何设备,并且图标都没了 错误案例 往上一顿搜索,试了很多博客都不行,比如 Ubuntu22.04右上角网络图标消失 最后解决的办法 下载网卡驱动,重新安装 操作步骤 查看自己网卡的型号 lspci | gre…...
VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...
客户案例 | 短视频点播企业海外视频加速与成本优化:MediaPackage+Cloudfront 技术重构实践
01技术背景与业务挑战 某短视频点播企业深耕国内用户市场,但其后台应用系统部署于东南亚印尼 IDC 机房。 随着业务规模扩大,传统架构已较难满足当前企业发展的需求,企业面临着三重挑战: ① 业务:国内用户访问海外服…...
Java后端检查空条件查询
通过抛出运行异常:throw new RuntimeException("请输入查询条件!");BranchWarehouseServiceImpl.java // 查询试剂交易(入库/出库)记录Overridepublic List<BranchWarehouseTransactions> queryForReagent(Branch…...
