python-pytorch 利用pytorch对堆叠自编码器进行训练和验证
利用pytorch对堆叠自编码器进行训练和验证
- 一、数据生成
- 二、定义自编码器模型
- 三、训练函数
- 四、训练堆叠自编码器
- 五、将已训练的自编码器级联
- 六、微调整个堆叠自编码器
一、数据生成
随机生成一些数据来模拟训练和验证数据集:
import torch# 随机生成数据
n_samples = 1000
n_features = 784 # 例如,28x28图像的像素数
train_data = torch.rand(n_samples, n_features)
val_data = torch.rand(int(n_samples * 0.1), n_features)
二、定义自编码器模型
import torch.nn as nnclass Autoencoder(nn.Module):def __init__(self, input_size, hidden_size):super(Autoencoder, self).__init__()self.encoder = nn.Sequential(nn.Linear(input_size, hidden_size),nn.Tanh())self.decoder = nn.Sequential(nn.Linear(hidden_size, input_size),nn.Tanh())def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x
三、训练函数
定义一个函数来训练自编码器:
def train_ae(model, train_loader, val_loader, num_epochs, criterion, optimizer):for epoch in range(num_epochs):# Trainingmodel.train()train_loss = 0for batch_data in train_loader:optimizer.zero_grad()outputs = model(batch_data)loss = criterion(outputs, batch_data)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader)print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}")# Validationmodel.eval()val_loss = 0with torch.no_grad():for batch_data in val_loader:outputs = model(batch_data)loss = criterion(outputs, batch_data)val_loss += loss.item()val_loss /= len(val_loader)print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
四、训练堆叠自编码器
使用上面定义的函数来训练自编码器:
from torch.utils.data import DataLoader# DataLoader
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)# 训练第一个自编码器
ae1 = Autoencoder(input_size=784, hidden_size=400)
optimizer = torch.optim.Adam(ae1.parameters(), lr=0.001)
criterion = nn.MSELoss()
train_ae(ae1, train_loader, val_loader, 10, criterion, optimizer)# 使用第一个自编码器的编码器对数据进行编码
encoded_train_data = []
for data in train_loader:encoded_train_data.append(ae1.encoder(data))
encoded_train_loader = DataLoader(torch.cat(encoded_train_data), batch_size=batch_size, shuffle=True)encoded_val_data = []
for data in val_loader:encoded_val_data.append(ae1.encoder(data))
encoded_val_loader = DataLoader(torch.cat(encoded_val_data), batch_size=batch_size, shuffle=False)# 训练第二个自编码器
ae2 = Autoencoder(input_size=400, hidden_size=200)
optimizer = torch.optim.Adam(ae2.parameters(), lr=0.001)
train_ae(ae2, encoded_train_loader, encoded_val_loader, 10, criterion, optimizer)# 使用第二个自编码器的编码器对数据进行编码
encoded_train_data = []
for data in train_loader:encoded_train_data.append(ae2.encoder(data))
encoded_train_loader = DataLoader(torch.cat(encoded_train_data), batch_size=batch_size, shuffle=True)encoded_val_data = []
for data in val_loader:encoded_val_data.append(ae2.encoder(data))
encoded_val_loader = DataLoader(torch.cat(encoded_val_data), batch_size=batch_size, shuffle=False)# 训练第三个自编码器
ae3 = Autoencoder(input_size=400, hidden_size=200)
optimizer = torch.optim.Adam(ae3.parameters(), lr=0.001)
train_ae(ae3, encoded_train_loader, encoded_val_loader, 10, criterion, optimizer)# 使用第三个自编码器的编码器对数据进行编码
encoded_train_data = []
for data in train_loader:encoded_train_data.append(ae3.encoder(data))
encoded_train_loader = DataLoader(torch.cat(encoded_train_data), batch_size=batch_size, shuffle=True)encoded_val_data = []
for data in val_loader:encoded_val_data.append(ae3.encoder(data))
encoded_val_loader = DataLoader(torch.cat(encoded_val_data), batch_size=batch_size, shuffle=False)
五、将已训练的自编码器级联
class StackedAutoencoder(nn.Module):def __init__(self, ae1, ae2, ae3):super(StackedAutoencoder, self).__init__()self.encoder = nn.Sequential(ae1.encoder, ae2.encoder, ae3.encoder)self.decoder = nn.Sequential(ae3.decoder, ae2.decoder, ae1.decoder)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return xsae = StackedAutoencoder(ae1, ae2, ae3)
六、微调整个堆叠自编码器
在整个数据集上重新训练堆叠自编码器来完成。
train_autoencoder(sae, train_dataset)
相关文章:
python-pytorch 利用pytorch对堆叠自编码器进行训练和验证
利用pytorch对堆叠自编码器进行训练和验证 一、数据生成二、定义自编码器模型三、训练函数四、训练堆叠自编码器五、将已训练的自编码器级联六、微调整个堆叠自编码器 一、数据生成 随机生成一些数据来模拟训练和验证数据集: import torch# 随机生成数据 n_sample…...

制作 3 档可调灯程序编写
PWM 0~255 可以将数据映射到0 75 150 225 尽可能均匀电压间隔...
源码分享-M3U8数据流ts的AES-128解密并合并---GoLang实现
之前使用C语言实现了一次,见M3U8数据流ts的AES-128解密并合并。 学习了Go语言后,又用Go重新实现了一遍。源码如下,无第三方库依赖。 package mainimport ("crypto/aes""crypto/cipher""encoding/binary"&quo…...
CSDN Q: “这段代码算是在STC89C52RC51单片机上完成PWM呼吸灯了吗?“
这是 CSDN上的一个问题 这段代码算是在STC89C52RC51单片机上完成PWM呼吸灯了吗,还是说得用上定时器和中断函数#include <regx52.h> 我个人认为: 效果上来说, 是的! 码以 以Time / 100-Time 调 Duty, 而 for i loop成 Period, 加上延时, 实现了 PWM周期, 虽然…...

Linux系统编程系列之线程池
Linux系统编程系列(16篇管饱,吃货都投降了!) 1、Linux系统编程系列之进程基础 2、Linux系统编程系列之进程间通信(IPC)-信号 3、Linux系统编程系列之进程间通信(IPC)-管道 4、Linux系统编程系列之进程间通信-IPC对象 5、Linux系统…...

Linux CentOS7 vim多文件与多窗口操作
窗口是可视化的分割区域。Windows中窗口的概念与linux中基本相同。连接xshell就是在Windows中新建一个窗口。而vim打开一个文件默认创建一个窗口。同时,Vim打开一个文件也就会建立一个缓冲区,打开多个文件就会创建多个缓冲区。 本文讨论vim中打开多个文…...

SPI 通信协议
1. SPI通信 1. 什么是SPI通信协议 2. SPI的通信过程 在一开始会先把发送缓冲器的数据(8位)。一次性放到移位寄存器里。 移位寄存器会一位一位发送出去。但是要先放到锁存器里。然后从机来读取。从机的过程也一样。当移位寄存器的数据全部发送完。其实…...

【图像处理】使用各向异性滤波器和分割图像处理从MRI图像检测脑肿瘤(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

5个适合初学者的初级网络安全工作,网络安全就业必看
前言 网络安全涉及保护计算机系统、网络和数据免受未经授权的访问、破坏和盗窃 - 防止数字活动和数据访问的中断 - 同时也保护用户的资产和隐私。鉴于公共事业、医疗保健、金融以及联邦政府等行业的网络犯罪攻击不断升级,对网络专业人员的需求很高,这并…...
Kafka核心原理
1、Topic的分片和副本机制 分片作用: 解决单台节点容量有限的问题,节点多,效率提升,吞吐量提升。通过分片,将一个大的容器分解为多个小的容器,分布在不同的节点上,从而实现分布式存储。 分片…...

探秘前后端开发世界:猫头虎带你穿梭编程的繁忙街区,解锁全栈之路
🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…...
洛谷_分支循环
p2433 问题 5 甲列火车长 260 米,每秒行 12 米;乙列火车长220 米,每秒行 20 米,两车相向而行,从两车车头相遇时开始计时,多长时间后两车车尾相离?已知答案是整数。 计算方式:两车车…...

MySQL数据库入门到精通——进阶篇(3)
黑马程序员 MySQL数据库入门到精通——进阶篇(3) 1. 锁1.1 锁-介绍1.2 锁-全局锁1.3 锁-表级锁1.3.1 表级锁-表锁1.3.2 表级锁元数据锁( meta data lock,MDL)1.3.3 表级锁-意向锁1.3.4 表级锁意向锁测试 1.4 锁-行级锁1.4.1 行级锁-行锁1.4.2…...

Mind Map:大语言模型中的知识图谱提示激发思维图10.1+10.2
知识图谱提示激发思维图 摘要介绍相关工作方法第一步:证据图挖掘第二步:证据图聚合第三步:LLM Mind Map推理 实验实验设置医学问答长对话问题使用KG的部分知识生成深入分析 总结 摘要 LLM通常在吸收新知识的能力、generation of hallucinati…...

[引擎开发] 杂谈ue4中的Vulkan
接触Vulkan大概也有大半年,概述一下自己这段时间了解到的东西。本文实际上是杂谈性质而非综述性质,带有严重的主观认知,因此并没有那么严谨。 使用Vulkan会带来什么呢?简单来说就是对底层更好的控制。这意味着我们能够有更多的手段…...
docker--redis容器部署及地理空间API的使用示例-II
文章目录 Redis 地理位置类型API命令操作示例JAVA使用示例导入依赖RedisTemplate 操作GeoData示例CityInfo实体类Geo操作接口类Geo操作接口实现类SpringBoot测试类RedissonClient 操作GeoData示例docker–redis容器部署及与SpringBoot整合 docker–redis容器部署及地理空间API的…...

Vue中如何进行文件浏览与文件管理
Vue中的文件浏览与文件管理 文件浏览与文件管理是许多Web应用程序中常见的功能之一。在Vue.js中,您可以轻松地实现文件浏览和管理功能,使您的应用程序更具交互性和可用性。本文将向您展示如何使用Vue.js构建文件浏览器和文件管理功能,以及如…...

jenkins利用插件Active Choices Plug-in达到联动显示或隐藏参数,且参数值可修改
1. 添加组件 Active Choices Plug-in 如jenkins无法联网,可在以下两个地址中下载插件,然后放到/home/jenkins/.jenkins/plugin下面重启jenkins即可 Active Choices Active Choices | Jenkins plugin 2. 效果如下: sharding为空时…...

香蕉叶病害数据集
1.数据集 第一个文件夹为数据增强(旋转平移裁剪等操作)后的数据集 第二个文件夹为原始数据集 2.原始数据集 Cordana文件夹(162张照片) healthy文件夹(129张) Pestalotiopsis文件夹(173张照片&…...

天地无用 - 修改朋友圈的定位: 高德地图 + 爱思助手
1,电脑上打开高德地图网页版 高德地图 (amap.com) 2,网页最下一栏,点击“开放平台” 高德开放平台 | 高德地图API (amap.com) 3,在新网页中,需要登录高德账户才能操作。 可以使用手机号和验证码登录。 4,…...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误
HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容
基于 UniApp + WebSocket实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...
Linux简单的操作
ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...

html css js网页制作成品——HTML+CSS榴莲商城网页设计(4页)附源码
目录 一、👨🎓网站题目 二、✍️网站描述 三、📚网站介绍 四、🌐网站效果 五、🪓 代码实现 🧱HTML 六、🥇 如何让学习不再盲目 七、🎁更多干货 一、👨…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题
分区配置 (ptab.json) img 属性介绍: img 属性指定分区存放的 image 名称,指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件,则以 proj_name:binary_name 格式指定文件名, proj_name 为工程 名&…...

Golang——6、指针和结构体
指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

LabVIEW双光子成像系统技术
双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制,展现出显著的技术优势: 深层组织穿透能力:适用于活体组织深度成像 高分辨率观测性能:满足微观结构的精细研究需求 低光毒性特点:减少对样本的损伤…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障
关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...

如何应对敏捷转型中的团队阻力
应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中,明确沟通敏捷转型目的尤为关键,团队成员只有清晰理解转型背后的原因和利益,才能降低对变化的…...