CNN手写数字识别1——模型搭建与数据准备
模型搭建
我们这次使用LeNet模型,LeNet是一个经典的卷积神经网络(Convolutional Neural Network, CNN)架构,最初由Yann LeCun等人在1998年提出,用于手写数字识别任务
创建一个文件model.py。实现以下代码。
源码
# 导入PyTorch库
import torch
# 从PyTorch库中导入神经网络模块
from torch import nn
# 从torchsummary库中导入summary函数,用于打印模型的结构和参数数量
from torchsummary import summary# 定义LeNet类,它继承自nn.Module,是一个神经网络模型
class LeNet(nn.Module):# 初始化函数,定义模型的层次结构def __init__(self):# 调用父类的初始化函数super().__init__()# 第一个卷积层,输入通道为1,输出通道为6,卷积核大小为5x5,padding为2self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)# Sigmoid激活函数self.sig = nn.Sigmoid()# 第一个平均池化层,池化窗口为2x2,步长为2self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)# 第二个卷积层,输入通道为6,输出通道为16,卷积核大小为5x5self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)# 第二个平均池化层,池化窗口为2x2,步长为2self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)# Flatten层,用于将多维的输入一维化,以便输入到全连接层self.flatten = nn.Flatten()# 第一个全连接层,输入特征数为400,输出特征数为120self.f5 = nn.Linear(400, 120)# 第二个全连接层,输入特征数为120,输出特征数为84self.f6 = nn.Linear(120, 84)# 第三个全连接层,输入特征数为84,输出特征数为10(通常对应分类任务中的类别数)self.f7 = nn.Linear(84, 10)# 前向传播函数,定义数据通过网络的方式def forward(self, x):x = self.sig(self.c1(x)) # 通过第一个卷积层和Sigmoid激活函数x = self.s2(x) # 通过第一个平均池化层x = self.sig(self.c3(x)) # 通过第二个卷积层和Sigmoid激活函数x = self.s4(x) # 通过第二个平均池化层x = self.flatten(x) # 通过Flatten层x = self.sig(self.f5(x)) # 通过第一个全连接层和Sigmoid激活函数x = self.sig(self.f6(x)) # 通过第二个全连接层和Sigmoid激活函数x = self.sig(self.f7(x)) # 通过第三个全连接层和Sigmoid激活函数return x# 主函数
if __name__ == "__main__":# 自动检测是否有可用的GPU,如果有则使用GPU,否则使用CPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 实例化LeNet模型,并将其移动到指定的设备上(GPU或CPU)model = LeNet().to(device)# 使用torchsummary的summary函数打印模型的结构和参数数量,输入形状为(1, 28, 28)print(summary(model, (1, 28, 28)))
源码解析
神经网络构建
LeNet网络主要由卷积层、池化层、激活函数和全连接层组成。有2个卷积层,2个池化层,3个全连接层。
- 卷积层(
nn.Conv2d):用于提取图像中的特征。这里有两个卷积层,第一个卷积层有6个输出通道,第二个卷积层有16个输出通道,卷积核大小都是5x5。 - 激活函数(
nn.Sigmoid):用于引入非线性,使得网络能够学习更复杂的模式。这里使用了Sigmoid激活函数。 - 池化层(
nn.AvgPool2d):用于降低特征图的尺寸,减少计算量,同时保留重要特征。这里使用了平均池化层,池化窗口大小为2*2,步长为2。 - Flatten层(
nn.Flatten):用于将多维的特征图展平成一维向量,以便输入到全连接层。 - 全连接层(
nn.Linear):用于分类任务,将特征向量映射到类别空间。这里有三个全连接层,分别将特征维度从400降到120,再从120降到84,最后从84降到10(对应10个类别)。
参数计算
从代码中可以看到每一层神经网络都有自己的参数,这里面通道数,卷积核大小,步长和感受野,一定程度上可以当做超参数人为自由设定,其他的参数都需要事先根据输入数据进行计算。
首先,假设输入的图像数据大小都是28*28*1,即宽28个像素,高28个像素,由于是灰度图所以色彩通道只有1。
在第一个卷积层c1,卷积核大小是5*5,卷积核个数是6个,步长默认是1,填充是2,这些是我们人为设定的。可可以得出输出通道数=卷积核个数=6,经过这一层的输出数据通道数为6,尺寸通过公式计算:
公式里面O是输出的宽/高,IN是输入的宽/高,P是填充,F是卷积核的宽/高或者感受野的宽/高,S是步长。
即输出图像的宽高是(28+2*2-5)/1+1=28。输出数据量是28*28*6。可以看到卷积层可以有效提升数据的通道数。
在第一个池化层s2,感受野是2*2,步长是2,填充默认是0,可以计算出输出图像的宽高是(28+2*0-2)/2+1=14。可以看到经过池化层之后数据的特征量明显减少,此时输出的数据量是14*14*6(池化不会改变通道数)。
在第二个卷积层c3,卷积核大小是5*5,步长默认是1,填充默认是0,有16个卷积核,也就是通道数增加到了16。根据公式可以计算出输出图像的宽高是(14+2*0-5)/1+1=10。输出数据量是10*10*16。
在第二个池化层s4,感受野是2*2,步长为2,那么输出数据宽高就是(10+2*0-2)/2+1=5。输出数据量是5*5*16。
到全连接层的第一层就比较关键了,因为这里的参数有一个“输入特征数”,也就是刚才算的5*5*16=400。如果前面的计算不对,到了这一步模型是会报错的,因此每一层的输出特征量都需要计算出来。
后面的全连接层就比较好写了,输入的特征量是上一层全连接层的神经元个数。
前向传播
前向传播定义了数据通过网络的方式。对于输入x,它首先通过第一个卷积层和Sigmoid激活函数,然后通过第一个平均池化层;接着通过第二个卷积层和Sigmoid激活函数,再通过第二个平均池化层;最后通过Flatten层将多维特征展平成一维向量,并依次通过三个全连接层和Sigmoid激活函数得到最终输出。
验证
在主函数中,我们首先检测是否有可用的GPU,并将模型移动到合适的计算设备上(GPU或CPU)。然后,我们使用torchsummary的summary函数打印模型的结构和参数数量,以便了解模型的复杂度和计算需求。

从图中可以看出,池化层是不包含参数的,整个模型的大部分参数都在全连接层(48120+10164+850 = 59134,将近6万个参数在伺候全连接层)。
数据准备
FashionMNIST是一个流行的数据集,包含了10种类别的70,000个灰度图像,通常用于计算机视觉和机器学习的教学与研究。我们这次通过远程下载的方式来获取数据。
另外创建一个plot.python。用来下载数据和预览数据。这部分代码可写可不写,模型训练的时候还会重新加载数据。
源码
# 导入必要的库和模块
from torchvision import transforms # 用于图像预处理的变换
from torchvision.datasets import FashionMNIST # 导入FashionMNIST数据集
import torch.utils.data as Data # 用于数据加载的实用工具
import numpy as np # 导入NumPy库,用于数值计算
import matplotlib.pyplot as plt# 准备训练数据
train_data = FashionMNIST(root='./data', # 数据集存储的根目录train=True, # 指定为训练数据集transform=transforms.Compose([ # 图像预处理步骤transforms.Resize(size=224), # 将图像大小调整为224x224transforms.ToTensor() # 将图像转换为PyTorch张量]),download=True # 如果数据集不存在,则下载
)# 创建数据加载器
train_loader = Data.DataLoader(dataset=train_data, # 指定数据集batch_size=64, # 每个批次的大小shuffle=True, # 在每个epoch开始时打乱数据num_workers=0 # 使用0个工作线程(对于Windows系统,有时需要设置为0以避免多进程问题)
)# 遍历数据加载器
for step, (b_x, b_y) in enumerate(train_loader):if step > 0: # 只处理第一个批次的数据break
# 将PyTorch张量转换为NumPy数组
batch_x = b_x.squeeze().numpy() # 移除批次维度(如果可能),并转换为NumPy数组
batch_y = b_y.numpy() # 将标签转换为NumPy数组# 获取数据集中的类别标签
class_label = train_data.classes # 这是一个包含所有类别名称的列表
# 打印类别标签
print(class_label) # 输出类别标签列表# 设置图形的大小
plt.figure(figsize=(12, 5))# 遍历batch_y中的每一个元素,即每一个样本的标签
for ii in np.arange(len(batch_y)):# 创建子图,4行16列,第ii+1个子图# 这里假设一个批次有64个样本,因此用4x16的布局来显示它们plt.subplot(4, 16, ii + 1)# 显示图像# batch_x[ii, :, :]表示第ii个样本的图像数据# cmap=plt.cm.gray指定使用灰度色彩映射plt.imshow(batch_x[ii, :, :], cmap=plt.cm.gray)# 设置标题为对应的类别标签# class_label[batch_y[ii]]根据标签索引获取类别名称# size=10设置标题字体大小plt.title(class_label[batch_y[ii]], size=10)# 关闭坐标轴显示plt.axis("off")# 调整子图之间的间距
# wspace=0.05设置子图之间的宽度间距
plt.subplots_adjust(wspace=0.05)# 显示图形
plt.show()
源码解析
下载和加载数据
首先准备训练数据。FashionMNIST数据集将被下载到指定的根目录,并进行图像预处理
为了高效地加载数据,我们使用PyTorch的DataLoader来创建数据加载器。
dataloader的参数解释如下:
dataset:指定要加载的数据集。batch_size:每个批次加载的样本数。shuffle:是否在每个epoch开始时打乱数据。num_workers:加载数据时使用的工作线程数。在Windows系统上,有时需要设置为0以避免多进程问题。
展示数据
我们遍历数据加载器,但只处理第一个批次的数据(为了简化示例)。使用squeeze()方法移除批次维度(如果可能),并将PyTorch张量转换为NumPy数组,以便使用matplotlib进行可视化。随后使用matplotlib的subplot()方法创建子图,并在每个子图中显示一个图像样本。我们使用灰度色彩映射(cmap=plt.cm.gray)来显示图像。最后,使用plt.show()方法显示图形。

相关文章:
CNN手写数字识别1——模型搭建与数据准备
模型搭建 我们这次使用LeNet模型,LeNet是一个经典的卷积神经网络(Convolutional Neural Network, CNN)架构,最初由Yann LeCun等人在1998年提出,用于手写数字识别任务 创建一个文件model.py。实现以下代码。 源码 #…...
深度学习04 数据增强、调整学习率
目录 数据增强 常用的数据增强方法 调整学习率 学习率 调整学习率 调整学习率的方法 有序调整 等间隔调整 多间隔调整 指数衰减 余弦退火 自适应调整 自定义调整 数据增强 数据增强是通过对训练数据进行各种变换(如旋转、翻转、裁剪等)&am…...
Python 自然语言处理(NLP)和文本挖掘的常规操作过程
Python 自然语言处理(NLP)和文本挖掘 自然语言处理(NLP)和文本挖掘是数据科学中的重要领域,涉及对文本数据的分析和处理。Python 提供了丰富的库和工具,用于执行各种 NLP 和文本挖掘任务。以下是一些常见的…...
掌握SQLite_轻量级数据库的全面指南
1. 引言 1.1 SQLite简介 SQLite 是一个嵌入式关系型数据库管理系统,它不需要单独的服务器进程或系统配置。它的设计目标是简单、高效、可靠,适用于各种应用场景,尤其是移动设备和嵌入式系统。 1.2 为什么选择SQLite 轻量级:文件大小通常在几百KB到几MB之间。无服务器架构…...
PH热榜 | 2025-02-16
1. Cal.com Routing 标语:根据客户线索,系统会智能地自动安排约会。 介绍:告别繁琐的排期!Cal.com 推出了新的路由功能,能更智能地分配预约,让你的日程安排更顺畅。这项功能运用智能逻辑和深入的数据分析…...
数据库基本概念及基本使用
数据库基本概念 什么是数据库: 数据库特点: 常见的数据库软件: 不同的公司进行不同的实践,生成了不同的产品。 比如买汽车,汽车只是一个概念,你要买哪个牌子哪个型号的汽车,才是真正的汽车的一…...
gozero实现数据库MySQL单例模式连接
在 GoZero 框架中实现数据库的单例连接可以通过以下步骤来完成。GoZero 使用 gorm 作为默认的数据库操作框架,接下来我会展示一个简单的单例模式实现。 ### 1. 定义数据库连接的单例结构 首先,你需要定义一个数据库连接的结构体,并在初始化…...
CSS flex布局 列表单个元素点击 本行下插入详情独占一行
技术栈:Vue2 javaScript 简介 在实际开发过程中有遇到一个场景:一个list,每行个数固定,点击单个元素后,在当前行与下一行之间插入一行元素详情,便于更直观的查看到对应的数据详情。 这种情形,…...
无人机航迹规划: 梦境优化算法(Dream Optimization Algorithm,DOA)求解无人机路径规划MATLAB
一、梦境优化算法 梦境优化算法(Dream Optimization Algorithm,DOA)是一种新型的元启发式算法,其灵感来源于人类的梦境行为。该算法结合了基础记忆策略、遗忘和补充策略以及梦境共享策略,通过模拟人类梦境中的部分记忆…...
权限五张表
重点:权限五张表的设计 核心概念: 在权限管理系统中,经典的设计通常涉及五张表,分别是用户表、角色表、权限表、用户角色表和角色权限表。这五张表的设计可以有效地管理用户的权限,确保系统的安全性和灵活性。 用户&…...
Docker-数据卷
1.数据卷 容器是隔离环境,容器内程序的文件、配置、运行时产生的容器都在容器内部,我们要读写容器内的文件非常不方便。大家思考几个问题: 如果要升级MySQL版本,需要销毁旧容器,那么数据岂不是跟着被销毁了࿱…...
在Linux系统下修改Docker的默认存储路径
在Linux系统下修改Docker的默认存储路径可以通过多种方法实现,下边是通过修改daemon.json文件方式实现 查看当前Docker存储路径 使用命令 docker info | grep "Docker Root Dir" 查看当前Docker的存储路径,默认为 /var/lib/docker 停止Docker…...
IT : 是工作還是嗜好? Delphi 30周年快乐!
又到2月14日了, 自从30多年前收到台湾宝蓝(Borland)公司一大包的3.5 磁盘片, 上面用黑色油性笔写着Delphi Beta开始, Delphi便和我的工作生涯有了密不可分的关系. 一年后Delphi大获成功, 自此对于使用Delphi的使用者来说2月14日也成了一个特殊的日子! 我清楚记得Delphi Beta使用…...
DeepPose
目录 摘要 Abstract DeepPose 算法框架 损失函数 创新点 局限性 训练过程 代码 总结 摘要 DeepPose是首个将CNN应用于姿态估计任务的模型。该模型在传统姿态估计方法的基础上,通过端到端的方式直接从图像中回归出人体关键点的二维坐标,避免了…...
[HarmonyOS]鸿蒙(添加服务卡片)推荐商品 修改卡片UI(内容)
什么是服务卡片 ? 鸿蒙系统中的服务卡片(Service Card)就是一种轻量级的应用展示形式,它可以让用户在不打开完整应用的情况下,快速访问应用内的特定功能或信息。以下是服务卡片的几个关键点: 轻量级&#…...
DeepSeek R1 本地部署和知识库搭建
一、本地部署 DeepSeek-R1,是幻方量化旗下AI公司深度求索(DeepSeek)研发的推理模型 。DeepSeek-R1采用强化学习进行后训练,旨在提升推理能力,尤其擅长数学、代码和自然语言推理等复杂任务 。 使用DeepSeek R1, 可以大大…...
领域驱动设计叕创新,平安保险申请DDD专利
DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 见下图: 这个名字拼得妙:领域驱动设计模式。 是领域驱动设计?还是设计模式?还是领域驱动设计设计模式?和下面这个知乎文章的…...
团体程序设计天梯赛-练习集——L1-041 寻找250
前言 10分的题,主要的想法就一个,按这个想法可以出几个写法 L1-041 寻找250 对方不想和你说话,并向你扔了一串数…… 而你必须从这一串数字中找到“250”这个高大上的感人数字。 输入格式: 输入在一行中给出不知道多少个绝对值…...
动量突破均值回归策略
动量突破均值回归策略:量化交易中的双剑合璧 引言 在量化交易的世界中,动量策略和均值回归策略是两种经典且广泛应用的策略。动量策略基于“强者恒强”的理念,认为过去表现良好的资产在未来一段时间内仍会继续表现良好;而均值回…...
vue3.x 的provide 与 inject详细解读
在 Vue 3.x 中,provide 和 inject 是一对用于实现依赖注入的 API。它们允许父组件向其所有子组件(无论嵌套多深)传递数据或方法,而不需要通过 props 逐层传递。这在开发复杂组件或高阶组件时非常有用。 1. provide 的基本用法 p…...
C#控制台大小Console.SetWindowSize函数失效解决
在使用C#修改控制台大小相关API会失效. 由于VS将控制台由命令提示符变成了终端,因此在设置大小时会出现问题 测试代码: Console.SetWindowSize(100, 50);...
spring boot 对接aws 的S3 服务,实现上传和查询
1.aws S3介绍 AWS S3(Amazon Simple Storage Service)是亚马逊提供的一种对象存储服务,旨在提供可扩展、高可用性和安全的数据存储解决方案。以下是AWS S3的一些主要特点和功能: 1.1. 对象存储 对象存储模型:S3使用…...
25/2/16 <算法笔记> DirectPose
DirectPose 是一种直接从图像中预测物体的 6DoF(位姿:6 Degrees of Freedom)姿态 的方法,包括平移和平面旋转。它在目标检测、机器人视觉、增强现实(AR)和自动驾驶等领域中具有广泛应用。相比于传统的位姿估…...
数据结构-8.Java. 七大排序算法(下篇)
本篇博客给大家带来的是排序的知识点, 由于时间有限, 分两天来写, 下篇主要实现最后一种排序算法: 归并排序。同时把中篇剩下的快排非递归实现补上. 文章专栏: Java-数据结构 若有问题 评论区见 欢迎大家点赞 评论 收藏 分享 如果你不知道分享给谁,那就分享给薯条. 你们的支持是…...
缓存穿透、缓存击穿、缓存雪崩的区别与解决方案
1. 缓存穿透(Cache Penetration) 定义:大量请求查询 数据库中不存在的数据,导致请求绕过缓存直接访问数据库,造成数据库压力过大。 场景: 恶意攻击:例如用不存在的用户ID频繁请求。 业务误操作…...
DeepSeek私有化部署+JAVA通过API调用离线大模型问答
在当今快速发展的数字化时代,企业对于高效、灵活的技术解决方案需求日益增长。DeepSeek作为一款领先的智能搜索与分析平台,凭借其强大的数据处理能力和精准的搜索结果,已经成为众多企业提升运营效率的得力助手。为了更好地满足企业对数据安全…...
【go语言规范】Gopherfest 2015 | Go Proverbs with Rob Pike的 总结
根据 Gopherfest 2015 | Go Proverbs with Rob Pike 的演讲,总结内容如下: 虽然已是十年前的产物,但是proverbs的价值依旧存在 以下是整合补充内容后的完整总结,涵盖 Rob Pike 在 Gopherfest 2015 演讲 “Go Proverbs” 中的核心…...
【吾爱出品】针对红警之类老游戏适用WIN10和11的补丁cnc-ddraw7.1汉化版
针对红警之类老游戏适用WIN10和11的补丁cnc-ddraw7.1汉化版 链接:https://pan.xunlei.com/s/VOJ8PZd4avMubnDzHQAeZDxWA1?pwdnjwm# 直接复制到游戏安装目录,保持与游戏主程序同目录下。...
内容中台驱动企业数字化内容管理高效协同架构
内容概要 在数字化转型加速的背景下,企业对内容管理的需求从单一存储向全链路协同演进。内容中台作为核心支撑架构,通过统一的内容资源池与智能化管理工具,重塑了内容生产、存储、分发及迭代的流程。其核心价值在于打破部门壁垒,…...
【第14章:神经符号集成与可解释AI—14.4 神经符号集成与可解释AI的未来发展趋势与挑战】
想象一下,如果AI既能像人类一样直觉感知(比如一眼认出街角的咖啡店),又能像数学家一样逻辑推理(比如计算最优路线避开拥堵),这个世界会变成什么样?这种“双脑协同”正是神经符号集成技术的终极目标。 但现实是,当前99%的AI系统要么只会“死记硬背”数据(如深度学习模…...
