当前位置: 首页 > news >正文

基于pytorch的手写数字识别-训练+使用

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):plt.subplot(10, 5, i + 1)  # 10行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签plt.figure(figsize=(16, 10))
for i in range(20):plt.subplot(4, 5, i + 1)  # 4行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

相关文章:

基于pytorch的手写数字识别-训练+使用

import pandas as pd import numpy as np import torch import matplotlib import matplotlib.pyplot as plt from torch.utils.data import TensorDataset, DataLoadermatplotlib.use(tkAgg)# 设置图形配置 config {"font.family": serif,"mathtext.fontset&q…...

SpringBoot接收前端传递参数

1&#xff09;URL 参数 参数直接 拼接在URL的后面&#xff0c;使用 ? 进行分隔&#xff0c;多个参数之间用 & 符号分隔。例如&#xff1a;http://localhost:8080/user?namezhangsan&id1后端接收&#xff08;在Controller方法的参数列表中使用 RequestParam 注解&…...

【LeetCode周赛】第 418 场

3309. 连接二进制表示可形成的最大数值 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接 数组 nums 中所有元素的 二进制表示 &#xff0c;请你返回可以由这种方法形成的 最大 数值。 注意 任何数字的二进制表示 不含 前导零 思路&#xff1a;暴力枚举 class Soluti…...

Android学习7 -- NDK2 -- 几个例子

学习 Android 的 NDK&#xff08;Native Development Kit&#xff09;可以帮助你用 C/C 来开发高性能的 Android 应用&#xff0c;特别适合对性能要求较高的任务&#xff0c;如音视频处理、游戏开发和硬件驱动等。下面是学习 NDK 的建议步骤和具体例子&#xff1a; ### 1. **准…...

问:说说JVM不同版本的变化和差异?

在Java程序的执行过程中&#xff0c;Java虚拟机&#xff08;JVM&#xff09;扮演着至关重要的角色。它不仅负责解释和执行Java字节码&#xff0c;还管理着程序运行时的内存。根据JVM规范&#xff0c;JVM将其所管理的内存划分为多个不同的数据区域&#xff0c;包括程序计数器、J…...

计算机毕业设计 基于Python的社交音乐分享平台的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…...

51单片机的水位检测系统【proteus仿真+程序+报告+原理图+演示视频】

1、主要功能 该系统由AT89C51/STC89C52单片机LCD1602显示模块水位传感器继电器LED、按键和蜂鸣器等模块构成。适用于水位监测、水位控制、水位检测相似项目。 可实现功能: 1、LCD1602实时显示水位高度 2、水位传感器采集水位高度 3、按键可设置水位的下限 4、按键可手动加…...

Python和R及Julia妊娠相关疾病生物剖析算法

&#x1f3af;要点 算法使用了矢量投影、现代优化线性代数、空间分区技术和大数据编程利用相应向量空间中标量积和欧几里得距离的紧密关系来计算使用妊娠相关疾病&#xff08;先兆子痫&#xff09;、健康妊娠和癌症测试算法模型使用相关性投影利用相关性和欧几里得距离之间的关…...

Web安全 - 重放攻击(Replay Attack)

文章目录 OWASP 2023 TOP 10导图1. 概述2. 重放攻击的原理攻击步骤 3. 常见的重放攻击场景4. 防御重放攻击的技术措施4.1 使用时效性验证&#xff08;Time-Based Tokens&#xff09;4.2 单次令牌机制&#xff08;Nonce&#xff09;4.3 TLS/SSL 协议4.4 HMAC&#xff08;哈希消息…...

Python项目文档生成常用工具对比

写在前面&#xff1a; 通过阅读本片文章&#xff0c;你将了解&#xff1a;主流的Python项目文档生成工具&#xff08;Sphinx&#xff0c;MkDocs&#xff0c;pydoc&#xff0c;Pdoc&#xff09;简介及对比&#xff0c;本文档不涉及相关工具的使用。 概述 近期&#xff0c;由于…...

教育领域的技术突破:SpringBoot系统实现

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…...

RabbitMQ入门3—virtual host参数详解

在 RabbitMQ 中&#xff0c;创建 Virtual Host 时会涉及到一些参数配置&#xff0c;比如 tags 和 Default Queue Type。下面是对这两个参数的详细解释&#xff1a; 1. Tags Tags 是 Virtual Host 的标记&#xff0c;用来为 Virtual Host 添加元数据&#xff0c;帮助你管理和组…...

【Nacos入门到实战十四】Nacos配置管理:集群部署与高可用策略

个人名片 &#x1f393;作者简介&#xff1a;java领域优质创作者 &#x1f310;个人主页&#xff1a;码农阿豪 &#x1f4de;工作室&#xff1a;新空间代码工作室&#xff08;提供各种软件服务&#xff09; &#x1f48c;个人邮箱&#xff1a;[2435024119qq.com] &#x1f4f1…...

UE5+ChatGPT实现3D AI虚拟人综合实战

第11章 综合实战&#xff1a;UE5ChatGPT实现3D AI虚拟人 通过结合Unreal Engine 5&#xff08;UE5&#xff09;的强大渲染能力和ChatGPT的自然语言处理能力&#xff0c;我们可以实现一个高度交互性的AI虚拟人。本文将详细介绍如何在UE5中安装必要的插件&#xff0c;配置OpenAI…...

[图形学]smallpt代码详解(2)

一、简介 本文紧接在[图形学]smallpt代码详解&#xff08;1&#xff09;之后&#xff0c;继续详细讲解smallpt中的代码&#xff0c;包括自定义函数&#xff08;第41到47行&#xff09;和递归路径跟踪函数&#xff08;第48到74行&#xff09;部分。 二、smallpt代码详解 1.自…...

vmstat命令:系统性能监控

一、命令简介 ​vmstat​ 是一种在类 Unix 系统上常用的性能监控工具&#xff0c;它可以报告虚拟内存统计信息&#xff0c;包括进程、内存、分页、块 IO、陷阱&#xff08;中断&#xff09;和 CPU 活动等。 ‍ 二、命令参数 2.1 命令格式 vmstat [选项] [ 延迟 [次数] ]2…...

linux部署NFS和autofs自动挂载

目录 &#xff08;一&#xff09;NFS&#xff1a; 1. 什么是NFS 2. NFS守护进程 3. RPC服务 4. 原理 5. 部署 5.1 安装NFS服务 5.2 配置防火墙 5.3 创建服务端共享目录 5.4 修改服务端配置文件 (1). /etc/exports (2). nfs.conf 5.5 启动nfs并加入自启 5.6 客户端…...

WPF RadioButton 绑定boolean值

<RadioButtonMargin"5"Content"替换"IsChecked"{Binding CorrectionOption.ReCorrectionMode}" /> <RadioButtonMargin"5"Content"平均"IsChecked"{Binding CorrectionOption.ReCorrectionMode, Converter{St…...

2024 ciscn WP

一、MISC 1.火锅链观光打卡 打开后连接自己的钱包&#xff0c;然后点击开始游戏&#xff0c;答题八次后点击获取NFT&#xff0c;得到有flag的图片 没什么多说的&#xff0c;知识问答题 兑换 NFT Flag{y0u_ar3_hotpot_K1ng} 2.Power Trajectory Diagram 方法1&#xff1a; 使用p…...

代码随想录--字符串--重复的子字符串

题目 给定一个非空的字符串&#xff0c;判断它是否可以由它的一个子串重复多次构成。给定的字符串只含有小写英文字母&#xff0c;并且长度不超过10000。 示例 1: 输入: "abab" 输出: True 解释: 可由子字符串 "ab" 重复两次构成。示例 2: 输入: "…...

华为云AI开发平台ModelArts

华为云ModelArts&#xff1a;重塑AI开发流程的“智能引擎”与“创新加速器”&#xff01; 在人工智能浪潮席卷全球的2025年&#xff0c;企业拥抱AI的意愿空前高涨&#xff0c;但技术门槛高、流程复杂、资源投入巨大的现实&#xff0c;却让许多创新构想止步于实验室。数据科学家…...

在Ubuntu中设置开机自动运行(sudo)指令的指南

在Ubuntu系统中&#xff0c;有时需要在系统启动时自动执行某些命令&#xff0c;特别是需要 sudo权限的指令。为了实现这一功能&#xff0c;可以使用多种方法&#xff0c;包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法&#xff0c;并提供…...

ServerTrust 并非唯一

NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中&#xff0c;元素的定位通过 position 属性控制&#xff0c;共有 5 种定位模式&#xff1a;static&#xff08;静态定位&#xff09;、relative&#xff08;相对定位&#xff09;、absolute&#xff08;绝对定位&#xff09;、fixed&#xff08;固定定位&#xff09;和…...

OpenLayers 分屏对比(地图联动)

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能&#xff0c;和卷帘图层不一样的是&#xff0c;分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

2025季度云服务器排行榜

在全球云服务器市场&#xff0c;各厂商的排名和地位并非一成不变&#xff0c;而是由其独特的优势、战略布局和市场适应性共同决定的。以下是根据2025年市场趋势&#xff0c;对主要云服务器厂商在排行榜中占据重要位置的原因和优势进行深度分析&#xff1a; 一、全球“三巨头”…...

Mysql中select查询语句的执行过程

目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析&#xff08;Parser&#xff09; 2.4、执行sql 1. 预处理&#xff08;Preprocessor&#xff09; 2. 查询优化器&#xff08;Optimizer&#xff09; 3. 执行器…...

AI+无人机如何守护濒危物种?YOLOv8实现95%精准识别

【导读】 野生动物监测在理解和保护生态系统中发挥着至关重要的作用。然而&#xff0c;传统的野生动物观察方法往往耗时耗力、成本高昂且范围有限。无人机的出现为野生动物监测提供了有前景的替代方案&#xff0c;能够实现大范围覆盖并远程采集数据。尽管具备这些优势&#xf…...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观&#xff0c;可持续性好高效率高复用&#xff0c;可移植性好高内聚&#xff0c;低耦合没有冗余规范性&#xff0c;代码有规可循&#xff0c;可以看出自己当时的思考过程特殊排版&#xff0c;特殊语法&#xff0c;特殊指令&#xff0c;必须…...