当前位置: 首页 > news >正文

VQ-VAE(Neural Discrete Representation Learning)论文解读及实现

pytorch 实现git地址
论文地址:Neural Discrete Representation Learning

1 论文核心知识点

  • encoder
    将图片通过encoder得到图片点表征
    如输入shape [32,3,32,32]
    通过encoder后输出 [32,64,8,8] (其中64位输出维度)

  • 量化码本
    先随机构建一个码本,维度与encoder保持一致
    这里定义512个离散特征,码本shape 为[512,64]

  • encoder 码本中向量最近查找
    encoder输出shape [32,64,8,8], 经过维度变换 shape [32 * 8 * 8,64]
    在码本中找到最相近的向量,并替换为码本中相似向量
    输出shape [3288,64],维度变换后,shape 为 [32,64,8,8]

  • decoder
    将上述数据,喂给decoder,还原原始图片

  • loss
    loss 包含两部分
    a . encoder输出和码本向量接近
    b. 重构loss,重构图片与原图片接近

在这里插入图片描述

2 论文实现

2.1 encoder

encoder是常用的图片卷积神经网络
输入x shape [32,3,32,32]
输出 shape [32,128,8,8]

def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Encoder, self).__init__()kernel = 4stride = 2self.conv_stack = nn.Sequential(nn.Conv2d(in_dim, h_dim // 2, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim // 2, h_dim, kernel_size=kernel,stride=stride, padding=1),nn.ReLU(),nn.Conv2d(h_dim, h_dim, kernel_size=kernel-1,stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers))def forward(self, x):return self.conv_stack(x)

2.2 VectorQuantizer 向量量化层

  • 输入:
    为encoder的输出z,shape : [32,64,8,8]
  • 码本维度:
    encoder维度变换为[2024,64],和码本embeddign shape [512,64]计算相似度
  • 相似计算:使用 ( x − y ) 2 = x 2 + y 2 − 2 x y (x-y)^2=x^2+y^2-2xy (xy)2=x2+y22xy计算和码本的相似度
  • z_q生成
    然后取码本中最相似的向量替换encoder中的向量
  • z_1维度:
    得到z_q shape [2024,64],经维度变换 shape [32,64,8,8] ,维度与输入z一致
  • 损失函数:
    使 z_q和z接近,构建损失函数
    在这里插入图片描述

decoder 层

decoder层比较简单,与encoder层相反
输入x shape 【32,64,8,8】
输出shape [32,3,32,32]

class Decoder(nn.Module):"""This is the p_phi (x|z) network. Given a latent sample z p_phi maps back to the original space z -> x.Inputs:- in_dim : the input dimension- h_dim : the hidden layer dimension- res_h_dim : the hidden dimension of the residual block- n_res_layers : number of layers to stack"""def __init__(self, in_dim, h_dim, n_res_layers, res_h_dim):super(Decoder, self).__init__()kernel = 4stride = 2self.inverse_conv_stack = nn.Sequential(nn.ConvTranspose2d(in_dim, h_dim, kernel_size=kernel-1, stride=stride-1, padding=1),ResidualStack(h_dim, h_dim, res_h_dim, n_res_layers),nn.ConvTranspose2d(h_dim, h_dim // 2,kernel_size=kernel, stride=stride, padding=1),nn.ReLU(),nn.ConvTranspose2d(h_dim//2, 3, kernel_size=kernel,stride=stride, padding=1))def forward(self, x):return self.inverse_conv_stack(x)

2.3 损失函数

损失函数为重构损失和embedding损失之和

  • decoder 输出为图片重构x_hat
  • embedding损失,为encoder和码本的embedding近似损失
  • 重点:(decoder计算损失时,由于中间有取最小值,导致梯度不连续,因此decoder loss 不能直接对encocer推荐进行求导,采用了复制梯度的方式: z_q = z + (z_q - z).detach(),及
    for i in range(args.n_updates):(x, _) = next(iter(training_loader))x = x.to(device)optimizer.zero_grad()embedding_loss, x_hat, perplexity = model(x)recon_loss = torch.mean((x_hat - x)**2) / x_train_varloss = recon_loss + embedding_lossloss.backward()optimizer.step()

相关文章:

VQ-VAE(Neural Discrete Representation Learning)论文解读及实现

pytorch 实现git地址 论文地址:Neural Discrete Representation Learning 1 论文核心知识点 encoder 将图片通过encoder得到图片点表征 如输入shape [32,3,32,32] 通过encoder后输出 [32,64,8,8] (其中64位输出维度) 量化码本 先随机构建一个码本,维度…...

OpenAI的ChatGPT:引领人工智能交流的未来

如果您在使用ChatGPT工具的过程中感到迷茫,别担心,我在这里提供帮助。无论您是初次接触ChatGPT plus,还是在注册、操作过程中遇到难题,我都将为您提供一对一的指导和支持。(qq:1371410959) 一、ChatGPT简介 OpenAI的ChatGPT是一…...

es集群安装及优化

es主节点 192.168.23.100 es节点 192.168.23.101 192.168.23.102 1.安装主节点 1.去官网下载es的yum包 官网下载地址 https://www.elastic.co/cn/downloads/elasticsearch 根据自己的需要下载对应的包 2.下载好之后把所有的包都传到从节点上,安装 [rootlocalho…...

【开源】基于JAVA+Vue+SpringBoot的医院门诊预约挂号系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 功能性需求2.1.1 数据中心模块2.1.2 科室医生档案模块2.1.3 预约挂号模块2.1.4 医院时政模块 2.2 可行性分析2.2.1 可靠性2.2.2 易用性2.2.3 维护性 三、数据库设计3.1 用户表3.2 科室档案表3.3 医生档案表3.4 医生放号…...

Java Swing 图书借阅系统 窗体项目 期末课程设计 窗体设计

视频教程: 【课程设计】图书借阅系统 功能描述: 图书管理系统有三个角色,系统管理员、图书管理员、借阅者; 系统管理员可以添加借阅用户; ​图书管理员可以添加图书,操作图书借阅和归还; 借…...

2024.01.09.Apple_UI_BUG

我是软件行业的,虽然不是手机设计的,但是这个设计真的导致经常看信息不完整,要下拉的。 特别读取文本或者其他文件的时候,上面有个抬头就是看不到,烦,体验感很差...

K8S Nginx Ingress Controller client_max_body_size 上传文件大小限制

现象 k8s集群中,上传图片时,大于1M就会报错 413 Request Entity Too Large Nginx Ingress Controller 的版本是 0.29.0 解决方案 1. 修改configmap kubectl edit configmap nginx-configuration -n ingress-nginx在 ConfigMap 的 data 字段中设置参数…...

Untiy HTC Vive VRTK 开发记录

目录 一.概述 二.功能实现 1.模型抓取 1)基础抓取脚本 2)抓取物体在手柄上的角度 2.模型放置区域高亮并吸附 1)VRTK_SnapDropZone 2)VRTK_PolicyList 3)VRTK_SnapDropZone_UnityEvents 3.交互滑动条 4.交互旋…...

机器学习指南:如何学习机器学习?

机器学习 一、介绍 你有没有想过计算机是如何从数据中学习和变得更聪明的?这就是机器学习 (ML) 的魔力!这就像计算机科学和统计学的酷炫组合,计算机从大量信息中学习以解决问题并做出预测,就像人类一样。 …...

使用numpy处理图片——分离通道

大纲 读入图片分离通道堆叠法复制修改法 生成图片 在《使用numpy处理图片——滤镜》中,我们剥离了RGB中的一个颜色,达到一种滤镜的效果。 如果我们只保留一种元素,就可以做到PS中分离通道的效果。 读入图片 import numpy as np import PIL.…...

metartc5_jz源码阅读-yang_rtcpush_on_rtcp_ps_feedback

// (Payload-specific FB messages,有效载荷反馈信息),这个函数处理Payload重传 int32_t yang_rtcpush_on_rtcp_ps_feedback(YangRtcContext *context,YangRtcPushStream *pub, YangRtcpCommon *rtcp) {if (context NULL || pub NULL)return ERROR_RTC…...

计算机毕业设计 | SpringBoot+vue的家庭理财 财务管理系统(附源码)

1,绪论 1.1 项目背景 网络的发展已经过去了七十多年,网络技术的发展,将会影响到人类的方方面面,网络的出现让各行各业都得到了极大的发展,为整个社会带来了巨大的生机。 现在许多的产业都与因特网息息相关&#xff…...

前端面试题集合三(js)

目录 1. 介绍 js 的基本数据类型。2. JavaScript 有几种类型的值?你能画一下他们的内存图吗?3. 什么是堆?什么是栈?它们之间有什么区别和联系?4. 内部属性 [[Class]] 是什么?5. 介绍 js 有哪些内置对象&am…...

ssm基于JAVA的酒店客房管理系统论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本酒店客房管理系统就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息…...

杨中科 .NETCORE ENTITY FRAMEWORK CORE-1 EFCORE 第一部分

一 、什么是EF Core 什么是ORM 1、说明: 本课程需要你有数据库、SOL等基础知识。 2、ORM: ObjectRelational Mapping。让开发者用对象操作的形式操作关系数据库 比如插入: User user new User(Name"admin"Password"123”; orm.Save(user);比如查询: Book b…...

微信小程序 全局配置||微信小程序 页面配置||微信小程序 sitemap配置

全局配置 小程序根目录下的 app.json 文件用来对微信小程序进行全局配置,决定页面文件的路径、窗口表现、设置网络超时时间、设置多 tab 等。 以下是一个包含了部分常用配置选项的 app.json : {"pages": ["pages/index/index",&q…...

使用ffmpeg对视频进行静音检测

1 原始视频信息 通过ffmpeg -i命令查看视频基本信息 ffmpeg version 6.1-essentials_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developersbuilt with gcc 12.2.0 (Rev10, Built by MSYS2 project)configuration: --enable-gpl --enable-version3 --enable-sta…...

Servlet-Request

一、预览 在上一篇Servlet体系结构中,我们初步了解了怎么快速本篇将介绍Servlet中请求Request的相关内容,包括Request的体系结构,Request常用API。 二、Request体系结构 我们注意到我们定义的Servlet类若实现Servlet接口时,请求…...

数据结构-怀化学院期末题(490)

哈希查找 题目描述: 实现哈希查找。要求根据给定的哈希函数进行存储,并查找相应元素的存储位置。本题目使用的哈希函数为除留取余法,即H(key)key%m,其中m为存储空间,冲突处理方法采用开放定址法中的线性探测再散列&am…...

Matlab字符识别实验

Matlab 字符识别OCR实验 图像来源于屏幕截图,要求黑底白字。数据来源是任意二进制文件,内容以16进制打印输出,0-9a-f’字符被16个可打印字符替代,这些替代字符经过挑选,使其相对容易被识别。 第一步进行线分割和字符…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向: 逆向设计 通过神经网络快速预测微纳结构的光学响应,替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

2025年能源电力系统与流体力学国际会议 (EPSFD 2025)

2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...

CMake 从 GitHub 下载第三方库并使用

有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)

船舶制造装配管理现状:装配工作依赖人工经验,装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书,但在实际执行中,工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...

【Go语言基础【13】】函数、闭包、方法

文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数(函数作为参数、返回值) 三、匿名函数与闭包1. 匿名函数(Lambda函…...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码,因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存,无论是github还是gittee,都是一种基于git去保存代码的形式,这样保存代码…...

比较数据迁移后MySQL数据库和OceanBase数据仓库中的表

设计一个MySQL数据库和OceanBase数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...

uniapp 小程序 学习(一)

利用Hbuilder 创建项目 运行到内置浏览器看效果 下载微信小程序 安装到Hbuilder 下载地址 :开发者工具默认安装 设置服务端口号 在Hbuilder中设置微信小程序 配置 找到运行设置,将微信开发者工具放入到Hbuilder中, 打开后出现 如下 bug 解…...

go 里面的指针

指针 在 Go 中,指针(pointer)是一个变量的内存地址,就像 C 语言那样: a : 10 p : &a // p 是一个指向 a 的指针 fmt.Println(*p) // 输出 10,通过指针解引用• &a 表示获取变量 a 的地址 p 表示…...