Keras三种主流模型构建方式:序列模型、函数模型、子类模型开发实践,以真实烟雾识别场景数据为例
Keras和PyTorch是两个常用的深度学习框架,它们都提供了用于构建和训练神经网络的高级API。
Keras:
Keras是一个高级神经网络API,可以在多个底层深度学习框架上运行,如TensorFlow和CNTK。以下是Keras的特点和优点:
优点:
- 简单易用:Keras具有简洁的API设计,易于上手和使用,适合快速原型设计和实验。
- 灵活性:Keras提供了高级API和模块化的架构,可以灵活地构建各种类型的神经网络模型。
- 复用性:Keras模型可以轻松保存和加载,可以方便地共享、部署和迁移模型。
- 社区支持:Keras拥有庞大的社区支持和活跃的开发者社区,提供了大量的文档、教程和示例代码。
缺点:
- 功能限制:相比于底层框架如TensorFlow和PyTorch,Keras在某些高级功能和自定义性方面可能有所限制。
- 可扩展性:虽然Keras提供了易于使用的API,但在需要大量定制化和扩展性的复杂模型上可能会有限制。
- 灵活程度:Keras主要设计用于简单的流程,当需要处理复杂的非标准任务时,使用Keras的灵活性较差。
适用场景:
- 初学者:对于新手来说,Keras是一个理想的选择,因为它简单易用,有丰富的文档和示例来帮助快速入门。
- 快速原型设计:Keras可以快速搭建和迭代模型,适用于快速原型设计和快速实验验证。
- 常规计算机视觉和自然语言处理任务:Keras提供了大量用于计算机视觉和自然语言处理的预训练模型和工具,适用于常规任务的开发与应用。
PyTorch:
PyTorch是一个动态图深度学习框架,强调易于使用和低延迟的调试功能。以下是PyTorch的特点和优点:
优点:
- 动态图:PyTorch使用动态图,使得模型构建和调试更加灵活和直观,可以实时查看和调试模型。
- 自由控制:相比于静态图框架,PyTorch能够更自由地控制模型的复杂逻辑和探索新的网络架构。
- 算法开发:PyTorch提供了丰富的数学运算库和自动求导功能,适用于算法研究和定制化模型开发。
- 社区支持:PyTorch拥有活跃的社区和大量的开源项目,提供了丰富的资源和支持。
缺点:
- 部署复杂性:相比于Keras等高级API框架,PyTorch需要开发者更多地处理模型的部署和生产环境的问题。
- 静态优化:相对于静态图框架,如TensorFlow,PyTorch无法进行静态图优化,可能在性能方面略逊一筹。
- 入门门槛:相比于Keras,PyTorch对初学者来说可能有一些陡峭的学习曲线。
适用场景:
- 研究和定制化模型:PyTorch适合进行研究和实验,以及需要灵活性和自由度较高的定制化模型开发。
- 高级计算机视觉和自然语言处理任务:PyTorch在计算机视觉和自然语言处理领域有广泛的应用,并且各类预训练模型和资源丰富。
在前面的两篇文章中整体系统总结记录了Keras和PyTroch这两大主流框架各自开发构建模型的三大主流方式,并对应给出来的基础的实例实现,感兴趣的话可以自行移步阅读即可:
《总结记录Keras开发构建神经网络模型的三种主流方式:序列模型、函数模型、子类模型》
《总结记录PyTorch构建神经网络模型的三种主流方式:nn.Sequential按层顺序构建模型、继承nn.Module基类构建自定义模型、继承nn.Module基类构建模型并辅助应用模型容器来封装》
本文的主要目的就是想要基于真实业务数据场景来实地开发实践这三种不同类型的模型构建方式,并对结果进行对比分析。
首先来看下数据集:


这里模型结构的话可以自行构建设计层数都是没有关系的,我这里主要是参考了VGG的网络结构来搭建的网络模型,首先来看序列模型构建实现:
def initModel(h=100, w=100, way=3):"""列模型"""input_shape = (h, w, way)model = Sequential()model.add(Conv2D(64,(3, 3),strides=(1, 1),input_shape=input_shape,padding="same",activation="relu",kernel_initializer="uniform",))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128,(3, 2),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(256,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(512,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(512,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten())model.add(Dense(820, activation="relu"))model.add(Dropout(0.1))model.add(Dense(820, activation="relu"))model.add(Dropout(0.1))model.add(Dense(numbers, activation="softmax"))return model
网络结构输出如下所示:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 100, 100, 64) 1792
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 50, 50, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 50, 50, 128) 49280
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 25, 25, 128) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 25, 25, 256) 295168
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 12, 12, 256) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 12, 12, 512) 1180160
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 6, 6, 512) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 6, 6, 512) 2359808
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 3, 3, 512) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 4608) 0
_________________________________________________________________
dense_1 (Dense) (None, 820) 3779380
_________________________________________________________________
dropout_1 (Dropout) (None, 820) 0
_________________________________________________________________
dense_2 (Dense) (None, 820) 673220
_________________________________________________________________
dropout_2 (Dropout) (None, 820) 0
_________________________________________________________________
dense_3 (Dense) (None, 2) 1642
=================================================================
Total params: 8,340,450
Trainable params: 8,340,450
Non-trainable params: 0
_________________________________________________________________
接下来是函数模型代码实现,如下所示:
def initModel(h=100, w=100, way=3):"""函数模型"""input_shape = (h, w, way)inputs = Input(shape=input_shape)X = Conv2D(64,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)(inputs)X = Conv2D(64,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)(X)X = MaxPooling2D(pool_size=(2, 2))(X)X = Conv2D(128,(3, 2),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)(X)X = MaxPooling2D(pool_size=(2, 2))(X)X = Conv2D(256,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)(X)X = MaxPooling2D(pool_size=(2, 2))(X)X = Conv2D(512,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)(X)X = MaxPooling2D(pool_size=(2, 2))(X)X = Conv2D(512,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)(X)X = MaxPooling2D(pool_size=(2, 2))(X)X = Flatten()(X)X = Dense(820, activation="relu")(X)X = Dropout(0.1)(X)X = Dense(820, activation="relu")(X)X = Dropout(0.1)(X)outputs = Dense(2, activation="sigmoid")(X)model = Model(input=inputs, output=outputs)return model
模型结构信息输出如下所示:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 100, 100, 3) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 100, 100, 64) 1792
_________________________________________________________________
conv2d_7 (Conv2D) (None, 100, 100, 64) 36928
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 50, 50, 64) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 50, 50, 128) 49280
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 25, 25, 128) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 25, 25, 256) 295168
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 12, 12, 256) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 12, 12, 512) 1180160
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 6, 6, 512) 0
_________________________________________________________________
conv2d_11 (Conv2D) (None, 6, 6, 512) 2359808
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 3, 3, 512) 0
_________________________________________________________________
flatten_2 (Flatten) (None, 4608) 0
_________________________________________________________________
dense_4 (Dense) (None, 820) 3779380
_________________________________________________________________
dropout_3 (Dropout) (None, 820) 0
_________________________________________________________________
dense_5 (Dense) (None, 820) 673220
_________________________________________________________________
dropout_4 (Dropout) (None, 820) 0
_________________________________________________________________
dense_6 (Dense) (None, 2) 1642
=================================================================
Total params: 8,377,378
Trainable params: 8,377,378
Non-trainable params: 0
_________________________________________________________________
最后是子类模型代码实现,如下所示:
class initModel(Model):"""子类模型"""def __init__(self):super(initModel, self).__init__()self.conv2d1 = Conv2D(64,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)self.conv2d2 = Conv2D(64,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)self.pool1 = MaxPooling2D(pool_size=(2, 2))self.conv2d3 = Conv2D(128,(3, 2),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)self.pool2 = MaxPooling2D(pool_size=(2, 2))self.conv2d4 = Conv2D(256,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)self.pool3 = MaxPooling2D(pool_size=(2, 2))self.conv2d5 = Conv2D(512,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)self.pool4 = MaxPooling2D(pool_size=(2, 2))self.conv2d6 = Conv2D(512,(3, 3),strides=(1, 1),padding="same",activation="relu",kernel_initializer="uniform",)self.pool5 = MaxPooling2D(pool_size=(2, 2))self.flatten = Flatten()self.dense1 = Dense(820, activation="relu")self.dropout1 = Dropout(0.1)self.dense2 = Dense(820, activation="relu")self.dropout2 = Dropout(0.1)self.dense3 = Dense(2, activation="sigmoid")def call(self, inputs):"""回调"""x = self.conv2d1(inputs)x = self.conv2d2(x)x = self.pool1(x)x = self.conv2d3(x)x = self.pool2(x)x = self.conv2d4(x)x = self.pool3(x)x = self.conv2d5(x)x = self.pool4(x)x = self.conv2d6(x)x = self.pool5(x)x = self.flatten(x)x = self.dense1(x)x = self.dropout1(x)x = self.dense2(x)x = self.dropout2(x)y = self.dense3(x)return y
完成模型的搭建之后就可以加载对应的数据集开始训练模型了,数据集加载仿照mnist数据集的形式即可,这里就不再赘述了,在我之前的文章中也都有对应的实现,如下所示:
# 数据加载
X_train, X_test, y_train, y_test = loadData()
X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
# 数据归一化
X_train /= 255
X_test /= 255
# 数据打乱
np.random.seed(200)
np.random.shuffle(X_train)
np.random.seed(200)
np.random.shuffle(y_train)
np.random.seed(200)
np.random.shuffle(X_test)
np.random.seed(200)
np.random.shuffle(y_test)
# 模型
model=initModel()
model.compile(loss="binary_crossentropy", optimizer="sgd", metrics=["accuracy"])
模型评估测试可视化实现如下所示:
# 可视化
plt.clf()
plt.plot(history.history["acc"])
plt.plot(history.history["val_acc"])
plt.title("model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epochs")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(saveDir + "train_validation_acc.png")
plt.clf()
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("model loss")
plt.ylabel("loss")
plt.xlabel("epochs")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(saveDir + "train_validation_loss.png")
scores = model.evaluate(X_test, y_test, verbose=0)
print("Accuracy: %.2f%%" % (scores[1] * 100))
接下来看下结果:
【序列模型】

【函数模型】

【子类模型】

结果上有略微的差异,这个应该跟训练有关系。
可视化结果如下所示:

其实三种方法也是本质一样的,只要熟练熟悉了某一种,其他的构建方式都是可以基于当前的构建方式转化完成的。没有绝对唯一的选择,只有最适合自己的选择。
相关文章:
Keras三种主流模型构建方式:序列模型、函数模型、子类模型开发实践,以真实烟雾识别场景数据为例
Keras和PyTorch是两个常用的深度学习框架,它们都提供了用于构建和训练神经网络的高级API。 Keras: Keras是一个高级神经网络API,可以在多个底层深度学习框架上运行,如TensorFlow和CNTK。以下是Keras的特点和优点: 优点…...
objective-v 获取iPhone系统当前时间字符串适配12小时制和24小时制
我们最开始获取系统当前时间,如下,这种方式存在一个问题,当iPhone关闭了24小时制时,获取的时间格式是:iPhone11上:20230822下午210568760;iPhone7 plus上:2023082240043851 PM&#…...
并查集及其简单应用
文章目录 一.并查集二.并查集的实现三.并查集的基本应用 一.并查集 并查集的逻辑结构:由多颗不相连通的多叉树构成的森林(一个这样的多叉树就是森林的一个连通分量) 并查集的元素(树节点)用0~9的整数表示,并查集可以表示如下: 并查集的物理存储结构:并查集一般采用顺序结构实…...
基于web的服装商城系统java网上购物商店jsp源代码mysql
本项目为前几天收费帮学妹做的一个项目,Java EE JSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。 一、项目描述 基于web的服装商城系统 系统有1权限:前台…...
.NET Core发布到IIS
项目介绍 1、开发工具Visual Studio 2017,语言C#,SQL SERVER,WIN10 2、本地IIS,手机上或其他用户在和本地在同一个局域网内访问,同时要把防火墙关掉 3、IIS全名Internet Information Services,用来发布网站 先决条件 安…...
Spring的基本概念
前言 Spring 究竟是什么?其实Spring简单来说就是一个包含众多工具方法的IOC容器。 那么什么是IOC呢? IoC Inversion of Control 翻译成中⽂是“控制反转”的意思. 既然Spring 是⼀个IoC(控制反转)容器,重点还在“容…...
设计模式之原型模式
文章目录 一、介绍二、实现步骤三、案例四、应用五、细胞分裂六、改造细胞分裂逻辑七、总结 一、介绍 原型模式属于创建型设计模式,用于创建重复的对象,且同时又保证了性能。 该设计模式的好处是将对象的创建与调用方分离。 其目的就是**根据一个对象…...
正则表达式在网页处理中的应用四则
正则表达式在网页处理中的应用四则 正则表达式(Regular Expression)为字符串模式匹配提供了一种高效、方便的方法。几乎所有高级语言都提供了对正则表达式的支持,或者提供了现成的代码库供调用。本文以ASP环境中常见的处理任务为例,介绍正则表达式的应用技巧。 一、检验密…...
ping使用方法
文章目录 1、Ping的基础知识2、Ping命令详解3、怎样使用Ping这命令来测试网络连通?4、如何用Ping命令来判断一条链路好坏?5、对Ping后返回信息的分析1.Request timed out2.Destination host Unreachable 1、Ping的基础知识 ping命令相信大家已经再熟悉不…...
“心理健康人工智能产学研创新联盟”揭牌成立|深兰科技
8月14日上午,“2023树洞救援年会”在上海举行,会上举行了“心理健康人工智能产学研创新联盟”的签约和揭牌仪式。“树洞行动救援团”创始人深兰科技科学院智能科学首席科学家、荷兰阿姆斯特丹自由大学人工智能系终身教授黄智生,深兰科技集团创…...
FastDFS+Nginx - 本地搭建文件服务器同时实现在外远程访问「端口映射」
文章目录 前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx 2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.…...
Mybatis-动态sql和分页
目录 一.什么是Mybatis动态分页 二.mybatis中的动态SQL 在BookMaaper.xml中写sql BookMapper BookBiz接口类 BookBizImpl实现接口类 demo测试类 编辑 测试结果 三.mybatis中的模糊查询 mybatis中的#与$有是什么区别 在BookMapper.xml里面建立三个模糊查询 编辑 …...
基于YOLOV8模型的西红柿目标检测系统(PyTorch+Pyside6+YOLOv8模型)
摘要:基于YOLOV8模型的西红柿目标检测系统可用于日常生活中检测与定位西红柿目标,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数…...
数学建模及数据分析 || 4. 深度学习应用案例分享
PyTorch 深度学习全连接网络分类 文章目录 PyTorch 深度学习全连接网络分类1. 非线性二分类2. 泰坦尼克号数据分类2.1 数据的准备工作2.2 全连接网络的搭建2.3 结果的可视化 1. 非线性二分类 import sklearn.datasets #数据集 import numpy as np import matplotlib.pyplot as…...
数据分析15——office中的Excel基础技术汇总
0、前言: 这部分总结就是总结每个基础技术的定义,在了解基础技术名称和定义后,方便对相关技术进行检索学习。笔记不会详细到所有操作都说明,但会把基础操作的名称及作用说明,可自行检索。本文对于大部分读者有以下作用…...
C语言好题解析(四)
目录 选择题一选择题二选择题三选择题四选择题五编程题一 选择题一 已知函数的原型是: int fun(char b[10], int *a); 设定义: char c[10];int d; ,正确的调用语句是( ) A: fun(c,&d); B: fun(c,d); C: fun(&…...
英语——主谓一致
主谓一致是指句子的谓语动词与其主语在数上必须保持一致,一般遵循以下三个原则: 一、语法形式上一致,即单复数形式与谓语要一致。 二、意义上一致,即主语意义上的单复数要与谓语的单复数形式一致。 三、就近以及就远原则,即谓语动词的单复形式取决于最靠近它的词语或者离它…...
属性字符串解析
连续的KV的字符串,每个KV之间用","分隔,V中可嵌套KV的连续字符串结构,例如“ key1value1,key2value2,key3[key4value4,key5value5,key6[key7value7]],key8value8 请编写如下函数,给定字符串,输出嵌套结构的H…...
【C++初阶】vector容器
👦个人主页:Weraphael ✍🏻作者简介:目前学习C和算法 ✈️专栏:C航路 🐋 希望大家多多支持,咱一起进步!😁 如果文章对你有帮助的话 欢迎 评论💬 点赞…...
ThreadLocal深度解析
简介 在并发编程中,导致并发bug的问题都会归结于对共享变量的操作不当。多个线程同时读写同一共享变量存在并发问题,我们可以利用写时复制、不变性来突破对原数据的写操作,没有写就没有并发问题,而本篇文章所介绍的技术是突破共享…...
基于FPGA的PID算法学习———实现PID比例控制算法
基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容:参考网站: PID算法控制 PID即:Proportional(比例)、Integral(积分&…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
vue3 定时器-定义全局方法 vue+ts
1.创建ts文件 路径:src/utils/timer.ts 完整代码: import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...
css的定位(position)详解:相对定位 绝对定位 固定定位
在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...
AI书签管理工具开发全记录(十九):嵌入资源处理
1.前言 📝 在上一篇文章中,我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源,方便后续将资源打包到一个可执行文件中。 2.embed介绍 🎯 Go 1.16 引入了革命性的 embed 包,彻底改变了静态资源管理的…...
C++使用 new 来创建动态数组
问题: 不能使用变量定义数组大小 原因: 这是因为数组在内存中是连续存储的,编译器需要在编译阶段就确定数组的大小,以便正确地分配内存空间。如果允许使用变量来定义数组的大小,那么编译器就无法在编译时确定数组的大…...
Selenium常用函数介绍
目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...
【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案
目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后,迭代器会失效,因为顺序迭代器在内存中是连续存储的,元素删除后,后续元素会前移。 但一些场景中,我们又需要在执行删除操作…...
TSN交换机正在重构工业网络,PROFINET和EtherCAT会被取代吗?
在工业自动化持续演进的今天,通信网络的角色正变得愈发关键。 2025年6月6日,为期三天的华南国际工业博览会在深圳国际会展中心(宝安)圆满落幕。作为国内工业通信领域的技术型企业,光路科技(Fiberroad&…...
