[Pytorch]手写数字识别——真·手写!
Github网址:https://github.com/diaoquesang/pytorchTutorials/tree/main
本教程创建于2023/7/31,几乎所有代码都有对应的注释,帮助初学者理解dataset、dataloader、transform的封装,初步体验调参的过程,初步掌握opencv、pandas、os等库的使用,😋纯手撸手写数字识别项目(为减少代码量简化了部分数据集相关操作),全流程跑通Pytorch!❤️❤️❤️
This tutorial was created on 2023/7/31. Almost all the code has corresponding comments, to help beginners understand dataset, dataloader, transform packaging, preliminary experience of the process of tuning the parameters, the initial grasp of the use of libraries such as opencv, pandas, os, etc., 😋 and get involved in this handwritten digit recognition project (we simplified some dataset-related operations in order to reduce the amount of code). Enjoy the whole process of running Pytorch!❤️❤️❤️如果喜欢本项目的话,留下你的⭐吧!
Give me a ⭐ if you like this project!
一、train.py
import torch
import torchvisionfrom torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transformsimport os
import cv2 as cv
import pandas as pdclass myDataset(Dataset): # 定义数据集类def __init__(self, annotations_file, img_dir, transform=None,target_transform=None): # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)self.img_labels = pd.read_csv(annotations_file, sep=" ", header=None)# 从标签路径中读取标签,sep为划分间隔符,header为列标题的行位置self.img_dir = img_dir # 读取图像路径self.transform = transform # 读取图像预处理方式self.target_transform = target_transform # 读取标签预处理方式def __len__(self):return len(self.img_labels) # 读取标签数量作为数据集长度def __getitem__(self, idx): # 从数据集中取出数据img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])# 从标签对象中取出第idx行第0列(第0列为图像位置所在列)的值(numberImages\5.bmp),并与图像路径(numberImages)进行拼接image = cv.imread(img_path) # 用openCV的imread函数读取图像label = self.img_labels.iloc[idx, 1] # 从标签对象中取出第idx行第1列(第1列为图像标签所在列)的值(5)if self.transform:image = self.transform(image) # 图像预处理if self.target_transform:label = self.target_transform(label) # 标签预处理return image, label # 返回图像和标签class myTransformMethod1(): # Python3默认继承object类def __call__(self, img): # __call___,让类实例变成一个可以被调用的对象,像函数img = cv.resize(img, (28, 28)) # 改变图像大小img = cv.cvtColor(img, cv.COLOR_BGR2RGB) # 将BGR(openCV默认读取为BGR)改为RGBreturn img # 返回预处理后的图像# 测试函数
# print(pd.read_csv("annotations.txt", sep=" ", header=None))
# print(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0]))
# print(pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 1])
# cv.imshow("1",cv.imread(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0])))
# cv.waitKey(0)class myNetwork(nn.Module): # 定义神经网络def __init__(self):super().__init__() # 继承nn.Module的构造器self.flatten = nn.Flatten(-3, -1)# 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)self.linear_relu_stack = nn.Sequential( # 定义前向传播序列nn.Linear(3 * 28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x): # 定义前向传播方法x = self.flatten(x)logits = self.linear_relu_stack(x)return logits# 设置运行环境,默认为cuda,若cuda不可用则改为mps,若mps也不可用则改为cpu
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device") # 输出运行环境model = myNetwork().to(device) # 创建神经网络模型实例# 设置超参数
learning_rate = 1e-5 # 学习率
batch_size = 8 # 每批数据数量
epochs = 3000 # 总轮数img_path = "./numberImages" # 设置图像路径
label_path = "./annotations.txt" # 设置标签路径myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])
# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]myDataset = myDataset(label_path, img_path, myTransform) # 创建数据集实例myDataLoader = DataLoader(myDataset, batch_size=batch_size,shuffle=True)
# 创建数据读取器(可对训练集和测试集分别创建),batch_size为每批数据数量(一般为2的n次幂以提高运行速度),shuffle为随机打乱数据def train():# 根据epochs(总轮数)训练for epoch in range(epochs):totalLoss = 0# 分批读取数据for batch, (images, labels) in enumerate(myDataLoader):# 数据转换到对应运行环境images = images.to(device)labels = labels.to(device)pred = model(images) # 前向传播myLoss = nn.CrossEntropyLoss() # 定义损失函数(交叉熵)optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 定义优化器loss = myLoss(pred, labels) # 计算损失函数totalLoss += loss # 计入总损失函数loss.backward() # 反向传播optimizer.step() # 更新权重optimizer.zero_grad() # 清空梯度if batch % 1 == 0: # 每隔1个batch输出1次lossloss, current = loss.item(), min((batch + 1) * batch_size,len(myDataset))print(f"epoch: {epoch:>5d} loss: {loss:>7f} [{current:>5d}/{len(myDataset):>5d}]")if epoch == 0:minTotalLoss = totalLossif totalLoss < minTotalLoss:print("······························模型已保存······························")minTotalLoss = totalLosstorch.save(model, "./myModel.pth") # 保存性能最好的模型if __name__ == "__main__":model.train() # 设置训练模式train()
二、eval.py
import torch
import torchvisionfrom torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transformsimport os
import cv2 as cv
import pandas as pdclass myTransformMethod1(): # Python3默认继承object类def __call__(self, img): # __call___,让类实例变成一个可以被调用的对象,像函数img = cv.resize(img, (28, 28)) # 改变图像大小img = cv.cvtColor(img, cv.COLOR_BGR2RGB) # 将BGR(openCV默认读取为BGR)改为RGBreturn img # 返回预处理后的图像class myNetwork(nn.Module): # 定义神经网络def __init__(self):super().__init__() # 继承nn.Module的构造器self.flatten = nn.Flatten(-3, -1)# 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)self.linear_relu_stack = nn.Sequential( # 定义前向传播序列nn.Linear(3 * 28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x): # 定义前向传播方法x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsif __name__ == "__main__":model = torch.load("./myModel.pth").to("cuda") # 载入模型model.eval() # 设置推理模式myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]for i in range(10):img = cv.imread("./numberImages/"+str(i)+".bmp") # 用openCV的imread函数读取图像img = myTransform(img).to("cuda") # 图像预处理print(torch.argmax(model(img)))
三、其余资料详见Github
相关文章:
[Pytorch]手写数字识别——真·手写!
Github网址:https://github.com/diaoquesang/pytorchTutorials/tree/main 本教程创建于2023/7/31,几乎所有代码都有对应的注释,帮助初学者理解dataset、dataloader、transform的封装,初步体验调参的过程,初步掌握openc…...
android studio 找不到符号类 Canvas 或者 错误: 程序包java.awt不存在
android studio开发提示 解决办法是: import android.graphics.Canvas; import android.graphics.Color; 而不是 //import java.awt.Canvas; //import java.awt.Color;...
AWS——02篇(AWS之服务存储EFS在Amazon EC2上的挂载——针对EC2进行托管文件存储)
AWS——02篇(AWS之服务存储EFS在Amazon EC2上的挂载——针对EC2进行托管文件存储) 1. 前言2. 关于Amazon EFS2.1 Amazon EFS全称2.2 什么是Amazon EFS2.3 优点和功能2.4 参考官网 3. 创建文件系统3.1 创建 EC2 实例3.2 创建文件系统 4. 在Linux实例上挂载…...
FFmpeg 打包mediacodec 编码帧 MPEGTS
在Android平台上合成视频一般使用MediaCodec进行硬编码,使用MediaMuxer进行封装,但是因为MediaMuxer支持格式有限,一般会采用ffmpeg封装,比如监控一般使用mpeg2ts格式而非MP4,这是因为两者对帧时pts等信息封装差异导致应用场景不同…...
软件测试如何推进项目进度?
在软件研发中,有一种思想叫TDD,即测试驱动开发,TDD是敏捷方法中的一项核心实践,其原理是在开发功能代码之前,先编写单元测试用例代码,对要编写的函数或类明确测试方法后,再进行设计与编码。 本…...
首次尝试鸿蒙开发!
今天是我第一次尝试鸿蒙开发,是因为身边的学长有搞这个的,而我也觉得我也该拓宽一下技术栈! 首先配置环境,唉~真的是非常心累,下载一个DevEco Studio 3.0.0.993,然后配置环境变量这些操作不用多说ÿ…...
前端面试题-react
1 React 中 keys 的作⽤是什么? Keys 是 React ⽤于追踪哪些列表中元素被修改、被添加或者被移除的辅助标识在开发过程中,我们需要保证某个元素的 key 在其同级元素中具有唯⼀性。在 React Diff 算法中 React 会借助元素的 Key 值来判断该元素是新近创建…...
EIP-2535 Diamond standard 实用工具分享
前段时间工作对接到了这标准的协议,于是简单介绍下这个标准分享下方便前端er使用的调用工具 一、标准的诞生 在写复杂逻辑的solidity智能合约时,经常会碰到两个问题,升级和合约大小限制。 升级目前有几种proxy模式,通过delegateca…...
【LangChain】向量存储(Vector stores)
LangChain学习文档 【LangChain】向量存储(Vector stores)【LangChain】向量存储之FAISS 概要 存储和搜索非结构化数据的最常见方法之一是嵌入它并存储生成的嵌入向量,然后在查询时嵌入非结构化查询并检索与嵌入查询“最相似”的嵌入向量。向量存储负责存储嵌入数…...
Debian/Ubuntu 安装 Chrome 和 Chrome Driver 并使用 selenium 自动化测试
截至目前,Chrome 仍是最好用的浏览器,没有之一。Chrome 不仅是日常使用的利器,通过 Chrome Driver 驱动和 selenium 等工具包,在执行自动任务中也是一绝。相信大家对 selenium 在 Windows 的配置使用已经有所了解了,下…...
[SQL挖掘机] - 窗口函数 - 合计: with rollup
介绍: 在sql中,with rollup 是一种用于在查询结果中生成小计和总计的选项。它可以与 group by 子句一起使用,用于在分组查询的结果中添加附加行。 with rollup 的作用是为每个指定的分组列生成小计,并在最后添加一行总计。这样,…...
远程控制平台一之推拉流的实现
确定框架 在选用推拉流框架的时候,有了解过nginx+rtmp/rtsp,Janus,以及其他开源的推拉流框架,要么是延迟严重(延迟一分多钟),要么配置复杂,而且这些框架对于只是转发远程画面这个简单需求来说,过于庞大了。机缘巧合之下,我了解到了一个简单易用的框架,就是ZeroMQ的…...
RTT(RT-Thread)线程管理(1.2W字详细讲解)
目录 RTT线程管理 线程管理特点 线程工作机制 线程控制块 线程属性 线程状态之间切换 线程相关操作 创建和删除线程 创建线程 删除线程 动态创建线程实例 启动线程 初始化和脱离线程 初始化线程 脱离线程 静态创建线程实例 线程辅助函数 获得当前线程 让出处…...
你真的会自动化吗?Web自动化测试-PO模式实战,一文通透...
目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 PO模式 Page Obj…...
C# 使用堆栈实现队列
232 使用堆栈实现队列 请你仅使用两个栈实现先入先出队列。队列应当支持一般队列支持的所有操作(、、、):pushpoppeekempty 实现 类:MyQueue void push(int x)将元素 x 推到队列的末尾 int pop()从队列的开头移除并返回元素 in…...
git操作:修改本地的地址
Windows下git如何修改本地默认下载仓库地址 - 简书 (jianshu.com) 详细解释: 打开终端拉取git时,会默认在git安装的地方,也就是终端前面的地址。 需要将代码 拉取到D盘的话,现在D盘创建好需要安放代码的文件夹,然后…...
【以图搜图】Python实现根据图片批量匹配(查找)相似图片
目的:可以解决在本地实现根据图片查找相似图片的功能 背景:由于需要查找别人代码保存的图像的命名,但由于数据集是cifa10图像又小又多,所以直接找很费眼睛,所以实现用该代码根据图像查找图像,从而得到保存…...
【无标题】JSP--Java的服务器页面
jsp是什么? jsp的全称是Java server pages,翻译过来就是java的服务器页面。 jsp有什么作用? jsp的主要作用是代替Servlet程序回传html页面的数据,因为Servlet程序回传html页面数据是一件非常繁琐的事情,开发成本和维护成本都非常高…...
【Linux】进程间通信——system V共享内存 | 消息队列 | 信号量
文章目录 一、system V共享内存1. 共享内存的原理2. 共享内存相关函数3. 共享内存实现通信4. 共享内存的特点 二、system V消息队列(了解)三、system V信号量(信号量) 一、system V共享内存 1. 共享内存的原理 共享内存是一种在…...
CentOS实现html转pdf
CentOS使用实现html转PDF,需安装以下软件: yum install wkhtmltopdf # 转换工具,将HTML文件或网页转换为PDFyum install xorg-x11-server-Xvfb # 虚拟的X服务器,在无图形界面环境下运行图形应用程yum install wqy-zenhei-fonts #…...
大数据学习栈记——Neo4j的安装与使用
本文介绍图数据库Neofj的安装与使用,操作系统:Ubuntu24.04,Neofj版本:2025.04.0。 Apt安装 Neofj可以进行官网安装:Neo4j Deployment Center - Graph Database & Analytics 我这里安装是添加软件源的方法 最新版…...
Vue记事本应用实现教程
文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...
Appium+python自动化(十六)- ADB命令
简介 Android 调试桥(adb)是多种用途的工具,该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具,其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利,如安装和调试…...
在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:
在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档,…...
Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)
概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...
使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...
laravel8+vue3.0+element-plus搭建方法
创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...
佰力博科技与您探讨热释电测量的几种方法
热释电的测量主要涉及热释电系数的测定,这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中,积分电荷法最为常用,其原理是通过测量在电容器上积累的热释电电荷,从而确定热释电系数…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...
