第T7周:Tensorflow实现咖啡豆识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架:
(二)具体步骤
1. 使用GPU
--------------------------utils.py-------------------
import tensorflow as tf
import PIL
import matplotlib.pyplot as plt def GPU_ON(): # 查询tensorflow版本 print("Tensorflow Version:", tf.__version__) # 设置使用GPU gpus = tf.config.list_physical_devices("GPU") print(gpus) if gpus: gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存按需使用 tf.config.set_visible_devices([gpu0], "GPU")
使用GPU并查看数据
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os, PIL, pathlib
from utils import GPU_ONGPU_ON() data_dir = "./datasets/coffee/"
data_dir = pathlib.Path(data_dir) image_count = len(list(data_dir.glob("*/*.png")))
print("图片总数量为:", image_count)
------------------
图片总数量为: 1200
2. 加载数据
# 加载数据
batch_size = 32
img_height, img_width = 224, 224 train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size,
) val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size,
)
--------------------
Found 1200 files belonging to 4 classes.
Using 960 files for training.
Found 1200 files belonging to 4 classes.
Using 240 files for validation.
获取标签:
# 获取标签
class_names = train_ds.class_names
print(class_names)
------------------
['Dark', 'Green', 'Light', 'Medium']
可视化数据:
# 可视化数据
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(2): for i in range(30): ax = plt.subplot(5, 6, i+1) plt.imshow(images[i].numpy().astype("uint8")) plt.title(class_names[labels[i]]) plt.axis("off")
plt.show()
检查一下数据:
# 检查一下数据
for image_batch, labels_batch in train_ds: print(image_batch.shape) print(labels_batch.shape) break
----------------------------
(32, 224, 224, 3)
(32,)
**3.**配置数据集
# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) image_batch, labels_batch = next(iter(train_ds))
first_image = image_batch[0] # 查看归一化后的数据
print(np.min(first_image), np.max(first_image))
--------------------
0.0 1.0
4.搭建VGG-16网络
本次准备直接调用官方模型
# 搭建VGG-16网络模型
model = tf.keras.applications.VGG16(weights="imagenet")
print(model.summary())
-------------------------------
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467096/553467096 [==============================] - 14s 0us/step
Model: "vgg16"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================input_1 (InputLayer) [(None, 224, 224, 3)] 0 block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 flatten (Flatten) (None, 25088) 0 fc1 (Dense) (None, 4096) 102764544 fc2 (Dense) (None, 4096) 16781312 predictions (Dense) (None, 1000) 4097000 =================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________
简简单单1亿的参数的模型。哈哈。
编译一下:
# 编译模型
# 设置初始学习率
initial_learning_rate = 1e-4
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=initial_learning_rate, decay_steps=30, decay_rate=0.92, staircase=True
) # 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate) model.compile( optimizer=opt, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy']
)
训练模型:
# 训练模型
epochs = 20
history = model.fit( train_ds, validation_data=val_ds, epochs=epochs,
)
训练效果不错,可视化看看:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
果然超赞。
改成动态学习率的结果:
opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
5. 手动搭建VGG-16模型
VGG-16的网络 有13个卷积层(被5个max-pooling层分割)和3个全连接层(FC),所有卷积层过滤器的大小都是3X3,步长为1,进行padding。5个max-pooling层分别在第2、4、7、10,13卷积层后面。每次进行池化(max-pooling)后,特征图的长宽都缩小一半,但是channel都翻倍了,一直到512。最后三个全连接层大小分别是4096,4096, 1000,我们使用的是咖啡豆识别,根据数据集的类别数量修改最后的分类数量(即从1000改成len(class_names))
-----------------------------
Model: "model"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================input_1 (InputLayer) [(None, 224, 224, 3)] 0 block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 flatten (Flatten) (None, 25088) 0 fc1 (Dense) (None, 4096) 102764544 fc2 (Dense) (None, 4096) 16781312 predictions (Dense) (None, 4) 16388 =================================================================
Total params: 134,276,932
Trainable params: 134,276,932
Non-trainable params: 0
_________________________________________________________________
(三)总结
相关文章:

第T7周:Tensorflow实现咖啡豆识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: PyCharm 框 架: (二)具体步骤 1. 使…...

imagehash图片去重:保留图片文件名
简介 在日常工作中,我们可能需要管理大量图片,这些图片中可能存在图像相同文件名不同的情况。手动删除这些重复的图片既费时又费力。为了解决这个问题,我们可以编写一个Python脚本来自动化这个过程。 准备工作 在开始之前,请确保…...

在Docker环境下为Nginx配置HTTPS
前言 配置HTTPS已经成为网站部署的必要步骤。本教程将详细介绍如何在Docker环境下为Nginx配置HTTPS,使用自签名证书来实现加密通信。虽然在生产环境中建议使用权威CA机构颁发的证书,但在开发测试或内网环境中,自签名证书是一个很好的选择。 …...

vue面试题9|[2024-11-15]
问题1:scoped原理 1.作用:让样式在本组件中生效,不影响其他组件 2.原理:给节点新增自定义属性,然后css根据属性选择器添加样式。 问题2:让css只在当前组件生效 <style scoped> 问题3:scss…...

大数据技术在金融风控中的应用
💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 大数据技术在金融风控中的应用 大数据技术在金融风控中的应用 大数据技术在金融风控中的应用 引言 大数据技术概述 定义与原理 发…...

安装一键式重置密码插件(Linux)-CloudResetPwdAgent
为了保证使用镜像创建的裸金属服务器可以实现一键式密码重置功能,建议您在制作镜像时安装重置密码插件“CloudResetPwdAgent”。 前提条件 需保证虚拟机根目录可写入,且剩余空间大于600MB。 1.下载插件包 华为云已提供下载包连接 在PC机里下载好软件…...

如何平滑切换Containerd数据目录
如何平滑切换Containerd数据目录 大家好,我是秋意零。 这是工作中遇到的一个问题。搭建的服务平台,在使用的过程中频繁出现镜像本地拉取不到问题(在项目群聊中老是被人出来😅)原因是由于/目录空间不足导致࿰…...

月影和米家大路灯哪个好?书客、月影、米家谁会更胜一筹!
月影和米家大路灯哪个好?近两年以来,护眼大路灯以良好的品质走进大众的视线,成为许多用眼人群的刚需品,不少用户说可以改善光线质量,视觉疲劳感夜可以减少,但又有人说护眼大路灯是“幌子、智商税”…...

instanceof 的模式匹配(二)
在经过了JEP305(jdk14)和JEP375(jdk15)的两轮预览之后,模式匹配终于迎来了他的交付日期,在2022年发布的JDK16中,伴随着JEP 394的发布,预览结束了,我们来看一下这个特性的结束点到底说了什么。 在这次预览之中ÿ…...

【Spring】Bean的作用域和Spring的执行流程
目录 1.Bean的作用域 1.1 Singleton(单例) 1.2 Prototype(原型) 1.3 适用于SpringMVC的作用域 2.Spring的执行流程 2.1 Spring容器的初始化 2.2 Bean的创建和装配 2.3 Bean的生命周期管理 2.4 其他重要概念 3. Spring的执行流程简洁版 1.Bean的作用域 Spring Bean的…...

自动驾驶系列—从数据采集到存储:解密自动驾驶传感器数据采集盒子的关键技术
🌟🌟 欢迎来到我的技术小筑,一个专为技术探索者打造的交流空间。在这里,我们不仅分享代码的智慧,还探讨技术的深度与广度。无论您是资深开发者还是技术新手,这里都有一片属于您的天空。让我们在知识的海洋中…...

QtWebServer
QtWebServer 是创建基于 Qt 的高性能 Web 应用程序服务器的尝试,即。运行本机 C/Qt 代码以交付网站的 Web 服务器。 一个完美的用例是为较小的服务提供 REST API。 在 Qt 应用程序中,您可以设置资源并将其绑定到物理提供程序,例如文件或数据…...

网络基础概念与应用:深入理解计算机网络
引言 计算机网络作为现代信息技术的重要支柱,是连接世界各地的重要纽带。它使得计算机能够相互通信、协同工作,从而极大地提高了我们的工作效率和生活质量。本篇文章将深入探讨计算机网络的基础概念,覆盖网络的分层模型、协议、数据传输原理…...

<el-select> :remote-method用法
el-select :remote-method用法 说明代码实现单选多选 说明 在 Vue.js 中, 是 Element UI 库提供的一个下拉选择框组件。:remote-method 是 组件的一个属性,用于指定一个远程方法,该方法将在用户输入时被调用,以获取下拉列表的选项…...

CKA认证 | Day3 K8s管理应用生命周期(上)
第四章 应用程序生命周期管理(上) 1、在Kubernetes中部署应用流程 1.1 使用Deployment部署Java应用 在 Kubernetes 中,Deployment 是一种控制器,用于管理 Pod 的部署和更新。以下是使用 Deployment 部署 Java 应用的步骤&#x…...

JavaWeb——HTML、CSS
目录 1.概述 2.HTML a.HTML结构标签 b.图片标签 c.标题标签 d.水平线标签 e.布局标签 f.超链接标签 e.视频标签 f.音频标签 e.换行标签 f.段落标签 g.加粗标签 h.表格 1.声明表格 2.表行 3.普通表格 4.加粗表格 i.表单标签 1.声明表单 2. 表单 3.下拉列表…...

springboot如何获取控制层get和Post入参
一、在 Spring 配置中创建一个过滤器,将 HttpServletRequest 包装为 ContentCachingRequestWrapper import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; import j…...

30 秒!用通义灵码画 SpaceX 星链发射流程图
不想读前人“骨灰级”代码, 不想当“牛马”程序员, 想像看图片一样快速读复杂代码和架构? 来了,灵码又加新 buff!! 通义灵码支持代码逻辑可视化, 可以把你的每段代码画成流程图。 你可以把…...

设计模式之组合模式(营销差异化人群发券,决策树引擎搭建场景)
前言: 往往很多大需求都是通过增删改查堆出来的,今天要一个需求if一下,明天加个内容else扩展一下。日积月累需求也就越来越大,扩展和维护的成本也就越来越高。往往大部分研发是不具备产品思维和整体业务需求导向的,总以…...

关于做完 C# 项目的问题总结
1. .Any()方法使用 可以与其他LINQ方法结合使用,以构建更复杂的查询。例如,你可以首先过滤集合,然后检查过滤后的集合是否包含任何元素: List<string> fruits new List<string> { "Apple", "Banana&q…...

CSS响应式布局实现1920屏幕1rem等于100px
代码解析与实现 设置根元素的 font-size 为 5.208333vw 假设你想让根元素的 font-size 基于视口宽度来动态调整。我们可以通过设置 font-size 为 5.208333vw 来让 1rem 相当于视口宽度的 5.208333%。 计算 5.208333vw: 当屏幕宽度为 1920px 时,5.208333vw 等于 5…...

根据当前浏览器版本,下载或更新驱动文件为对应的版本
以前通过ChromeDriverManager().install()的方式自动下载驱动的方式,现在行不通了,访问不通下载网址,会报错:requests.exceptions.ConnectionError: Could not reach host. Are you offline? 所以想着换一个下载地址和方式&…...

【轻量化】YOLOv10 更换骨干网络之 MobileNetv4 | 模块化加法!非 timm 包!
之前咱们在这个文章中讲了timm包的加法,不少同学反馈要模块化的加法,那么这篇就讲解下模块化的加法,值得注意的是,这样改加载不了mobilebnetv4官方开源的权重了~ 论文地址:https://arxiv.org/pdf/2404.10518 代码地址:https://github.com/tensorflow/models/blob/master…...

人体存在感应器设置时间开启感应人存在开灯,失效
环境: 领普人体存在感应器 问题描述: 人体存在感应器设置时间开启感应人存在开灯,失效,设置下午5点,如果有人在5点前一直在这个区域,这个时候到了5点,就触发不了感应自动打开灯光。 解决方案:…...

2024年09月CCF-GESP编程能力等级认证Python编程二级真题解析
本文收录于专栏《Python等级认证CCF-GESP真题解析》,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 一、单选题(每题 2 分,共 30 分) 第 1 题 据有关资料,山东大学于1972年研制成功DJL-1计算机,并于1973年投入运行,其综合性能居当时全国第三位。DJL-1计算机运算控…...

Vuex vs Pinia:新一代Vue状态管理方案对比
引言 随着Vue生态系统的不断发展,状态管理已经成为现代Vue应用程序中不可或缺的一部分。Vuex作为Vue官方的状态管理方案,一直是开发者的首选。然而,随着Pinia的出现,为Vue开发者带来了新的选择。本文将深入对比这两个状态管理方案…...

es查询报错:too_many_buckets_exception
故障排除 es查询报错:too_many_buckets_exception {"error":{"root_cause":[],"type":"search_phase_execution_exception","reason":"","phase":"fetch","grouped":…...

outlook邮箱关闭垃圾邮件——PowerAutomate自动化任务
微软邮箱反垃圾已经很强大了非常敏感,自家的域名的邮件都能给扔到垃圾邮箱里,但还是在本地增加了一层垃圾邮箱功能,然后垃圾邮箱并没有提示,导致错过很多通知,本身并没有提供关闭的功能,但微软有个Microsof…...

机器学习(七)——集成学习(个体与集成、Boosting、Bagging、随机森林RF、结合策略、多样性增强、多样性度量、Python源码)
目录 关于1 个体与集成2 Boosting3 Bagging与随机森林4 结合策略5 多样性X 案例代码X.1 分类任务-Adaboost-SVMX.1.1 源码X.1.2 数据集(鸢尾花数据集)X.1.3 模型效果 X.2 分类任务-随机森林RFX.2.1 源码X.2.2 数据集(鸢尾花数据集)…...

vue跳转传参
path 跳转只能使用 query 传参 ,name 跳转都可以 params :获取来自动态路由的参数 query :获取来自 search 部分的参数...