XGBOOST算法Python实现(保姆级)
摘要
XGBoost算法(eXtreme Gradient Boosting)在目前的Kaggle、数学建模和大数据应用等竞赛中非常流行。本文将会从XGBOOST算法原理、Python实现、敏感性分析和实际应用进行详细说明。
目录
0 绪论
一、材料准备
二、算法原理
三、算法Python实现
3.1 数据加载
3.2 将目标变量的定类数据分类编码
3.3 将数据分为训练数据和测试数据
3.4训练XGBOOST模型
3.5 测试模型
3.6 输出模型的预测混淆矩阵(结果矩阵)
3.7 输出模型准确率
3.8 绘制混淆矩阵图
3.9 完整实现代码
3.10 结果输出示例
四、 XGBOOST算法的敏感性分析和实际应用
4.1 敏感性分析
4.2 算法应用
五、结论
六、备注
0 绪论
数据挖掘和数学建模等比赛中,除了算法的实现,还需要对数据进行较为合理的预处理,包括缺失值处理、异常值处理、定类数据特征编码和冗余特征的删除等等,本文默认读者的数据均已完成数据预处理,如有需要,后续会将数据预处理的方法也进行发布。
一、材料准备
Python编译器:Pycharm社区版或个人版等
训练数据集:此处使用2022年数维杯国际大学生数学建模竞赛C题的附件数据为例。
数据处理:经过初步数据清洗和相关性分析等操作得到初步的特征,并利用决策树进行特征重要性分析,完成二次特征降维,得到'CDRSB_bl', 'PIB_bl', 'FBB_bl'三个自变量特征,DX_bl为分类特征。
二、算法原理
XGBOOST算法基于决策树的集成方法,主要采用了Boosting的思想,是Gradient Boosting算法的扩展,并使用梯度提升技术来提高模型的准确性和泛化能力。
首先将基分类器层层叠加,然后每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重,XGBOOST的目标函数为:

(1)
其中,
为损失函数;
为正则项,用于控制树的复杂度;
为常数项,
为新树的预测值
,它是将树的个数的结果进行求和。
三、算法Python实现
3.1 数据加载
此处导入本文所需数据,DataX为自变量数据,DataY为目标变量数据(DX_bl)。
import pandas as pd
X = pd.DataFrame(pd.read_excel('DataX.xlsx')).values # 输入特征
y = pd.DataFrame(pd.read_excel('DataY.xlsx')).values # 目标变量
3.2 将目标变量的定类数据分类编码
此处仅用0-4来代替五类数据,因为此处仅做预测,并不涉及相关性分析等其他操作,所以普通的分类编码就可以。如果需要用来做相关性分析或其他计算型操作,建议使用独热编码(OneHot- Encoding)。
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y = le.fit_transform(y)
label_mapping = {0: 'AD', 1: 'CN', 2: 'EMCI', 3: 'LMCI', 4: 'SMC'}
#此处为了后续输出混淆矩阵时,用原始数据输出
3.3 将数据分为训练数据和测试数据
本文将原始样本数据通过随机洗牌,并将70%的样本数据作为训练数据,30%的样本数据作为测试数据。这是一个较为常见的拆分方法,读者可通过不同的拆分测试最佳准确率和F1-score。
from sklearn.model_selection import train_test_split
# 将数据分为训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, train_size=0.7, random_state=42)
3.4训练XGBOOST模型
基于70%的样本数据进行训练建模,python有XGBOOST算法的库,所以很方便就可以调用。
import xgboost as xgb
# 训练XGBoost分类器
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
#xgb.plot_tree(model)
3.5 测试模型
利用另外的30%样本数据进行测试模型准确率、精确率、召回率和F1度量值。
# 使用测试数据预测类别
y_pred = model.predict(X_test)
3.6 输出模型的预测混淆矩阵(结果矩阵)
此处输出混淆矩阵的方法和之前的随机森林、KNN算法都有点不同,因为随机森拉算法不需要将定类数据进行分类编码就可以直接调用随机森林算法模型。
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
cm = confusion_matrix(y_test, y_pred)
# 输出混淆矩阵
for i, true_label in enumerate(label_mapping.values()):row = ''for j, pred_label in enumerate(label_mapping.values()):row += f'{cm[i, j]} ({pred_label})\t'print(f'{row} | {true_label}')# 输出混淆矩阵
print(classification_report(y_test, y_pred,target_names=['AD', 'CN', 'EMCI', 'LMCI', 'SMC'])) # 输出混淆矩阵
3.7 输出模型准确率
#此处的导库在上一个代码段中已引入
print("Accuracy:")
print(accuracy_score(y_test, y_pred))
3.8 绘制混淆矩阵图
将混淆矩阵结果图绘制并输出,可以将这一结果图放在论文中,提升论文美感和信服度。
import matplotlib.pyplot as plt
import numpy as np
label_names = ['AD', 'CN', 'EMCI', 'LMCI', 'SMC']
cm = confusion_matrix(y_test, y_pred)# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),yticks=np.arange(cm.shape[0]),xticklabels=label_names, yticklabels=label_names,title='Confusion matrix',ylabel='True label',xlabel='Predicted label')# 在矩阵图中显示数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, format(cm[i, j], 'd'),ha="center", va="center",color="white" if cm[i, j] > thresh else "black")fig.tight_layout()
#plt.show()
plt.savefig('XGBoost_Conclusion.png', dpi=300)
上面的代码首先计算混淆矩阵,然后使用 matplotlib 库中的 imshow 函数将混淆矩阵可视化,最后通过 text 函数在混淆矩阵上添加数字,并使用 show/savefig 函数显示图像,结果输出如图3.1所示。

图3.1 混淆矩阵结果图
3.9 完整实现代码
# 导入需要的库
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import numpy as nple = LabelEncoder()
label_mapping = {0: 'AD', 1: 'CN', 2: 'EMCI', 3: 'LMCI', 4: 'SMC'}
X = pd.DataFrame(pd.read_excel('DataX.xlsx')).values # 输入特征
y = pd.DataFrame(pd.read_excel('DataY.xlsx')).values # 目标变量
y = le.fit_transform(y)
# 将数据分为训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, train_size=0.7, random_state=42)
# 训练XGBoost分类器
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
#xgb.plot_tree(model)
# 使用测试数据预测类别
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
# 输出混淆矩阵
for i, true_label in enumerate(label_mapping.values()):row = ''for j, pred_label in enumerate(label_mapping.values()):row += f'{cm[i, j]} ({pred_label})\t'print(f'{row} | {true_label}')# 输出混淆矩阵
print(classification_report(y_test, y_pred,target_names=['AD', 'CN', 'EMCI', 'LMCI', 'SMC'])) # 输出混淆矩阵
print("Accuracy:")
print(accuracy_score(y_test, y_pred))# label_names 是分类变量的取值名称列表
label_names = ['AD', 'CN', 'EMCI', 'LMCI', 'SMC']
cm = confusion_matrix(y_test, y_pred)# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),yticks=np.arange(cm.shape[0]),xticklabels=label_names, yticklabels=label_names,title='Confusion matrix',ylabel='True label',xlabel='Predicted label')# 在矩阵图中显示数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, format(cm[i, j], 'd'),ha="center", va="center",color="white" if cm[i, j] > thresh else "black")fig.tight_layout()
#plt.show()
plt.savefig('XGBoost_Conclusion.png', dpi=300)
# 上面的代码首先计算混淆矩阵,然后使用 matplotlib 库中的 imshow 函数将混淆矩阵可视化,最后通过 text 函数在混淆矩阵上添加数字,并使用 show/savefig 函数显示图像。
3.10 结果输出示例

图3.2 结果输出示例
四、 XGBOOST算法的敏感性分析和实际应用
4.1 敏感性分析
敏感性分析也叫做稳定性分析,可以基于统计学思想,通过百次测试,记录其准确率、精确率、召回率和F1-Score的数据,统计其中位数、平均值、最大值和最小值等数据,从而进行对应的敏感性分析。结果表明符合原模型成立,则通过了敏感性分析。前面的随机森林算法和KNN算法也是如此。
4.2 算法应用
XGBOOST算法可应用于大数据分析、预测等方面,尤其是大数据竞赛(Kaggle、阿里天池等竞赛中)特别常用,也是本人目前认为最好用的一个算法。
五、结论
本文基于XGBOOST算法,从数据预处理、算法原理、算法实现、敏感性分析和算法应用都做了具体的分析,可适用于大部分机器学习算法初学者。
六、备注
本文为原创文章,禁止转载,违者必究。如需原始数据,可点赞+收藏,然后私聊作者或在评论区中留下你的邮箱,即可获得训练数据一份。
相关文章:
XGBOOST算法Python实现(保姆级)
摘要 XGBoost算法(eXtreme Gradient Boosting)在目前的Kaggle、数学建模和大数据应用等竞赛中非常流行。本文将会从XGBOOST算法原理、Python实现、敏感性分析和实际应用进行详细说明。 目录 0 绪论 一、材料准备 二、算法原理 三、算法Python实现 3…...
JDK、MAVEN与IDEA的安装与配置
1.认识JDK、MAVEN与IDEA JDK 提供了编译和运行Java程序的基本环境。Maven 帮助管理项目的构建和依赖。IDEA 提供了一个强大的开发环境,使得编写、调试和运行Java程序更加高效。 2. 安装与环境配置 2.1 官网地址 选择你需要的版本下载: MAVEN下载传送…...
输出比较简介
输出比较简介 主要是用来输出PWM波形,这个波形是驱动电机的(智能车和机器人等)必要条件 OC(Output Compare)输出比较,还有IC,全称是Input Capture,意为输入捕获,还有CC…...
什么是反向 DNS 查找以及它的作用是什么?
反向DNS查询(rDNS)是一种技术,用于确定与某个IP地址对应的域名。当我们对一个IP地址进行反向DNS查询时,实际上是向域名系统(DNS)的特殊部分请求信息,这部分被称为PTR记录。PTR记录会返回与这个I…...
集群聊天服务器(13)redis环境安装和发布订阅命令
目录 环境安装订阅redis发布-订阅的客户端编程环境配置客户端编程 功能测试 环境安装 sudo apt-get install redis-server 先启动redis服务 /etc/init.d/redis-server start默认在6379端口上 redis是存键值对的,还可以存链表、数组等等复杂数据结构 而且数据是在…...
[ubuntu]编译共享内存读取出现read.c:(.text+0x1a): undefined reference to `shm_open‘问题解决方案
问题log /tmp/ccByifPx.o: In function main: read.c:(.text0x1a): undefined reference to shm_open read.c:(.text0xd9): undefined reference to shm_unlink collect2: error: ld returned 1 exit status 程序代码 #include <stdio.h> #include <stdlib.h> #…...
Python Matplotlib 安装指南:使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装
Python Matplotlib 安装指南:使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装 Matplotlib是Python最常用的数据可视化工具之一,结合Miniconda可以轻松管理安装和依赖项。在这篇文章中,我们将详细介绍如何使用Miniconda在Linux、mac…...
DimensionX 部署笔记
目录 生成视频用CogVideoX-5b-I2V 推理代码: DimensionX 生成视频用CogVideoX-5b-I2V 推理代码: 可以生成,从左向右旋转的,也可以生成从上往下旋转的: import torch from diffusers import CogVideoXImageToVideo…...
django从入门到精通(五)——表单与模型
好的,下面将详细介绍 Django 的表单与模型,包括它们的定义、使用、如何在 Django Admin 中结合使用,以及相关的字段类型和验证机制。 Django 模型与表单 1. Django 模型 Django 模型是一个 Python 类,用于定义数据库中的数据结…...
C语言Day 03 学习总结
Day 03 学习总结 流程控制语句 顺序结构 程序从上到下依次执行。每一条语句顺序执行,直到结束。 选择结构 程序根据条件选择执行某一条分支。包括 if-else 和 switch-case。 循环结构 程序反复执行某段代码。包括 for、while、do-while。 跳转结构 控制程序直接跳…...
kafka中是如何快速定位到一个offset的
定位到具体的segment日志文件,采用二分法先定位到index索引文件计算查找的offset在日志文件的相对偏移量 1、分区和日志段: 每个主题的分区(Partition)被划分为多个日志段(Log Segment)。每个日志段是一个…...
视频对接rtsp协议学习
RTSP协议在视频平台中的应用 RTSP(Real Time Streaming Protocol)是一种基于TCP/IP的应用层协议,主要用于控制流媒体数据的传输和播放。它通过定义一系列命令和请求,实现对流媒体服务器的远程控制,但不传输媒体数据…...
【系统架构设计师】真题论文: 论企业信息化规划的实施与应用(包括解题思路和素材)
更多内容请见: 备考系统架构设计师-专栏介绍和目录 文章目录 真题题目(2012年 试题4)解题思路论文素材参考企业信息化规划概念与主要内容企业信息化规划实施的步骤企业信息化规划的应用案例真题题目(2012年 试题4) 企业信息化建设是一项长期而艰巨的任务,不可能在短时间…...
【ARM Coresight OpenOCD 系列 6.1 -- JTAG Commands】
请阅读【嵌入式开发学习必备专栏】 文章目录 JTAG Transport使用场景配置示例JTAG Speed配置示例初始化过程中的速度调整自适应时钟选择合适的速度Low Level JTAG Commandsdrscanflush_countirscanpathmoveruntestverify_ircaptureverify_jtagJTAG Transport OpenOCD 是一个强…...
开源许可协议
何同学推动了开源协议的认识,功不可没,第一次对开源有了清晰的认识,最宽松的MIT开源协议 源自OSC开源社区:何同学使用开源软件“翻车”,都别吵了!扯什么违反MIT...
241121学习日志——[CSDIY] [InternStudio] 大模型训练营 [11]
CSDIY:这是一个非科班学生的努力之路,从今天开始这个系列会长期更新,(最好做到日更),我会慢慢把自己目前对CS的努力逐一上传,帮助那些和我一样有着梦想的玩家取得胜利!!&…...
跟千里马学框架 遇到的坑
在编译 aosp 的 所有的东西都是和他一样的, 但是出现了这个问题 emulator: command not found 明明所有的都是一样的但是出现了这个问题 , 啥情况 。 首先你的 ubuntu 要开启虚拟机 。 这个自己百度去 重新进行这些步骤 1、 . build/envsetup.s…...
Swift从0开始学习 协议和扩展 day5
协议:定义行为的契约 协议类似于其他语言中的接口。它们定义了一组方法、属性或其他需求,供结构体、类、枚举等类型去遵循和实现。协议并不实现这些需求,而是作为一种约定或合同,确保实现协议的类型会遵循特定的行为。 协议的定义和遵循 在 Swift 中,使用 protocol 关键…...
javaScript交互案例
1、模态框(弹出框) (1)、需求: 点击弹出层,会弹出模态框,并且显示灰色半透明的遮挡层点击关闭按钮,可以关闭模态框,并且同时关闭半透明遮挡层鼠标放在模态框最上面一行,可以按住鼠…...
【自动驾驶】数据集合集!
本文将为您介绍经典、热门的数据集,希望对您在选择适合的数据集时有所帮助。 1 Automatic-driving-Test 更新时间:2024-07-26 访问地址: GitHub 描述: 该模型使用 ultralytics yolo v8 和 deepsort 方法来检测车道与车轮的碰撞并跟踪车辆。…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...
Python爬虫实战:研究feedparser库相关技术
1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
Java多线程实现之Callable接口深度解析
Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...
Linux-07 ubuntu 的 chrome 启动不了
文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了,报错如下四、启动不了,解决如下 总结 问题原因 在应用中可以看到chrome,但是打不开(说明:原来的ubuntu系统出问题了,这个是备用的硬盘&a…...
前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...
IT供电系统绝缘监测及故障定位解决方案
随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
