【Python篇】深入机器学习核心:XGBoost 从入门到实战
文章目录
- XGBoost 完整学习指南:从零开始掌握梯度提升
- 1. 前言
- 2. 什么是XGBoost?
- 2.1 梯度提升简介
- 3. 安装 XGBoost
- 4. 数据准备
- 4.1 加载数据
- 4.2 数据集划分
- 5. XGBoost 基础操作
- 5.1 转换为 DMatrix 格式
- 5.2 设置参数
- 5.3 模型训练
- 5.4 预测
- 6. 模型评估
- 7. 超参数调优
- 7.1 常用超参数
- 7.2 网格搜索
- 8. XGBoost 特征重要性分析
- 9. 高级功能扩展
- 9.1 模型解释与可解释性
- 9.2 XGBoost 与交叉验证
- 9.3 处理缺失值
- 10. XGBoost 在不同任务中的应用
- 10.1 回归任务
- 10.2 二分类任务
- 11. 分布式训练
- 12. 实战案例:XGBoost 与 Kaggle 竞赛
- 总结
XGBoost 完整学习指南:从零开始掌握梯度提升
💬 欢迎讨论:如果你在学习过程中有任何问题或想法,欢迎在评论区留言,我们一起交流学习。你的支持是我继续创作的动力!
👍 点赞、收藏与分享:觉得这篇文章对你有帮助吗?别忘了点赞、收藏并分享给更多的小伙伴哦!你们的支持是我不断进步的动力!
🚀分享给更多人:如果你觉得这篇文章对你有帮助,欢迎分享给更多对C++感兴趣的朋友,让我们一起进步!
1. 前言
在机器学习中,XGBoost 是一种基于梯度提升的决策树(GBDT)实现,因其卓越的性能和速度,广泛应用于分类、回归等任务。尤其在Kaggle竞赛中,XGBoost以其强大的表现受到开发者青睐。
本文将带你从安装、基本概念到模型调优,全面掌握 XGBoost 的使用。
2. 什么是XGBoost?
2.1 梯度提升简介
XGBoost是基于梯度提升框架的一个优化版本。梯度提升是一种迭代的集成算法,通过不断构建新的树来补充之前模型的错误。它依赖多个决策树的集成效果,来提高最终模型的预测能力。
- Boosting:通过组合多个弱分类器来生成强分类器。
- 梯度提升:使用损失函数的梯度信息来逐步优化模型。
XGBoost 提供了对内存效率、计算速度、并行化的优化,是一个非常适合大数据和高维数据集的工具。
3. 安装 XGBoost
首先,我们需要安装 XGBoost 库。可以通过 pip
安装:
pip install xgboost
如果你使用的是 Jupyter Notebook,可以通过以下命令安装:
!pip install xgboost
安装完成后,使用以下代码验证:
import xgboost as xgb
print(xgb.__version__) # 显示安装的版本号
如果正确输出版本号,则表示安装成功。
4. 数据准备
在机器学习中,数据预处理至关重要。我们将使用经典的鸢尾花数据集(Iris dataset),这是一个用于分类任务的多类数据集。
4.1 加载数据
通过 Scikit-learn 轻松获取鸢尾花数据:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split# 加载数据
iris = load_iris()
X, y = iris.data, iris.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
4.2 数据集划分
为了评估模型性能,我们将数据集分为训练集和测试集,训练集用于模型训练,测试集用于性能评估。
# 查看训练集和测试集的大小
print(X_train.shape, X_test.shape)
5. XGBoost 基础操作
XGBoost 的核心数据结构是 DMatrix
,它是经过优化的内部数据格式,具有更高的内存和计算效率。
5.1 转换为 DMatrix 格式
我们将训练集和测试集转换为 DMatrix 格式:
# 转换为 DMatrix 格式
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)
DMatrix 支持稀疏矩阵,可以显著提升大型数据集的内存效率。
5.2 设置参数
XGBoost 提供了大量的超参数可以调节。我们从一些基本参数开始:
# 设置参数
params = {'objective': 'multi:softmax', # 多分类问题'num_class': 3, # 类别数量'max_depth': 4, # 树的最大深度'eta': 0.3, # 学习率'seed': 42
}
- objective:损失函数,这里我们选择的是多分类的
softmax
。 - num_class:类别的数量。
- max_depth:树的最大深度,越深的树更复杂,但容易过拟合。
- eta:学习率,用于控制每棵树对最终模型影响的大小。
5.3 模型训练
通过以下代码训练模型:
# 训练模型
num_round = 10 # 迭代次数
bst = xgb.train(params, dtrain, num_boost_round=num_round)
5.4 预测
训练完成后,我们可以使用测试集进行预测:
# 预测
preds = bst.predict(dtest)
print(preds)
此时输出的是模型对每个样本的预测类别。
6. 模型评估
XGBoost 支持多种评估指标。我们可以使用 Scikit-learn
提供的 accuracy_score
来评估模型的准确性。
from sklearn.metrics import accuracy_score# 计算准确率
accuracy = accuracy_score(y_test, preds)
print(f"模型准确率: {accuracy:.2f}")
假设输出为:
模型准确率: 0.98
98% 的准确率表示模型在鸢尾花数据集上的表现非常好。
7. 超参数调优
XGBoost 提供了丰富的超参数,适当的调优可以显著提升模型性能。我们可以使用 GridSearchCV
进行超参数搜索。
7.1 常用超参数
max_depth
:树的深度,影响模型复杂度和过拟合风险。learning_rate
(或eta
):学习率,控制每次迭代的步长。n_estimators
:提升树的数量,即训练的轮数。
7.2 网格搜索
我们使用 GridSearchCV
来对这些超参数进行调优:
from sklearn.model_selection import GridSearchCV
from xgboost import XGBClassifier# 创建模型
model = XGBClassifier()# 定义参数网格
param_grid = {'max_depth': [3, 4, 5],'n_estimators': [50, 100, 200],'learning_rate': [0.1, 0.3, 0.5]
}# 使用网格搜索
grid_search = GridSearchCV(model, param_grid, scoring='accuracy', cv=3)
grid_search.fit(X_train, y_train)# 输出最佳参数
print("最佳参数组合:", grid_search.best_params_)
网格搜索会自动尝试不同的参数组合,最后返回最优组合。
8. XGBoost 特征重要性分析
XGBoost 提供了内置的方法来分析特征的重要性。这有助于理解哪些特征对模型影响最大。
# 绘制特征重要性
xgb.plot_importance(bst)
plt.show()
特征重要性图将显示每个特征对模型的影响,帮助开发者进一步优化模型。
9. 高级功能扩展
9.1 模型解释与可解释性
对于生产环境中的应用,解释模型预测结果至关重要。你可以使用 SHAP (SHapley Additive exPlanations) 来解释 XGBoost 模型的预测。它帮助我们理解特征对预测结果的影响。
安装并使用 SHAP:
pip install shap
import shap# 使用 SHAP 解释模型
explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(dtest)# 可视化 SHAP 值
shap.summary_plot(shap_values, X_test)
这个图表将展示每个特征如何影响预测输出,红色表示正向影响,蓝色表示负向影响。
9.2 XGBoost 与交叉验证
交叉验证(Cross-Validation, CV)是一种常见的评估方法,用来减少过拟合的风险。XGBoost 提供了内置的交叉验证功能:
cv_results = xgb.cv(params, dtrain, num_boost_round=50, nfold=5, metrics="mlogloss", as_pandas=True, seed=42
)# 输出交叉验证结果
print(cv_results)
通过 xgb.cv
,我们可以在不同的参数组合下进行多次训练,计算出平均损失值或准确率,从而找到最优的超参数。
9.3 处理缺失值
XGBoost 具有强大的处理缺失值能力,它会在训练过程中自动处理数据中的缺失值,选择最优的分裂方式。这使得它非常适合应用在含有缺失值的真实数据集上。
例如,如果数据中有缺失值,XGBoost 不需要手动填补:
import numpy as np
# 假设数据集中有 NaN 值
X_train[0, 0] = np.nan
dtrain = xgb.DMatrix(X_train, label=y_train)
10. XGBoost 在不同任务中的应用
10.1 回归任务
XGBoost 不仅适用于分类问题,也可以处理回归问题。在回归任务中,目标函数可以设置为 reg:squarederror
,这是最常见的回归目标:
params = {'objective': 'reg:squarederror', # 回归任务'max_depth': 4,'eta': 0.1,
}# 加载样例数据(例如房价预测)
from sklearn.datasets import load_boston
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)# 训练回归模型
bst = xgb.train(params, dtrain, num_boost_round=100)# 进行预测
preds = bst.predict(dtest)
print(preds)
10.2 二分类任务
对于二分类问题,我们可以将目标函数设置为 binary:logistic
,输出预测值为一个概率。
params = {'objective': 'binary:logistic','max_depth': 4,'eta': 0.3,
}# 假设我们有一个二分类数据集
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)# 训练模型
bst = xgb.train(params, dtrain, num_boost_round=100)# 进行预测
preds = bst.predict(dtest)
11. 分布式训练
XGBoost 支持多机多 GPU 的分布式训练,这使得它在大规模数据集上具有很高的可扩展性。要启用分布式训练,首先需要搭建集群,并配置相应的参数。
XGBoost 通过 Rabit
框架进行节点间的通信,支持通过 Spark、Dask 等框架实现分布式训练。你可以在大规模数据集上使用 XGBoost 高效地进行训练。
12. 实战案例:XGBoost 与 Kaggle 竞赛
XGBoost 在许多 Kaggle 竞赛中取得了优异的成绩。以下是一个实际案例:我们将使用泰坦尼克号乘客生存预测数据集,进行完整的模型训练与评估。
import pandas as pd# 加载泰坦尼克号数据
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')# 数据预处理
train['Age'].fillna(train['Age'].mean(), inplace=True)
train['Embarked'].fillna('S', inplace=True)
train['Fare'].fillna(train['Fare'].mean(), inplace=True)# 特征处理
train['Sex'] = train['Sex'].map({'male': 0, 'female': 1})
train['Embarked'] = train['Embarked'].map({'S': 0, 'C': 1, 'Q': 2})# 特征和标签
X_train = train[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']]
y_train = train['Survived']dtrain = xgb.DMatrix(X_train, label=y_train)# 设置参数
params = {'objective': 'binary:logistic','max_depth': 3,'eta': 0.1,'eval_metric': 'logloss'
}# 训练模型
bst = xgb.train(params, dtrain, num_boost_round=100)# 对测试集进行预测
dtest = xgb.DMatrix(test[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']])
preds = bst.predict(dtest)
这是一个简单的例子,展示了如何使用 XGBoost 处理分类任务并进行模型预测。根据任务复杂度,可以通过特征工程和调参来提升模型表现。
总结
在本教程中,我们详细介绍了 XGBoost 的各个方面,从基础到高级应用,包括分类、回归、特征重要性、调参、分布式训练等。XGBoost 作为高效的梯度提升工具,在各种机器学习任务中都表现优异。通过不断的实践和优化,你可以让 XGBoost 在实际项目中发挥更大的作用。
以上就是关于【Python篇】深入机器学习核心:XGBoost 从入门到实战的内容啦,各位大佬有什么问题欢迎在评论区指正,或者私信我也是可以的啦,您的支持是我创作的最大动力!❤️
相关文章:

【Python篇】深入机器学习核心:XGBoost 从入门到实战
文章目录 XGBoost 完整学习指南:从零开始掌握梯度提升1. 前言2. 什么是XGBoost?2.1 梯度提升简介 3. 安装 XGBoost4. 数据准备4.1 加载数据4.2 数据集划分 5. XGBoost 基础操作5.1 转换为 DMatrix 格式5.2 设置参数5.3 模型训练5.4 预测 6. 模型评估7. 超…...

简单学习 原码反码补码 学会了你才是真正的程序员了
一、简单介绍原码反码补码 首先我们需要知道的是原码反码补码是一个人为的行为,因为机器看的都是所谓的补码,这个反码只是作为补码的到原码也就是人能看懂的跳板,所以计算机无论是计算器里面的东西还是他底层运行的二进制代码都是补码&#x…...

基于规则的命名实体识别
基于规则的命名实体识别(Rule-Based Named Entity Recognition, NER)是一种通过预定义的模式或规则来识别文本中特定实体的方法。这种方法通常使用正则表达式来匹配文本中的实体。下面是一个更完整的示例,展示了如何使用正则表达式来识别文本…...

C语言从头学63—学习头文件stdlib.h(二)
6、随机数函数rand() 功能:产生0~RAND_MAX 之间的随机整数。 使用格式:rand(); //无参 返回值:返回随机整数 说明: a.RAND_MAX是一个定义在stdlib.h里面的宏,表示可以产生的最大随机整数&am…...

js判断一个对象里有没有某个属性
1. 使用in操作符 in操作符可以用来检测属性是否存在于对象或其原型链中。 const obj {a: 1, b: 2}; if (a in obj) { console.log(属性a存在于obj中); } else { console.log(属性a不存在于obj中); } 2. 使用hasOwnProperty()方法 hasOwnProperty()方法用来检测一个…...

Python(爬虫)正则表达式
正则表达式是文本匹配模式,也就是按照固定模式匹配文本 一、导入 re模块是Python环境的内置模块,所以无需手动安装。直接在文件中导入即可: import re 二、正则表达式基础知识 . 匹配除换行符以外的任意字符 ^ 匹配字符串的开始 $ 匹配字…...

Linux:进程(二)
目录 一、cwd的理解 二、fork的理解 1.代码共享 2.各司其职 3.fork的返回值 三、进程状态 1.进程排队 2.进程状态 运行状态 阻塞状态 挂起状态 一、cwd的理解 cwd(current working directory)。译为当前工作目录。 在C语言中,使用…...

【UE5】将2D切片图渲染为体积纹理,最终实现使用RT实时绘制体积纹理【第二篇-着色器制作】
在上一篇文章中,我们已经理顺了实现流程。 接下来,我们将在UE5中,从头开始一步一步地构建一次流程。 通过这种方法,我们可以借助一个熟悉的开发环境,使那些对着色器不太熟悉的朋友们更好地理解着色器的工作原理。 这篇…...

【OceanBase 诊断调优】—— GC问题根因分析
GC 流程涉及到 RS 的状态切换和 LS 的资源安全回收,流程上较长。且 GC 线程每个租户仅有一个,某个日志流 GC Hang 死时会卡住所有其余日志流的 GC,进而造成更大的影响。 本文档会帮助大家快速定位到 GC 故障的模块,直达问题核心。…...

图像面积计算一般方法及MATLAB实现
一、引言 在数字图像处理中,经常需要获取感兴趣区域的面积属性,下面给出图像处理的一般步骤。 1.读入的彩色图像 2.将彩色图像转化为灰度图像 3.灰度图像转化为二值图像 4.区域标记 5.对每个区域的面积进行计算和显示 二、程序代码 %面积计算 cle…...

指挥平台在应急场所中的主要表现有哪些
在应对自然灾害、公共安全事件等突发危机时,指挥平台作为应急管理体系的核心枢纽,其重要性不言而喻。它不仅承载着信息的快速汇聚、精准分析与高效调度功能,更在应急场所中有一定的关键表现。接下来就跟着北京嘉德立一起了解一下。 一、信息集…...

智能养殖场人机交互检测系统源码分享
智能养殖场人机交互检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Co…...

数据集-目标检测系列-海洋鱼类检测数据集 fish>> DataBall
数据集-目标检测系列-海洋鱼类检测数据集 fish>> DataBall 数据集-目标检测系列-海洋鱼类检测数据集 fish 数据量:1W 数据项目地址: gitcode: https://gitcode.com/DataBall/DataBall-detections-100s/overview github: https://github.com/…...

网络威慑战略带来的影响
文章目录 前言一、网络威慑的出现1、人工智能带来的机遇二、网络空间的威慑困境1、威慑概念的提出2、网络威慑的限度3、人类对网络威胁的认知变化4、网络空间的脆弱性总结前言 网络威慑是国家为应对网络空间风险和威胁而采取的战略。冷战时期核威慑路径难以有效复制至网络空间…...

决策树算法在机器学习中的应用
决策树算法在机器学习中的应用 决策树(Decision Tree)算法是一种基本的分类与回归方法,它通过树状结构对数据进行建模,以解决分类和回归问题。决策树算法在机器学习中具有广泛的应用,其直观性、易于理解和实现的特点使…...

Leetcode面试经典150题-39.组合总数进阶:40.组合总和II
本题是扩展题,真实考过,看这个题之前先看一下39题 Leetcode面试经典150题-39.组合总数-CSDN博客 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数…...

ProcessOn为什么导出有水印!!!(利用SVG转PNG)
processon-svg2png ProcessOn 一个非常好用的思维导图网站,但是为什么导出有水印!!!。 功能 支持按钮拖拽支持将流程图svg 转成 png下载支持修改自定义文字下载svg(开发中) 安装/使用方法 安装并使用…...

插入、更新与删除MySQL记录
在现代应用开发中,数据库操作是非常重要的一环。作为程序员,熟练掌握数据库的增删改功能,能够更有效地管理数据并提高开发效率。 本课程将围绕插入、更新与删除记录这三个操作展开,涵盖SQL中的常见语句:INSERT INTO、UPDATE 和 DELETE,并结合实际应用中的常见问题讨论如…...

【ARM】armv8的虚拟化深度解读
Type-1 hypervisor Type-1虚拟化也叫做Bare metal, standalone, Type1 Type2 hypervisor Type-2虚拟化也叫做hosted, Type-2 VM和vCPU(虚拟机和虚拟cpu) 在一个VM(虚拟机)中有多个vCPU,多个vCPU可能属于同一个Vritual Processor。 EL2…...

9/24作业
1. 分文件编译 分什么要分文件编译? 防止主文件过大,不好修改,简化编译流程 1) 分那些文件 头文件:所有需要提前导入的库文件,函数声明 功能函数:所有功能函数的定义 主函数:main函数&…...

Leetcode 106. 从中序与后序遍历序列构造二叉树
给定两个整数数组 inorder 和 postorder ,其中 inorder 是二叉树的中序遍历, postorder 是同一棵树的后序遍历,请你构造并返回这颗 二叉树 。 示例 1: 输入:inorder [9,3,15,20,7], postorder [9,15,7,20,3] 输出:[3…...

针对考研的C语言学习(定制化快速掌握重点1)
1.printf函数的几个要点 printf函数中所有的输出都是右对齐的,除非在%后面添加负号,则表示左对齐 #include<stdio.h> int main() {int num 10;int nums 100;float f 1000.2333333333;printf("%3d\n", nums);//%3d表示输出的总宽度至…...

【大数据入门 | Hive】DDL数据定义语言(数据库DataBase)
1. 数据库(DataBase) 1.1 创建数据库 语法: CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] [LOCATION hdfs_path] [WITH DBPROPERTIES (property_nameproperty_value, ...)]; 案例: (1)创建一个…...

CNVD漏洞和证书挖掘经验总结
前言 本篇文章主要是分享一下本人挖掘CVND漏洞碰到的一些问题,根据过往成功归档的漏洞和未归档的漏洞总结出的经验,也确实给审核的大佬们添了很多麻烦(主要真的没人教一下,闷着头尝试犯了好很多错误,希望各位以后交一个…...

阿里rtc旁路推流TypeScript版NODE运行
阿里云音视频服务云端录制typescript版本; 编译后可以使用 node index.js运行 package.json 版本 // npm install --save alicloud/rtc201801112.3.0 "alicloud/rtc20180111": "^2.3.0",引入 import Client, { StartCloudRecordRequest, StopCloudRecord…...

计算机书籍分享
0.简介 数据库系统概念、深入理解计算机系统、领域驱动设计、Linux高性能服务器编程 高清版本pdf 1.链接 数据库系统概念: 链接: https://pan.baidu.com/s/17zz7QFevV2Eni9qHJyLEGA 提取码: wfrx 深入理解计算机系统 链接: https://pan.baidu.com/s/19yiJG8GqHJR…...

处理ASAM-MDF格式的开源python库asammdf
asammdf是一个强大的Python库,专为处理ASAM(Association for Standardization of Automation and Measuring Systems)MDF(Measurement Data Format)文件而设计。MDF是一种用于存储测量和诊断数据的标准格式,…...

物业管理小程序开发
物业小程序的开发是一个综合性的项目,旨在提升物业管理效率和增强业主的服务体验。以下是关于物业小程序开发的一些关键方面: 一、需求分析 目标用户:识别主要用户群体,包括业主、租户、物业管理人员等。 功能需求: 物…...

【Vue】Pinia
系列文章目录 第八章 Pinia 文章目录 系列文章目录前言一、安装和配置:二、基本使用三、一个更真实的例子 前言 Pinia是Vue.js应用程序的状态管理库。它提供了一种简单,轻量级的解决方案,用于在Vue应用程序中管理和维护状态。Pinia库的特点…...

帕金森病患者的生命长度:科学管理与乐观心态是关键
在快节奏的现代生活中,健康成为了我们最宝贵的财富之一。然而,当“帕金森病”这个名词悄然进入我们的视野时,不少人心中难免会涌起一丝不安与担忧。帕金森病,作为一种常见的神经系统退行性疾病,确实给患者的日常生活带…...