当前位置: 首页 > news >正文

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)

时间VAE

    vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32,  # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None,  # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x

相关文章:

Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成 前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深…...

基于微信平台的旅游出行必备商城小程序+ssm(lw+演示+源码+运行)

摘 要 随着社会的发展,社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景,运用软件工程原理和开发方法,它主要是采用java语言技术和mysql数据库来完成对系统的设计。整个…...

AI绘画:科技赋能艺术的崭新时代

💯AI绘画:走进艺术创新的新时代 人工智能在改变世界的过程中,AI绘画工具逐渐成为创新的典范。 本文将为您揭示AI绘画背后的技术秘密、潜在的应用场景,并为您推荐几款出色的AI绘画工具,助您领略这一技术带来的艺术新体…...

性能诊断的方法(四):自下而上的资源诊断方法和发散的异常信息诊断方法

关于性能诊断的方法,我们可以按照“问题现象—直接原因—问题根源”这样一个思路去归纳。我们先从问题的现象去入手,包括时间的分析、资源的分析和异常信息的分析。接下来再去分析产生问题现象的直接原因是什么,这里我们归纳了自上而下的资源…...

GDPU Vue前端框架开发 计数器

计数器算不到你双向绑定的进度。 重要的更新公告 !!!GDPU的小伙伴,感谢大家的支持,希望到此一游的帅哥美女能有所帮助。本学期的前端框架及移动应用,采用专栏订阅量达到50才开始周更了哦( •̀ .̫ •́ )✧…...

最大流笔记

概念 求两点间的路径中可在同一时间内通过的最大量 EK算法 通过bfs找通路&#xff0c;找到后回溯&#xff1b; 每确定一条边时&#xff0c;同时建立一天反方向的边以用来进行反悔操作&#xff08;毕竟一次性找到正确方案的概率太低了&#xff09; code #include<bits/st…...

el-tree父子不互相关联时,手动实现全选、反选、子级全选、清空功能

el-tree父子不互相关联时&#xff0c;手动实现全选、反选、子级全选、清空功能 1、功能实现图示 2、实现思路 当属性check-strictly为true时&#xff0c;父子节点不互相关联&#xff0c;如果需要全部选中或选择某一节点下的全部节点就必须手动选择每个节点&#xff0c;十分麻…...

模板与泛型编程笔记(一)入门篇

1. 推荐书籍 《C新经典 模板与泛型编程》难得的很容易看得懂的好书&#xff0c;作者讲技术不跳跃&#xff0c;娓娓道来&#xff0c;只要花点时间就能看懂。 2. 笔记 2.1 模板基础 模板为什么要用尖括号&#xff1f;因为便于编译器解析&#xff0c;可以将模板和普通函数声明…...

浅谈WebApi

一、基本介绍 Web API&#xff08;Web应用程序编程接口&#xff09;是一种用于构建应用程序的接口&#xff0c;它允许软件应用程序通过HTTP请求与Web服务器进行交互。Web API通常用于构建客户端-服务器应用程序&#xff0c;其中客户端可以是Web浏览器、移动应用程序、桌面应用程…...

9月14日,每日信息差

第一、宝马集团宣布对设计部门进行重组&#xff0c;并将于 2024 年 10 月 1 日成立一个跨品牌设计团队&#xff0c;由范・霍伊顿克领导。该团队将引入极星汽车设计主管马克西米利安・米索尼&#xff0c;负责宝马中高档和豪华车型以及宝马 Alpina 的设计工作。 第二、小鹏汇天飞…...

无人机控制与三维AI感知处理平台正式上线!

低空经济被誉为推动我国经济高质量发展的全新增长引擎&#xff0c;是一种以民用有人驾驶和无人驾驶航空器的各类低空飞行活动为牵引&#xff0c;辐射带动相关领域融合发展的综合性经济形态&#xff0c;2024年全国两会首次被纳入政府工作报告。 大势智慧积极响应国家低空经济政…...

9.11-kubeadm方式安装k8s

一、安装环境 编号主机名称ip地址1k8s-master192.168.2.662k8s-node01192.168.2.773k8s-node02192.168.2.88 二、前期准备 1.设置免密登录 [rootk8s-master ~]# ssh-keygen [rootk8s-master ~]# ssh-copy-id root192.168.2.77 [rootk8s-master ~]# ssh-copy-id root192.168…...

限流,流量整形算法

写在前面 源码 。 本文看下流量整形相关算法。 目前流量整形算法主要有三种&#xff0c;计数器&#xff0c;漏桶&#xff0c;令牌桶。分别看下咯&#xff01; 1&#xff1a;计数器 1.1&#xff1a;描述 单位时间内只允许指定数量的请求&#xff0c;如果是时间区间内超过指…...

【C++知识扫盲】------C++ 中的引用入门

在 C 中&#xff0c;引用&#xff08;reference&#xff09; 是一个非常重要的概念&#xff0c;它提供了一种别名机制&#xff0c;让我们可以给已经存在的变量起一个新的名字&#xff0c;并且能够通过这个别名直接操作原始变量。本文将详细介绍引用的定义、使用场景及其与指针的…...

【机器学习】6 ——最大熵模型

机器学习6——最大熵模型 目录 机器学习6——最大熵模型最大熵&#xff08;maximum entropy&#xff09;模型模型模型学习&#xff08;估计参数&#xff09;模型评价应用 最大熵&#xff08;maximum entropy&#xff09;模型 选择熵最大的概率模型 熵是衡量不确定性的&#xf…...

小程序——生命周期

文章目录 运行机制更新机制生命周期介绍应用级别生命周期页面级别生命周期组件生命周期生命周期两个细节补充说明总结 运行机制 用一张图简要概述一下小程序的运行机制 冷启动与热启动&#xff1a; 小程序启动可以分为两种情况&#xff0c;一种是冷启动&#xff0c;一种是热…...

基于微信小程序的宠物之家的设计与实现

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于微信小程序JavaSpringBootVueMySQL的宠物之家/宠物综合…...

自定义EPICS在LabVIEW中的测试

继续上一篇&#xff1a;LabVIEW中EPICS客户端/服务端的测试 变量定义 You can use CaLabSoftIOC.vi to create new EPICS variables and start them. CA Lab - LabVIEW (Realtime) EPICS INPUT: PV set Cluster-array of names, data types and field definitions to crea…...

基于深度学习的农作物病害检测

基于深度学习的农作物病害检测利用卷积神经网络&#xff08;CNN&#xff09;、生成对抗网络&#xff08;GAN&#xff09;、Transformer等深度学习技术&#xff0c;自动识别和分类农作物的病害&#xff0c;帮助农业工作者提高作物管理效率、减少损失。 1. 农作物病害检测的挑战…...

【C#】命名规范

文章目录 C# 命名规范使用Pascal case使用Camel case方法、属性、类命名见名知义LINQ查询变量使用有意义的名称如何声明成员变量和字段正确格式化和缩进代码如何撰写备注 通用C#编码最佳实践如何将值与空字符串进行比较使用异常处理使用&&和||可获得更好的性能单一职责…...

Chapter03-Authentication vulnerabilities

文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中&#xff0c;我们已经大致实现了rpc服务端的各项功能代…...

黑马Mybatis

Mybatis 表现层&#xff1a;页面展示 业务层&#xff1a;逻辑处理 持久层&#xff1a;持久数据化保存 在这里插入图片描述 Mybatis快速入门 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/6501c2109c4442118ceb6014725e48e4.png //logback.xml <?xml ver…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

最新SpringBoot+SpringCloud+Nacos微服务框架分享

文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的&#xff0c;根据Excel列的需求预估的工时直接打骨折&#xff0c;不要问我为什么&#xff0c;主要…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

pgsql:还原数据库后出现重复序列导致“more than one owned sequence found“报错问题的解决

问题&#xff1a; pgsql数据库通过备份数据库文件进行还原时&#xff0c;如果表中有自增序列&#xff0c;还原后可能会出现重复的序列&#xff0c;此时若向表中插入新行时会出现“more than one owned sequence found”的报错提示。 点击菜单“其它”-》“序列”&#xff0c;…...

WEB3全栈开发——面试专业技能点P4数据库

一、mysql2 原生驱动及其连接机制 概念介绍 mysql2 是 Node.js 环境中广泛使用的 MySQL 客户端库&#xff0c;基于 mysql 库改进而来&#xff0c;具有更好的性能、Promise 支持、流式查询、二进制数据处理能力等。 主要特点&#xff1a; 支持 Promise / async-await&#xf…...

写一个shell脚本,把局域网内,把能ping通的IP和不能ping通的IP分类,并保存到两个文本文件里

写一个shell脚本&#xff0c;把局域网内&#xff0c;把能ping通的IP和不能ping通的IP分类&#xff0c;并保存到两个文本文件里 脚本1 #!/bin/bash #定义变量 ip10.1.1 #循环去ping主机的IP for ((i1;i<10;i)) doping -c1 $ip.$i &>/dev/null[ $? -eq 0 ] &&am…...