当前位置: 首页 > news >正文

[Pytorch]手写数字识别——真·手写!

Github网址:https://github.com/diaoquesang/pytorchTutorials/tree/main

本教程创建于2023/7/31,几乎所有代码都有对应的注释,帮助初学者理解dataset、dataloader、transform的封装,初步体验调参的过程,初步掌握opencv、pandas、os等库的使用,😋纯手撸手写数字识别项目(为减少代码量简化了部分数据集相关操作),全流程跑通Pytorch!❤️❤️❤️
This tutorial was created on 2023/7/31. Almost all the code has corresponding comments, to help beginners understand dataset, dataloader, transform packaging, preliminary experience of the process of tuning the parameters, the initial grasp of the use of libraries such as opencv, pandas, os, etc., 😋 and get involved in this handwritten digit recognition project (we simplified some dataset-related operations in order to reduce the amount of code). Enjoy the whole process of running Pytorch!❤️❤️❤️

如果喜欢本项目的话,留下你的⭐吧!
Give me a ⭐ if you like this project!

一、train.py

import torch
import torchvisionfrom torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transformsimport os
import cv2 as cv
import pandas as pdclass myDataset(Dataset):  # 定义数据集类def __init__(self, annotations_file, img_dir, transform=None,target_transform=None):  # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)self.img_labels = pd.read_csv(annotations_file, sep=" ", header=None)# 从标签路径中读取标签,sep为划分间隔符,header为列标题的行位置self.img_dir = img_dir  # 读取图像路径self.transform = transform  # 读取图像预处理方式self.target_transform = target_transform  # 读取标签预处理方式def __len__(self):return len(self.img_labels)  # 读取标签数量作为数据集长度def __getitem__(self, idx):  # 从数据集中取出数据img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])# 从标签对象中取出第idx行第0列(第0列为图像位置所在列)的值(numberImages\5.bmp),并与图像路径(numberImages)进行拼接image = cv.imread(img_path)  # 用openCV的imread函数读取图像label = self.img_labels.iloc[idx, 1]  # 从标签对象中取出第idx行第1列(第1列为图像标签所在列)的值(5)if self.transform:image = self.transform(image)  # 图像预处理if self.target_transform:label = self.target_transform(label)  # 标签预处理return image, label  # 返回图像和标签class myTransformMethod1():  # Python3默认继承object类def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数img = cv.resize(img, (28, 28))  # 改变图像大小img = cv.cvtColor(img, cv.COLOR_BGR2RGB)  # 将BGR(openCV默认读取为BGR)改为RGBreturn img  # 返回预处理后的图像# 测试函数
# print(pd.read_csv("annotations.txt", sep=" ", header=None))
# print(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0]))
# print(pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 1])
# cv.imshow("1",cv.imread(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0])))
# cv.waitKey(0)class myNetwork(nn.Module):  # 定义神经网络def __init__(self):super().__init__()  # 继承nn.Module的构造器self.flatten = nn.Flatten(-3, -1)# 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)self.linear_relu_stack = nn.Sequential(  # 定义前向传播序列nn.Linear(3 * 28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):  # 定义前向传播方法x = self.flatten(x)logits = self.linear_relu_stack(x)return logits# 设置运行环境,默认为cuda,若cuda不可用则改为mps,若mps也不可用则改为cpu
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")  # 输出运行环境model = myNetwork().to(device)  # 创建神经网络模型实例# 设置超参数
learning_rate = 1e-5  # 学习率
batch_size = 8  # 每批数据数量
epochs = 3000  # 总轮数img_path = "./numberImages"  # 设置图像路径
label_path = "./annotations.txt"  # 设置标签路径myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])
# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]myDataset = myDataset(label_path, img_path, myTransform)  # 创建数据集实例myDataLoader = DataLoader(myDataset, batch_size=batch_size,shuffle=True)
# 创建数据读取器(可对训练集和测试集分别创建),batch_size为每批数据数量(一般为2的n次幂以提高运行速度),shuffle为随机打乱数据def train():# 根据epochs(总轮数)训练for epoch in range(epochs):totalLoss = 0# 分批读取数据for batch, (images, labels) in enumerate(myDataLoader):# 数据转换到对应运行环境images = images.to(device)labels = labels.to(device)pred = model(images)  # 前向传播myLoss = nn.CrossEntropyLoss()  # 定义损失函数(交叉熵)optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # 定义优化器loss = myLoss(pred, labels)  # 计算损失函数totalLoss += loss  # 计入总损失函数loss.backward()  # 反向传播optimizer.step()  # 更新权重optimizer.zero_grad()  # 清空梯度if batch % 1 == 0:  # 每隔1个batch输出1次lossloss, current = loss.item(), min((batch + 1) * batch_size,len(myDataset))print(f"epoch: {epoch:>5d} loss: {loss:>7f}  [{current:>5d}/{len(myDataset):>5d}]")if epoch == 0:minTotalLoss = totalLossif totalLoss < minTotalLoss:print("······························模型已保存······························")minTotalLoss = totalLosstorch.save(model, "./myModel.pth")  # 保存性能最好的模型if __name__ == "__main__":model.train()  # 设置训练模式train()

二、eval.py

import torch
import torchvisionfrom torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transformsimport os
import cv2 as cv
import pandas as pdclass myTransformMethod1():  # Python3默认继承object类def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数img = cv.resize(img, (28, 28))  # 改变图像大小img = cv.cvtColor(img, cv.COLOR_BGR2RGB)  # 将BGR(openCV默认读取为BGR)改为RGBreturn img  # 返回预处理后的图像class myNetwork(nn.Module):  # 定义神经网络def __init__(self):super().__init__()  # 继承nn.Module的构造器self.flatten = nn.Flatten(-3, -1)# 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)self.linear_relu_stack = nn.Sequential(  # 定义前向传播序列nn.Linear(3 * 28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):  # 定义前向传播方法x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsif __name__ == "__main__":model = torch.load("./myModel.pth").to("cuda")  # 载入模型model.eval()  # 设置推理模式myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]for i in range(10):img = cv.imread("./numberImages/"+str(i)+".bmp")  # 用openCV的imread函数读取图像img = myTransform(img).to("cuda")  # 图像预处理print(torch.argmax(model(img)))

三、其余资料详见Github

相关文章:

[Pytorch]手写数字识别——真·手写!

Github网址&#xff1a;https://github.com/diaoquesang/pytorchTutorials/tree/main 本教程创建于2023/7/31&#xff0c;几乎所有代码都有对应的注释&#xff0c;帮助初学者理解dataset、dataloader、transform的封装&#xff0c;初步体验调参的过程&#xff0c;初步掌握openc…...

android studio 找不到符号类 Canvas 或者 错误: 程序包java.awt不存在

android studio开发提示 解决办法是&#xff1a; import android.graphics.Canvas; import android.graphics.Color; 而不是 //import java.awt.Canvas; //import java.awt.Color;...

AWS——02篇(AWS之服务存储EFS在Amazon EC2上的挂载——针对EC2进行托管文件存储)

AWS——02篇&#xff08;AWS之服务存储EFS在Amazon EC2上的挂载——针对EC2进行托管文件存储&#xff09; 1. 前言2. 关于Amazon EFS2.1 Amazon EFS全称2.2 什么是Amazon EFS2.3 优点和功能2.4 参考官网 3. 创建文件系统3.1 创建 EC2 实例3.2 创建文件系统 4. 在Linux实例上挂载…...

FFmpeg 打包mediacodec 编码帧 MPEGTS

在Android平台上合成视频一般使用MediaCodec进行硬编码&#xff0c;使用MediaMuxer进行封装&#xff0c;但是因为MediaMuxer支持格式有限&#xff0c;一般会采用ffmpeg封装&#xff0c;比如监控一般使用mpeg2ts格式而非MP4,这是因为两者对帧时pts等信息封装差异导致应用场景不同…...

软件测试如何推进项目进度?

在软件研发中&#xff0c;有一种思想叫TDD&#xff0c;即测试驱动开发&#xff0c;TDD是敏捷方法中的一项核心实践&#xff0c;其原理是在开发功能代码之前&#xff0c;先编写单元测试用例代码&#xff0c;对要编写的函数或类明确测试方法后&#xff0c;再进行设计与编码。 本…...

首次尝试鸿蒙开发!

今天是我第一次尝试鸿蒙开发&#xff0c;是因为身边的学长有搞这个的&#xff0c;而我也觉得我也该拓宽一下技术栈&#xff01; 首先配置环境&#xff0c;唉~真的是非常心累&#xff0c;下载一个DevEco Studio 3.0.0.993&#xff0c;然后配置环境变量这些操作不用多说&#xff…...

前端面试题-react

1 React 中 keys 的作⽤是什么&#xff1f; Keys 是 React ⽤于追踪哪些列表中元素被修改、被添加或者被移除的辅助标识在开发过程中&#xff0c;我们需要保证某个元素的 key 在其同级元素中具有唯⼀性。在 React Diff 算法中 React 会借助元素的 Key 值来判断该元素是新近创建…...

EIP-2535 Diamond standard 实用工具分享

前段时间工作对接到了这标准的协议&#xff0c;于是简单介绍下这个标准分享下方便前端er使用的调用工具 一、标准的诞生 在写复杂逻辑的solidity智能合约时&#xff0c;经常会碰到两个问题&#xff0c;升级和合约大小限制。 升级目前有几种proxy模式&#xff0c;通过delegateca…...

【LangChain】向量存储(Vector stores)

LangChain学习文档 【LangChain】向量存储(Vector stores)【LangChain】向量存储之FAISS 概要 存储和搜索非结构化数据的最常见方法之一是嵌入它并存储生成的嵌入向量&#xff0c;然后在查询时嵌入非结构化查询并检索与嵌入查询“最相似”的嵌入向量。向量存储负责存储嵌入数…...

Debian/Ubuntu 安装 Chrome 和 Chrome Driver 并使用 selenium 自动化测试

截至目前&#xff0c;Chrome 仍是最好用的浏览器&#xff0c;没有之一。Chrome 不仅是日常使用的利器&#xff0c;通过 Chrome Driver 驱动和 selenium 等工具包&#xff0c;在执行自动任务中也是一绝。相信大家对 selenium 在 Windows 的配置使用已经有所了解了&#xff0c;下…...

[SQL挖掘机] - 窗口函数 - 合计: with rollup

介绍: 在sql中&#xff0c;with rollup 是一种用于在查询结果中生成小计和总计的选项。它可以与 group by 子句一起使用&#xff0c;用于在分组查询的结果中添加附加行。 with rollup 的作用是为每个指定的分组列生成小计&#xff0c;并在最后添加一行总计。这样&#xff0c;…...

远程控制平台一之推拉流的实现

确定框架 在选用推拉流框架的时候,有了解过nginx+rtmp/rtsp,Janus,以及其他开源的推拉流框架,要么是延迟严重(延迟一分多钟),要么配置复杂,而且这些框架对于只是转发远程画面这个简单需求来说,过于庞大了。机缘巧合之下,我了解到了一个简单易用的框架,就是ZeroMQ的…...

RTT(RT-Thread)线程管理(1.2W字详细讲解)

目录 RTT线程管理 线程管理特点 线程工作机制 线程控制块 线程属性 线程状态之间切换 线程相关操作 创建和删除线程 创建线程 删除线程 动态创建线程实例 启动线程 初始化和脱离线程 初始化线程 脱离线程 静态创建线程实例 线程辅助函数 获得当前线程 让出处…...

你真的会自动化吗?Web自动化测试-PO模式实战,一文通透...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 PO模式 Page Obj…...

C# 使用堆栈实现队列

232 使用堆栈实现队列 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作&#xff08;、、、&#xff09;&#xff1a;pushpoppeekempty 实现 类&#xff1a;MyQueue void push(int x)将元素 x 推到队列的末尾 int pop()从队列的开头移除并返回元素 in…...

git操作:修改本地的地址

Windows下git如何修改本地默认下载仓库地址 - 简书 (jianshu.com) 详细解释&#xff1a; 打开终端拉取git时&#xff0c;会默认在git安装的地方&#xff0c;也就是终端前面的地址。 需要将代码 拉取到D盘的话&#xff0c;现在D盘创建好需要安放代码的文件夹&#xff0c;然后…...

【以图搜图】Python实现根据图片批量匹配(查找)相似图片

目的&#xff1a;可以解决在本地实现根据图片查找相似图片的功能 背景&#xff1a;由于需要查找别人代码保存的图像的命名&#xff0c;但由于数据集是cifa10图像又小又多&#xff0c;所以直接找很费眼睛&#xff0c;所以实现用该代码根据图像查找图像&#xff0c;从而得到保存…...

【无标题】JSP--Java的服务器页面

jsp是什么&#xff1f; jsp的全称是Java server pages,翻译过来就是java的服务器页面。 jsp有什么作用&#xff1f; jsp的主要作用是代替Servlet程序回传html页面的数据&#xff0c;因为Servlet程序回传html页面数据是一件非常繁琐的事情&#xff0c;开发成本和维护成本都非常高…...

【Linux】进程间通信——system V共享内存 | 消息队列 | 信号量

文章目录 一、system V共享内存1. 共享内存的原理2. 共享内存相关函数3. 共享内存实现通信4. 共享内存的特点 二、system V消息队列&#xff08;了解&#xff09;三、system V信号量&#xff08;信号量&#xff09; 一、system V共享内存 1. 共享内存的原理 共享内存是一种在…...

CentOS实现html转pdf

CentOS使用实现html转PDF&#xff0c;需安装以下软件&#xff1a; yum install wkhtmltopdf # 转换工具&#xff0c;将HTML文件或网页转换为PDFyum install xorg-x11-server-Xvfb # 虚拟的X服务器&#xff0c;在无图形界面环境下运行图形应用程yum install wqy-zenhei-fonts #…...

shell脚本--常见案例

1、自动备份文件或目录 2、批量重命名文件 3、查找并删除指定名称的文件&#xff1a; 4、批量删除文件 5、查找并替换文件内容 6、批量创建文件 7、创建文件夹并移动文件 8、在文件夹中查找文件...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

Nginx server_name 配置说明

Nginx 是一个高性能的反向代理和负载均衡服务器&#xff0c;其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机&#xff08;Virtual Host&#xff09;。 1. 简介 Nginx 使用 server_name 指令来确定…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)

笔记整理&#xff1a;刘治强&#xff0c;浙江大学硕士生&#xff0c;研究方向为知识图谱表示学习&#xff0c;大语言模型 论文链接&#xff1a;http://arxiv.org/abs/2407.16127 发表会议&#xff1a;ISWC 2024 1. 动机 传统的知识图谱补全&#xff08;KGC&#xff09;模型通过…...

拉力测试cuda pytorch 把 4070显卡拉满

import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试&#xff0c;通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小&#xff0c;增大可提高计算复杂度duration: 测试持续时间&#xff08;秒&…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec&#xff1f; IPsec VPN 5.1 IPsec传输模式&#xff08;Transport Mode&#xff09; 5.2 IPsec隧道模式&#xff08;Tunne…...

深度学习习题2

1.如果增加神经网络的宽度&#xff0c;精确度会增加到一个特定阈值后&#xff0c;便开始降低。造成这一现象的可能原因是什么&#xff1f; A、即使增加卷积核的数量&#xff0c;只有少部分的核会被用作预测 B、当卷积核数量增加时&#xff0c;神经网络的预测能力会降低 C、当卷…...