VQ-VAE(Neural Discrete Representation Learning)论文解读及实现
pytorch 实现git地址
论文地址:Neural Discrete Representation Learning
1 论文核心知识点
-
encoder
将图片通过encoder得到图片点表征
如输入shape [32,3,32,32]
通过encoder后输出 [32,64,8,8] (其中64位输出维度) -
量化码本
先随机构建一个码本,维度与encoder保持一致
这里定义512个离散特征,码本shape 为[512,64] -
encoder 码本中向量最近查找
encoder输出shape [32,64,8,8], 经过维度变换 shape [32 * 8 * 8,64]
在码本中找到最相近的向量,并替换为码本中相似向量
输出shape [3288,64],维度变换后,shape 为 [32,64,8,8] -
decoder
将上述数据,喂给decoder,还原原始图片 -
loss
loss 包含两部分
a . encoder输出和码本向量接近
b. 重构loss,重构图片与原图片接近
2 论文实现
2.1 encoder
encoder是常用的图片卷积神经网络
输入x shape [32,3,32,32]
输出 shape [32,128,8,8]
def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Encoder, self).__init__()kernel = 4stride = 2self.conv_stack = nn.Sequential(nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers))def forward(self, x):return self.conv_stack(x)
2.2 VectorQuantizer 向量量化层
- 输入:
为encoder的输出z,shape : [32,64,8,8] - 码本维度:
encoder维度变换为[2024,64],和码本embeddign shape [512,64]计算相似度 - 相似计算:使用 ( x − y ) 2 = x 2 + y 2 − 2 x y (x-y)^2=x^2+y^2-2xy (x−y)2=x2+y2−2xy计算和码本的相似度
- z_q生成
然后取码本中最相似的向量替换encoder中的向量 - z_1维度:
得到z_q shape [2024,64],经维度变换 shape [32,64,8,8] ,维度与输入z一致 - 损失函数:
使 z_q和z接近,构建损失函数
decoder 层
decoder层比较简单,与encoder层相反
输入x shape 【32,64,8,8】
输出shape [32,3,32,32]
class Decoder(nn.Module):"""This is the p_phi (x|z) network. Given a latent sample z p_phi maps back to the original space z -> x.Inputs:- in_dim : the input dimension- h_dim : the hidden layer dimension- res_h_dim : the hidden dimension of the residual block- n_res_layers : number of layers to stack"""def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Decoder, self).__init__()kernel = 4stride = 2self.inverse_conv_stack = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),nn.ConvTranspose2d(h_dim, h_dim // 2,kernel_size=kernel, stride=stride, padding=1),nn.ReLU(),nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,stride=stride, padding=1))def forward(self, x):return self.inverse_conv_stack(x)
2.3 损失函数
损失函数为重构损失和embedding损失之和
- decoder 输出为图片重构x_hat
- embedding损失,为encoder和码本的embedding近似损失
- 重点:(decoder计算损失时,由于中间有取最小值,导致梯度不连续,因此decoder loss 不能直接对encocer推荐进行求导,采用了复制梯度的方式: z_q = z + (z_q - z).detach(),及
for i in range(args.n_updates):(x, _) = next(iter(training_loader))x = x.to(device)optimizer.zero_grad()embedding_loss, x_hat, perplexity = model(x)recon_loss = torch.mean((x_hat - x)**2) / x_train_varloss = recon_loss + embedding_lossloss.backward()optimizer.step()
相关文章:

VQ-VAE(Neural Discrete Representation Learning)论文解读及实现
pytorch 实现git地址 论文地址:Neural Discrete Representation Learning 1 论文核心知识点 encoder 将图片通过encoder得到图片点表征 如输入shape [32,3,32,32] 通过encoder后输出 [32,64,8,8] (其中64位输出维度) 量化码本 先随机构建一个码本,维度…...

OpenAI的ChatGPT:引领人工智能交流的未来
如果您在使用ChatGPT工具的过程中感到迷茫,别担心,我在这里提供帮助。无论您是初次接触ChatGPT plus,还是在注册、操作过程中遇到难题,我都将为您提供一对一的指导和支持。(qq:1371410959) 一、ChatGPT简介 OpenAI的ChatGPT是一…...

es集群安装及优化
es主节点 192.168.23.100 es节点 192.168.23.101 192.168.23.102 1.安装主节点 1.去官网下载es的yum包 官网下载地址 https://www.elastic.co/cn/downloads/elasticsearch 根据自己的需要下载对应的包 2.下载好之后把所有的包都传到从节点上,安装 [rootlocalho…...

【开源】基于JAVA+Vue+SpringBoot的医院门诊预约挂号系统
目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 功能性需求2.1.1 数据中心模块2.1.2 科室医生档案模块2.1.3 预约挂号模块2.1.4 医院时政模块 2.2 可行性分析2.2.1 可靠性2.2.2 易用性2.2.3 维护性 三、数据库设计3.1 用户表3.2 科室档案表3.3 医生档案表3.4 医生放号…...

Java Swing 图书借阅系统 窗体项目 期末课程设计 窗体设计
视频教程: 【课程设计】图书借阅系统 功能描述: 图书管理系统有三个角色,系统管理员、图书管理员、借阅者; 系统管理员可以添加借阅用户; 图书管理员可以添加图书,操作图书借阅和归还; 借…...

2024.01.09.Apple_UI_BUG
我是软件行业的,虽然不是手机设计的,但是这个设计真的导致经常看信息不完整,要下拉的。 特别读取文本或者其他文件的时候,上面有个抬头就是看不到,烦,体验感很差...

K8S Nginx Ingress Controller client_max_body_size 上传文件大小限制
现象 k8s集群中,上传图片时,大于1M就会报错 413 Request Entity Too Large Nginx Ingress Controller 的版本是 0.29.0 解决方案 1. 修改configmap kubectl edit configmap nginx-configuration -n ingress-nginx在 ConfigMap 的 data 字段中设置参数…...

Untiy HTC Vive VRTK 开发记录
目录 一.概述 二.功能实现 1.模型抓取 1)基础抓取脚本 2)抓取物体在手柄上的角度 2.模型放置区域高亮并吸附 1)VRTK_SnapDropZone 2)VRTK_PolicyList 3)VRTK_SnapDropZone_UnityEvents 3.交互滑动条 4.交互旋…...

机器学习指南:如何学习机器学习?
机器学习 一、介绍 你有没有想过计算机是如何从数据中学习和变得更聪明的?这就是机器学习 (ML) 的魔力!这就像计算机科学和统计学的酷炫组合,计算机从大量信息中学习以解决问题并做出预测,就像人类一样。 …...

使用numpy处理图片——分离通道
大纲 读入图片分离通道堆叠法复制修改法 生成图片 在《使用numpy处理图片——滤镜》中,我们剥离了RGB中的一个颜色,达到一种滤镜的效果。 如果我们只保留一种元素,就可以做到PS中分离通道的效果。 读入图片 import numpy as np import PIL.…...

metartc5_jz源码阅读-yang_rtcpush_on_rtcp_ps_feedback
// (Payload-specific FB messages,有效载荷反馈信息),这个函数处理Payload重传 int32_t yang_rtcpush_on_rtcp_ps_feedback(YangRtcContext *context,YangRtcPushStream *pub, YangRtcpCommon *rtcp) {if (context NULL || pub NULL)return ERROR_RTC…...

计算机毕业设计 | SpringBoot+vue的家庭理财 财务管理系统(附源码)
1,绪论 1.1 项目背景 网络的发展已经过去了七十多年,网络技术的发展,将会影响到人类的方方面面,网络的出现让各行各业都得到了极大的发展,为整个社会带来了巨大的生机。 现在许多的产业都与因特网息息相关ÿ…...

前端面试题集合三(js)
目录 1. 介绍 js 的基本数据类型。2. JavaScript 有几种类型的值?你能画一下他们的内存图吗?3. 什么是堆?什么是栈?它们之间有什么区别和联系?4. 内部属性 [[Class]] 是什么?5. 介绍 js 有哪些内置对象&am…...

ssm基于JAVA的酒店客房管理系统论文
摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本酒店客房管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息…...

杨中科 .NETCORE ENTITY FRAMEWORK CORE-1 EFCORE 第一部分
一 、什么是EF Core 什么是ORM 1、说明: 本课程需要你有数据库、SOL等基础知识。 2、ORM: ObjectRelational Mapping。让开发者用对象操作的形式操作关系数据库 比如插入: User user new User(Name"admin"Password"123”; orm.Save(user);比如查询: Book b…...

微信小程序 全局配置||微信小程序 页面配置||微信小程序 sitemap配置
全局配置 小程序根目录下的 app.json 文件用来对微信小程序进行全局配置,决定页面文件的路径、窗口表现、设置网络超时时间、设置多 tab 等。 以下是一个包含了部分常用配置选项的 app.json : {"pages": ["pages/index/index",&q…...

使用ffmpeg对视频进行静音检测
1 原始视频信息 通过ffmpeg -i命令查看视频基本信息 ffmpeg version 6.1-essentials_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developersbuilt with gcc 12.2.0 (Rev10, Built by MSYS2 project)configuration: --enable-gpl --enable-version3 --enable-sta…...

Servlet-Request
一、预览 在上一篇Servlet体系结构中,我们初步了解了怎么快速本篇将介绍Servlet中请求Request的相关内容,包括Request的体系结构,Request常用API。 二、Request体系结构 我们注意到我们定义的Servlet类若实现Servlet接口时,请求…...

数据结构-怀化学院期末题(490)
哈希查找 题目描述: 实现哈希查找。要求根据给定的哈希函数进行存储,并查找相应元素的存储位置。本题目使用的哈希函数为除留取余法,即H(key)key%m,其中m为存储空间,冲突处理方法采用开放定址法中的线性探测再散列&am…...

Matlab字符识别实验
Matlab 字符识别OCR实验 图像来源于屏幕截图,要求黑底白字。数据来源是任意二进制文件,内容以16进制打印输出,0-9a-f’字符被16个可打印字符替代,这些替代字符经过挑选,使其相对容易被识别。 第一步进行线分割和字符…...

MySQL夯实之路-存储引擎深入浅出
innoDB Mysql4.1以后的版本将表的数据和索引放在单独的文件中 采用mvcc来支持高并发,实现了四个标准的隔离级别,默认为可重复读,并且通过间隙锁(next-key locking)策略防止幻读(查询的行中的间隙也会锁定…...

内存卡为什么会提示格式化,内存卡提示格式化还能恢复吗
对于许多电脑用户来说,执行内存卡格式化操作导致数据丢失是一个常见的问题。在日常生活中,数据丢失的情况并不少见,但内存卡格式化后的数据恢复相对较难。目前,能够使用的方法较少,且成功率较低,但并不是没…...

阅读文献-胃癌
写在前面 今天先不阅读肺癌的了,先读一篇胃癌的文章 文献 An individualized stemness-related signature to predict prognosis and immunotherapy responses for gastric cancer using single-cell and bulk tissue transcriptomes IF:4.0 中科院分区:2区 医学…...

水仙花数(Java解法)
什么是水仙花数? 水仙花数是指一个 3 位数,它每位上的数字的 3 次幂之和等于它本身(例如: 1 5 3 153 ),水仙花数的取值范围在 100~1000 之间。 解题思路: 这个题需要把所以的数字都拿到&…...

vue3 源码解析(3)— computed 计算属性的实现
前言 本文是 vue3 源码分析系列的第三篇文章,主要介绍 vue3 computed 原理。computed 是 vue3 的一个特性,可以根据其他响应式数据创建响应式的计算属性。计算属性的值会根据依赖的数据变化而自动更新,而且具有缓存机制,提高了性…...

Alibaba-> EasyExcel 整理3
1 导入依赖 <!-- easyExcel --><dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version >3.2.1</version><exclusions><exclusion><artifactId>poi-ooxml-schemas</art…...

创建组-RibbonGroup
使用实例如下: 1、main中: #include "QRibbonDemo.h" #include <QtWidgets/QApplication> int main(int argc, char *argv[]) { QApplication a(argc, argv); a.setStyle(new RibbonStyle()); a.setApplicationName(&quo…...

面试题目1
文章目录 1、安装系统的方法2、总线型3、OSL参考模型(网络七层模型)4、计算机系统的组成5、计算机硬件 1、安装系统的方法 U盘安装 硬盘安装 刻光盘安装 PE系统中安装 网络安装 2、总线型 所有设备都连接到公共总线上,结点间使用广播通信方…...

考古学家 - 华为OD统一考试
OD统一考试 分值: 200分 题解: Java / Python / C++ 题目描述 有一个考古学家发现一个石碑,但是很可惜发现时其已经断成多段。 原地发现N个断口整齐的石碑碎片,为了破解石碑内容,考古学家希望有程序能帮忙计算复原后的石碑文字组合数,你能帮忙吗? 备注: 如果存在石碑…...

Linux服务器安全配置基线
基线要求: 安全类别 检查项 检查要求 检查步骤 备注 账户及口令安全 1.1 检查是否设置口令生存周期 应配置口令生存周期,密码最长使用期限应小于等于90天,密码最短使用期限应非0。 执行:cat /etc/login.defs,检查是否配置了以下参数。 PASS_MAX_DAYS 配置项决定密码最长使…...