当前位置: 首页 > news >正文

tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。

注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()

1 基本用法

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

1) optimizer: 优化器

可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。

以下是可以使用的字符串参数:

  1. 'sgd': 随机梯度下降优化器
  2. 'adam': Adam 优化器
  3. 'rmsprop': RMSprop 优化器
  4. 'adagrad': Adagrad 优化器
  5. 'adadelta': Adadelta 优化器
  6. 'adamax': Adamax 优化器
  7. 'nadam': Nadam 优化器
  8. 'ftrl': Ftrl 优化器

需要注意的是:

  1. 这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。

  2. 使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    
  3. ‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。

  4. 在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。

2) loss: 损失函数

用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。

以下是一些常用的字符串参数对应的损失函数:

  1. 'binary_crossentropy': 用于二分类问题的交叉熵损失。
  2. 'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。
  3. 'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。
  4. 'mean_squared_error''mse': 均方误差损失,用于回归问题。
  5. 'mean_absolute_error''mae': 平均绝对误差损失,用于回归问题。
  6. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  7. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题,对小差异不敏感。
  8. 'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。
  9. 'kullback_leibler_divergence''kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。
  10. 'hinge': 用于“最大间隔”分类问题的铰链损失。
  11. 'squared_hinge': 铰链损失的平方版本。
  12. 'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。

3) metrics: 评估指标列表,用于评估模型的性能

这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy''precision''recall' 等。

model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:

  1. 'accuracy''acc': 准确率,用于分类问题。
  2. 'binary_accuracy': 二分类准确率。
  3. 'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。
  4. 'sparse_categorical_accuracy': 多分类准确率,标签为整数。
  5. 'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。
  6. 'sparse_top_k_categorical_accuracy': 与 'top_k_categorical_accuracy' 类似,但适用于标签为整数的情况。
  7. 'mean_squared_error''mse': 均方误差,用于回归问题。
  8. 'mean_absolute_error''mae': 平均绝对误差,用于回归问题。
  9. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  10. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题。
  11. 'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。
  12. 'precision': 精确率,用于二分类或多标签分类问题。
  13. 'recall': 召回率,用于二分类或多标签分类问题。
  14. 'auc': 曲线下面积(Area Under the Curve),用于二分类问题。

使用示例:

# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])

对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。

from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])

2 高级用法

  • 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
  • 使用多个损失函数和评估指标:

如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。

model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
  • 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

相关文章:

tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。 注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()。 1 基本…...

Prometheus 中的 Exporter

在 Prometheus 生态系统中,Exporter 扮演着至关重要的角色,它们负责从不同的服务或系统中收集和暴露度量数据。本文将详细介绍 Exporter 的概念、类型以及如何有效使用它们将 Prometheus 集成到各种系统中进行监控。 什么是 Exporter? Exporter 是一段软件,它从应用程序或…...

网工_HDLC协议

2025.01.25:网工老姜学习笔记 第9节 HDLC协议 9.1 HDLC高级数据链路控制9.2 HDLC帧格式(*控制字段)9.2.1 信息帧(承载用户数据,0开头)9.2.2 监督帧(帮助信息可靠传输,10开头&#xf…...

leetcode 2563. 统计公平数对的数目

题目如下 数据范围 显然数组长度最大可以到10的5次方n方的复杂度必然超时,阅读题目实际上就是寻找两个位置不同的数满足不等式即可(实际上i j无所谓是哪个 我们只要把位置小的想成i就行)。 按照上面的思路我们只需要排序数组然后从前往后遍历数组然后利用二分查找…...

Debian 10 中 Linux 4.19 内核在 x86_64 架构上对中断嵌套的支持情况

一、中断嵌套的定义与原理 中断嵌套是指在一个中断处理程序(ISR)正在执行的过程中,另一个更高优先级的中断请求到来,系统暂停当前中断处理程序,转而处理新的高优先级中断。处理完高优先级中断后,系统返回到原来的中断处理程序继续执行。这种机制允许系统更高效地响应紧急…...

FLTK - FLTK1.4.1 - demo - bitmap

文章目录 FLTK - FLTK1.4.1 - demo - bitmap概述笔记END FLTK - FLTK1.4.1 - demo - bitmap 概述 // 功能 : 演示位图数据在按钮上的显示 // * 以按钮为范围或者以窗口为范围移动 // * 上下左右, 文字和图像的相对位置 // 失能按钮,使能按钮 // 知识点 // FLTK可…...

数据结构 树1

目录 前言 一,树的引论 二,二叉树 三,二叉树的详细理解 四,二叉搜索树 五,二分法与二叉搜索树的效率 六,二叉搜索树的实现 七,查找最大值和最小值 指针传递 vs 传引用 为什么指针按值传递不会修…...

android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色

安卓自定义DatePicker选中日期颜色 背景:解决方案:方案一:方案二:实践效果: 背景: 最近在尝试用原生安卓实现仿element-ui表单校验功能,其中的的选择日期涉及到安卓DatePicker组件的使用&#…...

MySQL 如何深度分页问题

在实际的数据库应用场景中,我们常常会遇到需要进行分页查询的需求。对于少量数据的分页查询,MySQL 可以轻松应对。然而,当我们需要进行深度分页(即从大量数据的中间位置开始获取少量数据)时,就会面临性能严…...

1.攻防世界easyphp

进入题目页面如下 是一段PHP代码进行代码审计 <?php // 高亮显示PHP文件源代码 highlight_file(__FILE__);// 初始化变量$key1和$key2为0 $key1 0; $key2 0;// 从GET请求中获取参数a的值&#xff0c;并赋值给变量$a $a $_GET[a]; // 从GET请求中获取参数b的值&#xff…...

深度学习 Pytorch 神经网络的学习

本节将从梯度下降法向外拓展&#xff0c;介绍更常用的优化算法&#xff0c;实现神经网络的学习和迭代。在本节课结束将完整实现一个神经网络训练的全流程。 对于像神经网络这样的复杂模型&#xff0c;可能会有数百个 w w w的存在&#xff0c;同时如果我们使用的是像交叉熵这样…...

如何利用天赋实现最大化的价值输出-补

原文&#xff1a; https://blog.csdn.net/ZhangRelay/article/details/145408621 ​​​​​​如何利用天赋实现最大化的价值输出-CSDN博客 如何利用天赋实现最大化的价值输出-CSDN博客 引用视频差异 第一段视频目标明确&#xff0c;建议也非常明确。 录制视频的人是主动性…...

Vue简介

目录 Vue是什么&#xff1f;为什么要使用Vue&#xff1f;Vue的三种加载方式拓展&#xff1a;什么是渐进式框架&#xff1f; Vue是什么&#xff1f; Vue是一套用于构建用户界面的渐进式 JavaScript (主张最少)框架 &#xff0c;开发者只需关注视图层。另一方面&#xff0c;当与…...

three.js+WebGL踩坑经验合集(6.2):负缩放,负定矩阵和行列式的关系(3D版本)

本篇将紧接上篇的2D版本对3D版的负缩放矩阵进行解读。 (6.1):负缩放&#xff0c;负定矩阵和行列式的关系&#xff08;2D版本&#xff09; 既然three.js对3D版的负缩放也使用行列式进行判断&#xff0c;那么&#xff0c;2D版的结论用到3D上其实是没毛病的&#xff0c;THREE.Li…...

使用 OpenResty 构建高效的动态图片水印代理服务20250127

使用 OpenResty 构建高效的动态图片水印代理服务 在当今数字化的时代&#xff0c;图片在各种业务场景中广泛应用。为了保护版权、统一品牌形象&#xff0c;动态图片水印功能显得尤为重要。然而&#xff0c;直接在后端服务中集成水印功能&#xff0c;往往会带来代码复杂度增加、…...

Kafka下载

一、Kafka下载 下载地址&#xff1a;https://kafka.apache.org/downloads 二、Kafka安装 因为选择下载的是 .zip 文件&#xff0c;直接跳过安装&#xff0c;一步到位。 选择在任一磁盘创建空文件夹&#xff08;不要使用中文路径&#xff09;&#xff0c;解压之后把文件夹内容…...

【C++语言】卡码网语言基础课系列----5. A+B问题VIII

文章目录 练习题目AB问题VIII具体代码实现 小白寄语诗词共勉 练习题目 AB问题VIII 题目描述&#xff1a; 你的任务是计算若干整数的和。 输入描述&#xff1a; 输入的第一行为一个整数N&#xff0c;接下来N行每行先输入一个整数M&#xff0c;然后在同一行内输入M个整数。 输出…...

IP服务模型

1. IP数据报 IP数据报中除了包含需要传输的数据外&#xff0c;还包括目标终端的IP地址和发送终端的IP地址。 数据报通过网络从一台路由器跳到另一台路由器&#xff0c;一路从IP源地址传递到IP目标地址。每个路由器都包含一个转发表&#xff0c;该表告诉它在匹配到特定目标地址…...

仿真设计|基于51单片机的温湿度、一氧化碳、甲醛检测报警系统

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现&#xff08;protues8.7&#xff09; 程序&#xff08;Keil5&#xff09; 全部内容 资料获取 具体实现功能 &#xff08;1&#xff09;温湿度传感器、CO传感器、甲醛传感器实时检测温湿度值、CO值和甲醛值进…...

QModbusTCPClient 服务器断开引起的程序崩溃

最近使用QModbusTCPClient 与一套设备通信&#xff0c;有一个QTimer频繁的通过读取设备寄存器。程序运行良好&#xff0c;但是有个问题&#xff1a;正常进行中设备断电了&#xff0c;整个程序都会崩溃。解决过程如下&#xff1a; 1.失败方案一 在QModbusTCPClient的errorOccu…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误

HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误&#xff0c;它们的含义、原因和解决方法都有显著区别。以下是详细对比&#xff1a; 1. HTTP 406 (Not Acceptable) 含义&#xff1a; 客户端请求的内容类型与服务器支持的内容类型不匹…...

反向工程与模型迁移:打造未来商品详情API的可持续创新体系

在电商行业蓬勃发展的当下&#xff0c;商品详情API作为连接电商平台与开发者、商家及用户的关键纽带&#xff0c;其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息&#xff08;如名称、价格、库存等&#xff09;的获取与展示&#xff0c;已难以满足市场对个性化、智能…...

SCAU期末笔记 - 数据分析与数据挖掘题库解析

这门怎么题库答案不全啊日 来简单学一下子来 一、选择题&#xff08;可多选&#xff09; 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘&#xff1a;专注于发现数据中…...

在四层代理中还原真实客户端ngx_stream_realip_module

一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡&#xff08;如 HAProxy、AWS NLB、阿里 SLB&#xff09;发起上游连接时&#xff0c;将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后&#xff0c;ngx_stream_realip_module 从中提取原始信息…...

如何在看板中有效管理突发紧急任务

在看板中有效管理突发紧急任务需要&#xff1a;设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP&#xff08;Work-in-Progress&#xff09;弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中&#xff0c;设立专门的紧急任务通道尤为重要&#xff0c;这能…...

数据链路层的主要功能是什么

数据链路层&#xff08;OSI模型第2层&#xff09;的核心功能是在相邻网络节点&#xff08;如交换机、主机&#xff09;间提供可靠的数据帧传输服务&#xff0c;主要职责包括&#xff1a; &#x1f511; 核心功能详解&#xff1a; 帧封装与解封装 封装&#xff1a; 将网络层下发…...

【AI学习】三、AI算法中的向量

在人工智能&#xff08;AI&#xff09;算法中&#xff0c;向量&#xff08;Vector&#xff09;是一种将现实世界中的数据&#xff08;如图像、文本、音频等&#xff09;转化为计算机可处理的数值型特征表示的工具。它是连接人类认知&#xff08;如语义、视觉特征&#xff09;与…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...