当前位置: 首页 > news >正文

Pytorch Advanced(二) Variational Auto-Encoder

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。

自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的深层次特征再还原成为原来的数据特征。那么如何保证从压缩到解码两部分,原数据和解码数据保持一致呢?这就是要训练的过程。

如何理解降维?如果压缩的过程是卷积,维度可以根据核的个数变化,特征维度因此而改变。


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')sample_dir = 'samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3dataset = torchvision.datasets.MNIST(root='../../data',train=True,transform=transforms.ToTensor(),download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)

模型搭建:这里搭建的是一个变分自编码,Variational Autoencoder

那么变分自编码是为了解决什么问题呢? ——- 其主要思想还是希望学习隐层变量,并将其用来表示原始数据,但是它加另一个条件, 即隐层变量能学习原始数据的分布, 并反过来生产一些和原始数据相似的数据(这有啥用?—-可用于图片修复,让图片按训练集的数据分布变化)。

变分自编码 (Variational Autoencoder) 为了让隐层抓住输入数据特性, 而不是简单的输出数据=输入数据,他在隐层中加入随机噪声(单位高斯噪声)(这个过程也叫reparametrize),以确保隐层能较好抽象输入数据特点。

代码中怎么做的呢?

1、编码过程中我们保存了第二层线性层的输出。其中第二层包含有fc2与fc3两部分,他们是并联的。

2、给隐藏层加入随机噪声,作为解码的输入

class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encode(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_reconst = self.decode(z)return x_reconst, mu, log_var

训练:由于训练中加入了噪声,所以损失值的结构也因此改变。一部分来源于解码内容核原内容的相似度,另一部分是kl_div,具体是什么意义需查看论文。

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Start training
for epoch in range(num_epochs):for i, (x, _) in enumerate(data_loader):# Forward passx = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# Compute reconstruction loss and kl divergence# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# Backprop and optimizeloss = reconst_loss + kl_divoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 10 == 0:print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))with torch.no_grad():# Save the sampled imagesz = torch.randn(batch_size, z_dim).to(device)out = model.decode(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))# Save the reconstructed imagesout, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

模型训练完成了之后该如何使用这个模型呢?

model.decode()是一个解码的过程,我们给他一个随机的中间特征z就可以输出一个数字图片了。

z = torch.randn(1,z_dim).to(device)
out = model.decode(z)
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

有了随机的一张图片之后,我们把他完整的放入模型中,生成了和输入相似的一张图片,也没看出来是修复了图像......

out,_,_ = model(out) 
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

相关文章:

Pytorch Advanced(二) Variational Auto-Encoder

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。 自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的…...

Flask 使用 JWT(三)flask-jwt-extended

如果想要在 flask 中使用 JWT ,推荐使用 flask-jwt-extended 插件。 使用 pip 安装这个扩展插件的最简单方法是: pip install flask-jwt-extended基本使用 在接下来的案例中,我们看一下基本使用。我们可以使用 create_access_token() 函数用来生成实际的 JWT token。@jwt_r…...

堆与栈的区别

OVERVIEW 栈与堆的区别一、程序内存分区中的堆与栈1.栈2.堆3.堆&栈 二、数据结构中的堆与栈1.栈2.堆 三、堆的深入1.堆插入2.堆删除:3.堆建立:4.堆排序:5.堆实现优先队列:6.堆与栈的相关练习 栈与堆的区别 自整理,…...

OpenWrt kernel install分析(2)

一. 前言 接下来分析make -C image compile install TARGET_BUILD。 二. Makefile分析 1. 命令首先运行target/linux/mediatek/image/Makefile,该文件内容如下: target/linux/mediatek/image/Makefile: include $(TOPDIR)/rules.mk include $(INCLUDE_DIR)/image.…...

【计算机网络】传输层协议——TCP(下)

文章目录 1. 三次握手三次握手的本质是建立链接,什么是链接?整体过程三次握手过程中报文丢失问题为什么2次握手不可以?为什么要三次握手? 2. 四次挥手整体过程为什么要等待2MSL 3. 流量控制4. 滑动窗口共识滑动窗口的一般情况理解…...

Vue前端页面打印

前端依赖10-插件"print-js": “^1.6.0” 一:简介 print-js 是一个 Vue.js 插件,用于在 Vue.js 项目中实现打印功能。它依赖于 print-js 库,所以需要安装这个库。 能实现以下功能: PDF打印(默认&#xff…...

Visual Studio将C#项目编译成EXE可执行程序

经常看文章时会收获不少实用工具,有的在github上是编译好的,有的则是未编译的项目文件。所以经常会使用Visual Studio编译项目文件成exe可执行程序,以下为编译的流程。 第一步,从github上下载项目文件,举个例子&#…...

git把某一次commit修改过的文件打包导出(git)

1、使用命令把修改的文件打包导出:打包某次commit: git diff-tree -r --no-commit-id --name-only f4710c4a32975904b00609f3145c709f31392140 | xargs tar -rf xxx_1.1.tar 2、使用命令把某次节点后的文件导出: window 下: git diff f4710c4a32975904b00609f3145c709f31392…...

Vue3 Ajax(axios)异步

文章目录 Vue3 Ajax(axios)异步1. 基础1.1 安装Ajax1.2 使用方法1.3 浏览器支持情况 2. GET方法2.1 参数传递2.2 实例 3. POST方法4. 执行多个并发请求5. axios API5.1 传递配置创建请求5.2 请求方法的别名5.3 并发5.4 创建实例5.5 实例方法5.6 请求配置项5.7 响应结构5.8 配置…...

idea2023全量方法debug

为什么要全量debug 刚上手项目或者研读开源项目源码的时候,我们对项目的结构,尤其是功能链路非常陌生,想要debug根本不知道断点打在哪,光靠文件名类名或者方法名去猜也不是个事。这时候只要配置一下全量debug模式,就能…...

Docker镜像解析获取Dockerfile文件

01、概述 当涉及到容器镜像的安全时,特别是在出现镜像投毒引发的安全事件时,追溯镜像的来源和解析Dockerfile文件是应急事件处理的关键步骤。在这篇博客中,我们将探讨如何从镜像解析获取Dockerfile文件,这对容器安全至关重要。 02…...

使用maven命令打jar包

参考:https://blog.csdn.net/qq_27525611/article/details/123487255 https://blog.csdn.net/qq_35860138/article/details/82701919 小伙伴给我的项目自己尝试命令行打包遇到的坑,简单记录下 // 打包(1.8环境下打的,17会报错&…...

【多线程】死锁 详解

死锁 一. 死锁是什么二. 死锁的场景1. 一个线程一把锁2. 两个线程两把锁3. N 个线程 M 把锁 三. 死锁产生的四个必要条件四. 如何避免死锁 一. 死锁是什么 死锁是这样一种情形: 多个线程同时被阻塞,因为每个进程都在等其他线程释放某些资源,…...

成考[专升本政治]科目必背知识点

1. 马克思主义哲学研究的对象是:关于自然、社会、思维发展的一般规律。 2. 对待马克思主义的科学态度是:坚持和发展。 3. 物质的唯一特性是客观实在性。这里的客观实在是指:不以人的意志为转移。 4. 在实际工作中,要注意掌握…...

spring boot 使用AOP+自定义注解+反射实现操作日志记录修改前数据和修改后对比数据,并保存至日志表

一、添加aop starter依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-aop</artifactId> </dependency>二&#xff1a;自定义字段翻译注解。&#xff08;修改功能时&#xff0c;需要显示如…...

【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)

目录 一、前言 二、实验环境 三、PyTorch数据结构 0、分类 1、张量&#xff08;Tensor&#xff09; 1. 维度&#xff08;Dimensions&#xff09; 2. 数据类型&#xff08;Data Types&#xff09; 3. GPU加速&#xff08;GPU Acceleration&#xff09; 一、前言 ChatGP…...

多线程|多进程|高并发网络编程

一.多进程并发服务器 多进程并发服务器是一种经典的服务器架构&#xff0c;它通过创建多个子进程来处理客户端连接&#xff0c;从而实现并发处理多个客户端请求的能力。 概念&#xff1a; 服务器启动时&#xff0c;创建主进程&#xff0c;并绑定监听端口。当有客户端连接请求…...

云计算——ACA学习 云计算分类

作者简介&#xff1a;一名云计算网络运维人员、每天分享网络与运维的技术与干货。 公众号&#xff1a;网络豆 座右铭&#xff1a;低头赶路&#xff0c;敬事如仪 个人主页&#xff1a; 网络豆的主页​​​​​ 目录 写在前面 前期回顾 本期介绍 一.云计算分类 1.公有云…...

3 分钟,带你了解低代码开发

一、低代码平台存在的意义 传统软件开发交付链中&#xff0c;需求经过3次传递&#xff0c;用户→业务→架构师→开发&#xff0c;每一层传递都可能使需求失真&#xff0c;导致最终交付的功能返工。 业务的变化促使软件开发过程不断更新、迭代和演进&#xff0c;而低代码开发即是…...

小白学Unity03-太空漫游游戏脚本,控制飞船移动旋转

首先搭建好太阳系以及飞机的场景 需要用到3个脚本 1.控制飞机移动旋转 2.控制摄像机LookAt朝向飞机和差值平滑跟踪飞机 3.控制各个星球自转以及围绕太阳旋转&#xff08;rotate()和RotateAround()&#xff09; 1.控制飞机移动旋转的脚本 using System.Collections; using…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

C++使用 new 来创建动态数组

问题&#xff1a; 不能使用变量定义数组大小 原因&#xff1a; 这是因为数组在内存中是连续存储的&#xff0c;编译器需要在编译阶段就确定数组的大小&#xff0c;以便正确地分配内存空间。如果允许使用变量来定义数组的大小&#xff0c;那么编译器就无法在编译时确定数组的大…...

使用Spring AI和MCP协议构建图片搜索服务

目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式&#xff08;本地调用&#xff09; SSE模式&#xff08;远程调用&#xff09; 4. 注册工具提…...

Mysql8 忘记密码重置,以及问题解决

1.使用免密登录 找到配置MySQL文件&#xff0c;我的文件路径是/etc/mysql/my.cnf&#xff0c;有的人的是/etc/mysql/mysql.cnf 在里最后加入 skip-grant-tables重启MySQL服务 service mysql restartShutting down MySQL… SUCCESS! Starting MySQL… SUCCESS! 重启成功 2.登…...

Vue 3 + WebSocket 实战:公司通知实时推送功能详解

&#x1f4e2; Vue 3 WebSocket 实战&#xff1a;公司通知实时推送功能详解 &#x1f4cc; 收藏 点赞 关注&#xff0c;项目中要用到推送功能时就不怕找不到了&#xff01; 实时通知是企业系统中常见的功能&#xff0c;比如&#xff1a;管理员发布通知后&#xff0c;所有用户…...

【threejs】每天一个小案例讲解:创建基本的3D场景

代码仓 GitHub - TiffanyHoo/three_practices: Learning three.js together! 可自行clone&#xff0c;无需安装依赖&#xff0c;直接liver-server运行/直接打开chapter01中的html文件 运行效果图 知识要点 核心三要素 场景&#xff08;Scene&#xff09; 使用 THREE.Scene(…...

宠物车载安全座椅市场报告:解读行业趋势与投资前景

一、什么是宠物车载安全座椅&#xff1f; 宠物车载安全座椅是一种专为宠物设计的车内固定装置&#xff0c;旨在保障宠物在乘车过程中的安全性与舒适性。它通常由高强度材料制成&#xff0c;具备良好的缓冲性能&#xff0c;并可通过安全带或ISOFIX接口固定于车内。 近年来&…...

Linux【5】-----编译和烧写Linux系统镜像(RK3568)

参考&#xff1a;讯为 1、文件系统 不同的文件系统组成了&#xff1a;debian、ubuntu、buildroot、qt等系统 每个文件系统的uboot和kernel是一样的 2、源码目录介绍 目录 3、正式编译 编译脚本build.sh 帮助内容如下&#xff1a; Available options: uboot …...

Continue 开源 AI 编程助手框架深度分析

Continue 开源 AI 编程助手框架深度分析 一、项目简介 Continue 是一个模块化、可配置、跨平台的开源 AI 编程助手框架&#xff0c;目标是让开发者能在本地或云端环境中&#xff0c;快速集成和使用自定义的 LLM 编程辅助工具。它通过支持 VS Code 与 JetBrains 等主流 IDE 插件…...

一、ES6-let声明变量【解刨分析最详细】

一、块级作用域 { let Tim"Tim是靓仔&#xff01;" } console.log("Tim:",Tim) 打印结果&#xff1a;Tim未进行任何定义&#xff01; 原因&#xff1a;因为Tim定义再块级{}里面&#xff0c;它的声音Tim只服务于该块级里面。而打印结果是再块级外面&#…...