当前位置: 首页 > news >正文

自编码器简单介绍—使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试

文章目录

    • 自编码器简单介绍
    • 什么是自编码器?
    • 自动编码器和卷积神经网络的区别?
    • 如何构建一个自编码器?
    • 如何训练自编码器?
    • 如何使用自编码器进行图像压缩?
    • 总结
    • 使用PyTorch构建简单的自动编码器
      • 第一步:导入库和数据集
      • 第二步:建立编码器和解码器
      • 第三步:定义损失函数和优化器
      • 第四步:训练自编码器模型
      • 第五步:测试自编码器模型
      • 第六部:对比重构结果

自编码器简单介绍

自编码器是一种无监督学习算法,用于学习数据中的特征,并将这些特征用于重构与输入相似的新数据。自编码器由编码器和解码器两部分组成,编码器用于将输入数据压缩到一个低维度的表示形式,解码器将该表示形式还原回输入数据的形式。自编码器可以应用于多种领域,例如图像处理、语音识别和自然语言处理等。

什么是自编码器?

自编码器是一种无监督学习算法,用于学习数据的压缩表示。它由两个部分组成:编码器和解码器。编码器将输入数据压缩成低维表示,解码器将这个低维表示重构成与原始输入尽可能接近的输出。
简单结构如下:

自动编码器和卷积神经网络的区别?

自动编码器和卷积神经网络(CNN)都是深度学习中常用的模型,但它们的目的和结构略有不同。

自动编码器是一种无监督学习模型,其目的是学习一个对输入数据进行压缩和解压缩的函数。自动编码器通常由两个部分组成:编码器和解码器。编码器将输入数据转换为潜在表示,解码器将潜在表示转换回原始数据。自动编码器的目标是最小化重构误差,即在解码器输出的数据与原始数据之间的差异。

与之不同,卷积神经网络是一种用于图像分类、目标检测、语音识别等任务的有监督学习模型。卷积神经网络通常由卷积层、池化层、全连接层等组成。卷积层可以捕捉输入数据的局部特征,池化层可以减少特征图的大小,全连接层可以将特征图转换为分类结果。

虽然自动编码器和卷积神经网络的目的和结构不同,但它们都可以用于特征提取。自动编码器可以学习输入数据的低维表示,卷积神经网络可以提取输入数据的局部特征。在某些任务中,这些特征可以作为输入传递到其他模型中,以提高任务的性能。

除了特征提取,自动编码器和卷积神经网络还有其他方面的不同之处。

自动编码器可以用于数据的降维和去噪。通过学习输入数据的低维表示,自动编码器可以将高维数据降低到更低的维度,从而简化数据。此外,自动编码器还可以在输入数据中去除噪声,因为它们的目标是在重构时最小化重构误差,这有助于滤除输入数据中的噪声。

卷积神经网络则更适合处理图像、语音、视频等具有空间结构的数据。卷积层可以捕捉输入数据的局部特征,这对于图像分类、目标检测等任务非常重要。此外,卷积神经网络还可以通过使用池化层来减少特征图的大小,从而在处理大型图像时降低计算成本。

总的来说,自动编码器和卷积神经网络是两种不同的深度学习模型,它们的应用场景和目标略有不同。但是,它们都是非常有用的工具,可以在各种任务中发挥重要作用。

如何构建一个自编码器?

首先,需要确定编码器和解码器的结构。编码器可以是多层感知器(MLP)或卷积神经网络(CNN),其中每一层都包含多个神经元,并通过非线性函数进行激活。解码器通常与编码器结构相对称,也可以是一个MLP或CNN。在训练自编码器时,将输入数据输入编码器并计算编码器输出,然后将其输入解码器并计算解码器输出。最终的目标是最小化解码器输出和原始输入之间的差异。

如何训练自编码器?

训练自编码器的关键是确定损失函数。通常使用均方误差(MSE)来计算解码器输出和原始输入之间的差异。MSE的计算方式如下:

$ 在公式中, y i y_i yi 是实际值, y i ^ \hat{y_i} yi^ 是预测值, N N N 是数据集中的样本数。
M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 \mathrm MSE=\frac{1}{N}\sum_{i=1}^N(y_i-\hat y_i)^2 MSE=N1i=1N(yiy^i)2
其中,N是样本数量, y i y_i yi是实际值, y ^ i \hat y_i y^i是解码器输出,即预测值。通过反向传播算法,可以计算编码器和解码器中所有参数的梯度,并使用梯度下降算法对参数进行更新。

如何使用自编码器进行图像压缩?

在用自编码器压缩图像时,将图像输入自编码器的编码器部分,得到低维表示,该表示可以看作是图像的压缩版本。可以通过解码器部分将该低维表示解码成原始图像。由于自编码器的解码器部分重构图像,因此可以使用自编码器进行图像压缩。

总结

以上是自编码器的简单入门教程。自编码器是一种无监督学习算法,可以用于学习数据的压缩表示。在训练自编码器时需要确定编码器和解码器的结构以及损失函数,并使用梯度下降算法对参数进行更新。自编码器可以用于图像压缩等任务。

使用PyTorch构建简单的自动编码器

在本教程中,我们将使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试。

第一步:导入库和数据集

首先,我们需要导入必要的库并加载MNIST数据集。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 检查是否安装了CUDA,并且CUDA是否适用于你的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 超参数设置
batch_size = 64
learning_rate = 1e-3
num_epochs = 10# 加载MNIST手写数字数据集
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

第二步:建立编码器和解码器

接下来,我们需要建立编码器和解码器模型。

# 定义一个自编码器的类
class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()# 编码器部分self.encoder = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 输入1通道,输出16通道,3x3卷积,步长为2,padding为1nn.ReLU(),nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 输入16通道,输出32通道,3x3卷积,步长为2,padding为1nn.ReLU(),nn.Conv2d(32, 64, kernel_size=7),                      # 输入32通道,输出64通道,7x7卷积)# 解码器部分self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=7),             # 输入64通道,输出32通道,7x7卷积转置nn.ReLU(),nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 输入32通道,输出16通道,3x3卷积转置,步长为2,padding为1,输出padding为1nn.ReLU(),nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 输入16通道,输出1通道,3x3卷积转置,步长为2,padding为1,输出padding为1nn.Sigmoid())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x# 创建自编码器对象并将模型移动到GPU
model = Autoencoder().to(device)

第三步:定义损失函数和优化器

定义损失函数和优化器如下。

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

第四步:训练自编码器模型

现在可以使用MNIST数据集训练自编码器模型。

# 训练自编码器
for epoch in range(num_epochs):for data in train_loader:img, _ = dataimg = img.to(device)# 前向传播output = model(img)loss = criterion(output, img)# 反向传播和优化器优化optimizer.zero_grad()loss.backward()optimizer.step()print("Epoch[{}/{}], loss:{:.4f}".format(epoch+1, num_epochs, loss.data))

第五步:测试自编码器模型

# 测试自编码器
model.eval()
total_loss = 0
with torch.no_grad():for data in test_loader:img, _ = dataimg = img.to(device)output = model(img)loss = criterion(output, img)total_loss += loss.item()print("Test average loss:{:.4f}".format(total_loss / len(test_loader)))

第六部:对比重构结果

在训练完成后,我们可以使用自编码器模型重构一些测试数据,并将重构的结果与原始数据进行比较。

# 迭代测试数据集,生成迭代器
dataiter = iter(test_loader)# 从迭代器中获取下一个批次的图像和标签
images, labels = next(dataiter)# 使用模型进行推断,处理获取的图像数据,并将结果保存在output变量中
output = model(images.to(device))# 创建子图和轴对象,其中第一行显示原始图像,第二行显示重构后的图像
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))# 循环遍历前10个图像,绘制原始图像和重构图像并添加标题
for i in range(10):# 显示原始图像axes[0,i].imshow(images[i].squeeze().numpy(), cmap='gray')axes[0,i].set_title("Original")axes[0,i].get_xaxis().set_visible(False)axes[0,i].get_yaxis().set_visible(False)# 显示重构后的图像axes[1,i].imshow(output[i].squeeze().cpu().detach().numpy(), cmap='gray')axes[1,i].set_title("Reconstructed")axes[1,i].get_xaxis().set_visible(False)axes[1,i].get_yaxis().set_visible(False)# 显示生成的子图
plt.show()

运行完整的程序后,我们应该会看到原始图像和重构图像的对比结果。自编码器可以在其中一些图像上产生较好的结果,但在其他图像上可能会出现一些失真或模糊。

自编码器在实际应用中有许多变体和扩展,例如稀疏自编码器和卷积自编码器。这里仅仅是一个简单的入门教程,可以让您了解自编码器的基本原理和应用。

相关文章:

自编码器简单介绍—使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试

文章目录 自编码器简单介绍什么是自编码器?自动编码器和卷积神经网络的区别?如何构建一个自编码器?如何训练自编码器?如何使用自编码器进行图像压缩?总结使用PyTorch构建简单的自动编码器第一步:导入库和数…...

redis单机最大并发量

redis单机最大并发量 布隆过滤器多级缓存客户端缓存应用层缓存Expires和Cache-Control的区别Nginx缓存管理 服务层缓存进程内缓存进程外缓存 缓存数据一致性问题的解决引入多级缓存设计的时刻 Redis的速度非常的快,单机的Redis就可以⽀撑 每秒十几万的并发,相对于MySQL来说,性…...

MTLAB绘图

这里写目录标题 一、图例1、散点图 二、绘图1、总体图形参数2、坐标、图框、网格图框去上右边框小刻度网格坐标范围和刻度控制旋转 坐标、刻度 3、图例图例位置和方向 Location和Orientation图例加标题 、分多列 4、文本 字、字体、字号5、线型 符号6、颜色栏 colorbar7、颜色8…...

自媒体必备素材库,免费、商用,赶紧马住~

自媒体经常需要用到各类素材,本期就给大家安利6个自媒体必备的素材网站,免费、付费、商用都有,建议收藏起来~ 1、菜鸟图库 https://www.sucai999.com/video.html?vNTYwNDUx 菜鸟图库可以找到设计、办公、图片、视频、音频等各种素材。视频素…...

ESP32设备驱动-BMP388气压传感器驱动

BMP388气压传感器驱动 文章目录 BMP388气压传感器驱动1、BMP388介绍2、硬件准备3、软件准备4、驱动实现1、BMP388介绍 BMP388 是一款非常小巧、低功耗和低噪声的 24 位绝对气压传感器。 它可以实现精确的高度跟踪,特别适合无人机应用。 BMP388 在 0-65C 之间的同类最佳 TCO,…...

攻防世界-Reversing-x64Elf-100

Reversing-x64Elf-100 18最佳Writeup由 yuchouxuan 提供 收藏 反馈 难度:1 方向:Reverse 题解数:15 解出人数:2460 题目来源: 题目描述: 暂无 note:undefined8 FUN_004006fd(long param_1){int local_2c;char *local_28 …...

C/C++每日一练(20230419)

目录 1. 插入区间 🌟🌟🌟 2. 单词拆分 🌟🌟 3. 不同路径 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日…...

[自注意力神经网络]Mask Transfiner网络-论文解读

本文为CVPR2022的论文。国际惯例,先贴出原文和源码: 原论文地址https://arxiv.org/pdf/2111.13673.pdf源码地址https://github.com/SysCV/transfiner 一、概述 传统的Two-Stage网络,如Mask R-CNN虽然在实例分割上取得了较好的效果&#xff…...

漫画:是喜,还是悲?AI竟帮我们把Office破活干完了

图文原创:亲爱的数据 国产大模型烈火制造。阿里百度字节美团各科技大佬不等闲。 大模型嘛,重大工程,对我等“怀保小民”来说,只关心怎么用,不关心怎么造。 我来介绍一下自己,我是一个写稿男团组合的成员&am…...

ChatGPT的原理分析

1.前言 ChatGPT是一种基于自然语言处理和人工智能技术的聊天机器人,它的基础是由OpenAI研发的GPT模型,其中GPT是Generative Pre-trained Transformer的缩写。GPT模型的训练使用了海量的语料库,可以预测下一个单词、短语、句子或文本&#xf…...

在线免费把Markdown格式文件转换为PDF格式

用CSDN的MarkDown编辑器在线转换 CSDN的MarkDown编辑器说实话还是挺好用的。 导出PDF操作步骤,图文配合看: 在MD编辑模式下写好MarkDown文章或者直接把要转换的MarkDown贴进来; 使用预览模式,然后在预览文件上右键选择打印&…...

R7-5 列车厢调度

R7-5 列车厢调度 分数 25 全屏浏览题目 切换布局 作者 周强 单位 青岛大学 1 <--移动方向/3 \2 -->移动方向 大家或许在某些数据结构教材上见到过“列车厢调度问题”&#xff08;当然没见过也不要紧&#xff09;。今天&#xff0c;我们就来实际操作一下列车…...

English Learning - L2 第 16 次小组纠音 弱读和语调 2023.4.22 周六

English Learning - L2 第 16 次小组纠音 弱读和语调 2023.4.22 周六 共性问题help /help/ 中的 e 和 lsorry /ˈsɒri/ 中的 ɒ 和 ilook out /lʊk aʊt/ 中的 ɒ 和 aʊdont /dəʊnt/ 中的 əʊemergency /ɪˈmɜːʤənsɪ/ 中的 ɜːname /neɪm/ 中的 eɪright /raɪt/…...

( “树” 之 前中后序遍历) 145. 二叉树的后序遍历 ——【Leetcode每日一题】

基础概念&#xff1a;前中后序遍历 1/ \2 3/ \ \ 4 5 6层次遍历顺序&#xff1a;[1 2 3 4 5 6]前序遍历顺序&#xff1a;[1 2 4 5 3 6]中序遍历顺序&#xff1a;[4 2 5 1 3 6]后序遍历顺序&#xff1a;[4 5 2 6 3 1] 层次遍历使用 BFS 实现&#xff0c;利用的就是 BFS…...

NPOI與Crystal report 13.0關於ICSharpCode.SharpZipLib控件版本衝突的解決方法

公司原來的系統用了Crystal report 13.0&#xff0c;它關聯使用ICSharpCode.SharpZipLib.dll &#xff08;壓縮控件&#xff09;的版本為0.85.1.271&#xff1b;後來因需要新增加 NPOI2.3控件&#xff0c;它關聯使用了ICSharpCode.SharpZipLib.dll 的版本為 高版本0.86&#xf…...

Sass @extend 与 继承

Sass extend 与 继承 extend 指令告诉 Sass 一个选择器的样式从另一选择器继承。 如果一个样式与另外一个样式几乎相同&#xff0c;只有少量的区别&#xff0c;则使用 extend 就显得很有用。 以下 Sass 实例中&#xff0c;我们创建了一个基本的按钮样式 .button-basic&#…...

权限控制导入到项目中

在项目中应用 进行认证和授权需要前面课程中提到的权限模型涉及的7张表支撑&#xff0c;因为用户信息、权限信息、菜单信息、角色信息、关联信息等都保存在这7张表中&#xff0c;也就是这些表中的数据是进行认证和授权的依据。所以在真正进行认证和授权之前需要对这些数据进行…...

CVPR2020:训练多视图三维点云配准

CVPR2020&#xff1a;训练多视图三维点云配准 Learning Multiview 3D Point Cloud Registration 源代码和预训练模型&#xff1a;https://github.com/zgojcic/3D_multiview_reg 论文地址&#xff1a; https://openaccess.thecvf.com/content_CVPR_2020/papers/Gojcic_Learn…...

string容器及其简单使用

string容器 概述声明和初始化获取字符串长度字符串拼接字符串比较字符串插入和删除字符串转换 概述 string是C中的一个标准库容器&#xff0c;用于处理字符串。它提供了一系列的操作函数&#xff0c;使得我们可以像处理其他容器一样方便地处理字符串。下面是string容器的详细介…...

芴甲氧羰酰基-氨基-聚乙二醇-巯基吡啶Fmoc-NH-PEG-OPSS

修饰性PEG芴甲氧羰基-氨基-聚乙二醇-巯基吡啶Fmoc-NH-PEG-OPSS是保护氨基的PEG衍生物之一 结构式&#xff1a; 芴甲氧羰酰基-氨基-聚乙二醇-巯基吡啶Fmoc-NH-PEG-OPSS聚乙二醇化可以提高聚乙二醇分子的稳定性&#xff0c;降低其免疫原性&#xff0c;仅用于科研实验。 FMOC-NH…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

HTML 列表、表格、表单

1 列表标签 作用&#xff1a;布局内容排列整齐的区域 列表分类&#xff1a;无序列表、有序列表、定义列表。 例如&#xff1a; 1.1 无序列表 标签&#xff1a;ul 嵌套 li&#xff0c;ul是无序列表&#xff0c;li是列表条目。 注意事项&#xff1a; ul 标签里面只能包裹 li…...

第25节 Node.js 断言测试

Node.js的assert模块主要用于编写程序的单元测试时使用&#xff0c;通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试&#xff0c;通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​&#xff1a;Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​&#xff1a; V8引擎优化&#xff08;for of替代forEach、Map/Set替代Object&#xff09;。默认使用更快的md4哈希算法。AST直接从Loa…...

基于PHP的连锁酒店管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

Golang——7、包与接口详解

包与接口详解 1、Golang包详解1.1、Golang中包的定义和介绍1.2、Golang包管理工具go mod1.3、Golang中自定义包1.4、Golang中使用第三包1.5、init函数 2、接口详解2.1、接口的定义2.2、空接口2.3、类型断言2.4、结构体值接收者和指针接收者实现接口的区别2.5、一个结构体实现多…...

什么是VR全景技术

VR全景技术&#xff0c;全称为虚拟现实全景技术&#xff0c;是通过计算机图像模拟生成三维空间中的虚拟世界&#xff0c;使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验&#xff0c;结合图文、3D、音视频等多媒体元素…...

土建施工员考试:建筑施工技术重点知识有哪些?

《管理实务》是土建施工员考试中侧重实操应用与管理能力的科目&#xff0c;核心考查施工组织、质量安全、进度成本等现场管理要点。以下是结合考试大纲与高频考点整理的重点内容&#xff0c;附学习方向和应试技巧&#xff1a; 一、施工组织与进度管理 核心目标&#xff1a; 规…...