当前位置: 首页 > news >正文

VAE-变分自编码器(Variational Autoencoder,VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,结合了概率图模型与神经网络技术,广泛应用于数据生成、表示学习和数据压缩等领域。以下是对VAE的详细解释和理解:

基本概念

1. 自编码器(Autoencoder)

自编码器是一种无监督学习模型,通常用于降维和特征提取。它由两个主要部分组成:

  • 编码器(Encoder):将输入数据映射到一个低维隐变量空间。
  • 解码器(Decoder):从低维隐变量空间重建输入数据。
    自编码器的目标是使重建的数据尽可能与原始输入数据相似。

2. 变分自编码器(VAE)

VAE 是自编码器的一种扩展,它通过引入概率分布的概念来对隐变量空间进行建模。VAE 的目标不仅是重建输入数据,还要使隐变量遵循某种已知的概率分布(通常是标准正态分布)。这样可以通过采样隐变量来生成新数据。

VAE的工作原理

  1. 编码器
    在VAE中,编码器不是直接输出一个隐变量,而是输出隐变量的参数(均值 μ 和标准差 σ)。这些参数定义了隐变量的一个概率分布,通常假设为正态分布 N(μ, σ^2)。

  2. 重新参数化技巧(Reparameterization Trick)
    为了使模型能够通过梯度下降进行训练,VAE引入了重新参数化技巧。通过采样一个标准正态分布的变量 ε ~ N(0, 1),然后进行线性变换得到隐变量 z:
    在这里插入图片描述

这样,采样操作变成了一个确定性的操作,允许梯度反向传播。

  1. 解码器
    解码器接受从上述分布中采样的隐变量 z,并尝试重建输入数据。解码器的目标是最大化重建数据的概率。

损失函数

VAE 的损失函数由两部分组成:

  • 重构损失(Reconstruction Loss):衡量重建数据与原始数据的相似度,通常使用均方误差(MSE)或交叉熵损失。 KL

  • 散度(KL Divergence):衡量隐变量分布与标准正态分布的差异。通过最小化KL散度,使隐变量分布接近标准正态分布。

综合起来,VAE的损失函数为:

在这里插入图片描述

VAE的优点

  1. 生成能力:可以从隐变量空间采样生成新数据,具有良好的生成能力。
  2. 隐变量解释性:通过将隐变量空间约束为标准正态分布,隐变量具有一定的解释性和可操作性。
  3. 无监督学习:VAE是一种无监督学习模型,不需要标签数据即可进行训练。

VAE的缺点

  1. **生成质量有限:**生成数据的质量有时不如GAN(生成对抗网络)等其他生成模型。
  2. **训练复杂:**VAE的训练涉及到复杂的概率推断和优化过程。

总结

变分自编码器通过引入概率分布和重新参数化技巧,使得隐变量具有良好的生成能力和解释性。其核心思想是在保持重建数据质量的同时,使隐变量遵循标准正态分布,从而实现数据生成和表示学习。尽管存在一些缺点,但VAE在许多应用场景中仍然表现出色,并为生成模型的研究提供了重要的理论基础。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable# 定义VAE模型
class VAE(nn.Module):def __init__(self, input_dim, hidden_dim, latent_dim):super(VAE, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc21 = nn.Linear(hidden_dim, latent_dim)self.fc22 = nn.Linear(hidden_dim, latent_dim)self.fc3 = nn.Linear(latent_dim, hidden_dim)self.fc4 = nn.Linear(hidden_dim, input_dim)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparameterize(self, mu, logvar):std = torch.exp(0.5*logvar)eps = torch.randn_like(std)return mu + eps*stddef decode(self, z):h3 = F.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x.view(-1, 784))z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 定义损失函数
def loss_function(recon_x, x, mu, logvar):BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return BCE + KLD# 加载MNIST数据集
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.ToTensor()),batch_size=128, shuffle=True)# 初始化模型
vae = VAE(input_dim=784, hidden_dim=512, latent_dim=20)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)# 训练模型
def train(epoch):vae.train()train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):optimizer.zero_grad()recon_batch, mu, logvar = vae(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item() / len(data)))print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))# 开始训练
for epoch in range(1, 11):train(epoch)

代码说明

  • 编码器和解码器:编码器将输入图像编码为潜在空间的均值和对数方差,解码器从潜在变量生成重建的图像。
  • Sampling层:这是实现重参数化技巧的关键部分,将均值和对数方差转换为潜在变量。
  • VAE类:组合编码器和解码器,并实现自定义训练步骤,包括计算重建损失和KL散度损失。
  • 数据准备和训练:加载MNIST数据集,对数据进行预处理,然后训练VAE模型。
    这个示例展示了一个简单的VAE模型。根据具体的应用需求,你可能需要调整网络结构和超参数。

相关文章:

VAE-变分自编码器(Variational Autoencoder,VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,结合了概率图模型与神经网络技术,广泛应用于数据生成、表示学习和数据压缩等领域。以下是对VAE的详细解释和理解: 基本概念 1. 自编码器&#…...

Android Room 使用模版

文章目录 一、配置依赖 plugins {id kotlin-kapt }android {compileOptions {sourceCompatibility JavaVersion.VERSION_17targetCompatibility JavaVersion.VERSION_17}kotlinOptions {jvmTarget 17} }dependencies {implementation("androidx.room:room-runtime:2.4.2&…...

Linux/Ubuntu 中安装 ZeroTier,实现内网穿透,2分钟搞定

相信很多人都有远程连接家中设备的需求,如远程连接家中的NAS、Windows等服务,所以会涉及到一个内网穿透工具的使用,如果没有公网IP的情况下,推荐大家使用ZeroTier,这是一款强大的内网穿透工具。 mac和windows版的操作…...

java技术:oauth2协议

目录 一、黑马程序员Java进阶教程快速入门Spring Security OAuth2.0认证授权详解 1、oauth服务 WebSecurityConfig TokenConfig AuthorizationServer 改写密码校验逻辑实现类 2、oauth2支持的四种方式: 3、oauth2授权 ResouceServerConfig TokenConfig 4、…...

Java 18 新特性详解

Java 18 新特性详解 Java 18 作为 Oracle 推出的又一重要版本,继续秉持着 Java 平台“创新但不破坏”的原则,带来了多项旨在提升开发效率、性能和安全性的新特性。本篇文章将深入解析 Java 18 引入的主要特性,并探讨它们如何影响开发者的工作…...

【css3】06-css3新特性之网页布局篇

目录 伸缩布局或者弹性布局【响应式布局】 1 设置父元素为伸缩盒子 2 设置伸缩盒子主轴方向 3 设置元素在主轴的对齐方式 4 设置元素在侧轴的对齐方式 5 设置元素是否换行显示 6 设置元素换行后的对齐方式 7 效果测试原码 伸缩布局或者弹性布局【响应式布局】 1 设置父元…...

【开源】大学生竞赛管理系统 JAVA+Vue+SpringBoot+MySQL

目录 一、系统介绍 学生管理模块 教师管理模块 竞赛信息模块 竞赛报名模块 二、系统截图 三、核心代码 一、系统介绍 基于Vue.js和SpringBoot的大学生竞赛管理系统,分为管理后台和用户网页端,可以给管理员、学生和教师角色使用,包括学…...

跨境选品师不是神话:普通人也能轻松掌握,开启全球贸易新篇章!

随着互联网技术的飞速发展,跨境电商行业已成为全球经济的新增长点。在这个背景下,一个新兴的职业——跨境选品师,逐渐走进了人们的视野。那么,跨境选品师究竟是做什么的?普通人又该如何成为优秀的跨境选品师呢? 一、跨境选品师的…...

前缀和,差分算法理解

前缀和是什么: 前缀和指一个数组的某下标之前的所有数组元素的和(包含其自身)。前缀和分为一维前缀和,以及二维前缀和。前缀和是一种重要的预处理,能够降低算法的时间复杂度 说个人话就是比如有一个数组: …...

ubuntu/部分docker容器无法访问https站点

ubuntu/部分docker容器无法访问https站点 解决方案 解决方案 默认的系统内可能没有安装根证书,需要安装一下 apt install ca-certificates如果官方源比较慢,可以换为国内源,但是不要使用https...

【MySQL】库的基础操作

🌎库的操作 文章目录: 库的操作 创建删除数据库 数据库编码集和校验集 数据库的增删查改       数据库查找       数据库修改 备份和恢复 查看数据库连接情况 总结 前言:   数据库操作是软件开发中不可或缺的一部分&#xff0…...

嵌入式0基础开始学习 ⅠC语言(2)运算符与表达式

1.运算符 什么是运算符? 用来进来某种运算的符号 如: - * / (取余,取模) a,几目运算符 根据其操作数的不同 单目运算符 该运算符…...

汇编语言(一)

寄存器:cpu中可以储存数据的器件(AX,BX) 汇编语言的组成:1.汇编指令 2.伪指令 3.其他符号 存储器:cpu,传入指令和数据,加以运算。(内存) 指令和数据&#…...

2010-2022年各省新质生产力数据(含原始数据+测算代码+计算结果)

2010-2022年各省新质生产力数据(含原始数据测算代码计算结果) 1、时间:2010-2022年 2、范围:31省 3、指标:gdp(亿元)、在岗职工工资:元、第三产业就业比重、人均受教育平均年限、…...

需求分析部分图形工具

描述复杂的事物时,图形远比文字叙述优越得多,它形象直观容易理解。前面已经介绍了用于建立功能模型的数据流图、用于建立数据模型的实体-联系图和用于建立行为模型的状态图,本节再简要地介绍在需求分析阶段可能用到的另外3种图形工具。 1 层次方框图 层次方框图用树形结…...

ML307R OpenCPU GPIO使用

一、GPIO使用流程图 二、函数介绍 三、GPIO 点亮LED 四、代码下载地址 一、GPIO使用流程图 这个图是官网找到的,ML307R GPIO引脚电平默认为1.8V,需注意和外部电路的电平匹配,具体可参考《ML307R_硬件设计手册_OpenCPU版本适用.pdf》中的描…...

python基于深度学习的聊天机器人设计

python基于深度学习的聊天机器人设计 开发语言:Python 数据库:MySQL所用到的知识:Django框架工具:pycharm、Navicat、Maven 系统功能实现 登录注册功能 用户在没有登录自己的用户名之前只能浏览本网站的首页,想要使用其他功能都…...

Golang设计模式(四):观察者模式

观察者模式 什么是观察者 观察者模式(Observer Pattern):定义对象之间的一种一对多依赖关系,使得每当一个对象状态发生改变时,其相关依赖对象皆得到通知并被自动更新。观察者模式的别名包括发布-订阅(Publish/Subscribe&#xf…...

huggingface 笔记:查看GPU占用情况

0 准备部分 0.1 创建虚拟数据 import numpy as npfrom datasets import Datasetseq_len, dataset_size 512, 512 dummy_data {"input_ids": np.random.randint(100, 30000, (dataset_size, seq_len)),"labels": np.random.randint(0, 1, (dataset_size…...

JavaSE 学习记录

1. Java 内存 2. this VS super this和super是两个关键字,用于引用当前对象和其父类对象 this 关键字: this 关键字用于引用当前对象,即调用该关键字的方法所属的对象。 主要用途包括: 在类的实例方法中,通过 this …...

设计方案:核心框架搭建与落地实操全指南

当前很多团队在输出设计方案时容易陷入两个极端:要么过度追求创意忽略落地可行性,导致方案最终停留在概念阶段无法产生实际价值;要么完全照搬模板缺乏针对性,无法匹配业务的个性化需求。尤其是电商、新媒体、企业服务等领域的设计…...

Win11Debloat:轻量高效的Windows系统优化开源工具

Win11Debloat:轻量高效的Windows系统优化开源工具 【免费下载链接】Win11Debloat A simple, lightweight PowerShell script that allows you to remove pre-installed apps, disable telemetry, as well as perform various other changes to declutter and custom…...

英雄联盟智能助手ChampR:快速提升游戏水平的终极指南

英雄联盟智能助手ChampR:快速提升游戏水平的终极指南 【免费下载链接】champr 🐶 Yet another League of Legends helper 项目地址: https://gitcode.com/gh_mirrors/ch/champr 你是否在英雄联盟游戏中苦苦寻找最佳的出装和符文配置?C…...

FLUX.1-schnell:如何彻底改变文本到图像生成的技术范式

FLUX.1-schnell:如何彻底改变文本到图像生成的技术范式 【免费下载链接】FLUX.1-schnell 项目地址: https://ai.gitcode.com/hf_mirrors/black-forest-labs/FLUX.1-schnell 在当今人工智能图像生成领域,高质量图像创作一直面临着效率与质量难以兼…...

Human3.6M数据集获取与预处理实战指南:从百度网盘到可用的.pkl文件

1. Human3.6M数据集简介与下载准备 Human3.6M是目前人体姿态估计领域最权威的基准数据集之一,包含11名专业演员在17种日常活动场景下的360万帧动作捕捉数据。我第一次接触这个数据集时,面对近50GB的原始文件和复杂的目录结构也一头雾水。这里分享从下载到…...

StructBERT中文语义匹配系统安全审计:本地化部署带来的合规优势

StructBERT中文语义匹配系统安全审计:本地化部署带来的合规优势 1. 项目概述 StructBERT中文语义智能匹配系统是一个基于先进孪生网络模型的本地化部署解决方案。该系统专门针对中文文本处理需求,提供高精度的语义相似度计算和特征提取能力。 与传统方…...

1222万人同台竞技——这套AI工具组合,正在帮更多毕业生把简历捞率翻倍

2026届高校毕业生规模预计达1222万人,创历史新高。在这个数字背后,是更多人在同一个时间窗口、竞争有限的岗位机会。如何在同等条件下,让自己的求职路走得更快、更准、更稳,是2026春招最核心的命题。 这篇文章,我们想…...

zk(zookeeper)的选举机制

zk中有两种角色:Leader 和 Fllower,Leader是集群各台电脑投票选举出来的。事务【非常重要】:一通操作,要么同时成立,要么都不成立。LeaderZookeeper 集群工作的核心。1.事务请求(写操作)的唯一调…...

【无标题】1

把AI领域最前沿的技术,最硬核的实现,最有趣的见闻,变成能看懂,能上手,能直接用的内容。把不可能变成产品落地。...

MVT协议深度解析:从Protobuf编码到GISBox实战,看它如何碾压传统栅格瓦片

MVT协议技术内幕:从二进制编码到百万级数据渲染实战 当我们打开手机地图App,双指放大查看小区楼栋轮廓时,很少有人会思考这流畅体验背后的技术革命。传统栅格瓦片就像打印在纸上的地图,放大后必然出现马赛克;而MVT协议…...