通过卷积神经网络(CNN)识别和预测手写数字
一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍
卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿人类视觉系统的工作原理来处理数据,能够从图像中自动学习和提取特征。以下是CNN的一些关键特点和组成部分:
卷积层(Convolutional Layer):
卷积层是CNN的核心,它使用滤波器(或称为卷积核)在输入图像上滑动,以提取图像的局部特征。
每个滤波器负责检测图像中的特定特征,如边缘、角点或纹理等。
卷积操作会产生一个特征图(feature map),它表示输入图像在滤波器下的特征响应。
激活函数:
通常在卷积层之后使用非线性激活函数,如ReLU(Rectified Linear Unit),以增加网络的非线性表达能力。
激活函数帮助网络处理复杂的模式,并使网络能够学习更复杂的特征组合。
池化层(Pooling Layer):
池化层用于降低特征图的空间尺寸,减少参数数量和计算量,同时使特征检测更加鲁棒。
最常见的池化操作是最大池化(max pooling)和平均池化(average pooling)。
全连接层(Fully Connected Layer):
在多个卷积和池化层之后,CNN通常包含一个或多个全连接层,这些层将学习到的特征映射到最终的输出类别上。
全连接层中的每个神经元都与前一层的所有激活值相连。
softmax层:
在网络的最后一层,通常使用softmax层将输出转换为概率分布,用于多分类任务中。
softmax函数确保输出层的输出值在0到1之间,并且所有输出值的总和为1。
卷积神经网络的训练:
CNN通过反向传播算法和梯度下降法进行训练,以最小化损失函数(如交叉熵损失)。
在训练过程中,网络的权重通过大量图像数据进行调整,以提高分类或识别的准确性。
数据增强(Data Augmentation):
为了提高CNN的泛化能力,经常使用数据增强技术,如旋转、缩放、裁剪和翻转图像,以创建更多的训练样本。
迁移学习(Transfer Learning):
迁移学习是一种技术,它允许CNN利用在一个大型数据集(如ImageNet)上预训练的网络权重,来提高在小型或特定任务上的性能。
CNN在计算机视觉领域的应用非常广泛,包括但不限于图像分类、目标检测、语义分割、物体跟踪和面部识别等任务。由于其强大的特征提取能力,CNN已成为这些任务的主流方法之一。
MNIST数据集是一个广泛使用的手写数字识别数据集,可以通过TensorFlow库或Pytorch库来获取, 也可以从官方网站下载:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
MNIST数据集它包含四个部分:训练数据集、训练数据集标签、测试数据集和测试数据集标签。这些文件是IDX格式的二进制文件,需要特定的程序来读取。这个数据集包含了60,000张训练集图像和10,000张测试集图像,每张图像都是28x28像素的手写数字,范围从0到9。这些图像被处理为灰度值,其中黑色背景用0表示,手写数字用0到1之间的灰度值表示,数值越接近1,颜色越白。
MNIST数据集的图像通常被拉直为一个一维数组,每个数组包含784个元素(28x28像素)。数据集中的每个图像都有一个对应的标签,标签以one-hot编码的形式给出,例如数字5的标签表示为[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]。
在机器学习模型中,MNIST数据集常用于训练分类器,以识别和预测手写数字。例如,在深度学习中,可以使用卷积神经网络(CNN)来处理这些图像,学习从图像像素到数字标签的映射。
二:通过Pytorch库建立CNN模型训练MNIST数据集
使用Python的Pytorch库来完成一个卷积神经网络(CNN)来训练MNIST数据集,需要遵循以下步骤:
- 导入必要的库:我们需要导入Pytorch以及其它可能需要的库,如torchvision用于数据加载和变换。
- 加载MNIST数据集:使用torchvision库中的datasets和DataLoader来加载和预处理MNIST数据集。
- 定义卷积神经网络结构:设计一个简单的CNN结构,包括卷积层、池化层和全连接层。
- 定义损失函数和优化器:选择一个合适的损失函数,如交叉熵损失,以及一个优化器,如Adam或SGD。
- 训练模型:在训练集上训练模型,并保存训练过程中的损失和准确率。
- 测试模型:在测试集上评估模型的性能。
接下来,我们将按照这些步骤使用Python代码来完成这个任务。
Step1:导入必要的库
# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
: 导入了PyTorch的主库,这是进行深度学习任务的基础。import torch.nn as nn
: 导入了PyTorch的神经网络模块,它包含了构建神经网络所需的许多类和函数。import torch.nn.functional as F
: 导入了PyTorch的功能性API,它提供了不需要维护状态的神经网络操作,例如激活函数、池化等。import torchvision
: 导入了PyTorch的视觉库,它提供了许多视觉任务所需的工具和数据集。import torchvision.transforms as transforms
: 导入了对数据进行预处理的工具。from torch.utils.data import DataLoader
: 导入了PyTorch的数据加载器,它可以方便地迭代数据集。
Step2:加载MNIST数据集
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
transform = transforms.Compose(...)
: 创建了一个转换管道,用于对数据进行预处理。Compose
是一个函数,它将多个转换步骤组合成一个转换。transforms.ToTensor()
: 将图像数据从PIL Image或NumPy ndarray格式转换为浮点张量,并且将像素值缩放到[0,1]范围内。transforms.Normalize((0.5,), (0.5,))
: 对图像进行归一化处理。给定均值(mean)和标准差(std),这个转换将张量的每个通道都减去均值并除以标准差。在这里,它将每个像素值从[0,1]范围转换为[-1,1]范围。
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
- 这两行代码分别加载了MNIST数据集的训练集和测试集。
root='./data'
: 指定数据集下载和存储的根目录。train=True
: 对于trainset
,表示加载数据集的训练部分。train=False
: 对于testset
,表示加载数据集的测试部分。download=True
: 表示如果数据集不在指定的root
目录下,则从互联网上下载。transform=transform
: 应用之前定义的转换。
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
- 这两行代码创建了两个
DataLoader
对象,用于在训练和测试时迭代数据集。 batch_size=64
: 指定每个批次的样本数量。shuffle=True
: 对于trainloader
,在每次迭代时打乱数据,这对于训练是有益的,因为它可以减少模型学习数据的顺序性。shuffle=False
: 对于testloader
,不打乱数据,因为测试时不需要随机性。
得到了一个名为data的文件夹:
Step3:定义卷积神经网络结构
# 定义卷积神经网络结构
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 1024)self.fc2 = nn.Linear(1024, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 7 * 7)x = F.relu(self.fc1(x))x = self.fc2(x)return x
- 这段代码定义了一个名为
CNN
的卷积神经网络类,它继承自nn.Module
。 __init__
方法初始化了网络的结构:self.conv1
是一个2D卷积层,输入通道为1(MNIST图像为单通道),输出通道为32,卷积核大小为3x3,并带有1像素的填充。self.pool
是一个2x2的最大池化层,用于减小数据的维度。self.conv2
是第二个2D卷积层,输入通道为32,输出通道为64,卷积核大小为3x3,并带有1像素的填充。self.fc1
是一个全连接层,它将64个通道的7x7图像映射到1024个特征。self.fc2
是另一个全连接层,它将1024个特征映射到10个输出,对应于MNIST数据集的10个类别。
forward
方法定义了数据通过网络的前向传播路径:x
首先通过conv1
卷积层,然后应用ReLU激活函数,并使用pool
进行池化。- 接着,
x
通过conv2
卷积层,再次应用ReLU激活函数和池化。 x.view(-1, 64 * 7 * 7)
将数据扁平化,为全连接层准备。x
通过fc1
全连接层,并应用ReLU激活函数。- 最后,
x
通过fc2
全连接层,输出结果。
# 实例化网络
net = CNN()
- 创建了一个
CNN
类的实例,名为net
。
Step4:定义损失函数和优化器
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
criterion
是交叉熵损失函数,常用于多分类问题。optimizer
是Adam优化器,用于更新网络的权重。
Step5:训练模型
# 训练模型
epochs = 5
for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/(i+1)}")
下面是这段代码的逐行解释:
epochs
是一个变量,表示训练过程中模型将遍历整个训练数据集的次数。这里设置为5,意味着整个训练数据集将被遍历5次。- 外层for循环,它将执行
epochs
次。在每次迭代中,epoch
变量将代表当前的迭代次数,从0开始到epochs-1
结束。 - 在每次epoch开始时,
running_loss
被重置为0.0。这个变量用于累加每个epoch中的所有批次损失,以便计算平均损失。 - 这是一个嵌套的for循环,它遍历
trainloader
返回的批次数据。enumerate
函数用于遍历可迭代对象,同时跟踪当前的索引(这里是i
)。 trainloader
是之前定义的数据加载器,它负责分批加载数据,以便于训练。- 参数
0
指定了索引的起始值。 - 然后解包了
data
元组,其中包含输入(图像)和标签(目标值)。inputs
是模型的输入数据,labels
是这些输入数据的正确类别标签。 - 在每次迭代开始时,调用
optimizer.zero_grad()
来清除之前梯度计算的结果。这是必要的,因为PyTorch的梯度是累加的。 - 输入
inputs
传递给神经网络net
,并得到输出outputs
。这是模型的前向传播步骤。 - 计算了模型输出的损失。
criterion
是之前定义的交叉熵损失函数,它比较outputs
(模型的预测)和labels
(实际类别标签)来计算损失。 - 执行了反向传播。它计算了损失相对于模型参数的梯度。
- 更新了模型的权重。
optimizer
使用计算出的梯度来调整网络参数,以减少下一次迭代的损失。 - 将当前的批次损失累加到
running_loss
变量中,用于后续计算平均损失。 - 在每个epoch结束时,打印出当前epoch的编号和平均损失。
epoch+1
是为了从1开始计数epoch,而不是从0开始。running_loss/(i+1)
计算了当前epoch的平均损失,其中i+1
是当前epoch中批次的数量。
最终得到每个epoch的平均损失如下:
Step6:测试模型
# 测试模型
correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Accuracy of the network on the 10000 test images: {100 * correct / total}%")
correct
和total
是两个变量,分别用于跟踪模型在测试数据集上正确预测的样本数量和总的样本数量。with torch.no_grad()
是一个上下文管理器,用于在测试阶段禁用梯度计算。因为测试阶段不需要计算梯度,这样可以节省内存并加快计算速度。- for循环,遍历
testloader
返回的测试数据集的批次数据。 - 这行代码解包了
data
元组,其中包含测试图像images
和它们对应的真实标签labels
。 - 这行代码将测试图像
images
输入到训练好的神经网络net
中,并得到输出outputs
。 torch.max(outputs.data, 1)
返回两个值:第一个是每个批次中最大值的元素,第二个是这些最大值的索引。在这里,最大值代表模型对每个图像的预测类别,而索引则代表预测的类别标签。predicted
是模型预测的类别标签的向量。- 这行代码累加测试集中总的样本数量。
labels.size(0)
给出了当前批次中样本的数量。 (predicted == labels)
是一个布尔表达式,它比较模型的预测predicted
和真实标签labels
,并返回一个布尔张量,其中正确预测的位置为True,否则为False。.sum()
计算布尔张量中True的数量,即正确预测的样本数量。.item()
将计算得到的张量(只有一个元素)转换为Python的标量值。- 这行代码计算并打印出模型在测试数据集上的准确率。准确率是通过将正确预测的样本数量
correct
除以总样本数量total
,然后乘以100来得到的百分比。这里假设测试数据集包含10000个样本。
得到准确率如下:
使用这个建立好的卷积神经网络(CNN)模型,主要用于训练分类器。具体来说,这个模型能够识别手写数字图像,并将它们分类为0到9中的一个类别。它适用于MNIST数据集。这个示例能够帮助更好的了解卷积神经网络(CNN)的原理。
想要探索更多元化的数据分析视角,可以关注之前发布的相关内容。
相关文章:

通过卷积神经网络(CNN)识别和预测手写数字
一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍 卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿…...

【A题第二套完整论文已出】2024数模国赛A题第二套完整论文+可运行代码参考(无偿分享)
“板凳龙” 闹元宵路径速度问题 摘要 本文针对传统舞龙进行了轨迹分析,并针对一系列问题提出了解决方案,将这一运动进行了模型可视化。 针对问题一,我们首先对舞龙的螺线轨迹进行了建模,将直角坐标系转换为极坐标系࿰…...

一份热乎的数据分析(数仓)面试题 | 每天一点点,收获不止一点
目录 1. 已有ods层⽤⼾表为ods_online.user_info,有两个字段userid和age,现设计数仓⽤⼾表结构如 下: 2. 设计数据仓库的保单表(⾃⾏命名) 3. 根据上述两表,查询2024年8⽉份,每⽇,…...

3 html5之css新选择器和属性
要说css的变化那是发展比较快的,新增的选择器也很多,而且还有很多都是比较实用的。这里举出一些案例,看看你平时都是否用过。 1 新增的一些写法: 1.1 导入css 这个是非常好的一个变化。这样可以让我们将css拆分成公共部分或者多…...

【Kubernetes】K8s 的鉴权管理(一):基于角色的访问控制(RBAC 鉴权)
K8s 的鉴权管理(一):基于角色的访问控制(RBAC 鉴权) 1.Kubernetes 的鉴权管理1.1 审查客户端请求的属性1.2 确定请求的操作 2.基于角色的访问控制(RBAC 鉴权)2.1 基于角色的访问控制中的概念2.1…...

保研 比赛 利器: 用AI比赛助手降维打击数学建模
数学建模作为一个热门但又具有挑战性的赛道,在保研、学分加分、简历增色等方面具有独特优势。近年来,随着AI技术的发展,特别是像GPT-4模型的应用,数学建模的比赛变得不再那么“艰深”。通过利用AI比赛助手,不仅可以大大…...

秋招校招,在线性格测评应该如何应对
秋招校招,如果遇到在线测评,如何应对? 这里写个总结稿,希望对大家有些帮助。在线测评是企业深入了解求职人的渠道,如果是性格测试,会要求测试者能够快速答出,以便于反应实际情况(时间…...

chrome 插件开发入门
1. 介绍 Chrome 插件可用于在谷歌浏览器上控制当前页面的一些操作,可自主控制网页,提升效率。 平常我们可在谷歌应用商店中下载谷歌插件来增强浏览器功能,作为开发者,我们也可以自己开发一个浏览器插件来配合我们的日常学习工作…...

揭开面纱--机器学习
一、人工智能三大概念 1.1 AI、ML、DL 1.1.1 什么是人工智能? AI:Artificial Intelligence 人工智能 AI is the field that studies the synthesis and analysis of computational agents that act intelligently AI is to use computers to analog and instead…...

Python中的私有属性与方法:解锁面向对象编程的秘密
在Python的广阔世界里,面向对象编程(OOP)是一种强大而灵活的方法论,它帮助我们更好地组织代码、管理状态,并构建可复用的软件组件。而在这个框架内,私有属性与方法则是实现封装的关键机制之一。它们不仅有助…...

开篇_____何谓安卓机型“工程固件” 与其他固件的区别 作用
此系列博文将分析安卓系列机型与一些车机 wifi板子等工程固件的一些常识。从早期安卓1.0起始到目前的安卓15,一些厂家发布新机型的常规流程都是从工程机到量产的过程。在其中就需要调试各种参数以便后续的量产参数可以固定到最佳,工程固件由此诞生。 后…...

DBeaver 连接 MySQL 报错 Public Key Retrieval is not allowed
DBeaver 连接 MySQL 报错 Public Key Retrieval is not allowed 文章目录 DBeaver 连接 MySQL 报错 Public Key Retrieval is not allowed问题解决办法 问题 使用 DBeaver 连接 MySQL 数据库的时候, 一直报错下面的错误 Public Key Retrieval is not allowed详细…...

三个月涨粉两万,只因为知道了这个AI神器
大家好,我是凡人,最近midjourney的账号到期了,正准备充值时,被一个国内AI图片的生成神器给震惊了,不说废话,先上图看看生成效果。 怎么样还不错吧,是我非常喜欢的国风画,哈哈&#x…...

vulhub GhostScript 沙箱绕过(CVE-2018-16509)
1.搭建环境 2.进入网站 3.下载包含payload的png文件 vulhub/ghostscript/CVE-2018-16509/poc.png at master vulhub/vulhub GitHub 4.上传poc.png图片 5.查看创建的文件...

李宏毅机器学习笔记——反向传播算法
反向传播算法 反向传播(Backpropagation)是一种用于训练人工神经网络的算法,它通过计算损失函数相对于网络中每个参数的梯度来更新这些参数,从而最小化损失函数。反向传播是深度学习中最重要的算法之一,通常与梯度下降…...

内推|京东|后端开发|运维|算法...|北京 更多岗位扫内推码了解,直接投递,跟踪进度
热招岗位 更多岗位欢迎扫描末尾二维码,小程序直接提交简历等面试。实时帮你查询面试进程。 安全运营中心研发工程师 岗位要求 1、本科及以上学历,3年以上的安全相关工作经验; 2、熟悉c/c、go编程语言之一、熟悉linux网络编程和系统编程 3、…...

编写Dockerfile第二版
目标 更快的构建速度 更小的Docker镜像大小 更少的Docker镜像层 充分利用镜像缓存 增加Dockerfile可读性 让Docker容器使用起来更简单 总结 编写.dockerignore文件 容器只运行单个应用 将多个RUN指令合并为一个 基础镜像的标签不要用latest 每个RUN指令后删除多余文…...

校验码:奇偶校验,CRC循环冗余校验,海明校验码
文章目录 奇偶校验码CRC循环冗余校验码海明校验码 奇偶校验码 码距:任何一种编码都由许多码字构成,任意两个码字之间最少变化的二进制位数就称为数据检验码的码距。 奇偶校验码的编码方法是:由若干位有效信息(如一个字节),再加上…...

增维思考,减维问题,避免焦虑!
什么是嵌入式软件开发的核心技能? 1. 编程语言 熟练掌握C/C:C语言是嵌入式领域最重要也是最主要的编程语言,用于实现系统功能和性能优化。C在需要面向对象编程的场合也是重要的选择。了解汇编语言:在某些需要直接与硬件交互或优…...

自动化抢票 12306
自动化抢票 12306 1. 明确需求 明确采集的网站以及数据内容 网址: https://kyfw.12306.cn/otn/leftTicket/init数据: 车次相关信息 2. 抓包分析 通过浏览器开发者工具分析对应的数据位置 打开开发者工具 F12 或鼠标右键点击检查 刷新网页 点击下一页/下滑网页页面/点击搜…...

海外云服务器安装 MariaDB10.6.X (Ubuntu 18.04 记录篇二)
本文首发于 秋码记录 MariaDB 的由来(历史) 谈起新秀MariaDB,或许很多人都会感到陌生吧,但若聊起享誉开源界、业界知名的关系型数据库——Mysql,想必混迹于互联网的人们(coder)无不知晓。 其…...

Mybatis_基础
文章目录 第一章 Mybatis简介1.1 Mybatis特性1.2 和其它持久化层技术对比 第二章 Mybatis的增删改查第三章 Mybatis的增删改查 第一章 Mybatis简介 1.1 Mybatis特性 MyBatis 是支持定制化 SQL、存储过程以及高级映射的优秀的持久层框架。MyBatis 避免了几乎所有的 JDBC 代码和…...

8Manage采购申请管理:轻松实现手动采购流程自动化
您是否感受到通过手动采购申请流程管理成本的压力? 信息的不充分常常导致现金流的不透明,这已成为财务高管们的常见痛点。本文将展示采购申请管理软件如何帮助您减轻负担,使您能够简化流程。 没有采购申请软件会面临哪些挑战? …...

PADS Router 入门基础教程(一)
有将近三周没有更新过博客了,最近在整理PADS Router 入门基础教程,希望喜欢本系列教程的小伙伴可以点点关注和订阅!下面我们开始进入PADS Router课程的介绍。 一、PADS Router 快捷键 二、课程介绍 本教程主要介绍:PADS Rou…...

一台手机一个ip地址吗?手机ip地址泄露了怎么办
在数字化时代,手机作为我们日常生活中不可或缺的一部分,其网络安全性也日益受到关注。其中一个常见的疑问便是:“一台手机是否对应一个固定的IP地址?”实际上,情况并非如此简单。本文首先解答这一问题&a…...

【扇贝编程】使用Selenium模拟浏览器获取动态内容笔记
文章目录 selenium安装 selenium下载浏览器驱动 获取数据处理数据查找一个元素查找所有符合条件的元素 控制浏览器 selenium selenium是爬虫的好帮手, 可以控制你的浏览器,模仿人浏览网页,从而获取数据,自动操作等。 我们只要让…...

TCP Analysis Flags 之 TCP Port numbers reused
前言 默认情况下,Wireshark 的 TCP 解析器会跟踪每个 TCP 会话的状态,并在检测到问题或潜在问题时提供额外的信息。在第一次打开捕获文件时,会对每个 TCP 数据包进行一次分析,数据包按照它们在数据包列表中出现的顺序进行处理。可…...

【Python机器学习】核心数、进程、线程、超线程、L1、L2、L3级缓存
如何知道自己电脑的CPU是几核的,打开任务管理器(同时按下:Esc键、SHIFT键、CTRL键) 然后,点击任务管理器左上角的性能选项,观察右下角中的内核:后面的数字,就是你CPU的核心数,下图中我的是16个核心的。 需要注意的是,下面的逻辑处理器:32 表示支持 32 线程(即超线…...

JavaScript使用地理位置 API
前言 在JavaScript中,Geolocation API 是一种用于访问用户地理位置的接口。这个API允许网页应用程序获取用户的位置并提供基于位置的服务。 if (navigator.geolocation)navigator.geolocation.getCurrentPosition(function () {},function () {});这个函数中需要传…...

dockerfile部署fastapi项目
dockerfile部署fastapi项目 1、Dockerfile # 使用Python官方镜像作为基础镜像 FROM python:3.8-slim# 更新apt-get源并安装依赖 # RUN apt-get update -y && apt-get install -y git# 设置环境变量 ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONUNBUFFERED 1# 创建工作目…...