卷积神经网络(AlexNet)鸟类识别
文章目录
- 一、前言
- 二、前期工作
- 1. 设置GPU(如果使用的是CPU可以忽略这步)
- 2. 导入数据
- 3. 查看数据
- 二、数据预处理
- 1. 加载数据
- 2. 可视化数据
- 3. 再次检查数据
- 4. 配置数据集
- 三、AlexNet (8层)介绍
- 四、构建AlexNet (8层)网络模型
- 五、编译
- 六、训练模型
- 七、模型评估
- 八、保存and加载模型
- 九、预测
一、前言
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
往期精彩内容:
- 卷积神经网络(CNN)实现mnist手写数字识别
- 卷积神经网络(CNN)多种图片分类的实现
- 卷积神经网络(CNN)衣服图像分类的实现
- 卷积神经网络(CNN)鲜花识别
- 卷积神经网络(CNN)天气识别
- 卷积神经网络(VGG-16)识别海贼王草帽一伙
- 卷积神经网络(ResNet-50)鸟类识别
来自专栏:机器学习与深度学习算法推荐
二、前期工作
1. 设置GPU(如果使用的是CPU可以忽略这步)
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")
2. 导入数据
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号import os,PIL# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)import pathlib
data_dir = "bird_photos"data_dir = pathlib.Path(data_dir)
3. 查看数据
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 565
二、数据预处理
| 文件夹 | 数量 |
|---|---|
| Bananaquit | 166 张 |
| Black Throated Bushtiti | 111 张 |
| Black skimmer | 122 张 |
| Cockatoo | 166张 |
1. 加载数据
使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中
batch_size = 8
img_height = 227
img_width = 227
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 565 files belonging to 4 classes.
Using 452 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)
Found 565 files belonging to 4 classes.
Using 113 files for validation.
我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。
class_names = train_ds.class_names
print(class_names)
['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
2. 可视化数据
plt.figure(figsize=(10, 5)) # 图形的宽为10高为5for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1) plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
plt.imshow(images[1].numpy().astype("uint8"))
3. 再次检查数据
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(8, 227, 227, 3)
(8,)
Image_batch是形状的张量(8, 224, 224, 3)。这是一批形状240x240x3的8张图片(最后一维指的是彩色通道RGB)。Label_batch是形状(8,)的张量,这些标签对应8张图片
4. 配置数据集
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
三、AlexNet (8层)介绍
AleXNet使用了ReLU方法加快训练速度,并且使用Dropout来防止过拟合
AleXNet (8层)是首次把卷积神经网络引入计算机视觉领域并取得突破性成绩的模型。获得了ILSVRC 2012年的冠军,再top-5项目中错误率仅仅15.3%,相对于使用传统方法的亚军26.2%的成绩优良重大突破。和之前的LeNet相比,AlexNet通过堆叠卷积层使得模型更深更宽。

四、构建AlexNet (8层)网络模型
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout,BatchNormalization,Activationimport numpy as np
seed = 7
np.random.seed(seed)def AlexNet(nb_classes, input_shape):input_tensor = Input(shape=input_shape)# 1st blockx = Conv2D(96, (11,11), strides=4, name='block1_conv1')(input_tensor)x = BatchNormalization()(x)x = Activation('relu')(x)x = MaxPooling2D((3,3), strides=2, name = 'block1_pool')(x)# 2nd blockx = Conv2D(256, (5,5), padding='same', name='block2_conv1')(x)x = BatchNormalization()(x)x = Activation('relu')(x)x = MaxPooling2D((3,3), strides=2, name='block2_pool')(x)# 3rd blockx = Conv2D(384, (3,3), activation='relu', padding='same',name='block3_conv1')(x)# 4th blockx = Conv2D(384, (3,3), activation='relu', padding='same',name='block4_conv1')(x)# 5th blockx = Conv2D(256, (3,3), activation='relu', padding='same',name='block5_conv1')(x)x = MaxPooling2D((3,3), strides=2, name = 'block5_pool')(x)# full connectionx = Flatten()(x)x = Dense(4096, activation='relu', name='fc1')(x)x = Dropout(0.5)(x)x = Dense(4096, activation='relu', name='fc2')(x)x = Dropout(0.5)(x)output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)model = Model(input_tensor, output_tensor)return modelmodel=AlexNet(1000, (img_width, img_height, 3))
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 227, 227, 3)] 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 55, 55, 96) 34944
_________________________________________________________________
batch_normalization (BatchNo (None, 55, 55, 96) 384
_________________________________________________________________
activation (Activation) (None, 55, 55, 96) 0
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 27, 27, 96) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 27, 27, 256) 614656
_________________________________________________________________
batch_normalization_1 (Batch (None, 27, 27, 256) 1024
_________________________________________________________________
activation_1 (Activation) (None, 27, 27, 256) 0
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 13, 13, 256) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 13, 13, 384) 885120
_________________________________________________________________
block4_conv1 (Conv2D) (None, 13, 13, 384) 1327488
_________________________________________________________________
block5_conv1 (Conv2D) (None, 13, 13, 256) 884992
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 6, 6, 256) 0
_________________________________________________________________
flatten (Flatten) (None, 9216) 0
_________________________________________________________________
fc1 (Dense) (None, 4096) 37752832
_________________________________________________________________
dropout (Dropout) (None, 4096) 0
_________________________________________________________________
fc2 (Dense) (None, 4096) 16781312
_________________________________________________________________
dropout_1 (Dropout) (None, 4096) 0
_________________________________________________________________
predictions (Dense) (None, 1000) 4097000
=================================================================
Total params: 62,379,752
Trainable params: 62,379,048
Non-trainable params: 704
_________________________________________________________________
五、编译
在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
- 损失函数(loss):用于衡量模型在训练期间的准确率。
- 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
- 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 设置优化器,我这里改变了学习率。
# opt = tf.keras.optimizers.Adam(learning_rate=1e-7)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])
六、训练模型
epochs = 20history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)
Epoch 1/20
57/57 [==============================] - 5s 30ms/step - loss: 9.2789 - accuracy: 0.2166 - val_loss: 3.2340 - val_accuracy: 0.3363
Epoch 2/20
57/57 [==============================] - 1s 14ms/step - loss: 0.9329 - accuracy: 0.6224 - val_loss: 1.1778 - val_accuracy: 0.5310
Epoch 3/20
57/57 [==============================] - 1s 14ms/step - loss: 0.7438 - accuracy: 0.6747 - val_loss: 1.9651 - val_accuracy: 0.5133
Epoch 4/20
57/57 [==============================] - 1s 14ms/step - loss: 0.8875 - accuracy: 0.7025 - val_loss: 1.5589 - val_accuracy: 0.4602
Epoch 5/20
57/57 [==============================] - 1s 14ms/step - loss: 0.6116 - accuracy: 0.7424 - val_loss: 0.9914 - val_accuracy: 0.4956
Epoch 6/20
57/57 [==============================] - 1s 15ms/step - loss: 0.6258 - accuracy: 0.7520 - val_loss: 1.1103 - val_accuracy: 0.5221
Epoch 7/20
57/57 [==============================] - 1s 13ms/step - loss: 0.5138 - accuracy: 0.8034 - val_loss: 0.7832 - val_accuracy: 0.6726
Epoch 8/20
57/57 [==============================] - 1s 14ms/step - loss: 0.5343 - accuracy: 0.7940 - val_loss: 6.1064 - val_accuracy: 0.4602
Epoch 9/20
57/57 [==============================] - 1s 14ms/step - loss: 0.8667 - accuracy: 0.7606 - val_loss: 0.6869 - val_accuracy: 0.7965
Epoch 10/20
57/57 [==============================] - 1s 16ms/step - loss: 0.5785 - accuracy: 0.8141 - val_loss: 1.3631 - val_accuracy: 0.5310
Epoch 11/20
57/57 [==============================] - 1s 15ms/step - loss: 0.4929 - accuracy: 0.8109 - val_loss: 0.7191 - val_accuracy: 0.7345
Epoch 12/20
57/57 [==============================] - 1s 15ms/step - loss: 0.4141 - accuracy: 0.8507 - val_loss: 0.4962 - val_accuracy: 0.8496
Epoch 13/20
57/57 [==============================] - 1s 15ms/step - loss: 0.2591 - accuracy: 0.9148 - val_loss: 0.8015 - val_accuracy: 0.8053
Epoch 14/20
57/57 [==============================] - 1s 15ms/step - loss: 0.2683 - accuracy: 0.9079 - val_loss: 0.5451 - val_accuracy: 0.8142
Epoch 15/20
57/57 [==============================] - 1s 14ms/step - loss: 0.2925 - accuracy: 0.9096 - val_loss: 0.6668 - val_accuracy: 0.8584
Epoch 16/20
57/57 [==============================] - 1s 14ms/step - loss: 0.4009 - accuracy: 0.8804 - val_loss: 1.1609 - val_accuracy: 0.6372
Epoch 17/20
57/57 [==============================] - 1s 14ms/step - loss: 0.4375 - accuracy: 0.8446 - val_loss: 0.9854 - val_accuracy: 0.7965
Epoch 18/20
57/57 [==============================] - 1s 14ms/step - loss: 0.3085 - accuracy: 0.8926 - val_loss: 0.6477 - val_accuracy: 0.8761
Epoch 19/20
57/57 [==============================] - 1s 15ms/step - loss: 0.1200 - accuracy: 0.9538 - val_loss: 1.8996 - val_accuracy: 0.5398
Epoch 20/20
57/57 [==============================] - 1s 15ms/step - loss: 0.3378 - accuracy: 0.9095 - val_loss: 0.9337 - val_accuracy: 0.8053
七、模型评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
八、保存and加载模型
保存模型
model.save('model/my_model.h5')
# 加载模型
new_model = tf.keras.models.load_model('model/my_model.h5')
九、预测
# 采用加载的模型(new_model)来看预测结果plt.figure(figsize=(10, 5)) # 图形的宽为10高为5for images, labels in val_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i + 1) # 显示图片plt.imshow(images[i].numpy().astype("uint8"))# 需要给图片增加一个维度img_array = tf.expand_dims(images[i], 0) # 使用模型预测图片中的人物predictions = new_model.predict(img_array)plt.title(class_names[np.argmax(predictions)])plt.axis("off")
相关文章:
卷积神经网络(AlexNet)鸟类识别
文章目录 一、前言二、前期工作1. 设置GPU(如果使用的是CPU可以忽略这步)2. 导入数据3. 查看数据 二、数据预处理1. 加载数据2. 可视化数据3. 再次检查数据4. 配置数据集 三、AlexNet (8层)介绍四、构建AlexNet (8层)网络模型五、…...
hive 报错return code 40000 from org.apache.hadoop.hive.ql.exec.MoveTask解决思路
参考学习 https://github.com/apache/hive/blob/2b57dd27ad61e552f93817ac69313066af6562d9/ql/src/java/org/apache/hadoop/hive/ql/ErrorMsg.java#L47 为啥学习error code 开发过程中遇到以下错误,大家觉得应该怎么办?从哪方面入手呢? 1.百…...
Java Web——XML
1. XML概述 XML是EXtensible Markup Language的缩写,翻译过来就是可扩展标记语言。XML是一种用于存储和传输数据的语言,它使用标签来标记数据,以便于计算机处理和我们人来阅读。 “可扩展”三个字表明XML可以根据需要进行扩展和定制。这意味…...
【.NET Core】Task应用详解
【.NET Core】Task应用详解 文章目录 【.NET Core】Task应用详解一、概述二、Task用法应用2.1 通过New实例化Task2.2 通过Factory中StartNew方法2.3 通过Run方法 三、让Task任务按顺序执行四、通过异步Run方法异步执行顺序Task五、创建带有返回值的Task<TResult>六、Task…...
convertRect:toView 方法注意事项
这是在网上找到的一张图 我们开发中有时候会用到左边转换,convertRect:toView 通常情况下,我们回这样使用 CGRect newRect [a convertRect:originframe toView:c];其中newRect和 originframe的size相同,只改变origin newRect.origin a…...
Java实现王者荣耀小游戏
主要功能 键盘W,A,S,D键:控制玩家上下左右移动。按钮一:控制英雄发射一个矩形攻击红方小兵。按钮控制英雄发射魅惑技能,伤害小兵并让小兵停止移动。技能三:攻击多个敌人并让小兵停止移动。普攻:对小兵造成基础伤害。小…...
【黑马甄选离线数仓day04_维度域开发】
1. 维度主题表数据导出 1.1 PostgreSQL介绍 PostgreSQL 是一个功能强大的开源对象关系数据库系统,它使用和扩展了 SQL 语言,并结合了许多安全存储和扩展最复杂数据工作负载的功能。 官方网址:PostgreSQL: The worlds most advanced open s…...
C# 中using关键字的使用
在C#中我们还是很有必要掌握using关键字的。 比如这样: string path “D:\data.txt”; if (!File.Exists(path )) {File.Create(path); File.WriteAllText(path,"OK"); } 首先我创建…...
16 redis高可用读写分离方案
在前面说的JedisSentinelPool只能实现主从的切换,而无法实现读写的分离。 1.哨兵的客户端实现主从切换方案 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</arti…...
Nginx模块开发之http handler实现流量统计(2)
文章目录 一、概述二、Nginx handler模块开发2.1、代码实现2.2、编写config文件2.3、编译模块到Nginx源码中2.4、修改conf文件2.5、执行效果 总结 一、概述 上一篇【Nginx模块开发之http handler实现流量统计(1)】使用数组在单进程实现了IP的流量统计&a…...
案例012:Java+SSM+uniapp基于微信小程序的科创微应用平台设计与实现
文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…...
vue3+elementPlus登录向后端服务器发起数据请求Ajax
后端的url登录接口 先修改main.js文件 // 导入Ajax 前后端数据传输 import axios from "axios"; const app createApp(App) //vue3.0使用app.config.globalProperties.$http app.config.globalProperties.$http axios app.mount(#app); login.vue 页面显示部分…...
存储区域
将应用程序加载到内存空间执行时,操作系统负责代码段、数据段和BSS段的加载,并在内存中为这些段分配空间。 栈段亦由操作系统分配和管理,而不需要程序员显示地管理;堆段由程序员自己管理,即显示地申请和释放空间。 进…...
C#串口通信从入门到精通(27)——高速通信下解决数据处理慢的问题(20ms以内)
前言 我们在开发串口通信程序时,有时候会遇到比如单片机或者传感器发送的数据速度特别快,比如10ms、20ms发送一次,并且每次发送的数据量还比较大,如果按照常规的写法,我们会发现接收的数据还没处理完,新的数据又发送过来了,这就会导致处理数据滞后,软件始终处理的不是…...
Redis-Redis高可用集群之水平扩展
Redis3.0以后的版本虽然有了集群功能,提供了比之前版本的哨兵模式更高的性能与可用性,但是集群的水平扩展却比较麻烦,今天就来带大家看看redis高可用集群如何做水平扩展,原始集群(见下图)由6个节点组成,6个节点分布在三…...
2023全球数字贸易创新大赛-人工智能元宇宙-4-10
目录 竞赛感悟: 创业的话 好的项目 数字工厂,智慧制造:集群控制的安全问题...
go defer用法_类似与python_java_finially
defer 执行 时间 defer 一般 定义在 函数 开头, 但是 他会 最后 被执行 A defer statement defers the execution of a function until the surrounding function returns. 如果说 为什么 不在 末尾 定义 defer 呢, 因为 当 错误 发生时, 程序 执行 不到 末尾 就会 崩溃. d…...
Log4j2.xml不生效:WARN StatusLogger Multiple logging implementations found:
背景 将 -Dlog4j.debug 添加到IDEA的类的启动配置中 运行上图代码,这里log4j2.xml控制的日志级别是info,很明显是没生效。 DEBUG StatusLogger org.slf4j.helpers.Log4jLoggerFactory is not on classpath. Good! DEBUG StatusLogger Using Shutdow…...
【LeetCode】挑战100天 Day14(热题+面试经典150题)
【LeetCode】挑战100天 Day14(热题面试经典150题) 一、LeetCode介绍二、LeetCode 热题 HOT 100-162.1 题目2.2 题解 三、面试经典 150 题-163.1 题目3.2 题解 一、LeetCode介绍 LeetCode是一个在线编程网站,提供各种算法和数据结构的题目&…...
VMware安装windows操作系统
一、下载镜像包 地址:镜像包地址。 找到需要的版本下载镜像包。 二、安装 打开VMware新建虚拟机,选择用镜像文件。将下载的镜像包加载进去即可。...
国防科技大学计算机基础课程笔记02信息编码
1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制,因此这个了16进制的数据既可以翻译成为这个机器码,也可以翻译成为这个国标码,所以这个时候很容易会出现这个歧义的情况; 因此,我们的这个国…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!
5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...
从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等
🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...
LeetCode - 199. 二叉树的右视图
题目 199. 二叉树的右视图 - 力扣(LeetCode) 思路 右视图是指从树的右侧看,对于每一层,只能看到该层最右边的节点。实现思路是: 使用深度优先搜索(DFS)按照"根-右-左"的顺序遍历树记录每个节点的深度对于…...
C++.OpenGL (14/64)多光源(Multiple Lights)
多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...
【Android】Android 开发 ADB 常用指令
查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...
云原生安全实战:API网关Envoy的鉴权与限流详解
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关 作为微服务架构的统一入口,负责路由转发、安全控制、流量管理等核心功能。 2. Envoy 由Lyft开源的高性能云原生…...
五、jmeter脚本参数化
目录 1、脚本参数化 1.1 用户定义的变量 1.1.1 添加及引用方式 1.1.2 测试得出用户定义变量的特点 1.2 用户参数 1.2.1 概念 1.2.2 位置不同效果不同 1.2.3、用户参数的勾选框 - 每次迭代更新一次 总结用户定义的变量、用户参数 1.3 csv数据文件参数化 1、脚本参数化 …...
