基于pytorch的手写数字识别-训练+使用
import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval() # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False) # 随机选择50个样本
sample_images = x[sample_indices].to(device) # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy() # 真实标签# 进行预测
with torch.no_grad(): # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy() # 获取预测的标签# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):plt.subplot(10, 5, i + 1) # 10行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray') # 还原为20x20图像plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)plt.axis('off') # 关闭坐标轴plt.tight_layout() # 调整子图间距
plt.show()
Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:
import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval() # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False) # 随机选择50个样本
sample_images = x[sample_indices].to(device) # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy() # 真实标签# 进行预测
with torch.no_grad(): # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy() # 获取预测的标签plt.figure(figsize=(16, 10))
for i in range(20):plt.subplot(4, 5, i + 1) # 4行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray') # 还原为20x20图像plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12) # 标题中显示真实值和预测值plt.axis('off') # 关闭坐标轴plt.tight_layout() # 调整子图间距
plt.show()

相关文章:
基于pytorch的手写数字识别-训练+使用
import pandas as pd import numpy as np import torch import matplotlib import matplotlib.pyplot as plt from torch.utils.data import TensorDataset, DataLoadermatplotlib.use(tkAgg)# 设置图形配置 config {"font.family": serif,"mathtext.fontset&q…...
SpringBoot接收前端传递参数
1)URL 参数 参数直接 拼接在URL的后面,使用 ? 进行分隔,多个参数之间用 & 符号分隔。例如:http://localhost:8080/user?namezhangsan&id1后端接收(在Controller方法的参数列表中使用 RequestParam 注解&…...
【LeetCode周赛】第 418 场
3309. 连接二进制表示可形成的最大数值 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接 数组 nums 中所有元素的 二进制表示 ,请你返回可以由这种方法形成的 最大 数值。 注意 任何数字的二进制表示 不含 前导零 思路:暴力枚举 class Soluti…...
Android学习7 -- NDK2 -- 几个例子
学习 Android 的 NDK(Native Development Kit)可以帮助你用 C/C 来开发高性能的 Android 应用,特别适合对性能要求较高的任务,如音视频处理、游戏开发和硬件驱动等。下面是学习 NDK 的建议步骤和具体例子: ### 1. **准…...
问:说说JVM不同版本的变化和差异?
在Java程序的执行过程中,Java虚拟机(JVM)扮演着至关重要的角色。它不仅负责解释和执行Java字节码,还管理着程序运行时的内存。根据JVM规范,JVM将其所管理的内存划分为多个不同的数据区域,包括程序计数器、J…...
计算机毕业设计 基于Python的社交音乐分享平台的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档
🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…...
51单片机的水位检测系统【proteus仿真+程序+报告+原理图+演示视频】
1、主要功能 该系统由AT89C51/STC89C52单片机LCD1602显示模块水位传感器继电器LED、按键和蜂鸣器等模块构成。适用于水位监测、水位控制、水位检测相似项目。 可实现功能: 1、LCD1602实时显示水位高度 2、水位传感器采集水位高度 3、按键可设置水位的下限 4、按键可手动加…...
Python和R及Julia妊娠相关疾病生物剖析算法
🎯要点 算法使用了矢量投影、现代优化线性代数、空间分区技术和大数据编程利用相应向量空间中标量积和欧几里得距离的紧密关系来计算使用妊娠相关疾病(先兆子痫)、健康妊娠和癌症测试算法模型使用相关性投影利用相关性和欧几里得距离之间的关…...
Web安全 - 重放攻击(Replay Attack)
文章目录 OWASP 2023 TOP 10导图1. 概述2. 重放攻击的原理攻击步骤 3. 常见的重放攻击场景4. 防御重放攻击的技术措施4.1 使用时效性验证(Time-Based Tokens)4.2 单次令牌机制(Nonce)4.3 TLS/SSL 协议4.4 HMAC(哈希消息…...
Python项目文档生成常用工具对比
写在前面: 通过阅读本片文章,你将了解:主流的Python项目文档生成工具(Sphinx,MkDocs,pydoc,Pdoc)简介及对比,本文档不涉及相关工具的使用。 概述 近期,由于…...
教育领域的技术突破:SpringBoot系统实现
2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…...
RabbitMQ入门3—virtual host参数详解
在 RabbitMQ 中,创建 Virtual Host 时会涉及到一些参数配置,比如 tags 和 Default Queue Type。下面是对这两个参数的详细解释: 1. Tags Tags 是 Virtual Host 的标记,用来为 Virtual Host 添加元数据,帮助你管理和组…...
【Nacos入门到实战十四】Nacos配置管理:集群部署与高可用策略
个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] 📱…...
UE5+ChatGPT实现3D AI虚拟人综合实战
第11章 综合实战:UE5ChatGPT实现3D AI虚拟人 通过结合Unreal Engine 5(UE5)的强大渲染能力和ChatGPT的自然语言处理能力,我们可以实现一个高度交互性的AI虚拟人。本文将详细介绍如何在UE5中安装必要的插件,配置OpenAI…...
[图形学]smallpt代码详解(2)
一、简介 本文紧接在[图形学]smallpt代码详解(1)之后,继续详细讲解smallpt中的代码,包括自定义函数(第41到47行)和递归路径跟踪函数(第48到74行)部分。 二、smallpt代码详解 1.自…...
vmstat命令:系统性能监控
一、命令简介 vmstat 是一种在类 Unix 系统上常用的性能监控工具,它可以报告虚拟内存统计信息,包括进程、内存、分页、块 IO、陷阱(中断)和 CPU 活动等。 二、命令参数 2.1 命令格式 vmstat [选项] [ 延迟 [次数] ]2…...
linux部署NFS和autofs自动挂载
目录 (一)NFS: 1. 什么是NFS 2. NFS守护进程 3. RPC服务 4. 原理 5. 部署 5.1 安装NFS服务 5.2 配置防火墙 5.3 创建服务端共享目录 5.4 修改服务端配置文件 (1). /etc/exports (2). nfs.conf 5.5 启动nfs并加入自启 5.6 客户端…...
WPF RadioButton 绑定boolean值
<RadioButtonMargin"5"Content"替换"IsChecked"{Binding CorrectionOption.ReCorrectionMode}" /> <RadioButtonMargin"5"Content"平均"IsChecked"{Binding CorrectionOption.ReCorrectionMode, Converter{St…...
2024 ciscn WP
一、MISC 1.火锅链观光打卡 打开后连接自己的钱包,然后点击开始游戏,答题八次后点击获取NFT,得到有flag的图片 没什么多说的,知识问答题 兑换 NFT Flag{y0u_ar3_hotpot_K1ng} 2.Power Trajectory Diagram 方法1: 使用p…...
代码随想录--字符串--重复的子字符串
题目 给定一个非空的字符串,判断它是否可以由它的一个子串重复多次构成。给定的字符串只含有小写英文字母,并且长度不超过10000。 示例 1: 输入: "abab" 输出: True 解释: 可由子字符串 "ab" 重复两次构成。示例 2: 输入: "…...
深度解析Scarab:空洞骑士模组管理器的专业实现与架构设计
深度解析Scarab:空洞骑士模组管理器的专业实现与架构设计 【免费下载链接】Scarab An installer for Hollow Knight mods written with Avalonia. 项目地址: https://gitcode.com/gh_mirrors/sc/Scarab 空洞骑士模组管理器Scarab为玩家提供了高效、专业的模组…...
深部空间专属孪生,打造密闭硐室独有不可替代透明体系技术白皮书
深部空间专属孪生,打造密闭硐室独有不可替代透明体系技术白皮书副标题:井下专用暗光算法实现三维实时重建,搭配地下专属无感定位、多盲区跨镜穿透追踪、身体指纹特征识别,场景适配独一无二,行业无同类对标方案前言矿山…...
低多边形≠简陋!掌握这7个结构化Prompt技巧,3分钟产出可商用IP形象(附Figma网格对齐校验表)
更多请点击: https://intelliparadigm.com 第一章:低多边形设计的认知革命:从“简陋感”到“结构化美学” 低多边形(Low-Poly)设计曾长期被误读为建模能力不足的妥协产物,但其本质是一场对数字视觉语法的系…...
TPU柔性材料3D打印GoPro车载支架:从减震原理到实战拍摄全指南
1. 项目概述与设计思路我一直对第一人称视角(FPV)拍摄很着迷,尤其是那种能贴着地面、模拟小车视角疾驰的画面,动态感和沉浸感是手持拍摄无法比拟的。市面上的运动相机车载支架要么是硬连接,颠簸起来画面抖动得厉害&…...
AI驱动工作流自动化:从原理到实践,构建智能效率引擎
1. 项目概述:当AI遇上工作流,一场效率革命正在发生最近在GitHub上看到一个名为“WorkflowAI/WorkflowAI”的项目,这个名字本身就充满了想象空间。作为一个长期与各种自动化工具和效率方法论打交道的人,我立刻意识到,这…...
82.人工智能实战:大模型多环境治理怎么做?从开发、测试、预发到生产的 Prompt、模型、知识库隔离方案
人工智能实战:大模型多环境治理怎么做?从开发、测试、预发到生产的 Prompt、模型、知识库隔离方案 一、问题场景:测试环境改了 Prompt,结果生产回答变了 很多大模型项目早期只有一个环境: 一套 Prompt 一个知识库 一个模型地址 一个配置表开发、测试、运营都在同一套配置…...
跨平台鼠标控制库ez-cursor-free:原理、实现与自动化实战
1. 项目概述与核心价值如果你是一名开发者,尤其是经常需要处理跨平台UI自动化、游戏脚本或者桌面应用交互的开发者,那么你一定对“鼠标控制”这个基础但又充满细节的环节感到过头疼。不同的操作系统(Windows, macOS, Linux)提供了…...
CircuitPython REPL与库管理:嵌入式开发的效率利器
1. CircuitPython REPL:你的嵌入式开发“瑞士军刀” 如果你玩过Arduino,肯定对“上传-编译-看结果”这个循环不陌生。每次改一行代码,都得重新编译、上传,然后盯着串口看输出,效率低得让人抓狂。CircuitPython带来的R…...
基于BLE与UriBeacon标准,打造低成本物理网页信标实践指南
1. 项目概述:从蓝牙信标到物理网页的进化 几年前,当我第一次接触iBeacon时,就被这种“静默广播、主动感知”的物联网交互模式吸引了。一个小小的硬件,不用配对,就能让周围的手机知道它的存在,并触发相应的…...
在济宁,随着设备搬运服务需求的持续增长,市面上涌现出众多设
在济宁,设备搬运服务需求不断增加,众多厂家纷纷涌现,选择一家口碑良好的设备搬运厂家成为不少人的关注焦点。本次测评旨在通过客观的评估,为对济宁设备搬运厂家感兴趣的人群提供有价值的参考。参与本次测评的厂家为山东荣上机械设…...
