pytorch实现变分自编码器
人工智能例子汇总:AI常见的算法和例子-CSDN博客
变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution),生成与输入数据相似的新样本。VAE 可以用于数据生成、降维、异常检测等任务。
VAE 的关键思想是在传统的自编码器(Autoencoder)的基础上,引入了变分推断(Variational Inference)和概率模型,使得网络能够学习到数据的潜在分布,而不仅仅是数据的映射。
VAE 的结构:
- 编码器(Encoder):将输入数据映射到潜在空间的分布。不同于传统的自编码器直接将数据映射到一个固定的潜在向量,VAE 通过输出潜在变量的均值和方差来描述一个概率分布,这样潜在空间中的每个点都有一个概率分布。
- 潜在空间(Latent Space):表示数据的潜在特征。在 VAE 中,潜在空间的表示是一个分布而不是固定的值。通常,采用正态分布来作为潜在空间的先验分布。
- 解码器(Decoder):从潜在空间的样本中重构输入数据。解码器通过将潜在空间的点映射回数据空间来生成样本。
VAE 的目标函数:
VAE 的目标是最大化变分下界(Variational Lower Bound,简称 ELBO),即通过优化以下两部分的加权和:
- 重构误差(Reconstruction Loss):衡量生成的数据和输入数据之间的差异,通常使用均方误差(MSE)或交叉熵(Cross-Entropy)。
- KL 散度(KL Divergence):衡量潜在空间的分布与先验分布(通常是标准正态分布)之间的差异。
其最终的目标是使生成的数据尽可能接近真实数据,同时使潜在空间的分布接近先验分布。
优点:
- VAE 能够生成具有多样性的样本,尤其适用于图像、音频等数据的生成。
- 潜在空间通常具有良好的结构,可以进行插值、样本生成等操作。
应用:
- 生成任务:如图像生成、文本生成等。
- 数据重构:如去噪、自编码等。
- 半监督学习:VAE 可以结合有标签和无标签的数据进行训练,提升模型的泛化能力。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt# 生成圆形图像的函数(使用PyTorch)
def generate_circle_image(size=64):image = torch.zeros((1, size, size)) # 使用 PyTorch 创建空白图像center = size // 2radius = size // 4for y in range(size):for x in range(size):if (x - center) ** 2 + (y - center) ** 2 <= radius ** 2:image[0, y, x] = 1 # 在圆内的点设置为白色return image# 生成方形图像的函数(使用PyTorch)
def generate_square_image(size=64):image = torch.zeros((1, size, size)) # 使用 PyTorch 创建空白图像padding = size // 4image[0, padding:size - padding, padding:size - padding] = 1 # 设置方形区域为白色return image# 自定义数据集:圆形和方形图像
class ShapeDataset(Dataset):def __init__(self, num_samples=1000, size=64):self.num_samples = num_samplesself.size = sizeself.data = []# 生成数据:一半是圆形图像,一半是方形图像for i in range(num_samples // 2):self.data.append(generate_circle_image(size))self.data.append(generate_square_image(size))def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx].float() # 直接返回 PyTorch Tensor 格式的数据# VAE模型定义
class VAE(nn.Module):def __init__(self, latent_dim=2):super(VAE, self).__init__()self.latent_dim = latent_dim# 编码器self.fc1 = nn.Linear(64 * 64, 400)self.fc21 = nn.Linear(400, latent_dim) # 均值self.fc22 = nn.Linear(400, latent_dim) # 方差# 解码器self.fc3 = nn.Linear(latent_dim, 400)self.fc4 = nn.Linear(400, 64 * 64)def encode(self, x):h1 = torch.relu(self.fc1(x.view(-1, 64 * 64)))return self.fc21(h1), self.fc22(h1) # 返回均值和方差def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h3 = torch.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3)).view(-1, 1, 64, 64) # 重构图像def forward(self, x):mu, logvar = self.encode(x)z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar# 损失函数:重构误差 + KL 散度
def loss_function(recon_x, x, mu, logvar):BCE = nn.functional.binary_cross_entropy(recon_x.view(-1, 64 * 64), x.view(-1, 64 * 64), reduction='sum')# KL 散度return BCE + 0.5 * torch.sum(torch.exp(logvar) + mu ** 2 - 1 - logvar)# 设置超参数
batch_size = 128
epochs = 10
latent_dim = 2
learning_rate = 1e-3# 数据加载
train_loader = DataLoader(ShapeDataset(num_samples=2000), batch_size=batch_size, shuffle=True)# 创建模型和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练模型
def train(epoch):model.train()train_loss = 0for batch_idx, data in enumerate(train_loader):data = data.to(device)optimizer.zero_grad()recon_batch, mu, logvar = model(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()if batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item() / len(data):.6f}')print(f'Train Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')# 测试并显示一些真实图像和生成的图像
def test():model.eval()with torch.no_grad():# 获取一批真实的图像(原始图像)real_images = next(iter(train_loader))[:64] # 只取前64个图像real_images = real_images.cpu().numpy()# 从潜在空间随机生成一些样本sample = torch.randn(64, latent_dim).to(device)generated_images = model.decode(sample).cpu().numpy()# 显示真实图像和生成的图像,分别标明fig, axes = plt.subplots(8, 8, figsize=(8, 8))axes = axes.flatten()for i in range(64):if i < 32: # 前32个显示真实图像axes[i].imshow(real_images[i].squeeze(), cmap='gray')axes[i].set_title('Real', fontsize=8)else: # 后32个显示生成图像axes[i].imshow(generated_images[i - 32].squeeze(), cmap='gray')axes[i].set_title('Generated', fontsize=8)axes[i].axis('off')plt.tight_layout()plt.show()# 训练模型
for epoch in range(1, epochs + 1):train(epoch)# 训练完成后,显示生成的图像
test()
解释:
- 真实图像 (
real_images
):我们通过next(iter(train_loader))
获取一批真实图像,并将其转换为 NumPy 数组,以便matplotlib
显示。 - 生成图像 (
generated_images
):通过模型生成的图像,使用decode()
方法生成潜在空间的样本。 - 图像展示:前 32 张图像展示真实图像,后 32 张图像展示生成的图像。每个图像上方都有
Real
或Generated
标注。
结果:
- 前32个图像:显示真实图像,并标注为
Real
。 - 后32个图像:显示通过训练后的 VAE 生成的图像,并标注为
Generated
。
相关文章:
pytorch实现变分自编码器
人工智能例子汇总:AI常见的算法和例子-CSDN博客 变分自编码器(Variational Autoencoder, VAE)是一种生成模型,属于深度学习中的无监督学习方法。它通过学习输入数据的潜在分布(Latent Distribution)&…...

Node.js与嵌入式开发:打破界限的创新结合
文章目录 一、Node.js的本质与核心优势1.1 什么是Node.js?1.2 嵌入式开发的范式转变二、Node.js与嵌入式结合的四大技术路径2.1 硬件交互层2.2 物联网协议栈2.3 边缘计算架构2.4 轻量化运行时方案三、实战案例:智能农业监测系统3.1 硬件配置3.2 软件架构3.3 核心代码片段四、…...
Noise Conditional Score Network
NCSN p σ ( x ~ ∣ x ) : N ( x ~ ; x , σ 2 I ) p_\sigma(\tilde{\mathrm{x}}|\mathrm{x}) : \mathcal{N}(\tilde{\mathrm{x}}; \mathrm{x}, \sigma^2\mathbf{I}) pσ(x~∣x):N(x~;x,σ2I) p σ ( x ~ ) : ∫ p d a t a ( x ) p σ ( x ~ ∣ x ) d x p_\sigma(\mathrm…...

低代码系统-产品架构案例介绍、蓝凌(十三)
蓝凌低代码系统,依旧是从下到上,从左至右的顺序。 技术平台h/iPaas 指低层使用了哪些技术,例如:微服务架构,MySql数据库。个人认为,如果是市场的主流,就没必要赘述了。 新一代门户 门户设计器&a…...

51单片机 02 独立按键
一、独立按键控制LED亮灭 轻触按键:相当于是一种电子开关,按下时开关接通,松开时开关断开,实现原理是通过轻触按键内部的金属弹片受力弹动来实现接通和断开。 #include <STC89C5xRC.H> void main() { // P20xFE;while(1){…...

2021.3.1的android studio版本就很好用
使用最新版的studio有个问题就是gradle版本也比较高,这样就容易出现之前项目不兼容问题,配置gradle可能会出现很多问题比较烦,所以干脆就用老版本的studio...

CSV数据分析智能工具(基于OpenAI API和streamlit)
utils.py: from langchain_openai import ChatOpenAI from langchain_experimental.agents.agent_toolkits import create_csv_agent import jsonPROMPT_TEMPLATE """你是一位数据分析助手,你的回应内容取决于用户的请求内容。1. 对于文…...
移除元素-双指针(下标)
题目 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素。元素的顺序可能发生改变。然后返回 nums 中与 val 不同的元素的数量。 假设 nums 中不等于 val 的元素数量为 k,要通过此题,您需要执行以下操作:…...
蓝桥杯备赛题目练习(一)
目录 一. 口算练习题 代码如下 代码解读(简略重点): 代码解读(详细): 二. 小乐乐改数字 代码(1):当做整数读取 代码(2):当做字符…...
FortiOS 存在身份验证绕过导致命令执行漏洞(CVE-2024-55591)
免责声明: 本文旨在提供有关特定漏洞的深入信息,帮助用户充分了解潜在的安全风险。发布此信息的目的在于提升网络安全意识和推动技术进步,未经授权访问系统、网络或应用程序,可能会导致法律责任或严重后果。因此,作者不对读者基于本文内容所采取的任何行为承担责任。读者在…...

【多线程】线程池核心数到底如何配置?
🥰🥰🥰来都来了,不妨点个关注叭! 👉博客主页:欢迎各位大佬!👈 文章目录 1. 前置回顾2. 动态线程池2.1 JMX 的介绍2.1.1 MBeans 介绍 2.2 使用 JMX jconsole 实现动态修改线程池2.2.…...
Windows图形界面(GUI)-QT-C/C++ - Qt Combo Box
公开视频 -> 链接点击跳转公开课程博客首页 -> 链接点击跳转博客主页 目录 一、概述 1.1 基本概念 1.2 应用场景对比 二、核心属性详解 2.1 行为控制 2.2 显示配置 三、数据操作与访问 3.1 基础数据管理 3.2 高级数据访问 四、用户交互处理 4.1 信号处…...

开源AI智能名片2 + 1链动模式S2B2C商城小程序:内容价值创造与传播新引擎
摘要:本文聚焦于信息爆炸时代下,内容价值的创造与传播。随着用户角色的转变,其在内容生产与传播中的价值日益凸显。同时,深入探讨开源AI智能名片2 1链动模式S2B2C商城小程序这一创新商业模式,如何借助用户创造内容并传…...
python读取excel工具:openpyxl | AI应用开发
python读取excel工具:openpyxl | AI应用开发 openpyxl 是一个 Python 库,专门用于读写 Excel 2010 xlsx/xlsm/xltx/xltm 文件。它是处理 Excel 文件的强大工具,可以让你在不需要安装 Excel 软件的情况下,对 Excel 文件进行创建、…...

堆的基本概念
1.1 堆的基本概念 虚拟机所在目录 E:\ctf\pwn-self 进入虚拟机的pwndocker环境 holyeyesubuntu:~$ pwd /home/holyeyes holyeyesubuntu:~$ sudo ./1run.sh IDA分析 int __fastcall main(int argc, const char **argv, const char **envp) { void *v4; // [rsp20h] [rbp-1…...
Android车机DIY开发之软件篇(九) NXP AutomotiveOS编译
Android车机DIY开发之软件篇(十一) NXP AutomotiveOS编译 Google 在汽车上也提供了用于汽车的 Google 汽车服务(GAS,Google Automotive Service),包含有 Google 地图、应用市场、Google 汽车助理等等。Google 汽车服务同样没有开…...

嵌入式工程师必学(143):模拟信号链基础
概述: 我们每天使用的许多电子设备,以及我们赖以生存的电子设备,如果不使用电子工程师设计的实际输入信号,就无法运行。 模拟信号链由四个主要元件组成:传感器、放大器、滤波器和模数转换器 (ADC)。这些传感器用于检测、调节模拟信号并将其转换为适合由微控制器或其他数…...

《LLM大语言模型深度探索与实践:构建智能应用的新范式,融合代理与数据库的高级整合》
文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...

e2studio开发RA2E1(5)----GPIO输入检测
e2studio开发RA2E1.5--GPIO输入检测 概述视频教学样品申请硬件准备参考程序源码下载新建工程工程模板保存工程路径芯片配置工程模板选择时钟设置GPIO口配置按键口配置按键口&Led配置R_IOPORT_PortRead()函数原型R_IOPORT_PinRead()函数原型代码 概述 本篇文章主要介绍如何…...

Spring @Lazy:延迟初始化,为应用减负
在Spring框架中,Lazy注解的作用非常直观,它就是用来告诉Spring容器:“嘿,这个Bean嘛,先别急着创建和初始化,等到真正需要用到的时候再弄吧!” 默认情况下,Spring容器在启动时会立即创…...
<6>-MySQL表的增删查改
目录 一,create(创建表) 二,retrieve(查询表) 1,select列 2,where条件 三,update(更新表) 四,delete(删除表…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...
精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南
精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南 在数字化营销时代,邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天,我们将深入解析邮件打开率、网站可用性、页面参与时…...

OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
服务器--宝塔命令
一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合
在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...

mac 安装homebrew (nvm 及git)
mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用: 方法一:使用 Homebrew 安装 Git(推荐) 步骤如下:打开终端(Terminal.app) 1.安装 Homebrew…...

STM32HAL库USART源代码解析及应用
STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...