自编码器简单介绍—使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试
文章目录
- 自编码器简单介绍
- 什么是自编码器?
- 自动编码器和卷积神经网络的区别?
- 如何构建一个自编码器?
- 如何训练自编码器?
- 如何使用自编码器进行图像压缩?
- 总结
- 使用PyTorch构建简单的自动编码器
- 第一步:导入库和数据集
- 第二步:建立编码器和解码器
- 第三步:定义损失函数和优化器
- 第四步:训练自编码器模型
- 第五步:测试自编码器模型
- 第六部:对比重构结果
自编码器简单介绍
自编码器是一种无监督学习算法,用于学习数据中的特征,并将这些特征用于重构与输入相似的新数据。自编码器由编码器和解码器两部分组成,编码器用于将输入数据压缩到一个低维度的表示形式,解码器将该表示形式还原回输入数据的形式。自编码器可以应用于多种领域,例如图像处理、语音识别和自然语言处理等。
什么是自编码器?
自编码器是一种无监督学习算法,用于学习数据的压缩表示。它由两个部分组成:编码器和解码器。编码器将输入数据压缩成低维表示,解码器将这个低维表示重构成与原始输入尽可能接近的输出。
简单结构如下:

自动编码器和卷积神经网络的区别?
自动编码器和卷积神经网络(CNN)都是深度学习中常用的模型,但它们的目的和结构略有不同。
自动编码器是一种无监督学习模型,其目的是学习一个对输入数据进行压缩和解压缩的函数。自动编码器通常由两个部分组成:编码器和解码器。编码器将输入数据转换为潜在表示,解码器将潜在表示转换回原始数据。自动编码器的目标是最小化重构误差,即在解码器输出的数据与原始数据之间的差异。
与之不同,卷积神经网络是一种用于图像分类、目标检测、语音识别等任务的有监督学习模型。卷积神经网络通常由卷积层、池化层、全连接层等组成。卷积层可以捕捉输入数据的局部特征,池化层可以减少特征图的大小,全连接层可以将特征图转换为分类结果。
虽然自动编码器和卷积神经网络的目的和结构不同,但它们都可以用于特征提取。自动编码器可以学习输入数据的低维表示,卷积神经网络可以提取输入数据的局部特征。在某些任务中,这些特征可以作为输入传递到其他模型中,以提高任务的性能。
除了特征提取,自动编码器和卷积神经网络还有其他方面的不同之处。
自动编码器可以用于数据的降维和去噪。通过学习输入数据的低维表示,自动编码器可以将高维数据降低到更低的维度,从而简化数据。此外,自动编码器还可以在输入数据中去除噪声,因为它们的目标是在重构时最小化重构误差,这有助于滤除输入数据中的噪声。
卷积神经网络则更适合处理图像、语音、视频等具有空间结构的数据。卷积层可以捕捉输入数据的局部特征,这对于图像分类、目标检测等任务非常重要。此外,卷积神经网络还可以通过使用池化层来减少特征图的大小,从而在处理大型图像时降低计算成本。
总的来说,自动编码器和卷积神经网络是两种不同的深度学习模型,它们的应用场景和目标略有不同。但是,它们都是非常有用的工具,可以在各种任务中发挥重要作用。
如何构建一个自编码器?
首先,需要确定编码器和解码器的结构。编码器可以是多层感知器(MLP)或卷积神经网络(CNN),其中每一层都包含多个神经元,并通过非线性函数进行激活。解码器通常与编码器结构相对称,也可以是一个MLP或CNN。在训练自编码器时,将输入数据输入编码器并计算编码器输出,然后将其输入解码器并计算解码器输出。最终的目标是最小化解码器输出和原始输入之间的差异。
如何训练自编码器?
训练自编码器的关键是确定损失函数。通常使用均方误差(MSE)来计算解码器输出和原始输入之间的差异。MSE的计算方式如下:
$ 在公式中, y i y_i yi 是实际值, y i ^ \hat{y_i} yi^ 是预测值, N N N 是数据集中的样本数。
M S E = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 \mathrm MSE=\frac{1}{N}\sum_{i=1}^N(y_i-\hat y_i)^2 MSE=N1i=1∑N(yi−y^i)2
其中,N是样本数量, y i y_i yi是实际值, y ^ i \hat y_i y^i是解码器输出,即预测值。通过反向传播算法,可以计算编码器和解码器中所有参数的梯度,并使用梯度下降算法对参数进行更新。
如何使用自编码器进行图像压缩?
在用自编码器压缩图像时,将图像输入自编码器的编码器部分,得到低维表示,该表示可以看作是图像的压缩版本。可以通过解码器部分将该低维表示解码成原始图像。由于自编码器的解码器部分重构图像,因此可以使用自编码器进行图像压缩。
总结
以上是自编码器的简单入门教程。自编码器是一种无监督学习算法,可以用于学习数据的压缩表示。在训练自编码器时需要确定编码器和解码器的结构以及损失函数,并使用梯度下降算法对参数进行更新。自编码器可以用于图像压缩等任务。
使用PyTorch构建简单的自动编码器
在本教程中,我们将使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试。
第一步:导入库和数据集
首先,我们需要导入必要的库并加载MNIST数据集。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 检查是否安装了CUDA,并且CUDA是否适用于你的GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 超参数设置
batch_size = 64
learning_rate = 1e-3
num_epochs = 10# 加载MNIST手写数字数据集
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data/', train=False, transform=transforms.ToTensor())# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
第二步:建立编码器和解码器
接下来,我们需要建立编码器和解码器模型。
# 定义一个自编码器的类
class Autoencoder(nn.Module):def __init__(self):super(Autoencoder, self).__init__()# 编码器部分self.encoder = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # 输入1通道,输出16通道,3x3卷积,步长为2,padding为1nn.ReLU(),nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 输入16通道,输出32通道,3x3卷积,步长为2,padding为1nn.ReLU(),nn.Conv2d(32, 64, kernel_size=7), # 输入32通道,输出64通道,7x7卷积)# 解码器部分self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=7), # 输入64通道,输出32通道,7x7卷积转置nn.ReLU(),nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # 输入32通道,输出16通道,3x3卷积转置,步长为2,padding为1,输出padding为1nn.ReLU(),nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1), # 输入16通道,输出1通道,3x3卷积转置,步长为2,padding为1,输出padding为1nn.Sigmoid())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x# 创建自编码器对象并将模型移动到GPU
model = Autoencoder().to(device)
第三步:定义损失函数和优化器
定义损失函数和优化器如下。
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
第四步:训练自编码器模型
现在可以使用MNIST数据集训练自编码器模型。
# 训练自编码器
for epoch in range(num_epochs):for data in train_loader:img, _ = dataimg = img.to(device)# 前向传播output = model(img)loss = criterion(output, img)# 反向传播和优化器优化optimizer.zero_grad()loss.backward()optimizer.step()print("Epoch[{}/{}], loss:{:.4f}".format(epoch+1, num_epochs, loss.data))
第五步:测试自编码器模型
# 测试自编码器
model.eval()
total_loss = 0
with torch.no_grad():for data in test_loader:img, _ = dataimg = img.to(device)output = model(img)loss = criterion(output, img)total_loss += loss.item()print("Test average loss:{:.4f}".format(total_loss / len(test_loader)))
第六部:对比重构结果
在训练完成后,我们可以使用自编码器模型重构一些测试数据,并将重构的结果与原始数据进行比较。
# 迭代测试数据集,生成迭代器
dataiter = iter(test_loader)# 从迭代器中获取下一个批次的图像和标签
images, labels = next(dataiter)# 使用模型进行推断,处理获取的图像数据,并将结果保存在output变量中
output = model(images.to(device))# 创建子图和轴对象,其中第一行显示原始图像,第二行显示重构后的图像
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))# 循环遍历前10个图像,绘制原始图像和重构图像并添加标题
for i in range(10):# 显示原始图像axes[0,i].imshow(images[i].squeeze().numpy(), cmap='gray')axes[0,i].set_title("Original")axes[0,i].get_xaxis().set_visible(False)axes[0,i].get_yaxis().set_visible(False)# 显示重构后的图像axes[1,i].imshow(output[i].squeeze().cpu().detach().numpy(), cmap='gray')axes[1,i].set_title("Reconstructed")axes[1,i].get_xaxis().set_visible(False)axes[1,i].get_yaxis().set_visible(False)# 显示生成的子图
plt.show()
运行完整的程序后,我们应该会看到原始图像和重构图像的对比结果。自编码器可以在其中一些图像上产生较好的结果,但在其他图像上可能会出现一些失真或模糊。
自编码器在实际应用中有许多变体和扩展,例如稀疏自编码器和卷积自编码器。这里仅仅是一个简单的入门教程,可以让您了解自编码器的基本原理和应用。
相关文章:
自编码器简单介绍—使用PyTorch库实现一个简单的自编码器,并使用MNIST数据集进行训练和测试
文章目录 自编码器简单介绍什么是自编码器?自动编码器和卷积神经网络的区别?如何构建一个自编码器?如何训练自编码器?如何使用自编码器进行图像压缩?总结使用PyTorch构建简单的自动编码器第一步:导入库和数…...
redis单机最大并发量
redis单机最大并发量 布隆过滤器多级缓存客户端缓存应用层缓存Expires和Cache-Control的区别Nginx缓存管理 服务层缓存进程内缓存进程外缓存 缓存数据一致性问题的解决引入多级缓存设计的时刻 Redis的速度非常的快,单机的Redis就可以⽀撑 每秒十几万的并发,相对于MySQL来说,性…...
MTLAB绘图
这里写目录标题 一、图例1、散点图 二、绘图1、总体图形参数2、坐标、图框、网格图框去上右边框小刻度网格坐标范围和刻度控制旋转 坐标、刻度 3、图例图例位置和方向 Location和Orientation图例加标题 、分多列 4、文本 字、字体、字号5、线型 符号6、颜色栏 colorbar7、颜色8…...
自媒体必备素材库,免费、商用,赶紧马住~
自媒体经常需要用到各类素材,本期就给大家安利6个自媒体必备的素材网站,免费、付费、商用都有,建议收藏起来~ 1、菜鸟图库 https://www.sucai999.com/video.html?vNTYwNDUx 菜鸟图库可以找到设计、办公、图片、视频、音频等各种素材。视频素…...
ESP32设备驱动-BMP388气压传感器驱动
BMP388气压传感器驱动 文章目录 BMP388气压传感器驱动1、BMP388介绍2、硬件准备3、软件准备4、驱动实现1、BMP388介绍 BMP388 是一款非常小巧、低功耗和低噪声的 24 位绝对气压传感器。 它可以实现精确的高度跟踪,特别适合无人机应用。 BMP388 在 0-65C 之间的同类最佳 TCO,…...
攻防世界-Reversing-x64Elf-100
Reversing-x64Elf-100 18最佳Writeup由 yuchouxuan 提供 收藏 反馈 难度:1 方向:Reverse 题解数:15 解出人数:2460 题目来源: 题目描述: 暂无 note:undefined8 FUN_004006fd(long param_1){int local_2c;char *local_28 …...
C/C++每日一练(20230419)
目录 1. 插入区间 🌟🌟🌟 2. 单词拆分 🌟🌟 3. 不同路径 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日…...
[自注意力神经网络]Mask Transfiner网络-论文解读
本文为CVPR2022的论文。国际惯例,先贴出原文和源码: 原论文地址https://arxiv.org/pdf/2111.13673.pdf源码地址https://github.com/SysCV/transfiner 一、概述 传统的Two-Stage网络,如Mask R-CNN虽然在实例分割上取得了较好的效果ÿ…...
漫画:是喜,还是悲?AI竟帮我们把Office破活干完了
图文原创:亲爱的数据 国产大模型烈火制造。阿里百度字节美团各科技大佬不等闲。 大模型嘛,重大工程,对我等“怀保小民”来说,只关心怎么用,不关心怎么造。 我来介绍一下自己,我是一个写稿男团组合的成员&am…...
ChatGPT的原理分析
1.前言 ChatGPT是一种基于自然语言处理和人工智能技术的聊天机器人,它的基础是由OpenAI研发的GPT模型,其中GPT是Generative Pre-trained Transformer的缩写。GPT模型的训练使用了海量的语料库,可以预测下一个单词、短语、句子或文本…...
在线免费把Markdown格式文件转换为PDF格式
用CSDN的MarkDown编辑器在线转换 CSDN的MarkDown编辑器说实话还是挺好用的。 导出PDF操作步骤,图文配合看: 在MD编辑模式下写好MarkDown文章或者直接把要转换的MarkDown贴进来; 使用预览模式,然后在预览文件上右键选择打印&…...
R7-5 列车厢调度
R7-5 列车厢调度 分数 25 全屏浏览题目 切换布局 作者 周强 单位 青岛大学 1 <--移动方向/3 \2 -->移动方向 大家或许在某些数据结构教材上见到过“列车厢调度问题”(当然没见过也不要紧)。今天,我们就来实际操作一下列车…...
English Learning - L2 第 16 次小组纠音 弱读和语调 2023.4.22 周六
English Learning - L2 第 16 次小组纠音 弱读和语调 2023.4.22 周六 共性问题help /help/ 中的 e 和 lsorry /ˈsɒri/ 中的 ɒ 和 ilook out /lʊk aʊt/ 中的 ɒ 和 aʊdont /dəʊnt/ 中的 əʊemergency /ɪˈmɜːʤənsɪ/ 中的 ɜːname /neɪm/ 中的 eɪright /raɪt/…...
( “树” 之 前中后序遍历) 145. 二叉树的后序遍历 ——【Leetcode每日一题】
基础概念:前中后序遍历 1/ \2 3/ \ \ 4 5 6层次遍历顺序:[1 2 3 4 5 6]前序遍历顺序:[1 2 4 5 3 6]中序遍历顺序:[4 2 5 1 3 6]后序遍历顺序:[4 5 2 6 3 1] 层次遍历使用 BFS 实现,利用的就是 BFS…...
NPOI與Crystal report 13.0關於ICSharpCode.SharpZipLib控件版本衝突的解決方法
公司原來的系統用了Crystal report 13.0,它關聯使用ICSharpCode.SharpZipLib.dll (壓縮控件)的版本為0.85.1.271;後來因需要新增加 NPOI2.3控件,它關聯使用了ICSharpCode.SharpZipLib.dll 的版本為 高版本0.86…...
Sass @extend 与 继承
Sass extend 与 继承 extend 指令告诉 Sass 一个选择器的样式从另一选择器继承。 如果一个样式与另外一个样式几乎相同,只有少量的区别,则使用 extend 就显得很有用。 以下 Sass 实例中,我们创建了一个基本的按钮样式 .button-basic&#…...
权限控制导入到项目中
在项目中应用 进行认证和授权需要前面课程中提到的权限模型涉及的7张表支撑,因为用户信息、权限信息、菜单信息、角色信息、关联信息等都保存在这7张表中,也就是这些表中的数据是进行认证和授权的依据。所以在真正进行认证和授权之前需要对这些数据进行…...
CVPR2020:训练多视图三维点云配准
CVPR2020:训练多视图三维点云配准 Learning Multiview 3D Point Cloud Registration 源代码和预训练模型:https://github.com/zgojcic/3D_multiview_reg 论文地址: https://openaccess.thecvf.com/content_CVPR_2020/papers/Gojcic_Learn…...
string容器及其简单使用
string容器 概述声明和初始化获取字符串长度字符串拼接字符串比较字符串插入和删除字符串转换 概述 string是C中的一个标准库容器,用于处理字符串。它提供了一系列的操作函数,使得我们可以像处理其他容器一样方便地处理字符串。下面是string容器的详细介…...
芴甲氧羰酰基-氨基-聚乙二醇-巯基吡啶Fmoc-NH-PEG-OPSS
修饰性PEG芴甲氧羰基-氨基-聚乙二醇-巯基吡啶Fmoc-NH-PEG-OPSS是保护氨基的PEG衍生物之一 结构式: 芴甲氧羰酰基-氨基-聚乙二醇-巯基吡啶Fmoc-NH-PEG-OPSS聚乙二醇化可以提高聚乙二醇分子的稳定性,降低其免疫原性,仅用于科研实验。 FMOC-NH…...
XCTF-web-easyupload
试了试php,php7,pht,phtml等,都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接,得到flag...
SkyWalking 10.2.0 SWCK 配置过程
SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外,K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案,全安装在K8S群集中。 具体可参…...
(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...
智能在线客服平台:数字化时代企业连接用户的 AI 中枢
随着互联网技术的飞速发展,消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁,不仅优化了客户体验,还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用,并…...
ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放
简介 前面两期文章我们介绍了I2S的读取和写入,一个是通过INMP441麦克风模块采集音频,一个是通过PCM5102A模块播放音频,那如果我们将两者结合起来,将麦克风采集到的音频通过PCM5102A播放,是不是就可以做一个扩音器了呢…...
涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战
“🤖手搓TuyaAI语音指令 😍秒变表情包大师,让萌系Otto机器人🔥玩出智能新花样!开整!” 🤖 Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制(TuyaAI…...
前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
Xen Server服务器释放磁盘空间
disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...
【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...
