当前位置: 首页 > news >正文

探索用卷积神经网络实现MNIST数据集分类

问题
对比单个全连接网络,在卷积神经网络层的加持下,初始时,整个神经网络模型的性能是否会更好。

方法

模型设计
两层卷积神经网络(包含池化层),一层全连接网络。

  1. 选择 5 x 5 的卷积核,输入通道为 1,输出通道为 10:

    此时图像矩阵经过 5 x 5 的卷积核后会小两圈,也就是4个数位,变成 24 x 24,输出通道为10;

  2. 选择 2 x 2 的最大池化层:

    此时图像大小缩短一半,变成 12 x 12,通道数不变;

  3. 再次经过5 x 5的卷积核,输入通道为 10,输出通道为 20:

    此时图像再小两圈,变成 8*8,输出通道为20;

  4. 再次经过2 x 2的最大池化层:

    此时图像大小缩短一半,变成 4 x 4,通道数不变;

  5. 最后将图像整型变换成向量,输入到全连接层中:

    输入一共有 4 x 4 x 20 = 320个元素,输出为 10.

代码

准备数据集

# 准备数据集

batch_size = 64

transform = transforms.Compose([

   transforms.ToTensor(),

   transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = datasets.MNIST(root='data’,

                              train=True,

                              download=True,

                              transform=transform)

train_loader = DataLoader(train_dataset,

                         shuffle=True,

                         batch_size=batch_size)

test_dataset = datasets.MNIST(root='data',

                             train=False,

                             download=True,

                             transform=transform)

test_loader = DataLoader(test_dataset,

                        shuffle=False,

                        batch_size=batch_size)

建立模型

class Net(torch.nn.Module):

   def __init__(self):

       super(Net, self).__init__()

       self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)

       self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)

       self.pooling = torch.nn.MaxPool2d(2)

       self.fc = torch.nn.Linear(320, 10)

   def forward(self, x):

       batch_size = x.size(0)

       x = F.relu(self.pooling(self.conv1(x)))

       x = F.relu(self.pooling(self.conv2(x)))

       x = x.view(batch_size, -1)

       x = self.fc(x)

       return x

model = Net()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

构造损失函数+优化器

criterion = torch.nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

训练+测试

def train(epoch):

   running_loss = 0.0

   for batch_idx, data in enumerate(train_loader, 0):

       inputs, target = data

       inputs,target=inputs.to(device),target.to(device)

       optimizer.zero_grad()

       outputs = model(inputs)

       loss = criterion(outputs, target)

       loss.backward()

       optimizer.step()

       running_loss += loss.item()

       if batch_idx % 300 == 299:

           print('[%d,%.5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 2000))

           running_loss = 0.0

def test():

   correct=0

   total=0

   with torch.no_grad():

       for data in test_loader:

           inputs,target=data

           inputs,target=inputs.to(device),target.to(device)

           outputs=model(inputs)

           _,predicted=torch.max(outputs.data,dim=1)

           total+=target.size(0)

           correct+=(predicted==target).sum().item()

   print('Accuracy on test set:%d %% [%d%d]' %(100*correct/total,correct,total))

if __name__ =='__main__':

   for epoch in range(10):

       train(epoch)

       test()

运行结果

(1)batch_size:64,训练次数:10

180acf4f7c17af3db55512491cf13c94.png

0e80fc811d9fca29fe2d204f6636c7f0.png

(2)batch_size:128,训练次数:10

3f1ecc11eaa504aaac05dd597b2e2c4b.png

(3)batch_size:128,训练次数:10

542f1af0a5712cdb6403b0be0d47f394.png

结语

对比单个全连接网络,在卷积神经网络层的加持下,初始时,整个神经网络模型的性能显著提升,准确率最低为96%。在batch_size:64,训练次数:100情况下,准确率达到99%。下一阶在平均池化,3*3卷积核,以及不同通道数的情况下,探索对模型性能的影响。                                     

相关文章:

探索用卷积神经网络实现MNIST数据集分类

问题对比单个全连接网络,在卷积神经网络层的加持下,初始时,整个神经网络模型的性能是否会更好。方法模型设计两层卷积神经网络(包含池化层),一层全连接网络。选择 5 x 5 的卷积核,输入通道为 1&…...

MySQL 索引失效场景

1,前言 索引主要是为了提高表的查询速率,但在某些情况下,索引也会失效的情况。 2,失效场景 2.1 最左前缀法则 查询从索引最左列开始,如果跳过索引中的age列,那么age后面字段的索引都将失效,…...

Xcode开发工具,图片放入ios工程

Xcode开发工具,图片放入ios工程,有三种方式: 一:Assets Assets.xcassets 一般是以蓝色的Assets.xcassets的文件夹形式在工程中,以Image Set的形式管理。当一组图片放入的时候同时会生成描述文件Contents.jso…...

操作系统权限提升(十九)之Linux提权-SUID提权

系列文章 操作系统权限提升(十八)之Linux提权-内核提权 SUID提权 SUID介绍 SUID是一种特殊权限,设置了suid的程序文件,在用户执行该程序时,用户的权限是该程序文件属主的权限,例如程序文件的属主是root,那么执行该…...

直播 | StarRocks 实战系列第三期--StarRocks 运维的那些事

2023 年开春, StarRocks 社区重磅推出入门级实战系列直播,手把手带你从 Zero to Hero 成为一个 “StarRocks Pro”!通过实际操作和应用场景的结合,我们将帮你系统性地学习 StarRocks 这个当今最热门的开源 OLAP 数据库。本次&…...

KingabseES执行计划-分区剪枝(partition pruning)

概述 分区修剪(Partition Pruning)是分区表性能的查询优化技术 。在分区修剪中,优化器分析SQL语句中的FROM和WHERE子句,以在构建分区访问列表时消除不需要的分区。此功能使数据库只能在与SQL语句相关的分区上执行操作。 参数 enable_partition_pruning 设…...

Operator-sdk 在 KaiwuDB 容器云中的使用

一、使用背景KaiwuDB Operator 是一个自动运维部署工具,可以在 Kubernetes 环境上部署 KaiwuDB集群,借助 Operator 可实现无缝运行在公有云厂商提供的 Kubernetes 平台上,让 KaiwuDB 成为真正的 Cloud-Native 数据库。使用传统的自动化工具会…...

【数据挖掘】2、数据预处理

文章目录一、数据预处理的意义1.1 缺失数据1.1.1 原因1.1.2 方案1.1.3 离群点分析1.2 重复数据1.2.1 原因1.2.2 去重的方案1.3 数据转换1.4 数据描述二、数据预处理方法2.1 特征选择 Feature Selection2.2 特征提取 Feature Extraction2.2.1 PCA 主成分分析2.2.2 LDA 线性判别分…...

(四十六)大白话在数据库里,哪些操作会导致在表级别加锁呢?

之前我们已经给大家讲解了数据库里的行锁的概念,其实还是比较简单,容易理解的,因为在讲解锁这个概念之前,对于多事务并发以及隔离,我们已经深入讲解过了,所以大家应该很容易在脑子里有一个多事务并发执行的…...

【Android源码面试宝典】MMKV从使用到原理分析(二)

上一章节,我们从使用入手,进行了MMKV的简单讲解,我们通过分析简单的运行时日志,从中大概猜到了一些MMKV的代码内部流程,同时,我们也提出了若干的疑问?还是那句话,带着目标(问题)去阅读一篇源码,那么往往收获的知识,更加深入&扎实。 本节,我们一起来从源码层次…...

如何使用ADFSRelay分析和研究针对ADFS的NTLM中继攻击

关于ADFSRelay ADFSRelay是一款功能强大的概念验证工具,可以帮助广大研究人员分析和研究针对ADFS的NTLM中继攻击。 ADFSRelay这款工具由NTLMParse和ADFSRelay这两个实用程序组成。其中,NTLMParse用于解码base64编码的NTLM消息,并打印有关消…...

【Python学习笔记】第二十二节 Python XML 解析

一、什么是XMLXML即ExtentsibleMarkup Language(可扩展标记语言),是用来定义其它语言的一种元语言。XML 被设计用来传输和存储数据。XML 是一套定义语义标记的规则,它没有标签集(tagset),也没有语法规则(grammatical rule)。任何XML文档对任何…...

5分钟轻松拿下Java枚举

文章目录一、枚举(Enum)1.1 枚举概述1.2 定义枚举类型1.2.1 静态常量案例1.2.2 枚举案例1.2.3 枚举与switch1.3 枚举的用法1.3.1 枚举类的成员1.3.2 枚举类的构造方法1)枚举的无参构造方法2)枚举的有参构造方法1.3.3 枚举中的抽象方法1.4 Enum 类1.4.1 E…...

华为OD机试【独家】提供C语言题解 - 最小传递延迟

最近更新的博客 华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南)华为od机试,独家整理 已参加机试人员的实战技巧文章目录 最近更新的博客使用说明最小…...

【Web前端】关于JS数组方法的一些理解

一、具备栈特性的方法unshift(...items: T[]) : number将一个或多个元素添加到数组的开头,并返回该数组的新长度。shift(): T | undefined从数组中删除第一个元素,并返回该元素的值。此方法更改数组的长度。二、具备队列特性的方法push(...items: T[]): …...

多智能体集群协同控制笔记(1):线性无领航多智能体系统的一致性

对于连续时间高阶线性多智能体系统的状态方程为: x˙i(t)Axi(t)Bui(t),i1,2..N\dot {\mathbf{x}}_i(t)A\mathbf{x}_i(t)B\mathbf{u}_i(t),i1,2..N x˙i​(t)Axi​(t)Bui​(t),i1,2..N 下标iii代表第iii个智能体,ui(t)∈Rq1\mathbf{u}_i(t)\in R^{q \time…...

hadoop-Yarn资源调度器【尚硅谷】

大数据学习笔记 Yarn资源调度器 Yarn是一个资源调度平台,负责为运算程序提供服务器运算资源,相当于一个分布式的操作系统平台,而MapReduce等运算程序则相当于运行与操作系统之上的应用程序。 (也就是负责MapTask、ReduceTask等任…...

聊聊如何避免多个jar通过maven打包成一个jar,多个同名配置文件发生覆盖问题

前言 不知道大家在开发的过程中,有没有遇到这种场景,外部的项目想访问内部nexus私仓的jar,因为私仓不对外开放,导致外部的项目没法下载到私仓的jar,导致项目因缺少jar而无法运行。 通常遇到这种场景,常用…...

Flume 使用小案例

案例一:采集文件内容上传到HDFS 1)把Agent的配置保存到flume的conf目录下的 file-to-hdfs.conf 文件中 # Name the components on this agent a1.sources r1 a1.sinks k1 a1.channels c1 # Describe/configure the source a1.sources.r1.type spoo…...

DLO-SLAM代码阅读

文章目录DLO-SLAM点评代码解析OdomNode代码结构主函数 main激光回调函数 icpCB初始化 initializeDLO重力对齐 gravityAlign点云预处理 preprocessPoints关键帧指标 computeMetrics设定关键帧阈值setAdaptiveParams初始化目标数据 initializeInputTarget设置源数据 setInputSour…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...

基于Docker Compose部署Java微服务项目

一. 创建根项目 根项目&#xff08;父项目&#xff09;主要用于依赖管理 一些需要注意的点&#xff1a; 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件&#xff0c;否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...

自然语言处理——Transformer

自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效&#xff0c;它能挖掘数据中的时序信息以及语义信息&#xff0c;但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN&#xff0c;但是…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

Vite中定义@软链接

在webpack中可以直接通过符号表示src路径&#xff0c;但是vite中默认不可以。 如何实现&#xff1a; vite中提供了resolve.alias&#xff1a;通过别名在指向一个具体的路径 在vite.config.js中 import { join } from pathexport default defineConfig({plugins: [vue()],//…...

实战三:开发网页端界面完成黑白视频转为彩色视频

​一、需求描述 设计一个简单的视频上色应用&#xff0c;用户可以通过网页界面上传黑白视频&#xff0c;系统会自动将其转换为彩色视频。整个过程对用户来说非常简单直观&#xff0c;不需要了解技术细节。 效果图 ​二、实现思路 总体思路&#xff1a; 用户通过Gradio界面上…...

redis和redission的区别

Redis 和 Redisson 是两个密切相关但又本质不同的技术&#xff0c;它们扮演着完全不同的角色&#xff1a; Redis: 内存数据库/数据结构存储 本质&#xff1a; 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能&#xff1a; 提供丰…...

Java数组Arrays操作全攻略

Arrays类的概述 Java中的Arrays类位于java.util包中&#xff0c;提供了一系列静态方法用于操作数组&#xff08;如排序、搜索、填充、比较等&#xff09;。这些方法适用于基本类型数组和对象数组。 常用成员方法及代码示例 排序&#xff08;sort&#xff09; 对数组进行升序…...

【大模型】RankRAG:基于大模型的上下文排序与检索增强生成的统一框架

文章目录 A 论文出处B 背景B.1 背景介绍B.2 问题提出B.3 创新点 C 模型结构C.1 指令微调阶段C.2 排名与生成的总和指令微调阶段C.3 RankRAG推理&#xff1a;检索-重排-生成 D 实验设计E 个人总结 A 论文出处 论文题目&#xff1a;RankRAG&#xff1a;Unifying Context Ranking…...