深度学习:使用卷积神经网络CNN实现MNIST手写数字识别
引言
本项目基于pytorch构建了一个深度学习神经网络,网络包含卷积层、池化层、全连接层,通过此网络实现对MINST数据集手写数字的识别,通过本项目代码,从原理上理解手写数字识别的全过程,包括反向传播,梯度下降等。
1 卷积神经网络介绍
1.1 什么是卷积神经网络
卷积神经网络是一种多层、前馈型神经网络。从功能上来说,可以分为两个阶段,特征提取阶段和分类识别阶段。
特征提取阶段能够自动提取输入数据中的特征作为分类的依据,它由多个特征层堆叠而成,每个特征层又由卷积层和池化层组成。处在前面的特征层捕获图像中局部细节的信息,而后面的特征层能够捕获到图像中更加高层、抽象的信息。
1.1.1 卷积核(Convolution Kernel)
在卷积神经网络的卷积层中,一个神经元只与部分邻层神经元连接。在CNN的一个卷积层中,通常包含若干个特征图(featureMap),每个特征图由一些矩形排列的的神经元组成,同一特征图的神经元共享权值,这里共享的权值就是卷积核。卷积核一般以随机小数矩阵的形式初始化,在网络的训练过程中卷积核将学习得到合理的权值。共享权值(卷积核)带来的直接好处是减少网络各层之间的连接,同时又降低了过拟合的风险。
1.1.2 感受野(Receptive Field)
定义:在卷积神经网络中,卷积神经网络每一层输出的特征图(feature map)上的像素点在输入图片上映射的区域大小。在典型CNN结构中,FC层每个输出节点的值都依赖FC层所有输入,而CONV层每个输出节点的值仅依赖CONV层输入的一个区域, 这个区域之外的其他输入值都不会影响输出值,该区域就是感受野。下图为感受野示意图:
当我们采用尺寸不同的卷积核时,最大的区别就是感受野的大小不同,所以经常会采用多层小卷积核来替换一层大卷积核,在保持感受野相同的情况下减少参数量和计算量。例如十分常见的用两层3*3卷积核来替换一层5*5卷积核的方法
1.3 标准化(Batch Normalization)
在引入BN之前,以前的model training有一些系统性的问题,导致很多算法收敛速度都非常慢,甚至根本就不能工作,尤其在使用sigmoid激活函数时。在机器学习中我们通常会对输入特征进行标准化或归一化,因为直接输入的数据每个维度量纲可能不同、数值差别很大,导致模型不能很好地从各个特征中学习。当上一层输出值太大或太小,其经过sigmoid激活函数时会落在饱和区域,反向传播会有梯度消失的问题。
批标准化(Batch Normalization):对一小批数据(batch),做标准化处理。使数据符合0均值,1为标准差的分布。
Batch Normalization层通常添加在每个神经网络层和激活层之间,对神经网络层输出的数据分布进行统一和调整,变成均值为0方差为1的标准正态分布,解决神经网络中梯度消失的问题使输出位于激活层的非饱和区,达到加快收敛的效果。
1.1.4 池化层(Pooling)
池化 (Pooling) 用来降低神经网络中的特征图(Feature Map)的维度。在卷积神经网络中,池化操作通常紧跟在卷积操作之后,用于降低特征图的空间大小。池化操作的基本思想是将特征图划分为若干个子区域(一般为矩形),并对每个子区域进行统计汇总。池化通常有均值子池化(mean pooling)和最大值池化(max pooling)两种形式。池化可以看作一种特殊的卷积过程。卷积和池化大大简化了模型复杂度,减少了模型的参数。
- 最大值池化可提取图片纹理
- 均值池化可保留背景特征
1.2 卷积的计算过程
假设我们输入的是5*5*1的图像,中间的那个3*3*1是我们定义的一个卷积核(简单来说可以看做一个矩阵形式运算器),通过原始输入图像和卷积核做运算可以得到绿色部分的结果,怎么样的运算呢?实际很简单就是我们看左图中深色部分,处于中间的数字是图像的像素,处于右下角的数字是我们卷积核的数字,只要对应相乘再相加就可以得到结果。例如图中‘3*0+1*1+2*2+2*2+0*2+0*0+2*0+0*1+0*2=9’
计算过程如下动图:
图中最左边的三个输入矩阵就是我们的相当于输入d=3时有三个通道图,每个通道图都有一个属于自己通道的卷积核,我们可以看到输出(output)的只有两个特征图意味着我们设置的输出d=2,有几个输出通道就有几层卷积核(比如图中就有FilterW0和FilterW1),这意味着我们的卷积核数量就是输入d的个数乘以输出d的个数(图中就是2*3=6个),其中每一层通道图的计算与上文中提到的一层计算相同,再把每一个通道输出的输出再加起来就是绿色的输出数字。
步长:每次卷积核移动的大小
输出特征尺寸计算:在了解神经网络中卷积计算的整个过程后,就可以对输出特征图的尺寸进行计算。如下图所示,5×5的图像经过3×3大小的卷积核做卷积计算后输出特征尺寸为3×3
全零填充
当卷积核尺寸大于 1 时,输出特征图的尺寸会小于输入图片尺寸。如果经过多次卷积,输出图片尺寸会不断减小。为了避免卷积之后图片尺寸变小,通常会在图片的外围进行填充(padding),如下图所示
全零填充(padding):为了保持输出图像尺寸与输入图像一致,经常会在输入图像周围进行全零填充,如下所示,在5×5的输入图像周围填0,则输出特征尺寸同为5×5。
当padding=1和paadding=2时,如下图所示:
2 使用CNN实现MNIST手写数字识别
机器识图的过程:机器识别图像并不是一下子将一个复杂的图片完整识别出来,而是将一个完整的图片分割成许多个小部分,把每个小部分里具有的特征提取出来,再将这些小部分具有的特征汇总到一起,从而完成机器识别整个图像。
2.1 MNIST数据介绍
MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。其中的图像的尺寸为28*28。采样数据显示如下:
2.2 基于pytorch的代码实现
import torch
import torch.nn as nn
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import numpy as np#获取数据集
train_data=dataset.MNIST(root="./data",train=True,transform=transforms.ToTensor(),download=True)
test_data=dataset.MNIST(root="./data",train=False,transform=transforms.ToTensor(),download=False)
train_loader=data_utils.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader=data_utils.DataLoader(dataset=test_data, batch_size=64, shuffle=True)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#创建网络
class Net(torch.nn.Module):def __init__(self):super().__init__()self.conv=nn.Conv2d(1, 32, kernel_size=5, padding=2)self.bat2d=nn.BatchNorm2d(32)self.relu=nn.ReLU()self.pool=nn.MaxPool2d(2)self.linear=nn.Linear(14 * 14 * 32, 70)self.tanh=nn.Tanh()self.linear1=nn.Linear(70,30)self.linear2=nn.Linear(30, 10)def forward(self,x):y=self.conv(x)y=self.bat2d(y)y=self.relu(y)y=self.pool(y)y=y.view(y.size()[0],-1)y=self.linear(y)y=self.tanh(y)y=self.linear1(y)y=self.tanh(y)y=self.linear2(y)return y
cnn=Net()
cnn = cnn.to(device)#损失函数
los=torch.nn.CrossEntropyLoss()#优化函数
optime=torch.optim.Adam(cnn.parameters(), lr=0.001)#训练模型
accuracy_rate = [0]
num_epochs = 10
for epo in range(num_epochs):for i, (images,lab) in enumerate(train_loader):images=images.to(device)lab=lab.to(device)out = cnn(images)loss=los(out,lab)optime.zero_grad()loss.backward()optime.step()print("epo:{},i:{},loss:{}".format(epo+1,i,loss))#测试模型loss_test=0accuracy=0with torch.no_grad():for j, (images_test,lab_test) in enumerate(test_loader):images_test = images_test.to(device)lab_test=lab_test.to(device)out1 = cnn(images_test)loss_test+=los(out1,lab_test)loss_test=loss_test/(len(test_data)//100)_,p=out1.max(1)accuracy += (p==lab_test).sum().item()accuracy=accuracy/len(test_data)accuracy_rate.append(accuracy)print("loss_test:{},accuracy:{}".format(loss_test,accuracy))accuracy_rate = np.array(accuracy_rate)
times = np.linspace(0, num_epochs, num_epochs+1)
plt.xlabel('times')
plt.ylabel('accuracy rate')
plt.plot(times, accuracy_rate)
plt.show()
运行结果:
epo:1,i:937,loss:0.2277517020702362
loss_test:0.0017883364344015718,accuracy:0.9729
epo:2,i:937,loss:0.01490325853228569
loss_test:9.064914047485217e-05,accuracy:0.9773
epo:3,i:937,loss:0.0903361514210701
loss_test:0.0003304268466308713,accuracy:0.9791
epo:4,i:937,loss:0.003910894505679607
loss_test:0.00019427068764343858,accuracy:0.9845
epo:5,i:937,loss:0.011963552795350552
loss_test:3.232352901250124e-05,accuracy:0.983
epo:6,i:937,loss:0.04549657553434372
loss_test:0.0001462855434510857,accuracy:0.9859
epo:7,i:937,loss:0.02365218661725521
loss_test:3.670657861221116e-06,accuracy:0.9867
epo:8,i:937,loss:0.00040980291669256985
loss_test:1.4913265658833552e-05,accuracy:0.9872
epo:9,i:937,loss:0.024399513378739357
loss_test:7.590289897052571e-05,accuracy:0.9865
epo:10,i:937,loss:0.0012365489965304732
loss_test:0.00014759502664674073,accuracy:0.9869
3 总结
本文介绍了卷积神经网络中的关键概念,包含卷积核、池化、标准化、感受野等,并基于MNIST数据集,构建了卷积神经网络识别模型,经过10个epochs训练,正确率达到了98%,充分展示了卷积神经网络在图片识别中的作用。
相关文章:

深度学习:使用卷积神经网络CNN实现MNIST手写数字识别
引言 本项目基于pytorch构建了一个深度学习神经网络,网络包含卷积层、池化层、全连接层,通过此网络实现对MINST数据集手写数字的识别,通过本项目代码,从原理上理解手写数字识别的全过程,包括反向传播,梯度…...

docker search 镜像报错: connect: no route to host (桥接模式配置静态IP)
如下 原因 可能有多种: ① 没有开放防火墙端口 ② ip地址配置有误 解决 我是因为虚拟机采用了桥接模式,配置静态ip地址有问题。 先确认虚拟机采用的是 桥接模式,然后启动虚拟机。 1、打开命令行,输入下面指令,打开…...
【VUE】[Violation] Added non-passive event listener to a scroll-blocking...
环境 chrome: 115.0.5790.170vue: ^3.3.4element-plus: ^2.3.4vite: ^4.4.7 问题 [Violation] Added non-passive event listener to a scroll-blocking <某些> 事件. Consider marking event handler as passive to make the page more responsive. See <URL> …...
runit-docker中管理多个服务
runit-docker中管理多个服务 介绍Runit, systemctl和supervisor是三种不同的服务管理工具区别runit优点程序构成快速开始runit实现服务退出执行指定操作runit监管服务打印日志到syslogrunit监管服务后台运行runit监管服务一些错误总结 介绍 runit 是一个轻量级的、稳定的、跨平…...

Intune 应用程序管理
由于云服务提供了增强的安全性、稳定性和灵活性,越来越多的组织正在采用基于云的解决方案来满足他们的需求。这正是提出Microsoft Endpoint Manager等解决方案的原因,它结合了SCCM和Microsoft Intune,以满足本地和基于云的端点管理。 与 Int…...

Oracle DB 安全性 : TDE HSM TCPS Wallet Imperva
• 配置口令文件以使用区分大小写的口令 • 对表空间进行加密 • 配置对网络服务的细粒度访问 TCPS 安全口令支持 Oracle Database 11g中的口令: • 区分大小写 • 包含更多的字符 • 使用更安全的散列算法 • 在散列算法中使用salt 用户名仍是Oracle 标识…...

leetcode27—移除元素
思路: 参考26题目双指针的思想,只不过这道题不是快慢指针。 看到示例里面数组是无序的,也就是说后面的元素也是可能跟给定 val值相等的,那么怎么处理呢。就想到了从前往后遍历,如果left对应的元素 val时,…...
flask---》更多查询方式/连表查询/原生sql(django-orm如何执行原生sql)/flask-sqlalchemy
更多查询方式 #1 查询: filer:写条件 filter_by:等于的值 # 查询所有 是list对象 res session.query(User).all() # 是个普通列表 print(type(res)) print(len(res))# 2 只查询某几个字段 # select name as xx,email from user; res session.…...
Chromium内核浏览器编译记(三)116版本内核UI定制
转载请注明出处:https://blog.csdn.net/kong_gu_you_lan/article/details/132180843?spm1001.2014.3001.5501 本文出自 容华谢后的博客 往期回顾: Chromium内核浏览器编译记(一)踩坑实录 Chromium内核浏览器编译记(…...

LoRaWan网关设计架构介绍
LoRa 数据包转发器是在基于 LoRa 的网关(带或不带 GPS)主机上运行的程序。它将集中器(上行链路)接收到的 RF 数据包通过安全的 IP 链路转发到LoRaWAN 网络服务器( LNS )。它还通过相同的安全 IP 将 LNS(下行链路)发送的 RF 数据包传输到一台或多台设备。此外,它还可以传…...
vue 全局状态管理(简单的store模式、使用Pinia)
目录 为什么使用状态管理简单的store模式服务器渲染(SSR) pinia简介示例1. 定义一个index.ts文件2. 在main.ts中引入3. 定义4. 使用 为什么使用状态管理 多个组件可能会依赖同一个状态时,我们有必要抽取出组件内的共同状态集中统一管理&…...

ORACLE和MYSQL区别
1,Oracle没有offet,limit,在mysql中我们用它们来控制显示的行数,最多的是分页了。oracle要分页的话,要换成rownum。 2,oracle建表时,没有auto_increment,所有要想让表的一个字段自增,…...

tensorflow 1.14 的 demo 02 —— tensorboard 远程访问
tensorflow 1.14.0, 提供远程访问 tensorboard 服务的方法 第一步生成 events 文件: 在上一篇demo的基础上加了一句,如下, tf.summary.FileWriter("./tmp/summary", graphsess1.graph) hello_tensorboard_remote.py …...
Spring中Bean的循环依赖问题
1.什么是Bean的循环依赖? 简单来说就是在A类中,初始化A时需要用到B对象,而在B类中,初始化B时需要用到A对象,这种状况下在Spring中,如果A和B同时初始化,A,B同时都需要对方的资源&…...

若依管理系统后端将 Mybatis 升级为 Mybatis-Plus
文章目录 说明流程增加依赖修改配置文件注释掉MybatisConfig里面的Bean 代码生成使用IDEA生成代码注意 Controller文件 说明 若依管理系统是一个非常完善的管理系统模板,里面含有代码生成的方法,可以帮助用户快速进行开发,但是项目使用的是m…...

剪切、复制、粘贴事件
剪切、复制、粘贴事件 oncopy 事件在用户拷贝元素上的内容时触发。onbeforecut 事件在用户剪切文本,且文本还未删除时触发触发。oncut 事件在用户剪切元素的内容时触发。onbeforepaste 事件在用户向元素中粘贴文本之前触发。onpaste 事件在用户向元素中粘贴文本时触…...

Redis储存结构
Redis怎么储存的 这个redisDb是数据库对象 里面的其他字段忽略了 然后里面有个dict列表(字典列表) 我们随便来看一个redisObject 区分一下子啊 他这个dict里面没有存redisObject的对象 也没有存dict对象 它只是存了个数据指针 你看那个redis每个底层编码 抠搜的 这块要是再保存…...

使用logback异步打印日志
文章目录 一、介绍二、运行环境三、演示项目1. 接口2. 日志配置文件3. 效果演示4. 异步输出验证 四、异步输出原理五、其他参数配置六、源码分析1. 同步输出2. 异步输出 七、总结 一、介绍 对于每一个开发人员来说,在业务代码中添加日志是至关重要的,尤…...

ArcGIS Pro暨基础入门、制图、空间分析、影像分析、三维建模、空间统计分析与建模、python融合、案例应用
GIS是利用电子计算机及其外部设备,采集、存储、分析和描述整个或部分地球表面与空间信息系统。简单地讲,它是在一定的地域内,将地理空间信息和 一些与该地域地理信息相关的属性信息结合起来,达到对地理和属性信息的综合管理。GIS的…...

Rabbitmq的消息确认
配置文件 spring:rabbitmq:publisher-confirm-type: correlated #开启确认回调publisher-returns: true #开启返回回调listener:simple:acknowledge-mode: manual #设置手动接受消息消息从生产者到交换机 无论消息是否到交换机ConfirmCallback都会触发。 Resourceprivate Rabb…...

C++实现分布式网络通信框架RPC(3)--rpc调用端
目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘
美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...
多场景 OkHttpClient 管理器 - Android 网络通信解决方案
下面是一个完整的 Android 实现,展示如何创建和管理多个 OkHttpClient 实例,分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...
2024年赣州旅游投资集团社会招聘笔试真
2024年赣州旅游投资集团社会招聘笔试真 题 ( 满 分 1 0 0 分 时 间 1 2 0 分 钟 ) 一、单选题(每题只有一个正确答案,答错、不答或多答均不得分) 1.纪要的特点不包括()。 A.概括重点 B.指导传达 C. 客观纪实 D.有言必录 【答案】: D 2.1864年,()预言了电磁波的存在,并指出…...

如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...

PL0语法,分析器实现!
简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...
数据库分批入库
今天在工作中,遇到一个问题,就是分批查询的时候,由于批次过大导致出现了一些问题,一下是问题描述和解决方案: 示例: // 假设已有数据列表 dataList 和 PreparedStatement pstmt int batchSize 1000; // …...
音视频——I2S 协议详解
I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议,专门用于在数字音频设备之间传输数字音频数据。它由飞利浦(Philips)公司开发,以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看
文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...
WebRTC从入门到实践 - 零基础教程
WebRTC从入门到实践 - 零基础教程 目录 WebRTC简介 基础概念 工作原理 开发环境搭建 基础实践 三个实战案例 常见问题解答 1. WebRTC简介 1.1 什么是WebRTC? WebRTC(Web Real-Time Communication)是一个支持网页浏览器进行实时语音…...