利用InceptionV3实现图像分类
最近在做一个机审的项目,初步希望实现图像的四分类,即:正常(neutral)、涉政(political)、涉黄(porn)、涉恐(terrorism)。有朋友给推荐了个github上面的文章,浏览量还挺大的。地址如下:
https://github.com/xqtbox/generalImageClassification
我导入试了一下,发现博主没有放他训练的模型文件my_model.h5,所以代码trainMyDataWithKerasModel.py不能直接运行。必须先自己训练个模型才行,所以只好自己搞了。我开发电脑上安装的python版本是3.9.12,这个版本通常会遇到兼容性的问题,所以我决定先搭建个虚拟环境来测试一下。虚拟环境就用3.7.16了。
1、执行:conda create -n InceptionV3 python=3.7

在C:\Users\用户名\anaconda3\envs目录下创建虚拟环境InceptionV3目录。
2、执行:conda activate InceptionV3

启动InceptionV3虚拟环境。
3、执行:pip install -i https://pypi.douban.com/simple/ tensorflow==1.14.0

我的显卡是Nvidia GeForce RTX 3060的,CUDA是11.8,Cudnn是8.7.0,查了一下对应的。查了一下对应tensorflow版本是1.14.0,所以就安装这个。
4、执行:pip install -i https://pypi.douban.com/simple/ protobuf==3.19.0

5、执行:pip install -i https://pypi.douban.com/simple/ tensorflow_hub==0.9.0

6、执行:pip install -i https://pypi.douban.com/simple/ opencv-python

7、执行:pip install -i https://pypi.douban.com/simple/ scikit-learn

8、执行:pip install -i https://pypi.douban.com/simple/ albumentations==1.2.0

9、执行:pip install -i https://pypi.douban.com/simple/ h5py==2.10.0

10、执行:pip install -i https://pypi.douban.com/simple/ matplotlib

11、执行:pip install -i https://pypi.douban.com/simple/ Tensorflow-gpu==2.4.0

12、执行:pip install -i https://pypi.douban.com/simple/ keras==2.6.0

13、下面是训练代码,文件名是train1.py
import numpy as np
from tensorflow.keras.optimizers import Adamimport cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_splitfrom tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.applications import InceptionV3
import os
import tensorflow as tffrom tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.models import Sequentialimport albumentations
norm_size = 224
datapath = 'data/train'
EPOCHS = 20
INIT_LR = 3e-4
labelList = []# 这里是分类详情
dicClass = {'neutral':0, 'political':1, 'porn':2, 'terrorism':3}
# 这是分类个数
classnum = 4batch_size = 2
np.random.seed(42)# tf.config.list_physical_devices('GPU')
# tf.test.is_gpu_available()def loadImageData():imageList = []listClasses = os.listdir(datapath) # 类别文件夹print(listClasses)for class_name in listClasses:label_id = dicClass[class_name]class_path = os.path.join(datapath, class_name)image_names = os.listdir(class_path)for image_name in image_names:image_full_path = os.path.join(class_path, image_name)labelList.append(label_id)imageList.append(image_full_path)return imageListprint("开始加载数据")
imageArr = loadImageData()
labelList = np.array(labelList)
print("加载数据完成")
print(labelList)
trainX, valX, trainY, valY = train_test_split(imageArr, labelList, test_size=0.3, random_state=42)train_transform = albumentations.Compose([albumentations.OneOf([albumentations.RandomGamma(gamma_limit=(60, 120), p=0.9),albumentations.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.9),albumentations.CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p=0.9),]),albumentations.HorizontalFlip(p=0.5),albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20,interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, p=1),albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)])
val_transform = albumentations.Compose([albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)])def generator(file_pathList, labels, batch_size, train_action=False):L = len(file_pathList)while True:input_labels = []input_samples = []for row in range(0, batch_size):temp = np.random.randint(0, L)X = file_pathList[temp]Y = labels[temp]image = cv2.imdecode(np.fromfile(X, dtype=np.uint8), -1)if image.shape[2] > 3:image = image[:,:,:3]if train_action:image = train_transform(image=image)['image']else:image = val_transform(image=image)['image']image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)image = img_to_array(image)input_samples.append(image)input_labels.append(Y)batch_x = np.asarray(input_samples)batch_y = np.asarray(input_labels)yield (batch_x, batch_y)checkpointer = ModelCheckpoint(filepath='best_model.hdf5',monitor='val_acc', verbose=1, save_best_only=True, mode='max')reduce = ReduceLROnPlateau(monitor='val_acc', patience=10,verbose=1,factor=0.5,min_lr=1e-6)model = Sequential()
model.add(InceptionV3(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(classnum, activation='softmax'))optimizer = Adam(learning_rate=INIT_LR)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['acc'])# print('trainX = ' + str(trainX))
# print('trainY = ' + str(trainY))model.add(tf.keras.layers.BatchNormalization())history = model.fit(generator(trainX, trainY, batch_size, train_action=True),steps_per_epoch=len(trainX) / batch_size,validation_data=generator(valX, valY, batch_size, train_action=False),epochs=EPOCHS,validation_steps=len(valX) / batch_size,callbacks=[checkpointer, reduce])
model.save('my_model.h5')
print(history)loss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
import matplotlib.pyplot as pltprint("Now,we start drawing the loss and acc trends graph...")
# summarize history for acc
fig = plt.figure(1)
plt.plot(history.history["acc"])
plt.plot(history.history["val_acc"])
plt.title("Model acc")
plt.ylabel("acc")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for loss
fig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")
13.1、norm_size = 224 设置输入图像的大小,InceptionV3默认的图片尺寸是224×224。但是我的图片有300px以上的,好像也没什么问题
13.2、datapath = ‘data/train’ 设置图片存放的路径
13.3、EPOCHS = 20 epochs的数量,关于epoch的设置多少合适,这个问题很纠结,一般情况设置300足够了,如果感觉没有训练好,再载入模型训练。
13.4、INIT_LR = 1e-3 学习率,一般情况从0.001开始逐渐降低,也别太小了到1e-6就可以了。
13.5、classnum = 12 类别数量,数据集有两个类别,所有就分为两类。
13.6、batch_size = 4 batchsize,根据硬件的情况和数据集的大小设置,太小了loss浮动太大,太大了收敛不好,根据经验来,一般设置为2的次方。windows可以通过任务管理器查看显存的占用情况。
14、工程目录的文件如下图:

其中train1.py是训练程序;test.py是检测程序,本文后面会再详细讲怎么用;FormatImages.py是格式化图片的程序,功能就是把从网上爬下来比较大的图片等比压缩成300px以内。
data目录存放的就是训练用的数据,如下图:

其中train存放的是训练图片,test存放的是测试图片。train下的目录如下图:

可以看到,图中的train目录中的文件夹名要与train1.py中dicClass的值对应起来,训练数据放到对应目录下就可以了。如下图:

15、下面开始训练了,在训练之前有几个事情要做一下。
首先检查一下自己的cuda安装好没有,方法是在cmd下面输入命令nvcc -V,如果显示版本号就没问题了,如下图:

如果还没有安装也没关系,先看看自己显卡的cuda版本,如下图:

然后去https://developer.nvidia.com/cuda-toolkit-archive下载显卡对应版本的cuda工具包。如下图:

下载完成后安装到默认目录就行,一般是安装在C:\Program Files\NVIDIA GPU Computing Toolkit,如下图:











安装完成后在到https://developer.nvidia.com/rdp/cudnn-download去下载cudnn

下载完成后解压缩,把解压缩后的目录cudnn-windows-x86_64-8.8.0.121_cuda12-archive下的bin、include、lib三个目录里的文件分别复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8的bin、include、lib三个目录里。如下图:


最后到https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-zlib-windows下载ZLIB.DLL。如下图:

下载完成后解压缩,把解压后zlib123dllx64\dll_x64\zlibwapi.dll文件复制到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin目录下



现在,在train1.py目录下执行:python train1.py

可以看一下任务管理器,压力应该都在GPU上:

16、训练完成后,可以看到train1.py目录下多了几个文件,如下图:

其中my_model.h5就是咱们训练出来的模型文件。WW_acc.jpg和WW_loss.jpg是训练结果保存的图,看了一下觉得还不错。


17、接下来要验证一下模型的效果,现在data\test\放一张用于预测的图。如下图:

18、下面是测试代码,文件名是test.py:
import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
import albumentations
norm_size = 224
imagelist = []emotion_labels = {0: 'neutral',1: 'political',2: 'porn',3: 'terrorism',
}val_transform = albumentations.Compose([albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)])
emotion_classifier = load_model("best_model.hdf5")
t1 = time.time()
image = cv2.imdecode(np.fromfile('data/test/01.jpg', dtype=np.uint8), -1)
image = val_transform(image=image)['image']
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype="float")
out = emotion_classifier.predict(imageList)
print(out)
pre = np.argmax(out)
emotion = emotion_labels[pre]
t2 = time.time()
print(emotion)
t3 = t2 - t1
print(t3)
其中emotion_labels是分类,填上与训练文件中一致的内容。
在image = cv2.imdecode(np.fromfile('data/test/01.jpg', dtype=np.uint8), -1)这行修改路径,指向到用于预测的图片位置。
19、执行python test.py

可以看到,data/test/01.jpg被预测成为terrorism,验证正确。至此大功告成。
后记:我是python的领域的新兵,在开发过程中遇到最麻烦的事情就是版本的问题。tensorflow最新版本已经2.11.0了,但是使用起来会有各种问题。我尝试了很多版本,查了不少资料,最后才确定了能用的这个组合。尤其是过程中gpu一直利用不上,程序总是使用cpu在训练,经过一顿折腾总算是能用了,但是为什么这么组合,我也没有找到一个清晰的说明,希望能有大神能给解释一下CUDA、Cudnn、tensorflow、tensorflow-gpu的版本怎么组合最合理。下面把我虚拟环境的配置发上来供大家参考:
Package Version
----------------------- ---------
absl-py 0.15.0
albumentations 1.2.0
astor 0.8.1
astunparse 1.6.3
cachetools 5.3.0
certifi 2022.12.7
charset-normalizer 3.0.1
cycler 0.11.0
flatbuffers 1.12
fonttools 4.38.0
gast 0.3.3
google-auth 2.16.1
google-auth-oauthlib 0.4.6
google-pasta 0.2.0
grpcio 1.32.0
h5py 2.10.0
idna 3.4
imageio 2.25.1
importlib-metadata 6.0.0
joblib 1.2.0
keras 2.6.0
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
kiwisolver 1.4.4
Markdown 3.4.1
MarkupSafe 2.1.2
matplotlib 3.5.3
networkx 2.6.3
numpy 1.19.5
oauthlib 3.2.2
opencv-python 4.7.0.68
opencv-python-headless 4.7.0.68
opt-einsum 3.3.0
packaging 23.0
Pillow 9.4.0
pip 22.3.1
protobuf 3.19.0
pyasn1 0.4.8
pyasn1-modules 0.2.8
pyparsing 3.0.9
python-dateutil 2.8.2
PyWavelets 1.3.0
PyYAML 6.0
qudida 0.0.4
requests 2.28.2
requests-oauthlib 1.3.1
rsa 4.9
scikit-image 0.18.3
scikit-learn 1.0.2
scipy 1.7.3
setuptools 65.6.3
six 1.15.0
tensorboard 2.11.2
tensorboard-data-server 0.6.1
tensorboard-plugin-wit 1.8.1
tensorflow 1.14.0
tensorflow-estimator 2.4.0
tensorflow-gpu 2.4.0
tensorflow-hub 0.9.0
termcolor 1.1.0
threadpoolctl 3.1.0
tifffile 2021.11.2
typing-extensions 3.7.4.3
urllib3 1.26.14
Werkzeug 2.2.3
wheel 0.38.4
wincertstore 0.2
wrapt 1.12.1
zipp 3.14.0
相关文章:

利用InceptionV3实现图像分类
最近在做一个机审的项目,初步希望实现图像的四分类,即:正常(neutral)、涉政(political)、涉黄(porn)、涉恐(terrorism)。有朋友给推荐了个github上…...

【Java】CAS锁
一、什么是CAS机制(compare and swap) 1.概述 CAS的全称为Compare-And-Swap,直译就是对比交换。是一条CPU的原子指令,其作用是让CPU先进行比较两个值是否相等,然后原子地更新某个位置的值。经过调查发现,…...
Linux服务器配置系统安全加固方法
1. SSH空闲超时时间建议为: 600-900 解决方案: 在【/etc/ssh/sshd_config】文件中设置【ClientAliveInterval】设置为600到900之间 vim /etc/ssh/sshd_config #将 ClientAliveInterval 参数值设置为 900 2. 修改检查SSH密码修改最小间隔 解决方案: 在【/etc/login.defs】文件…...

Codeforces Round #850 (Div. 2, based on VK Cup 2022 - Final Round)(A~E)
t宝酱紫喜欢出这种分类讨论的题?!A1. Non-alternating Deck (easy version)给出n张牌,按照题目给的顺序分给两人,问最后两人手中各有几张牌。思路:模拟。AC Code:#include <bits/stdc.h>typedef long…...

qt源码--信号槽
本篇主要从Qt信号槽的连接、断开、调用、对象释放等方面展开; 1.信号建立连接过程 connect有多个重载函数,主要是为了方便使用者,比较常用的有2种方式: a. QObject::connect(&timer, &QTimer::timeout, &loop, &am…...

RecycleView详解
listview缓存请看: listview优化和详解RecycleView 和 ListView对比:使用方法上ListView:继承重写 BaseAdapter,自定义 ViewHolder 与 converView优化。RecyclerView: 继承重写 RecyclerView.Adapter 与 RecyclerView.ViewHolder。设置 Layou…...

【算法】最短路算法
😀大家好,我是白晨,一个不是很能熬夜😫,但是也想日更的人✈。如果喜欢这篇文章,点个赞👍,关注一下👀白晨吧!你的支持就是我最大的动力!Ǵ…...

< Linux > 进程间通信
目录 1、进程间通信介绍 进程间通信的概念 进程间通信的本质 进程间通信的分类 2、管道 2.1、什么是管道 2.2、匿名管道 匿名管道的原理 pipe函数 匿名管道使用步骤 2.3、管道的读写规则 2.4、管道的特点 2.5、命名管道 命名管道的原理 使用命令创建命名管道 mkfifo创建命名管…...

学习 Python 之 Pygame 开发魂斗罗(二)
学习 Python 之 Pygame 开发魂斗罗(二)魂斗罗的需求开始编写魂斗罗1. 搭建主类框架2. 设置游戏运行遍历和创建窗口3. 获取窗口中的事件4. 创建角色5. 完成角色更新函数魂斗罗的需求 魂斗罗游戏中包含很多个物体,现在要对这些物体进行总结 类…...

户籍管理系统测试用例
目录 一、根据页面的不同分别设计测试用例 登录页面 用户信息列表 用户编辑页面 用户更新页面 二、根据目的不同分别设计测试用例 一、根据页面的不同分别设计测试用例 上图是针对一个网站的测试,按照页面的不同分别来设计对应的测试用例。 登录页面 用户信息列…...

(三)代表性物质点邻域的变形分析
本文主要内容如下:1. 伸长张量与Cauchy-Green 张量2. 线元长度的改变2.1. 初始/当前构型下的长度比2.2. 主长度比与 Lagrange/Euler 主方向2.3. 初始/当前构型下任意方向的长度比3. 线元夹角的改变4. 面元的改变5. 体元的改变1. 伸长张量与Cauchy-Green 张量 由于变…...

Stream操作流 练习
基础数据:Data AllArgsConstructor NoArgsConstructor public class User {private String name;private int age;private String sex;private String city;private Integer money; static List<User> users new ArrayList<>();public static void m…...

【模拟集成电路】宽摆幅压控振荡器(VCO)设计
鉴频鉴相器设计(Phase Frequency Detector,PFD)前言一、VCO工作原理二、VCO电路设计VCO原理图三、压控振荡器(VCO)测试VCO测试电路图瞬态测试(1)瞬态输出(2)局部放大图&a…...
《英雄编程体验课》第 13 课 | 双指针
文章目录 零、写在前面一、最长不重复子串1、初步分析2、朴素算法3、优化算法二、双指针1、算法定义2、算法描述3、条件1)单调性2)时效性三、双指针的应用1、前缀和问题2、哈希问题3、K 大数问题零、写在前面 该章节节选自 《夜深人静写算法》,主要讲解最基础的枚举算法 ——…...

DS期末复习卷(十)
一、选择题(24分) 1.下列程序段的时间复杂度为( A )。 i0,s0; while (s<n) {ssi;i;} (A) O(n^1/2) (B) O(n ^1/3) © O(n) (D) O(n ^2) 12…xn xn^1/2 2.设某链表中最常用的…...

QT+OpenGL模板测试和混合
QTOpenGL模板测试和混合 本篇完整工程见gitee:QtOpenGL 对应点的tag,由turbolove提供技术支持,您可以关注博主或者私信博主 模板测试 当片段着色器处理完一个片段之后,模板测试会开始执行。和深度测试一样,它可能会丢弃片段&am…...
《英雄编程体验课》第 11 课 | 前缀和
文章目录 零、写在前面一、概念定义1、部分和2、朴素做法3、前缀和4、前缀和的边界值5、边界处理6、再看部分和二、题目描述1、定义2、求解三、算法详解四、源码剖析五、推荐专栏六、习题练习零、写在前面 该章节节选自 《算法零基础100讲》,主要讲解最基础的算法 —— 前缀和…...
Java学习--多线程2
2.线程同步 2.1卖票【应用】 案例需求 某电影院目前正在上映国产大片,共有100张票,而它有3个窗口卖票,请设计一个程序模拟该电影院卖票 实现步骤 定义一个类SellTicket实现Runnable接口,里面定义一个成员变量:privat…...

【Virtualization】Windows11安装VMware Workstation后异常处置
安装环境 Windows 11 专业版 22H2 build 22621.1265 VMware Workstation 17 Pro 17.0.0 build-20800274 存在问题 原因分析 1、BIOS未开启虚拟化。 2、操作系统启用的虚拟化与Workstation冲突。 3、操作系统启用内核隔离-内存完整性保护。 处置思路 1、打开“资源管理器”…...

第四章.神经网络—BP神经网络
第四章.神经网络 4.3 BP神经网络 BP神经网络(误差反向传播算法)是整个人工神经网络体系中的精华,广泛应用于分类识别,逼近,回归,压缩等领域,在实际应用中,大约80%的神经网络模型都采用BP网络或BP网络的变化…...

铭豹扩展坞 USB转网口 突然无法识别解决方法
当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

国防科技大学计算机基础课程笔记02信息编码
1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

大数据零基础学习day1之环境准备和大数据初步理解
学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 (1)设置网关 打开VMware虚拟机,点击编辑…...
Element Plus 表单(el-form)中关于正整数输入的校验规则
目录 1 单个正整数输入1.1 模板1.2 校验规则 2 两个正整数输入(联动)2.1 模板2.2 校验规则2.3 CSS 1 单个正整数输入 1.1 模板 <el-formref"formRef":model"formData":rules"formRules"label-width"150px"…...

dify打造数据可视化图表
一、概述 在日常工作和学习中,我们经常需要和数据打交道。无论是分析报告、项目展示,还是简单的数据洞察,一个清晰直观的图表,往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server,由蚂蚁集团 AntV 团队…...

均衡后的SNRSINR
本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt 根发送天线, n r n_r nr 根接收天线的 MIMO 系…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”
2025年#高考 将在近日拉开帷幕,#AI 监考一度冲上热搜。当AI深度融入高考,#时间同步 不再是辅助功能,而是决定AI监考系统成败的“生命线”。 AI亮相2025高考,40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕,江西、…...

HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合
在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...