python与深度学习(八):CNN和fashion_mnist二
目录
- 1. 说明
- 2. fashion_mnist的CNN模型测试
- 2.1 导入相关库
- 2.2 加载数据和模型
- 2.3 设置保存图片的路径
- 2.4 加载图片
- 2.5 图片预处理
- 2.6 对图片进行预测
- 2.7 显示图片
- 3. 完整代码和显示结果
- 4. 多张图片进行测试的完整代码以及结果
1. 说明
本篇文章是对上篇文章训练的模型进行测试。首先是将训练好的模型进行重新加载,然后采用opencv对图片进行加载,最后将加载好的图片输送给模型并且显示结果。
2. fashion_mnist的CNN模型测试
2.1 导入相关库
在这里导入需要的第三方库如cv2,如果没有,则需要自行下载。
from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
2.2 加载数据和模型
把fashion_mnist数据集进行加载,并且把训练好的模型也加载进来。
# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')
2.3 设置保存图片的路径
将数据集的某个数据以图片的形式进行保存,便于测试的可视化。
在这里设置图片存储的位置。
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)
在书写完上述代码后,需要在代码的当前路径下新建一个imgs的文件夹用于存储图片,如下。
执行完上述代码后就会在imgs的文件中可以发现多了一张图片,如下(下面测试了很多次)。
2.4 加载图片
采用cv2对图片进行加载,下面最后一行代码取一个通道的原因是用opencv库也就是cv2读取图片的时候,图片是三通道的,而训练的模型是单通道的,因此取单通道。
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
2.5 图片预处理
对图片进行预处理,即进行归一化处理和改变形状处理,这是为了便于将图片输入给训练好的模型进行预测。
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
2.6 对图片进行预测
将图片输入给训练好我的模型并且进行预测。
预测的结果是10个概率值,所以需要进行处理, np.argmax()是得到概率值最大值的序号,也就是预测的数字。
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])
2.7 显示图片
对预测的图片进行显示,把预测的数字显示在图片上。
下面5行代码分别是创建窗口,设定窗口大小,显示图片,停留图片,清除内存。
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500) # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()
3. 完整代码和显示结果
以下是完整的代码和图片显示结果。
from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500) # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
1/1 [==============================] - 0s 168ms/step
test.png的预测概率: [[2.9672831e-04 7.3040414e-05 1.4721525e-04 9.9842703e-01 4.7597905e-068.9959512e-06 1.0416918e-03 8.6147125e-09 4.2549357e-07 1.2974965e-07]]
test.png的预测概率: 0.99842703
test.png的所属类别: Dress
4. 多张图片进行测试的完整代码以及结果
为了测试更多的图片,引入循环进行多次测试,效果更好。
from tensorflow import keras
from keras.datasets import fashion_mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw # PIL就是pillow包(保存图像)
import numpy as np# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载mnist数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')prepicture = int(input("input the number of test picture :"))
for i in range(prepicture):path1 = input("input the test picture path:")# 创建图片保存路径test_file_path = os.path.join(sys.path[0], 'imgs', path1)# 存储测试数据的任意一个num = int(input("input the test picture num:"))Image.fromarray(x_test[num]).save(test_file_path)# 加载本地test.png图像image = cv2.imread(test_file_path)# 复制图片test_img = image.copy()# 将图片大小转换成(28,28)test_img = cv2.resize(test_img, (28, 28))# 取单通道值test_img = test_img[:, :, 0]# 预处理: 归一化 + reshapenew_test_img = (test_img/255.0).reshape(1, 28, 28, 1)# 预测y_pre_pro = recons_model.predict(new_test_img, verbose=1)# 哪一类数字class_id = np.argmax(y_pre_pro, axis=1)[0]print('test.png的预测概率:', y_pre_pro)print('test.png的预测概率:', y_pre_pro[0, class_id])print('test.png的所属类别:', class_names[class_id])text = str(class_names[class_id])# # 显示cv2.namedWindow('img', 0)cv2.resizeWindow('img', 500, 500) # 自己设定窗口图片的大小cv2.imshow('img', image)cv2.waitKey()cv2.destroyAllWindows()
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
input the number of test picture :2
input the test picture path:101.jpg
input the test picture num:1
1/1 [==============================] - 0s 145ms/step
test.png的预测概率: [[5.1000708e-05 2.9449904e-13 9.9993873e-01 5.5402721e-11 4.8696438e-061.2649738e-12 5.3379590e-06 6.5959898e-17 7.1223938e-10 4.0113624e-12]]
test.png的预测概率: 0.9999387
test.png的所属类别: Pullover
input the test picture path:102.jpg
input the test picture num:2
1/1 [==============================] - 0s 21ms/step
test.png的预测概率: [[3.01315001e-10 1.00000000e+00 1.03142118e-14 8.63922683e-114.10812981e-11 6.07313693e-22 2.31636132e-09 5.08595438e-251.02018335e-13 8.82350167e-28]]
test.png的预测概率: 1.0
test.png的所属类别: Trouser
相关文章:

python与深度学习(八):CNN和fashion_mnist二
目录 1. 说明2. fashion_mnist的CNN模型测试2.1 导入相关库2.2 加载数据和模型2.3 设置保存图片的路径2.4 加载图片2.5 图片预处理2.6 对图片进行预测2.7 显示图片 3. 完整代码和显示结果4. 多张图片进行测试的完整代码以及结果 1. 说明 本篇文章是对上篇文章训练的模型进行测…...

开发一个RISC-V上的操作系统(五)—— 协作式多任务
目录 往期文章传送门 一、什么是多任务 二、代码实现 三、测试 往期文章传送门 开发一个RISC-V上的操作系统(一)—— 环境搭建_riscv开发环境_Patarw_Li的博客-CSDN博客 开发一个RISC-V上的操作系统(二)—— 系统引导程序&a…...

Mybatis-plus集合
目录 mybatis-plus集合1、简介2、特性3、开始使用4、QueryWrapper的使用5、补充 mybatis-plus集合 1、简介 MyBatis-Plus (简称 MP)是一个 MyBatis的增强工具,在 MyBatis 的基础上只做增强不做改变,为简化开发、提高效率而生。 m…...

C++ 结构体和联合体
1.结构体 结构体是一种特殊形态的类,它和类一样,可以有自己的数据成员和函数成员,可以有自己的构造函数和析构函数,可以控制访问权限,可以继承,支持包含多态,结构体定义的语法和类的定义语法几…...

使用TensorFlow训练深度学习模型实战(下)
大家好,本文接TensorFlow训练深度学习模型的上半部分继续进行讲述,下面将介绍有关定义深度学习模型、训练模型和评估模型的内容。 定义深度学习模型 数据准备完成后,下一步是使用TensorFlow搭建神经网络模型,搭建模型有两个选项…...

lucene、solr、es的区别以及应用场景
目录 1. Lucene:2. Solr:3. Elasticsearch: Lucene、Solr 和 Elasticsearch(ES) 都是基于 Lucene 引擎的搜索引擎,它们之间有相似之处,但也有一些不同之处。 Lucene 是一个低级别的搜索引擎库,它提供了一种用于创建和维护全文索引的 API&…...

Java方法的使用(重点:形参和实参的关系、方法重载、递归)
目录 一、Java方法 * 有返回类型,在方法体里就一定要返回相应类型的数据。没有返回类型(void),就不要返回!! * 方法没有声明一说。与C语言不同(C语言是自顶向下读取代码)&#…...

登录页的具体实现 (小兔鲜儿)【Vue3】
登录页 整体认识和路由配置 整体认识 登录页面的主要功能就是表单校验和登录登出业务 准备模板 <script setup></script><template><div><header class"login-header"><div class"container m-top-20"><h1 cl…...

大学如何自学嵌入式开发?
1. C语言:C语言是基础中的基础,刚开始学习不用太深入,一本常用的C语言的教材即可,注意不是当教科书看,而是看完一节过后,打开电脑把后面的习题都写出来,并且编译运行一遍,一定要动手…...

pytorch学习——线性神经网络——1线性回归
概要:线性神经网络是一种最简单的神经网络模型,它由若干个线性变换和非线性变换组成。线性变换通常表示为矩阵乘法,非线性变换通常是一个逐元素的非线性函数。线性神经网络通常用于解决回归和分类问题。 一.线性回归 线性回归是一种常见的机…...

00 - RAP 开发环境配置
文章目录 [1] Eclipse - ADT[2] BTP / S4HC[3] Add ABAP Env. Service[4] Conn. to BTP [1] Eclipse - ADT 关于如何安装配置,参见文章: Install ABAP Development Tools (ADT) and abapGit Plugin Eclipse Eclipse - ADT Eclipse - abapGit Plugin [2] BTP / S4…...

山西电力市场日前价格预测【2023-08-01】
日前价格预测 预测明日(2023-08-01)山西电力市场全天平均日前电价为310.15元/MWh。其中,最高日前电价为335.18元/MWh,预计出现在19: 45。最低日前电价为288.85元/MWh,预计出现在14: 00。 价差方向预测 1:实…...

QT--day5(网络聊天室、学生信息管理系统)
服务器: #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);//给服务器指针实例化空间servernew QTcpServer(this); }Widget::~Widget() {delete ui; …...

【用IDEA基于Scala2.12.18开发Spark 3.4.1 项目】
目录 使用IDEA创建Spark项目设置sbt依赖创建Spark 项目结构新建Scala代码 使用IDEA创建Spark项目 打开IDEA后选址新建项目 选址sbt选项 配置JDK debug 解决方案 相关的依赖下载出问题多的话,可以关闭idea,重启再等等即可。 设置sbt依赖 将sbt…...

HEVC 速率控制(码控)介绍
视频编码速率控制 速率控制: 通过选择一系列编码参数,使得视频编码后的比特率满足所有需要的速率限制,并且使得编码失真尽量小。速率控制属于率失真优化的范畴,速率控制算法的重点是确定与速率相关的量化参数(Quantiz…...

四大软件测试策略的特点和区别(单元测试、集成测试、确认测试和系统测试)
四大软件测试策略分别是单元测试、集成测试、确认测试和系统测试。 一、单元测试 单元测试也称为模块测试,它针对软件中的最小单元(如函数、方法、类、模块等)进行测试,以验证其是否符合预期的行为和结果。单元测试通常由开发人…...

ingress-nginx controller安装
文章目录 一、ingress-nginx controller安装环境 1.1 部署yaml1.2 镜像1.3 安装操作 一、ingress-nginx controller安装 环境 kubernetes版本:1.27.1操作系统:CentOS7.9 1.1 部署yaml deploy.yaml apiVersion: v1 kind: Namespace metadata:labels:…...

开源快速开发平台:做好数据管理,实现流程化办公!
做好数据管理,可以提升企业的办公协作效率,实现数字化转型。开源快速开发平台是深受企业喜爱的低代码开发平台,拥有多项典型功能,是可以打造自主可控快速开发平台,实现一对一框架定制的软件平台。在快节奏的社会中&…...

基于深度学习的裂纹图像分类研究(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

TypeScript入门学习汇总
1.快速入门 1.1 简介 TypeScript 是 JavaScript 的一个超集,支持 ECMAScript 6 标准。 TypeScript 由微软开发的自由和开源的编程语言。 TypeScript 设计目标是开发大型应用,它可以编译成纯 JavaScript,编译出来的 JavaScript 可以运行在…...

Vue3使用vxetable进行表格的编辑、删除与新增
效果图如下: vxetable4传送门 一、引入插件 package.json中加入"vxe-table": "4.0.23",终端中执行npm i导入import {VXETable, VxeTableInstance...

JUC 并发编程之JMM
目录 1. 内存模型JMM 1. 1 主内存和工作内存 1.2 重排序 1. 内存模型JMM Java内存模型是Java虚拟机(JVM)规范中定义的一组规则,用于屏蔽各种硬件和操作系统的内存访问差异,保证多线程情况下程序的正确执行。Java内存模型规定了…...

k8s集群中安装kibana 7.x 踩坑
1. FATAL ValidationError: child "server" fails because [child "port" fails because ["port" must be a number]] 解决办法: 在环境变量中指定端口: - name: SERVER_PORTvalue: 5601 2. Kibana FATAL Error: [elast…...

CSS的一些基础知识
选择器: 选择器用于选择要应用样式的HTML元素。常见的选择器包括标签选择器(如 div、p)、类选择器(如 .class)、ID选择器(如 #id)和伪类选择器(如 :hover)。选择器可以根…...

解决多线程环境下单例模式同时访问生成多个实例
如何满足单例:1.构造方法是private、static方法、if语句判断 ①、单线程 Single类 //Single类,定义一个GetInstance操作,允许客户访问它的唯一实例。GetInstance是一个静态方法,主要负责创建自己的唯一实例 public class LazySi…...

转转闲鱼交易猫源码搭建
后台一键生成链接,独立后台管理 教程:修改数据库config/Conn.php 不会可以看源码里有教程 下载程序:https://pan.baidu.com/s/16lN3gvRIZm7pqhvVMYYecQ?pwd6zw3...

设计模式精华版汇总
以下是个人整理的设计模式汇总,将会持续更新工作和面试中经常用到的设计模式。 设计模式-装饰者模式(包装模式)- 案例分析和源码分析 设计模式-代理模式:控制访问的设计模式 - 案例分析 设计模式-门面模式…...

uniapp实现带参数二维码
view <view class"canvas"><!-- 二维码插件 width height设置宽高 --><canvas canvas-id"qrcode" :style"{width: ${qrcodeSize}px, height: ${qrcodeSize}px}" /></view> script import uQRCode from /utils/uqrcod…...

金融行业软件测试面试题及其答案
下面是一些常见的金融行业软件测试面试题及其答案: 1. 什么是金融行业软件测试? 金融行业软件测试是针对金融领域的软件系统进行验证和确认的过程,旨在确保软件在安全、稳定、可靠和符合法规要求的条件下运行。 2. 解释一下金融软件中的风险…...

强化学习QLearning 进行迷宫游戏和代码
强化学习是机器学习里面的一个分支。它强调基于环境而探索行动、学习,以取得最大化的预期收益。其灵感来源于心理学中的行为主义理论,既有机体如何在环境给予的奖励或者惩罚的刺激下,逐步形成对刺激的预期,产生能够最大利益的习惯…...