[PyTorch][chapter 54][Variational Auto-Encoder 实战]
前言:

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

训练的过程中要注意调节forward 中的kle ,调参。
整个工程两个文件:
vae.py
main.py
目录:
- vae
- main
一 vae
文件名: vae.py
作用: Variational Autoencoders (VAE)
训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat = self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld
二 main
文件名: main.py
作用: 训练,测试数据集
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter() interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()
参考:
课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili
相关文章:
[PyTorch][chapter 54][Variational Auto-Encoder 实战]
前言: 这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。 整个工程两个文件: vae.py main.py 目录: vae main 一 vae 文件名: vae…...
Java实现HTTP的上传与下载
相信很多人对于java文件下载的过程都存在一些疑惑,比如下载上传文件会不会占用vm内存,上传/下载大文件会不会导致oom。下面从字节流的角度看下载/上传的实现,可以更加深入理解文件的上传和下载功能。 文件下载 首先明确,文件下载…...
VPG算法
VPG算法 前言 首先来看经典的策略梯度REINFORCE算法: 在REINFORCE中,每次采集一个episode的轨迹,计算每一步动作的回报 G t G_t Gt,与动作概率对数相乘,作为误差反向传播,有以下几个特点: …...
docker 笔记5:redis 集群分布式存储案例
尚硅谷Docker实战教程(docker教程天花板)_哔哩哔哩_bilibili 目录 1.cluster(集群)模式-docker版哈希槽分区进行亿级数据存储 1.1面试题 1.1.1 方案1 哈希取余分区 1.1.2 方案2 一致性哈希算法分区 原理 优点 一致性哈希算法的容错性 一致性…...
【Vue2】 axios库
网络请求库-axios库 认识Axios库为什么选择Axios库安装Axios axios发送请求常见的配置选项简单请求可以给Axios设置公共的基础配置发送多个请求 axios创建实例为什么要创建axios的实例 axios的拦截器请求拦截器响应拦截器 axios请求封装 认识Axios库 为什么选择Axios库 在游览…...
云计算 - 百度AIStudio使用小结
云计算 - 百度AIStudio使用小结 前言 本文以ffmpeg处理视频为例,小结一下AI Studio的使用体验及一些避坑技巧。 算力获得 免费的算力获得方式为:每日登录后运行一个项目(只需要点击运行,不需要真正运行)即可获得8小…...
刷新你对Redis持久化的认知
认识持久化 redis是一个内存数据库,数据存储到内存中。而内存的数据是不持久的,要想做到持久化,就需要让redis把数据存储到硬盘上。因此redis既要在内存上存储一份数据,还要在硬盘上存储一份数据。这样这两份数据在理论上是完全相…...
Greenplum-最佳实践小结
注:本文翻译自https://docs.vmware.com/en/VMware-Greenplum/7/greenplum-database/best_practices-logfiles.html 数据模型 Greenplum数据库是一个分析型MPP无共享数据库。该模型与高度规范化/事务性的SMP数据库明显不同。Greenplum数据库使用适合MPP分析处理的非…...
从Gamma空间改为Linear空间会导致性能下降吗
1)从Gamma空间改为Linear空间会导致性能下降吗 2)如何处理没有使用Unity Ads却收到了GooglePlay平台的警告 3)C#端如何处理xLua在执行DoString时候死循环 4)Texture2DArray相关 这是第350篇UWA技术知识分享的推送,精选…...
双轨制的发展,弊端和前景
双轨制是一种经济体制,指两种不同的规则或机制并行运行,以适应不同的市场或客户需求。双轨制最早出现在中国的改革开放中,是从计划经济向市场经济过渡的一种渐进式改革方式。 双轨制的发展可以分为三个阶段: 第一阶段(…...
生成对抗网络(GAN):在图像生成和修复中的应用
文章目录 什么是生成对抗网络(GAN)?GAN在图像生成中的应用图像生成风格迁移 GAN在图像修复中的应用图像修复 拓展应用领域总结 🎉欢迎来到AIGC人工智能专栏~生成对抗网络(GAN):在图像生成和修复…...
扬杰科技携手企企通,召开SRM采购供应链协同系统项目启动会
近日,中国功率半导体领先企业扬州扬杰电子科技股份有限公司(以下简称“扬杰科技”)与企企通召开SRM采购供应链协同系统项目启动会,双方项目团队成员一同出席本次会议。 会上,双方就扬杰科技采购供应链管理平台项目的目…...
AtCoder Beginner Contest 318
目录 A - Full Moon B - Overlapping sheets C - Blue Spring D - General Weighted Max Matching E - Sandwiches F - Octopus A - Full Moon #include<bits/stdc.h> using namespace std; const int N1e65; typedef long long ll ; const int maxv4e65; typedef …...
《Python魔法大冒险》003 两个神奇的魔法工具
魔法师:小鱼,要开始编写魔法般的Python程序,我们首先需要两个神奇的工具:Python解释器和代码编辑器。 小鱼:这两个工具是做什么的? 魔法师:你可以把Python解释器看作是一个魔法棒,只要你向它说出正确的咒语,它就会为你施展魔法。 小鱼:那这个解释器和我之前用的电…...
每日一题-动态规划(从不同类型的物品中各挑选一个,使得最后花费总和等于1000)
四种类型的物品,每一种类型物品数量都是n,先要从每种类型的物品中挑选一件,使得最后花费总和等于1000 暴力做法10000^4 看到花费总和是1000,很小且固定的数字,肯定有玄机,从这里想应该是用dp,不…...
2023-9-3 试除法判定质数
题目链接:试除法判定质数 #include <iostream>using namespace std;bool is_prime(int n) {if(n < 2) return false;for(int i 2; i < n / i; i){if(n % i 0) return false;}return true; }int main() {int n;cin >> n;while(n--){int x;cin &g…...
【Apollo学习笔记】——规划模块TASK之RULE_BASED_STOP_DECIDER
文章目录 前言RULE_BASED_STOP_DECIDER相关配置RULE_BASED_STOP_DECIDER总体流程StopOnSidePassCheckClearDoneCheckSidePassStopIsPerceptionBlockedIsClearToChangeLaneCheckSidePassStopBuildStopDecisionELSE:涉及到的一些其他函数NormalizeAngleSelfRotate CheckLaneChang…...
【SpringBoot】最基础的项目架构(SpringBoot+Mybatis-plus+lombok+knife4j+hutool)
汝之观览,吾之幸也! 从本文开始讲下项目中用到的一些框架和技术,最基本的框架使用的是SpringBoot(2.5.10)Mybatis-plus(3.5.3.2)lombok(1.18.28)knife4j(3.0.3)hutool(5.8.21),可以做到代码自动生成,满足最基本的增删查改。 一、新…...
RNN 单元:分析 GRU 方程与 LSTM,以及何时选择 RNN 而不是变压器
一、说明 深度学习往往感觉像是在雪山上找到自己的道路。拥有坚实的原则会让你对做出决定更有信心。我们都去过那里 在上一篇文章中,我们彻底介绍并检查了 LSTM 单元的各个方面。有人可能会争辩说,RNN方法已经过时了,研究它们是没有意义的。的…...
Linux音频了解
ALPHA I.MX6U 开发板支持音频,板上搭载了音频编解码芯片 WM8960,支持播放以及录音功能! 本章将会讨论如下主题内容。 ⚫ Linux 下 ALSA 框架概述; ⚫ alsa-lib 库介绍; ⚫ alsa-lib 库移植; ⚫ alsa-l…...
别再乱用分支了!Flowable四种网关(排他/并行/包容/事件)实战选型指南
Flowable四大网关实战选型:从混乱到精准的决策艺术当你在设计一个请假审批流程时,是否遇到过这样的困惑:部门经理审批后需要同时通知HR和财务,但某些特殊情况下又需要跳过财务直接归档?这种看似简单的业务需求…...
Vue3 图片标框功能实现方案
基于 Vue3 组合式 API 的图片标框(画框、标注、选框)完整实现,核心逻辑封装在 GetBoxes 组件里,复制就能用 一、功能说明 ✅ 在图片上鼠标拖拽画矩形框 ✅ 实时显示框坐标(x, y, width, height) ✅ 支持多…...
[智能体-81]:工程化智能体 = 模型做脑力拆解 + 框架做流程落地。前者是决策者,后者是管理者,tools/function call是内部员工;mcp server是外部资源;
一、全角色人设 & 对应技术组件角色定位对应技术模块核心职责决策者(脑力大脑)大模型 LLM理解目标、任务拆解、逻辑判断、分支决策、内容生成,负责 “想方案、定步骤”管理者(流程总管)智能体编排框架(…...
上线前最后一道防线,DeepSeek代码审查如何帮你拦截87%的CVE类缺陷?
更多请点击: https://intelliparadigm.com 第一章:上线前最后一道防线,DeepSeek代码审查如何帮你拦截87%的CVE类缺陷? 在软件交付生命周期末期,传统人工代码审计与通用SAST工具常因误报率高、上下文理解弱而漏检高危漏…...
Unity动态自然系统:Forest Environment-Dynamic Nature深度解析
1. 这不是“贴图堆砌”,而是自然系统级建模:Forest Environment-Dynamic Nature 的真实定位你有没有试过在Unity里拖进几棵树、铺点草、加个天空盒,然后发现场景像一张静止的风景明信片——风不动、叶不摇、雨不落、雾不散?我做过…...
超维计算(HDC)原理与ScalableHD架构优化实践
1. 超维计算(HDC)基础解析超维计算(Hyperdimensional Computing, HDC)是一种受大脑信息处理机制启发的计算范式,其核心思想是用高维随机向量(通常称为超向量或HV)来表示和处理信息。与传统神经网…...
Hindsight测试策略:单元测试、集成测试和端到端测试
Hindsight测试策略:单元测试、集成测试和端到端测试 【免费下载链接】hindsight Hindsight: Agent Memory That Learns 项目地址: https://gitcode.com/GitHub_Trending/hindsight2/hindsight Hindsight作为一款专注于Agent Memory的开源项目,其可…...
基于Netburner NANO54415构建工业级嵌入式Web服务器:从硬件选型到广域监控实战
1. 项目概述:一个为广域与本地监控而生的嵌入式Web服务器如果你正在寻找一个能部署在野外、工厂角落或者任何需要远程数据采集与控制场景下的嵌入式Web服务器方案,并且对市面上那些要么性能孱弱、要么开发门槛极高的开发板感到厌倦,那么这个基…...
LVGL多页面开发避坑:用内部Timer替代轮询,解决页面切换时的内存踩踏问题
LVGL多页面开发中的内存安全实践:用Timer机制替代轮询的工程解决方案 在嵌入式UI开发中,LVGL因其轻量级和跨平台特性成为热门选择。但当项目复杂度提升到多页面交互时,开发者往往会遇到一个棘手问题:如何在频繁切换页面的同时保证…...
百度深度学习研究院的“叛将“,带着一颗芯片改变了中国智能驾驶——地平线余凯,从ImageNet冠军到征程出货1000万
大家好,我是写代码的篮球球痴。这篇文章跟我自己有点关系——我开的是理想汽车。理想的智驾系统 AD Pro,搭载的就是地平线征程 5 芯片。2026 年 1 月理想 AD Pro 4.0 推送,基于单颗征程 6M 实现了城市 NOA——这是行业里第一个用单颗 128TOPS…...
