当前位置: 首页 > news >正文

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat =   self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter()    interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

相关文章:

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言: 这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。 整个工程两个文件: vae.py main.py 目录: vae main 一 vae 文件名: vae…...

Java实现HTTP的上传与下载

相信很多人对于java文件下载的过程都存在一些疑惑,比如下载上传文件会不会占用vm内存,上传/下载大文件会不会导致oom。下面从字节流的角度看下载/上传的实现,可以更加深入理解文件的上传和下载功能。 文件下载 首先明确,文件下载…...

VPG算法

VPG算法 前言 首先来看经典的策略梯度REINFORCE算法: 在REINFORCE中,每次采集一个episode的轨迹,计算每一步动作的回报 G t G_t Gt​,与动作概率对数相乘,作为误差反向传播,有以下几个特点: …...

docker 笔记5:redis 集群分布式存储案例

尚硅谷Docker实战教程(docker教程天花板)_哔哩哔哩_bilibili 目录 1.cluster(集群)模式-docker版哈希槽分区进行亿级数据存储 1.1面试题 1.1.1 方案1 哈希取余分区 1.1.2 方案2 一致性哈希算法分区 原理 优点 一致性哈希算法的容错性 一致性…...

【Vue2】 axios库

网络请求库-axios库 认识Axios库为什么选择Axios库安装Axios axios发送请求常见的配置选项简单请求可以给Axios设置公共的基础配置发送多个请求 axios创建实例为什么要创建axios的实例 axios的拦截器请求拦截器响应拦截器 axios请求封装 认识Axios库 为什么选择Axios库 在游览…...

云计算 - 百度AIStudio使用小结

云计算 - 百度AIStudio使用小结 前言 本文以ffmpeg处理视频为例,小结一下AI Studio的使用体验及一些避坑技巧。 算力获得 免费的算力获得方式为:每日登录后运行一个项目(只需要点击运行,不需要真正运行)即可获得8小…...

刷新你对Redis持久化的认知

认识持久化 redis是一个内存数据库,数据存储到内存中。而内存的数据是不持久的,要想做到持久化,就需要让redis把数据存储到硬盘上。因此redis既要在内存上存储一份数据,还要在硬盘上存储一份数据。这样这两份数据在理论上是完全相…...

Greenplum-最佳实践小结

注:本文翻译自https://docs.vmware.com/en/VMware-Greenplum/7/greenplum-database/best_practices-logfiles.html 数据模型 Greenplum数据库是一个分析型MPP无共享数据库。该模型与高度规范化/事务性的SMP数据库明显不同。Greenplum数据库使用适合MPP分析处理的非…...

从Gamma空间改为Linear空间会导致性能下降吗

1)从Gamma空间改为Linear空间会导致性能下降吗 2)如何处理没有使用Unity Ads却收到了GooglePlay平台的警告 3)C#端如何处理xLua在执行DoString时候死循环 4)Texture2DArray相关 这是第350篇UWA技术知识分享的推送,精选…...

双轨制的发展,弊端和前景

双轨制是一种经济体制,指两种不同的规则或机制并行运行,以适应不同的市场或客户需求。双轨制最早出现在中国的改革开放中,是从计划经济向市场经济过渡的一种渐进式改革方式。 双轨制的发展可以分为三个阶段: 第一阶段(…...

生成对抗网络(GAN):在图像生成和修复中的应用

文章目录 什么是生成对抗网络(GAN)?GAN在图像生成中的应用图像生成风格迁移 GAN在图像修复中的应用图像修复 拓展应用领域总结 🎉欢迎来到AIGC人工智能专栏~生成对抗网络(GAN):在图像生成和修复…...

扬杰科技携手企企通,召开SRM采购供应链协同系统项目启动会

近日,中国功率半导体领先企业扬州扬杰电子科技股份有限公司(以下简称“扬杰科技”)与企企通召开SRM采购供应链协同系统项目启动会,双方项目团队成员一同出席本次会议。 会上,双方就扬杰科技采购供应链管理平台项目的目…...

AtCoder Beginner Contest 318

目录 A - Full Moon B - Overlapping sheets C - Blue Spring D - General Weighted Max Matching E - Sandwiches F - Octopus A - Full Moon #include<bits/stdc.h> using namespace std; const int N1e65; typedef long long ll ; const int maxv4e65; typedef …...

《Python魔法大冒险》003 两个神奇的魔法工具

魔法师:小鱼,要开始编写魔法般的Python程序,我们首先需要两个神奇的工具:Python解释器和代码编辑器。 小鱼:这两个工具是做什么的? 魔法师:你可以把Python解释器看作是一个魔法棒,只要你向它说出正确的咒语,它就会为你施展魔法。 小鱼:那这个解释器和我之前用的电…...

每日一题-动态规划(从不同类型的物品中各挑选一个,使得最后花费总和等于1000)

四种类型的物品&#xff0c;每一种类型物品数量都是n&#xff0c;先要从每种类型的物品中挑选一件&#xff0c;使得最后花费总和等于1000 暴力做法10000^4 看到花费总和是1000&#xff0c;很小且固定的数字&#xff0c;肯定有玄机&#xff0c;从这里想应该是用dp&#xff0c;不…...

2023-9-3 试除法判定质数

题目链接&#xff1a;试除法判定质数 #include <iostream>using namespace std;bool is_prime(int n) {if(n < 2) return false;for(int i 2; i < n / i; i){if(n % i 0) return false;}return true; }int main() {int n;cin >> n;while(n--){int x;cin &g…...

【Apollo学习笔记】——规划模块TASK之RULE_BASED_STOP_DECIDER

文章目录 前言RULE_BASED_STOP_DECIDER相关配置RULE_BASED_STOP_DECIDER总体流程StopOnSidePassCheckClearDoneCheckSidePassStopIsPerceptionBlockedIsClearToChangeLaneCheckSidePassStopBuildStopDecisionELSE:涉及到的一些其他函数NormalizeAngleSelfRotate CheckLaneChang…...

【SpringBoot】最基础的项目架构(SpringBoot+Mybatis-plus+lombok+knife4j+hutool)

汝之观览&#xff0c;吾之幸也&#xff01; 从本文开始讲下项目中用到的一些框架和技术&#xff0c;最基本的框架使用的是SpringBoot(2.5.10)Mybatis-plus(3.5.3.2)lombok(1.18.28)knife4j(3.0.3)hutool(5.8.21),可以做到代码自动生成&#xff0c;满足最基本的增删查改。 一、新…...

RNN 单元:分析 GRU 方程与 LSTM,以及何时选择 RNN 而不是变压器

一、说明 深度学习往往感觉像是在雪山上找到自己的道路。拥有坚实的原则会让你对做出决定更有信心。我们都去过那里 在上一篇文章中&#xff0c;我们彻底介绍并检查了 LSTM 单元的各个方面。有人可能会争辩说&#xff0c;RNN方法已经过时了&#xff0c;研究它们是没有意义的。的…...

Linux音频了解

ALPHA I.MX6U 开发板支持音频&#xff0c;板上搭载了音频编解码芯片 WM8960&#xff0c;支持播放以及录音功能&#xff01; 本章将会讨论如下主题内容。 ⚫ Linux 下 ALSA 框架概述&#xff1b; ⚫ alsa-lib 库介绍&#xff1b; ⚫ alsa-lib 库移植&#xff1b; ⚫ alsa-l…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界&#xff0c;看笔记好好学多敲多打&#xff0c;每个人都是大神&#xff01; 题目&#xff1a;KubeSphere 容器平台高可用&#xff1a;环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时&#xff0c;与数据库的交互无疑是核心环节。虽然传统的数据库操作方式&#xff08;如直接编写SQL语句与psycopg2交互&#xff09;赋予了我们精细的控制权&#xff0c;但在面对日益复杂的业务逻辑和快速迭代的需求时&#xff0c;这种方式的开发效率和可…...

蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练

前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1)&#xff1a;从基础到实战的深度解析-CSDN博客&#xff0c;但实际面试中&#xff0c;企业更关注候选人对复杂场景的应对能力&#xff08;如多设备并发扫描、低功耗与高发现率的平衡&#xff09;和前沿技术的…...

MODBUS TCP转CANopen 技术赋能高效协同作业

在现代工业自动化领域&#xff0c;MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步&#xff0c;这两种通讯协议也正在被逐步融合&#xff0c;形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...

基于Docker Compose部署Java微服务项目

一. 创建根项目 根项目&#xff08;父项目&#xff09;主要用于依赖管理 一些需要注意的点&#xff1a; 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件&#xff0c;否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索&#xff08;基于物理空间 广播范围&#xff09;2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

均衡后的SNRSINR

本文主要摘自参考文献中的前两篇&#xff0c;相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程&#xff0c;其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt​ 根发送天线&#xff0c; n r n_r nr​ 根接收天线的 MIMO 系…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...