当前位置: 首页 > news >正文

探索用卷积神经网络实现MNIST数据集分类

问题
对比单个全连接网络,在卷积神经网络层的加持下,初始时,整个神经网络模型的性能是否会更好。

方法

模型设计
两层卷积神经网络(包含池化层),一层全连接网络。

  1. 选择 5 x 5 的卷积核,输入通道为 1,输出通道为 10:

    此时图像矩阵经过 5 x 5 的卷积核后会小两圈,也就是4个数位,变成 24 x 24,输出通道为10;

  2. 选择 2 x 2 的最大池化层:

    此时图像大小缩短一半,变成 12 x 12,通道数不变;

  3. 再次经过5 x 5的卷积核,输入通道为 10,输出通道为 20:

    此时图像再小两圈,变成 8*8,输出通道为20;

  4. 再次经过2 x 2的最大池化层:

    此时图像大小缩短一半,变成 4 x 4,通道数不变;

  5. 最后将图像整型变换成向量,输入到全连接层中:

    输入一共有 4 x 4 x 20 = 320个元素,输出为 10.

代码

准备数据集

# 准备数据集

batch_size = 64

transform = transforms.Compose([

   transforms.ToTensor(),

   transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = datasets.MNIST(root='data’,

                              train=True,

                              download=True,

                              transform=transform)

train_loader = DataLoader(train_dataset,

                         shuffle=True,

                         batch_size=batch_size)

test_dataset = datasets.MNIST(root='data',

                             train=False,

                             download=True,

                             transform=transform)

test_loader = DataLoader(test_dataset,

                        shuffle=False,

                        batch_size=batch_size)

建立模型

class Net(torch.nn.Module):

   def __init__(self):

       super(Net, self).__init__()

       self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)

       self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)

       self.pooling = torch.nn.MaxPool2d(2)

       self.fc = torch.nn.Linear(320, 10)

   def forward(self, x):

       batch_size = x.size(0)

       x = F.relu(self.pooling(self.conv1(x)))

       x = F.relu(self.pooling(self.conv2(x)))

       x = x.view(batch_size, -1)

       x = self.fc(x)

       return x

model = Net()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

构造损失函数+优化器

criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

训练+测试

def train(epoch):

   running_loss = 0.0

   for batch_idx, data in enumerate(train_loader, 0):

       inputs, target = data

       inputs,target=inputs.to(device),target.to(device)

       optimizer.zero_grad()

       outputs = model(inputs)

       loss = criterion(outputs, target)

       loss.backward()

       optimizer.step()

       running_loss += loss.item()

       if batch_idx % 300 == 299:

           print('[%d,%.5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))

           running_loss = 0.0

def test():

   correct=0

   total=0

   with torch.no_grad():

       for data in test_loader:

           inputs,target=data

           inputs,target=inputs.to(device),target.to(device)

           outputs=model(inputs)

           _,predicted=torch.max(outputs.data,dim=1)

           total+=target.size(0)

           correct+=(predicted==target).sum().item()

   print('Accuracy on test set:%d %% [%d%d]' %(100*correct/total,correct,total))

if __name__ =='__main__':

   for epoch in range(10):

       train(epoch)

       test()

运行结果

(1)batch_size:64,训练次数:10

180acf4f7c17af3db55512491cf13c94.png

0e80fc811d9fca29fe2d204f6636c7f0.png

(2)batch_size:128,训练次数:10

3f1ecc11eaa504aaac05dd597b2e2c4b.png

(3)batch_size:128,训练次数:10

542f1af0a5712cdb6403b0be0d47f394.png

结语

对比单个全连接网络,在卷积神经网络层的加持下,初始时,整个神经网络模型的性能显著提升,准确率最低为96%。在batch_size:64,训练次数:100情况下,准确率达到99%。下一阶在平均池化,3*3卷积核,以及不同通道数的情况下,探索对模型性能的影响。                                     

相关文章:

探索用卷积神经网络实现MNIST数据集分类

问题对比单个全连接网络,在卷积神经网络层的加持下,初始时,整个神经网络模型的性能是否会更好。方法模型设计两层卷积神经网络(包含池化层),一层全连接网络。选择 5 x 5 的卷积核,输入通道为 1&…...

MySQL 索引失效场景

1,前言 索引主要是为了提高表的查询速率,但在某些情况下,索引也会失效的情况。 2,失效场景 2.1 最左前缀法则 查询从索引最左列开始,如果跳过索引中的age列,那么age后面字段的索引都将失效,…...

Xcode开发工具,图片放入ios工程

Xcode开发工具,图片放入ios工程,有三种方式: 一:Assets Assets.xcassets 一般是以蓝色的Assets.xcassets的文件夹形式在工程中,以Image Set的形式管理。当一组图片放入的时候同时会生成描述文件Contents.jso…...

操作系统权限提升(十九)之Linux提权-SUID提权

系列文章 操作系统权限提升(十八)之Linux提权-内核提权 SUID提权 SUID介绍 SUID是一种特殊权限,设置了suid的程序文件,在用户执行该程序时,用户的权限是该程序文件属主的权限,例如程序文件的属主是root,那么执行该…...

直播 | StarRocks 实战系列第三期--StarRocks 运维的那些事

2023 年开春, StarRocks 社区重磅推出入门级实战系列直播,手把手带你从 Zero to Hero 成为一个 “StarRocks Pro”!通过实际操作和应用场景的结合,我们将帮你系统性地学习 StarRocks 这个当今最热门的开源 OLAP 数据库。本次&…...

KingabseES执行计划-分区剪枝(partition pruning)

概述 分区修剪(Partition Pruning)是分区表性能的查询优化技术 。在分区修剪中,优化器分析SQL语句中的FROM和WHERE子句,以在构建分区访问列表时消除不需要的分区。此功能使数据库只能在与SQL语句相关的分区上执行操作。 参数 enable_partition_pruning 设…...

Operator-sdk 在 KaiwuDB 容器云中的使用

一、使用背景KaiwuDB Operator 是一个自动运维部署工具,可以在 Kubernetes 环境上部署 KaiwuDB集群,借助 Operator 可实现无缝运行在公有云厂商提供的 Kubernetes 平台上,让 KaiwuDB 成为真正的 Cloud-Native 数据库。使用传统的自动化工具会…...

【数据挖掘】2、数据预处理

文章目录一、数据预处理的意义1.1 缺失数据1.1.1 原因1.1.2 方案1.1.3 离群点分析1.2 重复数据1.2.1 原因1.2.2 去重的方案1.3 数据转换1.4 数据描述二、数据预处理方法2.1 特征选择 Feature Selection2.2 特征提取 Feature Extraction2.2.1 PCA 主成分分析2.2.2 LDA 线性判别分…...

(四十六)大白话在数据库里,哪些操作会导致在表级别加锁呢?

之前我们已经给大家讲解了数据库里的行锁的概念,其实还是比较简单,容易理解的,因为在讲解锁这个概念之前,对于多事务并发以及隔离,我们已经深入讲解过了,所以大家应该很容易在脑子里有一个多事务并发执行的…...

【Android源码面试宝典】MMKV从使用到原理分析(二)

上一章节,我们从使用入手,进行了MMKV的简单讲解,我们通过分析简单的运行时日志,从中大概猜到了一些MMKV的代码内部流程,同时,我们也提出了若干的疑问?还是那句话,带着目标(问题)去阅读一篇源码,那么往往收获的知识,更加深入&扎实。 本节,我们一起来从源码层次…...

如何使用ADFSRelay分析和研究针对ADFS的NTLM中继攻击

关于ADFSRelay ADFSRelay是一款功能强大的概念验证工具,可以帮助广大研究人员分析和研究针对ADFS的NTLM中继攻击。 ADFSRelay这款工具由NTLMParse和ADFSRelay这两个实用程序组成。其中,NTLMParse用于解码base64编码的NTLM消息,并打印有关消…...

【Python学习笔记】第二十二节 Python XML 解析

一、什么是XMLXML即ExtentsibleMarkup Language(可扩展标记语言),是用来定义其它语言的一种元语言。XML 被设计用来传输和存储数据。XML 是一套定义语义标记的规则,它没有标签集(tagset),也没有语法规则(grammatical rule)。任何XML文档对任何…...

5分钟轻松拿下Java枚举

文章目录一、枚举(Enum)1.1 枚举概述1.2 定义枚举类型1.2.1 静态常量案例1.2.2 枚举案例1.2.3 枚举与switch1.3 枚举的用法1.3.1 枚举类的成员1.3.2 枚举类的构造方法1)枚举的无参构造方法2)枚举的有参构造方法1.3.3 枚举中的抽象方法1.4 Enum 类1.4.1 E…...

华为OD机试【独家】提供C语言题解 - 最小传递延迟

最近更新的博客 华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南)华为od机试,独家整理 已参加机试人员的实战技巧文章目录 最近更新的博客使用说明最小…...

【Web前端】关于JS数组方法的一些理解

一、具备栈特性的方法unshift(...items: T[]) : number将一个或多个元素添加到数组的开头,并返回该数组的新长度。shift(): T | undefined从数组中删除第一个元素,并返回该元素的值。此方法更改数组的长度。二、具备队列特性的方法push(...items: T[]): …...

多智能体集群协同控制笔记(1):线性无领航多智能体系统的一致性

对于连续时间高阶线性多智能体系统的状态方程为: x˙i(t)Axi(t)Bui(t),i1,2..N\dot {\mathbf{x}}_i(t)A\mathbf{x}_i(t)B\mathbf{u}_i(t),i1,2..N x˙i​(t)Axi​(t)Bui​(t),i1,2..N 下标iii代表第iii个智能体,ui(t)∈Rq1\mathbf{u}_i(t)\in R^{q \time…...

hadoop-Yarn资源调度器【尚硅谷】

大数据学习笔记 Yarn资源调度器 Yarn是一个资源调度平台,负责为运算程序提供服务器运算资源,相当于一个分布式的操作系统平台,而MapReduce等运算程序则相当于运行与操作系统之上的应用程序。 (也就是负责MapTask、ReduceTask等任…...

聊聊如何避免多个jar通过maven打包成一个jar,多个同名配置文件发生覆盖问题

前言 不知道大家在开发的过程中,有没有遇到这种场景,外部的项目想访问内部nexus私仓的jar,因为私仓不对外开放,导致外部的项目没法下载到私仓的jar,导致项目因缺少jar而无法运行。 通常遇到这种场景,常用…...

Flume 使用小案例

案例一:采集文件内容上传到HDFS 1)把Agent的配置保存到flume的conf目录下的 file-to-hdfs.conf 文件中 # Name the components on this agent a1.sources r1 a1.sinks k1 a1.channels c1 # Describe/configure the source a1.sources.r1.type spoo…...

DLO-SLAM代码阅读

文章目录DLO-SLAM点评代码解析OdomNode代码结构主函数 main激光回调函数 icpCB初始化 initializeDLO重力对齐 gravityAlign点云预处理 preprocessPoints关键帧指标 computeMetrics设定关键帧阈值setAdaptiveParams初始化目标数据 initializeInputTarget设置源数据 setInputSour…...

Python|GIF 解析与构建(5):手搓截屏和帧率控制

目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

无法与IP建立连接,未能下载VSCode服务器

如题,在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈,发现是VSCode版本自动更新惹的祸!!! 在VSCode的帮助->关于这里发现前几天VSCode自动更新了,我的版本号变成了1.100.3 才导致了远程连接出…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址:pdf 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作:ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等(ArcGIS出图图例8大技巧),那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

基于Java+VUE+MariaDB实现(Web)仿小米商城

仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意:运行前…...

【Linux系统】Linux环境变量:系统配置的隐形指挥官

。# Linux系列 文章目录 前言一、环境变量的概念二、常见的环境变量三、环境变量特点及其相关指令3.1 环境变量的全局性3.2、环境变量的生命周期 四、环境变量的组织方式五、C语言对环境变量的操作5.1 设置环境变量:setenv5.2 删除环境变量:unsetenv5.3 遍历所有环境…...