深度学习:使用卷积神经网络CNN实现MNIST手写数字识别
引言
本项目基于pytorch构建了一个深度学习神经网络,网络包含卷积层、池化层、全连接层,通过此网络实现对MINST数据集手写数字的识别,通过本项目代码,从原理上理解手写数字识别的全过程,包括反向传播,梯度下降等。
1 卷积神经网络介绍
1.1 什么是卷积神经网络
卷积神经网络是一种多层、前馈型神经网络。从功能上来说,可以分为两个阶段,特征提取阶段和分类识别阶段。

特征提取阶段能够自动提取输入数据中的特征作为分类的依据,它由多个特征层堆叠而成,每个特征层又由卷积层和池化层组成。处在前面的特征层捕获图像中局部细节的信息,而后面的特征层能够捕获到图像中更加高层、抽象的信息。
1.1.1 卷积核(Convolution Kernel)
在卷积神经网络的卷积层中,一个神经元只与部分邻层神经元连接。在CNN的一个卷积层中,通常包含若干个特征图(featureMap),每个特征图由一些矩形排列的的神经元组成,同一特征图的神经元共享权值,这里共享的权值就是卷积核。卷积核一般以随机小数矩阵的形式初始化,在网络的训练过程中卷积核将学习得到合理的权值。共享权值(卷积核)带来的直接好处是减少网络各层之间的连接,同时又降低了过拟合的风险。
1.1.2 感受野(Receptive Field)
定义:在卷积神经网络中,卷积神经网络每一层输出的特征图(feature map)上的像素点在输入图片上映射的区域大小。在典型CNN结构中,FC层每个输出节点的值都依赖FC层所有输入,而CONV层每个输出节点的值仅依赖CONV层输入的一个区域, 这个区域之外的其他输入值都不会影响输出值,该区域就是感受野。下图为感受野示意图:


当我们采用尺寸不同的卷积核时,最大的区别就是感受野的大小不同,所以经常会采用多层小卷积核来替换一层大卷积核,在保持感受野相同的情况下减少参数量和计算量。例如十分常见的用两层3*3卷积核来替换一层5*5卷积核的方法

1.3 标准化(Batch Normalization)
在引入BN之前,以前的model training有一些系统性的问题,导致很多算法收敛速度都非常慢,甚至根本就不能工作,尤其在使用sigmoid激活函数时。在机器学习中我们通常会对输入特征进行标准化或归一化,因为直接输入的数据每个维度量纲可能不同、数值差别很大,导致模型不能很好地从各个特征中学习。当上一层输出值太大或太小,其经过sigmoid激活函数时会落在饱和区域,反向传播会有梯度消失的问题。
批标准化(Batch Normalization):对一小批数据(batch),做标准化处理。使数据符合0均值,1为标准差的分布。

Batch Normalization层通常添加在每个神经网络层和激活层之间,对神经网络层输出的数据分布进行统一和调整,变成均值为0方差为1的标准正态分布,解决神经网络中梯度消失的问题使输出位于激活层的非饱和区,达到加快收敛的效果。

1.1.4 池化层(Pooling)
池化 (Pooling) 用来降低神经网络中的特征图(Feature Map)的维度。在卷积神经网络中,池化操作通常紧跟在卷积操作之后,用于降低特征图的空间大小。池化操作的基本思想是将特征图划分为若干个子区域(一般为矩形),并对每个子区域进行统计汇总。池化通常有均值子池化(mean pooling)和最大值池化(max pooling)两种形式。池化可以看作一种特殊的卷积过程。卷积和池化大大简化了模型复杂度,减少了模型的参数。
- 最大值池化可提取图片纹理
- 均值池化可保留背景特征

1.2 卷积的计算过程
假设我们输入的是5*5*1的图像,中间的那个3*3*1是我们定义的一个卷积核(简单来说可以看做一个矩阵形式运算器),通过原始输入图像和卷积核做运算可以得到绿色部分的结果,怎么样的运算呢?实际很简单就是我们看左图中深色部分,处于中间的数字是图像的像素,处于右下角的数字是我们卷积核的数字,只要对应相乘再相加就可以得到结果。例如图中‘3*0+1*1+2*2+2*2+0*2+0*0+2*0+0*1+0*2=9’
计算过程如下动图:

图中最左边的三个输入矩阵就是我们的相当于输入d=3时有三个通道图,每个通道图都有一个属于自己通道的卷积核,我们可以看到输出(output)的只有两个特征图意味着我们设置的输出d=2,有几个输出通道就有几层卷积核(比如图中就有FilterW0和FilterW1),这意味着我们的卷积核数量就是输入d的个数乘以输出d的个数(图中就是2*3=6个),其中每一层通道图的计算与上文中提到的一层计算相同,再把每一个通道输出的输出再加起来就是绿色的输出数字。
步长:每次卷积核移动的大小
输出特征尺寸计算:在了解神经网络中卷积计算的整个过程后,就可以对输出特征图的尺寸进行计算。如下图所示,5×5的图像经过3×3大小的卷积核做卷积计算后输出特征尺寸为3×3

全零填充
当卷积核尺寸大于 1 时,输出特征图的尺寸会小于输入图片尺寸。如果经过多次卷积,输出图片尺寸会不断减小。为了避免卷积之后图片尺寸变小,通常会在图片的外围进行填充(padding),如下图所示
全零填充(padding):为了保持输出图像尺寸与输入图像一致,经常会在输入图像周围进行全零填充,如下所示,在5×5的输入图像周围填0,则输出特征尺寸同为5×5。

当padding=1和paadding=2时,如下图所示:

2 使用CNN实现MNIST手写数字识别
机器识图的过程:机器识别图像并不是一下子将一个复杂的图片完整识别出来,而是将一个完整的图片分割成许多个小部分,把每个小部分里具有的特征提取出来,再将这些小部分具有的特征汇总到一起,从而完成机器识别整个图像。
2.1 MNIST数据介绍
MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。其中的图像的尺寸为28*28。采样数据显示如下:

2.2 基于pytorch的代码实现
import torch
import torch.nn as nn
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import numpy as np#获取数据集
train_data=dataset.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=True)
test_data=dataset.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=False)
train_loader=data_utils.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader=data_utils.DataLoader(dataset=test_data, batch_size=64, shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#创建网络
class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv=nn.Conv2d(1, 32, kernel_size=5, padding=2)self.bat2d=nn.BatchNorm2d(32)self.relu=nn.ReLU()self.pool=nn.MaxPool2d(2)self.linear=nn.Linear(14 * 14 * 32, 70)self.tanh=nn.Tanh()self.linear1=nn.Linear(70,30)self.linear2=nn.Linear(30, 10)def forward(self,x):y=self.conv(x)y=self.bat2d(y)y=self.relu(y)y=self.pool(y)y=y.view(y.size()[0],-1)y=self.linear(y)y=self.tanh(y)y=self.linear1(y)y=self.tanh(y)y=self.linear2(y)return y
cnn=Net()
cnn = cnn.to(device)#损失函数
los=torch.nn.CrossEntropyLoss()#优化函数
optime=torch.optim.Adam(cnn.parameters(), lr=0.001)#训练模型
accuracy_rate = [0]
num_epochs = 10
for epo in range(num_epochs):for i, (images,lab) in enumerate(train_loader):images=images.to(device)lab=lab.to(device)out = cnn(images)loss=los(out,lab)optime.zero_grad()loss.backward()optime.step()print("epo:{},i:{},loss:{}".format(epo+1,i,loss))#测试模型loss_test=0accuracy=0with torch.no_grad():for j, (images_test,lab_test) in enumerate(test_loader):images_test = images_test.to(device)lab_test=lab_test.to(device)out1 = cnn(images_test)loss_test+=los(out1,lab_test)loss_test=loss_test/(len(test_data)//100)_,p=out1.max(1)accuracy += (p==lab_test).sum().item()accuracy=accuracy/len(test_data)accuracy_rate.append(accuracy)print("loss_test:{},accuracy:{}".format(loss_test,accuracy))accuracy_rate = np.array(accuracy_rate)
times = np.linspace(0, num_epochs, num_epochs+1)
plt.xlabel('times')
plt.ylabel('accuracy rate')
plt.plot(times, accuracy_rate)
plt.show()
运行结果:
epo:1,i:937,loss:0.2277517020702362
loss_test:0.0017883364344015718,accuracy:0.9729
epo:2,i:937,loss:0.01490325853228569
loss_test:9.064914047485217e-05,accuracy:0.9773
epo:3,i:937,loss:0.0903361514210701
loss_test:0.0003304268466308713,accuracy:0.9791
epo:4,i:937,loss:0.003910894505679607
loss_test:0.00019427068764343858,accuracy:0.9845
epo:5,i:937,loss:0.011963552795350552
loss_test:3.232352901250124e-05,accuracy:0.983
epo:6,i:937,loss:0.04549657553434372
loss_test:0.0001462855434510857,accuracy:0.9859
epo:7,i:937,loss:0.02365218661725521
loss_test:3.670657861221116e-06,accuracy:0.9867
epo:8,i:937,loss:0.00040980291669256985
loss_test:1.4913265658833552e-05,accuracy:0.9872
epo:9,i:937,loss:0.024399513378739357
loss_test:7.590289897052571e-05,accuracy:0.9865
epo:10,i:937,loss:0.0012365489965304732
loss_test:0.00014759502664674073,accuracy:0.9869

3 总结
本文介绍了卷积神经网络中的关键概念,包含卷积核、池化、标准化、感受野等,并基于MNIST数据集,构建了卷积神经网络识别模型,经过10个epochs训练,正确率达到了98%,充分展示了卷积神经网络在图片识别中的作用。
相关文章:
深度学习:使用卷积神经网络CNN实现MNIST手写数字识别
引言 本项目基于pytorch构建了一个深度学习神经网络,网络包含卷积层、池化层、全连接层,通过此网络实现对MINST数据集手写数字的识别,通过本项目代码,从原理上理解手写数字识别的全过程,包括反向传播,梯度…...
docker search 镜像报错: connect: no route to host (桥接模式配置静态IP)
如下 原因 可能有多种: ① 没有开放防火墙端口 ② ip地址配置有误 解决 我是因为虚拟机采用了桥接模式,配置静态ip地址有问题。 先确认虚拟机采用的是 桥接模式,然后启动虚拟机。 1、打开命令行,输入下面指令,打开…...
【VUE】[Violation] Added non-passive event listener to a scroll-blocking...
环境 chrome: 115.0.5790.170vue: ^3.3.4element-plus: ^2.3.4vite: ^4.4.7 问题 [Violation] Added non-passive event listener to a scroll-blocking <某些> 事件. Consider marking event handler as passive to make the page more responsive. See <URL> …...
runit-docker中管理多个服务
runit-docker中管理多个服务 介绍Runit, systemctl和supervisor是三种不同的服务管理工具区别runit优点程序构成快速开始runit实现服务退出执行指定操作runit监管服务打印日志到syslogrunit监管服务后台运行runit监管服务一些错误总结 介绍 runit 是一个轻量级的、稳定的、跨平…...
Intune 应用程序管理
由于云服务提供了增强的安全性、稳定性和灵活性,越来越多的组织正在采用基于云的解决方案来满足他们的需求。这正是提出Microsoft Endpoint Manager等解决方案的原因,它结合了SCCM和Microsoft Intune,以满足本地和基于云的端点管理。 与 Int…...
Oracle DB 安全性 : TDE HSM TCPS Wallet Imperva
• 配置口令文件以使用区分大小写的口令 • 对表空间进行加密 • 配置对网络服务的细粒度访问 TCPS 安全口令支持 Oracle Database 11g中的口令: • 区分大小写 • 包含更多的字符 • 使用更安全的散列算法 • 在散列算法中使用salt 用户名仍是Oracle 标识…...
leetcode27—移除元素
思路: 参考26题目双指针的思想,只不过这道题不是快慢指针。 看到示例里面数组是无序的,也就是说后面的元素也是可能跟给定 val值相等的,那么怎么处理呢。就想到了从前往后遍历,如果left对应的元素 val时,…...
flask---》更多查询方式/连表查询/原生sql(django-orm如何执行原生sql)/flask-sqlalchemy
更多查询方式 #1 查询: filer:写条件 filter_by:等于的值 # 查询所有 是list对象 res session.query(User).all() # 是个普通列表 print(type(res)) print(len(res))# 2 只查询某几个字段 # select name as xx,email from user; res session.…...
Chromium内核浏览器编译记(三)116版本内核UI定制
转载请注明出处:https://blog.csdn.net/kong_gu_you_lan/article/details/132180843?spm1001.2014.3001.5501 本文出自 容华谢后的博客 往期回顾: Chromium内核浏览器编译记(一)踩坑实录 Chromium内核浏览器编译记(…...
LoRaWan网关设计架构介绍
LoRa 数据包转发器是在基于 LoRa 的网关(带或不带 GPS)主机上运行的程序。它将集中器(上行链路)接收到的 RF 数据包通过安全的 IP 链路转发到LoRaWAN 网络服务器( LNS )。它还通过相同的安全 IP 将 LNS(下行链路)发送的 RF 数据包传输到一台或多台设备。此外,它还可以传…...
vue 全局状态管理(简单的store模式、使用Pinia)
目录 为什么使用状态管理简单的store模式服务器渲染(SSR) pinia简介示例1. 定义一个index.ts文件2. 在main.ts中引入3. 定义4. 使用 为什么使用状态管理 多个组件可能会依赖同一个状态时,我们有必要抽取出组件内的共同状态集中统一管理&…...
ORACLE和MYSQL区别
1,Oracle没有offet,limit,在mysql中我们用它们来控制显示的行数,最多的是分页了。oracle要分页的话,要换成rownum。 2,oracle建表时,没有auto_increment,所有要想让表的一个字段自增,…...
tensorflow 1.14 的 demo 02 —— tensorboard 远程访问
tensorflow 1.14.0, 提供远程访问 tensorboard 服务的方法 第一步生成 events 文件: 在上一篇demo的基础上加了一句,如下, tf.summary.FileWriter("./tmp/summary", graphsess1.graph) hello_tensorboard_remote.py …...
Spring中Bean的循环依赖问题
1.什么是Bean的循环依赖? 简单来说就是在A类中,初始化A时需要用到B对象,而在B类中,初始化B时需要用到A对象,这种状况下在Spring中,如果A和B同时初始化,A,B同时都需要对方的资源&…...
若依管理系统后端将 Mybatis 升级为 Mybatis-Plus
文章目录 说明流程增加依赖修改配置文件注释掉MybatisConfig里面的Bean 代码生成使用IDEA生成代码注意 Controller文件 说明 若依管理系统是一个非常完善的管理系统模板,里面含有代码生成的方法,可以帮助用户快速进行开发,但是项目使用的是m…...
剪切、复制、粘贴事件
剪切、复制、粘贴事件 oncopy 事件在用户拷贝元素上的内容时触发。onbeforecut 事件在用户剪切文本,且文本还未删除时触发触发。oncut 事件在用户剪切元素的内容时触发。onbeforepaste 事件在用户向元素中粘贴文本之前触发。onpaste 事件在用户向元素中粘贴文本时触…...
Redis储存结构
Redis怎么储存的 这个redisDb是数据库对象 里面的其他字段忽略了 然后里面有个dict列表(字典列表) 我们随便来看一个redisObject 区分一下子啊 他这个dict里面没有存redisObject的对象 也没有存dict对象 它只是存了个数据指针 你看那个redis每个底层编码 抠搜的 这块要是再保存…...
使用logback异步打印日志
文章目录 一、介绍二、运行环境三、演示项目1. 接口2. 日志配置文件3. 效果演示4. 异步输出验证 四、异步输出原理五、其他参数配置六、源码分析1. 同步输出2. 异步输出 七、总结 一、介绍 对于每一个开发人员来说,在业务代码中添加日志是至关重要的,尤…...
ArcGIS Pro暨基础入门、制图、空间分析、影像分析、三维建模、空间统计分析与建模、python融合、案例应用
GIS是利用电子计算机及其外部设备,采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲,它是在一定的地域内,将地理空间信息和 一些与该地域地理信息相关的属性信息结合起来,达到对地理和属性信息的综合管理。GIS的…...
Rabbitmq的消息确认
配置文件 spring:rabbitmq:publisher-confirm-type: correlated #开启确认回调publisher-returns: true #开启返回回调listener:simple:acknowledge-mode: manual #设置手动接受消息消息从生产者到交换机 无论消息是否到交换机ConfirmCallback都会触发。 Resourceprivate Rabb…...
解锁数据库简洁之道:FastAPI与SQLModel实战指南
在构建现代Web应用程序时,与数据库的交互无疑是核心环节。虽然传统的数据库操作方式(如直接编写SQL语句与psycopg2交互)赋予了我们精细的控制权,但在面对日益复杂的业务逻辑和快速迭代的需求时,这种方式的开发效率和可…...
为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?
在建筑行业,项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升,传统的管理模式已经难以满足现代工程的需求。过去,许多企业依赖手工记录、口头沟通和分散的信息管理,导致效率低下、成本失控、风险频发。例如&#…...
华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建
华为云FlexusDeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色,华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型,能助力我们轻松驾驭 DeepSeek-V3/R1,本文中将分享如何…...
在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...
Go语言多线程问题
打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...
【JVM】Java虚拟机(二)——垃圾回收
目录 一、如何判断对象可以回收 (一)引用计数法 (二)可达性分析算法 二、垃圾回收算法 (一)标记清除 (二)标记整理 (三)复制 (四ÿ…...
作为测试我们应该关注redis哪些方面
1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...
tomcat入门
1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效,稳定,易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...
DBLP数据库是什么?
DBLP(Digital Bibliography & Library Project)Computer Science Bibliography是全球著名的计算机科学出版物的开放书目数据库。DBLP所收录的期刊和会议论文质量较高,数据库文献更新速度很快,很好地反映了国际计算机科学学术研…...
华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)
题目描述 给定一个整型数组,请从该数组中选择3个元素 组成最小数字并输出 (如果数组长度小于3,则选择数组中所有元素来组成最小数字)。 输入描述 行用半角逗号分割的字符串记录的整型数组,0<数组长度<= 100,0<整数的取值范围<= 10000。 输出描述 由3个元素组成…...
