当前位置: 首页 > news >正文

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat =   self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter()    interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

相关文章:

[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言: 这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。 整个工程两个文件: vae.py main.py 目录: vae main 一 vae 文件名: vae…...

Java实现HTTP的上传与下载

相信很多人对于java文件下载的过程都存在一些疑惑,比如下载上传文件会不会占用vm内存,上传/下载大文件会不会导致oom。下面从字节流的角度看下载/上传的实现,可以更加深入理解文件的上传和下载功能。 文件下载 首先明确,文件下载…...

VPG算法

VPG算法 前言 首先来看经典的策略梯度REINFORCE算法: 在REINFORCE中,每次采集一个episode的轨迹,计算每一步动作的回报 G t G_t Gt​,与动作概率对数相乘,作为误差反向传播,有以下几个特点: …...

docker 笔记5:redis 集群分布式存储案例

尚硅谷Docker实战教程(docker教程天花板)_哔哩哔哩_bilibili 目录 1.cluster(集群)模式-docker版哈希槽分区进行亿级数据存储 1.1面试题 1.1.1 方案1 哈希取余分区 1.1.2 方案2 一致性哈希算法分区 原理 优点 一致性哈希算法的容错性 一致性…...

【Vue2】 axios库

网络请求库-axios库 认识Axios库为什么选择Axios库安装Axios axios发送请求常见的配置选项简单请求可以给Axios设置公共的基础配置发送多个请求 axios创建实例为什么要创建axios的实例 axios的拦截器请求拦截器响应拦截器 axios请求封装 认识Axios库 为什么选择Axios库 在游览…...

云计算 - 百度AIStudio使用小结

云计算 - 百度AIStudio使用小结 前言 本文以ffmpeg处理视频为例,小结一下AI Studio的使用体验及一些避坑技巧。 算力获得 免费的算力获得方式为:每日登录后运行一个项目(只需要点击运行,不需要真正运行)即可获得8小…...

刷新你对Redis持久化的认知

认识持久化 redis是一个内存数据库,数据存储到内存中。而内存的数据是不持久的,要想做到持久化,就需要让redis把数据存储到硬盘上。因此redis既要在内存上存储一份数据,还要在硬盘上存储一份数据。这样这两份数据在理论上是完全相…...

Greenplum-最佳实践小结

注:本文翻译自https://docs.vmware.com/en/VMware-Greenplum/7/greenplum-database/best_practices-logfiles.html 数据模型 Greenplum数据库是一个分析型MPP无共享数据库。该模型与高度规范化/事务性的SMP数据库明显不同。Greenplum数据库使用适合MPP分析处理的非…...

从Gamma空间改为Linear空间会导致性能下降吗

1)从Gamma空间改为Linear空间会导致性能下降吗 2)如何处理没有使用Unity Ads却收到了GooglePlay平台的警告 3)C#端如何处理xLua在执行DoString时候死循环 4)Texture2DArray相关 这是第350篇UWA技术知识分享的推送,精选…...

双轨制的发展,弊端和前景

双轨制是一种经济体制,指两种不同的规则或机制并行运行,以适应不同的市场或客户需求。双轨制最早出现在中国的改革开放中,是从计划经济向市场经济过渡的一种渐进式改革方式。 双轨制的发展可以分为三个阶段: 第一阶段(…...

生成对抗网络(GAN):在图像生成和修复中的应用

文章目录 什么是生成对抗网络(GAN)?GAN在图像生成中的应用图像生成风格迁移 GAN在图像修复中的应用图像修复 拓展应用领域总结 🎉欢迎来到AIGC人工智能专栏~生成对抗网络(GAN):在图像生成和修复…...

扬杰科技携手企企通,召开SRM采购供应链协同系统项目启动会

近日,中国功率半导体领先企业扬州扬杰电子科技股份有限公司(以下简称“扬杰科技”)与企企通召开SRM采购供应链协同系统项目启动会,双方项目团队成员一同出席本次会议。 会上,双方就扬杰科技采购供应链管理平台项目的目…...

AtCoder Beginner Contest 318

目录 A - Full Moon B - Overlapping sheets C - Blue Spring D - General Weighted Max Matching E - Sandwiches F - Octopus A - Full Moon #include<bits/stdc.h> using namespace std; const int N1e65; typedef long long ll ; const int maxv4e65; typedef …...

《Python魔法大冒险》003 两个神奇的魔法工具

魔法师:小鱼,要开始编写魔法般的Python程序,我们首先需要两个神奇的工具:Python解释器和代码编辑器。 小鱼:这两个工具是做什么的? 魔法师:你可以把Python解释器看作是一个魔法棒,只要你向它说出正确的咒语,它就会为你施展魔法。 小鱼:那这个解释器和我之前用的电…...

每日一题-动态规划(从不同类型的物品中各挑选一个,使得最后花费总和等于1000)

四种类型的物品&#xff0c;每一种类型物品数量都是n&#xff0c;先要从每种类型的物品中挑选一件&#xff0c;使得最后花费总和等于1000 暴力做法10000^4 看到花费总和是1000&#xff0c;很小且固定的数字&#xff0c;肯定有玄机&#xff0c;从这里想应该是用dp&#xff0c;不…...

2023-9-3 试除法判定质数

题目链接&#xff1a;试除法判定质数 #include <iostream>using namespace std;bool is_prime(int n) {if(n < 2) return false;for(int i 2; i < n / i; i){if(n % i 0) return false;}return true; }int main() {int n;cin >> n;while(n--){int x;cin &g…...

【Apollo学习笔记】——规划模块TASK之RULE_BASED_STOP_DECIDER

文章目录 前言RULE_BASED_STOP_DECIDER相关配置RULE_BASED_STOP_DECIDER总体流程StopOnSidePassCheckClearDoneCheckSidePassStopIsPerceptionBlockedIsClearToChangeLaneCheckSidePassStopBuildStopDecisionELSE:涉及到的一些其他函数NormalizeAngleSelfRotate CheckLaneChang…...

【SpringBoot】最基础的项目架构(SpringBoot+Mybatis-plus+lombok+knife4j+hutool)

汝之观览&#xff0c;吾之幸也&#xff01; 从本文开始讲下项目中用到的一些框架和技术&#xff0c;最基本的框架使用的是SpringBoot(2.5.10)Mybatis-plus(3.5.3.2)lombok(1.18.28)knife4j(3.0.3)hutool(5.8.21),可以做到代码自动生成&#xff0c;满足最基本的增删查改。 一、新…...

RNN 单元:分析 GRU 方程与 LSTM,以及何时选择 RNN 而不是变压器

一、说明 深度学习往往感觉像是在雪山上找到自己的道路。拥有坚实的原则会让你对做出决定更有信心。我们都去过那里 在上一篇文章中&#xff0c;我们彻底介绍并检查了 LSTM 单元的各个方面。有人可能会争辩说&#xff0c;RNN方法已经过时了&#xff0c;研究它们是没有意义的。的…...

Linux音频了解

ALPHA I.MX6U 开发板支持音频&#xff0c;板上搭载了音频编解码芯片 WM8960&#xff0c;支持播放以及录音功能&#xff01; 本章将会讨论如下主题内容。 ⚫ Linux 下 ALSA 框架概述&#xff1b; ⚫ alsa-lib 库介绍&#xff1b; ⚫ alsa-lib 库移植&#xff1b; ⚫ alsa-l…...

label-studio的使用教程(导入本地路径)

文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

FastAPI 教程:从入门到实践

FastAPI 是一个现代、快速&#xff08;高性能&#xff09;的 Web 框架&#xff0c;用于构建 API&#xff0c;支持 Python 3.6。它基于标准 Python 类型提示&#xff0c;易于学习且功能强大。以下是一个完整的 FastAPI 入门教程&#xff0c;涵盖从环境搭建到创建并运行一个简单的…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...

Python实现简单音频数据压缩与解压算法

Python实现简单音频数据压缩与解压算法 引言 在音频数据处理中&#xff0c;压缩算法是降低存储成本和传输效率的关键技术。Python作为一门灵活且功能强大的编程语言&#xff0c;提供了丰富的库和工具来实现音频数据的压缩与解压。本文将通过一个简单的音频数据压缩与解压算法…...

Visual Studio Code 扩展

Visual Studio Code 扩展 change-case 大小写转换EmmyLua for VSCode 调试插件Bookmarks 书签 change-case 大小写转换 https://marketplace.visualstudio.com/items?itemNamewmaurer.change-case 选中单词后&#xff0c;命令 changeCase.commands 可预览转换效果 EmmyLua…...

实战设计模式之模板方法模式

概述 模板方法模式定义了一个操作中的算法骨架&#xff0c;并将某些步骤延迟到子类中实现。模板方法使得子类可以在不改变算法结构的前提下&#xff0c;重新定义算法中的某些步骤。简单来说&#xff0c;就是在一个方法中定义了要执行的步骤顺序或算法框架&#xff0c;但允许子类…...