当前位置: 首页 > news >正文

基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码
在这里插入图片描述
原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:

PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape:  (252, 255)
采样后图片shape:  (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,[[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>

测试代码如下所示:

#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__)    #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__)   #2.4.2#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''print("图像数据形状和对应数据为:", train_data0.shape)                          #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0)         #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0))   # [5]#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义一层全连接层,输出维度是1self.fc = paddle.nn.Linear(in_features=784, out_features=1)# 定义网络结构的前向计算过程def forward(self, inputs):outputs = self.fc(inputs)return outputs
#===========================================#训练配置
# 声明网络结构
model = MNIST()
def train(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):# 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]assert len(img.shape) == 3batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]# 归一化图像数据img = img / 255# 将图像形式reshape为[batch_size, 784]img = paddle.reshape(img, [batch_size, img_h*img_w])return img  
#===========================================   
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')# 声明网络结构
model = MNIST()
#===========================================
def run(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())EPOCH_NUM = 10for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):images = norm_img(data[0]).astype('float32')labels = data[1].astype('float32')#前向计算的过程predicts = model(images)# 计算损失loss = F.square_error_cost(predicts, labels)avg_loss = paddle.mean(loss)#每训练了1000批次的数据,打印下当前Loss的情况if batch_id % 1000 == 0:print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()opt.step()opt.clear_grad()
#===========================================
#调用训练            
run(model)
paddle.save(model.state_dict(), './mnist.pdparams')  #模型测试#===========================================
def showImage(im):#img_path = 'example_0.jpg'# 读取原始图像并显示#im = Image.open('example_0.jpg')plt.imshow(im)plt.show()# 将原始图像转为灰度图im = im.convert('L')print('原始图像shape: ', np.array(im).shape)# 使用Image.ANTIALIAS方式采样原始图片im = im.resize((28, 28), Image.LANCZOS)plt.imshow(im)plt.show()print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)# 读取一张本地的样例图片,转变成模型输入的格式
#=========================================== 
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')# print(np.array(im))im = im.resize((28, 28), Image.LANCZOS)im = np.array(im).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255return im
#=========================================== 
# 定义预测过程
def test():model = MNIST()params_file_path = 'mnist.pdparams'img_path = 'example_0.jpg'# 加载模型参数param_dict = paddle.load(params_file_path)model.load_dict(param_dict)# 灌入数据model.eval()tensor_img = load_image(img_path)  result = model(paddle.to_tensor(tensor_img))print('result',result)#  预测输出取整,即为预测的数字,打印结果print("本次预测的数字是", result.numpy().astype('int32'))
#=========================================== 
test(); 

相关文章:

基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码 原始测试图片为255X252的图片 因为是极简方案采用的是线性回归模型,所以预测结果数字不一致 本次预测的数字是 [[3]] 测试结果: PS E:\project\python> & D:/Python39/python.exe e:/pro…...

soft ip与hard ip

ip分soft和hard两种,soft就是纯代码,买过来要自己综合自己pr。hard ip如mem和analog与工艺有关。 mem的lib和lef是memory compiler产生的,基于bitcell,是foundry给的。 我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起…...

MyBatisPlus从入门到精通-2

接着上一讲的Mp的分页功能 下面我们讲解条件查询功能和其他功能 解决一下日志输出和banner问题 每次卞就会输出这些日志 很不美观,现在我们关闭一下 这样建个xml,文件名为logback.xml 文件内容改成这样 配置了logback但是里面什么都没写就不会说有日…...

AI面试官:Asp.Net 中使用Log4Net (一)

AI面试官:Asp.Net 中使用Log4Net (一) 当面试涉及到使用log4net日志记录框架的相关问题时,通常会聚焦在如何在.NET或.NET Core应用程序中集成和使用log4net。以下是一些关于log4net的面试题目,以及相应的解答、案例和代码: 文章目…...

Selenium自动化元素定位方式与浏览器测试脚本

Selenium八大元素定位方法 Selenium可以驱动浏览器完成各种操作,比如模拟点击等。要想操作一个元素,首先应该识别这个元素。人有各种的特征(属性),我们可以通过其特征找到人,如通过身份证号、姓名、家庭住…...

人机交互与人机混合智能的区别

人机交互和人机融合智能是两个相关但不完全相同的概念: 人机交互是指人与计算机之间的信息交流和互动过程。它关注的是如何设计和实现用户友好的界面,以便人们能够方便、高效地与计算机进行沟通和操作。人机交互通常强调用户体验和界面设计,旨…...

【项目】轻量级HTTP服务器

文章目录 一、项目介绍二、前置知识2.1 URI、URL、URN2.2 CGI2.2.1 CGI的概念2.2.2 CGI模式的实现2.2.3 CGI的意义 三、项目设计3.1 日志的编写3.2 套接字编写3.3 HTTP服务器实现3.4 HTTP请求与响应结构3.5 EndPoint类的实现3.5.1 EndPoint的基本逻辑3.5.2 读取请求3.5.3 构建响…...

sketch如何在线打开?有没有什么软件可以辅助

Sketch 在线打开的方法有哪些?这个问题和我之前回答过的「Sketch 可以在线编辑吗?」是一样的答案,没有。很遗憾,Sketch 没有在线打开的方法,Sketch 也做不到可以在线编辑。那么,那些广告里出现的设计软件工…...

CSS Flex 笔记

1. Flexbox 术语 Flex 容器可以是<div> 等&#xff0c;对其设置属性&#xff1a;display: flex, justify-content 是沿主轴方向调整元素&#xff0c;align-items 是沿交叉轴对齐元素。 2. Cheatsheet 2.1 设置 Flex 容器&#xff0c;加粗的属性为默认值 2.1.1 align-it…...

Markdown常用标签及其用途-有示例

Markdown常用标签及其用途 Markdown是一种轻量级标记语言&#xff0c;具有简洁易读的特点。下面是一些常用的Markdown标签以及它们的用途&#xff0c;并附带一些示例&#xff1a; 标题 用于创建不同级别的标题&#xff0c;可通过添加一到六个#符号来表示不同级别的标题。 #…...

25.1 Knife4j-Swagger的增强插件

1.Knife4j概述 Knife4j是一款基于Swagger UI的增强插件&#xff0c;它可以为Spring Boot项目生成美观且易于使用的API文档界面。它是Swagger UI的增强版&#xff0c;提供了更多的功能和定制选项&#xff0c;使API文档更加易读和易于理解。 2.Knife4j使用 Knife4j 集Swagger2…...

用flask run代替flask run --debug

安装python-dotenv依赖。 在项目根目录下新建.flaskenv文件&#xff0c;并作如下配置&#xff1a; FLASK_ENVdevelopment FLASK_DEBUG1...

python_day14_综合案例

文件内容 导包配置 import jsonfrom pyspark import SparkContext, SparkConf import osos.environ["PYSPARK_PYTHON"] "D:/dev/python/python3.10.4/python.exe" os.environ["HADOOP_HOME"] "D:/dev/hadoop-3.0.0" conf SparkC…...

【算法题】2779. 数组的最大美丽值

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums 和一个 非负 整数 k 。 在一步操作中&#xff0c;你可以执行下述指令&#xff1a; 在范围 [0, nums.length - 1] 中选择一个 此前没有选过 的下标 i 。 将 nums[i] 替换为范围 [nums[i] - k, nums[i] k] 内的任一整…...

文件上传之PHP

别怕,我会一直陪着你 一.知识二.实例1.phtml, <?简单过滤2.前端验证, phtml3 \.htaccess 一.知识 绕过后缀的有文件格式有php,php3,php4,php5,phtml.pht 二.实例 1.phtml, <?简单过滤 (1)一句话木马 故意使用了post和get用来迷惑人 https://127.0.0.1/shy.php?POS…...

人脸检测实战-insightface

目录 简介 一、InsightFace介绍 二、安装 三、快速体验 四、代码实战 1、人脸检测 2、人脸识别 五、代码及示例图片链接 简介 目前github有非常多的人脸识别开源项目&#xff0c;下面列出几个常用的开源项目&#xff1a; 1、deepface 2、CompreFace 3、face_recogn…...

Linux工具【1】(编辑器vim、编译器gcc与g++)

vim详解 引言vimVim的三种模式及模式切换普通模式下操作底行模式下操作 gcc与ggcc的使用&#xff08;g类似&#xff09;预编译编译汇编链接静态库与动态库 总结 引言 vim&#xff08;vi improved&#xff09;编辑器是从 vi 发展出来的一个文本编辑器。 代码补全、编译及错误跳…...

基于Java+SpringBoot+vue前后端分离古典舞在线交流平台设计实现

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…...

MQ - 闲聊MQ一二事儿 (Kafka、RocketMQ 、Pulsar )

文章目录 MQ的发展史阶段一&#xff1a;追求解耦阶段二&#xff1a;追求吞吐量与一致性阶段三&#xff1a;追求平台化 MQ的通用架构主题topic、生产者producer、消费者consumer分区partition MQ 存储KafkaGood Design ---> 磁盘顺序写盘Poor Impact---> topic 数量不能过…...

Qt中的 QIODevice类(包含:随机访问、顺序访问设备)

QIODevice类 一、简介 QIODevice用于对输入输出设备进行管理&#xff0c;是Qt中所有I/O设备的基接口类。为支持读写数据块的设备(如QFile、QBuffer和QTcpSocket)提供了通用实现和抽象接口。 输入设备有2种类型&#xff1a; 一种是随机访问设备&#xff0c;QFile(文件)和QBuff…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...

比较数据迁移后MySQL数据库和OceanBase数据仓库中的表

设计一个MySQL数据库和OceanBase数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...

Python 实现 Web 静态服务器(HTTP 协议)

目录 一、在本地启动 HTTP 服务器1. Windows 下安装 node.js1&#xff09;下载安装包2&#xff09;配置环境变量3&#xff09;安装镜像4&#xff09;node.js 的常用命令 2. 安装 http-server 服务3. 使用 http-server 开启服务1&#xff09;使用 http-server2&#xff09;详解 …...

Ubuntu系统多网卡多相机IP设置方法

目录 1、硬件情况 2、如何设置网卡和相机IP 2.1 万兆网卡连接交换机&#xff0c;交换机再连相机 2.1.1 网卡设置 2.1.2 相机设置 2.3 万兆网卡直连相机 1、硬件情况 2个网卡n个相机 电脑系统信息&#xff0c;系统版本&#xff1a;Ubuntu22.04.5 LTS&#xff1b;内核版本…...

6️⃣Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙

Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙 一、前言:离区块链还有多远? 区块链听起来可能遥不可及,似乎是只有密码学专家和资深工程师才能涉足的领域。但事实上,构建一个区块链的核心并不复杂,尤其当你已经掌握了一门系统编程语言,比如 Go。 要真正理解区…...