当前位置: 首页 > news >正文

XGB-20:XGBoost中不同参数的预测函数

有许多在XGBoost中具有不同参数的预测函数。

预测选项

xgboost.Booster.predict() 方法有许多不同的预测选项,从 pred_contribspred_leaf 不等。输出形状取决于预测的类型。对于多类分类问题,XGBoost为每个类构建一棵树,每个类的树称为树的“组”,因此输出维度可能会因所使用的模型而改变。

在1.4版本后,添加了 strict_shape 的新参数。可以将其设置为 True,以指示希望获得更受限制的输出。假设正在使用 xgboost.Booster,以下是可能的返回列表:

  • 使用 strict_shape 设置为 True 进行正常预测时:

    • 输出是一个2维数组,第一维是行数,第二维是组数。对于回归/生存/排序/二分类,这相当于一个形状为shape[1] == 1的列向量。但对于多类别问题,使用 multi:softprob 时,列数等于类别数。如果 strict_shape 设置为 False,输出1维或2维数组
  • 使用 output_margin 避免转换且 strict_shape 设置为 True 时:

    • 输出是一个2维数组,除了 multi:softmax 由于去掉了转换而具有与 multi:softprob 相等的输出形状。如果 strict_shape 设置为 False,则输出可以具有1维或2维,具体取决于所使用的模型
  • 使用 pred_contribsstrict_shape 设置为 True 时:

    • 输出是一个3维数组,形状为(行数,组数,列数+1)。是否使用 approx_contribs 不会改变输出形状。如果未设置 strict_shape 参数,则它可以是2维或3维数组,具体取决于是否使用多类别模型
  • 使用 pred_interactionsstrict_shape 设置为 True 时:

    • 输出是一个4维数组,形状为(行数,组数,列数+1,列数+1)。是否使用 approx_contribs 不会改变输出形状。如果 strict_shape 设置为 False,则它可以具有3维或4维,具体取决于底层模型
  • 使用 pred_leafstrict_shape 设置为 True 时:

    • 输出是一个4维数组,形状为(n_samples, n_iterations, n_classes, n_trees_in_forest)。 n_trees_in_forest 在训练过程中由 num_parallel_tree 指定。当 strict_shape 设置为 False 时,输出是一个2维数组,最后3维连接成1维。如果最后一维等于1,则会删除最后一维。

对于 R 包,当指定 strict_shape 时,将返回一个数组,其值与 Python 相同, R 数组是列主序的,而 Python 的 numpy 数组是行主序的,因此所有维度都被颠倒。例如,对于在 strict_shape=True 的情况下通过 Python predict_leaf 获得的输出有4个维度:(n_samples, n_iterations, n_classes, n_trees_in_forest),而在 R 中 strict_shape=TRUE 的输出是 (n_trees_in_forest, n_classes, n_iterations, n_samples)。

除了这些预测类型之外,还有一个称为 iteration_range 的参数,类似于模型切片。但与实际将模型拆分为多个堆栈不同,它只是返回由范围内的树形成的预测。每次迭代创建的树的数量等于num_parallel_tree。因此,如果正在训练大小为4的增强随机森林,对于3类别分类数据集,并且想要使用前2次迭代的树进行预测,需要提供 iteration_range=(0, 2)。然后将在此预测中使用前
棵树。

提前停止Early Stopping

在使用提前停止进行训练时,原生 Python 接口和 sklearn/R 接口之间存在一种不一致的行为。默认情况下,在 R 和 sklearn 接口上,会自动使用 best_iteration,因此预测将来自最佳模型。但是在原生 Python 接口中,xgboost.Booster.predict()xgboost.Booster.inplace_predict() 默认使用完整模型。用户可以使用 iteration_range 参数和 best_iteration 属性来实现相同的行为。此外,xgboost.callback.EarlyStoppingsave_best 参数可能会很有用。

基准分数Base Margin

XGBoost 中有一个名为 base_score 的训练参数,以及一个 DMatrix 的元数据称为 base_margin。它们指定了增强模型的全局偏差。如果提供了后者,则会忽略前者。base_margin 可用于基于其他模型训练 XGBoost 模型。

阶段性预测

使用 DMatrix 的原生接口,可以对预测进行阶段性(或缓存)。例如,可以首先对前4棵树进行预测,然后在8棵树上运行预测。在运行第一个预测后,前4棵树的结果被缓存,因此当您在8棵树上运行预测时,XGBoost 可以重复使用先前预测的结果。缓存会在下一次预测、训练或评估时自动过期,如果缓存的 DMatrix 对象已过期(例如,超出作用域并被语言环境中的垃圾回收器收集)。

阶段性预测

使用原生接口和 DMatrix,可以对预测进行阶段性(或缓存)。例如,可以首先对前4棵树进行预测,然后在8棵树上运行预测。在运行第一个预测后,前4棵树的结果被缓存,因此当在8棵树上运行预测时,XGBoost 可以重复使用先前预测的结果。如果缓存的 DMatrix 对象已过期(例如,超出作用域并被语言环境中的垃圾回收器收集),则缓存会在下一次预测、训练或评估时自动过期

In-place预测

传统上,XGBoost 只接受 DMatrix 进行预测,使用诸如 scikit-learn 接口之类的包装器时,构建过程会在内部发生。添加了对就地预测的支持,以绕过 DMatrix 的构建,这种构建方式速度较慢且占用内存。新的预测函数具有有限的功能,但通常对于简单的推断任务已经足够。它接受 Python 中一些常见的数据类型,如 numpy.ndarrayscipy.sparse.csr_matrixcudf.DataFrame,而不是 xgboost.DMatrix。可以调用 xgboost.Booster.inplace_predict() 来使用它。请注意,就地预测的输出取决于输入数据类型,当输入在 GPU 数据上时,输出为 cupy.ndarray,否则返回 numpy.ndarray

线程安全

在 1.4 版本之后,所有的预测函数,包括具有各种参数的正常预测(如 shap 值计算和 inplace_predict),在底层 booster 为 gbtree 或 dart 时是线程安全的,这意味着只要使用树模型,预测本身就应该是线程安全的。但是安全性仅在预测方面得到保证。如果尝试在一个线程中训练模型,并在另一个线程中使用相同的模型进行预测,则行为是未定义的。这比人们可能期望的更容易发生,例如可能会在预测函数内部意外地调用 clf.set_params()

def predict_fn(clf: xgb.XGBClassifier, X):X = preprocess(X)clf.set_params(n_jobs=1)  # NOT safe!return clf.predict_proba(X, iteration_range=(0, 10))with ThreadPoolExecutor(max_workers=10) as e:e.submit(predict_fn, ...)

隐私保护预测

Concrete ML 是由 Zama 开发的第三方开源库,提供了类似于梯度提升类,但直接在加密数据上进行预测的功能,这得益于全同态加密。一个简单的例子如下:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from concrete.ml.sklearn import XGBClassifierx, y = make_classification(n_samples=100, class_sep=2, n_features=30, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=10, random_state=42
)# Train in the clear and quantize the weights
model = XGBClassifier()
model.fit(X_train, y_train)# Simulate the predictions in the clear
y_pred_clear = model.predict(X_test)# Compile in FHE
model.compile(X_train)# Generate keys
model.fhe_circuit.keygen()# Run the inference on encrypted inputs!
y_pred_fhe = model.predict(X_test, fhe="execute")print("In clear:", y_pred_clear)
print("In FHE:", y_pred_fhe)
print(f"Similarity: {int((y_pred_fhe == y_pred_clear).mean()*100)}%")

参考

  • https://xgboost.readthedocs.io/en/latest/prediction.html

相关文章:

XGB-20:XGBoost中不同参数的预测函数

有许多在XGBoost中具有不同参数的预测函数。 预测选项 xgboost.Booster.predict() 方法有许多不同的预测选项,从 pred_contribs 到 pred_leaf 不等。输出形状取决于预测的类型。对于多类分类问题,XGBoost为每个类构建一棵树,每个类的树称为…...

websocket 使用示例

websocket 使用示例 前言html中使用vue3中使用1、安装websocket依赖2、代码 vue2中使用1、安装websocket依赖2、代码 前言 即时通讯webSocket 的使用 html中使用 以下是一个简单的 HTML 页面示例,它连接到 WebSocket 服务器并包含一个文本框、一个发送按钮以及 …...

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的水下目标检测系统(深度学习模型+UI界面+训练数据集)

摘要:本研究详述了一种采用深度学习技术的水下目标检测系统,该系统集成了最新的YOLOv8算法,并与YOLOv7、YOLOv6、YOLOv5等早期算法进行了性能评估对比。该系统能够在各种媒介——包括图像、视频文件、实时视频流及批量文件中——准确地识别水…...

中间件 Redis 服务集群的部署方案

前言 在互联网业务发展非常迅猛的早期,如果预算不是问题,强烈建议使用“增强单机硬件性能”的方式提升系统并发能力,因为这个阶段,公司的战略往往是发展业务抢时间,而“增强单机硬件性能”往往是最快的方法。 正是在这…...

生成哈夫曼树C卷(JavaPythonC++Node.jsC语言)

给定长度为n的无序的数字数组,每个数字代表二叉树的叶子节点的权值,数字数组的值均大于等于1。请完成一个函数,根据输入的数字数组,生成哈夫曼树,并将哈夫曼树按照中序遍历输出。 为了保证输出的二又树中序遍历结果统一,增加以下限制:二叉树节点中,左节点权值小于等于右…...

Java代码审计安全篇-SSRF(服务端请求伪造)漏洞

前言: 堕落了三个月,现在因为被找实习而困扰,着实自己能力不足,从今天开始 每天沉淀一点点 ,准备秋招 加油 注意: 本文章参考qax的网络安全java代码审计,记录自己的学习过程,还希望各…...

入门可解释机器学习和可解释性【内容分享和实战分析】

本篇文章为天池三月场读书会《可解释机器学习》的内容概述和项目实战分享,旨在为推广机器学习可解释性的应用提供一定帮助。 本次直播分享视频和实践代码以及PP获取地址:https://tianchi.aliyun.com/specials/promotion/activity/bookclub 目录 内容分…...

Promise其实也不难

难点图解:then()方法 ES6学习网站:ES6 入门教程 解决:回调地狱(回调函数中嵌套回调) 两个特点: (1)对象的状态不受外界影响。Promise对象代表一个异步操作&…...

吴恩达 x Open AI ChatGPT ——如何写出好的提示词视频核心笔记

核心知识点脑图如下: 1、第一讲:课程介绍 要点1: 上图展示了两种大型语言模型(LLMs)的对比:基础语言模型(Base LLM)和指令调整语言模型(Instruction Tuned LLM&#xff0…...

JVM从1%到99%【精选】-【初步认识】

目录 1.java虚拟机 2.JVM的位置 3.代码的执行流程 4.JVM的架构模型 5.JVM的生命周期 6.JVM的整体结构 1.java虚拟机 Java虚拟机是一台执行Java字节码的虚拟计算机,它拥有独立的运行机制,其运行的Java字节码也未必由Java语言编译而成。JVM平台的各种语言可以共享Java…...

pdf转图片(利用pdf2image包)

参考: pdf2image pip install pdf2image代码: from pdf2image import convert_from_path, convert_from_bytes import osoutput_folder ./xx/ dpi_value 600 pdf_start_page 1 # pdf显示的第一页 start_page 1 # 真实页码 prex # 图像前缀def to_…...

SwiftUI的转场动画

SwiftUI的转场动画 记录一下SwiftUI中的一些弹窗动画 import SwiftUIstruct TransitionBootCamp: View {State var showView falselet screenWidth UIScreen.main.bounds.widthlet screenHeight UIScreen.main.bounds.heightvar body: some View {ZStack(alignment: .botto…...

Trust Region Policy Optimization (TRPO)

Trust Region Policy Optimization (TRPO) 是一种强化学习算法,专门设计来改善策略梯度方法在稳定性和效率方面的表现。由 John Schulman 等人在 2015 年提出,TRPO 的核心思想是在策略优化过程中引入一个信任区域(trust region)&a…...

消息服务--Kafka的简介和使用

消息服务--Kafka的简介和使用 前言异步解耦削峰缓存1、消息队列2、kafka工作原理3、springBoot KafKa整合3.1 添加插件3.2 kafKa的自动配置类3.21 配置kafka地址3.22 如果需要发送对象配置kafka值的序列化器3.3 测试发送消息3.31 在发送测试消息的时候由于是开发环境中会遇到的…...

【c++11线程库的使用】

#include<iostream> #include<thread> #include<string> using namespace std; void hello(string msg) { for (int i 0; i < 1000; i) { cout << i; cout << endl; } } int main() { //1.创建线程 thread …...

无限debugger的几种处理方式

不少网站会在代码中加入‘debugger’&#xff0c;使你F12时一直卡在debugger&#xff0c;这种措施会让新手朋友束手无策。 js中创建debugger的方式有很多&#xff0c;基础的形式有&#xff1a; ①直接创建debugger debugger; ②通过eval创建debugger&#xff08;在虚拟机中…...

数据库基础理论知识

1.基本概念 数据(Data)&#xff1a;数据库存储的基本对象。数字、字符串、图形、图像、音频、视频等数据库(DB)&#xff1a;在计算机内&#xff0c;永久存储、有组织、可共享的数据集合数据库管理系统(DBMS)&#xff1a;管理数据库的系统软件数据库系统(DBS):DBDBMSDBADBAP 数…...

华为OD机试真题-模拟目录管理-2024年OD统一考试(C卷)

题目描述: 实现一个模拟目录管理功能的软件,输入一个命令序列,输出最后一条命令运行结果。 支持命令: 1)创建目录命令:mkdir 目录名称,如mkdir abc为在当前目录创建abc目录,如果已存在同名目录则不执行任何操作。此命令无输出。 2)进入目录命令:cd 目录名称, 如cd …...

yield代码解释

目录 我们的post请求爬取百度翻译的代码 详细解释 解释一 解释二 再说一下callback 总结 发现了很多人对存在有yield的代码都不理解&#xff0c;那就来详细的解释一下 我们的post请求爬取百度翻译的代码 import scrapy import jsonclass TestpostSpider(scrapy.Spider):…...

C#四部曲(知识补充)

Unity跨平台原理 .Net相关 只要编写的时候遵循.NET的这些规则&#xff0c;就能在.NET平台下通用 各种源码→根据.NET规范编写→(虚拟机)生成CIL中间码(保存在程序集中)→转成操作系统原代码 跨语言← 跨平台↓ Unity跨平台原理&#xff08;Mono&#xff09; c#脚本→MonoC#编…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

逻辑回归:给不确定性划界的分类大师

想象你是一名医生。面对患者的检查报告&#xff08;肿瘤大小、血液指标&#xff09;&#xff0c;你需要做出一个**决定性判断**&#xff1a;恶性还是良性&#xff1f;这种“非黑即白”的抉择&#xff0c;正是**逻辑回归&#xff08;Logistic Regression&#xff09;** 的战场&a…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂

蛋白质结合剂&#xff08;如抗体、抑制肽&#xff09;在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上&#xff0c;高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术&#xff0c;但这类方法普遍面临资源消耗巨大、研发周期冗长…...

【决胜公务员考试】求职OMG——见面课测验1

2025最新版&#xff01;&#xff01;&#xff01;6.8截至答题&#xff0c;大家注意呀&#xff01; 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:&#xff08; B &#xff09; A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...

Python如何给视频添加音频和字幕

在Python中&#xff0c;给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加&#xff0c;包括必要的代码示例和详细解释。 环境准备 在开始之前&#xff0c;需要安装以下Python库&#xff1a;…...

Java入门学习详细版(一)

大家好&#xff0c;Java 学习是一个系统学习的过程&#xff0c;核心原则就是“理论 实践 坚持”&#xff0c;并且需循序渐进&#xff0c;不可过于着急&#xff0c;本篇文章推出的这份详细入门学习资料将带大家从零基础开始&#xff0c;逐步掌握 Java 的核心概念和编程技能。 …...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

python执行测试用例,allure报乱码且未成功生成报告

allure执行测试用例时显示乱码&#xff1a;‘allure’ &#xfffd;&#xfffd;&#xfffd;&#xfffd;&#xfffd;ڲ&#xfffd;&#xfffd;&#xfffd;&#xfffd;ⲿ&#xfffd;&#xfffd;&#xfffd;Ҳ&#xfffd;&#xfffd;&#xfffd;ǿ&#xfffd;&am…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

在 Spring Boot 中使用 JSP

jsp&#xff1f; 好多年没用了。重新整一下 还费了点时间&#xff0c;记录一下。 项目结构&#xff1a; pom: <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://ww…...