当前位置: 首页 > news >正文

tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。

注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()

1 基本用法

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

1) optimizer: 优化器

可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。

以下是可以使用的字符串参数:

  1. 'sgd': 随机梯度下降优化器
  2. 'adam': Adam 优化器
  3. 'rmsprop': RMSprop 优化器
  4. 'adagrad': Adagrad 优化器
  5. 'adadelta': Adadelta 优化器
  6. 'adamax': Adamax 优化器
  7. 'nadam': Nadam 优化器
  8. 'ftrl': Ftrl 优化器

需要注意的是:

  1. 这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。

  2. 使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    
  3. ‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。

  4. 在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。

2) loss: 损失函数

用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。

以下是一些常用的字符串参数对应的损失函数:

  1. 'binary_crossentropy': 用于二分类问题的交叉熵损失。
  2. 'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。
  3. 'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。
  4. 'mean_squared_error''mse': 均方误差损失,用于回归问题。
  5. 'mean_absolute_error''mae': 平均绝对误差损失,用于回归问题。
  6. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  7. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题,对小差异不敏感。
  8. 'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。
  9. 'kullback_leibler_divergence''kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。
  10. 'hinge': 用于“最大间隔”分类问题的铰链损失。
  11. 'squared_hinge': 铰链损失的平方版本。
  12. 'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。

3) metrics: 评估指标列表,用于评估模型的性能

这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy''precision''recall' 等。

model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:

  1. 'accuracy''acc': 准确率,用于分类问题。
  2. 'binary_accuracy': 二分类准确率。
  3. 'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。
  4. 'sparse_categorical_accuracy': 多分类准确率,标签为整数。
  5. 'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。
  6. 'sparse_top_k_categorical_accuracy': 与 'top_k_categorical_accuracy' 类似,但适用于标签为整数的情况。
  7. 'mean_squared_error''mse': 均方误差,用于回归问题。
  8. 'mean_absolute_error''mae': 平均绝对误差,用于回归问题。
  9. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  10. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题。
  11. 'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。
  12. 'precision': 精确率,用于二分类或多标签分类问题。
  13. 'recall': 召回率,用于二分类或多标签分类问题。
  14. 'auc': 曲线下面积(Area Under the Curve),用于二分类问题。

使用示例:

# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])

对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。

from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])

2 高级用法

  • 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
  • 使用多个损失函数和评估指标:

如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。

model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
  • 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

相关文章:

tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。 注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()。 1 基本…...

Prometheus 中的 Exporter

在 Prometheus 生态系统中,Exporter 扮演着至关重要的角色,它们负责从不同的服务或系统中收集和暴露度量数据。本文将详细介绍 Exporter 的概念、类型以及如何有效使用它们将 Prometheus 集成到各种系统中进行监控。 什么是 Exporter? Exporter 是一段软件,它从应用程序或…...

网工_HDLC协议

2025.01.25:网工老姜学习笔记 第9节 HDLC协议 9.1 HDLC高级数据链路控制9.2 HDLC帧格式(*控制字段)9.2.1 信息帧(承载用户数据,0开头)9.2.2 监督帧(帮助信息可靠传输,10开头&#xf…...

leetcode 2563. 统计公平数对的数目

题目如下 数据范围 显然数组长度最大可以到10的5次方n方的复杂度必然超时,阅读题目实际上就是寻找两个位置不同的数满足不等式即可(实际上i j无所谓是哪个 我们只要把位置小的想成i就行)。 按照上面的思路我们只需要排序数组然后从前往后遍历数组然后利用二分查找…...

Debian 10 中 Linux 4.19 内核在 x86_64 架构上对中断嵌套的支持情况

一、中断嵌套的定义与原理 中断嵌套是指在一个中断处理程序(ISR)正在执行的过程中,另一个更高优先级的中断请求到来,系统暂停当前中断处理程序,转而处理新的高优先级中断。处理完高优先级中断后,系统返回到原来的中断处理程序继续执行。这种机制允许系统更高效地响应紧急…...

FLTK - FLTK1.4.1 - demo - bitmap

文章目录 FLTK - FLTK1.4.1 - demo - bitmap概述笔记END FLTK - FLTK1.4.1 - demo - bitmap 概述 // 功能 : 演示位图数据在按钮上的显示 // * 以按钮为范围或者以窗口为范围移动 // * 上下左右, 文字和图像的相对位置 // 失能按钮,使能按钮 // 知识点 // FLTK可…...

数据结构 树1

目录 前言 一,树的引论 二,二叉树 三,二叉树的详细理解 四,二叉搜索树 五,二分法与二叉搜索树的效率 六,二叉搜索树的实现 七,查找最大值和最小值 指针传递 vs 传引用 为什么指针按值传递不会修…...

android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色

安卓自定义DatePicker选中日期颜色 背景:解决方案:方案一:方案二:实践效果: 背景: 最近在尝试用原生安卓实现仿element-ui表单校验功能,其中的的选择日期涉及到安卓DatePicker组件的使用&#…...

MySQL 如何深度分页问题

在实际的数据库应用场景中,我们常常会遇到需要进行分页查询的需求。对于少量数据的分页查询,MySQL 可以轻松应对。然而,当我们需要进行深度分页(即从大量数据的中间位置开始获取少量数据)时,就会面临性能严…...

1.攻防世界easyphp

进入题目页面如下 是一段PHP代码进行代码审计 <?php // 高亮显示PHP文件源代码 highlight_file(__FILE__);// 初始化变量$key1和$key2为0 $key1 0; $key2 0;// 从GET请求中获取参数a的值&#xff0c;并赋值给变量$a $a $_GET[a]; // 从GET请求中获取参数b的值&#xff…...

深度学习 Pytorch 神经网络的学习

本节将从梯度下降法向外拓展&#xff0c;介绍更常用的优化算法&#xff0c;实现神经网络的学习和迭代。在本节课结束将完整实现一个神经网络训练的全流程。 对于像神经网络这样的复杂模型&#xff0c;可能会有数百个 w w w的存在&#xff0c;同时如果我们使用的是像交叉熵这样…...

如何利用天赋实现最大化的价值输出-补

原文&#xff1a; https://blog.csdn.net/ZhangRelay/article/details/145408621 ​​​​​​如何利用天赋实现最大化的价值输出-CSDN博客 如何利用天赋实现最大化的价值输出-CSDN博客 引用视频差异 第一段视频目标明确&#xff0c;建议也非常明确。 录制视频的人是主动性…...

Vue简介

目录 Vue是什么&#xff1f;为什么要使用Vue&#xff1f;Vue的三种加载方式拓展&#xff1a;什么是渐进式框架&#xff1f; Vue是什么&#xff1f; Vue是一套用于构建用户界面的渐进式 JavaScript (主张最少)框架 &#xff0c;开发者只需关注视图层。另一方面&#xff0c;当与…...

three.js+WebGL踩坑经验合集(6.2):负缩放,负定矩阵和行列式的关系(3D版本)

本篇将紧接上篇的2D版本对3D版的负缩放矩阵进行解读。 (6.1):负缩放&#xff0c;负定矩阵和行列式的关系&#xff08;2D版本&#xff09; 既然three.js对3D版的负缩放也使用行列式进行判断&#xff0c;那么&#xff0c;2D版的结论用到3D上其实是没毛病的&#xff0c;THREE.Li…...

使用 OpenResty 构建高效的动态图片水印代理服务20250127

使用 OpenResty 构建高效的动态图片水印代理服务 在当今数字化的时代&#xff0c;图片在各种业务场景中广泛应用。为了保护版权、统一品牌形象&#xff0c;动态图片水印功能显得尤为重要。然而&#xff0c;直接在后端服务中集成水印功能&#xff0c;往往会带来代码复杂度增加、…...

Kafka下载

一、Kafka下载 下载地址&#xff1a;https://kafka.apache.org/downloads 二、Kafka安装 因为选择下载的是 .zip 文件&#xff0c;直接跳过安装&#xff0c;一步到位。 选择在任一磁盘创建空文件夹&#xff08;不要使用中文路径&#xff09;&#xff0c;解压之后把文件夹内容…...

【C++语言】卡码网语言基础课系列----5. A+B问题VIII

文章目录 练习题目AB问题VIII具体代码实现 小白寄语诗词共勉 练习题目 AB问题VIII 题目描述&#xff1a; 你的任务是计算若干整数的和。 输入描述&#xff1a; 输入的第一行为一个整数N&#xff0c;接下来N行每行先输入一个整数M&#xff0c;然后在同一行内输入M个整数。 输出…...

IP服务模型

1. IP数据报 IP数据报中除了包含需要传输的数据外&#xff0c;还包括目标终端的IP地址和发送终端的IP地址。 数据报通过网络从一台路由器跳到另一台路由器&#xff0c;一路从IP源地址传递到IP目标地址。每个路由器都包含一个转发表&#xff0c;该表告诉它在匹配到特定目标地址…...

仿真设计|基于51单片机的温湿度、一氧化碳、甲醛检测报警系统

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现&#xff08;protues8.7&#xff09; 程序&#xff08;Keil5&#xff09; 全部内容 资料获取 具体实现功能 &#xff08;1&#xff09;温湿度传感器、CO传感器、甲醛传感器实时检测温湿度值、CO值和甲醛值进…...

QModbusTCPClient 服务器断开引起的程序崩溃

最近使用QModbusTCPClient 与一套设备通信&#xff0c;有一个QTimer频繁的通过读取设备寄存器。程序运行良好&#xff0c;但是有个问题&#xff1a;正常进行中设备断电了&#xff0c;整个程序都会崩溃。解决过程如下&#xff1a; 1.失败方案一 在QModbusTCPClient的errorOccu…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接&#xff1a;3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯&#xff0c;要想要能够将所有的电脑解锁&#x…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

Unit 1 深度强化学习简介

Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库&#xff0c;例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体&#xff0c;比如 SnowballFight、Huggy the Do…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

学习STC51单片机32(芯片为STC89C52RCRC)OLED显示屏2

每日一言 今天的每一份坚持&#xff0c;都是在为未来积攒底气。 案例&#xff1a;OLED显示一个A 这边观察到一个点&#xff0c;怎么雪花了就是都是乱七八糟的占满了屏幕。。 解释 &#xff1a; 如果代码里信号切换太快&#xff08;比如 SDA 刚变&#xff0c;SCL 立刻变&#…...

Java 二维码

Java 二维码 **技术&#xff1a;**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...

jmeter聚合报告中参数详解

sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample&#xff08;样本数&#xff09; 表示测试中发送的请求数量&#xff0c;即测试执行了多少次请求。 单位&#xff0c;以个或者次数表示。 示例&#xff1a;…...

基于Java+VUE+MariaDB实现(Web)仿小米商城

仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意&#xff1a;运行前…...