PyTorch训练RNN, GRU, LSTM:手写数字识别
文章目录
- pytorch 神经网络训练demo
- Result
- 参考来源
pytorch 神经网络训练demo
数据集:MNIST
该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片

图片来源:https://tensornews.cn/mnist_intro/
神经网络:RNN, GRU, LSTM
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyperparameters
input_size = 28
sequence_length = 28
num_layers = 2
hidden_size = 256
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 2# Create a RNN
class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.rnn(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a GRU
class RNN_GRU(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_GRU, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.gru(x, h0)out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Create a LSTM
class RNN_LSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size*sequence_length, num_classes) # fully connecteddef forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)# Forward Propout, _ = self.lstm(x, (h0, c0))out = out.reshape(out.shape[0], -1)out = self.fc(out)return out# Load data
train_dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)# Initialize network 选择一个即可
model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_GRU(input_size, hidden_size, num_layers, num_classes).to(device)
# model = RNN_LSTM(input_size, hidden_size, num_layers, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Train network
for epoch in range(num_epochs):# data: images, targets: labelsfor batch_idx, (data, targets) in enumerate(train_loader):# Get data to cuda if possibledata = data.to(device).squeeze(1) # 删除一个张量中所有维数为1的维度 (N, 1, 28, 28) -> (N, 28, 28)targets = targets.to(device)# forwardscores = model(data) # 64*10loss = criterion(scores, targets)# backwardoptimizer.zero_grad()loss.backward()# gradient descent or adam stepoptimizer.step()# Check accuracy on training & test to see how good our model
def check_accuracy(loader, model):if loader.dataset.train:print("Checking accuracy on training data")else:print("Checking accuracy on test data")num_correct = 0num_samples = 0model.eval()with torch.no_grad(): # 不计算梯度for x, y in loader:x = x.to(device).squeeze(1)y = y.to(device)# x = x.reshape(x.shape[0], -1) # 64*784scores = model(x)# 64*10_, predictions = scores.max(dim=1) #dim=1,表示对每行取最大值,每行代表一个样本。num_correct += (predictions == y).sum()num_samples += predictions.size(0) # 64print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%')model.train()check_accuracy(train_loader, model)
check_accuracy(test_loader, model)
Result
RNN Result
Checking accuracy on training data
Got 57926 / 60000 with accuracy 96.54%
Checking accuracy on test data
Got 9640 / 10000 with accuracy 96.40%GRU Result
Checking accuracy on training data
Got 59058 / 60000 with accuracy 98.43%
Checking accuracy on test data
Got 9841 / 10000 with accuracy 98.41%LSTM Result
Checking accuracy on training data
Got 59248 / 60000 with accuracy 98.75%
Checking accuracy on test data
Got 9849 / 10000 with accuracy 98.49%
参考来源
【1】https://www.youtube.com/watch?v=Gl2WXLIMvKA&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=5
相关文章:
PyTorch训练RNN, GRU, LSTM:手写数字识别
文章目录 pytorch 神经网络训练demoResult参考来源 pytorch 神经网络训练demo 数据集:MNIST 该数据集的内容是手写数字识别,其分为两部分,分别含有60000张训练图片和10000张测试图片 图片来源:https://tensornews.cn/mnist_intr…...
基于深度学习的高精度道路瑕疵检测系统(PyTorch+Pyside6+YOLOv5模型)
摘要:基于深度学习的高精度道路瑕疵(裂纹(Crack)、检查井(Manhole)、网(Net)、裂纹块(Patch-Crack)、网块(Patch-Net)、坑洼块&#x…...
【裸辞转行】是告别,也是新的开始
一年多了没有更新,是因为去年身体加心理因素辞职了,并且大概率不会再做程序员了,嗯。本来觉得可能再也不会打开 CSDN 了,想了想,还是来做个告别吧,任何事情都该有始有终才对。 回忆碎碎念 是在去年的 11 …...
了解交换机接口的链路类型(access、trunk、hybrid)
上一个章节中讲到了vlan的作用及使用,这篇了解一下交换机接口的链路类型和什么情况下使用 vlan在数据包中是如何体现的,在上一篇的时候提到测试了一下,从PC1去访问PC4的时候,只从E0/0/2发送给了E0/0/3这是,因为两个接…...
Android系统启动流程分析
当按下Android系统的开机电源按键时候,硬件会触发引导芯片,执行预定义的代码,然后加载引导程序(BootLoader)到RAM,Bootloader是Android系统起来前第一个程序,主要用来拉起Android系统程序,Android系统被拉起…...
如何在Ubuntu上安装OpenneBula
OpenNebula是一个开源云计算平台,允许我们在完全虚拟化云中组合和管理VMware和KVM虚拟机 第1步:安装MariaDB数据库服务器 OpenNebula还需要一个数据库服务器来存储其内容。 安装MariaDB: 1 2 sudo apt update sudo apt install mariadb-s…...
解决MySQL中分页查询时多页有重复数据,实际只有一条数据的问题
0 前言 有一个离奇的BUG,在查询时,第一页跟第二页有一个共同的数据。有的数据却不显示。 后来发现是在SQL排序时没用主键排序。 解决:使用主键排序 以下是我准备的举例,可以自己试试。 1 数据准备 SET NAMES utf8mb4; SET FORE…...
【数据结构】时间复杂度---OJ练习题
目录 🌴时间复杂度练习 📌面试题--->消失的数字 题目描述 题目链接:面试题 17.04. 消失的数字 🌴解题思路 📌思路1: malloc函数用法 📌思路2: 📌思路3&…...
京东自动化功能之商品信息监控是否有库存
这里有两个参数,分别是area和skuids area是地区编码,我这里统计了全国各个区县的area编码,用户可以根据实际地址进行构造skuids是商品的信息ID填写好这两个商品之后,会显示两种状态,判断有货或者无货状态,详情如下图所示 简单编写下python代码,比如我们的地址是北京市…...
【SwitchyOmega】SwitchyOmega 安装及使用
文章目录 安装教程使用教程 安装教程 SwitchyOmega 谷歌商店下载链接:https://chrome.google.com/webstore/detail/proxy-switchyomega/padekgcemlokbadohgkifijomclgjgif?hlen-US 在谷歌商店搜索 SwitchyOmega, 选择 Proxy SwitchyOmega 点击 Add t…...
CentOS5678 repo源 地址 阿里云开源镜像站
CentOS5678 repo 地址 阿里云开源镜像站 https://mirrors.aliyun.com/repo/ CentOS-5.repo https://mirrors.aliyun.com/repo/Centos-5.repo [base] nameCentOS-$releasever - Base - mirrors.aliyun.com failovermethodpriority baseurlhttp://mirrors.aliyun.com/centos/$r…...
【LLM】Langchain使用[二](模型链)
文章目录 1. SimpleSequentialChain2. SequentialChain3. 路由链 Router Chain Reference 1. SimpleSequentialChain 场景:一个输入和一个输出 from langchain.chat_models import ChatOpenAI #导入OpenAI模型 from langchain.prompts import ChatPromptTempla…...
简单机器学习工程化过程
1、确认需求(构建问题) 我们需要做什么? 比如根据一些输入数据,预测某个值? 比如输入一些特征,判断这个是个什么动物? 这里我们要可以尝试分析一下,我们要处理的是个什么问题&…...
【MongoDB】SpringBoot整合MongoDB
【MongoDB】SpringBoot整合MongoDB 文章目录 【MongoDB】SpringBoot整合MongoDB0. 准备工作1. 集合操作1.1 创建集合1.2 删除集合 2. 相关注解3. 文档操作3.1 添加文档3.2 批量添加文档3.3 查询文档3.3.1 查询所有文档3.3.2 根据id查询3.3.3 等值查询3.3.4 范围查询3.3.5 and查…...
关于游戏引擎(godot)对齐音乐bpm的技术
引擎默认底层 1. _process(): 每秒钟调用60次(无限的) 数学 1. bpm1分钟节拍数量60s节拍数量 bpm120 60s120拍 2. 每拍子时间 60/bpm 3. 每个拍子触发周期所需要的帧数 每拍子时间*60(帧率) 这个是从帧数级别上对齐拍子的时间&#x…...
【Go】实现一个代理Kerberos环境部分组件控制台的Web服务
实现一个代理Kerberos环境部分组件控制台的Web服务 背景安全措施引入的问题SSO单点登录 过程整体设计路由反向代理登录会话组件代理YarnHbase 结果 背景 首先要说明下我们目前有部分集群的环境使用的是HDP-3.1.5.0的大数据集群,除了集成了一些自定义的服务以外&…...
Spring Security 6.x 系列【63】扩展篇之匿名认证
有道无术,术尚可求,有术无道,止于术。 本系列Spring Boot 版本 3.1.0 本系列Spring Security 版本 6.1.0 本系列Spring Authorization Server 版本 1.1.0 源码地址:https://gitee.com/pearl-organization/study-spring-security-demo 文章目录 1. 概述2. 配置3. Anonymo…...
供应链管理系统有哪些?
1万字干货分享,国内外 20款 供应链管理软件都给你讲的明明白白。如果你还不知道怎么选择,一定要翻到第三大段,这里我将会通过8年的软件产品选型经验告诉你,怎么样才能快速选到适合自己的软件工具。 (为防后续找不到&a…...
如何在PADS Logic中查找器件
PADS Logic提供类似于Windows的查找功能,可以进行器件的查找。 (1)在Logic设计界面中,将菜单显示中的“选择工具栏”进行打开,如图1所示,会弹出对应的“选择工具栏”的分栏菜单选项,如图2所示。…...
Android 生成pdf文件
Android 生成pdf文件 1.使用官方的方式 使用官方的方式也就是PdfDocument类的使用 1.1 基本使用 /**** 将tv内容写入到pdf文件*/RequiresApi(api Build.VERSION_CODES.KITKAT)private void newPdf() {// 创建一个PDF文本对象PdfDocument document new PdfDocument();//创建…...
19c补丁后oracle属主变化,导致不能识别磁盘组
补丁后服务器重启,数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后,存在与用户组权限相关的问题。具体表现为,Oracle 实例的运行用户(oracle)和集…...
CTF show Web 红包题第六弹
提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框,很难让人不联想到SQL注入,但提示都说了不是SQL注入,所以就不往这方面想了 先查看一下网页源码,发现一段JavaScript代码,有一个关键类ctfs…...
简易版抽奖活动的设计技术方案
1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...
边缘计算医疗风险自查APP开发方案
核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...
MySQL账号权限管理指南:安全创建账户与精细授权技巧
在MySQL数据库管理中,合理创建用户账号并分配精确权限是保障数据安全的核心环节。直接使用root账号进行所有操作不仅危险且难以审计操作行为。今天我们来全面解析MySQL账号创建与权限分配的专业方法。 一、为何需要创建独立账号? 最小权限原则…...
安卓基础(aar)
重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...
RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill
视觉语言模型(Vision-Language Models, VLMs),为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展,机器人仍难以胜任复杂的长时程任务(如家具装配),主要受限于人…...
苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会
在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...
消防一体化安全管控平台:构建消防“一张图”和APP统一管理
在城市的某个角落,一场突如其来的火灾打破了平静。熊熊烈火迅速蔓延,滚滚浓烟弥漫开来,周围群众的生命财产安全受到严重威胁。就在这千钧一发之际,消防救援队伍迅速行动,而豪越科技消防一体化安全管控平台构建的消防“…...
Windows 下端口占用排查与释放全攻略
Windows 下端口占用排查与释放全攻略 在开发和运维过程中,经常会遇到端口被占用的问题(如 8080、3306 等常用端口)。本文将详细介绍如何通过命令行和图形化界面快速定位并释放被占用的端口,帮助你高效解决此类问题。 一、准…...
