当前位置: 首页 > news >正文

深度学习手写字符识别:训练模型

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoaderfrom dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnxparser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:torch.backends.cudnn.benchmark = Truedef train():os.makedirs('./output', exist_ok=True)if True: #not os.path.exists('output/total.txt'):ml.image_list(args.datapath, 'output/total.txt')ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)model = Net(10)#model = models.vgg16(num_classes=10)#model = models.resnet18(num_classes=10)  # 调用内置模型#model.load_state_dict(torch.load('./output/params_10.pth'))#from torchsummary import summary#summary(model, (3, 28, 28))if args.cuda:print('training with cuda')model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)loss_func = nn.CrossEntropyLoss()for epoch in range(args.epochs):# training-----------------------------------model.train()train_loss = 0train_acc = 0for batch, (batch_x, batch_y) in enumerate(train_loader):if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)  # 256x3x28x28  out 256x10loss = loss_func(out, batch_y)train_loss += loss.item()pred = torch.max(out, 1)[1]train_correct = (pred == batch_y).sum()train_acc += train_correct.item()print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'% (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),loss.item(), train_correct.item() / len(batch_x)))optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()  # 更新learning rateprint('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),train_acc / (len(train_data))))# evaluation--------------------------------model.eval()eval_loss = 0eval_acc = 0for batch_x, batch_y in val_loader:if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)loss = loss_func(out, batch_y)eval_loss += loss.item()pred = torch.max(out, 1)[1]num_correct = (pred == batch_y).sum()eval_acc += num_correct.item()print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),eval_acc / (len(val_data))))# 保存模型。每隔多少帧存模型,此处可修改------------if (epoch + 1) % 1 == 0:# torch.save(model, 'output/model_' + str(epoch+1) + '.pth')torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')#to_onnx(model, 3, 28, 28, 'params.onnx')if __name__ == '__main__':train()
  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

相关文章:

深度学习手写字符识别:训练模型

说明 本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。 第一个深度学习实例手写字符识别 深度学习环境配置 可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。 Windows11搭建GPU版本PyTorch环境详细过程 数…...

Day 1. 学习linux高级编程之Shell命令和IO

1.C语言基础 现阶段学习安排 2.IO编程 多任务编程(进程、线程) 网络编程 数据库编程 3.数据结构 linux软件编程 1.linux: 操作系统:linux其实是操作系统的内核 系统调用:linux内核的函数接口 操作流程&#xff…...

STM32--SPI通信协议(1)SPI基础知识总结

前言 I2C (Inter-Integrated Circuit)和SPI (Serial Peripheral Interface)是两种常见的串行通信协议,用于连接集成电路芯片之间的通信,选择I2C或SPI取决于具体的应用需求。如果需要较高的传输速度和简单的接口,可以选择SPI。如果需要连接多…...

Debezium系列之:MariaDB10.5以上版本赋予数据库账号读取binlog权限的变化

Debezium系列之:MariaDB10.5以上版本赋予数据库账号读取binlog权限的变化 一、背景二、BINLOG MONITOR权限三、BINLOG MONITOR和REPLICA MONITOR的区别四、MariaDB版本升级的影响五、总结一、背景 数据接入会检测账号是否具有REPLICATION SLAVE、REPLICATION CLIENT的权限Mari…...

迅为STM32MP157开发板底板板载4G接口(选配)、千兆以太网、WIFI蓝牙模块

底板扩展接口丰富 底板板载4G接口(选配)、千兆以太网、WIFI蓝牙模块HDMI、CAN、RS485、LVDS接口、温湿度传感器(选配)光环境传感器、六轴传感器、2路USB OTG、3路串口CAMERA接口、ADC电位器、SPDIF、SDIO接口等。 支持多种显示屏 迅为在MP157开发板支持了多种屏幕&#xff0…...

「实用分享」用界面组件Telerik UI for Blazor增强你的财务图表!

Telerik UI for Blazor拥有110个原生的、易于定制的Blazor UI组件和高性能网格组件,能节约一半的时间开发全新的Blazor应用程序并使传统web项目现代化,其中囊括了设计和生成工具等。Telerik UI for Blazor控件提供的控件,可轻松满足应用程序对…...

使用org.openscada.utgard java opcda库做opc客户端时长期运行存在的若干问题

牛11月09日反馈东区存在以下问题,由于在现场未来得及处理。11月10日反馈西区亦存在此问题。经排查此问题已存在相当长一段时间(最长为9月底即存在)。 1、读报错Value: [[org.jinterop.dcom.core.VariantBody$EMPTY212c0aff]], Timestamp: Mo…...

杰克与魔法树的冒险

从前有一个小村庄,里面住着一个善良勇敢的小男孩叫杰克。杰克非常喜欢冒险和探索未知的事物。 一天,杰克听说村庄附近的森林里有一个神奇的魔法树,树上结满了金色的苹果。他决定去寻找这棵魔法树,并带回一些金苹果给村庄的居民们。…...

第九节HarmonyOS 常用基础组件22-Marquee

1、描述 跑马灯组件,用于滚动展示一段单行文本,仅当文本内容宽度超过跑马灯组件宽度时滚动。 2、接口 Marquee(value:{start:boolean, step?:number, loop?:number, fromStart?: boolean ,src:string}) 3、参数 参数名 参数类型 必填 描述 st…...

烽火传递

看似很简单的单调队列优化DP 但是如果状态是表示前\(i\)个烽火台被处理完的最小代价(即不知道最后一个烽火台在哪里)就无法降低复杂度 因为假设你在区间\([i-m1,i]\)中枚举最后一个烽火台(设为\(k\)),你前面的状态并不是\(f[k-1]\),因为此时\(k\)已经可以…...

《深入浅出Go语言》大纲

目录 为什么选择《深入浅出Go语言》? 基础核心模块 为什么选择《深入浅出Go语言》? 🚀 全面的基础知识体系 从环境搭建开始,对Go语言核心知识点进行深入探讨,深度挖掘每个基础知识的本质,为后续深入学习…...

flv视频格式批量截取封面图(不占内存版)--其他视频格式也通用

flv视频格式批量截取封面图(不占内存版)--其他视频格式也通用 需求(实现的效果)功能实现htmlcssjs 需求(实现的效果) 批量显示视频,后端若返回有imgUrl,则直接显示图1, 若无&#xf…...

【鸿蒙】大模型对话应用(三):跨Ability跳转页面

Demo介绍 本demo对接阿里云和百度的大模型API,实现一个简单的对话应用。 DecEco Studio版本:DevEco Studio 3.1.1 Release HarmonyOS SDK版本:API9 关键点:ArkTS、ArkUI、UIAbility、网络http请求、列表布局、层叠布局 页面跳…...

明道云入选亿欧智库《AIGC入局与低代码产品市场的发展研究》

2023年12月27日,亿欧智库正式发布**《AIGC入局与低代码产品市场的发展研究》**。该报告剖析了低代码/零代码市场的现状和发展趋势,深入探讨了大模型技术对此领域的影响和发展洞察。其中,亿欧智库将明道云作为标杆产品进行了研究和分析。 明…...

【深度学习】SDXL TensorRT Dockerfile Docker容器

文章目录 过程SDXL TensorRT构建SDXL TensorRT LCM 调度器过程 docker push kevinchina/deeplearning:cuda12.1torch2.1.1 FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive# 安装基本软件包 RUN apt-get update && \apt-get u…...

深入了解 Ansible:全面掌握自动化 IT 环境的利器

本文以详尽的篇幅介绍了 Ansible 的方方面面,旨在帮助读者从入门到精通。无论您是初学者还是有一定经验的 Ansible 用户,都可以在本文中找到对应的内容,加深对 Ansible 的理解和应用。愿本文能成为您在 Ansible 自动化旅程中的良师益友&#…...

PPT、PDF全文档翻译相关产品调研笔记

主要找一下是否有比较给力的全文档翻译 文章目录 1 百度翻译2 小牛翻译3 腾讯交互翻译4 DeepL5 languagex6 云译科技7 快翻:qtrans8 simplifyai9 officetranslator10 火山引擎翻译-无文档翻译1 百度翻译 地址: https://fanyi.baidu.com/ 配套的比较完善,对于不同行业也有区…...

JavaScript 垃圾回收的常用策略和内存管理

垃圾回收 ​ JavaScript 是使用垃圾回收的语言,也就是说执行环境负责在代码执行时管理内存。在 C 和 C等语言中,跟踪内存使用对开发者来说是个很大的负担,也是很多问题的来源。JavaScript 为开发者卸下了这个负担,通过自动内存管…...

如何结合ChatGPT生成个人魔法咒语词库

3.6.1 ChatGPT辅助力AI绘画 3.6.1.1 给定主题让ChatGPT直接描述 上面给了一个简易主题演示一下,这是完全我没有细化的提问,然后把直接把这些关键词组合在一起。 关键词: 黄山的美景,生机勃勃,湛蓝天空,青…...

瑞_23种设计模式_抽象工厂模式

文章目录 1 抽象工厂模式(Abstract Factory Pattern)1.1 概念1.2 介绍1.3 小结1.4 结构 2 案例一2.1 案例需求2.2 代码实现 3 案例二3.1 需求3.2 实现 4 总结4.1 抽象工厂模式优缺点4.2 抽象工厂模式使用场景4.3 抽象工厂模式VS工厂方法模式4.4 抽象工厂…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...

【Oracle APEX开发小技巧12】

有如下需求: 有一个问题反馈页面,要实现在apex页面展示能直观看到反馈时间超过7天未处理的数据,方便管理员及时处理反馈。 我的方法:直接将逻辑写在SQL中,这样可以直接在页面展示 完整代码: SELECTSF.FE…...

vscode(仍待补充)

写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh? debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...

12.找到字符串中所有字母异位词

🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

云原生玩法三问:构建自定义开发环境

云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

Linux nano命令的基本使用

参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...

Docker拉取MySQL后数据库连接失败的解决方案

在使用Docker部署MySQL时,拉取并启动容器后,有时可能会遇到数据库连接失败的问题。这种问题可能由多种原因导致,包括配置错误、网络设置问题、权限问题等。本文将分析可能的原因,并提供解决方案。 一、确认MySQL容器的运行状态 …...

stm32wle5 lpuart DMA数据不接收

配置波特率9600时,需要使用外部低速晶振...