当前位置: 首页 > news >正文

基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码
在这里插入图片描述
原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:

PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape:  (252, 255)
采样后图片shape:  (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,[[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>

测试代码如下所示:

#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__)    #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__)   #2.4.2#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''print("图像数据形状和对应数据为:", train_data0.shape)                          #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0)         #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0))   # [5]#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义一层全连接层,输出维度是1self.fc = paddle.nn.Linear(in_features=784, out_features=1)# 定义网络结构的前向计算过程def forward(self, inputs):outputs = self.fc(inputs)return outputs
#===========================================#训练配置
# 声明网络结构
model = MNIST()
def train(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):# 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]assert len(img.shape) == 3batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]# 归一化图像数据img = img / 255# 将图像形式reshape为[batch_size, 784]img = paddle.reshape(img, [batch_size, img_h*img_w])return img  
#===========================================   
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')# 声明网络结构
model = MNIST()
#===========================================
def run(model):# 启动训练模式model.train()# 加载训练集 batch_size 设为 16train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), batch_size=16, shuffle=True)# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())EPOCH_NUM = 10for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):images = norm_img(data[0]).astype('float32')labels = data[1].astype('float32')#前向计算的过程predicts = model(images)# 计算损失loss = F.square_error_cost(predicts, labels)avg_loss = paddle.mean(loss)#每训练了1000批次的数据,打印下当前Loss的情况if batch_id % 1000 == 0:print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()opt.step()opt.clear_grad()
#===========================================
#调用训练            
run(model)
paddle.save(model.state_dict(), './mnist.pdparams')  #模型测试#===========================================
def showImage(im):#img_path = 'example_0.jpg'# 读取原始图像并显示#im = Image.open('example_0.jpg')plt.imshow(im)plt.show()# 将原始图像转为灰度图im = im.convert('L')print('原始图像shape: ', np.array(im).shape)# 使用Image.ANTIALIAS方式采样原始图片im = im.resize((28, 28), Image.LANCZOS)plt.imshow(im)plt.show()print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)# 读取一张本地的样例图片,转变成模型输入的格式
#=========================================== 
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')# print(np.array(im))im = im.resize((28, 28), Image.LANCZOS)im = np.array(im).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255return im
#=========================================== 
# 定义预测过程
def test():model = MNIST()params_file_path = 'mnist.pdparams'img_path = 'example_0.jpg'# 加载模型参数param_dict = paddle.load(params_file_path)model.load_dict(param_dict)# 灌入数据model.eval()tensor_img = load_image(img_path)  result = model(paddle.to_tensor(tensor_img))print('result',result)#  预测输出取整,即为预测的数字,打印结果print("本次预测的数字是", result.numpy().astype('int32'))
#=========================================== 
test(); 

相关文章:

基于飞桨paddle的极简方案构建手写数字识别模型测试代码

基于飞桨paddle的极简方案构建手写数字识别模型测试代码 原始测试图片为255X252的图片 因为是极简方案采用的是线性回归模型,所以预测结果数字不一致 本次预测的数字是 [[3]] 测试结果: PS E:\project\python> & D:/Python39/python.exe e:/pro…...

soft ip与hard ip

ip分soft和hard两种,soft就是纯代码,买过来要自己综合自己pr。hard ip如mem和analog与工艺有关。 mem的lib和lef是memory compiler产生的,基于bitcell,是foundry给的。 我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起…...

MyBatisPlus从入门到精通-2

接着上一讲的Mp的分页功能 下面我们讲解条件查询功能和其他功能 解决一下日志输出和banner问题 每次卞就会输出这些日志 很不美观,现在我们关闭一下 这样建个xml,文件名为logback.xml 文件内容改成这样 配置了logback但是里面什么都没写就不会说有日…...

AI面试官:Asp.Net 中使用Log4Net (一)

AI面试官:Asp.Net 中使用Log4Net (一) 当面试涉及到使用log4net日志记录框架的相关问题时,通常会聚焦在如何在.NET或.NET Core应用程序中集成和使用log4net。以下是一些关于log4net的面试题目,以及相应的解答、案例和代码: 文章目…...

Selenium自动化元素定位方式与浏览器测试脚本

Selenium八大元素定位方法 Selenium可以驱动浏览器完成各种操作,比如模拟点击等。要想操作一个元素,首先应该识别这个元素。人有各种的特征(属性),我们可以通过其特征找到人,如通过身份证号、姓名、家庭住…...

人机交互与人机混合智能的区别

人机交互和人机融合智能是两个相关但不完全相同的概念: 人机交互是指人与计算机之间的信息交流和互动过程。它关注的是如何设计和实现用户友好的界面,以便人们能够方便、高效地与计算机进行沟通和操作。人机交互通常强调用户体验和界面设计,旨…...

【项目】轻量级HTTP服务器

文章目录 一、项目介绍二、前置知识2.1 URI、URL、URN2.2 CGI2.2.1 CGI的概念2.2.2 CGI模式的实现2.2.3 CGI的意义 三、项目设计3.1 日志的编写3.2 套接字编写3.3 HTTP服务器实现3.4 HTTP请求与响应结构3.5 EndPoint类的实现3.5.1 EndPoint的基本逻辑3.5.2 读取请求3.5.3 构建响…...

sketch如何在线打开?有没有什么软件可以辅助

Sketch 在线打开的方法有哪些?这个问题和我之前回答过的「Sketch 可以在线编辑吗?」是一样的答案,没有。很遗憾,Sketch 没有在线打开的方法,Sketch 也做不到可以在线编辑。那么,那些广告里出现的设计软件工…...

CSS Flex 笔记

1. Flexbox 术语 Flex 容器可以是<div> 等&#xff0c;对其设置属性&#xff1a;display: flex, justify-content 是沿主轴方向调整元素&#xff0c;align-items 是沿交叉轴对齐元素。 2. Cheatsheet 2.1 设置 Flex 容器&#xff0c;加粗的属性为默认值 2.1.1 align-it…...

Markdown常用标签及其用途-有示例

Markdown常用标签及其用途 Markdown是一种轻量级标记语言&#xff0c;具有简洁易读的特点。下面是一些常用的Markdown标签以及它们的用途&#xff0c;并附带一些示例&#xff1a; 标题 用于创建不同级别的标题&#xff0c;可通过添加一到六个#符号来表示不同级别的标题。 #…...

25.1 Knife4j-Swagger的增强插件

1.Knife4j概述 Knife4j是一款基于Swagger UI的增强插件&#xff0c;它可以为Spring Boot项目生成美观且易于使用的API文档界面。它是Swagger UI的增强版&#xff0c;提供了更多的功能和定制选项&#xff0c;使API文档更加易读和易于理解。 2.Knife4j使用 Knife4j 集Swagger2…...

用flask run代替flask run --debug

安装python-dotenv依赖。 在项目根目录下新建.flaskenv文件&#xff0c;并作如下配置&#xff1a; FLASK_ENVdevelopment FLASK_DEBUG1...

python_day14_综合案例

文件内容 导包配置 import jsonfrom pyspark import SparkContext, SparkConf import osos.environ["PYSPARK_PYTHON"] "D:/dev/python/python3.10.4/python.exe" os.environ["HADOOP_HOME"] "D:/dev/hadoop-3.0.0" conf SparkC…...

【算法题】2779. 数组的最大美丽值

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums 和一个 非负 整数 k 。 在一步操作中&#xff0c;你可以执行下述指令&#xff1a; 在范围 [0, nums.length - 1] 中选择一个 此前没有选过 的下标 i 。 将 nums[i] 替换为范围 [nums[i] - k, nums[i] k] 内的任一整…...

文件上传之PHP

别怕,我会一直陪着你 一.知识二.实例1.phtml, <?简单过滤2.前端验证, phtml3 \.htaccess 一.知识 绕过后缀的有文件格式有php,php3,php4,php5,phtml.pht 二.实例 1.phtml, <?简单过滤 (1)一句话木马 故意使用了post和get用来迷惑人 https://127.0.0.1/shy.php?POS…...

人脸检测实战-insightface

目录 简介 一、InsightFace介绍 二、安装 三、快速体验 四、代码实战 1、人脸检测 2、人脸识别 五、代码及示例图片链接 简介 目前github有非常多的人脸识别开源项目&#xff0c;下面列出几个常用的开源项目&#xff1a; 1、deepface 2、CompreFace 3、face_recogn…...

Linux工具【1】(编辑器vim、编译器gcc与g++)

vim详解 引言vimVim的三种模式及模式切换普通模式下操作底行模式下操作 gcc与ggcc的使用&#xff08;g类似&#xff09;预编译编译汇编链接静态库与动态库 总结 引言 vim&#xff08;vi improved&#xff09;编辑器是从 vi 发展出来的一个文本编辑器。 代码补全、编译及错误跳…...

基于Java+SpringBoot+vue前后端分离古典舞在线交流平台设计实现

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…...

MQ - 闲聊MQ一二事儿 (Kafka、RocketMQ 、Pulsar )

文章目录 MQ的发展史阶段一&#xff1a;追求解耦阶段二&#xff1a;追求吞吐量与一致性阶段三&#xff1a;追求平台化 MQ的通用架构主题topic、生产者producer、消费者consumer分区partition MQ 存储KafkaGood Design ---> 磁盘顺序写盘Poor Impact---> topic 数量不能过…...

Qt中的 QIODevice类(包含:随机访问、顺序访问设备)

QIODevice类 一、简介 QIODevice用于对输入输出设备进行管理&#xff0c;是Qt中所有I/O设备的基接口类。为支持读写数据块的设备(如QFile、QBuffer和QTcpSocket)提供了通用实现和抽象接口。 输入设备有2种类型&#xff1a; 一种是随机访问设备&#xff0c;QFile(文件)和QBuff…...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

从深圳崛起的“机器之眼”:赴港乐动机器人的万亿赛道赶考路

进入2025年以来&#xff0c;尽管围绕人形机器人、具身智能等机器人赛道的质疑声不断&#xff0c;但全球市场热度依然高涨&#xff0c;入局者持续增加。 以国内市场为例&#xff0c;天眼查专业版数据显示&#xff0c;截至5月底&#xff0c;我国现存在业、存续状态的机器人相关企…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么&#xff1f; WebAssembly&#xff08;WASM&#xff09; 是一种能在现代浏览器中高效运行的二进制指令格式&#xff0c;它不是传统的编程语言&#xff0c;而是一种 低级字节码格式&#xff0c;可由高级语言&#xff08;如 C、C、Rust&am…...

根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:

根据万维钢精英日课6的内容&#xff0c;使用AI&#xff08;2025&#xff09;可以参考以下方法&#xff1a; 四个洞见 模型已经比人聪明&#xff1a;以ChatGPT o3为代表的AI非常强大&#xff0c;能运用高级理论解释道理、引用最新学术论文&#xff0c;生成对顶尖科学家都有用的…...

python报错No module named ‘tensorflow.keras‘

是由于不同版本的tensorflow下的keras所在的路径不同&#xff0c;结合所安装的tensorflow的目录结构修改from语句即可。 原语句&#xff1a; from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后&#xff1a; from tensorflow.python.keras.lay…...

【网络安全】开源系统getshell漏洞挖掘

审计过程&#xff1a; 在入口文件admin/index.php中&#xff1a; 用户可以通过m,c,a等参数控制加载的文件和方法&#xff0c;在app/system/entrance.php中存在重点代码&#xff1a; 当M_TYPE system并且M_MODULE include时&#xff0c;会设置常量PATH_OWN_FILE为PATH_APP.M_T…...

Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)

引言 在人工智能飞速发展的今天&#xff0c;大语言模型&#xff08;Large Language Models, LLMs&#xff09;已成为技术领域的焦点。从智能写作到代码生成&#xff0c;LLM 的应用场景不断扩展&#xff0c;深刻改变了我们的工作和生活方式。然而&#xff0c;理解这些模型的内部…...

PHP 8.5 即将发布:管道操作符、强力调试

前不久&#xff0c;PHP宣布了即将在 2025 年 11 月 20 日 正式发布的 PHP 8.5&#xff01;作为 PHP 语言的又一次重要迭代&#xff0c;PHP 8.5 承诺带来一系列旨在提升代码可读性、健壮性以及开发者效率的改进。而更令人兴奋的是&#xff0c;借助强大的本地开发环境 ServBay&am…...

HubSpot推出与ChatGPT的深度集成引发兴奋与担忧

上周三&#xff0c;HubSpot宣布已构建与ChatGPT的深度集成&#xff0c;这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋&#xff0c;但同时也存在一些关于数据安全的担忧。 许多网络声音声称&#xff0c;这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...