当前位置: 首页 > news >正文

变分自编码器(VAE)PyTorch Lightning 实现

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • VAE 简介
      • 基本原理
      • 应用与优点
      • 缺点与挑战
    • 使用 VAE 生成 MNIST 手写数字
      • 忽略警告
      • 导入必要的库
      • 设置随机种子
      • cuDNN 设置
      • 超参数设置
      • 数据加载
      • 定义 VAE 模型
      • 定义损失函数
      • 定义 Lightning 模型
      • 训练模型
      • 绘制训练过程
      • 随机生成新样本
      • 根据潜变量插值生成新样本


VAE 简介

变分自编码器(Variational Autoencoder,VAE)是一种深度学习中的生成模型,它结合了自编码器(Autoencoder, AE)和概率建模的思想,在无监督学习环境中表现出了强大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成为生成模型领域的重要组成部分。

基本原理

自编码器(AE)基础:
自编码器是一种神经网络结构,通常由两部分组成:编码器(Encoder)和解码器(Decoder)。原始数据通过编码器映射到一个低维的潜在空间(或称为隐空间),这个低维向量被称为潜变量(latent variable)。然后,潜变量再通过解码器重构回原始数据的近似版本。在训练过程中,自编码器的目标是使得输入数据经过编码-解码过程后能够尽可能地恢复原貌,从而学习到数据的有效表示。

VAE的引入与扩展:
VAE 将自编码器的概念推广到了概率框架下。在 VAE 中,潜变量不再是确定性的,而是被赋予了概率分布。具体来说,对于给定的输入数据,编码器不直接输出一个点估计值,而是输出潜变量的均值和方差(假设潜变量服从高斯分布)。这样,每个输入数据可以被视为是从某个潜在的概率分布中采样得到的。

变分推断(Variational Inference):
训练 VA E时,由于真实的后验概率分布难以直接计算,因此采用变分推断来近似后验分布。编码器实际上输出的是一个参数化的概率分布 q ( z ∣ x ) q(z|x) q(zx),即给定输入 x x x 时潜变量 z z z 的概率分布。然后通过最小化 KL 散度(Kullback-Leibler divergence)来优化这个近似分布,使其尽可能接近真实的后验分布 p ( z ∣ x ) p(z|x) p(zx)

目标函数 - Evidence Lower Bound (ELBO):
VAE 的目标函数是证据下界(ELBO),它是原始数据 log-likelihood 的下界。优化该目标函数既鼓励编码器找到数据的高效潜在表示,又促使解码器基于这些表示重建出类似原始数据的新样本。

数学表达上,ELBO 通常分解为两个部分:

  1. 重构损失(Reconstruction Loss):衡量从潜变量重构出来的数据与原始数据之间的差异。
  2. KL散度损失(KL Divergence Loss):衡量编码器产生的潜变量分布与预设的标准正态分布(或其他先验分布)之间的距离。

应用与优点

  • VAE 可以用于生成新数据,例如图像、文本、音频等。
  • 由于其对潜变量进行概率建模,所以它可以提供连续的数据生成,并且能够探索数据的不同模式。
  • 在处理连续和离散数据时具有一定的灵活性。
  • 可以用于特征学习,提取数据的有效低维表示。

缺点与挑战

  • 训练 VAE 可能需要大量的计算资源和时间。
  • 生成的样本有时可能不够清晰或细节模糊,尤其是在复杂数据集上。
  • 对于某些复杂的分布形式,VAE 可能无法完美捕获所有细节。

使用 VAE 生成 MNIST 手写数字

下面我们将使用 PyTorch Lightning 来实现一个简单的 VAE 模型,并使用 MNIST 数据集来进行训练和生成。

在线 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist

忽略警告

import warnings
warnings.filterwarnings("ignore")

导入必要的库

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as snssns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasetsimport lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

设置随机种子

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cuDNN 设置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

超参数设置

batch_size = 64epochs = 10
KLD_weight = 1
lr = 0.001input_dim = 784  # 28 * 28
h_dim = 256  # 隐藏层维度  
z_dim = 2  # 潜变量维度

数据加载

train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

定义 VAE 模型

class VAE(nn.Module):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.input_dim = input_dimself.h_dim = h_dimself.z_dim = z_dim# Encoderself.fc1 = nn.Linear(input_dim, h_dim)self.fc21 = nn.Linear(h_dim, z_dim)  # muself.fc22 = nn.Linear(h_dim, z_dim)  # log_var# Decoderself.fc3 = nn.Linear(z_dim, h_dim)self.fc4 = nn.Linear(h_dim, input_dim)def encode(self, x):h = torch.relu(self.fc1(x))mean = self.fc21(h)log_var = self.fc22(h)return mean, log_vardef reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = torch.relu(self.fc3(z))out = torch.sigmoid(self.fc4(h))return outdef forward(self, x):mean, log_var = self.encode(x)z = self.reparameterize(mean, log_var)reconstructed_x = self.decode(z)return reconstructed_x, mean, log_varvae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])

定义损失函数

def loss_function(x_hat, x, mu, log_var, KLD_weight=1):BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重构损失KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度损失loss = BCE_loss + KLD_loss * KLD_weightreturn loss, BCE_loss, KLD_loss

定义 Lightning 模型

class LitModel(pl.LightningModule):def __init__(self, input_dim=784, h_dim=400, z_dim=20):super().__init__()self.model = VAE(input_dim, h_dim, z_dim)def forward(self, x):x = self.model(x)return xdef configure_optimizers(self):optimizer = optim.Adam(self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5)return optimizerdef training_step(self, batch, batch_idx):x, y = batchx = x.view(x.size(0), -1)reconstructed_x, mean, log_var = self(x)loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)self.log_dict({"BCE_loss": BCE_loss,"KLD_loss": KLD_loss,},on_step=False,on_epoch=True,logger=True,)return lossdef decode(self, z):out = self.model.decode(z)return out

训练模型

model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(max_epochs=epochs,enable_progress_bar=True,logger=logger,callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)

绘制训练过程

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()

训练过程

随机生成新样本

row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()plt.figure(figsize=(col, row))
for i in range(row * col):plt.subplot(row, col, i + 1)plt.imshow(random_res[i].squeeze(), cmap="gray")plt.xticks([])plt.yticks([])plt.axis("off")
plt.show()

随机生成新样本

根据潜变量插值生成新样本

from scipy.stats import normn = 15
digit_size = 28grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):for j, xi in enumerate(grid_x):t = [xi, yi]z_sampled = torch.FloatTensor(t)with torch.no_grad():decode = model.decode(z_sampled)digit = decode.view((digit_size, digit_size))figure[i * digit_size : (i + 1) * digit_size,j * digit_size : (j + 1) * digit_size,] = digitplt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()

根据潜变量插值生成新样本

相关文章:

变分自编码器(VAE)PyTorch Lightning 实现

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。 🍎个人主页:小嗷犬的个人主页 🍊个人网站:小嗷犬的技术小站 🥭个人信条:为天地立心&…...

设备驱动开发_1

可加载模块如何工作的 主要内容 描述可加载模块优势使用模块命令效率使用和定义模块密钥和模块工作1 描述可加载模块优势 开发周期优势: 静态模块在/boot下的vmlinuz中,需要配置、编译、重启。 开发周期长。 LKM 不需要重启。 开发周期优于静态模块。 2 使用模块命令效率…...

C语言位域(Bit Fields)知识点精要解析

在C语言中,位域(Bit Field)是一种独特的数据结构特性,它允许程序员在结构体(struct)中定义成员变量,并精确指定其占用的位数。通过使用位域,我们可以更高效地利用存储空间&#xff0…...

离散数学——图论(笔记及思维导图)

离散数学——图论(笔记及思维导图) 目录 大纲 内容 参考 大纲 内容 参考 笔记来自【电子科大】离散数学 王丽杰...

opencv图像像素的读写操作

void QuickDemo::pixel_visit_demo(Mat & image) {int w image.cols;//宽度int h image.rows;//高度int dims image.channels();//通道数 图像为灰度dims等于一 图像为彩色时dims等于三 for (int row 0; row < h; row) {for (int col 0; col < w; col) {if…...

Java学习第十四节之冒泡排序

冒泡排序 package array;import java.util.Arrays;//冒泡排序 //1.比较数组中&#xff0c;两个相邻的元素&#xff0c;如果第一个数比第二个数大&#xff0c;我们就交换他们的位置 //2.每一次比较&#xff0c;都会产生出一个最大&#xff0c;或者最小的数字 //3.下一轮则可以少…...

第1章 计算机网络体系结构-1.1计算机网络概述

1.1.1计算机网络概念 计算机网络是将一个分散的&#xff0c;具有独立功能的计算机系统通过通信设备与路线连接起来&#xff0c;由功能完善的软件实现资源共享和信息传递的系统。(计算机网络就是一些互连的&#xff0c;自治的计算机系统的集合) 1.1.2计算机网络的组成 从不同角…...

蓝桥杯:C++排序

排序 排序和排列是算法题目常见的基本算法。几乎每次蓝桥杯软件类大赛都有题目会用到排序或排列。常见的排序算法如下。 第(3)种排序算法不是基于比较的&#xff0c;而是对数值按位划分&#xff0c;按照以空间换取时间的思路来排序。看起来它们的复杂度更好&#xff0c;但实际…...

数据结构-堆

1.容器 容器用于容纳元素集合&#xff0c;并对元素集合进行管理和维护&#xff0e; 传统意义上的管理和维护就是&#xff1a;增&#xff0c;删&#xff0c;改&#xff0c;查&#xff0e; 我们分析每种类型容器时&#xff0c;主要分析其增&#xff0c;删&#xff0c;改&#xff…...

奔跑吧小恐龙(Java)

前言 Google浏览器内含了一个小彩蛋当没有网络连接时&#xff0c;浏览器会弹出一个小恐龙&#xff0c;当我们点击它时游戏就会开始进行&#xff0c;大家也可以玩一下试试&#xff0c;网址&#xff1a;恐龙快跑 - 霸王龙游戏. (ur1.fun) 今天我们也可以用Java来简单的实现一下这…...

Ubuntu 1804 And Above Coredump Settings

查看 coredump 是否开启 # 查询&#xff0c; 0 未开启&#xff0c; unlimited 开启 xiaoUbuntu:/var/core$ ulimit -c 0# 开启 xiaoUbuntu:/var/core$ ulimit -c unlimited查看 coredump 保存路径 默认情况下&#xff0c;Ubuntu 使用 apport 服务处理 coredump 文件&#xff…...

docker 2:安装

docker 2&#xff1a;安装 ‍ ubuntu 安装 docker sudo apt install docker.io‍ 把当前用户放进 docker 用户组&#xff0c;避免每次运行 docker 命都要使用 sudo​ 或者 root​ 权限。 sudo usermod -aG docker $USER​id $USER ​看到用户已加入 docker 组 ​​ ‍ …...

LeetCode Python - 19.删除链表的倒数第N个结点

目录 题目答案运行结果 题目 给你一个链表&#xff0c;删除链表的倒数第 n 个结点&#xff0c;并且返回链表的头结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], n 2 输出&#xff1a;[1,2,3,5] 示例 2&#xff1a; 输入&#xff1a;head [1], n 1 输出&a…...

Spring Boot 笔记 005 环境搭建

1.1 创建数据库和表&#xff08;略&#xff09; 2.1 创建Maven工程 2.2 补齐resource文件夹和application.yml文件 2.3 porn.xml中引入web,mybatis,mysql等依赖 2.3.1 引入springboot parent 2.3.2 删除junit 依赖--不能删&#xff0c;删了会报错 2.3.3 引入spring web依赖…...

【解决(几乎)任何机器学习问题】:超参数优化篇(超详细)

这篇文章相当长&#xff0c;您可以添加至收藏夹&#xff0c;以便在后续有空时候悠闲地阅读。 有了优秀的模型&#xff0c;就有了优化超参数以获得最佳得分模型的难题。那么&#xff0c;什么是超参数优化呢&#xff1f;假设您的机器学习项⽬有⼀个简单的流程。有⼀个数据集&…...

面试计算机网络框架八股文十问十答第七期

面试计算机网络框架八股文十问十答第七期 作者&#xff1a;程序员小白条&#xff0c;个人博客 相信看了本文后&#xff0c;对你的面试是有一定帮助的&#xff01;关注专栏后就能收到持续更新&#xff01; ⭐点赞⭐收藏⭐不迷路&#xff01;⭐ 1&#xff09;UDP协议为什么不可…...

Codeforces Round 926 (Div. 2)

A. Sasha and the Beautiful Array&#xff08;模拟&#xff09; 思路 最大值减去最小值 #include<iostream> #include<algorithm> using namespace std; const int N 110; int a[N];int main(){int t, n;cin>>t;while(t--){cin>>n;for(int i 0; i…...

构建智慧交通平台:架构设计与实现

随着城市交通的不断发展和智能化技术的迅速进步&#xff0c;智慧交通平台作为提升城市交通管理效率和水平的重要手段备受关注。本文将探讨如何设计和实现智慧交通平台的系统架构&#xff0c;以应对日益增长的城市交通需求&#xff0c;并提高交通管理的智能化水平。 ### 1. 智慧…...

移动端设置position: fixed;固定定位,底部出现一条缝隙,不知原因,欢迎探讨!!!

1、问题 在父盒子中有一个子盒子&#xff0c;父盒子加了固定定位&#xff0c;需要子盒子上下都有要边距&#xff0c;用margin或者padding挤开时&#xff0c;会出现缝隙是子盒子背景颜色的。 测试过了&#xff0c;有些手机型号有&#xff0c;有些没有&#xff0c;微信小程序同移…...

有关网络安全的课程学习网页

1.思科网络学院 免费学习skillsforall的课程 课程链接&#xff1a;Introduction to Cybersecurity by Cisco: Free Online Course (skillsforall.com) 2.斯坦福大学计算机和网络安全基础 该证书对于初学者来说最有价值&#xff0c;它由最著名的大学之一斯坦福大学提供。您可…...

丹青识画系统Prompt工程指南:如何用文本描述引导更精准的风格鉴定

丹青识画系统Prompt工程指南&#xff1a;如何用文本描述引导更精准的风格鉴定 丹青识画这类AI系统&#xff0c;很多人以为它就是个“看图说话”的工具&#xff0c;把图片丢进去&#xff0c;它告诉你这是什么风格、哪个流派。这确实没错&#xff0c;但如果你只这么用&#xff0…...

PySR社区贡献指南:如何参与这个革命性符号回归开源项目的开发

PySR社区贡献指南&#xff1a;如何参与这个革命性符号回归开源项目的开发 【免费下载链接】PySR High-Performance Symbolic Regression in Python and Julia 项目地址: https://gitcode.com/gh_mirrors/py/PySR 想要为高性能符号回归工具PySR做出贡献吗&#xff1f;这份…...

ReAct Agent:新手程序员必看!收藏这款融合推理与行动的AI智能体框架,轻松入门大模型应用开发

ReAct框架通过结合推理与行动&#xff0c;解决了传统提示工程的局限性&#xff0c;构建出能主动思考、决策并执行复杂任务的智能体。本文详细介绍了ReAct的核心设计思想&#xff0c;包括推理模块的动态思考链和错误回溯机制&#xff0c;以及行动模块的工具集成和环境状态感知。…...

高效管理惠普OMEN游戏本:OmenSuperHub全面解析与实战指南

高效管理惠普OMEN游戏本&#xff1a;OmenSuperHub全面解析与实战指南 【免费下载链接】OmenSuperHub 项目地址: https://gitcode.com/gh_mirrors/om/OmenSuperHub OmenSuperHub是一款专为惠普OMEN系列游戏本设计的轻量级系统管理工具&#xff0c;它通过替代原厂Omen Ga…...

别再一条条Update了!MyBatis批量更新数据,用这个Case When写法性能翻倍

MyBatis批量更新性能优化实战&#xff1a;告别低效循环&#xff0c;拥抱CASE WHEN 每次看到代码里用循环一条条执行update语句&#xff0c;我的数据库性能监控图表就会剧烈波动——这简直是DBA的噩梦。上周排查一个后台任务卡死问题&#xff0c;发现同事在处理5万条数据更新时&…...

G-Helper:华硕笔记本电池健康管理的终极轻量化解决方案

G-Helper&#xff1a;华硕笔记本电池健康管理的终极轻量化解决方案 【免费下载链接】g-helper Lightweight Armoury Crate alternative for Asus laptops. Control tool for ROG Zephyrus G14, G15, G16, M16, Flow X13, Flow X16, TUF, Strix, Scar and other models 项目地…...

SPIRAN ART SUMMONER优化指南:如何调整参数让生成的图片更符合预期

SPIRAN ART SUMMONER优化指南&#xff1a;如何调整参数让生成的图片更符合预期 1. 理解SPIRAN ART SUMMONER的核心参数 SPIRAN ART SUMMONER作为一款基于Flux.1-Dev模型的图像生成工具&#xff0c;其参数设置直接影响最终输出效果。与普通AI绘画工具不同&#xff0c;它融入了…...

如何选择可靠的第三方软件测试机构,构建全生命周期的软件安全防线

在数字化转型的浪潮中&#xff0c;软件已成为企业运营的核心。然而&#xff0c;伴随其重要性一同增长的&#xff0c;是日益严峻的安全威胁。传统软件开发流程中&#xff0c;安全测试往往被置于交付前的独立环节&#xff0c;这种“事后补丁”的模式导致安全漏洞发现晚、修复成本…...

老旧Mac如何重获新生?OCLP-Mod带来的系统升级解决方案

老旧Mac如何重获新生&#xff1f;OCLP-Mod带来的系统升级解决方案 【免费下载链接】OCLP-Mod A mod version for OCLP,with more interesting features. 项目地址: https://gitcode.com/gh_mirrors/oc/OCLP-Mod 随着科技的快速迭代&#xff0c;许多曾经性能卓越的Mac设备…...

ssm+java2026年毕设私教预约系统【源码+论文】

本系统&#xff08;程序源码&#xff09;带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、选题背景关于会议管理问题的研究&#xff0c;现有研究主要以传统纸质登记和简单的OA系统为主&#xff0c;专门针对智能化、全流程会议预…...