第T7周:Tensorflow实现咖啡豆识别
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
目标
具体实现
(一)环境
语言环境:Python 3.10
编 译 器: PyCharm
框 架:
(二)具体步骤
1. 使用GPU
--------------------------utils.py-------------------
import tensorflow as tf
import PIL
import matplotlib.pyplot as plt def GPU_ON(): # 查询tensorflow版本 print("Tensorflow Version:", tf.__version__) # 设置使用GPU gpus = tf.config.list_physical_devices("GPU") print(gpus) if gpus: gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存按需使用 tf.config.set_visible_devices([gpu0], "GPU")
使用GPU并查看数据
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os, PIL, pathlib
from utils import GPU_ONGPU_ON() data_dir = "./datasets/coffee/"
data_dir = pathlib.Path(data_dir) image_count = len(list(data_dir.glob("*/*.png")))
print("图片总数量为:", image_count)
------------------
图片总数量为: 1200
2. 加载数据
# 加载数据
batch_size = 32
img_height, img_width = 224, 224 train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size,
) val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size,
)
--------------------
Found 1200 files belonging to 4 classes.
Using 960 files for training.
Found 1200 files belonging to 4 classes.
Using 240 files for validation.
获取标签:
# 获取标签
class_names = train_ds.class_names
print(class_names)
------------------
['Dark', 'Green', 'Light', 'Medium']
可视化数据:
# 可视化数据
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(2): for i in range(30): ax = plt.subplot(5, 6, i+1) plt.imshow(images[i].numpy().astype("uint8")) plt.title(class_names[labels[i]]) plt.axis("off")
plt.show()

检查一下数据:
# 检查一下数据
for image_batch, labels_batch in train_ds: print(image_batch.shape) print(labels_batch.shape) break
----------------------------
(32, 224, 224, 3)
(32,)
**3.**配置数据集
# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) image_batch, labels_batch = next(iter(train_ds))
first_image = image_batch[0] # 查看归一化后的数据
print(np.min(first_image), np.max(first_image))
--------------------
0.0 1.0
4.搭建VGG-16网络
本次准备直接调用官方模型
# 搭建VGG-16网络模型
model = tf.keras.applications.VGG16(weights="imagenet")
print(model.summary())
-------------------------------
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467096/553467096 [==============================] - 14s 0us/step
Model: "vgg16"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================input_1 (InputLayer) [(None, 224, 224, 3)] 0 block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 flatten (Flatten) (None, 25088) 0 fc1 (Dense) (None, 4096) 102764544 fc2 (Dense) (None, 4096) 16781312 predictions (Dense) (None, 1000) 4097000 =================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
_________________________________________________________________
简简单单1亿的参数的模型。哈哈。
编译一下:
# 编译模型
# 设置初始学习率
initial_learning_rate = 1e-4
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( initial_learning_rate=initial_learning_rate, decay_steps=30, decay_rate=0.92, staircase=True
) # 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate) model.compile( optimizer=opt, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy']
)
训练模型:
# 训练模型
epochs = 20
history = model.fit( train_ds, validation_data=val_ds, epochs=epochs,
)

训练效果不错,可视化看看:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

果然超赞。
改成动态学习率的结果:
opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

5. 手动搭建VGG-16模型



VGG-16的网络 有13个卷积层(被5个max-pooling层分割)和3个全连接层(FC),所有卷积层过滤器的大小都是3X3,步长为1,进行padding。5个max-pooling层分别在第2、4、7、10,13卷积层后面。每次进行池化(max-pooling)后,特征图的长宽都缩小一半,但是channel都翻倍了,一直到512。最后三个全连接层大小分别是4096,4096, 1000,我们使用的是咖啡豆识别,根据数据集的类别数量修改最后的分类数量(即从1000改成len(class_names))
-----------------------------
Model: "model"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================input_1 (InputLayer) [(None, 224, 224, 3)] 0 block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 flatten (Flatten) (None, 25088) 0 fc1 (Dense) (None, 4096) 102764544 fc2 (Dense) (None, 4096) 16781312 predictions (Dense) (None, 4) 16388 =================================================================
Total params: 134,276,932
Trainable params: 134,276,932
Non-trainable params: 0
_________________________________________________________________
(三)总结
相关文章:
第T7周:Tensorflow实现咖啡豆识别
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 目标 具体实现 (一)环境 语言环境:Python 3.10 编 译 器: PyCharm 框 架: (二)具体步骤 1. 使…...
imagehash图片去重:保留图片文件名
简介 在日常工作中,我们可能需要管理大量图片,这些图片中可能存在图像相同文件名不同的情况。手动删除这些重复的图片既费时又费力。为了解决这个问题,我们可以编写一个Python脚本来自动化这个过程。 准备工作 在开始之前,请确保…...
在Docker环境下为Nginx配置HTTPS
前言 配置HTTPS已经成为网站部署的必要步骤。本教程将详细介绍如何在Docker环境下为Nginx配置HTTPS,使用自签名证书来实现加密通信。虽然在生产环境中建议使用权威CA机构颁发的证书,但在开发测试或内网环境中,自签名证书是一个很好的选择。 …...
vue面试题9|[2024-11-15]
问题1:scoped原理 1.作用:让样式在本组件中生效,不影响其他组件 2.原理:给节点新增自定义属性,然后css根据属性选择器添加样式。 问题2:让css只在当前组件生效 <style scoped> 问题3:scss…...
大数据技术在金融风控中的应用
💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 大数据技术在金融风控中的应用 大数据技术在金融风控中的应用 大数据技术在金融风控中的应用 引言 大数据技术概述 定义与原理 发…...
安装一键式重置密码插件(Linux)-CloudResetPwdAgent
为了保证使用镜像创建的裸金属服务器可以实现一键式密码重置功能,建议您在制作镜像时安装重置密码插件“CloudResetPwdAgent”。 前提条件 需保证虚拟机根目录可写入,且剩余空间大于600MB。 1.下载插件包 华为云已提供下载包连接 在PC机里下载好软件…...
如何平滑切换Containerd数据目录
如何平滑切换Containerd数据目录 大家好,我是秋意零。 这是工作中遇到的一个问题。搭建的服务平台,在使用的过程中频繁出现镜像本地拉取不到问题(在项目群聊中老是被人出来😅)原因是由于/目录空间不足导致࿰…...
月影和米家大路灯哪个好?书客、月影、米家谁会更胜一筹!
月影和米家大路灯哪个好?近两年以来,护眼大路灯以良好的品质走进大众的视线,成为许多用眼人群的刚需品,不少用户说可以改善光线质量,视觉疲劳感夜可以减少,但又有人说护眼大路灯是“幌子、智商税”…...
instanceof 的模式匹配(二)
在经过了JEP305(jdk14)和JEP375(jdk15)的两轮预览之后,模式匹配终于迎来了他的交付日期,在2022年发布的JDK16中,伴随着JEP 394的发布,预览结束了,我们来看一下这个特性的结束点到底说了什么。 在这次预览之中ÿ…...
【Spring】Bean的作用域和Spring的执行流程
目录 1.Bean的作用域 1.1 Singleton(单例) 1.2 Prototype(原型) 1.3 适用于SpringMVC的作用域 2.Spring的执行流程 2.1 Spring容器的初始化 2.2 Bean的创建和装配 2.3 Bean的生命周期管理 2.4 其他重要概念 3. Spring的执行流程简洁版 1.Bean的作用域 Spring Bean的…...
自动驾驶系列—从数据采集到存储:解密自动驾驶传感器数据采集盒子的关键技术
🌟🌟 欢迎来到我的技术小筑,一个专为技术探索者打造的交流空间。在这里,我们不仅分享代码的智慧,还探讨技术的深度与广度。无论您是资深开发者还是技术新手,这里都有一片属于您的天空。让我们在知识的海洋中…...
QtWebServer
QtWebServer 是创建基于 Qt 的高性能 Web 应用程序服务器的尝试,即。运行本机 C/Qt 代码以交付网站的 Web 服务器。 一个完美的用例是为较小的服务提供 REST API。 在 Qt 应用程序中,您可以设置资源并将其绑定到物理提供程序,例如文件或数据…...
网络基础概念与应用:深入理解计算机网络
引言 计算机网络作为现代信息技术的重要支柱,是连接世界各地的重要纽带。它使得计算机能够相互通信、协同工作,从而极大地提高了我们的工作效率和生活质量。本篇文章将深入探讨计算机网络的基础概念,覆盖网络的分层模型、协议、数据传输原理…...
<el-select> :remote-method用法
el-select :remote-method用法 说明代码实现单选多选 说明 在 Vue.js 中, 是 Element UI 库提供的一个下拉选择框组件。:remote-method 是 组件的一个属性,用于指定一个远程方法,该方法将在用户输入时被调用,以获取下拉列表的选项…...
CKA认证 | Day3 K8s管理应用生命周期(上)
第四章 应用程序生命周期管理(上) 1、在Kubernetes中部署应用流程 1.1 使用Deployment部署Java应用 在 Kubernetes 中,Deployment 是一种控制器,用于管理 Pod 的部署和更新。以下是使用 Deployment 部署 Java 应用的步骤&#x…...
JavaWeb——HTML、CSS
目录 1.概述 2.HTML a.HTML结构标签 b.图片标签 c.标题标签 d.水平线标签 e.布局标签 f.超链接标签 e.视频标签 f.音频标签 e.换行标签 f.段落标签 g.加粗标签 h.表格 1.声明表格 2.表行 3.普通表格 4.加粗表格 i.表单标签 1.声明表单 2. 表单 3.下拉列表…...
springboot如何获取控制层get和Post入参
一、在 Spring 配置中创建一个过滤器,将 HttpServletRequest 包装为 ContentCachingRequestWrapper import org.springframework.stereotype.Component; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; import j…...
30 秒!用通义灵码画 SpaceX 星链发射流程图
不想读前人“骨灰级”代码, 不想当“牛马”程序员, 想像看图片一样快速读复杂代码和架构? 来了,灵码又加新 buff!! 通义灵码支持代码逻辑可视化, 可以把你的每段代码画成流程图。 你可以把…...
设计模式之组合模式(营销差异化人群发券,决策树引擎搭建场景)
前言: 往往很多大需求都是通过增删改查堆出来的,今天要一个需求if一下,明天加个内容else扩展一下。日积月累需求也就越来越大,扩展和维护的成本也就越来越高。往往大部分研发是不具备产品思维和整体业务需求导向的,总以…...
关于做完 C# 项目的问题总结
1. .Any()方法使用 可以与其他LINQ方法结合使用,以构建更复杂的查询。例如,你可以首先过滤集合,然后检查过滤后的集合是否包含任何元素: List<string> fruits new List<string> { "Apple", "Banana&q…...
从WWDC看苹果产品发展的规律
WWDC 是苹果公司一年一度面向全球开发者的盛会,其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具,对过去十年 WWDC 主题演讲内容进行了系统化分析,形成了这份…...
在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:
在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档,…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
转转集团旗下首家二手多品类循环仓店“超级转转”开业
6月9日,国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解,“超级…...
Robots.txt 文件
什么是robots.txt? robots.txt 是一个位于网站根目录下的文本文件(如:https://example.com/robots.txt),它用于指导网络爬虫(如搜索引擎的蜘蛛程序)如何抓取该网站的内容。这个文件遵循 Robots…...
什么是EULA和DPA
文章目录 EULA(End User License Agreement)DPA(Data Protection Agreement)一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA(End User License Agreement) 定义: EULA即…...
Device Mapper 机制
Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...
C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
【Java学习笔记】BigInteger 和 BigDecimal 类
BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点:传参类型必须是类对象 一、BigInteger 1. 作用:适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...
【Redis】笔记|第8节|大厂高并发缓存架构实战与优化
缓存架构 代码结构 代码详情 功能点: 多级缓存,先查本地缓存,再查Redis,最后才查数据库热点数据重建逻辑使用分布式锁,二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...
