当前位置: 首页 > news >正文

第P1周:Pytorch实现mnist手写数字识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目标

1. 实现pytorch环境配置
2. 实现mnist手写数字识别
3. 自己写几个数字识别试试

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架:

(二)具体步骤
**1.**配置Pytorch环境

打开官网PyTorch,Get started:
image.png
接下来是选择安装版本,最难的就是确定Compute Platform的版本,是否要使用GPU。所以先要确定CUDA的版本。
image.png
会发现,pytorch官网根本没有对应12.7的版本,先安装最新的试试呗,选择12.4:
image.png
安装命令:pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
image.png
image.png
安装完成,我们建立python文件,输入如下代码:

import torch  
x = torch.rand(5, 3)  
print(x)  print(torch.cuda.is_available())---------output---------------
tensor([[0.3952, 0.6351, 0.3107],[0.8780, 0.6469, 0.6714],[0.4380, 0.0236, 0.5976],[0.4132, 0.9663, 0.7576],[0.4047, 0.4636, 0.2858]])
True

从输出来看,成功了。下面开始正式的mnist手写数字识别

2. 下载数据并加载数据
import torch  
import torch.nn as nn  
# import matplotlib.pyplot as plt  
import torchvision  # 第一步:设置硬件设备,有GPU就使用GPU,没有就使用GPU  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)  # 第二步:导入数据  
# MNIST数据在torchvision.datasets中,自带的,可以通过代码在线下载数据。  
train_ds = torchvision.datasets.MNIST(root='./data',    # 下载的数据所存储的本地目录  train=True,       # True为训练集,False为测试集  transform=torchvision.transforms.ToTensor(),  # 将下载的数据直接转换成张量格式  download=True     # True直接在线下载,且下载到root指定的目录中,注意已经下载了,第二次以后就不会再下载了  )  
test_ds = torchvision.datasets.MNIST(root='./data',  train=False,  transform=torchvision.transforms.ToTensor(),  download=True  )  # 第三步:加载数据  
# Pytorch使用torch.utils.data.DataLoader进行数据加载  
batch_size = 32  
train_dl = torch.utils.data.DataLoader(dataset=train_ds, # 要加载的数据集  batch_size=batch_size, # 批次的大小  shuffle=True,     # 每个epoch重新排列数据  # 以下的参数有默认值可以不写  num_workers=0, # 用于加载的子进程数,默认值为0.注意在windows中如果设置非0,有可能会报错  pin_memory=True, # True-数据加载器将在返回之前将张量复制到设备/CUDA 固定内存中。 如果数据元素是自定义类型,或者collate_fn返回一个自定义类型的批次。  drop_last=False, #如果数据集大小不能被批次大小整除,则设置为 True 以删除最后一个不完整的批次。 如果 False 并且数据集的大小不能被批大小整除,则最后一批将保留。 (默认值:False)  timeout=0, # 设置数据读取的超时时间 , 超过这个时间还没读取到数据的话就会报错。(默认值:0)  worker_init_fn=None # 如果不是 None,这将在步长之后和数据加载之前在每个工作子进程上调用,并使用工作 id([0,num_workers - 1] 中的一个 int)的顺序逐个导入。(默认:None)  )  # 取一个批次看一下数据格式,数据的shape为[batch_size, channel, height, weight]  
# batch_size是已经设定的32,channel, height和weight分别是图片的通道数,高度和宽度  
images, labels = next(iter(train_dl))  
print(images.shape)

image.png
image.png
看这个图片的shape是torch.size([32, 1, 28, 28]),可以看图MNIST的数据集里的图像我猜应该是单色的(channel=1),28 * 28大小的图片(height=28, weight=28)。
将图片可视化展示出来看看:

# 数据可视化  
plt.figure(figsize=(20, 5)) # 指定图片大小 ,图像大小为20宽,高5的绘图(单位为英寸)  
for i , images in enumerate(images[:20]):  # 维度缩减,npimg = np.squeeze(images.numpy())  # 将整个figure分成2行10列,绘制第i+1个子图  plt.subplot(2, 10, i+1)  plt.imshow(npimg, cmap=plt.cm.binary)  plt.axis('off')  
plt.show()

image.png

**3.**构建CNN网络
num_classes = 10 # MNIST数据集中是识别0-9这10个数字,因此是10个类别。class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 特征提取网络self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # 第一层卷积,卷积核大小3*3self.pool1 = nn.MaxPool2d(2)    # 池化层,池化核大小为2*2self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 第二层卷积,卷积核大小3*3self.pool2 = nn.MaxPool2d(2)# 分类网络self.fc1 = nn.Linear(1600, 64)self.fc2 = nn.Linear(64, num_classes)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = torch.flatten(x, start_dim=1)x = F.relu(self.fc1(x))x = self.fc2(x)return x# 第四步:加载并打印模型
# 将模型转移到GPU中
model = Model().to(device)
summary(model)>)

image.png

4.训练模型
# 第五步:训练模型  
loss_fn = nn.CrossEntropyLoss() # 创建损失函数  
learn_rate = 1e-2   # 设置学习率  
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)  # 循环训练  
def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset) # 训练集的大小  num_batches = len(dataloader) # 批次数目  train_loss, train_acc = 0, 0  # 初始化训练损失率和正确率都为0  for X, y in dataloader: # 获取图片及标签  X, y = X.to(device), y.to(device)   # 将图片和标准转换到GPU中  # 计算预测误差  pred = model(X) # 使用CNN网络预测输出pred  loss = loss_fn(pred, y) # 计算预测输出的pred和真实值y之间的差距  # 反向传播  optimizer.zero_grad()   # grad属性归零  loss.backward() # 反向传播  optimizer.step()    # 第一步自动更新  # 记录acc与loss  train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()  train_loss += loss.item()  train_acc /= size  train_loss /= num_batches  return train_acc, train_loss  # 测试函数,注意测试函数不需要进行梯度下降,不进行网络权重更新,所以不需要传入优化器  
def test(dataloader, model, loss_fn):  size = len(dataloader.dataset)  num_batches = len(dataloader)  test_loss, test_acc = 0, 0  # 当不进行训练时,停止梯度更新,节省计算内存消耗  with torch.no_grad():  for imgs, targets in dataloader:  imgs, target = imgs.to(device), targets.to(device)  # 计算 loss            target_pred = model(imgs)  loss = loss_fn(target_pred, target)  test_loss += loss.item()  test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()  test_acc /= size  test_loss /= num_batches  return test_acc, test_loss  # 正式训练  
epochs = 5  
train_loss, train_acc, test_loss, test_acc = [], [], [], []  for epoch in range(epochs):  model.train()  epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)  model.eval()  epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)  train_acc.append(epoch_train_acc)  test_acc.append(epoch_test_acc)  train_loss.append(epoch_train_loss)  test_loss.append(epoch_test_loss)  template = 'Epoch: {:2d}, Train_acc:{:.1f}%, Train_loss: {:.3f}%, Test_acc: {:.1f}%, Test_loss: {:.3f}%'  print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))  
print('Done')

image.png

# 可见化一下训练结果  
warnings.filterwarnings("ignore")  
plt.rcParams['font.sans-serif'] = ['SimHei']    # 显示中文不标签,不设置会显示中文乱码  
plt.rcParams['axes.unicode_minus'] = False      # 显示负号  
plt.rcParams['figure.dpi'] = 100                # 分辨率  epochs_range = range(epochs)  plt.figure(figsize=(12, 3))  
plt.subplot(1, 2, 1)  plt.plot(epochs_range, train_acc, label='训练正确率')  
plt.plot(epochs_range, test_acc, label='测试正确率')  
plt.legend(loc='lower right')  
plt.title('训练与测试正确率')  plt.subplot(1, 2, 2)  
plt.plot(epochs_range, train_loss, label='训练损失率')  
plt.plot(epochs_range, test_loss, label='测试损失率')  
plt.legend(loc='upper right')  
plt.title('训练与测试损失率')  plt.show()

image.png

四:预测一下自己手写的数字

准备数据:
image.png
再手动将每个数字切割成单独的一个文件:
image.png
注意,这里并没有将每个图片的大小切割成一致,理论上切割成要求的28*28是最好。我这里用代码来重新生成28 * 28大小的图片。

import torch  
import numpy as np  
from PIL import Image  
from torchvision import transforms  
import torch.nn as nn  
import torch.nn.functional as F  
import matplotlib.pyplot as plt  
import os, pathlib  # 第一步:设置硬件设备,有GPU就使用GPU,没有就使用GPU  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
print(device)  # 定义模型,要把模型搞过来嘛,不然加载模型会出错。  
class Model(nn.Module):  def __init__(self):  super().__init__()  # 特征提取网络  self.conv1 = nn.Conv2d(1, 32, kernel_size=3 ) # 第一层卷积,卷积核大小3*3  self.pool1 = nn.MaxPool2d(2)    # 池化层,池化核大小为2*2  self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # 第二层卷积,卷积核大小3*3  self.pool2 = nn.MaxPool2d(2)  # 分类网络  self.fc1 = nn.Linear(1600, 64)  self.fc2 = nn.Linear(64, 10)  def forward(self, x):  x = self.pool1(F.relu(self.conv1(x)))  x = self.pool2(F.relu(self.conv2(x)))  x = torch.flatten(x, start_dim=1)  x = F.relu(self.fc1(x))  x = self.fc2(x)  return x  # 加载模型  
model = torch.load('./models/cnn.pth')   
model.eval()  transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.1307,), (0.3081,))  
])  # 导入数据  
data_dir = "./mydata/handwrite"  
data_dir = pathlib.Path(data_dir)  
image_count = len(list(data_dir.glob('*.jpg')))  
print("图片总数量为:", image_count)  plt.rcParams['font.sans-serif'] = ['SimHei']    # 显示中文不标签,不设置会显示中文乱码  
plt.rcParams['axes.unicode_minus'] = False      # 显示负号  
plt.rcParams['figure.dpi'] = 100                # 分辨率  
plt.figure(figsize=(10, 10))  
i = 0  
for input_file in list(data_dir.glob('*.jpg')):  image = Image.open(input_file)  image_resize = image.resize((28, 28))   # 将图片转换成 28*28  image = image_resize.convert('L')  # 转换成灰度图  image_array = np.array(image)  # print(image_array.shape)    # (high, weight)  image = Image.fromarray(image_array)  image = transform(image)  image = torch.unsqueeze(image, 0)   # 返回维度为1的张量  image = image.to(device)  output = model(image)  pred = torch.argmax(output, dim=1)  image = torch.squeeze(image, 0)     # 返回一个张量,其中删除了大小为1的输入的所有指定维度  image = transforms.ToPILImage()(image)  plt.subplot(10, 4, i+1)  plt.tight_layout()  plt.imshow(image, cmap='gray', interpolation='none')  plt.title("实际值:{},预测值:{}".format(input_file.stem[:1], pred.item()))  plt.xticks([])  plt.yticks([])  i += 1  
plt.show()

image.png

准确性很低,40张图片预测准确数量:6,占比:15.0%.。看图片,感觉resize成28*28和转换成灰度图后,图片本身已经失真比较严重了。先把图片像素翻转一下,其实就是反色处理,加上这段代码:
image.png
image.png
准确率上了一个台阶(40张图片预测准确数量:30,占比:75.0%).。但是看图片,还是不清晰。

(三)总结
  1. epochs=5,预测的准确性达到97%,如果增加迭代的次数到10,准确性提升接近到99%。迭代20次则达到99.3,提升不明显。
    image.png
    image.png
  2. batch_size如何从32调整到64,准确性差不太多
    image.png
    image.png
  3. 后续研究图片增强

相关文章:

第P1周:Pytorch实现mnist手写数字识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 1. 实现pytorch环境配置 2. 实现mnist手写数字识别 3. 自己写几个数字识别试试具体实现 (一)环境 语言环境:Python…...

使用EventLog Analyzer进行Apache日志监控和日志分析

一、什么是Apache日志分析 Apache日志分析是网站管理和维护的重要部分,通过分析Apache服务器生成的日志文件,可以了解网站的访问情况、识别潜在的安全问题、优化网站性能等。 二、Apache日志类型 Apache日志主要有两种类型:访问日志&a…...

PaddleOCR模型ch_PP-OCRv3文本检测模型研究(二)颈部网络

上节研究了PaddleOCR文本检测v3模型的骨干网,本文接着研究其颈部网络。 文章目录 研究起点残注层颈部网络代码实验小结 研究起点 摘取开源yml配置文件,摘取网络架构Architecture中颈部网络的配置如下 Neck:name: RSEFPNout_channels: 96shortcut: True可…...

360极速浏览器不支持看PDF

360安全浏览器采用的是基于IE内核和Chrome内核的双核浏览器。360极速浏览器是源自Chromium开源项目的浏览器,不但完美融合了IE内核引擎,而且实现了双核引擎的无缝切换。因此在速度上,360极速浏览器的极速体验感更佳。 展示自己的时候要在有优…...

【深度学习】深刻理解ViT

ViT(Vision Transformer)是谷歌研究团队于2020年提出的一种新型图像识别模型,首次将Transformer架构成功应用于计算机视觉任务中。Transformer最初应用于自然语言处理(如BERT和GPT),而ViT展示了其在视觉任务…...

解决vue2中更新列表数据,页面dom没有重新渲染的问题

在 Vue 2 中,直接修改数组的某个项可能不会触发视图的更新。这是因为 Vue 不能检测到数组的索引变化或对象属性的直接赋值。为了确保 Vue 能够正确地响应数据变化,你可以使用以下几种方法: 1. 使用 Vue.set() 使用 Vue.set() 方法可以确保 …...

vscode通过ssh连接远程服务器(实习心得)

一、连接ssh服务器 1.打开Visual Studio Code,进入拓展市场(CtrlShiftX),下载拓展Remote - SSH 2. 点击远程资源管理器选项卡,并选择远程(隧道/SSH)类别 3. 点击ssh配置:输入你的账号主机ip地址 4.在弹出的选择配置文件中&#xf…...

知识图谱9:知识图谱的展示

1、知识图谱的展示有很多工具 Neo4j Browser - - - - 浏览器版本 Neo4j Desktop - - - - 桌面版本 graphX - - - - 可以集成到Neo4j Desktop Neo4j 提供的 Neo4j Bloom 是用户友好的可视化工具,适合非技术用户直观地浏览图数据。Cypher 是其核心查询语言&#x…...

leetcode 面试经典 150 题:验证回文串

链接验证回文串题序号125类型字符串解题方法双指针法难度简单 题目 如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后,短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 字母和数字都属于字母数字字符。 给你一个字符串 s&#xf…...

【0363】Postgres内核 从 XLogReaderState readBuf 解析 XLOG Record( 8 )

上一篇: 【0362】Postgres内核 XLogReaderState readBuf 有完整 XLOG page header 信息 ? ( 7 ) 直接相关: 【0341】Postgres内核 读取单个 xlog page (2 - 2 ) 文章目录 1. readBuf 获取 page header 大小1.1 XLOG record 跨 page ?1.2 获取 XLOG Record 的 长度(xl…...

docker tdengine windows快速体验

#拉取镜像 docker pull tdengine/tdengine:2.6.0.34#容器运行 docker run -d --name td2.6 --restartalways -p 6030:6030 -p 6041:6041 -p 6043:6043 -p 6044-6049:6044-6049 -p 6044-6045:6044-6045/udp -p 6060:6060 tdengine/tdengine:2.6.0.34#容器数据持久化到本地 #/va…...

详解RabbitMQ在Ubuntu上的安装

​​​​​​​ 目录 Ubuntu 环境安装 安装Erlang 查看Erlang版本 退出命令 ​编辑安装RabbitMQ 确认安装结果 安装RabbitMQ管理界面 启动服务 查看服务状态 通过IP:port访问 添加管理员用户 给用户添加权限 再次访问 Ubuntu 环境安装 安装Erlang RabbitMq需要…...

Python的3D可视化库【vedo】2-2 (plotter模块) 访问绘制器信息、操作渲染器

文章目录 4 Plotter类的方法4.1 访问Plotter信息4.1.1 实例信息4.1.2 演员对象列表 4.2 渲染器操作4.2.1 选择渲染器4.2.2 更新渲染场景 4.3 控制渲染效果4.3.1 渲染窗格的背景色4.3.2 深度剥离效果4.3.3 隐藏线框的线条4.3.4 改为平行投影模式4.3.5 添加阴影4.3.6 环境光遮蔽4…...

【vue2】文本自动省略组件,支持单行和多行省略,超出显示tooltip

代码见文末 vue3实现 最开始就用的vue3实现,如下 Vue3实现方式 vue2开发和使用文档 组件功能 TooltipText 是一个文字展示组件,具有以下功能: 文本显示:支持单行和多行文本显示。自动判断溢出:判断文本是否溢出…...

网络安全产品之认识防病毒软件

随着计算机技术的不断发展,防病毒软件已成为企业和个人计算机系统中不可或缺的一部分。防病毒软件是网络安全产品中的一种,主要用于检测、清除计算机病毒,以及预防病毒的传播。本文我们一起来认识一下防病毒软件。 一、什么是计算机病毒 计算…...

游戏引擎学习第42天

仓库: https://gitee.com/mrxiao_com/2d_game 简介 目前我们正在研究的内容是如何构建一个基本的游戏引擎。我们将深入了解游戏开发的每一个环节,从最基础的技术实现到高级的游戏编程。 角色移动代码 我们主要讨论的是角色的移动代码。我一直希望能够使用一些基…...

区块链智能合约( solidity) 安全编程

引言:本文由天玄链开源开发者提供,欢迎报名公益天玄链训练营 https://blockchain.163.com/trainingCamp 一、重入和竞态 重入和竞态在solidity 编程安全中会多次提及,历史上也造成了重大的损失。 1.1 问题分析 竞态的描述不严格&#xf…...

GUNS搭建

一、准备工作 源码下载: 链接: https://pan.baidu.com/s/1bJZzAzGJRt-NxtIQ82KlBw 提取码: criq 官方文档 二、导入代码 1、导入后端IDE 导入完成需要,需要修改yml文件中的数据库配置,改成自己的。 2、导入前端IDE 我是用npm安装的yarn npm…...

【ETCD】【源码阅读】stepWithWaitOption方法解析

在分布式系统中,ETCD 作为一个强一致性、高可用的 key-value 存储系统,广泛应用于服务发现、配置管理等场景。ETCD 在内部采用了 Raft 协议来保证集群的一致性,而日志预提案(log proposal)是 Raft 协议中至关重要的一部…...

redis 怎么样查看list

在 Redis 中,可以通过以下方法查看列表的内容或属性: 1. 查看列表中的所有元素 使用 LRANGE 命令: LRANGE key start endkey 是列表的名称。start 是起始索引,0 表示第一个元素。end 是结束索引,-1 表示最后一个元素…...

RestClient

什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级&#xff…...

在软件开发中正确使用MySQL日期时间类型的深度解析

在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题(可多选) 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘:专注于发现数据中…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中,高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术,实现年省电费15%-60%,且不改动原有装备、安装快捷、…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全:零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言:云原生安全的范式革命 随着云原生技术的普及,安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测,到2025年,零信任架构将成为超…...

九天毕昇深度学习平台 | 如何安装库?

pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子: 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...

《C++ 模板》

目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板,就像一个模具,里面可以将不同类型的材料做成一个形状,其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式:templa…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...