当前位置: 首页 > news >正文

基于pytorch的手写数字识别

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):plt.subplot(10, 5, i + 1)  # 10行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签plt.figure(figsize=(16, 10))
for i in range(20):plt.subplot(4, 5, i + 1)  # 4行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

相关文章:

基于pytorch的手写数字识别

import pandas as pd import numpy as np import torch import matplotlib import matplotlib.pyplot as plt from torch.utils.data import TensorDataset, DataLoadermatplotlib.use(tkAgg)# 设置图形配置 config {"font.family": serif,"mathtext.fontset&q…...

MySQL 实验 7:索引的操作

MySQL 实验 7&#xff1a;索引的操作 索引是对数据表中一列或多列的值进行排序的一种结构&#xff0c;索引可以大大提高 MySQL 的检索速度。合理使用索引&#xff0c;可以大大提升 SQL 查询的性能。 索引好比是一本书前面的目录&#xff0c;假如我们需要从书籍查找与 xx 相关…...

为Floorp浏览器添加搜索引擎及搜索栏相关设置. 2024-10-05

Floorp浏览器开源项目地址: https://github.com/floorp-Projects/floorp/ 1.第一步 为Floorp浏览器添加搜索栏 (1.工具栏空白处 次键选择 定制工具栏 (2. 把 搜索框 拖动至工具栏 2.添加搜索引擎 以添加 搜狗搜索 为例 (1.访问 搜索引擎网址 搜狗搜索引擎 - 上网从搜狗开始 (2…...

如何设置WSL Ubuntu在Windows开机时自动启动

如何设置WSL Ubuntu在Windows开机时自动启动 步骤详解1. 创建批处理脚本2. 添加到Windows启动项 注意事项结语 在使用Windows Subsystem for Linux (WSL) 时,我们可能希望Ubuntu能够在Windows启动时自动运行。本文将介绍如何实现这一功能,让您的开发环境更加便捷。 步骤详解 …...

使用TensorBoard可视化模型

目录 TensorBoard简介 神经网络模型 可视化 轮次-损失曲线 轮次-准确率曲线 轮次-学习率曲线 迭代-评估准确率曲线 迭代-评估损失曲线 TensorBoard简介 TensorBoard是一款出色的交互式的模型可视化工具。安装TensorFlow时,会自动安装TensorBoard。如图: TensorFlow可…...

《深度学习》OpenCV 图像拼接 原理、参数解析、案例实现

目录 一、图像拼接 1、直接看案例 图1与图2展示&#xff1a; 合并完结果&#xff1a; 2、什么是图像拼接 3、图像拼接步骤 1&#xff09;加载图像 2&#xff09;特征点检测与描述 3&#xff09;特征点匹配 4&#xff09;图像配准 5&#xff09;图像变换和拼接 6&am…...

Hive数仓操作(三)

一、Hive 数据库操作 1. 创建数据库 基本创建数据库命令&#xff1a; CREATE DATABASE bigdata;说明&#xff1a; 数据库会在 HDFS 中以目录的形式创建和保存&#xff0c;数据库名称会存储在 Hive 的元数据中。如果不指定目录&#xff0c;数据库将在 /user/hive/warehouse 下…...

TDSQL-C电商可视化,重塑电商决策新纪元

前言&#xff1a; 在数字化浪潮席卷全球的今天&#xff0c;电子商务行业以其独特的魅力和无限潜力&#xff0c;成为了推动全球经济增长的重要引擎。然而&#xff0c;随着业务规模的急剧扩张&#xff0c;海量数据的涌现给电商企业带来了前所未有的挑战与机遇。如何高效地处理、…...

翔云 OCR:发票识别与验真

在数字化时代&#xff0c;高效处理大量文档和数据成为企业和个人的迫切需求。翔云 OCR 作为一款强大的光学字符识别工具&#xff0c;在发票识别及验真方面表现出色&#xff0c;为我们带来了极大的便利。 一、翔云 OCR 简介 翔云 OCR 是一款基于先进的人工智能技术开发的文字识别…...

HTML ASCII:Web 开发中的字符编码基础

HTML ASCII&#xff1a;Web 开发中的字符编码基础 ASCII&#xff0c;全称为美国信息交换标准代码&#xff08;American Standard Code for Information Interchange&#xff09;&#xff0c;是一种用于电子通信的字符编码标准。它最初于1963年提出&#xff0c;用于在不同的计算…...

Meta 首个多模态大模型一键启动!首个多针刺绣数据集上线,含超 30k 张图片

小扎在 Meta Connect 2024 主题演讲中宣布推出首个多模态大模型 Llama 3.2 vision&#xff01;该模型有 11B 和 90B 两个版本&#xff0c;成为首批支持多模态任务的 Llama 系列模型&#xff0c;根据官方数据&#xff0c;这两个开原模型的性能已超越闭源模型。 小编已经迫不及待…...

阿里云ECS服务器仿真

1.首先使用qemu-img对RAW镜像进行转换&#xff0c;qemu-img convert -O vmdk 1.raw 2.vmdk 2.使用WinHex对镜像的root密码进行删除 3.由于这次阿里云ECS使用了CONFIG_SYSTEM_TRUSTED_KEYS验证&#xff0c;无法直接仿真&#xff0c;需使用live系统对内核进行修改。分为以下几步&…...

如何为树莓派安装操作系统,以及远程操控树莓派的两种方法,无线操控和插网线操控

文章目录 一、下载树莓派的系统二、将文件下载到SD卡中1.使用官方软件2.其他选择 三、远程连接电脑安装vnc-viewer1.无线操作&#xff08;配置树莓派&#xff0c;开启VNC&#xff09;电脑远程配置2.有线连接&#xff08;需要一根网线&#xff09; 总结 一、下载树莓派的系统 下…...

【最新华为OD机试E卷-支持在线评测】简单的自动曝光(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 💻 ACM金牌🏅️团队 | 大厂实习经历 | 多年算法竞赛经历 ✨ 本系列打算持续跟新华为OD-E/D卷的多语言AC题解 🧩 大部分包含 Python / C / Javascript / Java / Cpp 多语言代码 👏 感谢大家的订阅➕ 和 喜欢�…...

每日一练:等差数列划分

413. 等差数列划分 - 力扣&#xff08;LeetCode&#xff09; 题目要求&#xff1a; 如果一个数列 至少有三个元素 &#xff0c;并且任意两个相邻元素之差相同&#xff0c;则称该数列为等差数列。 例如&#xff0c;[1,3,5,7,9]、[7,7,7,7] 和 [3,-1,-5,-9] 都是等差数列。 给…...

Kotlin真·全平台——Kotlin Compose Multiplatform Mobile(kotlin跨平台方案、KMP、KMM)

前言 随着kotlin代码跨平台方案的推出&#xff0c;kotlin跨平台一度引起不少波澜。但波澜终归没有掀起太大的风浪&#xff0c;作为一个敏捷型开发的公司&#xff0c;依然少不了Android和iOS的同步开发&#xff0c;实际成本和效益并没有太多变化。所以对于大多数公司来说依然风平…...

unity 默认渲染管线材质球的材质通道,材质球的材质通道

标准渲染管线——材质球的材质通道 文档&#xff0c;与内容无关&#xff0c;是介绍材质球的属性的。 https://docs.unity3d.com/2022.1/Documentation/Manual/StandardShaderMaterialParameters.html游戏资源中常见的贴图类型 https://zhuanlan.zhihu.com/p/260973533 十大贴图…...

PostgreSQL升级:使用pg_upgrade进行大版本(16.3)升级(17.0)

1.pg_upgrade工具介绍 pg_upgrade 会创建新的系统表&#xff0c;并以重用旧的数据文件的方式进行升级。 pg_upgrade 的参数选项如下&#xff1a; -b bindir&#xff0c;--old-bindirbindir&#xff1a;旧的 PostgreSQL 可执行文件目录&#xff1b; -B bindir&#xff0c;--new-…...

userdel命令:删除指定Linux用户

一、命令简介 ​userdel​ 命令用于删除 Linux 系统中的用户账号。当您不再需要某个用户账号时&#xff0c;可以使用 userdel​ 命令将其从系统中删除。 ‍ 二、命令参数 userdel [选项] 用户名一些常用的选项包括&#xff1a; -r, --remove: 删除用户的家目录及邮件目录。…...

QT系统学习篇(1)

一、什么是Qt、Qt的优势 QT是一个跨平台的C图形用户界面库&#xff0c;目前包括Qt Creator、Qt Designer等等快速开发工具。支持所有Linux/Unix系统&#xff0c;还支持windows平台。Qt很容易扩展&#xff0c;并且允许真正的组件编程。&#xff08;军工企业项目开发基本离不开Q…...

Python 3.14 JIT为何在ARM64上降频17%?源码级定位_pyltopt_arch.c中2个未对齐的寄存器分配bug(已提交CPython PR#12894)

第一章&#xff1a;Python 3.14 JIT编译器性能降频现象概览Python 3.14 引入的实验性 JIT 编译器&#xff08;基于 Pyjion 与新式 AST 优化管道&#xff09;在部分工作负载下表现出非预期的性能降频现象——即启用 JIT 后&#xff0c;某些计算密集型循环或 I/O 绑定协程的执行耗…...

FileConverter:重构文件格式转换流程,实现设计师与教育工作者的效率突破

FileConverter&#xff1a;重构文件格式转换流程&#xff0c;实现设计师与教育工作者的效率突破 【免费下载链接】FileConverter File Converter is a very simple tool which allows you to convert and compress files using the context menu in windows explorer. 项目地…...

RePKG:突破动态壁纸资源壁垒的开源工具

RePKG&#xff1a;突破动态壁纸资源壁垒的开源工具 【免费下载链接】repkg Wallpaper engine PKG extractor/TEX to image converter 项目地址: https://gitcode.com/gh_mirrors/re/repkg 当你面对一个包含丰富素材的动态壁纸资源包&#xff08;PKG文件&#xff09;却无…...

HiOmics平台:零代码实现ChIP-Seq数据可视化与深度解析

1. 为什么科研人员需要零代码ChIP-Seq分析工具 做表观遗传学研究的朋友们应该都深有体会&#xff0c;ChIP-Seq数据分析就像一场马拉松——从原始数据清洗、序列比对、peak calling到功能注释&#xff0c;每个环节都需要不同的工具和脚本。我刚开始接触这个领域时&#xff0c;光…...

知乎上线求职工具,助力毕业生破困局

知乎上线求职利器&#xff0c;直击毕业生痛点2026届全国普通高校毕业生预计达1270万人&#xff0c;再创历史新高。与此同时&#xff0c;AI技术加速行业重构&#xff0c;部分传统岗位需求收缩&#xff0c;大量毕业生陷入“海投”困境&#xff0c;难以精准定位自身。在此背景下&a…...

Phi-4-mini-reasoning部署教程:Nginx反向代理+Basic Auth安全加固

Phi-4-mini-reasoning部署教程&#xff1a;Nginx反向代理Basic Auth安全加固 1. 项目介绍 Phi-4-mini-reasoning是一款由微软开源的轻量级AI模型&#xff0c;专注于数学推理、逻辑推导和多步解题等强逻辑任务。这个3.8B参数的模型虽然体积小巧&#xff0c;但在推理能力上表现…...

你的文件真的‘上传’了吗?聊聊阿里云盘‘秒传’背后的隐私与安全考量

你的文件真的“上传”了吗&#xff1f;揭秘秒传技术背后的隐私博弈 第一次在阿里云盘体验“秒传”功能时&#xff0c;那种近乎魔法的速度确实令人惊叹——几个GB的文件眨眼间就完成了“上传”。但惊喜之余&#xff0c;一个更根本的问题浮现出来&#xff1a;我的文件真的被上传了…...

威纶通宏指令实战:从零构建中文输入与智能配方检索系统

1. 威纶通触摸屏的中文输入困境与破解之道 第一次接触威纶通中低端触摸屏时&#xff0c;我就被它缺乏中文输入支持的问题给难住了。当时接了个食品包装机的项目&#xff0c;客户要求操作界面必须支持中文输入&#xff0c;方便工人记录生产批号和产品信息。市面上常见的中高端HM…...

从Solid模块到轨迹规划:一个完整机械臂SimMechanics仿真项目的保姆级拆解

从Solid模块到轨迹规划&#xff1a;一个完整机械臂SimMechanics仿真项目的保姆级拆解 机械臂仿真一直是工业自动化和机器人研究中的核心课题。不同于传统Adams等专业仿真软件&#xff0c;SimMechanics凭借其与Matlab/Simulink的无缝集成&#xff0c;为工程师提供了从建模到控制…...

SAP BAPI实战指南:核心模块高频接口速查与应用解析

1. SAP BAPI入门&#xff1a;为什么开发者需要这份速查手册 第一次接触SAP BAPI时&#xff0c;我盯着满屏的接口文档差点崩溃——光是FICO模块就有二十多个常用BAPI&#xff0c;每个接口的参数列表长得像毕业论文。后来在项目上踩过几次坑才明白&#xff0c;BAPI的难点不在于技…...