python实现——分类类型数据挖掘任务(图形识别分类任务)
- 分类类型数据挖掘任务
基于卷积神经网络(CNN)的岩石图像分类。有一岩石图片数据集,共300张岩石图片,图片尺寸224x224。岩石种类有砾岩(Conglomerate)、安山岩(Andesite)、花岗岩(Granite)、石灰岩(Limestone)、石英岩(Quartzite)和5种,每种岩石图片各50张,共250张。请选择合适模型对该数据集进行建模,训练优化模型并给出模型评估指标,再利用GUI框架开发岩石图片分类界面。
1.1总体流程
1.2数据增强
定义:数据增强是利用现有数据生成新的数据来增加数据量的过程,能够有效地扩充训练数据集的大小,提高模型的泛化能力,同时也能够有效地防止过拟合现象的发生。
本项目采用的数据增强方法:
(1)水平翻转
(2)缩放
(3)旋转
(4)添加高斯噪音
(5)调整对比度和亮度
通过数据增强,数据集从之前的250张扩充至1500张,数据量为之前的6倍。
参考代码:
import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):rows, cols, _ = img.shape# 水平翻转图像img_flip = cv2.flip(img, 1)img_name = os.path.splitext(save_path)[0] + "_flip.jpg"cv2.imwrite(img_name, img_flip)print("Saved augmented image:", img_name)# 随机缩放图像scale = np.random.uniform(0.9, 1.1)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)img_transformed = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_transform.jpg"cv2.imwrite(img_name, img_transformed)print("Saved augmented image:", img_name)# 随机旋转图像angle = np.random.randint(-10, 10)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)img_rotated = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"cv2.imwrite(img_name, img_rotated)print("Saved augmented image:", img_name)# 添加高斯噪音mean = 0std = np.random.uniform(5, 15)noise = np.zeros(img.shape, np.float32)cv2.randn(noise, mean, std)noise = np.uint8(noise)img_noisy = cv2.add(img, noise)img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"cv2.imwrite(img_name, img_noisy)print("Saved augmented image:", img_name)# 随机调整对比度和亮度alpha = np.random.uniform(0.8, 1.2)beta = np.random.randint(-10, 10)img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"cv2.imwrite(img_name, img_contrast)print("Saved augmented image:", img_name)return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):img = cv2.imread(img_path)if img is None:print("Error: Unable to read image at", img_path)continue# 获取保存增强后的图片文件名img_name = os.path.basename(img_path)save_path = os.path.join(save_dir, img_name)# 数据增强augmented_img = augment_data(img, save_path)if augmented_img is not None:# 保存原始图片cv2.imwrite(save_path, img)print("Saved original image:", save_path)
结果:
1.3数据预处理
将1500张图片依次读入并转化为可训练的数据(特征变量(X)和标签(Y))
代码:
import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))i=0
for name in categories:img = Image.open(image_folder + '\\' +name)img_rgb = img.split()X_list[i,:,:,0] = np.array(img_rgb[0])/255X_list[i,:,:,1] = np.array(img_rgb[1])/255X_list[i,:,:,2] = np.array(img_rgb[2])/255Y_list[i] = name.split('_')[0]i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)
![]() |
1.4模型构建
1.4.1模型结构定义
模型参数:
参考代码:
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别
num_classes = 5
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)
input_shape = (224, 224, 3)
# 假设X和Y是您的原始数据
# X: 图像数据,形状为(num_samples, 224, 224, 3)
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)
# 将数据划分为训练集和测试集(只执行一次)
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
# 构建模型
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape), tf.keras.layers.MaxPooling2D((2,2), strides=2), tf.keras.layers.Conv2D(16, (5,5), activation='relu'), tf.keras.layers.MaxPooling2D((2,2), strides=2), tf.keras.layers.Conv2D(120, (5,5), activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(84, activation='relu'), tf.keras.layers.Dropout(0.3), tf.keras.layers.Dense(num_classes, activation='softmax') # 确保输出层的神经元数量与类别数量匹配
]) # 编译模型
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数 optimizer=tf.keras.optimizers.Adam(), # 使用Adam优化器 metrics=['sparse_categorical_accuracy']) # 监控准确率 # 打印模型概述
model.summary() # 使用model.fit()函数训练模型
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)
1.4.2模型译
编译参数参考:
# 优化器optimizer='adam'# 损失函数loss='sparse_categorical_crossentropy'# 评估指标metrics=['sparse_categorical_accuracy']
1.5模型训练
1.5.1划分训练集和测试集
按照训练集:测试集=8:2的比例对数据集进行划分,建议使用sklearn库中的train_test_split函数。
1.5.2训练
使用fit函数对训练集进行拟合训练,并将训练过程中产生的历史数据history保存至变量中。
训练参数参考:
# 迭代次数epochs=20# 验证集比例validation_split=0.2
1.5.3训练过程可视化
对history中保存下来的训练过程中的loss和sparse_categorical_accuracy的变化情况进行绘图。
参考代码:
# 获取训练和验证的准确率和损失
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss'] # 使用model.evaluate()函数评估模型在测试集上的性能
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_accuracy}') # 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()plt.rcParams['font.sans-serif'] = ['SimHei']
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,cmap="Blues",cbar=False,linewidths=2,linecolor='white',square=True,xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'])
plt.show
1.6.3保存模型
使用save函数对训练好的模型进行保存,方便后续使用。
参考代码:
model.save('roch_classification_cnn.h5')
1.7图形用户界面(GUI)开发
1.7.1配置开发工具
在PyCharm中配置QtDesigner和PyUIC工具。
注意:需提前在python环境中安装好PyQt5和PyQt5-tools库。
- 配置QtDesigner
Program:(对应designer.exe的路径)
Working directory: $FileDir$
- 配置PyUCI
Program:(对应pyuic5.exe的路径)
Arguments: $FileName$ -o $FileNameWithoutExtension$.py
Working directory: $FileDir$
配置完成后的界面:
1.7.2设计图形用户界面
在PyCharm中“Tools”—“External Tools”中打开QtDesigner
在QtDesigner主界面中选择创建Main Window,然后根据需求选择相应的控件进行设计。
设计界面参考:
设计好之后保存为.ui文件。
1.7.3 ui文件转换为代码
在PyCharm中右键点击.ui文件并使用PyUCI工具进行转换。
1.7.4代码与模型结合
将转化后的代码与之前训练的模型相结合。
参考代码:
# -*- coding: utf-8 -*-
import osfrom PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):def setupUi(self, MainWindow):MainWindow.setObjectName("MainWindow")MainWindow.resize(800, 600)self.centralwidget = QtWidgets.QWidget(MainWindow)self.centralwidget.setObjectName("centralwidget")self.label = QtWidgets.QLabel(self.centralwidget)self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))self.label.setScaledContents(False)self.label.setObjectName("label")self.pushButton = QtWidgets.QPushButton(self.centralwidget)self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))self.pushButton.setObjectName("pushButton")self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))self.pushButton_2.setObjectName("pushButton_2")self.label_2 = QtWidgets.QLabel(self.centralwidget)self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_3 = QtWidgets.QLabel(self.centralwidget)self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))self.label_3.setText("")self.label_3.setObjectName("label_3")self.label_4 = QtWidgets.QLabel(self.centralwidget)self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.label_4.setObjectName("label_4")self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))self.textBrowser.setObjectName("textBrowser")self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))self.textBrowser_2.setObjectName("textBrowser_2")self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))self.textBrowser_3.setObjectName("textBrowser_3")self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.textBrowser_4.setObjectName("textBrowser_4")self.textBrowser_2.raise_()self.label.raise_()self.textBrowser.raise_()self.textBrowser_3.raise_()self.pushButton.raise_()self.pushButton_2.raise_()self.label_2.raise_()self.label_4.raise_()self.textBrowser_4.raise_()self.label_3.raise_()MainWindow.setCentralWidget(self.centralwidget)self.menubar = QtWidgets.QMenuBar(MainWindow)self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))self.menubar.setObjectName("menubar")MainWindow.setMenuBar(self.menubar)self.statusbar = QtWidgets.QStatusBar(MainWindow)self.statusbar.setObjectName("statusbar")MainWindow.setStatusBar(self.statusbar)self.toolBar = QtWidgets.QToolBar(MainWindow)self.toolBar.setObjectName("toolBar")MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)# 模型相关变量初始化self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')self.path = ''self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']# 将“导入图片”按钮与openImage函数绑定self.pushButton.clicked.connect(self.openImage)# 将“岩石分类”按钮与classify函数绑定self.pushButton_2.clicked.connect(self.classify)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.label.setText(_translate("MainWindow", "岩石图像分类"))self.pushButton.setText(_translate("MainWindow", "导入图像"))self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))self.label_4.setText(_translate("MainWindow", "分类结果"))self.textBrowser_3.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))self.textBrowser_4.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))# 导入图片函数def resource_path(relative):if hasattr(sys, "_MEIPASS"):absolute_path = os.path.join(sys._MEIPASS, relative)else:absolute_path = os.path.join(relative)return absolute_path# 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))def openImage(self):imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())self.label_2.setPixmap(jpg)self.path=imgPathself.label_3.setText('')def classify(self):img = Image.open(self.path) # 读取图像img_rgb = img.split()x = np.zeros((1, 224, 224, 3))x[0,:, :, 0] = np.array(img_rgb[0]) / 255x[0,:, :, 1] = np.array(img_rgb[1]) / 255x[0,:, :, 2] = np.array(img_rgb[2]) / 255y = self.model.predict(x)result = self.rock_types[np.argmax(y)]self.label_3.setText(result)
if __name__=='__main__':QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)app=QtWidgets.QApplication(sys.argv)MainWindow=QtWidgets.QMainWindow()ui_test=Ui_MainWindow()ui_test.setupUi(MainWindow)MainWindow.show()sys.exit(app.exec_())
1.7.5测试
执行程序测试“导入图片”和“鉴定分类”功能。
1.8打包可执行文件(exe)
在命令窗口中使用如下指令对上一步的程序进行打包。
Pyinstaller -F -w xxxxx.py
运行生成的.exe文件并测试功能。
打完包之后可能出现错误
报错信息:
=============================================================
A RecursionError (maximum recursion depth exceeded) occurred.
For working around please follow these instructions
=============================================================
1. In your program's .spec file add this line near the top::
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
2. Build your program by running PyInstaller with the .spec file as
argument::
pyinstaller myprog.spec
3. If this fails, you most probably hit an endless recursion in
PyInstaller. Please try to track this down has far as possible,
create a minimal example so we can reproduce and open an issue at
https://github.com/pyinstaller/pyinstaller/issues following the
instructions in the issue template. Many thanks.
Explanation: Python's stack-limit is a safety-belt against endless recursion,
eating up memory. PyInstaller imports modules recursively. If the structure
how modules are imported within your program is awkward, this leads to the
nesting being too deep and hitting Python's stack-limit.
With the default recursion limit (1000), the recursion error occurs at about
115 nested imported, with limit 2000 at about 240, with limit 5000 at about
660.
————————————————
你打包目录下会生成如下文件
打开你的main.spec文件
在顶端添加代码:
import sys
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
然后在运行命令(对应的文件名)
pyinstaller 你的文件名.spec
然后就完成了
打完包之的运行闪退问题:
先安装一个新的第三方库ordereddict
安装命令:
pip install ordereddict
注意自己python代码的文件引入路径(确保对应的路径下有对应的文件,我这里设置的是根目录下)
重新打包
完成之后
打开对应的文件夹双击就可以了
完整代码:
import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):rows, cols, _ = img.shape# 水平翻转图像img_flip = cv2.flip(img, 1)img_name = os.path.splitext(save_path)[0] + "_flip.jpg"cv2.imwrite(img_name, img_flip)print("Saved augmented image:", img_name)# 随机缩放图像scale = np.random.uniform(0.9, 1.1)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)img_transformed = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_transform.jpg"cv2.imwrite(img_name, img_transformed)print("Saved augmented image:", img_name)# 随机旋转图像angle = np.random.randint(-10, 10)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)img_rotated = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"cv2.imwrite(img_name, img_rotated)print("Saved augmented image:", img_name)# 添加高斯噪音mean = 0std = np.random.uniform(5, 15)noise = np.zeros(img.shape, np.float32)cv2.randn(noise, mean, std)noise = np.uint8(noise)img_noisy = cv2.add(img, noise)img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"cv2.imwrite(img_name, img_noisy)print("Saved augmented image:", img_name)# 随机调整对比度和亮度alpha = np.random.uniform(0.8, 1.2)beta = np.random.randint(-10, 10)img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"cv2.imwrite(img_name, img_contrast)print("Saved augmented image:", img_name)return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):img = cv2.imread(img_path)if img is None:print("Error: Unable to read image at", img_path)continue# 获取保存增强后的图片文件名img_name = os.path.basename(img_path)save_path = os.path.join(save_dir, img_name)# 数据增强augmented_img = augment_data(img, save_path)if augmented_img is not None:# 保存原始图片cv2.imwrite(save_path, img)print("Saved original image:", save_path)
#%%
import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))i=0
for name in categories:img = Image.open(image_folder + '\\' +name)img_rgb = img.split()X_list[i,:,:,0] = np.array(img_rgb[0])/255X_list[i,:,:,1] = np.array(img_rgb[1])/255X_list[i,:,:,2] = np.array(img_rgb[2])/255Y_list[i] = name.split('_')[0]i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)
#%%
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别
num_classes = 5
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)
input_shape = (224, 224, 3)
# 假设X和Y是您的原始数据
# X: 图像数据,形状为(num_samples, 224, 224, 3)
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)
# 将数据划分为训练集和测试集(只执行一次)
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
# 构建模型
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape), tf.keras.layers.MaxPooling2D((2,2), strides=2), tf.keras.layers.Conv2D(16, (5,5), activation='relu'), tf.keras.layers.MaxPooling2D((2,2), strides=2), tf.keras.layers.Conv2D(120, (5,5), activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(84, activation='relu'), tf.keras.layers.Dropout(0.3), tf.keras.layers.Dense(num_classes, activation='softmax') # 确保输出层的神经元数量与类别数量匹配
]) # 编译模型
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数 optimizer=tf.keras.optimizers.Adam(), # 使用Adam优化器 metrics=['sparse_categorical_accuracy']) # 监控准确率 # 打印模型概述
model.summary() # 使用model.fit()函数训练模型
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2) #%%
y_pred = model.predict(x_test)
print(y_pred)
#%%#%%
# 获取训练和验证的准确率和损失
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss'] # 使用model.evaluate()函数评估模型在测试集上的性能
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_accuracy}') # 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()plt.rcParams['font.sans-serif'] = ['SimHei']
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,cmap="Blues",cbar=False,linewidths=2,linecolor='white',square=True,xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'])
plt.show
#%%
model.save('roch_classification_cnn.h5')
# -*- coding: utf-8 -*-
import osfrom PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):def setupUi(self, MainWindow):MainWindow.setObjectName("MainWindow")MainWindow.resize(800, 600)self.centralwidget = QtWidgets.QWidget(MainWindow)self.centralwidget.setObjectName("centralwidget")self.label = QtWidgets.QLabel(self.centralwidget)self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))self.label.setScaledContents(False)self.label.setObjectName("label")self.pushButton = QtWidgets.QPushButton(self.centralwidget)self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))self.pushButton.setObjectName("pushButton")self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))self.pushButton_2.setObjectName("pushButton_2")self.label_2 = QtWidgets.QLabel(self.centralwidget)self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_3 = QtWidgets.QLabel(self.centralwidget)self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))self.label_3.setText("")self.label_3.setObjectName("label_3")self.label_4 = QtWidgets.QLabel(self.centralwidget)self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.label_4.setObjectName("label_4")self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))self.textBrowser.setObjectName("textBrowser")self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))self.textBrowser_2.setObjectName("textBrowser_2")self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))self.textBrowser_3.setObjectName("textBrowser_3")self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.textBrowser_4.setObjectName("textBrowser_4")self.textBrowser_2.raise_()self.label.raise_()self.textBrowser.raise_()self.textBrowser_3.raise_()self.pushButton.raise_()self.pushButton_2.raise_()self.label_2.raise_()self.label_4.raise_()self.textBrowser_4.raise_()self.label_3.raise_()MainWindow.setCentralWidget(self.centralwidget)self.menubar = QtWidgets.QMenuBar(MainWindow)self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))self.menubar.setObjectName("menubar")MainWindow.setMenuBar(self.menubar)self.statusbar = QtWidgets.QStatusBar(MainWindow)self.statusbar.setObjectName("statusbar")MainWindow.setStatusBar(self.statusbar)self.toolBar = QtWidgets.QToolBar(MainWindow)self.toolBar.setObjectName("toolBar")MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)# 模型相关变量初始化self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')self.path = ''self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']# 将“导入图片”按钮与openImage函数绑定self.pushButton.clicked.connect(self.openImage)# 将“岩石分类”按钮与classify函数绑定self.pushButton_2.clicked.connect(self.classify)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.label.setText(_translate("MainWindow", "岩石图像分类"))self.pushButton.setText(_translate("MainWindow", "导入图像"))self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))self.label_4.setText(_translate("MainWindow", "分类结果"))self.textBrowser_3.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))self.textBrowser_4.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))# 导入图片函数def resource_path(relative):if hasattr(sys, "_MEIPASS"):absolute_path = os.path.join(sys._MEIPASS, relative)else:absolute_path = os.path.join(relative)return absolute_path# 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))def openImage(self):imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())self.label_2.setPixmap(jpg)self.path=imgPathself.label_3.setText('')def classify(self):img = Image.open(self.path) # 读取图像img_rgb = img.split()x = np.zeros((1, 224, 224, 3))x[0,:, :, 0] = np.array(img_rgb[0]) / 255x[0,:, :, 1] = np.array(img_rgb[1]) / 255x[0,:, :, 2] = np.array(img_rgb[2]) / 255y = self.model.predict(x)result = self.rock_types[np.argmax(y)]self.label_3.setText(result)
if __name__=='__main__':QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)app=QtWidgets.QApplication(sys.argv)MainWindow=QtWidgets.QMainWindow()ui_test=Ui_MainWindow()ui_test.setupUi(MainWindow)MainWindow.show()sys.exit(app.exec_())
相关文章:

python实现——分类类型数据挖掘任务(图形识别分类任务)
分类类型数据挖掘任务 基于卷积神经网络(CNN)的岩石图像分类。有一岩石图片数据集,共300张岩石图片,图片尺寸224x224。岩石种类有砾岩(Conglomerate)、安山岩(Andesite)、花岗岩&am…...
【安卓跨进程通信IPC】-- Binder
目录 BinderBinder是什么?进程空间分配进程隔离Binder跨进程通信机制模型优点AIDL常见面试题 Binder 夯实基础之超详解Android Binder的工作方式与原理以及aidl示例代码 比较详细的介绍:Android跨进程通信:图文详解 Binder机制 原理 操作系统…...

大数据之Schedule调度错误(一)
当我们在利用ooize发起整个任务的调度过程中,如果多个调度任务同时运行并且多个调度任务操作了相同的表,那么就会出现如下的错误关系: Invalid path hdfs://iZh5w01l7f8lnog055cpXXX:8000/user/admin/xxx: No files matching path hdfs://iZh5w01l7f8lnog055cpXXX:8000/user/ad…...

DiffIR论文阅读笔记
ICCV2023的一篇用diffusion模型做Image Restoration的论文,一作是清华的教授,还在NIPS2023上一作发表了Hierarchical Integration Diffusion Model for Realistic Image Deblurring,作者里甚至有Luc Van Gool大佬。模型分三个部分,…...

prometheus+alertmanager+webhook钉钉机器人告警
版本:centos7.9 python3.9.5 alertmanager0.25.0 prometheus2.46.0 安装alertmanager prometheus 配置webhook # 解压: tar -xvf alertmanager-0.25.0.linux-amd64.tar.gz tar -xvf prometheus-2.46.0.linux-amd64.tar.gz mv alertmanager-0.25.0.linu…...
ctfshow 年CTF web
除夕 Notice: Undefined index: year in /var/www/html/index.php on line 16 <?phpinclude "flag.php";$year $_GET[year];if($year2022 && $year1!2023){echo $flag; }else{highlight_file(__FILE__); } 弱比较绕过很简单,连函数都没有直…...
原型链、闭包、手写一个闭包函数、 闭包有哪些优缺点、原型链继承
什么是原型链? 原型链是一种查找规则 为对象成员查找机制提供一个方向 因为构造函数的 prototype 和其实例的 __ proto __ 都是指向原型对象的 所以可以通过__proto__ 查找当前的原型对象有没有该属性, 没有就找原型的原型, 依次类推一直找到Object( null ) 为…...
linux中SSH_ASKPASS全局变量的作用
在工作中遇到一段代码,通过SSH_ASKPASS全局变量实现了ssh登录远程IP时的密码输入,chatgpt搜索了一下,其解释大致如下所示: SSH_ASKPASS 是一个环境变量,它在 SSH 客户端需要用户输入密码时起作用。当 SSH 客户端检测到…...

9 -力扣高频 SQL 50 题(基础版)
9 - 上升的温度 -- 找出与之前(昨天的)日期相比温度更高的所有日期的 id -- DATEDIFF(2007-12-31,2007-12-30); # 1 -- DATEDIFF(2010-12-30,2010-12-31); # -1select w1.id from Weather w1, Weather w2 wheredatediff(w1.recordDate,w2.recordDat…...

TCP的重传机制
TCP 是一个可靠的传输协议,解决了IP层的丢包、乱序、重复等问题。这其中,TCP的重传机制起到重要的作用。 序列号和确认号 之前我们在讲解TCP三次握手时,提到过TCP包头结构,其中有序列号和确认号, 而TCP 实现可靠传输…...
pg 数据库,获取时间字段值的具体小时,赋值给其他字段
目录 1 问题2 实现 1 问题 pg 数据库,有一个表,其中有2个字段 一个是时间字段obstime ,一个是时次ltime字段,int 类型,现在这个表里面是obstime 里面有数据,ltime字段 没有数据,现在就是批量获…...

做视频号小店什么类目最容易爆单?其实,弄懂这三点就会选品了
大家好,我是电商花花。 我们做视频号小店做什么类目最容易爆单? 其实任何类目都有属于自己的受众人群和客户,都非常容易爆单,我们想要爆单,就要选对类目,选对产品。 视频号上所有的类目基本上可以分为标…...

Nginx作为下载站点
grep -Ev ^$|# /usr/local/nginx/conf/nginx.conf > /opt/nginx.txt cat /opt/nginx.txt > /usr/local/nginx/conf/nginx.conf用上面的指令提取最小化的配置文件 vim /usr/local/nginx/conf/nginx.conf [rootlocalhost ~]# cat /usr/local/nginx/conf/nginx.conf worker…...

vue3简单快速实现主题切换功能
⛰️个人主页: 蒾酒 🔥系列专栏:《vue3实战》 目录 内容概要 实现步骤 1.定义不同主题的css样式变量 2.入口main.ts中引入这个样式文件 3.主题样式css变量引用 4.设置默认主题样式 5.实现点击按钮主题切换 总结 最近发现了一个巨牛的人工智…...

国联易安:网络反不正当竞争,要防患于未然
据市场监管总局官网消息,为预防和制止网络不正当竞争,维护公平竞争的市场秩序,鼓励创新,保护经营者和消费者的合法权益,促进数字经济规范健康持续发展,市场监管总局近日发布《网络反不正当竞争暂行规定》&a…...
Linux 网络配置 01
基本命令 1、查看网络接口信息ifconfig ifconfig:当前设备正在工作的网卡,启动的设备 ifconfig -a :所网络设备 ifconfig信息解析: ens33: flags4163<UP,BROADCAST,RUNNING,MULTICAST> mtu 1500inet 192.168.10.10 n…...

快速入门C++正则表达式
正则表达式(Regular Expression,简称 Regex)是一种强大的文本处理工具,广泛用于字符串的搜索、替换、分析等操作。它基于一种表达式语言,使用单个字符串来描述、匹配一系列符合某个句法规则的字符串。正则表达式不仅在…...
java —— 缓冲字符输入流/缓冲字符输出流
缓冲字符输入流/缓冲字符输出流是对字符输入流/字符输出流的加强,在使用中仍旧要借助于字符输入流/字符输出流才能完成实现。与字符输入流/字符输出流按照字符为单位进行输入/输出不同的是,缓冲字符输入流/缓冲字符输出流能够以行为单位进行读取和写入。…...
blender从视频中动作捕捉,绑定到人物模型
总共分为3个步骤: 1、从视频中捕捉动作模型 小K动画网-AIGC视频动捕平台 地址:https://xk.yunbovtb.com/ 需要注册 生成的FBX文件,不能直接导入到blender中, 方法有2种: 第一种:需要转换一下&#x…...

掘金滑块验证码安全升级,继续破解
去年发过一篇文章,《使用前端技术破解掘金滑块验证码》,我很佩服掘金官方的气度,不但允许我发布这篇文章,还同步发到了官方公众号。最近发现掘金的滑块验证码升级了,也许是我那篇文章起到了一些作用,逼迫官…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版分享
平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

高频面试之3Zookeeper
高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个?3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制(过半机制࿰…...
生成 Git SSH 证书
🔑 1. 生成 SSH 密钥对 在终端(Windows 使用 Git Bash,Mac/Linux 使用 Terminal)执行命令: ssh-keygen -t rsa -b 4096 -C "your_emailexample.com" 参数说明: -t rsa&#x…...
Spring AI 入门:Java 开发者的生成式 AI 实践之路
一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...

ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

NFT模式:数字资产确权与链游经济系统构建
NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
在树莓派上添加音频输入设备的几种方法
在树莓派上添加音频输入设备可以通过以下步骤完成,具体方法取决于设备类型(如USB麦克风、3.5mm接口麦克风或HDMI音频输入)。以下是详细指南: 1. 连接音频输入设备 USB麦克风/声卡:直接插入树莓派的USB接口。3.5mm麦克…...

sshd代码修改banner
sshd服务连接之后会收到字符串: SSH-2.0-OpenSSH_9.5 容易被hacker识别此服务为sshd服务。 是否可以通过修改此banner达到让人无法识别此服务的目的呢? 不能。因为这是写的SSH的协议中的。 也就是协议规定了banner必须这么写。 SSH- 开头,…...