[PyTorch][chapter 54][Variational Auto-Encoder 实战]
前言:

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

训练的过程中要注意调节forward 中的kle ,调参。
整个工程两个文件:
vae.py
main.py
目录:
- vae
- main
一 vae
文件名: vae.py
作用: Variational Autoencoders (VAE)
训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat = self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld
二 main
文件名: main.py
作用: 训练,测试数据集
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter() interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()
参考:
课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili
相关文章:
[PyTorch][chapter 54][Variational Auto-Encoder 实战]
前言: 这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。 整个工程两个文件: vae.py main.py 目录: vae main 一 vae 文件名: vae…...
Java实现HTTP的上传与下载
相信很多人对于java文件下载的过程都存在一些疑惑,比如下载上传文件会不会占用vm内存,上传/下载大文件会不会导致oom。下面从字节流的角度看下载/上传的实现,可以更加深入理解文件的上传和下载功能。 文件下载 首先明确,文件下载…...
VPG算法
VPG算法 前言 首先来看经典的策略梯度REINFORCE算法: 在REINFORCE中,每次采集一个episode的轨迹,计算每一步动作的回报 G t G_t Gt,与动作概率对数相乘,作为误差反向传播,有以下几个特点: …...
docker 笔记5:redis 集群分布式存储案例
尚硅谷Docker实战教程(docker教程天花板)_哔哩哔哩_bilibili 目录 1.cluster(集群)模式-docker版哈希槽分区进行亿级数据存储 1.1面试题 1.1.1 方案1 哈希取余分区 1.1.2 方案2 一致性哈希算法分区 原理 优点 一致性哈希算法的容错性 一致性…...
【Vue2】 axios库
网络请求库-axios库 认识Axios库为什么选择Axios库安装Axios axios发送请求常见的配置选项简单请求可以给Axios设置公共的基础配置发送多个请求 axios创建实例为什么要创建axios的实例 axios的拦截器请求拦截器响应拦截器 axios请求封装 认识Axios库 为什么选择Axios库 在游览…...
云计算 - 百度AIStudio使用小结
云计算 - 百度AIStudio使用小结 前言 本文以ffmpeg处理视频为例,小结一下AI Studio的使用体验及一些避坑技巧。 算力获得 免费的算力获得方式为:每日登录后运行一个项目(只需要点击运行,不需要真正运行)即可获得8小…...
刷新你对Redis持久化的认知
认识持久化 redis是一个内存数据库,数据存储到内存中。而内存的数据是不持久的,要想做到持久化,就需要让redis把数据存储到硬盘上。因此redis既要在内存上存储一份数据,还要在硬盘上存储一份数据。这样这两份数据在理论上是完全相…...
Greenplum-最佳实践小结
注:本文翻译自https://docs.vmware.com/en/VMware-Greenplum/7/greenplum-database/best_practices-logfiles.html 数据模型 Greenplum数据库是一个分析型MPP无共享数据库。该模型与高度规范化/事务性的SMP数据库明显不同。Greenplum数据库使用适合MPP分析处理的非…...
从Gamma空间改为Linear空间会导致性能下降吗
1)从Gamma空间改为Linear空间会导致性能下降吗 2)如何处理没有使用Unity Ads却收到了GooglePlay平台的警告 3)C#端如何处理xLua在执行DoString时候死循环 4)Texture2DArray相关 这是第350篇UWA技术知识分享的推送,精选…...
双轨制的发展,弊端和前景
双轨制是一种经济体制,指两种不同的规则或机制并行运行,以适应不同的市场或客户需求。双轨制最早出现在中国的改革开放中,是从计划经济向市场经济过渡的一种渐进式改革方式。 双轨制的发展可以分为三个阶段: 第一阶段(…...
生成对抗网络(GAN):在图像生成和修复中的应用
文章目录 什么是生成对抗网络(GAN)?GAN在图像生成中的应用图像生成风格迁移 GAN在图像修复中的应用图像修复 拓展应用领域总结 🎉欢迎来到AIGC人工智能专栏~生成对抗网络(GAN):在图像生成和修复…...
扬杰科技携手企企通,召开SRM采购供应链协同系统项目启动会
近日,中国功率半导体领先企业扬州扬杰电子科技股份有限公司(以下简称“扬杰科技”)与企企通召开SRM采购供应链协同系统项目启动会,双方项目团队成员一同出席本次会议。 会上,双方就扬杰科技采购供应链管理平台项目的目…...
AtCoder Beginner Contest 318
目录 A - Full Moon B - Overlapping sheets C - Blue Spring D - General Weighted Max Matching E - Sandwiches F - Octopus A - Full Moon #include<bits/stdc.h> using namespace std; const int N1e65; typedef long long ll ; const int maxv4e65; typedef …...
《Python魔法大冒险》003 两个神奇的魔法工具
魔法师:小鱼,要开始编写魔法般的Python程序,我们首先需要两个神奇的工具:Python解释器和代码编辑器。 小鱼:这两个工具是做什么的? 魔法师:你可以把Python解释器看作是一个魔法棒,只要你向它说出正确的咒语,它就会为你施展魔法。 小鱼:那这个解释器和我之前用的电…...
每日一题-动态规划(从不同类型的物品中各挑选一个,使得最后花费总和等于1000)
四种类型的物品,每一种类型物品数量都是n,先要从每种类型的物品中挑选一件,使得最后花费总和等于1000 暴力做法10000^4 看到花费总和是1000,很小且固定的数字,肯定有玄机,从这里想应该是用dp,不…...
2023-9-3 试除法判定质数
题目链接:试除法判定质数 #include <iostream>using namespace std;bool is_prime(int n) {if(n < 2) return false;for(int i 2; i < n / i; i){if(n % i 0) return false;}return true; }int main() {int n;cin >> n;while(n--){int x;cin &g…...
【Apollo学习笔记】——规划模块TASK之RULE_BASED_STOP_DECIDER
文章目录 前言RULE_BASED_STOP_DECIDER相关配置RULE_BASED_STOP_DECIDER总体流程StopOnSidePassCheckClearDoneCheckSidePassStopIsPerceptionBlockedIsClearToChangeLaneCheckSidePassStopBuildStopDecisionELSE:涉及到的一些其他函数NormalizeAngleSelfRotate CheckLaneChang…...
【SpringBoot】最基础的项目架构(SpringBoot+Mybatis-plus+lombok+knife4j+hutool)
汝之观览,吾之幸也! 从本文开始讲下项目中用到的一些框架和技术,最基本的框架使用的是SpringBoot(2.5.10)Mybatis-plus(3.5.3.2)lombok(1.18.28)knife4j(3.0.3)hutool(5.8.21),可以做到代码自动生成,满足最基本的增删查改。 一、新…...
RNN 单元:分析 GRU 方程与 LSTM,以及何时选择 RNN 而不是变压器
一、说明 深度学习往往感觉像是在雪山上找到自己的道路。拥有坚实的原则会让你对做出决定更有信心。我们都去过那里 在上一篇文章中,我们彻底介绍并检查了 LSTM 单元的各个方面。有人可能会争辩说,RNN方法已经过时了,研究它们是没有意义的。的…...
Linux音频了解
ALPHA I.MX6U 开发板支持音频,板上搭载了音频编解码芯片 WM8960,支持播放以及录音功能! 本章将会讨论如下主题内容。 ⚫ Linux 下 ALSA 框架概述; ⚫ alsa-lib 库介绍; ⚫ alsa-lib 库移植; ⚫ alsa-l…...
7.4.分块查找
一.分块查找的算法思想: 1.实例: 以上述图片的顺序表为例, 该顺序表的数据元素从整体来看是乱序的,但如果把这些数据元素分成一块一块的小区间, 第一个区间[0,1]索引上的数据元素都是小于等于10的, 第二…...
Ubuntu系统下交叉编译openssl
一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...
React Native 导航系统实战(React Navigation)
导航系统实战(React Navigation) React Navigation 是 React Native 应用中最常用的导航库之一,它提供了多种导航模式,如堆栈导航(Stack Navigator)、标签导航(Tab Navigator)和抽屉…...
React第五十七节 Router中RouterProvider使用详解及注意事项
前言 在 React Router v6.4 中,RouterProvider 是一个核心组件,用于提供基于数据路由(data routers)的新型路由方案。 它替代了传统的 <BrowserRouter>,支持更强大的数据加载和操作功能(如 loader 和…...
通过Wrangler CLI在worker中创建数据库和表
官方使用文档:Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后,会在本地和远程创建数据库: npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库: 现在,您的Cloudfla…...
循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...
STM32+rt-thread判断是否联网
一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...
CentOS下的分布式内存计算Spark环境部署
一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...
高危文件识别的常用算法:原理、应用与企业场景
高危文件识别的常用算法:原理、应用与企业场景 高危文件识别旨在检测可能导致安全威胁的文件,如包含恶意代码、敏感数据或欺诈内容的文档,在企业协同办公环境中(如Teams、Google Workspace)尤为重要。结合大模型技术&…...
CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云
目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...
