当前位置: 首页 > news >正文

基于tensorflow使用VGG16实现猫狗识别

import tensorflow as tf
import numpy as np
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator# 定义 VGG16 模型
class VGG16(tf.keras.Model):def __init__(self, num_classes=2):super(VGG16, self).__init__()self.features = models.Sequential([layers.Conv2D(64, (3, 3), padding='same', activation='relu', input_shape=(224, 224, 3)),layers.MaxPooling2D((2, 2), strides=(2, 2)),layers.Conv2D(128, (3, 3), padding='same', activation='relu'),layers.MaxPooling2D((2, 2), strides=(2, 2)),layers.Conv2D(256, (3, 3), padding='same', activation='relu'),layers.Conv2D(256, (3, 3), padding='same', activation='relu'),layers.MaxPooling2D((2, 2), strides=(2, 2)),layers.Conv2D(512, (3, 3), padding='same', activation='relu'),layers.Conv2D(512, (3, 3), padding='same', activation='relu'),layers.MaxPooling2D((2, 2), strides=(2, 2)),layers.Conv2D(512, (3, 3), padding='same', activation='relu'),layers.Conv2D(512, (3, 3), padding='same', activation='relu'),layers.MaxPooling2D((2, 2), strides=(2, 2)),])self.classifier = models.Sequential([layers.Flatten(),layers.Dense(4096, activation='relu'),layers.Dropout(0.5),layers.Dense(4096, activation='relu'),layers.Dropout(0.5),layers.Dense(num_classes, activation='softmax'),])def call(self, x):x = self.features(x)x = self.classifier(x)return x# 使用 ImageDataGenerator 加载并预处理数据集
data_dir = 'data'
input_shape = (224, 224)
batch_size = 4train_datagen = ImageDataGenerator(rescale=1.0/255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True
)
val_datagen = ImageDataGenerator(rescale=1.0/255)train_gen = train_datagen.flow_from_directory(directory=f'{data_dir}/train',target_size=input_shape,batch_size=batch_size,class_mode='binary'
)val_gen = val_datagen.flow_from_directory(directory=f'{data_dir}/validation',target_size=input_shape,batch_size=batch_size,class_mode='binary'
)# 初始化模型、优化器和损失函数
model = VGG16(num_classes=2)
# 构建模型结构(明确指定输入形状)
model.build(input_shape=(None, 224, 224, 3))  # None 表示动态批次大小# 查看模型结构
model.summary()
model.compile(optimizer=optimizers.Adam(learning_rate=0.0001),loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练循环
epochs = 20
steps_per_epoch = train_gen.samples // batch_size
validation_steps = val_gen.samples // batch_sizefor epoch in range(epochs):print(f"=========== Epoch {epoch + 1} ==============")history = model.fit(train_gen,steps_per_epoch=steps_per_epoch,validation_data=val_gen,validation_steps=validation_steps,epochs=1)train_loss = history.history['loss'][0]val_loss = history.history['val_loss'][0]val_accuracy = history.history['val_accuracy'][0]print(f"训练集上的损失:{train_loss}")print(f"验证集上的损失:{val_loss}")print(f"验证集上的精度:{val_accuracy:.1%}")# 保存模型model.save_weights(f"Adogandcat_epoch_{epoch + 1}.h5")print("模型权重已保存。")#预测部分
# 定义和加载 VGG16 模型
vgg16 = VGG16(num_classes=2)
vgg16.build(input_shape=(None, 224, 224, 3))
vgg16.load_weights('Adogandcat_epoch_20.h5')  # 替换为训练好的 VGG16 权重路径# 加载和预处理图像
def load_and_preprocess_image(image_path, target_size=(224, 224)):img = load_img(image_path, target_size=target_size)  # 加载图像并调整大小img_array = img_to_array(img)  # 转换为 NumPy 数组img_array = np.expand_dims(img_array, axis=0)  # 添加批次维度img_array = preprocess_input(img_array)  # VGG16 所需的标准化return img, img_array# 预测和显示图像
def predict_and_display(image_path, model, model_name):# 加载图像original_img, processed_img = load_and_preprocess_image(image_path)# 预测类别predictions = model(processed_img, training=False)predicted_class = np.argmax(predictions, axis=1)[0]confidence = predictions[0][predicted_class]# 显示结果plt.figure(figsize=(6, 6))plt.imshow(original_img)plt.axis('off')plt.title(f"Model: {model_name}\nPredicted Class: {predicted_class}\nConfidence: {confidence:.2f}")plt.show()# 测试图像路径
image_path = 'data/test/1.jpg'  # 替换为实际图像路径# 使用 CustomCNN 预测
predict_and_display(image_path, custom_cnn, "Custom CNN")# 使用 VGG16 预测
predict_and_display(image_path, vgg16, "VGG16")

训练结果:

运行时间:46 mins.

Found 18750 images belonging to 2 classes.

Found 6250 images belonging to 2 classes.

2024-11-22 00:51:58.211251: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.

2024-11-22 00:51:58.612873: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38404 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:4b:00.0, compute capability: 8.0

=========== Epoch 1 ==============

4687/4687 [==============================] - 140s 29ms/step - loss: 0.6943 - accuracy: 0.5031 - val_loss: 0.6932 - val_accuracy: 0.5010

训练集上的损失:0.6943416595458984

验证集上的损失:0.6932020783424377

验证集上的精度:50.1%

模型已保存。

=========== Epoch 2 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.6843 - accuracy: 0.5500 - val_loss: 0.6604 - val_accuracy: 0.6079

训练集上的损失:0.6842659711837769

验证集上的损失:0.660356879234314

验证集上的精度:60.8%

模型已保存。

=========== Epoch 3 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.6018 - accuracy: 0.6783 - val_loss: 0.4973 - val_accuracy: 0.7553

训练集上的损失:0.6018291711807251

验证集上的损失:0.4972682297229767

验证集上的精度:75.5%

模型已保存。

=========== Epoch 4 ==============

4687/4687 [==============================] - 139s 30ms/step - loss: 0.4743 - accuracy: 0.7762 - val_loss: 0.4171 - val_accuracy: 0.8156

训练集上的损失:0.4742658734321594

验证集上的损失:0.4170766770839691

验证集上的精度:81.6%

模型已保存。

=========== Epoch 5 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.3851 - accuracy: 0.8273 - val_loss: 0.3572 - val_accuracy: 0.8489

训练集上的损失:0.3850820064544678

验证集上的损失:0.3571555018424988

验证集上的精度:84.9%

模型已保存。

=========== Epoch 6 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.3096 - accuracy: 0.8676 - val_loss: 0.2901 - val_accuracy: 0.8841

训练集上的损失:0.30961713194847107

验证集上的损失:0.29008445143699646

验证集上的精度:88.4%

模型已保存。

=========== Epoch 7 ==============

4687/4687 [==============================] - 139s 30ms/step - loss: 0.2486 - accuracy: 0.8966 - val_loss: 0.2143 - val_accuracy: 0.9088

训练集上的损失:0.2486010491847992

验证集上的损失:0.2143394649028778

验证集上的精度:90.9%

模型已保存。

=========== Epoch 8 ==============

4687/4687 [==============================] - 139s 30ms/step - loss: 0.2155 - accuracy: 0.9101 - val_loss: 0.1907 - val_accuracy: 0.9205

训练集上的损失:0.2155471295118332

验证集上的损失:0.1906772404909134

验证集上的精度:92.0%

模型已保存。

=========== Epoch 9 ==============

4687/4687 [==============================] - 139s 30ms/step - loss: 0.1929 - accuracy: 0.9192 - val_loss: 0.1902 - val_accuracy: 0.9214

训练集上的损失:0.19291609525680542

验证集上的损失:0.19024263322353363

验证集上的精度:92.1%

模型已保存。

=========== Epoch 10 ==============

4687/4687 [==============================] - 143s 30ms/step - loss: 0.1751 - accuracy: 0.9284 - val_loss: 0.1607 - val_accuracy: 0.9337

训练集上的损失:0.17511127889156342

验证集上的损失:0.1606709510087967

验证集上的精度:93.4%

模型已保存。

=========== Epoch 11 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.1559 - accuracy: 0.9391 - val_loss: 0.1388 - val_accuracy: 0.9416

训练集上的损失:0.155866801738739

验证集上的损失:0.13884252309799194

验证集上的精度:94.2%

模型已保存。

=========== Epoch 12 ==============

4687/4687 [==============================] - 137s 29ms/step - loss: 0.1457 - accuracy: 0.9401 - val_loss: 0.1550 - val_accuracy: 0.9470

训练集上的损失:0.14570224285125732

验证集上的损失:0.15503031015396118

验证集上的精度:94.7%

模型已保存。

=========== Epoch 13 ==============

4687/4687 [==============================] - 138s 30ms/step - loss: 0.1359 - accuracy: 0.9451 - val_loss: 0.1201 - val_accuracy: 0.9520

训练集上的损失:0.1359049379825592

验证集上的损失:0.12010601162910461

验证集上的精度:95.2%

模型已保存。

=========== Epoch 14 ==============

4687/4687 [==============================] - 138s 30ms/step - loss: 0.1293 - accuracy: 0.9489 - val_loss: 0.1366 - val_accuracy: 0.9486

训练集上的损失:0.12929560244083405

验证集上的损失:0.13661223649978638

验证集上的精度:94.9%

模型已保存。

=========== Epoch 15 ==============

4687/4687 [==============================] - 139s 30ms/step - loss: 0.1206 - accuracy: 0.9516 - val_loss: 0.1472 - val_accuracy: 0.9478

训练集上的损失:0.12062755972146988

验证集上的损失:0.1471676379442215

验证集上的精度:94.8%

模型已保存。

=========== Epoch 16 ==============

4687/4687 [==============================] - 137s 29ms/step - loss: 0.1174 - accuracy: 0.9544 - val_loss: 0.1282 - val_accuracy: 0.9475

训练集上的损失:0.11741997301578522

验证集上的损失:0.1282137632369995

验证集上的精度:94.8%

模型已保存。

=========== Epoch 17 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.1139 - accuracy: 0.9552 - val_loss: 0.1264 - val_accuracy: 0.9563

训练集上的损失:0.11387941241264343

验证集上的损失:0.12638334929943085

验证集上的精度:95.6%

模型已保存。

=========== Epoch 18 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.1123 - accuracy: 0.9557 - val_loss: 0.1192 - val_accuracy: 0.9585

训练集上的损失:0.11233851313591003

验证集上的损失:0.11923010647296906

验证集上的精度:95.9%

模型已保存。

=========== Epoch 19 ==============

4687/4687 [==============================] - 138s 29ms/step - loss: 0.1124 - accuracy: 0.9579 - val_loss: 0.1174 - val_accuracy: 0.9534

训练集上的损失:0.11243574321269989

验证集上的损失:0.11737789213657379

验证集上的精度:95.3%

模型已保存。

=========== Epoch 20 ==============

4687/4687 [==============================] - 150s 32ms/step - loss: 0.0986 - accuracy: 0.9616 - val_loss: 0.1180 - val_accuracy: 0.9539

训练集上的损失:0.09860984236001968

验证集上的损失:0.11801616102457047

验证集上的精度:95.4%

模型已保存。

原文使用的是pytorch:

VGG网络详解(实现猫猫和狗狗识别)_vgg卷积神经网络猫狗实验-CSDN博客

相关文章:

基于tensorflow使用VGG16实现猫狗识别

import tensorflow as tf import numpy as np from tensorflow.keras import layers, models, optimizers from tensorflow.keras.preprocessing.image import ImageDataGenerator# 定义 VGG16 模型 class VGG16(tf.keras.Model):def __init__(self, num_classes2):super(VGG16…...

第18章 EXISTS 与 NOT EXISTS 关键字

一、EXISTS 关键字介绍 关键字介绍EXISTS 关联子查询通常也会和 EXISTS操作符一起来使用,用来检查在子查询中是否存在满足条件的行。 如果在子查询中当前的行不满足条件:返回 FALSE,继续在子查询中查找 如果在子查询中当前的行满足条件&…...

Windows多JDK版本管理工具JVMs

Windows多JDK版本管理工具JVMs 官网安装使用手动下载jdk 官网 https://github.com/ystyle/jvms 下载 https://github.com/ystyle/jvms/releases 当前下载版本为v2.1.6 安装 下载后,解压到某个目录。 比如:D:\soft\JVMs\jvms_v2.1.6_amd64 把这个目录…...

【C++】初始化列表、类型转换

目录: 一、const成员函数 二、初始化列表 三、类型转换 正文 一、const成员函数 (1)将const修饰的成员函数称之为const成员函数,const修饰成员函数放到成员函数参数列表的后⾯。至于为什么这么放是语法规定。 (2&a…...

创新设计,精准仿真|SOLIDWORKS Simulation 2025新功能

SOLIDWORKS Simulation 2025 带来了多项新功能,不仅提高了工作效率,还增强了仿真的精确度。以下是五大新功能的详细介绍,帮助您更好地利用这些新特性提升设计仿真能力。 1. 从分析中排除实体 在复杂的装配体仿真中,有时需要排除某…...

vue3封装Element Plus table表格组件

支持绝大部分Element Plus原有设置属性&#xff0c;支持分页&#xff0c;支持动态适配高度 效果展示 组件代码&#xff1a; <template><div class"table-wrap" ref"tableWrap"><el-tableclass"w100 h100":data"tableInfo.…...

Qt之QWidget相关

Qt概述 Qt 是一个跨平台的 C 开发框架。 跨平台支持&#xff1a;可以用于开发 Windows、macOS、Linux、Android、iOS 等多种操作系统下的应用程序。这意味着开发者使用 Qt 编写的代码&#xff0c;在经过适当的编译和配置后&#xff0c;能够在不同平台上运行&#xff0c;减少了…...

用web前端写出一个高校官网

所实现的效果如链接&#xff1a; http://127.0.0.1:5500/school.html <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>xigongshang</title> <style> * {margin: 0;padding: 0;} a{ text-decoration: none…...

【笔记】Android Gradle Plugin配置文件相关说明-libs.versions.toml

版本号 文件路径&#xff1a;Project\gradle\libs.versions.toml 直接搜索versions.agp是找不到的&#xff0c;这是变量引用的写法&#xff0c;查询 agp版本可以直接查版本号。 [versions] agp "8.5.0-alpha08" junit "4.13.2" junitVersion "1.…...

如何修复WordPress卡在维护模式

当你管理WordPress网站时&#xff0c;没有什么比看到它卡在维护模式更令人沮丧的了。特别是在你进行重要更新或期望大量流量的时候&#xff0c;这种情况会更加令人不安。 维护模式可能由许多因素引起&#xff0c;从简单的文件损坏到更复杂的插件冲突或存在的.maintenance文件。…...

glob三个函数的效果

glob 允许你给一个从dirname到basename pattern的整字符串路径&#xff0c;然后匹配一切符合要求的路径。 glob0 dirname和basename分开传&#xff0c;只返回basename。可见不支持pattern匹配。 glob1 dirname和basename pattern分开传&#xff0c;只返回basename。支持pa…...

FreeRTOS:事件标志组与任务通知

目录 一、事件标志组&#xff08;Event Groups&#xff09; 1、事件标志组的特点 2、事件标志组与队列、信号量的区别 3、关键API函数 4、示例代码 5、优缺点 6、总结 二、任务通知&#xff08;Task Notifications&#xff09; 1、任务通知的特点 2、关键API函数 3、…...

c++11的动态类型

c17引入了any 和 variant&#xff0c;可以将任意数据类型统一用any或variant类型表示&#xff0c;在开发中还是能够带来很多便利的。在c11版本中&#xff0c;可以用下面这个例子&#xff0c;仿照实现一个Any类型。 #include <iostream> #include <stdexcept> #inc…...

付费会员渗透难,腾讯音乐的触顶挑战

腾讯音乐付费用户增长背后&#xff0c;月活跃用户数下滑3%&#xff0c;超级会员渗透率仅1.8%。 转载|原创新熵 作者丨晓伊 编辑丨蕨影 腾讯音乐此次营收同比正增长的到来&#xff0c;殊为不易。要知道&#xff0c;此前已连续四个季度出现营收同比下滑的态势。 11月12日&#x…...

内网安全隧道搭建-ngrok-frp-nps-sapp

1.ngrok 建立内网主机与公网跳板机的连接&#xff1a; 内网主机为客户机&#xff1a; 下载客户端执行 2.frp &#xff08;1&#xff09;以下为内网穿透端口转发 frp服务端配置&#xff1a; bindPort 为frp运行端口 服务端运行 ./frps -c frps.ini frp客户端配置&#xf…...

Load-Balanced-Online-OJ(负载均衡式在线OJ)

负载均衡式在线OJ 1. 项目介绍2. 项目说明4. 项目代码5. 项目演示 1. 项目介绍 2. 项目说明 4. 项目代码 5. 项目演示...

Postman之变量操作

系列文章目录 1.Postman之安装及汉化基本使用介绍 2.Postman之变量操作 3.Postman之数据提取 4.Postman之pm.test断言操作 5.Postman之newman Postman之变量操作 1.pm.globals全局变量2.pm.environment环境变量3.pm.collectionVariables集合变量4.pm.variables5.提取数据-设置变…...

查找字符串中某个字符返回字符位置

当然有正则表达式就非常简单,没有的话也不用担心,我们自己写算法完成这个功能. 早期的vs版本不支持vs,当然也可以下载boost来实现,关键还是不想下载,那么就自己写吧. 1.要求,查找字符串中同一个字符,并找出字符的位置. 2.根据字符位置,计算出对应的x,y坐标. 算法第一步,查找字…...

《数学物理学报》

作者指南 一、目的与范围 《数学物理学报》以刊登数学与物理科学的边缘学科中具有创造性的代表学科水平的科研成果为主的综合性学术刊物。其目的旨在向专业读者&#xff08;研究生水平以上&#xff09;提供数理学科领域的重要的、独创的成果。 二、投稿 要求文章内容具有创新…...

39页PDF | 毕马威_数据资产运营白皮书(限免下载)

一、前言 《毕马威数据资产运营白皮书》探讨了数据作为新型生产要素在企业数智化转型中的重要性&#xff0c;提出了数据资产运营的“三要素”&#xff08;组织与意识、流程与规范、平台与工具&#xff09;和“四重奏”&#xff08;数据资产盘点、评估、治理、共享&#xff09;…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

k8s从入门到放弃之Ingress七层负载

k8s从入门到放弃之Ingress七层负载 在Kubernetes&#xff08;简称K8s&#xff09;中&#xff0c;Ingress是一个API对象&#xff0c;它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress&#xff0c;你可…...

深入理解JavaScript设计模式之单例模式

目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式&#xff08;Singleton Pattern&#…...

django filter 统计数量 按属性去重

在Django中&#xff0c;如果你想要根据某个属性对查询集进行去重并统计数量&#xff0c;你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求&#xff1a; 方法1&#xff1a;使用annotate()和Count 假设你有一个模型Item&#xff0c;并且你想…...

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

在Ubuntu中设置开机自动运行(sudo)指令的指南

在Ubuntu系统中&#xff0c;有时需要在系统启动时自动执行某些命令&#xff0c;特别是需要 sudo权限的指令。为了实现这一功能&#xff0c;可以使用多种方法&#xff0c;包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法&#xff0c;并提供…...

PL0语法,分析器实现!

简介 PL/0 是一种简单的编程语言,通常用于教学编译原理。它的语法结构清晰,功能包括常量定义、变量声明、过程(子程序)定义以及基本的控制结构(如条件语句和循环语句)。 PL/0 语法规范 PL/0 是一种教学用的小型编程语言,由 Niklaus Wirth 设计,用于展示编译原理的核…...

【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)

升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点&#xff0c;但无自动故障转移能力&#xff0c;Master宕机后需人工切换&#xff0c;期间消息可能无法读取。Slave仅存储数据&#xff0c;无法主动升级为Master响应请求&#xff…...

NLP学习路线图(二十三):长短期记忆网络(LSTM)

在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...

今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存

文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...