当前位置: 首页 > news >正文

论文学习——VideoGPT

论文学习——VideoGPT: Video Generation using VQ-VAE and Transformers

原文链接:https://arxiv.org/abs/2104.10157

1. 设计思路

不同种类的生成模型在一下多个维度各有权衡:采样速度、样本多样性、样本质量、优化稳定性、计算需求、评估难易程度等等。

这些模型,除分数匹配模型(score-matching models)之外,广义上可以分为基于似然的模型(PixelCNNs, iGPT, NVAE, VQ-VAE, Glow)和对抗生成模型(GANs)。那么哪一类的模型适于研究和视频生成任务呢?

首先,从两大类模型中进行选择。基于似然的模型训练更为方便,因为目标是很容易理解的,在不同的batch size上都很容易优化,相对于GANs的判别器来讲,也十分易于评估。考虑到由于数据的性质,对视频任务建模已经是一个较大的挑战,因此我们任务基于似然的模型在优化和评估过程中存在的困难较少,因此可以关注于结构的改进上。

其次,在许多基于似然的模型中,我们选择了自回归模型,仅因为其在离散数据上运行良好,在样本质量上表现优异,且训练方法和模型架构上较为乘数,可以利用transformer中的最新改进。

在自回归模型中,考虑如下问题:自回归模型是在没有时空冗余的下采样潜空间内进行建模更好,还是在时空领域的所有帧、所有像素上训练好呢?考虑到自然视频的时空上的冗余度,作者选择了前者,通过将高维输入编码乘一个去噪后的下采样编码的方式去除冗余度。如在时空上进行4倍下采样,总分辨率就是64倍下采样,因此生成模型就能在更少更有用的信息上倾注计算量。如在VQ-VAE上,即使一个残缺的decoder也能将潜向量转化为足够真实的样本。并且在潜空间内建模也提升了计算速度。

上述三个原因促使VideoGPT的产生,这是一款使用基于似然的生成式模型,生成对象是自然视频。VideoGPT主体上有两个结构:VQ-VAE和GPT。

VQ-VAE中的autoencoder,通过3d卷积和轴向的注意力机制(axial self.attention)来从视频中学习其下采样潜空间的离散表征。

而类似于GPT的架构(强大的自回归先验)可以使用时空位置编码来为(VQ-VAE获得的)离散潜向量自回归地建模。

上述过程得到的潜向量再通过VQ-VAE的解码器,恢复为原像素规模的视频

后续在消融实验中,作者研究了axial attention blocks的优点、VQ-VAE潜空间大小、codebooks的输入、自回归先验的容量(模型大小)的影响。

2. 具体实现

VideoGPT的整体结构如下图所示:
在这里插入图片描述
将模型分为两个部分进行讲解:

2.1 学习潜编码——VQ-VAE

为了学习到离散的latent code,首先在视频数据上训练VQ-VAE。编码器在时空维度使用3d卷积进行下采样,然后是残差注意力模块,该模块的结构如下所示,在模块中使用layernorm和轴向注意力机制。
在这里插入图片描述
解码器的结构则是编码器的反向,先通过残差注意力模块,再通过3d转置卷积,在时空维度上进行上采样。位置编码是学习到的时间+空间上的嵌入,它们可以在encoder和decoder之间,所有轴向注意力层中共享。

关于VQ-VAE的轴向注意力,下面对其代码进行展示:

(1)需要注意的是VQ-VAE分为encoder和decoder,两部分对称。

class VQVAE(pl.LightningModule):def __init__(self, args):super().__init__()self.args = args# codebooks 中embedding的维度self.embedding_dim = args.embedding_dim# codebook中code 的数目self.n_codes = args.n_codes# n_hiddens: 残差块儿中隐藏特征的数目# n_res_layers: 残差块儿的数目# downsample: T, H, W三个维度下采样倍数self.encoder = Encoder(args.n_hiddens, args.n_res_layers, args.downsample)self.decoder = Decoder(args.n_hiddens, args.n_res_layers, args.downsample)

(2)以encoder为例,其残差层数目n_res_layers取值为4,故而其self.res_stack部分共有4层

class Encoder(nn.Module):def __init__(self, n_hiddens, n_res_layers, downsample):super().__init__()n_times_downsample = np.array([int(math.log2(d)) for d in downsample])self.convs = nn.ModuleList()max_ds = n_times_downsample.max()for i in range(max_ds):in_channels = 3 if i == 0 else n_hiddensstride = tuple([2 if d > 0 else 1 for d in n_times_downsample])conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride)self.convs.append(conv)n_times_downsample -= 1self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3)self.res_stack = nn.Sequential(*[AttentionResidualBlock(n_hiddens)for _ in range(n_res_layers)],nn.BatchNorm3d(n_hiddens),nn.ReLU())

(3)对于其中的每一层AttentionResidualBlock,即之前图中所提的残差注意力模块,模块的末端各对应一个AxialBlock,每个AxialBlock中对应时空三个维度的多头注意力机制

class AxialBlock(nn.Module):def __init__(self, n_hiddens, n_head):super().__init__()kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens,dim_kv=n_hiddens, n_head=n_head,n_layer=1, causal=False, attn_type='axial')self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2),**kwargs)self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3),**kwargs)self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4),**kwargs)

(4)对于每一个多头注意力,其注意力部分对应一个AxialAttention机制

class AxialAttention(nn.Module):def __init__(self, n_dim, axial_dim):super().__init__()# encoder 里4个attentionResidualBlock,对应4组axial-attention,每组3个# decoder 结构上与encoder对称,故也有4个attrntionResidualBlock# print(n_dim, axial_dim) # 如下内容,共8组,应该是共8个attention block# 3 -2# 3 -3# 3 -4if axial_dim < 0:axial_dim = 2 + n_dim + 1 + axial_dimelse:axial_dim += 2 # account for batch, head, dimself.axial_dim = axial_dimdef forward(self, q, k, v, decode_step, decode_idx):q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3)k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3)v = shift_dim(v, self.axial_dim, -2)old_shape = list(v.shape)v = v.flatten(end_dim=-3)# scaled dot-product attention,计算分类结果out = scaled_dot_product_attention(q, k, v, training=self.training)out = out.view(*old_shape)out = shift_dim(out, -2, self.axial_dim)return out

以上

2.2 学习到先验

模型的第二阶段是在VQ-VAE第一阶段的latent code的基础上学习一个先验。先验网络遵循Image-GPT的结构,另外还在feedforward layer和注意力块儿后面加入了dropout,以实现正则化。

以上过程是无条件限制的情况下进行训练的。可以通过训练带条件的先验(conditional prior)来生成conditional samples。条件限制有两种方法:

  • 交叉注意力(Cross Attention):作为视频帧的限制,在先验网络的训练过程中,我们首先向3d的resnet中喂入有条件限制的帧,然后再resnet的输出上使用cross attention。
  • 条件正则(Conditional Norms):与GANs中使用的限制方法相似,我们在transformer的Layer Normalization层上参数化gain和bias,作为条件张量(conditional vector)的放射函数。这种方法适用于对动作和类别进行限制的模型

相关文章:

论文学习——VideoGPT

论文学习——VideoGPT: Video Generation using VQ-VAE and Transformers 原文链接&#xff1a;https://arxiv.org/abs/2104.10157 1. 设计思路 不同种类的生成模型在一下多个维度各有权衡&#xff1a;采样速度、样本多样性、样本质量、优化稳定性、计算需求、评估难易程度等…...

Flutter系列(五)底部导航详解

Flutter系列&#xff08;四&#xff09;底部导航顶部导航图文列表完整代码&#xff0c;如下&#xff1a; Flutter系列&#xff08;四&#xff09;底部导航顶部导航图文列表完整代码_摸金青年v的博客-CSDN博客 目录 一、前言 二、Scaffold组件 三、BottomNavigationBar组件 …...

『pyqt5 从0基础开始项目实战』02. 页面布局设计(保姆级图文)

目录弹性布局介绍导包和框架代码布局框架搭建1. 总体布局框架2. 顶部菜单布局3. form添加内容布局4. table数据展示布局5. footer底部菜单完整项目代码总结欢迎关注 『pyqt5 从0基础开始项目实战』 专栏&#xff0c;持续更新中 欢迎关注 『pyqt5 从0基础开始项目实战』 专栏&am…...

【Python机器学习】——平均中位数模式

Python机器学习——平均中位数模式 文章目录 Python机器学习——平均中位数模式一、Python 平均中位数模式一、Python 平均中位数模式 均值、中值和众数 从一组数字中我们可以学到什么? 在机器学习(和数学)中,通常存在三中我们感兴趣的值: 均值(Mean) - 平均值 中值(M…...

Windows窗口

Windows窗口 Unit01注册窗口类 01窗口类的概念 窗口类是包括了窗口的各种参数信息的数据结构每个窗口都具有窗口类&#xff0c;基于窗口类创建窗口每个窗口都具有一个名称&#xff0c;使用前必须注册到系统 02窗口类的分类 系统窗口类 系统已经定义好的窗口类&#xff0c;…...

Spring Transaction 源码解读

Spring Transaction 规范的maven坐标如下&#xff1a; <dependency><groupId>org.springframework</groupId><artifactId>spring-tx</artifactId><version>...</version></dependency>该包提供了spring事务规范和默认的jta(ja…...

[Netty] Channel和ChannelFuture和ChannelFutureListener (六)

文章目录1.Channel介绍2.ChannelFuture接口介绍3.GenericFutureListener接口介绍1.Channel介绍 NIO的Channel与Netty的Channel 不一样 Netty重新设计了Channel接口,并且给予了很多不同的实现, Channel是Netty网络的抽象类, 除了NIO中Channel所包含的网络I/O操作, 主动建立和关…...

条件渲染

组件经常需要根据不同条件显示不同内容。在React中&#xff0c;你可以使用类似于if语句、&&和?:运算符的JavaScript语法有条件地呈现JSX。你将学到&#xff1a;如何根据条件返回不同的JSX如何有条件地包含或排除一段JSX在React代码库中常见的条件语法快捷方式有条件地…...

springboot(10)异步任务

文章目录1、SpringBoot异步任务1.1使用注解EnableAsync开启异步任务支持1.2使用Async注解标记要进行异步执行的方法1.3controller测试2.异步任务相关限制3.1自定义 Executor3.1.1应用层级&#xff1a;3.1.2方法层级&#xff1a;3.2自定义 Executor (第二种方式)4.1异常处理4.1.…...

清华大学开源的chatGLM-6B部署实战

Windows部署 win10 通过wsl部署 常见问题: torch.cuda.OutOfMemoryError: CUDA out of memory. 在Windows的系统环境变量中增加 变量名:PYTORCH_CUDA_ALLOC_CONF 变量值:max_split_size_mb:32 文档书写时使用3090 24G显存配置,其他规格酌情调整 32 至其他值,如未设置变…...

通过矩阵从整体角度搞懂快速傅里叶变换原理

离散傅里叶变换公式 公式 f[k]∑n0N−1g[n]e−i(2π/N)kn,其中(0<n<N)f[k]\sum_{n0}^{N-1}g[n]e^{-i(2\pi/N)kn}, 其中(0<n<N) f[k]n0∑N−1​g[n]e−i(2π/N)kn,其中(0<n<N) 逆变换公式 g[n]1N∑k0N−1f[k]ei(2π/N)kn,其中(0<k<N)g[n]\frac{1}{N}\…...

【C++从0到1】25、C++中嵌套使用循环

C从0到1全系列教程 1、实例代码 #include <iostream> // 包含头文件。 using namespace std; // 指定缺省的命名空间。int main() {// 超女分4个小组&#xff0c;每个小组有3名超女&#xff0c;在控制台显示每个超女的小组编号和组内编号。// 用一个循环…...

FastDFS与Nginx结合搭建文件服务器,并内网穿透实现公网访问

文章目录前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.1 …...

密集场景下的行人跟踪替代算法,头部跟踪算法 | CVPR 2021

一个不知名大学生&#xff0c;江湖人称菜狗 original author: Jacky LiEmail : 3435673055qq.com Time of completion&#xff1a;2023.4.8 Last edited: 2023.4.8 目录 摘要 主要内容 结果 这篇文章是CVPR 2021 的最新论文&#xff0c;文章的标题&#xff1a; 文章的主要内…...

Matlab与ROS(1/2)---服务端和客户端数据通信(五)

0. 简介 在前几讲我们讲了Matlab中的Message以及Topic的相关知识。而ROS主要支持的通信机制还有服务这一类。服务通过允许请求以及响应的通信方式&#xff0c;来给整个系统完成更紧密的耦合。服务客户端向服务服务器发送请求消息并等待响应。服务器将使用请求中的数据构造响应…...

数字化转型的避坑指南:细说数字化转型十二大坑

随着信息技术的快速发展&#xff0c;数字化转型已经成为许多企业发展的必经之路。然而&#xff0c;数字化转型过程中也存在许多坑&#xff0c;如果不谨慎处理&#xff0c;就可能导致企业陷入困境。本文将细说数字化转型的十二大坑&#xff0c;并提供相应的避坑指南。 1、不了解…...

pt05Encapsulationinherit

Encapsulation &inherit 封装继承 封装 向类外提供必要的功能&#xff0c;隐藏实现的细节, 代码可读性更高优势&#xff1a;简化编程&#xff0c;使用者不必了解具体的实现细节&#xff0c;只需要调用对外提供的功能。私有成员&#xff1a;作用&#xff1a;无需向类外提供…...

面向对象编程(基础)9:封装性(encapsulation)

目录 9.1 为什么需要封装&#xff1f; 而“高内聚&#xff0c;低耦合”的体现之一&#xff1a; 9.2 何为封装性&#xff1f; 9.3 Java如何实现数据封装 9.4 封装性的体现 9.4.1 成员变量/属性私有化 实现步骤&#xff1a; 成员变量封装的好处&#xff1a; 9.4.2 私有化…...

fate-serving-server增加取数逻辑并源码编译

1.什么是fate-serving-server? FATE-Serving 是一个高性能、工业化的联邦学习模型服务系统&#xff0c;专为生产环境而设计,主要用于在线推理。 2.fate-serving-server源码编译 下载fate-serving-serving项目&#xff08;GitHub - FederatedAI/FATE-Serving: A scalable, h…...

循环队列、双端队列 C和C++

队列 目录 概念 实现方式 顺序队列 循环队列 队列的数组实现 用循环链表实现队列 STL 之 queue 实现队列 STL 之 dequeue 实现双端队列 概念 队列是一种特殊的线性表&#xff0c;它只允许在表的前端&#xff08;称为队头&#xff0c;front&#xff09;进行删除操作…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架&#xff0c;专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用&#xff0c;其中包含三个使用通用基本模板的页面。在此…...

K8S认证|CKS题库+答案| 11. AppArmor

目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作&#xff1a; 1&#xff09;、切换集群 2&#xff09;、切换节点 3&#xff09;、切换到 apparmor 的目录 4&#xff09;、执行 apparmor 策略模块 5&#xff09;、修改 pod 文件 6&#xff09;、…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接&#xff1a;A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串&#xff0c;只有在同时为 o 时输出 Yes 并结束程序&#xff0c;否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

优选算法第十二讲:队列 + 宽搜 优先级队列

优选算法第十二讲&#xff1a;队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...

LINUX 69 FTP 客服管理系统 man 5 /etc/vsftpd/vsftpd.conf

FTP 客服管理系统 实现kefu123登录&#xff0c;不允许匿名访问&#xff0c;kefu只能访问/data/kefu目录&#xff0c;不能查看其他目录 创建账号密码 useradd kefu echo 123|passwd -stdin kefu [rootcode caozx26420]# echo 123|passwd --stdin kefu 更改用户 kefu 的密码…...