当前位置: 首页 > news >正文

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)

时间VAE

    vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32,  # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None,  # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x

相关文章:

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成 前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深…...

基于微信平台的旅游出行必备商城小程序+ssm(lw+演示+源码+运行)

摘 要 随着社会的发展,社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景,运用软件工程原理和开发方法,它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…...

AI绘画:科技赋能艺术的崭新时代

💯AI绘画:走进艺术创新的新时代 人工智能在改变世界的过程中,AI绘画工具逐渐成为创新的典范。 本文将为您揭示AI绘画背后的技术秘密、潜在的应用场景,并为您推荐几款出色的AI绘画工具,助您领略这一技术带来的艺术新体…...

性能诊断的方法(四):自下而上的资源诊断方法和发散的异常信息诊断方法

关于性能诊断的方法,我们可以按照“问题现象—直接原因—问题根源”这样一个思路去归纳。我们先从问题的现象去入手,包括时间的分析、资源的分析和异常信息的分析。接下来再去分析产生问题现象的直接原因是什么,这里我们归纳了自上而下的资源…...

GDPU Vue前端框架开发 计数器

计数器算不到你双向绑定的进度。 重要的更新公告 !!!GDPU的小伙伴,感谢大家的支持,希望到此一游的帅哥美女能有所帮助。本学期的前端框架及移动应用,采用专栏订阅量达到50才开始周更了哦( •̀ .̫ •́ )✧…...

最大流笔记

概念 求两点间的路径中可在同一时间内通过的最大量 EK算法 通过bfs找通路&#xff0c;找到后回溯&#xff1b; 每确定一条边时&#xff0c;同时建立一天反方向的边以用来进行反悔操作&#xff08;毕竟一次性找到正确方案的概率太低了&#xff09; code #include<bits/st…...

el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能

el-tree父子不互相关联时&#xff0c;手动实现全选、反选、子级全选、清空功能 1、功能实现图示 2、实现思路 当属性check-strictly为true时&#xff0c;父子节点不互相关联&#xff0c;如果需要全部选中或选择某一节点下的全部节点就必须手动选择每个节点&#xff0c;十分麻…...

模板与泛型编程笔记(一)入门篇

1. 推荐书籍 《C新经典 模板与泛型编程》难得的很容易看得懂的好书&#xff0c;作者讲技术不跳跃&#xff0c;娓娓道来&#xff0c;只要花点时间就能看懂。 2. 笔记 2.1 模板基础 模板为什么要用尖括号&#xff1f;因为便于编译器解析&#xff0c;可以将模板和普通函数声明…...

浅谈WebApi

一、基本介绍 Web API&#xff08;Web应用程序编程接口&#xff09;是一种用于构建应用程序的接口&#xff0c;它允许软件应用程序通过HTTP请求与Web服务器进行交互。Web API通常用于构建客户端-服务器应用程序&#xff0c;其中客户端可以是Web浏览器、移动应用程序、桌面应用程…...

9月14日,每日信息差

第一、宝马集团宣布对设计部门进行重组&#xff0c;并将于 2024 年 10 月 1 日成立一个跨品牌设计团队&#xff0c;由范・霍伊顿克领导。该团队将引入极星汽车设计主管马克西米利安・米索尼&#xff0c;负责宝马中高档和豪华车型以及宝马 Alpina 的设计工作。 第二、小鹏汇天飞…...

无人机控制与三维AI感知处理平台正式上线!

低空经济被誉为推动我国经济高质量发展的全新增长引擎&#xff0c;是一种以民用有人驾驶和无人驾驶航空器的各类低空飞行活动为牵引&#xff0c;辐射带动相关领域融合发展的综合性经济形态&#xff0c;2024年全国两会首次被纳入政府工作报告。 大势智慧积极响应国家低空经济政…...

9.11-kubeadm方式安装k8s

一、安装环境 编号主机名称ip地址1k8s-master192.168.2.662k8s-node01192.168.2.773k8s-node02192.168.2.88 二、前期准备 1.设置免密登录 [rootk8s-master ~]# ssh-keygen [rootk8s-master ~]# ssh-copy-id root192.168.2.77 [rootk8s-master ~]# ssh-copy-id root192.168…...

限流,流量整形算法

写在前面 源码 。 本文看下流量整形相关算法。 目前流量整形算法主要有三种&#xff0c;计数器&#xff0c;漏桶&#xff0c;令牌桶。分别看下咯&#xff01; 1&#xff1a;计数器 1.1&#xff1a;描述 单位时间内只允许指定数量的请求&#xff0c;如果是时间区间内超过指…...

【C++知识扫盲】------C++ 中的引用入门

在 C 中&#xff0c;引用&#xff08;reference&#xff09; 是一个非常重要的概念&#xff0c;它提供了一种别名机制&#xff0c;让我们可以给已经存在的变量起一个新的名字&#xff0c;并且能够通过这个别名直接操作原始变量。本文将详细介绍引用的定义、使用场景及其与指针的…...

【机器学习】6 ——最大熵模型

机器学习6——最大熵模型 目录 机器学习6——最大熵模型最大熵&#xff08;maximum entropy&#xff09;模型模型模型学习&#xff08;估计参数&#xff09;模型评价应用 最大熵&#xff08;maximum entropy&#xff09;模型 选择熵最大的概率模型 熵是衡量不确定性的&#xf…...

小程序——生命周期

文章目录 运行机制更新机制生命周期介绍应用级别生命周期页面级别生命周期组件生命周期生命周期两个细节补充说明总结 运行机制 用一张图简要概述一下小程序的运行机制 冷启动与热启动&#xff1a; 小程序启动可以分为两种情况&#xff0c;一种是冷启动&#xff0c;一种是热…...

基于微信小程序的宠物之家的设计与实现

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSpringBootVueMySQL的宠物之家/宠物综合…...

自定义EPICS在LabVIEW中的测试

继续上一篇&#xff1a;LabVIEW中EPICS客户端/服务端的测试 变量定义 You can use CaLabSoftIOC.vi to create new EPICS variables and start them. CA Lab - LabVIEW (Realtime) EPICS INPUT: PV set Cluster-array of names, data types and field definitions to crea…...

基于深度学习的农作物病害检测

基于深度学习的农作物病害检测利用卷积神经网络&#xff08;CNN&#xff09;、生成对抗网络&#xff08;GAN&#xff09;、Transformer等深度学习技术&#xff0c;自动识别和分类农作物的病害&#xff0c;帮助农业工作者提高作物管理效率、减少损失。 1. 农作物病害检测的挑战…...

【C#】命名规范

文章目录 C# 命名规范使用Pascal case使用Camel case方法、属性、类命名见名知义LINQ查询变量使用有意义的名称如何声明成员变量和字段正确格式化和缩进代码如何撰写备注 通用C#编码最佳实践如何将值与空字符串进行比较使用异常处理使用&&和||可获得更好的性能单一职责…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包&#xff1a; for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

Unity | AmplifyShaderEditor插件基础(第七集:平面波动shader)

目录 一、&#x1f44b;&#x1f3fb;前言 二、&#x1f608;sinx波动的基本原理 三、&#x1f608;波动起来 1.sinx节点介绍 2.vertexPosition 3.集成Vector3 a.节点Append b.连起来 4.波动起来 a.波动的原理 b.时间节点 c.sinx的处理 四、&#x1f30a;波动优化…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...

算法岗面试经验分享-大模型篇

文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer &#xff08;1&#xff09;资源 论文&a…...

C++:多态机制详解

目录 一. 多态的概念 1.静态多态&#xff08;编译时多态&#xff09; 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1&#xff09;.协变 2&#xff09;.析构函数的重写 5.override 和 final关键字 1&#…...

LLMs 系列实操科普(1)

写在前面&#xff1a; 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容&#xff0c;原视频时长 ~130 分钟&#xff0c;以实操演示主流的一些 LLMs 的使用&#xff0c;由于涉及到实操&#xff0c;实际上并不适合以文字整理&#xff0c;但还是决定尽量整理一份笔…...

iview框架主题色的应用

1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题&#xff0c;无需引入&#xff0c;直接可…...