1. pytorch手写数字预测
1. pytorch手写数字预测
- 1.背景
- 2.准备数据集
- 2.定义模型
- 3.dataloader和训练
- 4.训练模型
- 5.测试模型
- 6.保存模型
1.背景
因为自身的研究方向是多模态目标跟踪,突然对其他的视觉方向产生了兴趣,所以心血来潮的回到最经典的视觉任务手写数字预测上来,所以这份教程并不是一份非常详尽的教程,是在一部分pytorch,深度学习基础上的教程,如果需要的是非常保姆级的教程建议看别的文章
2.准备数据集
这里我才用了直接导torchvision中的dataset包来下载Mnist数据集,也算是一个非常经典的数据集了
# 导入数据集
from torchvision.datasets import MNIST
import torch# 设置随机种子
torch.manual_seed(3306)# 数据预处理
from torchvision import transforms
# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(), # 转换为 Tensortransforms.Normalize((0.1307,), (0.3081,)) # 标准化
])# 下载 MNIST 数据集
mnist_train = MNIST(root='./dataset_file/mnist_raw', train=True, download=True,transform=transform)
mnist_test = MNIST(root='./dataset_file/mnist_raw', train=False, download=True,transform=transform)
# 查看数据集大小
print(f"MNIST train dataset size: {len(mnist_train)}")
print(f"MNIST test dataset size: {len(mnist_test)}")
其中,MNIST()中的root代表的是数据集存放的位置,download代表是如果当前位置没有数据集是否需要下载。
transformer则是对数据的处理方式,我这里采用了简单地转成tensor和简单地标准化。
不过这样子下载下来的数据集是二进制格式的,无法直接查看图片,当然,如果你需要查看图片,也有办法。
# 查看图片
import matplotlib.pyplot as pltdef show_image(id):img, label = mnist_train[id]img = img.squeeze().numpy() # 去掉通道维度print(img.shape)# print(img)plt.imshow(img, cmap='gray')plt.title(f"Label: {label}")plt.axis('off')plt.show()show_image(1)
效果
又或者你想要下载的数据集是图片格式,我这里也准备了代码
代码是在别人的基础上改的,其中数据集存放路径是dataset_dir,如果需要修改自行打印然后修改位置就好了。
#!/usr/bin/env python3
# -*- encoding utf-8 -*-'''
@File: save_mnist_to_jpg.py
@Date: 2024-08-23
@Author: KRISNAT
@Version: 0.0.0
@Email: ****
@Copyright: (C)Copyright 2024, KRISNAT
@Desc:1. 通过 torchvision.datasets.MNIST 下载、解压和读取 MNIST 数据集;2. 使用 PIL.Image.save 将 MNIST 数据集中的灰度图片以 JPEG 格式保存。
'''import sys, os
sys.path.insert(0, os.getcwd())from torchvision.datasets import MNIST
import PIL
from tqdm import tqdmif __name__ == "__main__":home_dir = os.path.abspath('.')root = os.path.abspath(os.path.join(home_dir, '../dataset_file'))print(root)# exit(0)# 图片保存路径dataset_dir = os.path.join(root, 'mnist_jpg')if not os.path.exists(dataset_dir):os.makedirs(dataset_dir)# 从网络上下载或从本地加载MNIST数据集# 训练集60K、测试集10K# torchvision.datasets.MNIST接口下载的数据一组元组# 每个元组的结构是: (PIL.Image.Image image model=L size=28x28, 标签数字 int)training_dataset = MNIST(root='mnist',train=True,download=True,)test_dataset = MNIST(root='mnist',train=False,download=True,)# 保存训练集图片with tqdm(total=len(training_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(training_dataset):f = dataset_dir + "/" + "training_" + str(idx) + \"_" + str(training_dataset[idx][1] ) + ".jpg" # 文件路径training_dataset[idx][0].save(f)pro_bar.update(n=1)# 保存测试集图片with tqdm(total=len(test_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(test_dataset):f = dataset_dir + "/" + "test_" + str(idx) + \"_" + str(test_dataset[idx][1] ) + ".jpg" # 文件路径test_dataset[idx][0].save(f)pro_bar.update(n=1)
2.定义模型
这里我准备了两个模型,一个MLP模型和一个简单地CNN模型,其中MLP模型参数量1M,CNN模型参数量大概8M,当然这俩模型也没有很仔细的规划
import torch
import torch.nn as nnclass DigitLinear(nn.Module):def __init__(self):super(DigitLinear, self).__init__()self.fc1 = nn.Linear(28 * 28, 1000)self.fc2 = nn.Linear(1000, 500)self.dropout = nn.Dropout(0.3)self.fc3 = nn.Linear(500, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)x = torch.relu(x)x = self.fc3(x)return xclass DigitCNN(nn.Module):def __init__(self):super(DigitCNN,self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64*28*28, 128)self.dropout = nn.Dropout(0.1)self.fc2 = nn.Linear(128, 10)def forward(self, x):# print("x.shape:", x.shape)B,N,H,W = x.shapex = self.conv1(x)x = torch.relu(x)x = self.conv2(x)x = torch.relu(x)x = x.view(B, -1) # 展平x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)return x
3.dataloader和训练
这里的代码就很简单了,就是一些参数的选择,例如epoch,batchsize。其中的训练函数我写的买有很全面,只是勉强满足了训练功能,还有好多可以优化的点,比如打印fps,断点续训练啥的,不过这个任务提不起劲去干这事,大家可以自行优化。
# 数据加载器
from torch.utils.data import DataLoader
from lib.model.DigitModel import DigitLinear,DigitCNN
# 定义数据加载器
batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)epoch = 50
# 训练模型
net = DigitLinear() # 参数量1M 97.50%
# net = DigitCNN() # 参数量8M 98.81%
net.cuda()# 定义损失函数和优化器
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练函数def train_model(model, train_loader, criterion, optimizer, num_epochs=10):model.train() # 设置模型为训练模式for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):inputs= inputs.cuda()y = torch.tensor(torch.zeros((inputs.shape[0],10), dtype=torch.float)).cuda()y[torch.arange(inputs.shape[0]), labels] = 1optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, y)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100. * correct / totalprint(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 训练模型
train_model(net, train_loader, criterion, optimizer, num_epochs=epoch)
4.训练模型
有了上面的代码就可以开始训练了,我这里训练的截图是我的MLP模型,效果不是很好,CNN的效果稍微好一点,比MLP高1%,但是图忘记截了。反正够用了,因为本身MNIST的数据就不是很完美,有很多类似于噪声的数据例如:
这些数字我人眼都分不出是什么玩意。
训练效果如下
5.测试模型
训练完当然是测试了
最后我的MLP模型跑了97.50%的准确率
代码如下
# 测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.eval()
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device).float(), labels.to(device).float()outputs = net(inputs)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()# print(f"Predicted: {predicted}, Ground Truth: {targets}")print(f"Accuracy: {correct / total * 100:.4f} %")
6.保存模型
保存模型代码就更简单了
# 保存模型
torch.save(net.state_dict(), './digit_model.pth')
相关文章:

1. pytorch手写数字预测
1. pytorch手写数字预测 1.背景2.准备数据集2.定义模型3.dataloader和训练4.训练模型5.测试模型6.保存模型 1.背景 因为自身的研究方向是多模态目标跟踪,突然对其他的视觉方向产生了兴趣,所以心血来潮的回到最经典的视觉任务手写数字预测上来࿰…...
vs中添加三方库的流程
在Visual Studio(VS)中添加第三方库(如OpenCV、PCL等)的流程可以分为以下几个步骤:安装库、配置项目、编写代码。以下是详细的步骤说明: 1. 安装第三方库 首先,需要下载并安装所需的第三方库。…...
JAVASE面相对象进阶之static
JavaSE 面向对象进阶之 static 一、static 的核心作用 static 是 Java 中用于修饰成员(属性/方法)的关键字,作用是让成员与类直接关联,而非依赖对象存在。 二、static 修饰属性(静态变量) 特点…...
深入解析 Redis Cluster 架构与实现(一)
#作者:stackofumbrella 文章目录 Redis Cluster特点Redis Cluster与其它集群模式的区别集群目标性能hash tagsMutli-key操作Cluster Bus安全写入(write safety)集群节点的属性集群拓扑节点间handshake重定向与reshardingMOVED重定向ASK重定向…...
(12)java+ selenium->元素定位大法之By_link_text
1.简介 本章节介绍元素定位中的link_text,顾名思义是通过链接定位的(官方说法:超链接文本定位)。什么是link_text呢,就是我们在任何一个网页上都可以看到有一个或者多个链接,上面有一个文字描述,点击这个文字,就可以跳转到其他页面。这个就是link_Text。 注意:link_t…...
数据库MySQL集群MGR
一、MGR原理 一、基本定义 MGR(MySQL Group Replication) 是 MySQL 官方推出的一种高可用、高可靠的数据库集群解决方案,基于分布式系统理论(如 Paxos 协议变种)实现,主要用于构建强一致性的主从复制集群…...
Ubuntu22.04 安装 ROS2 Humble
ROS2 Documentation: Humble Ubuntu 22.04 对应的 ROS 2 版本是 ROS 2 Humble Hawksbill (LTS)。 1.设置系统区域 确保区域设置支持UTF-8 sudo apt update && sudo apt install locales sudo locale-gen en_US en_US.UTF-8 sudo update-locale LC_ALLen_US.UTF-8 L…...
Spring Boot,注解,@RestController
RestController 是 Spring MVC 中用于创建 RESTful Web 服务的核心注解。 RestController 核心知识点 REST 作用: RestController 是一个方便的组合注解,它结合了 Controller 和 ResponseBody 两个注解。 Controller: 将类标记为一个控制器,使其能够处理…...
C++中新式类型转换static_cast、const_cast、dynamic_cast、reinterpret_cast
C中新式类型转换static_cast、const_cast、dynamic_cast、reinterpret_cast 在C中,新式类型转换(也称为强制类型转换)是C标准引入的一种更安全、更明确的类型转换方式,用以替代C语言风格的类型转换。C提供了四种新式类型转换操作…...

AXI 协议补充(二)
axi协议存在slave 和master 之间的数据交互,在ahb ,axi-stream 高速接口 ,叠加大位宽代码逻辑中,往往有时序问题,valid 和ready 的组合电路中的问题引发的时序问题较多。 本文根据axi 协议和现有解决反压造成的时序问题的方法做一个详细的科普。 1. 解决时序问题的方法:…...

Linux 基础指令入门指南:解锁命令行的实用密码
文章目录 引言:Linux 下基本指令常用选项ls 指令pwd 命令cd 指令touch 指令mkdir 指令rmdir 指令 && rm 指令man 指令cp 指令mv 指令cat 指令more 指令less 指令head 指令tail 指令date 指令cal 指令find 指令按文件名搜索按文件大小搜索按修改时间搜索按文…...

标准精读:2025 《可信数据空间 技术架构》【附全文阅读】
《可信数据空间 技术架构》规范了可信数据空间的技术架构,明确其作为国家数据基础设施的定位,以数字合约和使用控制技术为核心,涵盖功能架构(含服务平台与接入连接器的身份管理、目录管理、数字合约管理等功能)、业务流程(登记、发现、创建空间及数据流通利用)及安全要求…...

山东大学软件学院项目实训-基于大模型的模拟面试系统-面试官和面试记录的分享功能(2)
本文记录在发布文章时,可以添加自己创建的面试官和面试记录到文章中这一功能的实现。 前端 首先是在原本的界面的底部添加了两个多选框(后期需要美化调整) 实现的代码: <el-col style"margin-top: 1rem;"><e…...

Webug4.0靶场通关笔记05- 第5关SQL注入之过滤关键字
目录 一、代码审计 1、源码分析 2、SQL注入分析 (1)大小写绕过 (2)双写绕过 二、第05关 过滤型注入 1、进入靶场 2、sqlmap渗透 (1)bp抓包保存报文 (2)sqlmap渗透 &…...

ONLYOFFICE文档API:更强的安全功能
在数字化办公时代,文档的安全性与隐私保护已成为企业和个人用户的核心关切。如何确保信息在存储、传输及协作过程中的安全,是开发者与IT管理者亟需解决的问题。ONLYOFFICE作为一款功能强大的开源办公套件,不仅提供了高效的文档编辑与协作体验…...
深入浅出MQTT协议:从物联网基础到实战应用全解析
深入浅出MQTT协议:从物联网基础到实战应用全解析 作为一名在物联网领域摸爬滚打多年的老程序员,今天来和大家聊聊物联网通信中最核心的技术之一——MQTT协议。无论是Java后端开发还是嵌入式硬件开发,掌握MQTT都能让你在物联网项目中如鱼得水…...

解析楼宇自控系统:分布式结构的核心特点与优势展现
在建筑智能化发展的进程中,楼宇自控系统作为实现建筑高效运行与管理的关键,其系统结构的选择至关重要。传统的集中式楼宇自控系统在面对日益复杂的建筑环境和多样化的管理需求时,逐渐暴露出诸多弊端,如可靠性低、扩展性差、响应速…...

C#数字图像处理(三)
文章目录 前言1.图像平移1.1 图像平移定义1.2 图像平移编程实例 2.图像镜像2.1 图像镜像定义2.2 图像镜像编程实例 3.图像缩放3.1 图像缩放定义3.2 灰度插值法3.3 图像缩放编程实例 4.图像旋转4.1 图像旋转定义4.2 图像旋转编程实例 前言 在某种意义上来说,图像的几…...
STM32 智能小车项目 L298N 电机驱动模块
今天开始着手做智能小车的项目了 在智能小车或机器人项目中,我们经常会听到一个词叫 “H 桥电机驱动”,尤其是常见的 L298N 模块,就是基于“双 H 桥”原理设计的。那么,“H 桥”到底是什么?为什么要用“双 H 桥”来驱动…...

SQL Transactions(事务)、隔离机制
目录 Why Transactions? Example: Bad Interaction Transactions ACID Transactions COMMIT ROLLBACK How the Transaction Log Works How Data Is Stored Example: Interacting Processes Interleaving of Statements Example: Strange Interleaving Fixing the…...
【动画】unity中实现骨骼蒙皮动画
我是一名资深的游戏客户端,没事的时候我就想手搓轮子 本文目标 搓一个骨骼动画的核心实现,促进理解骨骼动画本质 骨骼动画简介 官方解释上网搜或者问豆包 快速理解 想知道骨骼动画怎么个事要先知道模型是怎么个事 简单来说:模型 顶点数…...
VSCODE的终端无法执行npm命令
问题原因:PowerShell 默认可能限制脚本执行。 解决方法: 在 PowerShell 中运行以下命令,查看当前策略: Get-ExecutionPolicy 如果结果是 Restricted,改为 RemoteSigned: Set-ExecutionPolicy RemoteSigne…...
Langchian - 自定义提示词模板 提取结构化的数据
场景:从自然语言中提取固定结构信息返回 例如:根据一段文字,提取文字中人的具体特征 马路上走来一个1米7的女生,她一头乌黑的长发披在肩上随风飘动,在她旁边的是她的男朋友,叫:刘山;比她高10厘米 如果想要提取上面这句话中人的身高及头发的颜色,并以固定的格式返回,…...

【机器学习基础】机器学习入门核心:Jaccard相似度 (Jaccard Index) 和 Pearson相似度 (Pearson Correlation)
机器学习入门核心:Jaccard相似度 (Jaccard Index) 和 Pearson相似度 (Pearson Correlation) 一、算法逻辑Jaccard相似度 (Jaccard Index)**Pearson相似度 (Pearson Correlation)** 二、算法原理与数学推导1. Jaccard相…...

QT之头像剪裁效果实现
文章目录 源码地址,环境:QT5.15,MinGW32位效果演示导入图片设置剪裁区域创建剪裁小窗口重写剪裁小窗口的鼠标事件mousePressEventmouseMoveEventmouseReleaseEvent 小窗口移动触发父窗口的重绘事件剪裁效果实现 源码地址,环境&…...
apptrace 视角下移动端深度链接技术与优势
官网链接:AppTrace - 专业的移动应用推广追踪平台 App 拉起,本质上是移动端深度链接技术的具象化呈现。在这一领域,apptrace 凭借前沿技术与创新理念,实现从 H5 网页到 App 的无缝跳转,精准定位 App 内指定页面&#…...
微前端之micro-app数据通信
在这之前如果还没接触过微前端,可以找一些视频、资料先去了解一下,就不在这里赘述了。 现在常见的微前端框架包括: single-spa micro-app qiankun EMP 无界 目前了解到的基本上是这些哈,大家感兴趣可以自行去了解一下,看下它们之间的区别。 因为我目前使用的是mic…...

【GPT入门】第40课 vllm与ollama特性对比,与模型部署
【GPT入门】第40课 vllm与ollama特性对比,与模型部署 1.两种部署1.1 vllm与ollama特性对比2. vllm部署2.1 服务器准备2.1 下载模型2.2 提供模型服务 1.两种部署 1.1 vllm与ollama特性对比 2. vllm部署 2.1 服务器准备 在autodl 等大模型服务器提供商,…...

unity开发棋牌游戏
使用unity开发的棋牌游戏,目前包含麻将、斗地主、比鸡、牛牛四种玩法游戏。 相关技术 客户端:unity 热更新:xlua 服务器:c Web服务器:ruoyi 游戏视频 unity开发棋牌游戏 游戏截图...

Nat Commun项目文章 ▏小麦CUTTag助力解析转录因子TaTCP6调控小麦氮磷高效利用机制
今年2月份发表在《Nature Communications》(IF14.4)的“TaTCP6 is required for efficientand balanced utilization of nitrate and phosphorus in wheat”揭示了TaTCP6在小麦氮磷利用中的关键调控作用,为优化肥料利用和提高作物产量提供了理…...