当前位置: 首页 > news >正文

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)

时间VAE

    vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32,  # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None,  # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x

相关文章:

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成 前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深…...

基于微信平台的旅游出行必备商城小程序+ssm(lw+演示+源码+运行)

摘 要 随着社会的发展,社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景,运用软件工程原理和开发方法,它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…...

AI绘画:科技赋能艺术的崭新时代

💯AI绘画:走进艺术创新的新时代 人工智能在改变世界的过程中,AI绘画工具逐渐成为创新的典范。 本文将为您揭示AI绘画背后的技术秘密、潜在的应用场景,并为您推荐几款出色的AI绘画工具,助您领略这一技术带来的艺术新体…...

性能诊断的方法(四):自下而上的资源诊断方法和发散的异常信息诊断方法

关于性能诊断的方法,我们可以按照“问题现象—直接原因—问题根源”这样一个思路去归纳。我们先从问题的现象去入手,包括时间的分析、资源的分析和异常信息的分析。接下来再去分析产生问题现象的直接原因是什么,这里我们归纳了自上而下的资源…...

GDPU Vue前端框架开发 计数器

计数器算不到你双向绑定的进度。 重要的更新公告 !!!GDPU的小伙伴,感谢大家的支持,希望到此一游的帅哥美女能有所帮助。本学期的前端框架及移动应用,采用专栏订阅量达到50才开始周更了哦( •̀ .̫ •́ )✧…...

最大流笔记

概念 求两点间的路径中可在同一时间内通过的最大量 EK算法 通过bfs找通路&#xff0c;找到后回溯&#xff1b; 每确定一条边时&#xff0c;同时建立一天反方向的边以用来进行反悔操作&#xff08;毕竟一次性找到正确方案的概率太低了&#xff09; code #include<bits/st…...

el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能

el-tree父子不互相关联时&#xff0c;手动实现全选、反选、子级全选、清空功能 1、功能实现图示 2、实现思路 当属性check-strictly为true时&#xff0c;父子节点不互相关联&#xff0c;如果需要全部选中或选择某一节点下的全部节点就必须手动选择每个节点&#xff0c;十分麻…...

模板与泛型编程笔记(一)入门篇

1. 推荐书籍 《C新经典 模板与泛型编程》难得的很容易看得懂的好书&#xff0c;作者讲技术不跳跃&#xff0c;娓娓道来&#xff0c;只要花点时间就能看懂。 2. 笔记 2.1 模板基础 模板为什么要用尖括号&#xff1f;因为便于编译器解析&#xff0c;可以将模板和普通函数声明…...

浅谈WebApi

一、基本介绍 Web API&#xff08;Web应用程序编程接口&#xff09;是一种用于构建应用程序的接口&#xff0c;它允许软件应用程序通过HTTP请求与Web服务器进行交互。Web API通常用于构建客户端-服务器应用程序&#xff0c;其中客户端可以是Web浏览器、移动应用程序、桌面应用程…...

9月14日,每日信息差

第一、宝马集团宣布对设计部门进行重组&#xff0c;并将于 2024 年 10 月 1 日成立一个跨品牌设计团队&#xff0c;由范・霍伊顿克领导。该团队将引入极星汽车设计主管马克西米利安・米索尼&#xff0c;负责宝马中高档和豪华车型以及宝马 Alpina 的设计工作。 第二、小鹏汇天飞…...

无人机控制与三维AI感知处理平台正式上线!

低空经济被誉为推动我国经济高质量发展的全新增长引擎&#xff0c;是一种以民用有人驾驶和无人驾驶航空器的各类低空飞行活动为牵引&#xff0c;辐射带动相关领域融合发展的综合性经济形态&#xff0c;2024年全国两会首次被纳入政府工作报告。 大势智慧积极响应国家低空经济政…...

9.11-kubeadm方式安装k8s

一、安装环境 编号主机名称ip地址1k8s-master192.168.2.662k8s-node01192.168.2.773k8s-node02192.168.2.88 二、前期准备 1.设置免密登录 [rootk8s-master ~]# ssh-keygen [rootk8s-master ~]# ssh-copy-id root192.168.2.77 [rootk8s-master ~]# ssh-copy-id root192.168…...

限流,流量整形算法

写在前面 源码 。 本文看下流量整形相关算法。 目前流量整形算法主要有三种&#xff0c;计数器&#xff0c;漏桶&#xff0c;令牌桶。分别看下咯&#xff01; 1&#xff1a;计数器 1.1&#xff1a;描述 单位时间内只允许指定数量的请求&#xff0c;如果是时间区间内超过指…...

【C++知识扫盲】------C++ 中的引用入门

在 C 中&#xff0c;引用&#xff08;reference&#xff09; 是一个非常重要的概念&#xff0c;它提供了一种别名机制&#xff0c;让我们可以给已经存在的变量起一个新的名字&#xff0c;并且能够通过这个别名直接操作原始变量。本文将详细介绍引用的定义、使用场景及其与指针的…...

【机器学习】6 ——最大熵模型

机器学习6——最大熵模型 目录 机器学习6——最大熵模型最大熵&#xff08;maximum entropy&#xff09;模型模型模型学习&#xff08;估计参数&#xff09;模型评价应用 最大熵&#xff08;maximum entropy&#xff09;模型 选择熵最大的概率模型 熵是衡量不确定性的&#xf…...

小程序——生命周期

文章目录 运行机制更新机制生命周期介绍应用级别生命周期页面级别生命周期组件生命周期生命周期两个细节补充说明总结 运行机制 用一张图简要概述一下小程序的运行机制 冷启动与热启动&#xff1a; 小程序启动可以分为两种情况&#xff0c;一种是冷启动&#xff0c;一种是热…...

基于微信小程序的宠物之家的设计与实现

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSpringBootVueMySQL的宠物之家/宠物综合…...

自定义EPICS在LabVIEW中的测试

继续上一篇&#xff1a;LabVIEW中EPICS客户端/服务端的测试 变量定义 You can use CaLabSoftIOC.vi to create new EPICS variables and start them. CA Lab - LabVIEW (Realtime) EPICS INPUT: PV set Cluster-array of names, data types and field definitions to crea…...

基于深度学习的农作物病害检测

基于深度学习的农作物病害检测利用卷积神经网络&#xff08;CNN&#xff09;、生成对抗网络&#xff08;GAN&#xff09;、Transformer等深度学习技术&#xff0c;自动识别和分类农作物的病害&#xff0c;帮助农业工作者提高作物管理效率、减少损失。 1. 农作物病害检测的挑战…...

【C#】命名规范

文章目录 C# 命名规范使用Pascal case使用Camel case方法、属性、类命名见名知义LINQ查询变量使用有意义的名称如何声明成员变量和字段正确格式化和缩进代码如何撰写备注 通用C#编码最佳实践如何将值与空字符串进行比较使用异常处理使用&&和||可获得更好的性能单一职责…...

超级帐本(Hyperledger)

1. Hyperledger 项目 Hyperledger 下有两类项目:第一类是区块链框架项目;第二类是支持这些区块链的相关工具或模块。 在 Hyperledger 框架下&#xff0c;目前有 5 个区块链框架项目&#xff1a;Fabric、Sawtooth Lake、Iroha、Burrow 和 Indy。 在模块类下&#xff0c;则有 Hyp…...

如何精细优化网站关键词排名:实战经验分享

在数字营销日益激烈的今天&#xff0c;我深知每一个关键词的排名都关乎着网站的流量与转化。凭借多年的实战经验&#xff0c;我深刻体会到&#xff0c;要想在浩如烟海的网络世界中脱颖而出&#xff0c;精细化的关键词优化策略至关重要。今天&#xff0c;我将从实战角度出发&…...

Ruoyi Cloud 本地启动

本文视频版本&#xff1a;https://www.bilibili.com/video/BV1SNtueBE9M 参考 http://doc.ruoyi.vip/ https://gitee.com/y_project/RuoYi-Cloud https://blog.csdn.net/cs_dnzk/article/details/135289966 https://doc.ruoyi.vip/ruoyi-cloud/cloud/seata.html#%E5%9F%BA%E6…...

Nginx解析:入门笔记

&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》《MYSQL》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 ✨欢迎加入探索nginx之旅✨ &#x1f44b; 大家好&#xff01;文本学习和探索Nginx配置。…...

在 Mac 上安装双系统会影响性能吗,安装双系统会清除数据吗?

在 Mac 系统安装并使用双系统已经成为了许多用户办公的选择之一&#xff0c;双系统可以让用户在 Mac 上同时运行 Windows 或其他操作系统。然而&#xff0c;许多用户担心这样做会对 Mac 的性能产生影响。 接下来将给大家介绍 Mac 装双系统会影响性能吗&#xff0c;Mac装双系统…...

vue3提交按钮限制重复点击

下载lodash npm install lodash 引入并使用 <template><div click"submit()">提交</div> </template><script setup>import { debounce } from lodash;const submit debounce(() > {//业务代码},2000,{leading: true,trailing:…...

Java | Leetcode Java题解之第395题至少有K个重复字符的最长子串

题目&#xff1a; 题解&#xff1a; class Solution {public int longestSubstring(String s, int k) {int ret 0;int n s.length();for (int t 1; t < 26; t) {int l 0, r 0;int[] cnt new int[26];int tot 0;int less 0;while (r < n) {cnt[s.charAt(r) - a];…...

20240915 每日AI必读资讯

国家网信办发布《人工智能生成合成内容标识办法&#xff08;征求意见稿&#xff09;》 - 要求所有的AI生成内容都要打标&#xff0c;包括文字、图像、视频、音频… - 文本内容要插入标识符提醒&#xff0c;音频内容要在里面插入提示音 - 对创作者不太友好&#xff0c;对平台…...

量化交易需要注意的关于股票交易挂单排队规则的问题

炒股自动化&#xff1a;申请官方API接口&#xff0c;散户也可以 python炒股自动化&#xff08;0&#xff09;&#xff0c;申请券商API接口 python炒股自动化&#xff08;1&#xff09;&#xff0c;量化交易接口区别 Python炒股自动化&#xff08;2&#xff09;&#xff1a;获取…...

应急响应实战---是谁修改了我的密码?

前言&#xff1a;此次应急响应为真实案例&#xff0c;客户反馈无法通过密码登录服务器&#xff0c;疑似服务器被入侵 0x01 如何找回密码&#xff1f; 客户服务器为windows server2019&#xff0c;运维平台为PVE平台&#xff1b;实际上无论是windows系统或者是linux系统&#…...