当前位置: 首页 > news >正文

python实现——分类类型数据挖掘任务(图形识别分类任务)

  1. 分类类型数据挖掘任务

基于卷积神经网络(CNN)的岩石图像分类。有一岩石图片数据集,共300张岩石图片,图片尺寸224x224。岩石种类有砾岩(Conglomerate)、安山岩(Andesite)、花岗岩(Granite)、石灰岩(Limestone)、石英岩(Quartzite)和5种,每种岩石图片各50张,共250张。请选择合适模型对该数据集进行建模,训练优化模型并给出模型评估指标,再利用GUI框架开发岩石图片分类界面。

1.1总体流程

1.2数据增强

定义:数据增强是利用现有数据生成新的数据来增加数据量的过程,能够有效地扩充训练数据集的大小,提高模型的泛化能力,同时也能够有效地防止过拟合现象的发生。

本项目采用的数据增强方法:

(1)水平翻转

(2)缩放

(3)旋转

(4)添加高斯噪音

(5)调整对比度和亮度

通过数据增强,数据集从之前的250张扩充至1500张,数据量为之前的6倍。

参考代码:

import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):rows, cols, _ = img.shape# 水平翻转图像img_flip = cv2.flip(img, 1)img_name = os.path.splitext(save_path)[0] + "_flip.jpg"cv2.imwrite(img_name, img_flip)print("Saved augmented image:", img_name)# 随机缩放图像scale = np.random.uniform(0.9, 1.1)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)img_transformed = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_transform.jpg"cv2.imwrite(img_name, img_transformed)print("Saved augmented image:", img_name)# 随机旋转图像angle = np.random.randint(-10, 10)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)img_rotated = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"cv2.imwrite(img_name, img_rotated)print("Saved augmented image:", img_name)# 添加高斯噪音mean = 0std = np.random.uniform(5, 15)noise = np.zeros(img.shape, np.float32)cv2.randn(noise, mean, std)noise = np.uint8(noise)img_noisy = cv2.add(img, noise)img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"cv2.imwrite(img_name, img_noisy)print("Saved augmented image:", img_name)# 随机调整对比度和亮度alpha = np.random.uniform(0.8, 1.2)beta = np.random.randint(-10, 10)img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"cv2.imwrite(img_name, img_contrast)print("Saved augmented image:", img_name)return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):img = cv2.imread(img_path)if img is None:print("Error: Unable to read image at", img_path)continue# 获取保存增强后的图片文件名img_name = os.path.basename(img_path)save_path = os.path.join(save_dir, img_name)# 数据增强augmented_img = augment_data(img, save_path)if augmented_img is not None:# 保存原始图片cv2.imwrite(save_path, img)print("Saved original image:", save_path)

 结果:

1.3数据预处理

将1500张图片依次读入并转化为可训练的数据(特征变量(X)和标签(Y))

代码:

import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))i=0
for name in categories:img = Image.open(image_folder + '\\' +name)img_rgb = img.split()X_list[i,:,:,0] = np.array(img_rgb[0])/255X_list[i,:,:,1] = np.array(img_rgb[1])/255X_list[i,:,:,2] = np.array(img_rgb[2])/255Y_list[i] = name.split('_')[0]i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)

1.4模型构建

1.4.1模型结构定义

模型参数:

参考代码:

from sklearn.model_selection import train_test_split
import seaborn as sns  
import matplotlib.pyplot as plt  
import tensorflow as tf
from sklearn.metrics import confusion_matrix  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别  
num_classes = 5  
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)  
input_shape = (224, 224, 3)  
# 假设X和Y是您的原始数据  
# X: 图像数据,形状为(num_samples, 224, 224, 3)  
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)  
# 将数据划分为训练集和测试集(只执行一次)  
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)  
# 构建模型  
model = tf.keras.models.Sequential([  tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(16, (5,5), activation='relu'),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(120, (5,5), activation='relu'),  tf.keras.layers.Flatten(),  tf.keras.layers.Dense(84, activation='relu'),  tf.keras.layers.Dropout(0.3),  tf.keras.layers.Dense(num_classes, activation='softmax')  # 确保输出层的神经元数量与类别数量匹配  
])  # 编译模型  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数   optimizer=tf.keras.optimizers.Adam(),  # 使用Adam优化器  metrics=['sparse_categorical_accuracy'])  # 监控准确率  # 打印模型概述  
model.summary()  # 使用model.fit()函数训练模型  
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)  

 

1.4.2模型译

编译参数参考:

# 优化器optimizer='adam'# 损失函数loss='sparse_categorical_crossentropy'# 评估指标metrics=['sparse_categorical_accuracy']

1.5模型训练

1.5.1划分训练集和测试集

按照训练集:测试集=8:2的比例对数据集进行划分,建议使用sklearn库中的train_test_split函数。

1.5.2训练

使用fit函数对训练集进行拟合训练,并将训练过程中产生的历史数据history保存至变量中。

训练参数参考:

# 迭代次数epochs=20# 验证集比例validation_split=0.2

1.5.3训练过程可视化

对history中保存下来的训练过程中的loss和sparse_categorical_accuracy的变化情况进行绘图。

参考代码:

# 获取训练和验证的准确率和损失  
acc = history.history['sparse_categorical_accuracy']  
val_acc = history.history['sparse_categorical_accuracy']  
loss = history.history['loss']  
val_loss = history.history['val_loss']  # 使用model.evaluate()函数评估模型在测试集上的性能  
test_loss, test_accuracy = model.evaluate(x_test, y_test)  
print(f'Test accuracy: {test_accuracy}')  # 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()plt.rcParams['font.sans-serif'] = ['SimHei'] 
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,cmap="Blues",cbar=False,linewidths=2,linecolor='white',square=True,xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'])
plt.show

 

 

1.6.3保存模型

使用save函数对训练好的模型进行保存,方便后续使用。

参考代码:

model.save('roch_classification_cnn.h5')

1.7图形用户界面(GUI)开发

1.7.1配置开发工具

在PyCharm中配置QtDesigner和PyUIC工具。

注意:需提前在python环境中安装好PyQt5和PyQt5-tools库。

  1. 配置QtDesigner

Program:(对应designer.exe的路径)

Working directory: $FileDir$

  1. 配置PyUCI

Program:(对应pyuic5.exe的路径)

Arguments: $FileName$ -o $FileNameWithoutExtension$.py

Working directory: $FileDir$

配置完成后的界面:

1.7.2设计图形用户界面

在PyCharm中“Tools”—“External Tools”中打开QtDesigner

在QtDesigner主界面中选择创建Main Window,然后根据需求选择相应的控件进行设计。

设计界面参考:

设计好之后保存为.ui文件。

1.7.3 ui文件转换为代码

在PyCharm中右键点击.ui文件并使用PyUCI工具进行转换。

1.7.4代码与模型结合

将转化后的代码与之前训练的模型相结合。

参考代码:

# -*- coding: utf-8 -*-
import osfrom PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):def setupUi(self, MainWindow):MainWindow.setObjectName("MainWindow")MainWindow.resize(800, 600)self.centralwidget = QtWidgets.QWidget(MainWindow)self.centralwidget.setObjectName("centralwidget")self.label = QtWidgets.QLabel(self.centralwidget)self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))self.label.setScaledContents(False)self.label.setObjectName("label")self.pushButton = QtWidgets.QPushButton(self.centralwidget)self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))self.pushButton.setObjectName("pushButton")self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))self.pushButton_2.setObjectName("pushButton_2")self.label_2 = QtWidgets.QLabel(self.centralwidget)self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_3 = QtWidgets.QLabel(self.centralwidget)self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))self.label_3.setText("")self.label_3.setObjectName("label_3")self.label_4 = QtWidgets.QLabel(self.centralwidget)self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.label_4.setObjectName("label_4")self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))self.textBrowser.setObjectName("textBrowser")self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))self.textBrowser_2.setObjectName("textBrowser_2")self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))self.textBrowser_3.setObjectName("textBrowser_3")self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.textBrowser_4.setObjectName("textBrowser_4")self.textBrowser_2.raise_()self.label.raise_()self.textBrowser.raise_()self.textBrowser_3.raise_()self.pushButton.raise_()self.pushButton_2.raise_()self.label_2.raise_()self.label_4.raise_()self.textBrowser_4.raise_()self.label_3.raise_()MainWindow.setCentralWidget(self.centralwidget)self.menubar = QtWidgets.QMenuBar(MainWindow)self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))self.menubar.setObjectName("menubar")MainWindow.setMenuBar(self.menubar)self.statusbar = QtWidgets.QStatusBar(MainWindow)self.statusbar.setObjectName("statusbar")MainWindow.setStatusBar(self.statusbar)self.toolBar = QtWidgets.QToolBar(MainWindow)self.toolBar.setObjectName("toolBar")MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)# 模型相关变量初始化self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')self.path = ''self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']# 将“导入图片”按钮与openImage函数绑定self.pushButton.clicked.connect(self.openImage)# 将“岩石分类”按钮与classify函数绑定self.pushButton_2.clicked.connect(self.classify)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.label.setText(_translate("MainWindow", "岩石图像分类"))self.pushButton.setText(_translate("MainWindow", "导入图像"))self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))self.label_4.setText(_translate("MainWindow", "分类结果"))self.textBrowser_3.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))self.textBrowser_4.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))# 导入图片函数def resource_path(relative):if hasattr(sys, "_MEIPASS"):absolute_path = os.path.join(sys._MEIPASS, relative)else:absolute_path = os.path.join(relative)return absolute_path# 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))def openImage(self):imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())self.label_2.setPixmap(jpg)self.path=imgPathself.label_3.setText('')def classify(self):img = Image.open(self.path)  # 读取图像img_rgb = img.split()x = np.zeros((1, 224, 224, 3))x[0,:, :, 0] = np.array(img_rgb[0]) / 255x[0,:, :, 1] = np.array(img_rgb[1]) / 255x[0,:, :, 2] = np.array(img_rgb[2]) / 255y = self.model.predict(x)result = self.rock_types[np.argmax(y)]self.label_3.setText(result)
if __name__=='__main__':QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)app=QtWidgets.QApplication(sys.argv)MainWindow=QtWidgets.QMainWindow()ui_test=Ui_MainWindow()ui_test.setupUi(MainWindow)MainWindow.show()sys.exit(app.exec_())

1.7.5测试

执行程序测试“导入图片”和“鉴定分类”功能。

1.8打包可执行文件(exe)

在命令窗口中使用如下指令对上一步的程序进行打包。

Pyinstaller -F -w xxxxx.py

运行生成的.exe文件并测试功能。

打完包之后可能出现错误

报错信息:

=============================================================

A RecursionError (maximum recursion depth exceeded) occurred.

For working around please follow these instructions

=============================================================

1. In your program's .spec file add this line near the top::

     import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)

2. Build your program by running PyInstaller with the .spec file as

   argument::

     pyinstaller myprog.spec

3. If this fails, you most probably hit an endless recursion in

   PyInstaller. Please try to track this down has far as possible,

   create a minimal example so we can reproduce and open an issue at

   https://github.com/pyinstaller/pyinstaller/issues following the

   instructions in the issue template. Many thanks.

Explanation: Python's stack-limit is a safety-belt against endless recursion,

eating up memory. PyInstaller imports modules recursively. If the structure

how modules are imported within your program is awkward, this leads to the

nesting being too deep and hitting Python's stack-limit.

With the default recursion limit (1000), the recursion error occurs at about

115 nested imported, with limit 2000 at about 240, with limit 5000 at about

660.

————————————————

你打包目录下会生成如下文件

打开你的main.spec文件

在顶端添加代码:

import sys

sys.setrecursionlimit(sys.getrecursionlimit() * 5)

然后在运行命令(对应的文件名)

pyinstaller 你的文件名.spec

然后就完成了

打完包之的运行闪退问题:

先安装一个新的第三方库ordereddict

安装命令:

pip install ordereddict

注意自己python代码的文件引入路径(确保对应的路径下有对应的文件,我这里设置的是根目录下)

重新打包

完成之后

打开对应的文件夹双击就可以了

完整代码:

import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):rows, cols, _ = img.shape# 水平翻转图像img_flip = cv2.flip(img, 1)img_name = os.path.splitext(save_path)[0] + "_flip.jpg"cv2.imwrite(img_name, img_flip)print("Saved augmented image:", img_name)# 随机缩放图像scale = np.random.uniform(0.9, 1.1)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)img_transformed = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_transform.jpg"cv2.imwrite(img_name, img_transformed)print("Saved augmented image:", img_name)# 随机旋转图像angle = np.random.randint(-10, 10)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)img_rotated = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"cv2.imwrite(img_name, img_rotated)print("Saved augmented image:", img_name)# 添加高斯噪音mean = 0std = np.random.uniform(5, 15)noise = np.zeros(img.shape, np.float32)cv2.randn(noise, mean, std)noise = np.uint8(noise)img_noisy = cv2.add(img, noise)img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"cv2.imwrite(img_name, img_noisy)print("Saved augmented image:", img_name)# 随机调整对比度和亮度alpha = np.random.uniform(0.8, 1.2)beta = np.random.randint(-10, 10)img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"cv2.imwrite(img_name, img_contrast)print("Saved augmented image:", img_name)return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):img = cv2.imread(img_path)if img is None:print("Error: Unable to read image at", img_path)continue# 获取保存增强后的图片文件名img_name = os.path.basename(img_path)save_path = os.path.join(save_dir, img_name)# 数据增强augmented_img = augment_data(img, save_path)if augmented_img is not None:# 保存原始图片cv2.imwrite(save_path, img)print("Saved original image:", save_path)
#%%
import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))i=0
for name in categories:img = Image.open(image_folder + '\\' +name)img_rgb = img.split()X_list[i,:,:,0] = np.array(img_rgb[0])/255X_list[i,:,:,1] = np.array(img_rgb[1])/255X_list[i,:,:,2] = np.array(img_rgb[2])/255Y_list[i] = name.split('_')[0]i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)
#%%
from sklearn.model_selection import train_test_split
import seaborn as sns  
import matplotlib.pyplot as plt  
import tensorflow as tf
from sklearn.metrics import confusion_matrix  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别  
num_classes = 5  
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)  
input_shape = (224, 224, 3)  
# 假设X和Y是您的原始数据  
# X: 图像数据,形状为(num_samples, 224, 224, 3)  
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)  
# 将数据划分为训练集和测试集(只执行一次)  
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)  
# 构建模型  
model = tf.keras.models.Sequential([  tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(16, (5,5), activation='relu'),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(120, (5,5), activation='relu'),  tf.keras.layers.Flatten(),  tf.keras.layers.Dense(84, activation='relu'),  tf.keras.layers.Dropout(0.3),  tf.keras.layers.Dense(num_classes, activation='softmax')  # 确保输出层的神经元数量与类别数量匹配  
])  # 编译模型  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数   optimizer=tf.keras.optimizers.Adam(),  # 使用Adam优化器  metrics=['sparse_categorical_accuracy'])  # 监控准确率  # 打印模型概述  
model.summary()  # 使用model.fit()函数训练模型  
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)  #%%
y_pred = model.predict(x_test) 
print(y_pred)
#%%#%%
# 获取训练和验证的准确率和损失  
acc = history.history['sparse_categorical_accuracy']  
val_acc = history.history['sparse_categorical_accuracy']  
loss = history.history['loss']  
val_loss = history.history['val_loss']  # 使用model.evaluate()函数评估模型在测试集上的性能  
test_loss, test_accuracy = model.evaluate(x_test, y_test)  
print(f'Test accuracy: {test_accuracy}')  # 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()plt.rcParams['font.sans-serif'] = ['SimHei'] 
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,cmap="Blues",cbar=False,linewidths=2,linecolor='white',square=True,xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'])
plt.show
#%%
model.save('roch_classification_cnn.h5')

# -*- coding: utf-8 -*-
import osfrom PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):def setupUi(self, MainWindow):MainWindow.setObjectName("MainWindow")MainWindow.resize(800, 600)self.centralwidget = QtWidgets.QWidget(MainWindow)self.centralwidget.setObjectName("centralwidget")self.label = QtWidgets.QLabel(self.centralwidget)self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))self.label.setScaledContents(False)self.label.setObjectName("label")self.pushButton = QtWidgets.QPushButton(self.centralwidget)self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))self.pushButton.setObjectName("pushButton")self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))self.pushButton_2.setObjectName("pushButton_2")self.label_2 = QtWidgets.QLabel(self.centralwidget)self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_3 = QtWidgets.QLabel(self.centralwidget)self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))self.label_3.setText("")self.label_3.setObjectName("label_3")self.label_4 = QtWidgets.QLabel(self.centralwidget)self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.label_4.setObjectName("label_4")self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))self.textBrowser.setObjectName("textBrowser")self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))self.textBrowser_2.setObjectName("textBrowser_2")self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))self.textBrowser_3.setObjectName("textBrowser_3")self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.textBrowser_4.setObjectName("textBrowser_4")self.textBrowser_2.raise_()self.label.raise_()self.textBrowser.raise_()self.textBrowser_3.raise_()self.pushButton.raise_()self.pushButton_2.raise_()self.label_2.raise_()self.label_4.raise_()self.textBrowser_4.raise_()self.label_3.raise_()MainWindow.setCentralWidget(self.centralwidget)self.menubar = QtWidgets.QMenuBar(MainWindow)self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))self.menubar.setObjectName("menubar")MainWindow.setMenuBar(self.menubar)self.statusbar = QtWidgets.QStatusBar(MainWindow)self.statusbar.setObjectName("statusbar")MainWindow.setStatusBar(self.statusbar)self.toolBar = QtWidgets.QToolBar(MainWindow)self.toolBar.setObjectName("toolBar")MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)# 模型相关变量初始化self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')self.path = ''self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']# 将“导入图片”按钮与openImage函数绑定self.pushButton.clicked.connect(self.openImage)# 将“岩石分类”按钮与classify函数绑定self.pushButton_2.clicked.connect(self.classify)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.label.setText(_translate("MainWindow", "岩石图像分类"))self.pushButton.setText(_translate("MainWindow", "导入图像"))self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))self.label_4.setText(_translate("MainWindow", "分类结果"))self.textBrowser_3.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))self.textBrowser_4.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))# 导入图片函数def resource_path(relative):if hasattr(sys, "_MEIPASS"):absolute_path = os.path.join(sys._MEIPASS, relative)else:absolute_path = os.path.join(relative)return absolute_path# 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))def openImage(self):imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())self.label_2.setPixmap(jpg)self.path=imgPathself.label_3.setText('')def classify(self):img = Image.open(self.path)  # 读取图像img_rgb = img.split()x = np.zeros((1, 224, 224, 3))x[0,:, :, 0] = np.array(img_rgb[0]) / 255x[0,:, :, 1] = np.array(img_rgb[1]) / 255x[0,:, :, 2] = np.array(img_rgb[2]) / 255y = self.model.predict(x)result = self.rock_types[np.argmax(y)]self.label_3.setText(result)
if __name__=='__main__':QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)app=QtWidgets.QApplication(sys.argv)MainWindow=QtWidgets.QMainWindow()ui_test=Ui_MainWindow()ui_test.setupUi(MainWindow)MainWindow.show()sys.exit(app.exec_())

相关文章:

python实现——分类类型数据挖掘任务(图形识别分类任务)

分类类型数据挖掘任务 基于卷积神经网络&#xff08;CNN&#xff09;的岩石图像分类。有一岩石图片数据集&#xff0c;共300张岩石图片&#xff0c;图片尺寸224x224。岩石种类有砾岩&#xff08;Conglomerate&#xff09;、安山岩&#xff08;Andesite&#xff09;、花岗岩&am…...

【安卓跨进程通信IPC】-- Binder

目录 BinderBinder是什么&#xff1f;进程空间分配进程隔离Binder跨进程通信机制模型优点AIDL常见面试题 Binder 夯实基础之超详解Android Binder的工作方式与原理以及aidl示例代码 比较详细的介绍&#xff1a;Android跨进程通信&#xff1a;图文详解 Binder机制 原理 操作系统…...

大数据之Schedule调度错误(一)

当我们在利用ooize发起整个任务的调度过程中,如果多个调度任务同时运行并且多个调度任务操作了相同的表,那么就会出现如下的错误关系: Invalid path hdfs://iZh5w01l7f8lnog055cpXXX:8000/user/admin/xxx: No files matching path hdfs://iZh5w01l7f8lnog055cpXXX:8000/user/ad…...

DiffIR论文阅读笔记

ICCV2023的一篇用diffusion模型做Image Restoration的论文&#xff0c;一作是清华的教授&#xff0c;还在NIPS2023上一作发表了Hierarchical Integration Diffusion Model for Realistic Image Deblurring&#xff0c;作者里甚至有Luc Van Gool大佬。模型分三个部分&#xff0c…...

prometheus+alertmanager+webhook钉钉机器人告警

版本&#xff1a;centos7.9 python3.9.5 alertmanager0.25.0 prometheus2.46.0 安装alertmanager prometheus 配置webhook # 解压&#xff1a; tar -xvf alertmanager-0.25.0.linux-amd64.tar.gz tar -xvf prometheus-2.46.0.linux-amd64.tar.gz mv alertmanager-0.25.0.linu…...

ctfshow 年CTF web

除夕 Notice: Undefined index: year in /var/www/html/index.php on line 16 <?phpinclude "flag.php";$year $_GET[year];if($year2022 && $year1!2023){echo $flag; }else{highlight_file(__FILE__); } 弱比较绕过很简单&#xff0c;连函数都没有直…...

原型链、闭包、手写一个闭包函数、 闭包有哪些优缺点、原型链继承

什么是原型链&#xff1f; 原型链是一种查找规则 为对象成员查找机制提供一个方向 因为构造函数的 prototype 和其实例的 __ proto __ 都是指向原型对象的 所以可以通过__proto__ 查找当前的原型对象有没有该属性, 没有就找原型的原型, 依次类推一直找到Object( null ) 为…...

linux中SSH_ASKPASS全局变量的作用

在工作中遇到一段代码&#xff0c;通过SSH_ASKPASS全局变量实现了ssh登录远程IP时的密码输入&#xff0c;chatgpt搜索了一下&#xff0c;其解释大致如下所示&#xff1a; SSH_ASKPASS 是一个环境变量&#xff0c;它在 SSH 客户端需要用户输入密码时起作用。当 SSH 客户端检测到…...

9 -力扣高频 SQL 50 题(基础版)

9 - 上升的温度 -- 找出与之前&#xff08;昨天的&#xff09;日期相比温度更高的所有日期的 id -- DATEDIFF(2007-12-31,2007-12-30); # 1 -- DATEDIFF(2010-12-30,2010-12-31); # -1select w1.id from Weather w1, Weather w2 wheredatediff(w1.recordDate,w2.recordDat…...

TCP的重传机制

TCP 是一个可靠的传输协议&#xff0c;解决了IP层的丢包、乱序、重复等问题。这其中&#xff0c;TCP的重传机制起到重要的作用。 序列号和确认号 之前我们在讲解TCP三次握手时&#xff0c;提到过TCP包头结构&#xff0c;其中有序列号和确认号&#xff0c; 而TCP 实现可靠传输…...

pg 数据库,获取时间字段值的具体小时,赋值给其他字段

目录 1 问题2 实现 1 问题 pg 数据库&#xff0c;有一个表&#xff0c;其中有2个字段 一个是时间字段obstime &#xff0c;一个是时次ltime字段&#xff0c;int 类型&#xff0c;现在这个表里面是obstime 里面有数据&#xff0c;ltime字段 没有数据&#xff0c;现在就是批量获…...

做视频号小店什么类目最容易爆单?其实,弄懂这三点就会选品了

大家好&#xff0c;我是电商花花。 我们做视频号小店做什么类目最容易爆单&#xff1f; 其实任何类目都有属于自己的受众人群和客户&#xff0c;都非常容易爆单&#xff0c;我们想要爆单&#xff0c;就要选对类目&#xff0c;选对产品。 视频号上所有的类目基本上可以分为标…...

Nginx作为下载站点

grep -Ev ^$|# /usr/local/nginx/conf/nginx.conf > /opt/nginx.txt cat /opt/nginx.txt > /usr/local/nginx/conf/nginx.conf用上面的指令提取最小化的配置文件 vim /usr/local/nginx/conf/nginx.conf [rootlocalhost ~]# cat /usr/local/nginx/conf/nginx.conf worker…...

vue3简单快速实现主题切换功能

⛰️个人主页: 蒾酒 &#x1f525;系列专栏&#xff1a;《vue3实战》 目录 内容概要 实现步骤 1.定义不同主题的css样式变量 2.入口main.ts中引入这个样式文件 3.主题样式css变量引用 4.设置默认主题样式 5.实现点击按钮主题切换 总结 最近发现了一个巨牛的人工智…...

国联易安:网络反不正当竞争,要防患于未然

据市场监管总局官网消息&#xff0c;为预防和制止网络不正当竞争&#xff0c;维护公平竞争的市场秩序&#xff0c;鼓励创新&#xff0c;保护经营者和消费者的合法权益&#xff0c;促进数字经济规范健康持续发展&#xff0c;市场监管总局近日发布《网络反不正当竞争暂行规定》&a…...

Linux 网络配置 01

基本命令 1、查看网络接口信息ifconfig ifconfig&#xff1a;当前设备正在工作的网卡&#xff0c;启动的设备 ifconfig -a &#xff1a;所网络设备 ifconfig信息解析&#xff1a; ens33: flags4163<UP,BROADCAST,RUNNING,MULTICAST> mtu 1500inet 192.168.10.10 n…...

快速入门C++正则表达式

正则表达式&#xff08;Regular Expression&#xff0c;简称 Regex&#xff09;是一种强大的文本处理工具&#xff0c;广泛用于字符串的搜索、替换、分析等操作。它基于一种表达式语言&#xff0c;使用单个字符串来描述、匹配一系列符合某个句法规则的字符串。正则表达式不仅在…...

java —— 缓冲字符输入流/缓冲字符输出流

缓冲字符输入流/缓冲字符输出流是对字符输入流/字符输出流的加强&#xff0c;在使用中仍旧要借助于字符输入流/字符输出流才能完成实现。与字符输入流/字符输出流按照字符为单位进行输入/输出不同的是&#xff0c;缓冲字符输入流/缓冲字符输出流能够以行为单位进行读取和写入。…...

blender从视频中动作捕捉,绑定到人物模型

总共分为3个步骤&#xff1a; 1、从视频中捕捉动作模型 小K动画网-AIGC视频动捕平台 地址&#xff1a;https://xk.yunbovtb.com/ 需要注册 生成的FBX文件&#xff0c;不能直接导入到blender中&#xff0c; 方法有2种&#xff1a; 第一种&#xff1a;需要转换一下&#x…...

掘金滑块验证码安全升级,继续破解

去年发过一篇文章&#xff0c;《使用前端技术破解掘金滑块验证码》&#xff0c;我很佩服掘金官方的气度&#xff0c;不但允许我发布这篇文章&#xff0c;还同步发到了官方公众号。最近发现掘金的滑块验证码升级了&#xff0c;也许是我那篇文章起到了一些作用&#xff0c;逼迫官…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学&#xff1f;传统医学奠基期&#xff08;远古 - 17 世纪&#xff09;近代医学转型期&#xff08;17 世纪 - 19 世纪末&#xff09;​现代医学成熟期&#xff08;20世纪至今&#xff09; 中医的源远流长和一脉相承远古至…...

【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题

【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要&#xff1a; 近期&#xff0c;在使用较新版本的OpenSSH客户端连接老旧SSH服务器时&#xff0c;会遇到 "no matching key exchange method found"​, "n…...

【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论

路径问题的革命性重构&#xff1a;基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中&#xff08;图1&#xff09;&#xff1a; mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...

iview框架主题色的应用

1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题&#xff0c;无需引入&#xff0c;直接可…...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...