基于飞桨paddle的极简方案构建手写数字识别模型测试代码
基于飞桨paddle的极简方案构建手写数字识别模型测试代码

原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:
PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape: (252, 255)
采样后图片shape: (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,[[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>
测试代码如下所示:
#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__) #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__) #2.4.2#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''print("图像数据形状和对应数据为:", train_data0.shape) #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0) #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0)) # [5]#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义一层全连接层,输出维度是1self.fc = paddle.nn.Linear(in_features=784, out_features=1)# 定义网络结构的前向计算过程def forward(self, inputs):outputs = self.fc(inputs)return outputs
#===========================================#训练配置
# 声明网络结构
model = MNIST()
def train(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):# 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]assert len(img.shape) == 3batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]# 归一化图像数据img = img / 255# 将图像形式reshape为[batch_size, 784]img = paddle.reshape(img, [batch_size, img_h*img_w])return img
#===========================================
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')# 声明网络结构
model = MNIST()
#===========================================
def run(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())EPOCH_NUM = 10for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):images = norm_img(data[0]).astype('float32')labels = data[1].astype('float32')#前向计算的过程predicts = model(images)# 计算损失loss = F.square_error_cost(predicts, labels)avg_loss = paddle.mean(loss)#每训练了1000批次的数据,打印下当前Loss的情况if batch_id % 1000 == 0:print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()opt.step()opt.clear_grad()
#===========================================
#调用训练
run(model)
paddle.save(model.state_dict(), './mnist.pdparams') #模型测试#===========================================
def showImage(im):#img_path = 'example_0.jpg'# 读取原始图像并显示#im = Image.open('example_0.jpg')plt.imshow(im)plt.show()# 将原始图像转为灰度图im = im.convert('L')print('原始图像shape: ', np.array(im).shape)# 使用Image.ANTIALIAS方式采样原始图片im = im.resize((28, 28), Image.LANCZOS)plt.imshow(im)plt.show()print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)# 读取一张本地的样例图片,转变成模型输入的格式
#===========================================
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')# print(np.array(im))im = im.resize((28, 28), Image.LANCZOS)im = np.array(im).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255return im
#===========================================
# 定义预测过程
def test():model = MNIST()params_file_path = 'mnist.pdparams'img_path = 'example_0.jpg'# 加载模型参数param_dict = paddle.load(params_file_path)model.load_dict(param_dict)# 灌入数据model.eval()tensor_img = load_image(img_path) result = model(paddle.to_tensor(tensor_img))print('result',result)# 预测输出取整,即为预测的数字,打印结果print("本次预测的数字是", result.numpy().astype('int32'))
#===========================================
test();
相关文章:
基于飞桨paddle的极简方案构建手写数字识别模型测试代码
基于飞桨paddle的极简方案构建手写数字识别模型测试代码 原始测试图片为255X252的图片 因为是极简方案采用的是线性回归模型,所以预测结果数字不一致 本次预测的数字是 [[3]] 测试结果: PS E:\project\python> & D:/Python39/python.exe e:/pro…...
soft ip与hard ip
ip分soft和hard两种,soft就是纯代码,买过来要自己综合自己pr。hard ip如mem和analog与工艺有关。 mem的lib和lef是memory compiler产生的,基于bitcell,是foundry给的。 我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起…...
MyBatisPlus从入门到精通-2
接着上一讲的Mp的分页功能 下面我们讲解条件查询功能和其他功能 解决一下日志输出和banner问题 每次卞就会输出这些日志 很不美观,现在我们关闭一下 这样建个xml,文件名为logback.xml 文件内容改成这样 配置了logback但是里面什么都没写就不会说有日…...
AI面试官:Asp.Net 中使用Log4Net (一)
AI面试官:Asp.Net 中使用Log4Net (一) 当面试涉及到使用log4net日志记录框架的相关问题时,通常会聚焦在如何在.NET或.NET Core应用程序中集成和使用log4net。以下是一些关于log4net的面试题目,以及相应的解答、案例和代码: 文章目…...
Selenium自动化元素定位方式与浏览器测试脚本
Selenium八大元素定位方法 Selenium可以驱动浏览器完成各种操作,比如模拟点击等。要想操作一个元素,首先应该识别这个元素。人有各种的特征(属性),我们可以通过其特征找到人,如通过身份证号、姓名、家庭住…...
人机交互与人机混合智能的区别
人机交互和人机融合智能是两个相关但不完全相同的概念: 人机交互是指人与计算机之间的信息交流和互动过程。它关注的是如何设计和实现用户友好的界面,以便人们能够方便、高效地与计算机进行沟通和操作。人机交互通常强调用户体验和界面设计,旨…...
【项目】轻量级HTTP服务器
文章目录 一、项目介绍二、前置知识2.1 URI、URL、URN2.2 CGI2.2.1 CGI的概念2.2.2 CGI模式的实现2.2.3 CGI的意义 三、项目设计3.1 日志的编写3.2 套接字编写3.3 HTTP服务器实现3.4 HTTP请求与响应结构3.5 EndPoint类的实现3.5.1 EndPoint的基本逻辑3.5.2 读取请求3.5.3 构建响…...
sketch如何在线打开?有没有什么软件可以辅助
Sketch 在线打开的方法有哪些?这个问题和我之前回答过的「Sketch 可以在线编辑吗?」是一样的答案,没有。很遗憾,Sketch 没有在线打开的方法,Sketch 也做不到可以在线编辑。那么,那些广告里出现的设计软件工…...
CSS Flex 笔记
1. Flexbox 术语 Flex 容器可以是<div> 等,对其设置属性:display: flex, justify-content 是沿主轴方向调整元素,align-items 是沿交叉轴对齐元素。 2. Cheatsheet 2.1 设置 Flex 容器,加粗的属性为默认值 2.1.1 align-it…...
Markdown常用标签及其用途-有示例
Markdown常用标签及其用途 Markdown是一种轻量级标记语言,具有简洁易读的特点。下面是一些常用的Markdown标签以及它们的用途,并附带一些示例: 标题 用于创建不同级别的标题,可通过添加一到六个#符号来表示不同级别的标题。 #…...
25.1 Knife4j-Swagger的增强插件
1.Knife4j概述 Knife4j是一款基于Swagger UI的增强插件,它可以为Spring Boot项目生成美观且易于使用的API文档界面。它是Swagger UI的增强版,提供了更多的功能和定制选项,使API文档更加易读和易于理解。 2.Knife4j使用 Knife4j 集Swagger2…...
用flask run代替flask run --debug
安装python-dotenv依赖。 在项目根目录下新建.flaskenv文件,并作如下配置: FLASK_ENVdevelopment FLASK_DEBUG1...
python_day14_综合案例
文件内容 导包配置 import jsonfrom pyspark import SparkContext, SparkConf import osos.environ["PYSPARK_PYTHON"] "D:/dev/python/python3.10.4/python.exe" os.environ["HADOOP_HOME"] "D:/dev/hadoop-3.0.0" conf SparkC…...
【算法题】2779. 数组的最大美丽值
题目: 给你一个下标从 0 开始的整数数组 nums 和一个 非负 整数 k 。 在一步操作中,你可以执行下述指令: 在范围 [0, nums.length - 1] 中选择一个 此前没有选过 的下标 i 。 将 nums[i] 替换为范围 [nums[i] - k, nums[i] k] 内的任一整…...
文件上传之PHP
别怕,我会一直陪着你 一.知识二.实例1.phtml, <?简单过滤2.前端验证, phtml3 \.htaccess 一.知识 绕过后缀的有文件格式有php,php3,php4,php5,phtml.pht 二.实例 1.phtml, <?简单过滤 (1)一句话木马 故意使用了post和get用来迷惑人 https://127.0.0.1/shy.php?POS…...
人脸检测实战-insightface
目录 简介 一、InsightFace介绍 二、安装 三、快速体验 四、代码实战 1、人脸检测 2、人脸识别 五、代码及示例图片链接 简介 目前github有非常多的人脸识别开源项目,下面列出几个常用的开源项目: 1、deepface 2、CompreFace 3、face_recogn…...
Linux工具【1】(编辑器vim、编译器gcc与g++)
vim详解 引言vimVim的三种模式及模式切换普通模式下操作底行模式下操作 gcc与ggcc的使用(g类似)预编译编译汇编链接静态库与动态库 总结 引言 vim(vi improved)编辑器是从 vi 发展出来的一个文本编辑器。 代码补全、编译及错误跳…...
基于Java+SpringBoot+vue前后端分离古典舞在线交流平台设计实现
博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…...
MQ - 闲聊MQ一二事儿 (Kafka、RocketMQ 、Pulsar )
文章目录 MQ的发展史阶段一:追求解耦阶段二:追求吞吐量与一致性阶段三:追求平台化 MQ的通用架构主题topic、生产者producer、消费者consumer分区partition MQ 存储KafkaGood Design ---> 磁盘顺序写盘Poor Impact---> topic 数量不能过…...
Qt中的 QIODevice类(包含:随机访问、顺序访问设备)
QIODevice类 一、简介 QIODevice用于对输入输出设备进行管理,是Qt中所有I/O设备的基接口类。为支持读写数据块的设备(如QFile、QBuffer和QTcpSocket)提供了通用实现和抽象接口。 输入设备有2种类型: 一种是随机访问设备,QFile(文件)和QBuff…...
设计模式和设计原则回顾
设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...
Flask RESTful 示例
目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题: 下面创建一个简单的Flask RESTful API示例。首先,我们需要创建环境,安装必要的依赖,然后…...
盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
【机器视觉】单目测距——运动结构恢复
ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛…...
【git】把本地更改提交远程新分支feature_g
创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...
【服务器压力测试】本地PC电脑作为服务器运行时出现卡顿和资源紧张(Windows/Linux)
要让本地PC电脑作为服务器运行时出现卡顿和资源紧张的情况,可以通过以下几种方式模拟或触发: 1. 增加CPU负载 运行大量计算密集型任务,例如: 使用多线程循环执行复杂计算(如数学运算、加密解密等)。运行图…...
QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...
sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!
简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求,并检查收到的响应。它以以下模式之一…...
