基于深度可分离卷积的MNIST手势识别
基于深度可分离膨胀卷积的MNIST手写体识别
Github链接
项目背景:
MNIST手写体数据集是深度学习领域中一个经典的入门数据集,包含了从0到9的手写数字图像,用于评估不同模型在图像分类任务上的性能。在本项目中,我们通过设计一种基于深度可分离膨胀卷积的神经网络模型,解决模型参数量大与特征提取能力不足之间的矛盾,同时实现对MNIST手写数字的高效识别。
核心技术:
- 深度卷积:
深度卷积(Depthwise Convolution)将标准卷积分解为每个通道独立的卷积操作,从而显著减少计算量。相比传统卷积操作,深度卷积只需要处理每个通道的卷积,不会引入通道间的冗余计算。 - 点卷积:
点卷积(Pointwise Convolution)采用1×1的卷积核,用于整合深度卷积生成的特征,将通道间信息融合,增强表达能力。点卷积是深度可分离卷积不可或缺的部分,负责重建多通道的全局信息。 - 膨胀卷积:
膨胀卷积(Dilated Convolution)通过在卷积核间插入空洞扩展感受野,允许模型捕获更大范围的上下文信息,特别适合处理具有稀疏特征的任务,同时避免了增加参数量的开销。

实现流程:
- 数据准备:
- 加载MNIST数据集,进行标准化预处理。
- 将数据分为训练集和测试集,保证模型的评估结果具备可靠性。
- 模型设计:
- 构建基于深度卷积、点卷积及膨胀卷积的神经网络结构,重点在于设计轻量化且具有良好表达能力的卷积模块。
- 使用ReLU激活函数和全连接层对提取的特征进行分类处理。
- 模型训练与测试:
- 使用交叉熵损失函数和Adam优化器训练模型,记录损失值变化以监控收敛情况。
- 测试阶段,评估模型在MNIST测试集上的分类准确率,并验证其泛化能力。
- 参数量对比分析:
- 对比标准卷积和深度可分离卷积在参数量上的差异,直观展示优化效果。
- 在同等条件下,深度可分离卷积显著减少参数量,同时保持了分类性能的稳定性。
项目成果:
通过本项目的实验,模型在MNIST数据集上的分类准确率达到了接近 98.7% 的水平。结合深度可分离膨胀卷积的轻量化设计,我们在大幅减少参数量的同时,实现了与传统卷积模型媲美的性能。此方法为资源受限的场景(如移动设备和嵌入式系统)提供了一种有效的解决方案。
结论与展望:
本项目展示了深度可分离膨胀卷积在图像分类任务中的强大能力,特别是在参数量和计算量优化方面。未来的工作可以尝试将该方法应用于更复杂的数据集和任务场景,例如自然图像分类、目标检测或语义分割,从而进一步验证其通用性与潜力。
# @Time : 28/12/2024 上午 10:00
# @Author : Xuan
# @File : 基于深度可分离膨胀卷积的MNIST手写体识别.py
# @Software: PyCharmimport torch
import torch.nn as nn
import einops.layers.torch as elt
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 下载并加载训练集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)# 下载并加载测试集
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)# 检查数据集大小
print(f'Train dataset size: {len(train_dataset)}') # Train dataset size: 60000
print(f'Test dataset size: {len(test_dataset)}') # Test dataset size: 10000class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 深度、可分离、膨胀卷积self.conv = nn.Sequential(nn.Conv2d(1, 12, kernel_size=7),nn.ReLU(),nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, groups=6, dilation=2),nn.Conv2d(in_channels=12, out_channels=24, kernel_size=1),nn.ReLU(),nn.Conv2d(24, 6, kernel_size=7),)self.logits_layer = nn.Linear(in_features=6 * 12 * 12, out_features=10)def forward(self, x):x = self.conv(x)x = elt.Rearrange('b c h w -> b (c h w)')(x)logits = self.logits_layer(x)return logitsdevice = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)epochs = 10
for epoch in range(epochs):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')# Save the model
torch.save(model.state_dict(), './modelsave/mnist.pth')# Load the model
model = Model().to(device)
model.load_state_dict(torch.load('./modelsave/mnist.pth'))# Evaluation
model.eval()
correct = 0
first_image_displayed = False
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()# Display the first image and its predictionif not first_image_displayed:plt.imshow(data[0].cpu().numpy().squeeze(), cmap='gray')plt.title(f'Predicted: {pred[0].item()}')plt.show()first_image_displayed = Trueprint(f'Test set: Accuracy: {correct / len(test_loader.dataset):.4f}') # Test set: Accuracy: 0.9874# 深度可分离卷积参数比较
# 普通卷积参数量
conv = nn.Conv2d(in_channels=3 ,out_channels=3, kernel_size=3) # in(3) * out(3) * k(3) * k(3) + out(3) = 84
conv_params = sum(p.numel() for p in conv.parameters())
print('conv_params:', conv_params) # conv_params: 84# 深度可分离卷积参数量
depthwise = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, groups=3) # in(3) * k(3) * k(3) + out(3) = 30
pointwise = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) # in(3) * out(3) * k(1) * k(1) + out(3) = 12
depthwise_separable = nn.Sequential(depthwise, pointwise)
depthwise_separable_params = sum(p.numel() for p in depthwise_separable.parameters())
print('depthwise_separable_params:', depthwise_separable_params) # depthwise_separable_params: 42

相关文章:
基于深度可分离卷积的MNIST手势识别
基于深度可分离膨胀卷积的MNIST手写体识别 Github链接 项目背景: MNIST手写体数据集是深度学习领域中一个经典的入门数据集,包含了从0到9的手写数字图像,用于评估不同模型在图像分类任务上的性能。在本项目中,我们通过设计一种基…...
Linux服务器pm2 运行chatgpt-on-wechat,搭建微信群ai机器人
安装 1.拉取项目 项目地址: chatgpt-on-wechat 2.安装依赖 pip3 install -r requirements.txt pip3 install -r requirements-optional.txt3、获取API信息 当前免费的有百度的文心一言,讯飞的个人认证提供500万token的额度。 控制台-讯飞开放平台 添加链接描述…...
Word批量更改题注
文章目录 批量更改批量去除空格 在写文章的时候,往往对图片题注有着统一的编码要求,例如以【图 1- xx】。一般会点击【引用】->【插入题注】来插入题注,并且在引用的时候,点击【引用】->【交叉引用】,并且在交叉…...
Springboot配置嵌入式服务器
一.如何定制和修改Servlet容器的相关配置 修改和server有关的配置(ServerProperties); server.port8081 server.context‐path/tx server.tomcat.uri‐encodingUTF‐8 在yml配置文件的写法: server:port: 8081servlet:context-pa…...
正交三角函数全面阐述
目录 1. 正交性定义 2. 正交三角函数 常见的正交三角函数 3. 正交三角函数的特性 4. 正交三角函数在傅里叶分析中的应用 5. 正交三角函数的应用领域 6. 总结 正交三角函数是指在特定条件下,三角函数之间的内积为零。更具体地说,在数学分析、信号处…...
《Vue3 四》Vue 的组件化
组件化:将一个页面拆分成一个个小的功能模块,每个功能模块完成自己部分的独立的功能。任何应用都可以被抽象成一棵组件树。 Vue 中的根组件: Vue.createApp() 中传入对象的本质上就是一个组件,称之为根组件(APP 组件…...
linux中,mysql数据库分片(分库分表)
1.mysql分库分表:解决单个mysql存储上限问题1.实现方法:存储层面:利用分布式存储解决方案分库分表:拆分库和表到其它服务器2.常用设计思路:垂直分库(库里面的表分开)水平分表(表里面的数据分开)分库:数据库分为多个,每个数据库里面都有表,数据均匀存储分库分表:在分的每库里面,…...
springboot503基于Sringboot+Vue个人驾校预约管理系统(论文+源码)_kaic
摘 要 传统办法管理信息首先需要花费的时间比较多,其次数据出错率比较高,而且对错误的数据进行更改也比较困难,最后,检索数据费事费力。因此,在计算机上安装个人驾校预约管理系统软件来发挥其高效地信息处理的作用&am…...
Docker应用-项目部署及DockerCompose
文章目录 Docker应用-项目部署1. 项目部署-后端1.1 修改配置1.2 项目打包1.3 编写Dockerfile1.4 创建镜像1.5 创建并运行容器1.6 测试 2. 项目部署-前端2.1 html前端静态目录2.2 nginx.config编写2.3 部署宿主机服务器2.4 创建容器并挂载2.5 测试 3. DockerCompose3.1 基本语法…...
从0入门自主空中机器人-2-1【无人机硬件框架】
关于本课程: 本次课程是一套面向对自主空中机器人感兴趣的学生、爱好者、相关从业人员的免费课程,包含了从硬件组装、机载电脑环境设置、代码部署、实机实验等全套详细流程,带你从0开始,组装属于自己的自主无人机,并让…...
Kafka高性能设计
高性能设计概述 Kafka高性能是多方面协同的结果,包括集群架构、分布式存储、ISR数据同步及高效利用磁盘和操作系统特性等。主要体现在消息分区、顺序读写、页缓存、零拷贝、消息压缩和分批发送六个方面。 消息分区 存储不受单台服务器限制,能处理更多数据…...
Redis字符串底层结构对数值型的支持常用数据结构和使用场景
字符串底层结构 SDS (Simple Dynamic Strings) 是 Redis 中用于实现字符串类型的一种数据结构。SDS 的设计目标是提供高效、灵活的字符串操作,同时避免传统 C 字符串的一些缺点。 struct sdshdr {int len; // 已使用的长度int free; // 未使用的长度char bu…...
uniapp 微信小程序 数据空白展示组件
效果图 html <template><view class"nodata"><view class""><image class"nodataimg":src"$publicfun.locaAndHttp()?localUrl:$publicfun.httpUrlImg(httUrl)"mode"aspectFit"></image>&l…...
在vscode的ESP-IDF中使用自定义组件
以hello-world为例,演示步骤和注意事项 1、新建ESP-IDF项目 选择模板 从hello-world模板创建 2、打开项目 3、编译结果没错 正在执行任务: /home/azhu/.espressif/python_env/idf5.1_py3.10_env/bin/python /home/azhu/esp/v5.1/esp-idf/tools/idf_size.py /home…...
目标检测,语义分割标注工具--labelimg labelme
1 labelimg labelimg可以用来标注目标检测的数据集, 提供多种格式的输出, 如Pascal Voc, YOLO等。 1.1 安装 pip install labelimg1.2 使用 命令行直接输入labelimg即可打开软件主界面进行操作。 使用非常简单, 不做过细的介绍࿰…...
发明专利与实用新型专利申请过程及自助与代办方式对比
申请专利(发明专利、实用新型专利、外观设计专利)有两种方式:1、自己直接向国家知识产权局申请。2、通过专利代办处申请。以下是对这两种专利类型(发明专利、实用新型专利)申请过程及两种申请方式的详细介绍和对比,参考…...
Datawhale AI冬令营(第二期)动手学AI Agent task2--学Prompt工程,优化Agent效果
目录 如何写好Prompt? 工具包神器1:Prompt框架——CO-STAR 框架 工具包神器2:Prompt结构优化 工具包神器3:引入案例 案例:构建虚拟女友小冰 1. 按照 CO-STAR框架 梳理目标 2. 撰写Prompt 3. 制作对话生成应用&…...
基于python对网页进行爬虫简单教程
python对网页进行爬虫 基于BeautifulSoup的爬虫—源码 """ 基于BeautifulSoup的爬虫###?一、BeautifulSoup简介1.?Beautiful?Soup提供一些简单的、python式的函数用来处理导航、搜索、修改分析树等功能。它是一个工具箱,通过解析文档为用户提供…...
【JavaEE进阶】@RequestMapping注解
目录 📕前言 🌴项目准备 🌲建立连接 🚩RequestMapping注解 🚩RequestMapping 注解介绍 🎄RequestMapping是GET还是POST请求? 🚩通过Fiddler查看 🚩Postman查看 …...
【WebAR-图像跟踪】在Unity中基于Imagine WebAR实现AR图像识别
写在前面的话 感慨一下, WebXR的发展是真的快,20年的时候,大多都在用AR.js做WebAR。随着WebXR标准发展,现在诸如Threejs、AFrame、Unity等多个平台都支持里WebXR。 本文将介绍在Unity中使用 Image Tracker实现Web端的AR图像识别功…...
业务系统对接大模型的基础方案:架构设计与关键步骤
业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...
React 第五十五节 Router 中 useAsyncError的使用详解
前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...
树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频
使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...
【JVM】- 内存结构
引言 JVM:Java Virtual Machine 定义:Java虚拟机,Java二进制字节码的运行环境好处: 一次编写,到处运行自动内存管理,垃圾回收的功能数组下标越界检查(会抛异常,不会覆盖到其他代码…...
2.Vue编写一个app
1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...
《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...
高等数学(下)题型笔记(八)空间解析几何与向量代数
目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...
多模态大语言模型arxiv论文略读(108)
CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题:CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者:Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...
如何在最短时间内提升打ctf(web)的水平?
刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...
如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...
