当前位置: 首页 > article >正文

2. 手写数字预测 gui版

2. 手写数字预测 gui版

  • 背景
  • 1.界面绘制
  • 2.处理图片
  • 3. 加载模型
  • 4. 预测
  • 5.结果
  • 6.一点小问题

在这里插入图片描述

背景

做了手写数字预测的模型,但是老是跑模型太无聊了,就配合pyqt做了一个可视化界面出来玩一下

源代码可以去这里https://github.com/Leezed525/pytorch_toy拿

1.界面绘制

在这里插入图片描述

整个页面布局逻辑很简单,搭建一下就好了

class MainWindow(QMainWindow):def __init__(self):super().__init__()self.net = self.get_net()  # 获取数字预测模型self.setWindowTitle("PyQt 数字预测")self.setGeometry(100, 100, 500, 550)  # 设置主窗口的初始位置和大小,留出空间给按钮self.setFixedSize(500, 550)self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.WindowMaximizeButtonHint)central_widget = QWidget()  # 创建一个中央 QWidgetself.setCentralWidget(central_widget)  # 设置中央 QWidget 为主窗口的中心部件layout = QVBoxLayout(central_widget)  # 为中央 QWidget 创建一个垂直布局# 创建一个水平布局operation_layer = QHBoxLayout()  # 创建一个水平布局用于放置操作区域left_operation_layer = QVBoxLayout()right_operation_layer = QVBoxLayout()self.canvas = DrawingCanvas(self)  # 创建 DrawingCanvas 实例canvas_label = QLabel("请在此处绘制数字")  # 创建一个标签,提示用户在画布上绘制数字canvas_label.setAlignment(Qt.AlignmentFlag.AlignCenter)canvas_label.setStyleSheet("font-size: 20px;")  # 设置标签的样式left_operation_layer.addWidget(canvas_label)  # 将标签添加到左侧操作区域布局中left_operation_layer.addWidget(self.canvas)left_operation_layer.setStretch(0, 1)left_operation_layer.setStretch(1, 10)  # 设置画布的伸缩比例,使其占据更多空间operation_layer.addLayout(left_operation_layer)  # 将左侧操作区域布局添加到操作层布局中# 右侧操作区域self.predict_label = QLabel("预测结果: ")  # 创建一个标签,显示预测结果right_operation_layer.addWidget(self.predict_label)self.predict_digit_labels = []for i in range(10):predict_digit_label = QLabel(f"数字 {i}: 0.00%")  # 创建标签显示每个数字的预测概率self.predict_digit_labels.append(predict_digit_label)  # 将标签添加到列表中for label in self.predict_digit_labels:right_operation_layer.addWidget(label)operation_layer.addLayout(right_operation_layer)  # 将右侧操作区域布局添加到操作层布局中operation_layer.setStretch(0, 10)operation_layer.setStretch(1, 1)layout.addLayout(operation_layer)  # 将操作层布局添加到主布局中# 按钮区布局button_layout = QHBoxLayout()  # 创建一个垂直布局用于放置按钮clear_button = QPushButton("清空画布")  # 清空画布按钮clear_button.clicked.connect(self.canvas.clear_canvas)  # 连接按钮的点击信号到清空画布方法predict_button = QPushButton("预测")  # 清空画布按钮predict_button.clicked.connect(self.predict)  # 连接按钮的点击信号到预测方法button_layout.addStretch(6)button_layout.addWidget(clear_button)button_layout.addWidget(predict_button)layout.addLayout(button_layout)  # 将按钮布局添加到主布局中

其中稍微有点心智压力的区域就是画图区域,这里配合ai然后再自行修改一下就好了,逻辑就是鼠标按住然后绘制,松开后停止绘制。

canvas代码

class DrawingCanvas(QWidget):"""一个自定义的 QWidget 类,用作绘图画布。用户可以在此画布上用鼠标点击并拖动来绘制线条。"""def __init__(self, parent=None):super().__init__(parent)  # 调用父类 QWidget 的构造函数self.setWindowTitle("绘图画布")  # 设置窗口标题self.setGeometry(100, 100, 280, 280)  # 设置窗口的初始位置和大小 (x, y, width, height)self.setMinimumSize(280, 280)# 创建一个 QImage 对象作为绘图缓冲区# 所有的绘图操作都在这个 QImage 上进行,然后整体绘制到屏幕,可以避免闪烁。# QImage.Format.Format_RGB32 是 PyQt6 中推荐的 RGBA 格式,支持透明度。self.image = QImage(self.size(), QImage.Format.Format_RGB32)# 将 QImage 填充为白色。self.image.fill(Qt.GlobalColor.white)self.drawing = False  # 一个布尔标志,指示当前是否正在进行鼠标拖拽绘图self.last_point = QPoint()  # 存储鼠标上次的位置,用于绘制连续的线条# 同样,颜色常量需要通过 Qt.GlobalColor 访问。self.pen_color = Qt.GlobalColor.blackself.pen_size = 20def paintEvent(self, event):"""绘制事件处理函数。每当窗口需要被重新绘制时(例如,首次显示、窗口大小改变、调用 update() 时),Qt 就会自动调用这个方法。"""painter = QPainter(self)  # 创建一个 QPainter 对象,指定在当前 QWidget (self) 上进行绘制# 将 self.image (绘图缓冲区) 的内容绘制到当前 QWidget 的整个矩形区域内。painter.drawImage(self.rect(), self.image, self.image.rect())def mousePressEvent(self, event):# 检查是否是鼠标左键被按下。if event.button() == Qt.MouseButton.LeftButton:self.drawing = True  # 设置绘图标志为 Trueself.last_point = event.pos()  # 记录当前鼠标位置作为线条的起始点def mouseMoveEvent(self, event):"""鼠标移动事件处理函数。当鼠标在窗口内移动时触发。"""# 只有当正在绘图 (self.drawing 为 True) 并且鼠标左键被按住时才执行绘图操作。# event.buttons() 返回当前按下的所有鼠标按钮的位掩码,Qt.MouseButton.LeftButton 用于检查左键是否按下。if self.drawing and event.buttons() & Qt.MouseButton.LeftButton:painter = QPainter(self.image)  # 在 QImage (绘图缓冲区) 上创建 QPainter 进行绘制# 设置画笔的颜色、粗细和样式。painter.setPen(QPen(QColor(self.pen_color), self.pen_size,Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin))# 绘制从上次记录的点到当前鼠标位置的直线painter.drawLine(self.last_point, event.pos())self.last_point = event.pos()  # 更新 last_point 为当前鼠标位置,为下一次绘制做准备self.update()  # 请求窗口重绘。这会间接调用 paintEvent,将 QImage 的最新内容显示到屏幕上。def mouseReleaseEvent(self, event):"""鼠标释放事件处理函数。当用户释放鼠标按钮时触发。"""# 检查是否是鼠标左键被释放。if event.button() == Qt.MouseButton.LeftButton:self.drawing = False  # 停止绘图def resizeEvent(self, event):"""窗口大小改变事件处理函数。当窗口大小改变时触发。"""# 如果新窗口的宽度或高度大于当前 QImage 的尺寸,则需要创建一个新的 QImage。if self.width() > self.image.width() or self.height() > self.image.height():new_image = QImage(self.size(), QImage.Format.Format_RGB32)# 填充新图像为白色new_image.fill(Qt.GlobalColor.white)painter = QPainter(new_image)# 将旧图像的内容绘制到新图像上,以保留已有的绘图。painter.drawImage(QPoint(0, 0), self.image)self.image = new_image  # 更新 self.image 为新的 QImageself.update()  # 请求重绘窗口def clear_canvas(self):"""清空画布内容,将整个 QImage 重新填充为白色。"""self.image.fill(Qt.GlobalColor.white)self.update()  # 请求重绘以显示空白画布def set_pen_size(self, size):"""设置画笔粗细。"""self.pen_size = size

2.处理图片

当布局完成后就只需要处理将图片变成输入的过程就好了,先给代码,在讲解

    def get_image(self):"""获取当前画布上的图像数据。返回一个 QImage 对象,包含当前画布的绘图内容。"""image = self.canvas.image# 将图像缩放到 28x28 像素并转换为灰度图scaled_image = image.scaled(28, 28,Qt.AspectRatioMode.IgnoreAspectRatio,  # 不保持宽高比Qt.TransformationMode.SmoothTransformation  # 平滑缩放)# 转换为 8 位灰度图grayscale_image = scaled_image.convertToFormat(QImage.Format.Format_Grayscale8)# 使用 qimage2ndarray.byte_view() 获取 NumPy 数组arr_3d = qimage2ndarray.byte_view(grayscale_image)arr = arr_3d.squeeze()# 将 NumPy 数组转换为 PyTorch 张量tensor_image = torch.from_numpy(arr).float()# --- 关键修正:添加颜色反转和标准化 ---# 1. 将像素值从 [0, 255] 归一化到 [0.0, 1.0]tensor_image = tensor_image / 255.0# 2. 颜色反转:如果你的模型是基于白色数字黑色背景训练的 而画布是黑色数字白色背景,则需要反转颜色tensor_image = 1.0 - tensor_image# 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std# 添加批次维度和通道维度,使形状变为 (1, 1, 28, 28)tensor_image = tensor_image.unsqueeze(0).unsqueeze(0).cuda()# --- 可视化 PyTorch 张量 ---# 为了可视化,我们先将其恢复到 [0,1] 范围,否则标准化后的值可能很难看# 逆标准化 (用于可视化,不影响模型输入)# visual_tensor = tensor_image * std + mean# # 确保在 [0,1] 范围内# visual_tensor = torch.clamp(visual_tensor, 0.0, 1.0)# plt.figure(figsize=(2, 2))# plt.imshow(visual_tensor.cpu().squeeze().numpy(), cmap='gray')# plt.title("input")# plt.axis('off')# plt.show()return tensor_image

其中有几个注意点
1.
目前的画布是白色的,画笔是黑色,但是mnist数据集的底是黑色的,画笔是白色的,因此需要使用

tensor_image = 1.0 - tensor_image

来将颜色取反,不然跟训练数据不一样模型无法良好运行。
2.
QT中的image是Qimage,转换成numpy代码有点麻烦,我这里图省事直接用了qimage2ndarray库,因此只需一行代码

arr_3d = qimage2ndarray.byte_view(grayscale_image)

就完成了这个操作。
3.
在输入到模型之前,要进行数据预处理,如上面的代码中

        # 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std

来优化模型效果。

3. 加载模型

这里的预训练权重就直接用了上一篇文章中训练出来的权重,还给她放到cuda上了,不过这么小的模型其实放不放其实都无所谓,没有太大的影响。

    def get_net(self):"""获取数字预测模型。返回一个 DigitCNN 模型实例。"""# 创建并返回一个 DigitCNN 模型实例net = DigitCNN()net.eval()net.cuda()net.load_state_dict(torch.load('./digit_CNN.pth'))return net

4. 预测

这里就没什么好说的了,就是简单地预测然后将结果同步到gui上了。

    def predict(self):"""预测当前画布上绘制的数字。这里可以调用模型进行预测,并更新预测结果标签。"""input = self.get_image()  # 获取当前画布上的图像数据# 使用模型进行预测with torch.no_grad():output = self.net(input)# 获取预测结果self.update_predict_result(output)def update_predict_result(self, output):_, predict = output.max(1)  # 获取预测的数字类别predict = predict.cpu().numpy()[0]# 更新预测结果标签self.predict_label.setText(f"预测结果: {predict}")# 更新每个数字的预测概率probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]for i, label in enumerate(self.predict_digit_labels):label.setText(f"数字 {i}: {probabilities[i] * 100:.2f}%")

5.结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

6.一点小问题

现在模型是可以用了,但是因为Mnist数据集本身的局限性,已经网络也比较小,泛化性能比较差(但是没差到不能用的地步),所以预测结果又是后会比较奇怪,例如:

.在这里插入图片描述
这是mnist数据集中的数据,可以看出这里的0大部分都是上面闭合,导致模型预测奇怪位置的闭合的0会失准。

还有其中的4大部分都是开口的,并没有闭合4上面的开口,导致写一个很标准的4反倒有时候会预测出错,还有其他的一些问题我就不赘述了。

总之如果想要模型想要获得更好的表现,一是可以增强一下模型的能力,第二个我觉得更重的是把数据好好清洗一下,有些数据真的太差了

相关文章:

2. 手写数字预测 gui版

2. 手写数字预测 gui版 背景1.界面绘制2.处理图片3. 加载模型4. 预测5.结果6.一点小问题 背景 做了手写数字预测的模型,但是老是跑模型太无聊了,就配合pyqt做了一个可视化界面出来玩一下 源代码可以去这里https://github.com/Leezed525/pytorch_toy拿 …...

js数据类型有哪些?它们有什么区别?

js数据类型共有8种,分别是undefined,null,boolean,number,string,Object,symbol,bigint symbol和bigint是es6中提出来的数据类型 symbol创建后独一无二不可变的数据类型,它主要是为了解决出现全局变量冲突的问题 bigint 是一种数字类型的数据,它可以表示任意精度格式的整数,…...

大模型应用开发第五讲:成熟度模型:从ChatGPT(L2)到未来自主Agent(L4)

大模型应用开发第五讲:成熟度模型:从ChatGPT(L2)到未来自主Agent(L4) 资料取自《大模型应用开发:动手做AI Agent 》。 查看总目录:学习大纲 关于DeepSeek本地部署指南可以看下我之…...

特别篇-产品经理(三)

一、市场与竞品分析—竞品分析 1. 课后总结 案例框架:通过"小新吃蛋糕"案例展示行业分析方法,包含四个关键步骤: 明确目标行业调研确定竞品分析竞争策略输出结论 1)行业背景分析方法 PEST分析法:从四个…...

IP地址扫描 网络状态监测 企业网络管理 免安装,企业级 IP 监控防未授权接入

各位网络小卫士们!今天咱来聊聊一款超厉害的局域网IP地址扫描工具——IPScaner V1.22。这玩意儿就像网络世界的大侦探,能快速识别网络里设备的状态和资源分布。下面咱就好好唠唠它的那些事儿。 软件获取夸克网盘下载 先说说它的核心功能。第一个是IP…...

【unity游戏开发——编辑器扩展】AssetDatabase公共类在编辑器环境中管理和操作项目中的资源

注意:考虑到编辑器扩展的内容比较多,我将编辑器扩展的内容分开,并全部整合放在【unity游戏开发——编辑器扩展】专栏里,感兴趣的小伙伴可以前往逐一查看学习。 文章目录 前言一、AssetDatabase常用API1、创建资源1.1 API1.2 示例 …...

BLE协议全景图:从0开始理解低功耗蓝牙

BLE(Bluetooth Low Energy)作为一种针对低功耗场景优化的通信协议,已经广泛应用于智能穿戴、工业追踪、智能家居、医疗设备等领域。 本文是《BLE 协议实战详解》系列的第一篇,将从 BLE 的发展历史、协议栈结构、核心机制和应用领域出发,为后续工程实战打下全面认知基础。 …...

【机器学习基础】机器学习入门核心算法:GBDT(Gradient Boosting Decision Tree)

机器学习入门核心算法:GBDT(Gradient Boosting Decision Tree) 1. 算法逻辑2. 算法原理与数学推导2.1 目标函数2.2 负梯度计算2.3 决策树拟合2.4 叶子权重计算2.5 模型更新 3. 模型评估评估指标防止过拟合 4. 应用案例4.1 金融风控4.2 推荐系…...

基于开源AI大模型AI智能名片S2B2C商城小程序源码的销售环节数字化实现路径研究

摘要:在数字化浪潮下,企业销售环节的转型升级已成为提升竞争力的核心命题。本文基于清华大学全球产业研究院《中国企业数字化转型研究报告(2020)》提出的“提升销售率与利润率、打通客户数据、强化营销协同、构建全景用户画像、助…...

Spring Cache核心原理与快速入门指南

文章目录 前言一、Spring Cache核心原理1.1 架构设计思想1.2 运行时执行流程1.3 核心组件协作1.4 关键机制详解1.5 扩展点设计1.6 与Spring事务的协同 二、快速入门实战三、局限性3.1 多级缓存一致性缺陷3.2 分布式锁能力缺失3.3 事务集成陷阱 总结 前言 在当今高并发、低延迟…...

Redisson学习专栏(四):实战应用(分布式会话管理,延迟队列)

文章目录 前言一、为什么需要分布式会话管理?1.1 使用 Redisson 实现 Session 共享 二、订单超时未支付?用延迟队列精准处理2.1 RDelayedQueue 核心机制2.2 订单超时处理实战 总结 前言 在现代分布式系统中,会话管理和延迟任务处理是两个核心…...

java程序从服务器端到Lambda函数的迁移与优化

source:https://www.jfokus.se/jfokus24-preso/From-Serverful-to-Serverless-Java.pdf 从传统的服务器端Java应用,到如今的无服务器架构。这不仅仅是技术名词的改变,更是开发模式和运维理念的一次深刻变革。先快速回顾一下我们熟悉的“服务…...

使用yocto搭建qemuarm64环境

环境 yocto下载 # 源码下载 git clone git://git.yoctoproject.org/poky git reset --hard b223b6d533a6d617134c1c5bec8ed31657dd1268 构建 # 编译镜像 export MACHINE"qemuarm64" . oe-init-build-env bitbake core-image-full-cmdline 运行 # 跑虚拟机 export …...

Vue 3前沿生态整合:WebAssembly与TypeScript深度实践

一、Vue 3 WebAssembly:突破性能天花板 01、WebAssembly:浏览器中的原生性能 WebAssembly(Wasm)是一种可在现代浏览器中运行的二进制指令格式,其性能接近原生代码。结合Vue 3的响应式架构,我们可以在前端…...

Linux系统下安装配置 Nginx

Windows Nginx https://nginx.org/en/download.htmlLinux Nginx https://nginx.org/download/nginx-1.24.0.tar.gz解压 tar -zxvf tar -zxvf nginx-1.18.0.tar.gz #解压安装依赖(如未安装) yum groupinstall "Development Tools" -y yum…...

Kotlin 中集合遍历有哪几种方式?

1 for-in 循环(最常用) val list listOf("A", "B", "C") for (item in list) {print("$item ") }// A B C 2 forEach 高阶函数 val list listOf("A", "B", "C") list.forEac…...

图像卷积OpenCV C/C++ 核心操作

图像卷积:OpenCV C 核心操作 图像卷积是图像处理和计算机视觉领域最基本且最重要的操作之一。它通过一个称为卷积核(或滤波器)的小矩阵,在输入图像上滑动,并对核覆盖的图像区域执行元素对应相乘后求和的运算&#xff…...

LiveGBS作为下级平台GB28181国标级联2016|2022对接海康大华宇视华为政务公安内网等GB28181国标平台查看级联状态及会话

LiveGBS作为下级平台GB28181国标级联2016|2022对接海康大华宇视华为政务公安内网等GB28181国标平台查看级联状态及会话 1、GB/T28181级联概述2、搭建GB28181国标流媒体平台3、获取上级平台接入信息3.1、向下级提供信息3.2、上级国标平台添加下级域3.3、接入LiveGBS示例 4、配置…...

leetcode17.电话号码的字母组合:字符串映射与回溯的巧妙联动

一、题目深度解析与字符映射逻辑 题目描述 给定一个仅包含数字 2-9 的字符串 digits,返回所有它能表示的字母组合。数字与字母的映射关系如下(与电话按键相同): 2: "abc", 3: "def", 4: "ghi", …...

Gartner《2025 年软件工程规划指南》报告学习心得

一、引言 软件工程领域正面临着前所未有的变革与挑战。随着生成式人工智能(GenAI)等新兴技术的涌现、市场环境的剧烈动荡以及企业对软件工程效能的更高追求,软件工程师们必须不断适应和拥抱变化,以提升自身竞争力并推动业务发展。Gartner 公司发布的《2025 年软件工程规划…...

数据库 | 使用timescaledb和大模型进行数据分析

时序数据库:timescaledb 大模型:通义千问2.5 对话开始前提示词: 我正在做数据分析,以下是已知信息: 数据库:timescaledb,表名:dm_tag_value,tag_name列是位号名,app_time列是时间,…...

快速阅读源码

Doxygen 轻松生成包含类图、调用关系图的 HTML 和 PDF 文档, Graphviz 可以用来生成类图、调用图 sudo apt-get install doxygen graphviz brew install doxygen graphviz#HTML 文档: open docs/html/index.html一、Doxyfile配置: Doxyfile 文件 doxygen Doxyfile P…...

linux创建虚拟网卡和配置多ip

1.展示当前网卡信息列表: linux上: ip a ifconfigwindows上: ipconfig 2.创建虚拟网卡对: sudo ip link add name veth0 type veth peer name veth1 在 ip link add 命令中,type 参数可以指定多种虚拟网络设备类型&…...

Java Class类文件结构

🧑 博主简介:CSDN博客专家,历代文学网(PC端可以访问:https://literature.sinhy.com/#/literature?__c1000,移动端可微信小程序搜索“历代文学”)总架构师,15年工作经验,…...

AI问答-Vue3+TS:reactive创建一个响应式数组,用一个新的数组对象来替换它,同时保持响应性

在 Vue 3 中,当你使用 reactive 创建一个响应式数组后,如果你想用一个新的数组对象来替换它,同时保持响应性,有几种方法可以实现 方法一:直接替换整个数组(推荐) import { reactive } from vu…...

quasar electron mode如何打包无边框桌面应用程序

预览 开源项目Tokei Kun 一款简洁的周年纪念app,现已发布APK(安卓)和 EXE(Windows) 项目仓库地址:Github Repo 应用下载链接:Github Releases Preparation for Electron quasar dev -m elect…...

【HW系列】—Windows日志与Linux日志分析

文章目录 一、Windows日志1. Windows事件日志2. 核心日志类型3. 事件日志分析实战详细分析步骤 二、Linux日志1. 常见日志文件2. 关键日志解析3. 登录爆破检测方法日志分析核心要点 一、Windows日志 1. Windows事件日志 介绍:记录系统、应用程序及安全事件&#x…...

VIN码识别解析接口如何用C#进行调用?

一、什么是VIN码识别解析接口? VIN码不仅是车辆的“身份证”,更是连接制造、销售、维修、保险、金融等多个环节的数字纽带。而VIN码查询API,正是打通这一链条的关键工具。 无论是汽车电商平台、二手车商、维修厂,还是保险公司、金…...

动态规划之网格图模型(一)

文章目录 动态规划之网格图模型(一)LeetCode 64. 最小路径和思路Golang 代码 LeetCode 62. 不同路径思路Golang 代码 LeetCode 63. 不同路径 II思路Golang 代码 LeetCode 120. 三角形最小路径和思路Golang 代码 LeetCode 3393. 统计异或值为给定值的路径…...

PCB设计实践(三十)地平面完整性

在高速数字电路和混合信号系统设计中,地平面完整性是决定PCB性能的核心要素之一。本文将从电磁场理论、信号完整性、电源分配系统等多个维度深入剖析地平面设计的关键要点,并提出系统性解决方案。 一、地平面完整性的电磁理论基础 电流回流路径分析 在PC…...