当前位置: 首页 > news >正文

深度学习手写字符识别:训练模型

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoaderfrom dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnxparser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:torch.backends.cudnn.benchmark = Truedef train():os.makedirs('./output', exist_ok=True)if True: #not os.path.exists('output/total.txt'):ml.image_list(args.datapath, 'output/total.txt')ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)model = Net(10)#model = models.vgg16(num_classes=10)#model = models.resnet18(num_classes=10)  # 调用内置模型#model.load_state_dict(torch.load('./output/params_10.pth'))#from torchsummary import summary#summary(model, (3, 28, 28))if args.cuda:print('training with cuda')model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)loss_func = nn.CrossEntropyLoss()for epoch in range(args.epochs):# training-----------------------------------model.train()train_loss = 0train_acc = 0for batch, (batch_x, batch_y) in enumerate(train_loader):if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)  # 256x3x28x28  out 256x10loss = loss_func(out, batch_y)train_loss += loss.item()pred = torch.max(out, 1)[1]train_correct = (pred == batch_y).sum()train_acc += train_correct.item()print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'% (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),loss.item(), train_correct.item() / len(batch_x)))optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()  # 更新learning rateprint('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),train_acc / (len(train_data))))# evaluation--------------------------------model.eval()eval_loss = 0eval_acc = 0for batch_x, batch_y in val_loader:if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)loss = loss_func(out, batch_y)eval_loss += loss.item()pred = torch.max(out, 1)[1]num_correct = (pred == batch_y).sum()eval_acc += num_correct.item()print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),eval_acc / (len(val_data))))# 保存模型。每隔多少帧存模型,此处可修改------------if (epoch + 1) % 1 == 0:# torch.save(model, 'output/model_' + str(epoch+1) + '.pth')torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')#to_onnx(model, 3, 28, 28, 'params.onnx')if __name__ == '__main__':train()
  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

相关文章:

深度学习手写字符识别:训练模型

说明 本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。 第一个深度学习实例手写字符识别 深度学习环境配置 可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。 Windows11搭建GPU版本PyTorch环境详细过程 数…...

Day 1. 学习linux高级编程之Shell命令和IO

1.C语言基础 现阶段学习安排 2.IO编程 多任务编程(进程、线程) 网络编程 数据库编程 3.数据结构 linux软件编程 1.linux: 操作系统:linux其实是操作系统的内核 系统调用:linux内核的函数接口 操作流程&#xff…...

STM32--SPI通信协议(1)SPI基础知识总结

前言 I2C (Inter-Integrated Circuit)和SPI (Serial Peripheral Interface)是两种常见的串行通信协议,用于连接集成电路芯片之间的通信,选择I2C或SPI取决于具体的应用需求。如果需要较高的传输速度和简单的接口,可以选择SPI。如果需要连接多…...

Debezium系列之:MariaDB10.5以上版本赋予数据库账号读取binlog权限的变化

Debezium系列之:MariaDB10.5以上版本赋予数据库账号读取binlog权限的变化 一、背景二、BINLOG MONITOR权限三、BINLOG MONITOR和REPLICA MONITOR的区别四、MariaDB版本升级的影响五、总结一、背景 数据接入会检测账号是否具有REPLICATION SLAVE、REPLICATION CLIENT的权限Mari…...

迅为STM32MP157开发板底板板载4G接口(选配)、千兆以太网、WIFI蓝牙模块

底板扩展接口丰富 底板板载4G接口(选配)、千兆以太网、WIFI蓝牙模块HDMI、CAN、RS485、LVDS接口、温湿度传感器(选配)光环境传感器、六轴传感器、2路USB OTG、3路串口CAMERA接口、ADC电位器、SPDIF、SDIO接口等。 支持多种显示屏 迅为在MP157开发板支持了多种屏幕&#xff0…...

「实用分享」用界面组件Telerik UI for Blazor增强你的财务图表!

Telerik UI for Blazor拥有110个原生的、易于定制的Blazor UI组件和高性能网格组件,能节约一半的时间开发全新的Blazor应用程序并使传统web项目现代化,其中囊括了设计和生成工具等。Telerik UI for Blazor控件提供的控件,可轻松满足应用程序对…...

使用org.openscada.utgard java opcda库做opc客户端时长期运行存在的若干问题

牛11月09日反馈东区存在以下问题,由于在现场未来得及处理。11月10日反馈西区亦存在此问题。经排查此问题已存在相当长一段时间(最长为9月底即存在)。 1、读报错Value: [[org.jinterop.dcom.core.VariantBody$EMPTY212c0aff]], Timestamp: Mo…...

杰克与魔法树的冒险

从前有一个小村庄,里面住着一个善良勇敢的小男孩叫杰克。杰克非常喜欢冒险和探索未知的事物。 一天,杰克听说村庄附近的森林里有一个神奇的魔法树,树上结满了金色的苹果。他决定去寻找这棵魔法树,并带回一些金苹果给村庄的居民们。…...

第九节HarmonyOS 常用基础组件22-Marquee

1、描述 跑马灯组件,用于滚动展示一段单行文本,仅当文本内容宽度超过跑马灯组件宽度时滚动。 2、接口 Marquee(value:{start:boolean, step?:number, loop?:number, fromStart?: boolean ,src:string}) 3、参数 参数名 参数类型 必填 描述 st…...

烽火传递

看似很简单的单调队列优化DP 但是如果状态是表示前\(i\)个烽火台被处理完的最小代价(即不知道最后一个烽火台在哪里)就无法降低复杂度 因为假设你在区间\([i-m1,i]\)中枚举最后一个烽火台(设为\(k\)),你前面的状态并不是\(f[k-1]\),因为此时\(k\)已经可以…...

《深入浅出Go语言》大纲

目录 为什么选择《深入浅出Go语言》? 基础核心模块 为什么选择《深入浅出Go语言》? 🚀 全面的基础知识体系 从环境搭建开始,对Go语言核心知识点进行深入探讨,深度挖掘每个基础知识的本质,为后续深入学习…...

flv视频格式批量截取封面图(不占内存版)--其他视频格式也通用

flv视频格式批量截取封面图(不占内存版)--其他视频格式也通用 需求(实现的效果)功能实现htmlcssjs 需求(实现的效果) 批量显示视频,后端若返回有imgUrl,则直接显示图1, 若无&#xf…...

【鸿蒙】大模型对话应用(三):跨Ability跳转页面

Demo介绍 本demo对接阿里云和百度的大模型API,实现一个简单的对话应用。 DecEco Studio版本:DevEco Studio 3.1.1 Release HarmonyOS SDK版本:API9 关键点:ArkTS、ArkUI、UIAbility、网络http请求、列表布局、层叠布局 页面跳…...

明道云入选亿欧智库《AIGC入局与低代码产品市场的发展研究》

2023年12月27日,亿欧智库正式发布**《AIGC入局与低代码产品市场的发展研究》**。该报告剖析了低代码/零代码市场的现状和发展趋势,深入探讨了大模型技术对此领域的影响和发展洞察。其中,亿欧智库将明道云作为标杆产品进行了研究和分析。 明…...

【深度学习】SDXL TensorRT Dockerfile Docker容器

文章目录 过程SDXL TensorRT构建SDXL TensorRT LCM 调度器过程 docker push kevinchina/deeplearning:cuda12.1torch2.1.1 FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 ENV DEBIAN_FRONTEND=noninteractive# 安装基本软件包 RUN apt-get update && \apt-get u…...

深入了解 Ansible:全面掌握自动化 IT 环境的利器

本文以详尽的篇幅介绍了 Ansible 的方方面面,旨在帮助读者从入门到精通。无论您是初学者还是有一定经验的 Ansible 用户,都可以在本文中找到对应的内容,加深对 Ansible 的理解和应用。愿本文能成为您在 Ansible 自动化旅程中的良师益友&#…...

PPT、PDF全文档翻译相关产品调研笔记

主要找一下是否有比较给力的全文档翻译 文章目录 1 百度翻译2 小牛翻译3 腾讯交互翻译4 DeepL5 languagex6 云译科技7 快翻:qtrans8 simplifyai9 officetranslator10 火山引擎翻译-无文档翻译1 百度翻译 地址: https://fanyi.baidu.com/ 配套的比较完善,对于不同行业也有区…...

JavaScript 垃圾回收的常用策略和内存管理

垃圾回收 ​ JavaScript 是使用垃圾回收的语言,也就是说执行环境负责在代码执行时管理内存。在 C 和 C等语言中,跟踪内存使用对开发者来说是个很大的负担,也是很多问题的来源。JavaScript 为开发者卸下了这个负担,通过自动内存管…...

如何结合ChatGPT生成个人魔法咒语词库

3.6.1 ChatGPT辅助力AI绘画 3.6.1.1 给定主题让ChatGPT直接描述 上面给了一个简易主题演示一下,这是完全我没有细化的提问,然后把直接把这些关键词组合在一起。 关键词: 黄山的美景,生机勃勃,湛蓝天空,青…...

瑞_23种设计模式_抽象工厂模式

文章目录 1 抽象工厂模式(Abstract Factory Pattern)1.1 概念1.2 介绍1.3 小结1.4 结构 2 案例一2.1 案例需求2.2 代码实现 3 案例二3.1 需求3.2 实现 4 总结4.1 抽象工厂模式优缺点4.2 抽象工厂模式使用场景4.3 抽象工厂模式VS工厂方法模式4.4 抽象工厂…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

挑战杯推荐项目

“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 ​ - 个性化梦境…...

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

CMake 从 GitHub 下载第三方库并使用

有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...

ZYNQ学习记录FPGA(一)ZYNQ简介

一、知识准备 1.一些术语,缩写和概念&#xff1a; 1&#xff09;ZYNQ全称&#xff1a;ZYNQ7000 All Pgrammable SoC 2&#xff09;SoC:system on chips(片上系统)&#xff0c;对比集成电路的SoB&#xff08;system on board&#xff09; 3&#xff09;ARM&#xff1a;处理器…...

电脑桌面太单调,用Python写一个桌面小宠物应用。

下面是一个使用Python创建的简单桌面小宠物应用。这个小宠物会在桌面上游荡&#xff0c;可以响应鼠标点击&#xff0c;并且有简单的动画效果。 import tkinter as tk import random import time from PIL import Image, ImageTk import os import sysclass DesktopPet:def __i…...

验证redis数据结构

一、功能验证 1.验证redis的数据结构&#xff08;如字符串、列表、哈希、集合、有序集合等&#xff09;是否按照预期工作。 2、常见的数据结构验证方法&#xff1a; ①字符串&#xff08;string&#xff09; 测试基本操作 set、get、incr、decr 验证字符串的长度和内容是否正…...