PyTorch训练RNN, GRU, LSTM:手写数字识别
文章目录
- pytorch 神经网络训练demo
- Result
- 参考来源
pytorch 神经网络训练demo
数据集:MNIST
该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片

图片来源:https://tensornews.cn/mnist_intro/
神经网络:RNN, GRU, LSTM
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyperparameters
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2# Create a RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.rnn(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a GRU
class RNN_GRU(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_GRU, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.gru(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a LSTM
class RNN_LSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.lstm(x, (h0, c0))out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Load data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)# Initialize network 选择一个即可
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_GRU(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_LSTM(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Train network
for epoch in range(num_epochs):# data: images, targets: labelsfor batch_idx, (data, targets) in enumerate(train_loader):# Get data to cuda if possibledata = data.to(device).squeeze(1) # 删除一个张量中所有维数为1的维度 (N, 1, 28, 28) -> (N, 28, 28)targets = targets.to(device)# forwardscores = model(data) # 64*10loss = criterion(scores, targets)# backwardoptimizer.zero_grad()loss.backward()# gradient descent or adam stepoptimizer.step()# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):if loader.dataset.train:print("Checking accuracy on training data")else:print("Checking accuracy on test data")num_correct = 0num_samples = 0model.eval()with torch.no_grad(): # 不计算梯度for x, y in loader:x = x.to(device).squeeze(1)y = y.to(device)# x = x.reshape(x.shape[0], -1) # 64*784scores = model(x)# 64*10_, predictions = scores.max(dim=1) #dim=1,表示对每行取最大值,每行代表一个样本。num_correct += (predictions == y).sum()num_samples += predictions.size(0) # 64print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')model.train()check_accuracy(train_loader, model)
check_accuracy(test_loader, model)
Result
RNN Result
Checking accuracy on training data
Got 57926 / 60000 with accuracy 96.54%
Checking accuracy on test data
Got 9640 / 10000 with accuracy 96.40%GRU Result
Checking accuracy on training data
Got 59058 / 60000 with accuracy 98.43%
Checking accuracy on test data
Got 9841 / 10000 with accuracy 98.41%LSTM Result
Checking accuracy on training data
Got 59248 / 60000 with accuracy 98.75%
Checking accuracy on test data
Got 9849 / 10000 with accuracy 98.49%
参考来源
【1】https://www.youtube.com/watch?v=Gl2WXLIMvKA&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=5
相关文章:
PyTorch训练RNN, GRU, LSTM:手写数字识别
文章目录 pytorch 神经网络训练demoResult参考来源 pytorch 神经网络训练demo 数据集:MNIST 该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片 图片来源:https://tensornews.cn/mnist_intr…...
基于深度学习的高精度道路瑕疵检测系统(PyTorch+Pyside6+YOLOv5模型)
摘要:基于深度学习的高精度道路瑕疵(裂纹(Crack)、检查井(Manhole)、网(Net)、裂纹块(Patch-Crack)、网块(Patch-Net)、坑洼块&#x…...
【裸辞转行】是告别,也是新的开始
一年多了没有更新,是因为去年身体加心理因素辞职了,并且大概率不会再做程序员了,嗯。本来觉得可能再也不会打开 CSDN 了,想了想,还是来做个告别吧,任何事情都该有始有终才对。 回忆碎碎念 是在去年的 11 …...
了解交换机接口的链路类型(access、trunk、hybrid)
上一个章节中讲到了vlan的作用及使用,这篇了解一下交换机接口的链路类型和什么情况下使用 vlan在数据包中是如何体现的,在上一篇的时候提到测试了一下,从PC1去访问PC4的时候,只从E0/0/2发送给了E0/0/3这是,因为两个接…...
Android系统启动流程分析
当按下Android系统的开机电源按键时候,硬件会触发引导芯片,执行预定义的代码,然后加载引导程序(BootLoader)到RAM,Bootloader是Android系统起来前第一个程序,主要用来拉起Android系统程序,Android系统被拉起…...
如何在Ubuntu上安装OpenneBula
OpenNebula是一个开源云计算平台,允许我们在完全虚拟化云中组合和管理VMware和KVM虚拟机 第1步:安装MariaDB数据库服务器 OpenNebula还需要一个数据库服务器来存储其内容。 安装MariaDB: 1 2 sudo apt update sudo apt install mariadb-s…...
解决MySQL中分页查询时多页有重复数据,实际只有一条数据的问题
0 前言 有一个离奇的BUG,在查询时,第一页跟第二页有一个共同的数据。有的数据却不显示。 后来发现是在SQL排序时没用主键排序。 解决:使用主键排序 以下是我准备的举例,可以自己试试。 1 数据准备 SET NAMES utf8mb4; SET FORE…...
【数据结构】时间复杂度---OJ练习题
目录 🌴时间复杂度练习 📌面试题--->消失的数字 题目描述 题目链接:面试题 17.04. 消失的数字 🌴解题思路 📌思路1: malloc函数用法 📌思路2: 📌思路3&…...
京东自动化功能之商品信息监控是否有库存
这里有两个参数,分别是area和skuids area是地区编码,我这里统计了全国各个区县的area编码,用户可以根据实际地址进行构造skuids是商品的信息ID填写好这两个商品之后,会显示两种状态,判断有货或者无货状态,详情如下图所示 简单编写下python代码,比如我们的地址是北京市…...
【SwitchyOmega】SwitchyOmega 安装及使用
文章目录 安装教程使用教程 安装教程 SwitchyOmega 谷歌商店下载链接:https://chrome.google.com/webstore/detail/proxy-switchyomega/padekgcemlokbadohgkifijomclgjgif?hlen-US 在谷歌商店搜索 SwitchyOmega, 选择 Proxy SwitchyOmega 点击 Add t…...
CentOS5678 repo源 地址 阿里云开源镜像站
CentOS5678 repo 地址 阿里云开源镜像站 https://mirrors.aliyun.com/repo/ CentOS-5.repo https://mirrors.aliyun.com/repo/Centos-5.repo [base] nameCentOS-$releasever - Base - mirrors.aliyun.com failovermethodpriority baseurlhttp://mirrors.aliyun.com/centos/$r…...
【LLM】Langchain使用[二](模型链)
文章目录 1. SimpleSequentialChain2. SequentialChain3. 路由链 Router Chain Reference 1. SimpleSequentialChain 场景:一个输入和一个输出 from langchain.chat_models import ChatOpenAI #导入OpenAI模型 from langchain.prompts import ChatPromptTempla…...
简单机器学习工程化过程
1、确认需求(构建问题) 我们需要做什么? 比如根据一些输入数据,预测某个值? 比如输入一些特征,判断这个是个什么动物? 这里我们要可以尝试分析一下,我们要处理的是个什么问题&…...
【MongoDB】SpringBoot整合MongoDB
【MongoDB】SpringBoot整合MongoDB 文章目录 【MongoDB】SpringBoot整合MongoDB0. 准备工作1. 集合操作1.1 创建集合1.2 删除集合 2. 相关注解3. 文档操作3.1 添加文档3.2 批量添加文档3.3 查询文档3.3.1 查询所有文档3.3.2 根据id查询3.3.3 等值查询3.3.4 范围查询3.3.5 and查…...
关于游戏引擎(godot)对齐音乐bpm的技术
引擎默认底层 1. _process(): 每秒钟调用60次(无限的) 数学 1. bpm1分钟节拍数量60s节拍数量 bpm120 60s120拍 2. 每拍子时间 60/bpm 3. 每个拍子触发周期所需要的帧数 每拍子时间*60(帧率) 这个是从帧数级别上对齐拍子的时间&#x…...
【Go】实现一个代理Kerberos环境部分组件控制台的Web服务
实现一个代理Kerberos环境部分组件控制台的Web服务 背景安全措施引入的问题SSO单点登录 过程整体设计路由反向代理登录会话组件代理YarnHbase 结果 背景 首先要说明下我们目前有部分集群的环境使用的是HDP-3.1.5.0的大数据集群,除了集成了一些自定义的服务以外&…...
Spring Security 6.x 系列【63】扩展篇之匿名认证
有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 3.1.0 本系列Spring Security 版本 6.1.0 本系列Spring Authorization Server 版本 1.1.0 源码地址:https://gitee.com/pearl-organization/study-spring-security-demo 文章目录 1. 概述2. 配置3. Anonymo…...
供应链管理系统有哪些?
1万字干货分享,国内外 20款 供应链管理软件都给你讲的明明白白。如果你还不知道怎么选择,一定要翻到第三大段,这里我将会通过8年的软件产品选型经验告诉你,怎么样才能快速选到适合自己的软件工具。 (为防后续找不到&a…...
如何在PADS Logic中查找器件
PADS Logic提供类似于Windows的查找功能,可以进行器件的查找。 (1)在Logic设计界面中,将菜单显示中的“选择工具栏”进行打开,如图1所示,会弹出对应的“选择工具栏”的分栏菜单选项,如图2所示。…...
Android 生成pdf文件
Android 生成pdf文件 1.使用官方的方式 使用官方的方式也就是PdfDocument类的使用 1.1 基本使用 /**** 将tv内容写入到pdf文件*/RequiresApi(api Build.VERSION_CODES.KITKAT)private void newPdf() {// 创建一个PDF文本对象PdfDocument document new PdfDocument();//创建…...
阿里云ACP云计算备考笔记 (5)——弹性伸缩
目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...
MFC内存泄露
1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...
【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...
【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表
1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...
[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...
《C++ 模板》
目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板,就像一个模具,里面可以将不同类型的材料做成一个形状,其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式:templa…...
基于TurtleBot3在Gazebo地图实现机器人远程控制
1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...
Java求职者面试指南:计算机基础与源码原理深度解析
Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...
【网络安全】开源系统getshell漏洞挖掘
审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...
[ACTF2020 新生赛]Include 1(php://filter伪协议)
题目 做法 启动靶机,点进去 点进去 查看URL,有 ?fileflag.php说明存在文件包含,原理是php://filter 协议 当它与包含函数结合时,php://filter流会被当作php文件执行。 用php://filter加编码,能让PHP把文件内容…...
