基于深度可分离卷积的MNIST手势识别
基于深度可分离膨胀卷积的MNIST手写体识别
Github链接
项目背景:
MNIST手写体数据集是深度学习领域中一个经典的入门数据集,包含了从0到9的手写数字图像,用于评估不同模型在图像分类任务上的性能。在本项目中,我们通过设计一种基于深度可分离膨胀卷积的神经网络模型,解决模型参数量大与特征提取能力不足之间的矛盾,同时实现对MNIST手写数字的高效识别。
核心技术:
- 深度卷积:
深度卷积(Depthwise Convolution)将标准卷积分解为每个通道独立的卷积操作,从而显著减少计算量。相比传统卷积操作,深度卷积只需要处理每个通道的卷积,不会引入通道间的冗余计算。 - 点卷积:
点卷积(Pointwise Convolution)采用1×1的卷积核,用于整合深度卷积生成的特征,将通道间信息融合,增强表达能力。点卷积是深度可分离卷积不可或缺的部分,负责重建多通道的全局信息。 - 膨胀卷积:
膨胀卷积(Dilated Convolution)通过在卷积核间插入空洞扩展感受野,允许模型捕获更大范围的上下文信息,特别适合处理具有稀疏特征的任务,同时避免了增加参数量的开销。
实现流程:
- 数据准备:
- 加载MNIST数据集,进行标准化预处理。
- 将数据分为训练集和测试集,保证模型的评估结果具备可靠性。
- 模型设计:
- 构建基于深度卷积、点卷积及膨胀卷积的神经网络结构,重点在于设计轻量化且具有良好表达能力的卷积模块。
- 使用ReLU激活函数和全连接层对提取的特征进行分类处理。
- 模型训练与测试:
- 使用交叉熵损失函数和Adam优化器训练模型,记录损失值变化以监控收敛情况。
- 测试阶段,评估模型在MNIST测试集上的分类准确率,并验证其泛化能力。
- 参数量对比分析:
- 对比标准卷积和深度可分离卷积在参数量上的差异,直观展示优化效果。
- 在同等条件下,深度可分离卷积显著减少参数量,同时保持了分类性能的稳定性。
项目成果:
通过本项目的实验,模型在MNIST数据集上的分类准确率达到了接近 98.7% 的水平。结合深度可分离膨胀卷积的轻量化设计,我们在大幅减少参数量的同时,实现了与传统卷积模型媲美的性能。此方法为资源受限的场景(如移动设备和嵌入式系统)提供了一种有效的解决方案。
结论与展望:
本项目展示了深度可分离膨胀卷积在图像分类任务中的强大能力,特别是在参数量和计算量优化方面。未来的工作可以尝试将该方法应用于更复杂的数据集和任务场景,例如自然图像分类、目标检测或语义分割,从而进一步验证其通用性与潜力。
# @Time : 28/12/2024 上午 10:00
# @Author : Xuan
# @File : 基于深度可分离膨胀卷积的MNIST手写体识别.py
# @Software: PyCharmimport torch
import torch.nn as nn
import einops.layers.torch as elt
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 下载并加载训练集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 下载并加载测试集
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)# 检查数据集大小
print(f'Train dataset size: {len(train_dataset)}') # Train dataset size: 60000
print(f'Test dataset size: {len(test_dataset)}') # Test dataset size: 10000class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 深度、可分离、膨胀卷积self.conv = nn.Sequential(nn.Conv2d(1, 12, kernel_size=7),nn.ReLU(),nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, groups=6, dilation=2),nn.Conv2d(in_channels=12, out_channels=24, kernel_size=1),nn.ReLU(),nn.Conv2d(24, 6, kernel_size=7),)self.logits_layer = nn.Linear(in_features=6 * 12 * 12, out_features=10)def forward(self, x):x = self.conv(x)x = elt.Rearrange('b c h w -> b (c h w)')(x)logits = self.logits_layer(x)return logitsdevice = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for epoch in range(epochs):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')# Save the model
torch.save(model.state_dict(), './modelsave/mnist.pth')# Load the model
model = Model().to(device)
model.load_state_dict(torch.load('./modelsave/mnist.pth'))# Evaluation
model.eval()
correct = 0
first_image_displayed = False
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()# Display the first image and its predictionif not first_image_displayed:plt.imshow(data[0].cpu().numpy().squeeze(), cmap='gray')plt.title(f'Predicted: {pred[0].item()}')plt.show()first_image_displayed = Trueprint(f'Test set: Accuracy: {correct / len(test_loader.dataset):.4f}') # Test set: Accuracy: 0.9874# 深度可分离卷积参数比较
# 普通卷积参数量
conv = nn.Conv2d(in_channels=3 ,out_channels=3, kernel_size=3) # in(3) * out(3) * k(3) * k(3) + out(3) = 84
conv_params = sum(p.numel() for p in conv.parameters())
print('conv_params:', conv_params) # conv_params: 84# 深度可分离卷积参数量
depthwise = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, groups=3) # in(3) * k(3) * k(3) + out(3) = 30
pointwise = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) # in(3) * out(3) * k(1) * k(1) + out(3) = 12
depthwise_separable = nn.Sequential(depthwise, pointwise)
depthwise_separable_params = sum(p.numel() for p in depthwise_separable.parameters())
print('depthwise_separable_params:', depthwise_separable_params) # depthwise_separable_params: 42
相关文章:

基于深度可分离卷积的MNIST手势识别
基于深度可分离膨胀卷积的MNIST手写体识别 Github链接 项目背景: MNIST手写体数据集是深度学习领域中一个经典的入门数据集,包含了从0到9的手写数字图像,用于评估不同模型在图像分类任务上的性能。在本项目中,我们通过设计一种基…...

Linux服务器pm2 运行chatgpt-on-wechat,搭建微信群ai机器人
安装 1.拉取项目 项目地址: chatgpt-on-wechat 2.安装依赖 pip3 install -r requirements.txt pip3 install -r requirements-optional.txt3、获取API信息 当前免费的有百度的文心一言,讯飞的个人认证提供500万token的额度。 控制台-讯飞开放平台 添加链接描述…...
Word批量更改题注
文章目录 批量更改批量去除空格 在写文章的时候,往往对图片题注有着统一的编码要求,例如以【图 1- xx】。一般会点击【引用】->【插入题注】来插入题注,并且在引用的时候,点击【引用】->【交叉引用】,并且在交叉…...
Springboot配置嵌入式服务器
一.如何定制和修改Servlet容器的相关配置 修改和server有关的配置(ServerProperties); server.port8081 server.context‐path/tx server.tomcat.uri‐encodingUTF‐8 在yml配置文件的写法: server:port: 8081servlet:context-pa…...
正交三角函数全面阐述
目录 1. 正交性定义 2. 正交三角函数 常见的正交三角函数 3. 正交三角函数的特性 4. 正交三角函数在傅里叶分析中的应用 5. 正交三角函数的应用领域 6. 总结 正交三角函数是指在特定条件下,三角函数之间的内积为零。更具体地说,在数学分析、信号处…...
《Vue3 四》Vue 的组件化
组件化:将一个页面拆分成一个个小的功能模块,每个功能模块完成自己部分的独立的功能。任何应用都可以被抽象成一棵组件树。 Vue 中的根组件: Vue.createApp() 中传入对象的本质上就是一个组件,称之为根组件(APP 组件…...
linux中,mysql数据库分片(分库分表)
1.mysql分库分表:解决单个mysql存储上限问题1.实现方法:存储层面:利用分布式存储解决方案分库分表:拆分库和表到其它服务器2.常用设计思路:垂直分库(库里面的表分开)水平分表(表里面的数据分开)分库:数据库分为多个,每个数据库里面都有表,数据均匀存储分库分表:在分的每库里面,…...

springboot503基于Sringboot+Vue个人驾校预约管理系统(论文+源码)_kaic
摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装个人驾校预约管理系统软件来发挥其高效地信息处理的作用&am…...

Docker应用-项目部署及DockerCompose
文章目录 Docker应用-项目部署1. 项目部署-后端1.1 修改配置1.2 项目打包1.3 编写Dockerfile1.4 创建镜像1.5 创建并运行容器1.6 测试 2. 项目部署-前端2.1 html前端静态目录2.2 nginx.config编写2.3 部署宿主机服务器2.4 创建容器并挂载2.5 测试 3. DockerCompose3.1 基本语法…...

从0入门自主空中机器人-2-1【无人机硬件框架】
关于本课程: 本次课程是一套面向对自主空中机器人感兴趣的学生、爱好者、相关从业人员的免费课程,包含了从硬件组装、机载电脑环境设置、代码部署、实机实验等全套详细流程,带你从0开始,组装属于自己的自主无人机,并让…...

Kafka高性能设计
高性能设计概述 Kafka高性能是多方面协同的结果,包括集群架构、分布式存储、ISR数据同步及高效利用磁盘和操作系统特性等。主要体现在消息分区、顺序读写、页缓存、零拷贝、消息压缩和分批发送六个方面。 消息分区 存储不受单台服务器限制,能处理更多数据…...
Redis字符串底层结构对数值型的支持常用数据结构和使用场景
字符串底层结构 SDS (Simple Dynamic Strings) 是 Redis 中用于实现字符串类型的一种数据结构。SDS 的设计目标是提供高效、灵活的字符串操作,同时避免传统 C 字符串的一些缺点。 struct sdshdr {int len; // 已使用的长度int free; // 未使用的长度char bu…...

uniapp 微信小程序 数据空白展示组件
效果图 html <template><view class"nodata"><view class""><image class"nodataimg":src"$publicfun.locaAndHttp()?localUrl:$publicfun.httpUrlImg(httUrl)"mode"aspectFit"></image>&l…...

在vscode的ESP-IDF中使用自定义组件
以hello-world为例,演示步骤和注意事项 1、新建ESP-IDF项目 选择模板 从hello-world模板创建 2、打开项目 3、编译结果没错 正在执行任务: /home/azhu/.espressif/python_env/idf5.1_py3.10_env/bin/python /home/azhu/esp/v5.1/esp-idf/tools/idf_size.py /home…...

目标检测,语义分割标注工具--labelimg labelme
1 labelimg labelimg可以用来标注目标检测的数据集, 提供多种格式的输出, 如Pascal Voc, YOLO等。 1.1 安装 pip install labelimg1.2 使用 命令行直接输入labelimg即可打开软件主界面进行操作。 使用非常简单, 不做过细的介绍࿰…...

发明专利与实用新型专利申请过程及自助与代办方式对比
申请专利(发明专利、实用新型专利、外观设计专利)有两种方式:1、自己直接向国家知识产权局申请。2、通过专利代办处申请。以下是对这两种专利类型(发明专利、实用新型专利)申请过程及两种申请方式的详细介绍和对比,参考…...

Datawhale AI冬令营(第二期)动手学AI Agent task2--学Prompt工程,优化Agent效果
目录 如何写好Prompt? 工具包神器1:Prompt框架——CO-STAR 框架 工具包神器2:Prompt结构优化 工具包神器3:引入案例 案例:构建虚拟女友小冰 1. 按照 CO-STAR框架 梳理目标 2. 撰写Prompt 3. 制作对话生成应用&…...

基于python对网页进行爬虫简单教程
python对网页进行爬虫 基于BeautifulSoup的爬虫—源码 """ 基于BeautifulSoup的爬虫###?一、BeautifulSoup简介1.?Beautiful?Soup提供一些简单的、python式的函数用来处理导航、搜索、修改分析树等功能。它是一个工具箱,通过解析文档为用户提供…...

【JavaEE进阶】@RequestMapping注解
目录 📕前言 🌴项目准备 🌲建立连接 🚩RequestMapping注解 🚩RequestMapping 注解介绍 🎄RequestMapping是GET还是POST请求? 🚩通过Fiddler查看 🚩Postman查看 …...

【WebAR-图像跟踪】在Unity中基于Imagine WebAR实现AR图像识别
写在前面的话 感慨一下, WebXR的发展是真的快,20年的时候,大多都在用AR.js做WebAR。随着WebXR标准发展,现在诸如Threejs、AFrame、Unity等多个平台都支持里WebXR。 本文将介绍在Unity中使用 Image Tracker实现Web端的AR图像识别功…...
Vim 调用外部命令学习笔记
Vim 外部命令集成完全指南 文章目录 Vim 外部命令集成完全指南核心概念理解命令语法解析语法对比 常用外部命令详解文本排序与去重文本筛选与搜索高级 grep 搜索技巧文本替换与编辑字符处理高级文本处理编程语言处理其他实用命令 范围操作示例指定行范围处理复合命令示例 实用技…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版分享
平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...
【论文笔记】若干矿井粉尘检测算法概述
总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
反射获取方法和属性
Java反射获取方法 在Java中,反射(Reflection)是一种强大的机制,允许程序在运行时访问和操作类的内部属性和方法。通过反射,可以动态地创建对象、调用方法、改变属性值,这在很多Java框架中如Spring和Hiberna…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战
在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...

如何在最短时间内提升打ctf(web)的水平?
刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...

【Oracle】分区表
个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...