当前位置: 首页 > news >正文

tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。

注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()

1 基本用法

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

1) optimizer: 优化器

可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。

以下是可以使用的字符串参数:

  1. 'sgd': 随机梯度下降优化器
  2. 'adam': Adam 优化器
  3. 'rmsprop': RMSprop 优化器
  4. 'adagrad': Adagrad 优化器
  5. 'adadelta': Adadelta 优化器
  6. 'adamax': Adamax 优化器
  7. 'nadam': Nadam 优化器
  8. 'ftrl': Ftrl 优化器

需要注意的是:

  1. 这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。

  2. 使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    
  3. ‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。

  4. 在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。

2) loss: 损失函数

用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。

以下是一些常用的字符串参数对应的损失函数:

  1. 'binary_crossentropy': 用于二分类问题的交叉熵损失。
  2. 'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。
  3. 'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。
  4. 'mean_squared_error''mse': 均方误差损失,用于回归问题。
  5. 'mean_absolute_error''mae': 平均绝对误差损失,用于回归问题。
  6. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  7. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题,对小差异不敏感。
  8. 'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。
  9. 'kullback_leibler_divergence''kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。
  10. 'hinge': 用于“最大间隔”分类问题的铰链损失。
  11. 'squared_hinge': 铰链损失的平方版本。
  12. 'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。

3) metrics: 评估指标列表,用于评估模型的性能

这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy''precision''recall' 等。

model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:

  1. 'accuracy''acc': 准确率,用于分类问题。
  2. 'binary_accuracy': 二分类准确率。
  3. 'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。
  4. 'sparse_categorical_accuracy': 多分类准确率,标签为整数。
  5. 'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。
  6. 'sparse_top_k_categorical_accuracy': 与 'top_k_categorical_accuracy' 类似,但适用于标签为整数的情况。
  7. 'mean_squared_error''mse': 均方误差,用于回归问题。
  8. 'mean_absolute_error''mae': 平均绝对误差,用于回归问题。
  9. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  10. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题。
  11. 'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。
  12. 'precision': 精确率,用于二分类或多标签分类问题。
  13. 'recall': 召回率,用于二分类或多标签分类问题。
  14. 'auc': 曲线下面积(Area Under the Curve),用于二分类问题。

使用示例:

# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])

对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。

from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])

2 高级用法

  • 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
  • 使用多个损失函数和评估指标:

如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。

model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
  • 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

相关文章:

tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。 注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()。 1 基本…...

Prometheus 中的 Exporter

在 Prometheus 生态系统中,Exporter 扮演着至关重要的角色,它们负责从不同的服务或系统中收集和暴露度量数据。本文将详细介绍 Exporter 的概念、类型以及如何有效使用它们将 Prometheus 集成到各种系统中进行监控。 什么是 Exporter? Exporter 是一段软件,它从应用程序或…...

网工_HDLC协议

2025.01.25:网工老姜学习笔记 第9节 HDLC协议 9.1 HDLC高级数据链路控制9.2 HDLC帧格式(*控制字段)9.2.1 信息帧(承载用户数据,0开头)9.2.2 监督帧(帮助信息可靠传输,10开头&#xf…...

leetcode 2563. 统计公平数对的数目

题目如下 数据范围 显然数组长度最大可以到10的5次方n方的复杂度必然超时,阅读题目实际上就是寻找两个位置不同的数满足不等式即可(实际上i j无所谓是哪个 我们只要把位置小的想成i就行)。 按照上面的思路我们只需要排序数组然后从前往后遍历数组然后利用二分查找…...

Debian 10 中 Linux 4.19 内核在 x86_64 架构上对中断嵌套的支持情况

一、中断嵌套的定义与原理 中断嵌套是指在一个中断处理程序(ISR)正在执行的过程中,另一个更高优先级的中断请求到来,系统暂停当前中断处理程序,转而处理新的高优先级中断。处理完高优先级中断后,系统返回到原来的中断处理程序继续执行。这种机制允许系统更高效地响应紧急…...

FLTK - FLTK1.4.1 - demo - bitmap

文章目录 FLTK - FLTK1.4.1 - demo - bitmap概述笔记END FLTK - FLTK1.4.1 - demo - bitmap 概述 // 功能 : 演示位图数据在按钮上的显示 // * 以按钮为范围或者以窗口为范围移动 // * 上下左右, 文字和图像的相对位置 // 失能按钮,使能按钮 // 知识点 // FLTK可…...

数据结构 树1

目录 前言 一,树的引论 二,二叉树 三,二叉树的详细理解 四,二叉搜索树 五,二分法与二叉搜索树的效率 六,二叉搜索树的实现 七,查找最大值和最小值 指针传递 vs 传引用 为什么指针按值传递不会修…...

android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色

安卓自定义DatePicker选中日期颜色 背景:解决方案:方案一:方案二:实践效果: 背景: 最近在尝试用原生安卓实现仿element-ui表单校验功能,其中的的选择日期涉及到安卓DatePicker组件的使用&#…...

MySQL 如何深度分页问题

在实际的数据库应用场景中,我们常常会遇到需要进行分页查询的需求。对于少量数据的分页查询,MySQL 可以轻松应对。然而,当我们需要进行深度分页(即从大量数据的中间位置开始获取少量数据)时,就会面临性能严…...

1.攻防世界easyphp

进入题目页面如下 是一段PHP代码进行代码审计 <?php // 高亮显示PHP文件源代码 highlight_file(__FILE__);// 初始化变量$key1和$key2为0 $key1 0; $key2 0;// 从GET请求中获取参数a的值&#xff0c;并赋值给变量$a $a $_GET[a]; // 从GET请求中获取参数b的值&#xff…...

深度学习 Pytorch 神经网络的学习

本节将从梯度下降法向外拓展&#xff0c;介绍更常用的优化算法&#xff0c;实现神经网络的学习和迭代。在本节课结束将完整实现一个神经网络训练的全流程。 对于像神经网络这样的复杂模型&#xff0c;可能会有数百个 w w w的存在&#xff0c;同时如果我们使用的是像交叉熵这样…...

如何利用天赋实现最大化的价值输出-补

原文&#xff1a; https://blog.csdn.net/ZhangRelay/article/details/145408621 ​​​​​​如何利用天赋实现最大化的价值输出-CSDN博客 如何利用天赋实现最大化的价值输出-CSDN博客 引用视频差异 第一段视频目标明确&#xff0c;建议也非常明确。 录制视频的人是主动性…...

Vue简介

目录 Vue是什么&#xff1f;为什么要使用Vue&#xff1f;Vue的三种加载方式拓展&#xff1a;什么是渐进式框架&#xff1f; Vue是什么&#xff1f; Vue是一套用于构建用户界面的渐进式 JavaScript (主张最少)框架 &#xff0c;开发者只需关注视图层。另一方面&#xff0c;当与…...

three.js+WebGL踩坑经验合集(6.2):负缩放,负定矩阵和行列式的关系(3D版本)

本篇将紧接上篇的2D版本对3D版的负缩放矩阵进行解读。 (6.1):负缩放&#xff0c;负定矩阵和行列式的关系&#xff08;2D版本&#xff09; 既然three.js对3D版的负缩放也使用行列式进行判断&#xff0c;那么&#xff0c;2D版的结论用到3D上其实是没毛病的&#xff0c;THREE.Li…...

使用 OpenResty 构建高效的动态图片水印代理服务20250127

使用 OpenResty 构建高效的动态图片水印代理服务 在当今数字化的时代&#xff0c;图片在各种业务场景中广泛应用。为了保护版权、统一品牌形象&#xff0c;动态图片水印功能显得尤为重要。然而&#xff0c;直接在后端服务中集成水印功能&#xff0c;往往会带来代码复杂度增加、…...

Kafka下载

一、Kafka下载 下载地址&#xff1a;https://kafka.apache.org/downloads 二、Kafka安装 因为选择下载的是 .zip 文件&#xff0c;直接跳过安装&#xff0c;一步到位。 选择在任一磁盘创建空文件夹&#xff08;不要使用中文路径&#xff09;&#xff0c;解压之后把文件夹内容…...

【C++语言】卡码网语言基础课系列----5. A+B问题VIII

文章目录 练习题目AB问题VIII具体代码实现 小白寄语诗词共勉 练习题目 AB问题VIII 题目描述&#xff1a; 你的任务是计算若干整数的和。 输入描述&#xff1a; 输入的第一行为一个整数N&#xff0c;接下来N行每行先输入一个整数M&#xff0c;然后在同一行内输入M个整数。 输出…...

IP服务模型

1. IP数据报 IP数据报中除了包含需要传输的数据外&#xff0c;还包括目标终端的IP地址和发送终端的IP地址。 数据报通过网络从一台路由器跳到另一台路由器&#xff0c;一路从IP源地址传递到IP目标地址。每个路由器都包含一个转发表&#xff0c;该表告诉它在匹配到特定目标地址…...

仿真设计|基于51单片机的温湿度、一氧化碳、甲醛检测报警系统

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现&#xff08;protues8.7&#xff09; 程序&#xff08;Keil5&#xff09; 全部内容 资料获取 具体实现功能 &#xff08;1&#xff09;温湿度传感器、CO传感器、甲醛传感器实时检测温湿度值、CO值和甲醛值进…...

QModbusTCPClient 服务器断开引起的程序崩溃

最近使用QModbusTCPClient 与一套设备通信&#xff0c;有一个QTimer频繁的通过读取设备寄存器。程序运行良好&#xff0c;但是有个问题&#xff1a;正常进行中设备断电了&#xff0c;整个程序都会崩溃。解决过程如下&#xff1a; 1.失败方案一 在QModbusTCPClient的errorOccu…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

C++初阶-list的底层

目录 1.std::list实现的所有代码 2.list的简单介绍 2.1实现list的类 2.2_list_iterator的实现 2.2.1_list_iterator实现的原因和好处 2.2.2_list_iterator实现 2.3_list_node的实现 2.3.1. 避免递归的模板依赖 2.3.2. 内存布局一致性 2.3.3. 类型安全的替代方案 2.3.…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank&#xff1f;由于时间太久&#xff0c;我真忘记了。搜搜发现&#xff0c;还真有人和我一样。见下面的链接&#xff1a;https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合

强化学习&#xff08;Reinforcement Learning, RL&#xff09;是工业领域智能控制的重要方法。它的基本原理是将最优控制问题建模为马尔可夫决策过程&#xff0c;然后使用强化学习的Actor-Critic机制&#xff08;中文译作“知行互动”机制&#xff09;&#xff0c;逐步迭代求解…...

Objective-C常用命名规范总结

【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名&#xff08;Class Name)2.协议名&#xff08;Protocol Name)3.方法名&#xff08;Method Name)4.属性名&#xff08;Property Name&#xff09;5.局部变量/实例变量&#xff08;Local / Instance Variables&…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中&#xff0c;高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术&#xff0c;实现年省电费15%-60%&#xff0c;且不改动原有装备、安装快捷、…...

Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级

在互联网的快速发展中&#xff0c;高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司&#xff0c;近期做出了一个重大技术决策&#xff1a;弃用长期使用的 Nginx&#xff0c;转而采用其内部开发…...

爬虫基础学习day2

# 爬虫设计领域 工商&#xff1a;企查查、天眼查短视频&#xff1a;抖音、快手、西瓜 ---> 飞瓜电商&#xff1a;京东、淘宝、聚美优品、亚马逊 ---> 分析店铺经营决策标题、排名航空&#xff1a;抓取所有航空公司价格 ---> 去哪儿自媒体&#xff1a;采集自媒体数据进…...

Redis数据倾斜问题解决

Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中&#xff0c;部分节点存储的数据量或访问量远高于其他节点&#xff0c;导致这些节点负载过高&#xff0c;影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...