VQ-VAE(Neural Discrete Representation Learning)论文解读及实现
pytorch 实现git地址
论文地址:Neural Discrete Representation Learning
1 论文核心知识点
-
encoder
将图片通过encoder得到图片点表征
如输入shape [32,3,32,32]
通过encoder后输出 [32,64,8,8] (其中64位输出维度) -
量化码本
先随机构建一个码本,维度与encoder保持一致
这里定义512个离散特征,码本shape 为[512,64] -
encoder 码本中向量最近查找
encoder输出shape [32,64,8,8], 经过维度变换 shape [32 * 8 * 8,64]
在码本中找到最相近的向量,并替换为码本中相似向量
输出shape [3288,64],维度变换后,shape 为 [32,64,8,8] -
decoder
将上述数据,喂给decoder,还原原始图片 -
loss
loss 包含两部分
a . encoder输出和码本向量接近
b. 重构loss,重构图片与原图片接近
2 论文实现
2.1 encoder
encoder是常用的图片卷积神经网络
输入x shape [32,3,32,32]
输出 shape [32,128,8,8]
def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Encoder, self).__init__()kernel = 4stride = 2self.conv_stack = nn.Sequential(nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers))def forward(self, x):return self.conv_stack(x)
2.2 VectorQuantizer 向量量化层
- 输入:
为encoder的输出z,shape : [32,64,8,8] - 码本维度:
encoder维度变换为[2024,64],和码本embeddign shape [512,64]计算相似度 - 相似计算:使用 ( x − y ) 2 = x 2 + y 2 − 2 x y (x-y)^2=x^2+y^2-2xy (x−y)2=x2+y2−2xy计算和码本的相似度
- z_q生成
然后取码本中最相似的向量替换encoder中的向量 - z_1维度:
得到z_q shape [2024,64],经维度变换 shape [32,64,8,8] ,维度与输入z一致 - 损失函数:
使 z_q和z接近,构建损失函数
decoder 层
decoder层比较简单,与encoder层相反
输入x shape 【32,64,8,8】
输出shape [32,3,32,32]
class Decoder(nn.Module):"""This is the p_phi (x|z) network. Given a latent sample z p_phi maps back to the original space z -> x.Inputs:- in_dim : the input dimension- h_dim : the hidden layer dimension- res_h_dim : the hidden dimension of the residual block- n_res_layers : number of layers to stack"""def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Decoder, self).__init__()kernel = 4stride = 2self.inverse_conv_stack = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),nn.ConvTranspose2d(h_dim, h_dim // 2,kernel_size=kernel, stride=stride, padding=1),nn.ReLU(),nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,stride=stride, padding=1))def forward(self, x):return self.inverse_conv_stack(x)
2.3 损失函数
损失函数为重构损失和embedding损失之和
- decoder 输出为图片重构x_hat
- embedding损失,为encoder和码本的embedding近似损失
- 重点:(decoder计算损失时,由于中间有取最小值,导致梯度不连续,因此decoder loss 不能直接对encocer推荐进行求导,采用了复制梯度的方式: z_q = z + (z_q - z).detach(),及
for i in range(args.n_updates):(x, _) = next(iter(training_loader))x = x.to(device)optimizer.zero_grad()embedding_loss, x_hat, perplexity = model(x)recon_loss = torch.mean((x_hat - x)**2) / x_train_varloss = recon_loss + embedding_lossloss.backward()optimizer.step()
相关文章:

VQ-VAE(Neural Discrete Representation Learning)论文解读及实现
pytorch 实现git地址 论文地址:Neural Discrete Representation Learning 1 论文核心知识点 encoder 将图片通过encoder得到图片点表征 如输入shape [32,3,32,32] 通过encoder后输出 [32,64,8,8] (其中64位输出维度) 量化码本 先随机构建一个码本,维度…...
OpenAI的ChatGPT:引领人工智能交流的未来
如果您在使用ChatGPT工具的过程中感到迷茫,别担心,我在这里提供帮助。无论您是初次接触ChatGPT plus,还是在注册、操作过程中遇到难题,我都将为您提供一对一的指导和支持。(qq:1371410959) 一、ChatGPT简介 OpenAI的ChatGPT是一…...

es集群安装及优化
es主节点 192.168.23.100 es节点 192.168.23.101 192.168.23.102 1.安装主节点 1.去官网下载es的yum包 官网下载地址 https://www.elastic.co/cn/downloads/elasticsearch 根据自己的需要下载对应的包 2.下载好之后把所有的包都传到从节点上,安装 [rootlocalho…...

【开源】基于JAVA+Vue+SpringBoot的医院门诊预约挂号系统
目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 功能性需求2.1.1 数据中心模块2.1.2 科室医生档案模块2.1.3 预约挂号模块2.1.4 医院时政模块 2.2 可行性分析2.2.1 可靠性2.2.2 易用性2.2.3 维护性 三、数据库设计3.1 用户表3.2 科室档案表3.3 医生档案表3.4 医生放号…...

Java Swing 图书借阅系统 窗体项目 期末课程设计 窗体设计
视频教程: 【课程设计】图书借阅系统 功能描述: 图书管理系统有三个角色,系统管理员、图书管理员、借阅者; 系统管理员可以添加借阅用户; 图书管理员可以添加图书,操作图书借阅和归还; 借…...

2024.01.09.Apple_UI_BUG
我是软件行业的,虽然不是手机设计的,但是这个设计真的导致经常看信息不完整,要下拉的。 特别读取文本或者其他文件的时候,上面有个抬头就是看不到,烦,体验感很差...
K8S Nginx Ingress Controller client_max_body_size 上传文件大小限制
现象 k8s集群中,上传图片时,大于1M就会报错 413 Request Entity Too Large Nginx Ingress Controller 的版本是 0.29.0 解决方案 1. 修改configmap kubectl edit configmap nginx-configuration -n ingress-nginx在 ConfigMap 的 data 字段中设置参数…...

Untiy HTC Vive VRTK 开发记录
目录 一.概述 二.功能实现 1.模型抓取 1)基础抓取脚本 2)抓取物体在手柄上的角度 2.模型放置区域高亮并吸附 1)VRTK_SnapDropZone 2)VRTK_PolicyList 3)VRTK_SnapDropZone_UnityEvents 3.交互滑动条 4.交互旋…...

机器学习指南:如何学习机器学习?
机器学习 一、介绍 你有没有想过计算机是如何从数据中学习和变得更聪明的?这就是机器学习 (ML) 的魔力!这就像计算机科学和统计学的酷炫组合,计算机从大量信息中学习以解决问题并做出预测,就像人类一样。 …...

使用numpy处理图片——分离通道
大纲 读入图片分离通道堆叠法复制修改法 生成图片 在《使用numpy处理图片——滤镜》中,我们剥离了RGB中的一个颜色,达到一种滤镜的效果。 如果我们只保留一种元素,就可以做到PS中分离通道的效果。 读入图片 import numpy as np import PIL.…...
metartc5_jz源码阅读-yang_rtcpush_on_rtcp_ps_feedback
// (Payload-specific FB messages,有效载荷反馈信息),这个函数处理Payload重传 int32_t yang_rtcpush_on_rtcp_ps_feedback(YangRtcContext *context,YangRtcPushStream *pub, YangRtcpCommon *rtcp) {if (context NULL || pub NULL)return ERROR_RTC…...

计算机毕业设计 | SpringBoot+vue的家庭理财 财务管理系统(附源码)
1,绪论 1.1 项目背景 网络的发展已经过去了七十多年,网络技术的发展,将会影响到人类的方方面面,网络的出现让各行各业都得到了极大的发展,为整个社会带来了巨大的生机。 现在许多的产业都与因特网息息相关ÿ…...
前端面试题集合三(js)
目录 1. 介绍 js 的基本数据类型。2. JavaScript 有几种类型的值?你能画一下他们的内存图吗?3. 什么是堆?什么是栈?它们之间有什么区别和联系?4. 内部属性 [[Class]] 是什么?5. 介绍 js 有哪些内置对象&am…...

ssm基于JAVA的酒店客房管理系统论文
摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本酒店客房管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息…...

杨中科 .NETCORE ENTITY FRAMEWORK CORE-1 EFCORE 第一部分
一 、什么是EF Core 什么是ORM 1、说明: 本课程需要你有数据库、SOL等基础知识。 2、ORM: ObjectRelational Mapping。让开发者用对象操作的形式操作关系数据库 比如插入: User user new User(Name"admin"Password"123”; orm.Save(user);比如查询: Book b…...

微信小程序 全局配置||微信小程序 页面配置||微信小程序 sitemap配置
全局配置 小程序根目录下的 app.json 文件用来对微信小程序进行全局配置,决定页面文件的路径、窗口表现、设置网络超时时间、设置多 tab 等。 以下是一个包含了部分常用配置选项的 app.json : {"pages": ["pages/index/index",&q…...

使用ffmpeg对视频进行静音检测
1 原始视频信息 通过ffmpeg -i命令查看视频基本信息 ffmpeg version 6.1-essentials_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developersbuilt with gcc 12.2.0 (Rev10, Built by MSYS2 project)configuration: --enable-gpl --enable-version3 --enable-sta…...

Servlet-Request
一、预览 在上一篇Servlet体系结构中,我们初步了解了怎么快速本篇将介绍Servlet中请求Request的相关内容,包括Request的体系结构,Request常用API。 二、Request体系结构 我们注意到我们定义的Servlet类若实现Servlet接口时,请求…...
数据结构-怀化学院期末题(490)
哈希查找 题目描述: 实现哈希查找。要求根据给定的哈希函数进行存储,并查找相应元素的存储位置。本题目使用的哈希函数为除留取余法,即H(key)key%m,其中m为存储空间,冲突处理方法采用开放定址法中的线性探测再散列&am…...

Matlab字符识别实验
Matlab 字符识别OCR实验 图像来源于屏幕截图,要求黑底白字。数据来源是任意二进制文件,内容以16进制打印输出,0-9a-f’字符被16个可打印字符替代,这些替代字符经过挑选,使其相对容易被识别。 第一步进行线分割和字符…...

【JavaEE】-- HTTP
1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...
Cesium1.95中高性能加载1500个点
一、基本方式: 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...
java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别
UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...
2024年赣州旅游投资集团社会招聘笔试真
2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

全球首个30米分辨率湿地数据集(2000—2022)
数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

苍穹外卖--缓存菜品
1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...
Unit 1 深度强化学习简介
Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库,例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体,比如 SnowballFight、Huggy the Do…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...
大学生职业发展与就业创业指导教学评价
这里是引用 作为软工2203/2204班的学生,我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要,而您认真负责的教学态度,让课程的每一部分都充满了实用价值。 尤其让我…...
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问(基础概念问题) 1. 请解释Spring框架的核心容器是什么?它在Spring中起到什么作用? Spring框架的核心容器是IoC容器&#…...