当前位置: 首页 > news >正文

深度学习笔记11-优化器对比实验(Tensorflow)

  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

目录

一、导入数据并检查

二、配置数据集

三、数据可视化

四、构建模型

五、训练模型

六、模型对比评估

七、总结


一、导入数据并检查

import pathlib,PIL
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签data_dir    = pathlib.Path("./T6")
image_count = len(list(data_dir.glob('*/*')))
batch_size = 16
img_height = 336
img_width  = 336
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break

二、配置数据集

AUTOTUNE = tf.data.AUTOTUNE
#归一化处理
def train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

三、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

四、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Modeldef create_model(optimizer='adam'):# 加载预训练模型vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',include_top=False,#不包含顶层的全连接层input_shape=(img_width, img_height, 3),pooling='avg')#平均池化层替代顶层的全连接层for layer in vgg16_base_model.layers:layer.trainable = False  #将 trainable属性设置为 False 意味着在训练过程中,这些层的权重不会更新X = vgg16_base_model.outputX = Dense(170, activation='relu')(X)X = BatchNormalization()(X)X = Dropout(0.5)(X)output = Dense(len(class_names), activation='softmax')(X)#神经元数量等于类别数vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)vgg16_model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['accuracy'])return vgg16_modelmodel1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())#随机梯度下降(SGD)优化器的
model2.summary()

五、训练模型

NO_EPOCHS = 20history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)

六、模型对比评估

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']epochs_range = range(len(acc1))plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))plt.show()

可以看出,在这个实例中,Adam优化器的效果优于SGD优化器

七、总结

      通过本次实验,学会了比较不同优化器(Adam和SGD)在训练过程中的性能表现,可视化训练过程的损失曲线和准确率等指标。这是一项非常重要的技能,在研究论文中,可以通过这些优化方法可以提高工作量。

相关文章:

深度学习笔记11-优化器对比实验(Tensorflow)

🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 目录 一、导入数据并检查 二、配置数据集 三、数据可视化 四、构建模型 五、训练模型 六、模型对比评估 七、总结 一、导入数据并检查 import pathlib,…...

【掌握 JavaScript 数组迭代:map 和 includes 的使用技巧】

map map()方法是数组原型的一个函数,用于对数组的每个元素执行一个函数,并返回一个新的数组,其中包含么哦一个元素执行的结果。 语法 const newArray array.map(callback(currentValue, index, arr), thisValue)参数 callback&#xff1…...

深入浅出 Android AES 加密解密:从理论到实战

深入浅出 Android AES 加密解密:从理论到实战 在现代移动应用中,数据安全是不可忽视的一环。无论是用户隐私保护,还是敏感信息的存储与传输,加密技术都扮演着重要角色。本文将以 AES(Advanced Encryption Standard&am…...

Clickhouse基础(一)

数据存储的目录,在存储数据时是先经过压缩后再存储的,压缩效率很高 操作命令: sudo clickhouse start sudo clickhouse restart sudo clickhouse status进入clickhouse clickhouse-client -mCREATE TABLE db_13.t_assist (modelId UInt64,…...

深度学习|表示学习|一个神经元可以干什么|02

如是我闻: 如果我们只有一个神经元(即一个单一的线性或非线性函数),仍然可以完成一些简单的任务。以下是一个神经元可以实现的功能和应用: 1. 实现简单的线性分类 输入:一组特征向量 x x x 输出&#xff…...

ubuntu22.04降级安装CUDA11.3

环境:主机x64的ubuntu22.04,原有CUDA12.1,但是现在需要CUDA11.3,本篇文章介绍步骤。 一、下载CUDA11.3的run文件 下载网址:https://developer.nvidia.com/cuda-11-3-1-download-archive?target_osLinux&target_…...

为AI聊天工具添加一个知识系统 之32 三“中”全“会”:推理式的ISA(父类)和IOS(母本)以及生成式CMN (双亲委派)之1

本文要点和问题 要点 三“中”全“会”:推理式的ISA的(父类-父类源码)和IOS的(母本-母类脚本)以及生成式 CMN (双亲委派-子类实例)。 数据中台三端架构的中间端(信息系统架构ISA &#xff1a…...

Python----Python高级(函数基础,形参和实参,参数传递,全局变量和局部变量,匿名函数,递归函数,eval()函数,LEGB规则)

一、函数基础 1.1、函数的用法和底层分析 函数是可重用的程序代码块。 函数的作用,不仅可以实现代码的复用,更能实现代码的一致性。一致性指的是,只要修改函数的代码,则所有调用该函数的地方都能得到体现。 在编写函数时&#xf…...

spring解决循环依赖的通俗理解

目录标题 1、什么是循环依赖2、解决循环依赖的原理3、Spring通过三级缓存解决循环依赖4、为什么要使用三级缓存而不是二级缓存?5、三级缓存中存放的是lambda表达式而不是一个半成品对象 1、什么是循环依赖 众所周知,Spring的容器中管理整个体系的bean对…...

用 Python 从零开始创建神经网络(十九):真实数据集

真实数据集 引言数据准备数据加载数据预处理数据洗牌批次(Batches)训练(Training)到目前为止的全部代码: 引言 在实践中,深度学习通常涉及庞大的数据集(通常以TB甚至更多为单位)&am…...

介绍PyTorch张量

介绍PyTorch张量 介绍PyTorch张量 PyTorch张量是我们在PyTorch中编程神经网络时将使用的数据结构。 在编程神经网络时,数据预处理通常是整个过程的第一步,数据预处理的一个目标是将原始输入数据转换为张量形式。 torch.Tensor​类的实例 PyTorch张量…...

Vision Transformer (ViT)原理

Vision Transformer (ViT)原理 flyfish Transformer缺乏卷积神经网络(CNNs)的归纳偏差(inductive biases),比如平移不变性和局部受限的感受野。不变性意味着即使实体entity(即对象)的外观或位…...

移动云自研云原生数据库入围国采!

近日,中央国家机关2024年度事务型数据库软件框架协议联合征集采购项目产品名单正式公布,移动云自主研发的云原生数据库产品顺利入围。这一成就不仅彰显了移动云在数据库领域深耕多年造就的领先技术优势,更标志着国家权威评审机构对移动云在数…...

Unity中对象池的使用(用一个简单粗暴的例子)

问题描述:Unity在创建和销毁对象的时候是很消耗性能的,所以我们在销毁一个对象的时候,可以不用Destroy,而是将这个物体隐藏后放到回收池里面,当再次需要的时候如果回收池里面有之前回收的对象,就直接拿来用…...

linux命令行连接Postgresql常用命令

1.linux系统命令行连接数据库命令 psql -h hostname -p port -U username -d databasename -h 主机名或IP地址 -p 端口 -U 用户名 -d 连接的数据库 2.查询数据库表命令 select version() #查看版本号 \dg #查看用户 \l #查询数据库 \c mydb #切换…...

每日一题-单链表排序

为了对给定的单链表按升序排序,我们可以考虑以下解决方法: 思路 归并排序(Merge Sort):由于归并排序的时间复杂度为 O ( n log ⁡ n ) O(n \log n) O(nlogn),并且归并排序不需要额外的空间(空…...

webpack04服务器配置

webpack配置 entryoutput filenamepathpublicPath 。。 打包引入的基本路径,,,比如引入一个bundle.js,。引用之后的路径就是 publicPathfilename -devServer:static : 静态文件的位置。。。hostportopencompress : 静态资源是否用gzip压缩hi…...

JDK下载安装配置

一.JDK安装配置。 1.安装注意路径,其他直接下一步。 2.配置。 下接第4步. 或者 代码复制: JAVA_HOME D:\Program Files\Java\jdk1.8.0_91 %JAVA_HOME%\bin 或者直接配置 D:\Program Files\Java\jdk1.8.0_91\bin 3.验证(CMD)。 java javac java -version javac -version 二.下…...

30_Redis哨兵模式

在Redis主从复制模式中,因为系统不具备自动恢复的功能,所以当主服务器(master)宕机后,需要手动把一台从服务器(slave)切换为主服务器。在这个过程中,不仅需要人为干预,而且还会造成一段时间内服务器处于不可用状态,同时数据安全性也得不到保障,因此主从模式的可用性…...

NLP三大特征抽取器:CNN、RNN与Transformer全面解析

引言 自然语言处理(NLP)领域的快速发展离不开深度学习技术的推动。随着应用需求的不断增加,如何高效地从文本中抽取特征成为NLP研究中的核心问题。深度学习中三大主要特征抽取器——卷积神经网络(Convolutional Neural Network, …...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...

selenium学习实战【Python爬虫】

selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...

【C++进阶篇】智能指针

C内存管理终极指南:智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...

通过MicroSip配置自己的freeswitch服务器进行调试记录

之前用docker安装的freeswitch的,启动是正常的, 但用下面的Microsip连接不上 主要原因有可能一下几个 1、通过下面命令可以看 [rootlocalhost default]# docker exec -it freeswitch fs_cli -x "sofia status profile internal"Name …...

认识CMake并使用CMake构建自己的第一个项目

1.CMake的作用和优势 跨平台支持:CMake支持多种操作系统和编译器,使用同一份构建配置可以在不同的环境中使用 简化配置:通过CMakeLists.txt文件,用户可以定义项目结构、依赖项、编译选项等,无需手动编写复杂的构建脚本…...

深度剖析 DeepSeek 开源模型部署与应用:策略、权衡与未来走向

在人工智能技术呈指数级发展的当下,大模型已然成为推动各行业变革的核心驱动力。DeepSeek 开源模型以其卓越的性能和灵活的开源特性,吸引了众多企业与开发者的目光。如何高效且合理地部署与运用 DeepSeek 模型,成为释放其巨大潜力的关键所在&…...

Linux 下 DMA 内存映射浅析

序 系统 I/O 设备驱动程序通常调用其特定子系统的接口为 DMA 分配内存,但最终会调到 DMA 子系统的dma_alloc_coherent()/dma_alloc_attrs() 等接口。 关于 dma_alloc_coherent 接口详细的代码讲解、调用流程,可以参考这篇文章,我觉得写的非常…...

CVE-2023-25194源码分析与漏洞复现(Kafka JNDI注入)

漏洞概述 漏洞名称:Apache Kafka Connect JNDI注入导致的远程代码执行漏洞 CVE编号:CVE-2023-25194 CVSS评分:8.8 影响版本:Apache Kafka 2.3.0 - 3.3.2 修复版本:≥ 3.4.0 漏洞类型:反序列化导致的远程代…...

项目进度管理软件是什么?项目进度管理软件有哪些核心功能?

无论是建筑施工、软件开发,还是市场营销活动,项目往往涉及多个团队、大量资源和严格的时间表。如果没有一个系统化的工具来跟踪和管理这些元素,项目很容易陷入混乱,导致进度延误、成本超支,甚至失败。 项目进度管理软…...

WEB3全栈开发——面试专业技能点P8DevOps / 区块链部署

一、Hardhat / Foundry 进行合约部署 概念介绍 Hardhat 和 Foundry 都是以太坊智能合约开发的工具套件,支持合约的编译、测试和部署。 它们允许开发者在本地或测试网络快速开发智能合约,并部署到链上(测试网或主网)。 部署过程…...