当前位置: 首页 > article >正文

pytorch实现变分自编码器

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution),生成与输入数据相似的新样本。VAE 可以用于数据生成、降维、异常检测等任务。

VAE 的关键思想是在传统的自编码器(Autoencoder)的基础上,引入了变分推断(Variational Inference)和概率模型,使得网络能够学习到数据的潜在分布,而不仅仅是数据的映射。

VAE 的结构:

  1. 编码器(Encoder):将输入数据映射到潜在空间的分布。不同于传统的自编码器直接将数据映射到一个固定的潜在向量,VAE 通过输出潜在变量的均值和方差来描述一个概率分布,这样潜在空间中的每个点都有一个概率分布。
  2. 潜在空间(Latent Space):表示数据的潜在特征。在 VAE 中,潜在空间的表示是一个分布而不是固定的值。通常,采用正态分布来作为潜在空间的先验分布。
  3. 解码器(Decoder):从潜在空间的样本中重构输入数据。解码器通过将潜在空间的点映射回数据空间来生成样本。

VAE 的目标函数:

VAE 的目标是最大化变分下界(Variational Lower Bound,简称 ELBO),即通过优化以下两部分的加权和:

  • 重构误差(Reconstruction Loss):衡量生成的数据和输入数据之间的差异,通常使用均方误差(MSE)或交叉熵(Cross-Entropy)。
  • KL 散度(KL Divergence):衡量潜在空间的分布与先验分布(通常是标准正态分布)之间的差异。

其最终的目标是使生成的数据尽可能接近真实数据,同时使潜在空间的分布接近先验分布。

优点:

  • VAE 能够生成具有多样性的样本,尤其适用于图像、音频等数据的生成。
  • 潜在空间通常具有良好的结构,可以进行插值、样本生成等操作。

应用:

  • 生成任务:如图像生成、文本生成等。
  • 数据重构:如去噪、自编码等。
  • 半监督学习:VAE 可以结合有标签和无标签的数据进行训练,提升模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt# 生成圆形图像的函数(使用PyTorch)
def generate_circle_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像center = size // 2radius = size // 4for y in range(size):for x in range(size):if (x - center) ** 2 + (y - center) ** 2 <= radius ** 2:image[0, y, x] = 1  # 在圆内的点设置为白色return image# 生成方形图像的函数(使用PyTorch)
def generate_square_image(size=64):image = torch.zeros((1, size, size))  # 使用 PyTorch 创建空白图像padding = size // 4image[0, padding:size - padding, padding:size - padding] = 1  # 设置方形区域为白色return image# 自定义数据集:圆形和方形图像
class ShapeDataset(Dataset):def __init__(self, num_samples=1000, size=64):self.num_samples = num_samplesself.size = sizeself.data = []# 生成数据:一半是圆形图像,一半是方形图像for i in range(num_samples // 2):self.data.append(generate_circle_image(size))self.data.append(generate_square_image(size))def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx].float()  # 直接返回 PyTorch Tensor 格式的数据# VAE模型定义
class VAE(nn.Module):def __init__(self, latent_dim=2):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.fc1 = nn.Linear(64 * 64, 400)self.fc21 = nn.Linear(400, latent_dim)  # 均值self.fc22 = nn.Linear(400, latent_dim)  # 方差# 解码器self.fc3 = nn.Linear(latent_dim, 400)self.fc4 = nn.Linear(400, 64 * 64)def encode(self, x):h1 = torch.relu(self.fc1(x.view(-1, 64 * 64)))return self.fc21(h1), self.fc22(h1)  # 返回均值和方差def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h3 = torch.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3)).view(-1, 1, 64, 64)  # 重构图像def forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 损失函数:重构误差 + KL 散度
def loss_function(recon_x, x, mu, logvar):BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 64 * 64), x.view(-1, 64 * 64), reduction='sum')# KL 散度return BCE + 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar)# 设置超参数
batch_size = 128
epochs = 10
latent_dim = 2
learning_rate = 1e-3# 数据加载
train_loader = DataLoader(ShapeDataset(num_samples=2000), batch_size=batch_size, shuffle=True)# 创建模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
def train(epoch):model.train()train_loss = 0for batch_idx, data in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()recon_batch, mu, logvar = model(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.6f}')print(f'Train Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')# 测试并显示一些真实图像和生成的图像
def test():model.eval()with torch.no_grad():# 获取一批真实的图像(原始图像)real_images = next(iter(train_loader))[:64]  # 只取前64个图像real_images = real_images.cpu().numpy()# 从潜在空间随机生成一些样本sample = torch.randn(64, latent_dim).to(device)generated_images = model.decode(sample).cpu().numpy()# 显示真实图像和生成的图像,分别标明fig, axes = plt.subplots(8, 8, figsize=(8, 8))axes = axes.flatten()for i in range(64):if i < 32:  # 前32个显示真实图像axes[i].imshow(real_images[i].squeeze(), cmap='gray')axes[i].set_title('Real', fontsize=8)else:  # 后32个显示生成图像axes[i].imshow(generated_images[i - 32].squeeze(), cmap='gray')axes[i].set_title('Generated', fontsize=8)axes[i].axis('off')plt.tight_layout()plt.show()# 训练模型
for epoch in range(1, epochs + 1):train(epoch)# 训练完成后,显示生成的图像
test()

解释:

  1. 真实图像 (real_images):我们通过 next(iter(train_loader)) 获取一批真实图像,并将其转换为 NumPy 数组,以便 matplotlib 显示。
  2. 生成图像 (generated_images):通过模型生成的图像,使用 decode() 方法生成潜在空间的样本。
  3. 图像展示:前 32 张图像展示真实图像,后 32 张图像展示生成的图像。每个图像上方都有 RealGenerated 标注。

结果:

  • 前32个图像:显示真实图像,并标注为 Real
  • 后32个图像:显示通过训练后的 VAE 生成的图像,并标注为 Generated

相关文章:

pytorch实现变分自编码器

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 变分自编码器&#xff08;Variational Autoencoder, VAE&#xff09;是一种生成模型&#xff0c;属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布&#xff08;Latent Distribution&#xff09;&…...

使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

1. 导入必要的库 首先&#xff0c;导入我们需要的库&#xff1a;Numpy、Pytorch 和相关工具包。 import numpy as np import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import accuracy_score, recall_score, f1_score2. 自定义数据集 …...

JVM方法区

一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分&#xff0c;但是一些简单的实现可能不会去进行垃圾收集或者进行压缩&#xff0c;方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样&#xff0c;是各个…...

【Python】第七弹---Python基础进阶:深入字典操作与文件处理技巧

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、字典 1.1、字典是什么 1.2、创建字典 1.3、查找 key 1.4、新增/修改元素 1.5、删除元素 1.6、遍历…...

指导初学者使用Anaconda运行GitHub上One - DM项目的步骤

以下是指导初学者使用Anaconda运行GitHub上One - DM项目的步骤&#xff1a; 1. 安装Anaconda 下载Anaconda&#xff1a; 让初学者访问Anaconda官网&#xff08;https://www.anaconda.com/products/distribution&#xff09;&#xff0c;根据其操作系统&#xff08;Windows、M…...

在实际开发中,如何正确使用 INT(1) 和 INT(10)

在实际开发中&#xff0c;如何正确使用 INT(1) 和 INT(10) 前言 在数据库设计和开发过程中&#xff0c;数据类型的选择至关重要。 最近&#xff0c;我在工作中遇到了一个关于MySQL中INT类型的误解问题&#xff0c;这让我意识到很多开发者对INT类型的理解存在误区。 本文将深…...

像接口契约文档 这种工件,在需求 分析 设计 工作流里面 属于哪一个工作流

οゞ浪漫心情ゞο(20***328) 2016/2/18 10:26:47 请教一下&#xff0c;像接口契约文档 这种工件&#xff0c;在需求 分析 设计 工作流里面 属于哪一个工作流&#xff1f; 潘加宇(35***47) 17:17:28 你这相当于问用例图、序列图属于哪个工作流&#xff0c;看内容。 如果你的&quo…...

GAMES101学习笔记(六):Geometry 几何(基本表示方法、曲线与曲面、网格处理)

文章目录 几何的表示方法隐式几何 Implicit Geometry代数曲面(Algebraic surface)构造实体几何CSG(Constructive Solid Geometry)距离函数(Distance Function)水平集方法(Level Set Methods)分型几何(Fractal) 显式几何 Explicit Geometry点云(Point Cloud)多边形网格(Polygon …...

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.24 随机宇宙:生成现实世界数据的艺术

1.24 随机宇宙&#xff1a;生成现实世界数据的艺术 目录 #mermaid-svg-vN1An9qZ6t4JUcGa {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-vN1An9qZ6t4JUcGa .error-icon{fill:#552222;}#mermaid-svg-vN1An9qZ6t4JUc…...

深入解析:一个简单的浮动布局 HTML 示例

深入解析&#xff1a;一个简单的浮动布局 HTML 示例 示例代码解析代码结构分析1. HTML 结构2. CSS 样式 核心功能解析1. 浮动布局&#xff08;Float&#xff09;2. 清除浮动&#xff08;Clear&#xff09;3. 其他样式 效果展示代码优化与扩展总结 在网页设计中&#xff0c;浮动…...

爬虫基础(三)Session和Cookie讲解

目录 一、前备知识点 &#xff08;1&#xff09;静态网页 &#xff08;2&#xff09;动态网页 &#xff08;3&#xff09;无状态HTTP 二、Session和Cookie 三、Session 四、Cookie &#xff08;1&#xff09;维持过程 &#xff08;2&#xff09;结构 正式开始说 Sessi…...

HTMLCSS :下雪了

这段代码创建了一个动态的雪花飘落加载动画&#xff0c;通过 CSS 技术实现了雪花的下落和消失效果&#xff0c;为页面添加了视觉吸引力和动态感。 大家复制代码时&#xff0c;可能会因格式转换出现错乱&#xff0c;导致样式失效。建议先少量复制代码进行测试&#xff0c;若未能…...

力扣 84. 柱状图中最大的矩形

&#x1f517; https://leetcode.cn/problems/largest-rectangle-in-histogram 题目 给一个数组 num 表示位置 i 上圆柱的高度&#xff0c;求圆柱可以勾勒出的矩形的最大面积 思路 枚举圆柱 i&#xff0c;以该圆柱为高&#xff0c;计算其可以组成的矩形的最大面积。记录这过…...

【Windows Server实战】生产环境云和NPS快速搭建

前置条件 本文假定你已达成以下前提条件&#xff1a; 有域控DC。有证书服务器&#xff08;AD CS&#xff09;。已使用Microsoft Intune或者GPO为客户机申请证书。服务器上至少有两张网卡&#xff08;如果用虚拟机做的测试环境&#xff0c;可以用一张HostOnly网卡做测试&#…...

RHCSA——搭建FTP文件共享服务器

一、实验目的 1、掌握vsftpd服务器的配置方法 2、熟悉FTP客户端工具的使用 3、掌握常见的FTP服务器的故障排除 二、实验项目背景 某企业像架构一台FTP服务器&#xff0c;为企业局域网中的计算机提供文件传送的任务&#xff0c;为财务部门、销售部门和OA系统提供异地数据备…...

IM 即时通讯系统-50-[特殊字符]cim(cross IM) 适用于开发者的分布式即时通讯系统

IM 开源系列 IM 即时通讯系统-41-开源 野火IM 专注于即时通讯实时音视频技术&#xff0c;提供优质可控的IMRTC能力 IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统 IM 即时通讯系统-43-简单的仿QQ聊天安卓APP IM 即时通讯系统-44-仿QQ即…...

SSH代理實用指南

SSH是一種安全的遠程訪問協議&#xff0c;用於遠程登錄和代理工具&#xff0c;是一種通過SSH協議實現的網路代理&#xff0c;常用於將網路流量通過安全的SSH通道進行轉發。與傳統的HTTP代理不同&#xff0c;SSH代理能夠在多種協議下工作&#xff08;如HTTP、HTTPS、FTP等&#…...

Python在线编辑器

from flask import Flask, render_template, request, jsonify import sys from io import StringIO import contextlib import subprocess import importlib import threading import time import ast import reapp Flask(__name__)RESTRICTED_PACKAGES {tkinter: 抱歉&…...

ZZNUOJ(C/C++)基础练习1041——1050(详解版)

1041 : 数列求和2 题目描述 输入一个整数n&#xff0c;输出数列1-1/31/5-……前n项的和。 输入 输入只有一个整数n。 输出 结果保留2为小数,单独占一行。 样例输入 3 样例输出 0.87注意sum 1相当于sumsum1 注意sum * 1相当于sumsum*1 C语言版 #include<stdio.h> // 包含…...

JavaScript系列(51)--解释器实现详解

JavaScript解释器实现详解 &#x1f3af; 今天&#xff0c;让我们深入探讨JavaScript解释器的实现。解释器是一个将源代码直接转换为结果的程序&#xff0c;通过理解其工作原理&#xff0c;我们可以更好地理解JavaScript的执行过程。 解释器基础概念 &#x1f31f; &#x1f…...

浅析DDOS攻击及防御策略

DDoS&#xff08;分布式拒绝服务&#xff09;攻击是一种通过大量计算机或网络僵尸主机对目标服务器发起大量无效或高流量请求&#xff0c;耗尽其资源&#xff0c;从而导致服务中断的网络攻击方式。这种攻击方式利用了分布式系统的特性&#xff0c;使攻击规模更大、影响范围更广…...

深度学习 Pytorch 神经网络的学习

本节将从梯度下降法向外拓展&#xff0c;介绍更常用的优化算法&#xff0c;实现神经网络的学习和迭代。在本节课结束将完整实现一个神经网络训练的全流程。 对于像神经网络这样的复杂模型&#xff0c;可能会有数百个 w w w的存在&#xff0c;同时如果我们使用的是像交叉熵这样…...

【回溯】目标和 字母大小全排列

文章目录 494. 目标和解题思路&#xff1a;回溯784. 字母大小写全排列解题思路&#xff1a;回溯 494. 目标和 494. 目标和 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - &#xff0c;然后串联起所有整数&#xff0c;可以构造一个 表达式…...

Linux系统上安装与配置 MySQL( CentOS 7 )

目录 1. 下载并安装 MySQL 官方 Yum Repository 2. 启动 MySQL 并查看运行状态 3. 找到 root 用户的初始密码 4. 修改 root 用户密码 5. 设置允许远程登录 6. 在云服务器配置 MySQL 端口 7. 关闭防火墙 8. 解决密码错误的问题 前言 在 Linux 服务器上安装并配置 MySQL …...

Miniconda 安装及使用

文章目录 前言1、Miniconda 简介2、Linux 环境说明2.1、安装2.2、配置2.3、常用命令2.4、常见问题及解决方案 前言 在 Python 中&#xff0c;“环境管理”是一个非常重要的概念&#xff0c;它主要是指对 Python 解释器及其相关依赖库进行管理和隔离&#xff0c;以确保开发环境…...

记录一次,PyQT的报错,多线程Udp失效,使用工具如netstat来检查端口使用情况。

1.问题 报错Exception in thread Thread-1: Traceback (most recent call last): File "threading.py", line 932, in _bootstrap_inner File "threading.py", line 870, in run File "main.py", line 456, in udp_recv IndexError: list…...

kamailio-ACC_JSON模块详解【后端语言go】

要确认 ACC_JSON 模块是否已经成功将计费信息推送到消息队列&#xff08;MQueue&#xff09;&#xff0c;以及如何从队列中取值&#xff0c;可以按照以下步骤进行操作&#xff1a; 1. 确认 ACC_JSON 已推送到队列 1.1 配置 ACC_JSON 确保 ACC_JSON 模块已正确配置并启用。以下…...

群晖NAS安卓Calibre 个人图书馆

docker 下载镜像johngong/calibre-web&#xff0c;安装之 我是本地的/docker/xxx/metadata目录 映射到 /usr/local/calibre-web/app/cps/metadata_provider CALIBREDB_OTHER_OPTION 删除 CALIBRE_SERVER_USER calibre_server_user 缺省用户名口令 admin admin123 另外有个N…...

android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色

安卓自定义DatePicker选中日期颜色 背景&#xff1a;解决方案&#xff1a;方案一&#xff1a;方案二&#xff1a;实践效果&#xff1a; 背景&#xff1a; 最近在尝试用原生安卓实现仿element-ui表单校验功能&#xff0c;其中的的选择日期涉及到安卓DatePicker组件的使用&#…...

pytorch实现基于Word2Vec的词嵌入

PyTorch 实现 Word2Vec(Skip-gram 模型) 的完整代码,使用 中文语料 进行训练,包括数据预处理、模型定义、训练和测试。 1. 主要特点 支持中文数据,基于 jieba 进行分词 使用 Skip-gram 进行训练,适用于小数据集 支持负采样,提升训练效率 使用 cosine similarity 计算相…...