当前位置: 首页 > news >正文

CNN卷积网络实现MNIST数据集手写数字识别

步骤一:加载MNIST数据集

train_data = MNIST(root='./data',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=64)
# 测试数据集
test_data = MNIST(root='./data',train=False,download=False,transform=transforms.ToTensor())
test_loader = DataLoader(test_data,shuffle=False,batch_size=64)

首先,通过MNIST类创建了train_data对象,指定了数据集的路径root='./data',并且将数据集标记为训练集train=Truedownload=False表示不自动从网络上下载数据集,而是使用已经下载好的数据集。我是之前自己已经下载过该数据集所以这里填的是False,如果之前没有下载的话就要填True。下面测试集也是一样。transforms.ToTensor()将数据转换为张量形式。

然后,通过DataLoader类创建了train_loader对象,指定了使用train_data作为数据源。shuffle=True表示在每个epoch开始时,将数据打乱顺序。batch_size=64表示每次抓取64个样本。

接下来,同样的步骤也被用来创建了测试集的数据加载器test_loader。不同的是,这里将数据集标记为测试集train=False,并且shuffle=False表示不需要打乱顺序。

加载完的数据集存在MNIST文件夹的raw文件夹下内容如下:

其中t10k-images-idx3-ubyte是测试集的图像,t10k-labels-idx3-ubyte是测试集的标签。train-images-idx3-ubyte是训练集的图像,train-labels-idx1-ubyte是训练集的标签。

存下来的这些数据集是二进制的形式,可以通过下面的代码(1.py)读取:

"""
Created on Sat Jul 27 15:26:39 2024@author: wangyiyuan
"""
# 导入包
import struct
import numpy as np
from PIL import Imageclass MnistParser:# 加载图像def load_image(self, file_path):# 读取二进制数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>iiii'offset = 0# 读取头文件magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('图片数量:%d,图片行数:%d,图片列数:%d'%(images_number,rows_number,columns_number))# 处理数据image_size = rows_number * columns_numberfmt_data = '>'+str(image_size)+'B'offset = offset + struct.calcsize(fmt_head)# 读取数据images = np.empty((images_number,rows_number,columns_number))for i in range(images_number):images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1) % 10000 == 0:print('> 已读取:%d张图片'%(i+1))# 返回数据return images_number,rows_number,columns_number,images# 加载标签def load_labels(self, file_path):# 读取数据binary = open(file_path,'rb').read()# 读取头文件fmt_head = '>ii'offset = 0# 读取头文件magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)# 打印头文件信息print('标签数:%d'%(items_number))# 处理数据fmt_data = '>B'offset = offset + struct.calcsize(fmt_head)# 读取数据labels = np.empty((items_number))for i in range(items_number):labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]offset = offset + struct.calcsize(fmt_data)# 每1万张打印一次信息if (i+1)%10000 == 0:print('> 已读取:%d个标签'%(i+1))# 返回数据return items_number,labels# 图片可视化def visualaztion(self, images, labels, path):d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}for i in range(images.__len__()):im = Image.fromarray(np.uint8(images[i]))im.save(path + "%d_%d.png"%(labels[i], d[labels[i]]))d[labels[i]] += 1# im.show()if (i+1)%10000 == 0:print('> 已保存:%d个图片'%(i+1))# 保存为图片格式
def change_and_save():mnist =  MnistParser()trainImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(trainImageFile)trainLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(trainLabelFile)mnist.visualaztion(images, labels, "./images/train/")testImageFile = './train-images-idx3-ubyte'_, _, _, images = mnist.load_image(testImageFile)testLabelFile = './train-labels-idx1-ubyte'_, labels = mnist.load_labels(testLabelFile)mnist.visualaztion(images, labels, "./images/test/")# 测试
if __name__ == '__main__':change_and_save()

将这个1.py文件和下载好的数据集放在同一个文件夹下:

新建一个文件夹images,在文件夹images里面新建两个文件夹分别叫test和train。

运行完可以发现train和test里的内容如下:

步骤二:建立模型

class Model(nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = nn.Linear(784,256)self.linear2 = nn.Linear(256,64)self.linear3 = nn.Linear(64,10) # 10个手写数字对应的10个输出def forward(self,x):x = x.view(-1,784) # 变形x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))# x = torch.relu(self.linear3(x))return x

这里是建立了一个神经网络模型类(Model)。这个模型有三个线性层(linear1、linear2、linear3)。输入维度为784(因为每一张图片的大小是28*28=784),输出维度为256、64、10(因为有十个类)。forward函数定义了模型的前向传播过程,其中x.view(-1, 784)将输入张量x变形为(batch_size, 784)的大小。然后经过三个线性层和relu激活函数进行运算,最后返回输出结果x。

步骤三:训练模型

model = Model()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,相当于Softmax+Log+NllLoss
optimizer = torch.optim.SGD(model.parameters(),0.8) # 第一个参数是初始化参数值,第二个参数是学习率# 模型训练
# def train():
for index,data in enumerate(train_loader):input,target = data # input为输入数据,target为标签optimizer.zero_grad() # 梯度清零y_predict = model(input) # 模型预测loss = criterion(y_predict,target) # 计算损失loss.backward() # 反向传播optimizer.step() # 更新参数if index % 100 == 0: # 每一百次保存一次模型,打印损失torch.save(model.state_dict(),"./model/model.pkl") # 保存模型torch.save(optimizer.state_dict(),"./model/optimizer.pkl")print("损失值为:%.2f" % loss.item())

首先创建了一个模型对象model,一个损失函数对象criterion和一个优化器对象optimizer。然后使用一个for循环遍历训练数据集train_loader,每次取出一个batch的数据。接着将优化器的梯度清零,然后使用模型前向传播得到预测结果y_predict,计算损失值loss,然后进行反向传播和参数更新。每训练100个batch,保存模型和优化器的参数,并打印当前的损失值。

步骤四:保存模型参数

if os.path.exists('./model/model.pkl'):model.load_state_dict(torch.load("./model/model.pkl")) # 加载保存模型的参数

在当前文件夹下新建一个名叫model的文件夹。保存步骤三中训练完模型的参数。

步骤五:检验模型

correct = 0 # 正确预测的个数total = 0 # 总数with torch.no_grad(): # 测试不用计算梯度for data in test_loader:input,target = dataoutput=model(input) # output输出10个预测取值,其中最大的即为预测的数probability,predict=torch.max(output.data,dim=1) # 返回一个元组,第一个为最大概率值,第二个为最大值的下标total += target.size(0) # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item() # predict和target均为(batch_size,1)的矩阵,sum()求出相等的个数print("准确率为:%.2f" % (correct / total))

参数说明:

  • correct:记录正确预测的个数
  • total:记录总样本数
  • test_loader:测试集的数据加载器
  • input:输入数据
  • target:目标标签
  • output:模型的输出结果
  • probability:最大概率值
  • predict:最大值的下标

过程:

  • 使用torch.no_grad()包装测试过程,表示不需要计算梯度
  • 遍历测试集中的每个数据,获取输入数据和目标标签
  • 将输入数据输入模型,得到模型的输出结果
  • 使用torch.max()函数返回预测结果中的最大概率值和最大值的下标
  • 更新总数和正确预测的个数
  • 最后计算并输出准确率。

步骤六:检测自己的手写数据

if __name__ == '__main__':# 自定义测试image = Image.open('C:/Users/wangyiyuan/Desktop/20201116160729670.jpg') # 读取自定义手写图片image = image.resize((28,28)) # 裁剪尺寸为28*28image = image.convert('L') # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1,1,28,28)output = model(image)probability,predict=torch.max(output.data,dim=1)print("此手写图片值为:%d,其最大概率为:%.2f" % (predict[0],probability))plt.title('此手写图片值为:{}'.format((int(predict))),fontname="SimHei")plt.imshow(image.squeeze())plt.show()

这里的C:/Users/wangyiyuan/Desktop/20201116160729670.jpg是我自己从网上找的的手写图片。这段代码意思如下:

  1. 打开并读取一张手写图片,图片的路径为'C:/Users/wangyiyuan/Desktop/20201116160729670.jpg'。
  2. 调整图片尺寸为28x28。
  3. 将图片转换为灰度图像,以便后续处理。
  4. 使用transforms.ToTensor()将图片转换为PyTorch张量。
  5. 调整图片尺寸为(1, 1, 28, 28)以适应模型的输入要求。
  6. 将处理后的图片输入模型,获取预测输出。
  7. 通过torch.max函数获得输出中的最大值及其索引,即预测的数字和其概率。
  8. 打印预测的数字和概率。
  9. 在图像上显示预测结果和手写图片。
  10. 展示图像。

步骤七:结果展示

我的原图是:

测试得到的结果为:


损失值为:4.16
损失值为:0.93
损失值为:0.31
损失值为:0.19
损失值为:0.24
损失值为:0.15
损失值为:0.13
损失值为:0.11
损失值为:0.18
损失值为:0.02
此手写图片值为:2,其最大概率为:6.57

相关文章:

CNN卷积网络实现MNIST数据集手写数字识别

步骤一:加载MNIST数据集 train_data MNIST(root./data,trainTrue,downloadFalse,transformtransforms.ToTensor()) train_loader DataLoader(train_data,shuffleTrue,batch_size64) # 测试数据集 test_data MNIST(root./data,trainFalse,downloadFalse,transfor…...

深入理解Java中的时间处理与时区管理

在Java开发中,时间处理和时区管理是常见的需求,特别是在全球化应用中。Java 8引入了新的时间API(java.time包),使时间处理变得更加直观和高效。本文将详细介绍Java中的时间处理与时区管理,通过丰富的代码示…...

虚拟机windows server创建域

目录 准备工作 一、新建域控制器 二、提升为域控制器添加新林 三、新建组织单位(OU),用户 四、将计算机加域 五、在域控中管理计算机 六、在域控中配置组策略 七、域内计算机验证组策略配置 准备工作 安装域前,如果有DNS…...

Java 集合框架:Java 中的 Set 集合(HashSet LinkedHashSet TreeSet)特点与实现解析

大家好,我是栗筝i,这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 017 篇文章,在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验,并希望进一步完善自己对整个 Java 技术体系来充实自…...

springboot智能健康管理平台-计算机毕业设计源码57256

摘要 在当今社会,人们越来越重视健康饮食和健康管理。借助SpringBoot框架和MySQL数据库的支持,开发智能健康管理平台成为可能。该平台结合了小程序技术的便利性和SpringBoot框架的快速开发能力,为用户提供了便捷的健康管理解决方案。 通过智能…...

LetterBox图像预处理方法

LetterBox图像预处理方法就是要将不同分辨率的图像转换成固定分辨率,比如v8输入网络的固定分辨率为6406403,因此这里分享一下默认情况下对训练集、验证集和测试图片做的letterBox的方法。 1.LetterBox-Train 对于训练集,默认输入网络的图像尺寸为640640,假设有一张7201280…...

C++第五篇 类和对象(下) 初始化列表

目录 1.再探构造函数 2.类型转换 3.static成员 4.友元 friiend 1.再探构造函数 (1).之前我们实现构造函数时,初始化成员变量主要使用函数体内赋值,构造函数初始化还有一种方式,就是初始化列表,初始化列表的使用方式是以一个冒…...

C#中的通信

上位机应用开发-串口通信1、基于C#的串口通信对象:SerialPort 2、字段属性 PortName:获取或设置通信端口 BaudRate:获取或设置串行波特率-DataBits:获取或设置每个字节的标准数据位长度 Parity:获取或设置奇偶校验检查协仪I-StopBits;获取或设置每个字节的标准停止位数 3、…...

CVE-2022-21663: WordPress <5.8.3 版本对象注入漏洞深入分析

引言 在网络安全领域,技术的研究与讨论是不断进步的动力。本文针对WordPress的一个对象注入漏洞进行分析,旨在分享技术细节并提醒安全的重要性。特别强调:本文内容仅限技术研究,严禁用于非法目的。 漏洞背景 继WordPress CVE-2…...

C语言笔试题(三)

本专栏通过整理各专业方向的面试资料并咨询业界相关人士,整合不同方向的面试资料,希望能为您的面试道路点亮一盏灯! 1 简单题 如何声明一个二维数组? 答案: int arr[3][4];解析: 二维数组可以看作数组的数组。 union和struct…...

minio笔记之windows下安装使用

minio安装使用 去官网下载安装包启动访问管理平台创建桶创建用户、资源授权访问访问策略创建创建用户创建accessKey,用于应用程序开发 去官网下载安装包 直接安装即可 启动 设置密码 set MINIO_ROOT_USERadmin set MINIO_ROOT_PASSWORD12345678 cd到安装目录 mi…...

代码随想录算法训练营day31 | 56. 合并区间、738.单调递增的数字

碎碎念:加油 参考:代码随想录 56. 合并区间 题目链接 56. 合并区间 思想 这道题的核心还是判断重叠区间,本题和之前做过的452. 用最少数量的箭引爆气球、435. 无重叠区间的区别在于判断出重叠区间之后的操作,本题需要做的是合…...

利用 Python 制作图片轮播应用

在这篇博客中,我将向大家展示如何使用 xPython 创建一个图片轮播应用。这个应用能够从指定文件夹中加载图片,定时轮播,并提供按钮来保存当前图片到收藏夹或仅轮播收藏夹中的图片。我们还将实现退出按钮和全屏显示的功能。 C:\pythoncode\new\…...

报表系统之Cube.js

Cube.js 是一个开源的分析框架,专为构建数据应用和分析工具而设计。它的主要目的是简化和加速构建复杂的分析和数据可视化应用。以下是对 Cube.js 的详细介绍: 核心功能和特点 1. 多数据源支持 Cube.js 支持从多个数据源中提取数据,包括 SQ…...

代码随想录算法训练营第45天

115.不同的子序列 但相对于刚讲过 392.判断子序列,本题 就有难度了 ,感受一下本题和 392.判断子序列 的区别。 代码随想录 class Solution {public int numDistinct(String s, String t) {int lenS s.length();int lenT t.length();int[][] dp new …...

solidity合约创建

合约可以通过使用new关键字来创建其他合约的实例。 这个过程会执行被创建合约的构造函数(如果存在的话),并返回一个指向新创建合约的地址的引用。 这种方式允许智能合约动态地在区块链上部署新合约,并与它们交互。 通过 new 创…...

队列---循环队列实现

循环队列详解 概述 循环队列是一种基于数组实现的队列数据结构,其中队列的队首和队尾是通过模运算连接起来形成一个逻辑上的环形结构。这样可以有效地利用数组的空间,避免出现“假溢出”的情况。 结构体定义 循环队列的结构体定义如下: …...

【视频讲解】后端增删改查接口有什么用?

B站视频地址 B站视频地址 前言 “后端增删改查接口有什么用”,其实这句话可以拆解为下面3个问题。 接口是什么意思?后端接口是什么意思?后端接口中的增删改查接口有什么用? 1、接口 概念:接口的概念在不同的领域中…...

双指针hard题

[LeetCode]4. Median of Two Sorted Arrays 中文 - YouTube 依赖merge sort和priorityqueue的废物 正式变身山景城一姐小迷妹✪ω✪ 寻找正序数组中位数 class Solution {public double findMedianSortedArrays(int[] nums1, int[] nums2) {int len1 nums1.length;int len2 …...

前端实现【 批量任务调度管理器 】demo优化

一、前提介绍 我在前文实现过一个【批量任务调度管理器】的 demo,能实现简单的任务批量并发分组,过滤等操作。但是还有很多优化空间,所以查找一些优化的库, 主要想优化两个方面, 上篇提到的: 针对 3&…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式,可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理&#xff1a…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?

大家好,欢迎来到《云原生核心技术》系列的第七篇! 在上一篇,我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在,我们就像一个拥有了一块崭新数字土地的农场主,是时…...

(十)学生端搭建

本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

day52 ResNet18 CBAM

在深度学习的旅程中&#xff0c;我们不断探索如何提升模型的性能。今天&#xff0c;我将分享我在 ResNet18 模型中插入 CBAM&#xff08;Convolutional Block Attention Module&#xff09;模块&#xff0c;并采用分阶段微调策略的实践过程。通过这个过程&#xff0c;我不仅提升…...

定时器任务——若依源码分析

分析util包下面的工具类schedule utils&#xff1a; ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类&#xff0c;封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz&#xff0c;先构建任务的 JobD…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入&#xff0c;一个是通过INMP441麦克风模块采集音频&#xff0c;一个是通过PCM5102A模块播放音频&#xff0c;那如果我们将两者结合起来&#xff0c;将麦克风采集到的音频通过PCM5102A播放&#xff0c;是不是就可以做一个扩音器了呢…...