当前位置: 首页 > news >正文

PyTorch训练RNN, GRU, LSTM:手写数字识别

文章目录

    • pytorch 神经网络训练demo
    • Result
    • 参考来源

pytorch 神经网络训练demo

数据集:MNIST

该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片

在这里插入图片描述
图片来源:https://tensornews.cn/mnist_intro/

神经网络:RNN, GRU, LSTM

# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyperparameters
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2# Create a RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.rnn(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a GRU
class RNN_GRU(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_GRU, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.gru(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a LSTM
class RNN_LSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.lstm(x, (h0, c0))out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Load data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)# Initialize network 选择一个即可
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_GRU(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_LSTM(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Train network
for epoch in range(num_epochs):# data: images, targets: labelsfor batch_idx, (data, targets) in enumerate(train_loader):# Get data to cuda if possibledata = data.to(device).squeeze(1) # 删除一个张量中所有维数为1的维度 (N, 1, 28, 28) -> (N, 28, 28)targets = targets.to(device)# forwardscores = model(data) # 64*10loss = criterion(scores, targets)# backwardoptimizer.zero_grad()loss.backward()# gradient descent or adam stepoptimizer.step()# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):if loader.dataset.train:print("Checking accuracy on training data")else:print("Checking accuracy on test data")num_correct = 0num_samples = 0model.eval()with torch.no_grad(): # 不计算梯度for x, y in loader:x = x.to(device).squeeze(1)y = y.to(device)# x = x.reshape(x.shape[0], -1) # 64*784scores = model(x)# 64*10_, predictions = scores.max(dim=1) #dim=1,表示对每行取最大值,每行代表一个样本。num_correct += (predictions == y).sum()num_samples += predictions.size(0) # 64print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')model.train()check_accuracy(train_loader, model)
check_accuracy(test_loader, model)

Result

RNN Result
Checking accuracy on training data
Got 57926 / 60000 with accuracy 96.54%
Checking accuracy on test data
Got 9640 / 10000 with accuracy 96.40%GRU Result
Checking accuracy on training data
Got 59058 / 60000 with accuracy 98.43%
Checking accuracy on test data
Got 9841 / 10000 with accuracy 98.41%LSTM Result
Checking accuracy on training data
Got 59248 / 60000 with accuracy 98.75%
Checking accuracy on test data
Got 9849 / 10000 with accuracy 98.49%

参考来源

【1】https://www.youtube.com/watch?v=Gl2WXLIMvKA&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=5

相关文章:

PyTorch训练RNN, GRU, LSTM:手写数字识别

文章目录 pytorch 神经网络训练demoResult参考来源 pytorch 神经网络训练demo 数据集:MNIST 该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片 图片来源:https://tensornews.cn/mnist_intr…...

基于深度学习的高精度道路瑕疵检测系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度道路瑕疵(裂纹(Crack)、检查井(Manhole)、网(Net)、裂纹块(Patch-Crack)、网块(Patch-Net)、坑洼块&#x…...

【裸辞转行】是告别,也是新的开始

一年多了没有更新,是因为去年身体加心理因素辞职了,并且大概率不会再做程序员了,嗯。本来觉得可能再也不会打开 CSDN 了,想了想,还是来做个告别吧,任何事情都该有始有终才对。 回忆碎碎念 是在去年的 11 …...

了解交换机接口的链路类型(access、trunk、hybrid)

上一个章节中讲到了vlan的作用及使用,这篇了解一下交换机接口的链路类型和什么情况下使用 vlan在数据包中是如何体现的,在上一篇的时候提到测试了一下,从PC1去访问PC4的时候,只从E0/0/2发送给了E0/0/3这是,因为两个接…...

Android系统启动流程分析

当按下Android系统的开机电源按键时候,硬件会触发引导芯片,执行预定义的代码,然后加载引导程序(BootLoader)到RAM,Bootloader是Android系统起来前第一个程序,主要用来拉起Android系统程序,Android系统被拉起…...

如何在Ubuntu上安装OpenneBula

OpenNebula是一个开源云计算平台,允许我们在完全虚拟化云中组合和管理VMware和KVM虚拟机 第1步:安装MariaDB数据库服务器 OpenNebula还需要一个数据库服务器来存储其内容。 安装MariaDB: 1 2 sudo apt update sudo apt install mariadb-s…...

解决MySQL中分页查询时多页有重复数据,实际只有一条数据的问题

0 前言 有一个离奇的BUG,在查询时,第一页跟第二页有一个共同的数据。有的数据却不显示。 后来发现是在SQL排序时没用主键排序。 解决:使用主键排序 以下是我准备的举例,可以自己试试。 1 数据准备 SET NAMES utf8mb4; SET FORE…...

【数据结构】时间复杂度---OJ练习题

目录 🌴时间复杂度练习 📌面试题--->消失的数字 题目描述 题目链接:面试题 17.04. 消失的数字 🌴解题思路 📌思路1: malloc函数用法 📌思路2: 📌思路3&…...

京东自动化功能之商品信息监控是否有库存

这里有两个参数,分别是area和skuids area是地区编码,我这里统计了全国各个区县的area编码,用户可以根据实际地址进行构造skuids是商品的信息ID填写好这两个商品之后,会显示两种状态,判断有货或者无货状态,详情如下图所示 简单编写下python代码,比如我们的地址是北京市…...

【SwitchyOmega】SwitchyOmega 安装及使用

文章目录 安装教程使用教程 安装教程 SwitchyOmega 谷歌商店下载链接:https://chrome.google.com/webstore/detail/proxy-switchyomega/padekgcemlokbadohgkifijomclgjgif?hlen-US 在谷歌商店搜索 SwitchyOmega, 选择 Proxy SwitchyOmega 点击 Add t…...

CentOS5678 repo源 地址 阿里云开源镜像站

CentOS5678 repo 地址 阿里云开源镜像站 https://mirrors.aliyun.com/repo/ CentOS-5.repo https://mirrors.aliyun.com/repo/Centos-5.repo [base] nameCentOS-$releasever - Base - mirrors.aliyun.com failovermethodpriority baseurlhttp://mirrors.aliyun.com/centos/$r…...

【LLM】Langchain使用[二](模型链)

文章目录 1. SimpleSequentialChain2. SequentialChain3. 路由链 Router Chain Reference 1. SimpleSequentialChain 场景:一个输入和一个输出 from langchain.chat_models import ChatOpenAI #导入OpenAI模型 from langchain.prompts import ChatPromptTempla…...

简单机器学习工程化过程

1、确认需求(构建问题) 我们需要做什么? 比如根据一些输入数据,预测某个值? 比如输入一些特征,判断这个是个什么动物? 这里我们要可以尝试分析一下,我们要处理的是个什么问题&…...

【MongoDB】SpringBoot整合MongoDB

【MongoDB】SpringBoot整合MongoDB 文章目录 【MongoDB】SpringBoot整合MongoDB0. 准备工作1. 集合操作1.1 创建集合1.2 删除集合 2. 相关注解3. 文档操作3.1 添加文档3.2 批量添加文档3.3 查询文档3.3.1 查询所有文档3.3.2 根据id查询3.3.3 等值查询3.3.4 范围查询3.3.5 and查…...

关于游戏引擎(godot)对齐音乐bpm的技术

引擎默认底层 1. _process(): 每秒钟调用60次(无限的) 数学 1. bpm1分钟节拍数量60s节拍数量 bpm120 60s120拍 2. 每拍子时间 60/bpm 3. 每个拍子触发周期所需要的帧数 每拍子时间*60(帧率) 这个是从帧数级别上对齐拍子的时间&#x…...

【Go】实现一个代理Kerberos环境部分组件控制台的Web服务

实现一个代理Kerberos环境部分组件控制台的Web服务 背景安全措施引入的问题SSO单点登录 过程整体设计路由反向代理登录会话组件代理YarnHbase 结果 背景 首先要说明下我们目前有部分集群的环境使用的是HDP-3.1.5.0的大数据集群,除了集成了一些自定义的服务以外&…...

Spring Security 6.x 系列【63】扩展篇之匿名认证

有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 3.1.0 本系列Spring Security 版本 6.1.0 本系列Spring Authorization Server 版本 1.1.0 源码地址:https://gitee.com/pearl-organization/study-spring-security-demo 文章目录 1. 概述2. 配置3. Anonymo…...

供应链管理系统有哪些?

1万字干货分享,国内外 20款 供应链管理软件都给你讲的明明白白。如果你还不知道怎么选择,一定要翻到第三大段,这里我将会通过8年的软件产品选型经验告诉你,怎么样才能快速选到适合自己的软件工具。 (为防后续找不到&a…...

如何在PADS Logic中查找器件

PADS Logic提供类似于Windows的查找功能,可以进行器件的查找。 (1)在Logic设计界面中,将菜单显示中的“选择工具栏”进行打开,如图1所示,会弹出对应的“选择工具栏”的分栏菜单选项,如图2所示。…...

Android 生成pdf文件

Android 生成pdf文件 1.使用官方的方式 使用官方的方式也就是PdfDocument类的使用 1.1 基本使用 /**** 将tv内容写入到pdf文件*/RequiresApi(api Build.VERSION_CODES.KITKAT)private void newPdf() {// 创建一个PDF文本对象PdfDocument document new PdfDocument();//创建…...

XML Group端口详解

在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)

HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中,我们会遇到使用 java 调用 dll文件 的情况,此时大概率出现UnsatisfiedLinkError链接错误,原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用,结果 dll 未实现 JNI 协…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted()是OpenCV库中用于图像处理的函数,主要功能是将两个输入图像(尺寸和类型相同)按照指定的权重进行加权叠加(图像融合),并添加一个标量值&#x…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

uniapp中使用aixos 报错

问题&#xff1a; 在uniapp中使用aixos&#xff0c;运行后报如下错误&#xff1a; AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索&#xff08;基于物理空间 广播范围&#xff09;2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...

保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek

文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama&#xff08;有网络的电脑&#xff09;2.2.3 安装Ollama&#xff08;无网络的电脑&#xff09;2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...