VQ-VAE(Neural Discrete Representation Learning)论文解读及实现
pytorch 实现git地址
论文地址:Neural Discrete Representation Learning
1 论文核心知识点
-
encoder
将图片通过encoder得到图片点表征
如输入shape [32,3,32,32]
通过encoder后输出 [32,64,8,8] (其中64位输出维度) -
量化码本
先随机构建一个码本,维度与encoder保持一致
这里定义512个离散特征,码本shape 为[512,64] -
encoder 码本中向量最近查找
encoder输出shape [32,64,8,8], 经过维度变换 shape [32 * 8 * 8,64]
在码本中找到最相近的向量,并替换为码本中相似向量
输出shape [3288,64],维度变换后,shape 为 [32,64,8,8] -
decoder
将上述数据,喂给decoder,还原原始图片 -
loss
loss 包含两部分
a . encoder输出和码本向量接近
b. 重构loss,重构图片与原图片接近
2 论文实现
2.1 encoder
encoder是常用的图片卷积神经网络
输入x shape [32,3,32,32]
输出 shape [32,128,8,8]
def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Encoder, self).__init__()kernel = 4stride = 2self.conv_stack = nn.Sequential(nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers))def forward(self, x):return self.conv_stack(x)
2.2 VectorQuantizer 向量量化层
- 输入:
为encoder的输出z,shape : [32,64,8,8] - 码本维度:
encoder维度变换为[2024,64],和码本embeddign shape [512,64]计算相似度 - 相似计算:使用 ( x − y ) 2 = x 2 + y 2 − 2 x y (x-y)^2=x^2+y^2-2xy (x−y)2=x2+y2−2xy计算和码本的相似度
- z_q生成
然后取码本中最相似的向量替换encoder中的向量 - z_1维度:
得到z_q shape [2024,64],经维度变换 shape [32,64,8,8] ,维度与输入z一致 - 损失函数:
使 z_q和z接近,构建损失函数
decoder 层
decoder层比较简单,与encoder层相反
输入x shape 【32,64,8,8】
输出shape [32,3,32,32]
class Decoder(nn.Module):"""This is the p_phi (x|z) network. Given a latent sample z p_phi maps back to the original space z -> x.Inputs:- in_dim : the input dimension- h_dim : the hidden layer dimension- res_h_dim : the hidden dimension of the residual block- n_res_layers : number of layers to stack"""def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Decoder, self).__init__()kernel = 4stride = 2self.inverse_conv_stack = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),nn.ConvTranspose2d(h_dim, h_dim // 2,kernel_size=kernel, stride=stride, padding=1),nn.ReLU(),nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,stride=stride, padding=1))def forward(self, x):return self.inverse_conv_stack(x)
2.3 损失函数
损失函数为重构损失和embedding损失之和
- decoder 输出为图片重构x_hat
- embedding损失,为encoder和码本的embedding近似损失
- 重点:(decoder计算损失时,由于中间有取最小值,导致梯度不连续,因此decoder loss 不能直接对encocer推荐进行求导,采用了复制梯度的方式: z_q = z + (z_q - z).detach(),及
for i in range(args.n_updates):(x, _) = next(iter(training_loader))x = x.to(device)optimizer.zero_grad()embedding_loss, x_hat, perplexity = model(x)recon_loss = torch.mean((x_hat - x)**2) / x_train_varloss = recon_loss + embedding_lossloss.backward()optimizer.step()
相关文章:

VQ-VAE(Neural Discrete Representation Learning)论文解读及实现
pytorch 实现git地址 论文地址:Neural Discrete Representation Learning 1 论文核心知识点 encoder 将图片通过encoder得到图片点表征 如输入shape [32,3,32,32] 通过encoder后输出 [32,64,8,8] (其中64位输出维度) 量化码本 先随机构建一个码本,维度…...
OpenAI的ChatGPT:引领人工智能交流的未来
如果您在使用ChatGPT工具的过程中感到迷茫,别担心,我在这里提供帮助。无论您是初次接触ChatGPT plus,还是在注册、操作过程中遇到难题,我都将为您提供一对一的指导和支持。(qq:1371410959) 一、ChatGPT简介 OpenAI的ChatGPT是一…...

es集群安装及优化
es主节点 192.168.23.100 es节点 192.168.23.101 192.168.23.102 1.安装主节点 1.去官网下载es的yum包 官网下载地址 https://www.elastic.co/cn/downloads/elasticsearch 根据自己的需要下载对应的包 2.下载好之后把所有的包都传到从节点上,安装 [rootlocalho…...

【开源】基于JAVA+Vue+SpringBoot的医院门诊预约挂号系统
目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 功能性需求2.1.1 数据中心模块2.1.2 科室医生档案模块2.1.3 预约挂号模块2.1.4 医院时政模块 2.2 可行性分析2.2.1 可靠性2.2.2 易用性2.2.3 维护性 三、数据库设计3.1 用户表3.2 科室档案表3.3 医生档案表3.4 医生放号…...

Java Swing 图书借阅系统 窗体项目 期末课程设计 窗体设计
视频教程: 【课程设计】图书借阅系统 功能描述: 图书管理系统有三个角色,系统管理员、图书管理员、借阅者; 系统管理员可以添加借阅用户; 图书管理员可以添加图书,操作图书借阅和归还; 借…...

2024.01.09.Apple_UI_BUG
我是软件行业的,虽然不是手机设计的,但是这个设计真的导致经常看信息不完整,要下拉的。 特别读取文本或者其他文件的时候,上面有个抬头就是看不到,烦,体验感很差...
K8S Nginx Ingress Controller client_max_body_size 上传文件大小限制
现象 k8s集群中,上传图片时,大于1M就会报错 413 Request Entity Too Large Nginx Ingress Controller 的版本是 0.29.0 解决方案 1. 修改configmap kubectl edit configmap nginx-configuration -n ingress-nginx在 ConfigMap 的 data 字段中设置参数…...

Untiy HTC Vive VRTK 开发记录
目录 一.概述 二.功能实现 1.模型抓取 1)基础抓取脚本 2)抓取物体在手柄上的角度 2.模型放置区域高亮并吸附 1)VRTK_SnapDropZone 2)VRTK_PolicyList 3)VRTK_SnapDropZone_UnityEvents 3.交互滑动条 4.交互旋…...

机器学习指南:如何学习机器学习?
机器学习 一、介绍 你有没有想过计算机是如何从数据中学习和变得更聪明的?这就是机器学习 (ML) 的魔力!这就像计算机科学和统计学的酷炫组合,计算机从大量信息中学习以解决问题并做出预测,就像人类一样。 …...

使用numpy处理图片——分离通道
大纲 读入图片分离通道堆叠法复制修改法 生成图片 在《使用numpy处理图片——滤镜》中,我们剥离了RGB中的一个颜色,达到一种滤镜的效果。 如果我们只保留一种元素,就可以做到PS中分离通道的效果。 读入图片 import numpy as np import PIL.…...
metartc5_jz源码阅读-yang_rtcpush_on_rtcp_ps_feedback
// (Payload-specific FB messages,有效载荷反馈信息),这个函数处理Payload重传 int32_t yang_rtcpush_on_rtcp_ps_feedback(YangRtcContext *context,YangRtcPushStream *pub, YangRtcpCommon *rtcp) {if (context NULL || pub NULL)return ERROR_RTC…...

计算机毕业设计 | SpringBoot+vue的家庭理财 财务管理系统(附源码)
1,绪论 1.1 项目背景 网络的发展已经过去了七十多年,网络技术的发展,将会影响到人类的方方面面,网络的出现让各行各业都得到了极大的发展,为整个社会带来了巨大的生机。 现在许多的产业都与因特网息息相关ÿ…...
前端面试题集合三(js)
目录 1. 介绍 js 的基本数据类型。2. JavaScript 有几种类型的值?你能画一下他们的内存图吗?3. 什么是堆?什么是栈?它们之间有什么区别和联系?4. 内部属性 [[Class]] 是什么?5. 介绍 js 有哪些内置对象&am…...

ssm基于JAVA的酒店客房管理系统论文
摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本酒店客房管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息…...

杨中科 .NETCORE ENTITY FRAMEWORK CORE-1 EFCORE 第一部分
一 、什么是EF Core 什么是ORM 1、说明: 本课程需要你有数据库、SOL等基础知识。 2、ORM: ObjectRelational Mapping。让开发者用对象操作的形式操作关系数据库 比如插入: User user new User(Name"admin"Password"123”; orm.Save(user);比如查询: Book b…...

微信小程序 全局配置||微信小程序 页面配置||微信小程序 sitemap配置
全局配置 小程序根目录下的 app.json 文件用来对微信小程序进行全局配置,决定页面文件的路径、窗口表现、设置网络超时时间、设置多 tab 等。 以下是一个包含了部分常用配置选项的 app.json : {"pages": ["pages/index/index",&q…...

使用ffmpeg对视频进行静音检测
1 原始视频信息 通过ffmpeg -i命令查看视频基本信息 ffmpeg version 6.1-essentials_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developersbuilt with gcc 12.2.0 (Rev10, Built by MSYS2 project)configuration: --enable-gpl --enable-version3 --enable-sta…...

Servlet-Request
一、预览 在上一篇Servlet体系结构中,我们初步了解了怎么快速本篇将介绍Servlet中请求Request的相关内容,包括Request的体系结构,Request常用API。 二、Request体系结构 我们注意到我们定义的Servlet类若实现Servlet接口时,请求…...
数据结构-怀化学院期末题(490)
哈希查找 题目描述: 实现哈希查找。要求根据给定的哈希函数进行存储,并查找相应元素的存储位置。本题目使用的哈希函数为除留取余法,即H(key)key%m,其中m为存储空间,冲突处理方法采用开放定址法中的线性探测再散列&am…...

Matlab字符识别实验
Matlab 字符识别OCR实验 图像来源于屏幕截图,要求黑底白字。数据来源是任意二进制文件,内容以16进制打印输出,0-9a-f’字符被16个可打印字符替代,这些替代字符经过挑选,使其相对容易被识别。 第一步进行线分割和字符…...

深度学习在微纳光子学中的应用
深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...
内存分配函数malloc kmalloc vmalloc
内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

python打卡day49
知识点回顾: 通道注意力模块复习空间注意力模块CBAM的定义 作业:尝试对今天的模型检查参数数目,并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...
C++ 基础特性深度解析
目录 引言 一、命名空间(namespace) C 中的命名空间 与 C 语言的对比 二、缺省参数 C 中的缺省参数 与 C 语言的对比 三、引用(reference) C 中的引用 与 C 语言的对比 四、inline(内联函数…...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...
JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案
JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停 1. 安全点(Safepoint)阻塞 现象:JVM暂停但无GC日志,日志显示No GCs detected。原因:JVM等待所有线程进入安全点(如…...

中医有效性探讨
文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...
BLEU评分:机器翻译质量评估的黄金标准
BLEU评分:机器翻译质量评估的黄金标准 1. 引言 在自然语言处理(NLP)领域,衡量一个机器翻译模型的性能至关重要。BLEU (Bilingual Evaluation Understudy) 作为一种自动化评估指标,自2002年由IBM的Kishore Papineni等人提出以来,…...