tf.Keras (tf-1.15)使用记录3-model.compile方法
model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。
注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()。
1 基本用法
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
1) optimizer: 优化器
可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。
以下是可以使用的字符串参数:
'sgd': 随机梯度下降优化器'adam': Adam 优化器'rmsprop': RMSprop 优化器'adagrad': Adagrad 优化器'adadelta': Adadelta 优化器'adamax': Adamax 优化器'nadam': Nadam 优化器'ftrl': Ftrl 优化器
需要注意的是:
-
这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。
-
使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) -
‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。
-
在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。
2) loss: 损失函数
用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。
以下是一些常用的字符串参数对应的损失函数:
'binary_crossentropy': 用于二分类问题的交叉熵损失。'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。'mean_squared_error'或'mse': 均方误差损失,用于回归问题。'mean_absolute_error'或'mae': 平均绝对误差损失,用于回归问题。'mean_absolute_percentage_error'或'mape': 平均绝对百分比误差,用于回归问题。'mean_squared_logarithmic_error'或'msle': 均方对数误差,用于回归问题,对小差异不敏感。'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。'kullback_leibler_divergence'或'kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。'hinge': 用于“最大间隔”分类问题的铰链损失。'squared_hinge': 铰链损失的平方版本。'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。
3) metrics: 评估指标列表,用于评估模型的性能
这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy'、'precision'、'recall' 等。
在 model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:
'accuracy'或'acc': 准确率,用于分类问题。'binary_accuracy': 二分类准确率。'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。'sparse_categorical_accuracy': 多分类准确率,标签为整数。'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。'sparse_top_k_categorical_accuracy': 与'top_k_categorical_accuracy'类似,但适用于标签为整数的情况。'mean_squared_error'或'mse': 均方误差,用于回归问题。'mean_absolute_error'或'mae': 平均绝对误差,用于回归问题。'mean_absolute_percentage_error'或'mape': 平均绝对百分比误差,用于回归问题。'mean_squared_logarithmic_error'或'msle': 均方对数误差,用于回归问题。'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。'precision': 精确率,用于二分类或多标签分类问题。'recall': 召回率,用于二分类或多标签分类问题。'auc': 曲线下面积(Area Under the Curve),用于二分类问题。
使用示例:
# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])
对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。
from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])
2 高级用法
- 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
- 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
- 使用多个损失函数和评估指标:
如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。
model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
- 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
相关文章:
tf.Keras (tf-1.15)使用记录3-model.compile方法
model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。 注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()。 1 基本…...
Prometheus 中的 Exporter
在 Prometheus 生态系统中,Exporter 扮演着至关重要的角色,它们负责从不同的服务或系统中收集和暴露度量数据。本文将详细介绍 Exporter 的概念、类型以及如何有效使用它们将 Prometheus 集成到各种系统中进行监控。 什么是 Exporter? Exporter 是一段软件,它从应用程序或…...
网工_HDLC协议
2025.01.25:网工老姜学习笔记 第9节 HDLC协议 9.1 HDLC高级数据链路控制9.2 HDLC帧格式(*控制字段)9.2.1 信息帧(承载用户数据,0开头)9.2.2 监督帧(帮助信息可靠传输,10开头…...
leetcode 2563. 统计公平数对的数目
题目如下 数据范围 显然数组长度最大可以到10的5次方n方的复杂度必然超时,阅读题目实际上就是寻找两个位置不同的数满足不等式即可(实际上i j无所谓是哪个 我们只要把位置小的想成i就行)。 按照上面的思路我们只需要排序数组然后从前往后遍历数组然后利用二分查找…...
Debian 10 中 Linux 4.19 内核在 x86_64 架构上对中断嵌套的支持情况
一、中断嵌套的定义与原理 中断嵌套是指在一个中断处理程序(ISR)正在执行的过程中,另一个更高优先级的中断请求到来,系统暂停当前中断处理程序,转而处理新的高优先级中断。处理完高优先级中断后,系统返回到原来的中断处理程序继续执行。这种机制允许系统更高效地响应紧急…...
FLTK - FLTK1.4.1 - demo - bitmap
文章目录 FLTK - FLTK1.4.1 - demo - bitmap概述笔记END FLTK - FLTK1.4.1 - demo - bitmap 概述 // 功能 : 演示位图数据在按钮上的显示 // * 以按钮为范围或者以窗口为范围移动 // * 上下左右, 文字和图像的相对位置 // 失能按钮,使能按钮 // 知识点 // FLTK可…...
数据结构 树1
目录 前言 一,树的引论 二,二叉树 三,二叉树的详细理解 四,二叉搜索树 五,二分法与二叉搜索树的效率 六,二叉搜索树的实现 七,查找最大值和最小值 指针传递 vs 传引用 为什么指针按值传递不会修…...
android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色
安卓自定义DatePicker选中日期颜色 背景:解决方案:方案一:方案二:实践效果: 背景: 最近在尝试用原生安卓实现仿element-ui表单校验功能,其中的的选择日期涉及到安卓DatePicker组件的使用&#…...
MySQL 如何深度分页问题
在实际的数据库应用场景中,我们常常会遇到需要进行分页查询的需求。对于少量数据的分页查询,MySQL 可以轻松应对。然而,当我们需要进行深度分页(即从大量数据的中间位置开始获取少量数据)时,就会面临性能严…...
1.攻防世界easyphp
进入题目页面如下 是一段PHP代码进行代码审计 <?php // 高亮显示PHP文件源代码 highlight_file(__FILE__);// 初始化变量$key1和$key2为0 $key1 0; $key2 0;// 从GET请求中获取参数a的值,并赋值给变量$a $a $_GET[a]; // 从GET请求中获取参数b的值ÿ…...
深度学习 Pytorch 神经网络的学习
本节将从梯度下降法向外拓展,介绍更常用的优化算法,实现神经网络的学习和迭代。在本节课结束将完整实现一个神经网络训练的全流程。 对于像神经网络这样的复杂模型,可能会有数百个 w w w的存在,同时如果我们使用的是像交叉熵这样…...
如何利用天赋实现最大化的价值输出-补
原文: https://blog.csdn.net/ZhangRelay/article/details/145408621 如何利用天赋实现最大化的价值输出-CSDN博客 如何利用天赋实现最大化的价值输出-CSDN博客 引用视频差异 第一段视频目标明确,建议也非常明确。 录制视频的人是主动性…...
Vue简介
目录 Vue是什么?为什么要使用Vue?Vue的三种加载方式拓展:什么是渐进式框架? Vue是什么? Vue是一套用于构建用户界面的渐进式 JavaScript (主张最少)框架 ,开发者只需关注视图层。另一方面,当与…...
three.js+WebGL踩坑经验合集(6.2):负缩放,负定矩阵和行列式的关系(3D版本)
本篇将紧接上篇的2D版本对3D版的负缩放矩阵进行解读。 (6.1):负缩放,负定矩阵和行列式的关系(2D版本) 既然three.js对3D版的负缩放也使用行列式进行判断,那么,2D版的结论用到3D上其实是没毛病的,THREE.Li…...
使用 OpenResty 构建高效的动态图片水印代理服务20250127
使用 OpenResty 构建高效的动态图片水印代理服务 在当今数字化的时代,图片在各种业务场景中广泛应用。为了保护版权、统一品牌形象,动态图片水印功能显得尤为重要。然而,直接在后端服务中集成水印功能,往往会带来代码复杂度增加、…...
Kafka下载
一、Kafka下载 下载地址:https://kafka.apache.org/downloads 二、Kafka安装 因为选择下载的是 .zip 文件,直接跳过安装,一步到位。 选择在任一磁盘创建空文件夹(不要使用中文路径),解压之后把文件夹内容…...
【C++语言】卡码网语言基础课系列----5. A+B问题VIII
文章目录 练习题目AB问题VIII具体代码实现 小白寄语诗词共勉 练习题目 AB问题VIII 题目描述: 你的任务是计算若干整数的和。 输入描述: 输入的第一行为一个整数N,接下来N行每行先输入一个整数M,然后在同一行内输入M个整数。 输出…...
IP服务模型
1. IP数据报 IP数据报中除了包含需要传输的数据外,还包括目标终端的IP地址和发送终端的IP地址。 数据报通过网络从一台路由器跳到另一台路由器,一路从IP源地址传递到IP目标地址。每个路由器都包含一个转发表,该表告诉它在匹配到特定目标地址…...
仿真设计|基于51单片机的温湿度、一氧化碳、甲醛检测报警系统
目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现(protues8.7) 程序(Keil5) 全部内容 资料获取 具体实现功能 (1)温湿度传感器、CO传感器、甲醛传感器实时检测温湿度值、CO值和甲醛值进…...
QModbusTCPClient 服务器断开引起的程序崩溃
最近使用QModbusTCPClient 与一套设备通信,有一个QTimer频繁的通过读取设备寄存器。程序运行良好,但是有个问题:正常进行中设备断电了,整个程序都会崩溃。解决过程如下: 1.失败方案一 在QModbusTCPClient的errorOccu…...
利用最小二乘法找圆心和半径
#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
逻辑回归:给不确定性划界的分类大师
想象你是一名医生。面对患者的检查报告(肿瘤大小、血液指标),你需要做出一个**决定性判断**:恶性还是良性?这种“非黑即白”的抉择,正是**逻辑回归(Logistic Regression)** 的战场&a…...
对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...
屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!
5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...
SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...
css的定位(position)详解:相对定位 绝对定位 固定定位
在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析
Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问(基础概念问题) 1. 请解释Spring框架的核心容器是什么?它在Spring中起到什么作用? Spring框架的核心容器是IoC容器&#…...
