当前位置: 首页 > news >正文

PyTorch训练RNN, GRU, LSTM:手写数字识别

文章目录

    • pytorch 神经网络训练demo
    • Result
    • 参考来源

pytorch 神经网络训练demo

数据集:MNIST

该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片

在这里插入图片描述
图片来源:https://tensornews.cn/mnist_intro/

神经网络:RNN, GRU, LSTM

# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyperparameters
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2# Create a RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.rnn(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a GRU
class RNN_GRU(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_GRU, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.gru(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a LSTM
class RNN_LSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.lstm(x, (h0, c0))out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Load data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)# Initialize network 选择一个即可
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_GRU(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_LSTM(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Train network
for epoch in range(num_epochs):# data: images, targets: labelsfor batch_idx, (data, targets) in enumerate(train_loader):# Get data to cuda if possibledata = data.to(device).squeeze(1) # 删除一个张量中所有维数为1的维度 (N, 1, 28, 28) -> (N, 28, 28)targets = targets.to(device)# forwardscores = model(data) # 64*10loss = criterion(scores, targets)# backwardoptimizer.zero_grad()loss.backward()# gradient descent or adam stepoptimizer.step()# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):if loader.dataset.train:print("Checking accuracy on training data")else:print("Checking accuracy on test data")num_correct = 0num_samples = 0model.eval()with torch.no_grad(): # 不计算梯度for x, y in loader:x = x.to(device).squeeze(1)y = y.to(device)# x = x.reshape(x.shape[0], -1) # 64*784scores = model(x)# 64*10_, predictions = scores.max(dim=1) #dim=1,表示对每行取最大值,每行代表一个样本。num_correct += (predictions == y).sum()num_samples += predictions.size(0) # 64print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')model.train()check_accuracy(train_loader, model)
check_accuracy(test_loader, model)

Result

RNN Result
Checking accuracy on training data
Got 57926 / 60000 with accuracy 96.54%
Checking accuracy on test data
Got 9640 / 10000 with accuracy 96.40%GRU Result
Checking accuracy on training data
Got 59058 / 60000 with accuracy 98.43%
Checking accuracy on test data
Got 9841 / 10000 with accuracy 98.41%LSTM Result
Checking accuracy on training data
Got 59248 / 60000 with accuracy 98.75%
Checking accuracy on test data
Got 9849 / 10000 with accuracy 98.49%

参考来源

【1】https://www.youtube.com/watch?v=Gl2WXLIMvKA&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=5

相关文章:

PyTorch训练RNN, GRU, LSTM:手写数字识别

文章目录 pytorch 神经网络训练demoResult参考来源 pytorch 神经网络训练demo 数据集:MNIST 该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片 图片来源:https://tensornews.cn/mnist_intr…...

基于深度学习的高精度道路瑕疵检测系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度道路瑕疵(裂纹(Crack)、检查井(Manhole)、网(Net)、裂纹块(Patch-Crack)、网块(Patch-Net)、坑洼块&#x…...

【裸辞转行】是告别,也是新的开始

一年多了没有更新,是因为去年身体加心理因素辞职了,并且大概率不会再做程序员了,嗯。本来觉得可能再也不会打开 CSDN 了,想了想,还是来做个告别吧,任何事情都该有始有终才对。 回忆碎碎念 是在去年的 11 …...

了解交换机接口的链路类型(access、trunk、hybrid)

上一个章节中讲到了vlan的作用及使用,这篇了解一下交换机接口的链路类型和什么情况下使用 vlan在数据包中是如何体现的,在上一篇的时候提到测试了一下,从PC1去访问PC4的时候,只从E0/0/2发送给了E0/0/3这是,因为两个接…...

Android系统启动流程分析

当按下Android系统的开机电源按键时候,硬件会触发引导芯片,执行预定义的代码,然后加载引导程序(BootLoader)到RAM,Bootloader是Android系统起来前第一个程序,主要用来拉起Android系统程序,Android系统被拉起…...

如何在Ubuntu上安装OpenneBula

OpenNebula是一个开源云计算平台,允许我们在完全虚拟化云中组合和管理VMware和KVM虚拟机 第1步:安装MariaDB数据库服务器 OpenNebula还需要一个数据库服务器来存储其内容。 安装MariaDB: 1 2 sudo apt update sudo apt install mariadb-s…...

解决MySQL中分页查询时多页有重复数据,实际只有一条数据的问题

0 前言 有一个离奇的BUG,在查询时,第一页跟第二页有一个共同的数据。有的数据却不显示。 后来发现是在SQL排序时没用主键排序。 解决:使用主键排序 以下是我准备的举例,可以自己试试。 1 数据准备 SET NAMES utf8mb4; SET FORE…...

【数据结构】时间复杂度---OJ练习题

目录 🌴时间复杂度练习 📌面试题--->消失的数字 题目描述 题目链接:面试题 17.04. 消失的数字 🌴解题思路 📌思路1: malloc函数用法 📌思路2: 📌思路3&…...

京东自动化功能之商品信息监控是否有库存

这里有两个参数,分别是area和skuids area是地区编码,我这里统计了全国各个区县的area编码,用户可以根据实际地址进行构造skuids是商品的信息ID填写好这两个商品之后,会显示两种状态,判断有货或者无货状态,详情如下图所示 简单编写下python代码,比如我们的地址是北京市…...

【SwitchyOmega】SwitchyOmega 安装及使用

文章目录 安装教程使用教程 安装教程 SwitchyOmega 谷歌商店下载链接:https://chrome.google.com/webstore/detail/proxy-switchyomega/padekgcemlokbadohgkifijomclgjgif?hlen-US 在谷歌商店搜索 SwitchyOmega, 选择 Proxy SwitchyOmega 点击 Add t…...

CentOS5678 repo源 地址 阿里云开源镜像站

CentOS5678 repo 地址 阿里云开源镜像站 https://mirrors.aliyun.com/repo/ CentOS-5.repo https://mirrors.aliyun.com/repo/Centos-5.repo [base] nameCentOS-$releasever - Base - mirrors.aliyun.com failovermethodpriority baseurlhttp://mirrors.aliyun.com/centos/$r…...

【LLM】Langchain使用[二](模型链)

文章目录 1. SimpleSequentialChain2. SequentialChain3. 路由链 Router Chain Reference 1. SimpleSequentialChain 场景:一个输入和一个输出 from langchain.chat_models import ChatOpenAI #导入OpenAI模型 from langchain.prompts import ChatPromptTempla…...

简单机器学习工程化过程

1、确认需求(构建问题) 我们需要做什么? 比如根据一些输入数据,预测某个值? 比如输入一些特征,判断这个是个什么动物? 这里我们要可以尝试分析一下,我们要处理的是个什么问题&…...

【MongoDB】SpringBoot整合MongoDB

【MongoDB】SpringBoot整合MongoDB 文章目录 【MongoDB】SpringBoot整合MongoDB0. 准备工作1. 集合操作1.1 创建集合1.2 删除集合 2. 相关注解3. 文档操作3.1 添加文档3.2 批量添加文档3.3 查询文档3.3.1 查询所有文档3.3.2 根据id查询3.3.3 等值查询3.3.4 范围查询3.3.5 and查…...

关于游戏引擎(godot)对齐音乐bpm的技术

引擎默认底层 1. _process(): 每秒钟调用60次(无限的) 数学 1. bpm1分钟节拍数量60s节拍数量 bpm120 60s120拍 2. 每拍子时间 60/bpm 3. 每个拍子触发周期所需要的帧数 每拍子时间*60(帧率) 这个是从帧数级别上对齐拍子的时间&#x…...

【Go】实现一个代理Kerberos环境部分组件控制台的Web服务

实现一个代理Kerberos环境部分组件控制台的Web服务 背景安全措施引入的问题SSO单点登录 过程整体设计路由反向代理登录会话组件代理YarnHbase 结果 背景 首先要说明下我们目前有部分集群的环境使用的是HDP-3.1.5.0的大数据集群,除了集成了一些自定义的服务以外&…...

Spring Security 6.x 系列【63】扩展篇之匿名认证

有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 3.1.0 本系列Spring Security 版本 6.1.0 本系列Spring Authorization Server 版本 1.1.0 源码地址:https://gitee.com/pearl-organization/study-spring-security-demo 文章目录 1. 概述2. 配置3. Anonymo…...

供应链管理系统有哪些?

1万字干货分享,国内外 20款 供应链管理软件都给你讲的明明白白。如果你还不知道怎么选择,一定要翻到第三大段,这里我将会通过8年的软件产品选型经验告诉你,怎么样才能快速选到适合自己的软件工具。 (为防后续找不到&a…...

如何在PADS Logic中查找器件

PADS Logic提供类似于Windows的查找功能,可以进行器件的查找。 (1)在Logic设计界面中,将菜单显示中的“选择工具栏”进行打开,如图1所示,会弹出对应的“选择工具栏”的分栏菜单选项,如图2所示。…...

Android 生成pdf文件

Android 生成pdf文件 1.使用官方的方式 使用官方的方式也就是PdfDocument类的使用 1.1 基本使用 /**** 将tv内容写入到pdf文件*/RequiresApi(api Build.VERSION_CODES.KITKAT)private void newPdf() {// 创建一个PDF文本对象PdfDocument document new PdfDocument();//创建…...

【推荐算法】DeepFM:特征交叉建模的革命性架构

DeepFM:特征交叉建模的革命性架构 一、算法背景知识:特征交叉的演进困境1.1 特征交叉的核心价值1.2 传统方法的局限性 二、算法理论/结构:双路并行架构2.1 FM组件:显式特征交叉专家2.2 Deep组件:隐式高阶交叉挖掘机2.3…...

KAG与RAG在医疗人工智能系统中的多维对比分析

1、引言 随着人工智能技术的迅猛发展,大型语言模型(LLM)凭借其卓越的生成能力在医疗健康领域展现出巨大潜力。然而,这些模型在面对专业性、时效性和准确性要求极高的医疗场景时,往往面临知识更新受限、事实准确性不足以及幻觉问题等挑战。为解决这些问题,检索增强生成(…...

使用Python和OpenCV实现图像识别与目标检测

在计算机视觉领域,图像识别和目标检测是两个非常重要的任务。图像识别是指识别图像中的内容,例如判断一张图片中是否包含某个特定物体;目标检测则是在图像中定位并识别多个物体的位置和类别。OpenCV是一个功能强大的开源计算机视觉库&#xf…...

​​TLV4062-Q1​​、TLV4082-Q1​​迟滞电压比较器应用笔记

文章目录 主要作用应用场景关键优势典型应用示意图TLV4062-Q1 和 TLV4082-Q1 的主要作用及应用场景如下: 主要作用 精密电压监测:是一款双通道、低功耗比较器,用于监测输入电压是否超过预设阈值。 集成高精度基准电压源(阈值精度1%),内置60mV迟滞功能,可避免因噪声导致的…...

解决cocos 2dx/creator2.4在ios18下openURL无法调用的问题

由于ios18废弃了旧的openURL接口,我们需要修改CCApplication-ios.mm文件的Application::openURL方法: //修复openURL在ios18下无法调用的问题 bool Application::openURL(const std::string &url) {// NSString* msg [NSString stringWithCString:…...

bug 记录 - 使用 el-dialog 的 before-close 的坑

需求说明 弹窗中内嵌一个 form 表单 原始代码 <script setup lang"ts"> import { reactive, ref } from "vue" import type { FormRules } from element-plus const ruleFormRef ref() interface RuleForm {name: stringregion: number | null } …...

前端js获取当前经纬度(H5/pc/mac/window都可用)

前端JS获取当前位置的经纬度&#xff08;H5/PC/mac/window都可用&#xff0c;亲测&#xff01;&#xff09;&#xff0c;效果如下。 完整代码如下&#xff1a; <!-- 用原生api获取经纬度&#xff0c;转化为百度经纬度与服务端交互&#xff0c; 只支持https&#xff01; --&g…...

JavaScript 数组与流程控制:从基础操作到实战应用

在 JavaScript 编程的世界里&#xff0c;数组是一种极为重要的数据结构&#xff0c;它就像是一个有序的 “收纳盒”&#xff0c;能够将多个值整齐地存储起来。而流程控制语句则像是 “指挥官”&#xff0c;能够按照特定的逻辑对数组进行遍历和操作。接下来&#xff0c;就让我们…...

Kubernetes (k8s)版本发布情况

Kubernetes (k8s)版本发布情况 代码放在 GitHub - kubernetes/kubernetes: Production-Grade Container Scheduling and Management https://github.com/kubernetes/kubernetes/releases 文档放在 kubernetes.io各个版本变更等: https://github.com/kubernetes/kubernet…...

技术文档写作全攻略

一、引言 在快速迭代的软件开发中&#xff0c;技术文档早已不只是附属品&#xff0c;而是与代码同等重要的交付物&#xff1a; 帮助新成员 T0 → T1 学习曲线指数下降&#xff1b;降低支持成本&#xff0c;将重复性问答前移到自助文档&#xff1b;为合规审计、知识传承及商业…...