当前位置: 首页 > news >正文

XGBOOST算法Python实现(保姆级)

摘要

        XGBoost算法(eXtreme Gradient Boosting)在目前的Kaggle、数学建模和大数据应用等竞赛中非常流行。本文将会从XGBOOST算法原理、Python实现、敏感性分析和实际应用进行详细说明。

目录

0 绪论

一、材料准备

二、算法原理

三、算法Python实现

        3.1 数据加载

        3.2 将目标变量的定类数据分类编码

        3.3 将数据分为训练数据和测试数据

        3.4训练XGBOOST模型

        3.5 测试模型

        3.6 输出模型的预测混淆矩阵(结果矩阵)

        3.7 输出模型准确率

        3.8 绘制混淆矩阵图

         3.9 完整实现代码

        3.10 结果输出示例

四、 XGBOOST算法的敏感性分析和实际应用

        4.1 敏感性分析

        4.2 算法应用

五、结论

六、备注

0 绪论

        数据挖掘和数学建模等比赛中,除了算法的实现,还需要对数据进行较为合理的预处理,包括缺失值处理、异常值处理、定类数据特征编码和冗余特征的删除等等,本文默认读者的数据均已完成数据预处理,如有需要,后续会将数据预处理的方法也进行发布。

一、材料准备

        Python编译器:Pycharm社区版或个人版等

        训练数据集:此处使用2022年数维杯国际大学生数学建模竞赛C题的附件数据为例。

        数据处理:经过初步数据清洗和相关性分析等操作得到初步的特征,并利用决策树进行特征重要性分析,完成二次特征降维,得到'CDRSB_bl', 'PIB_bl', 'FBB_bl'三个自变量特征,DX_bl为分类特征。

二、算法原理

     XGBOOST算法基于决策树的集成方法,主要采用了Boosting的思想,是Gradient Boosting算法的扩展,并使用梯度提升技术来提高模型的准确性和泛化能力。

        首先将基分类器层层叠加,然后每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重,XGBOOST的目标函数为:

    (1)

        其中,为损失函数;为正则项,用于控制树的复杂度;为常数项,为新树的预测值,它是将树的个数的结果进行求和。

三、算法Python实现

3.1 数据加载

        此处导入本文所需数据,DataX为自变量数据,DataY为目标变量数据(DX_bl)。

import pandas as pd
X = pd.DataFrame(pd.read_excel('DataX.xlsx')).values  # 输入特征
y = pd.DataFrame(pd.read_excel('DataY.xlsx')).values  # 目标变量

 3.2 将目标变量的定类数据分类编码

此处仅用0-4来代替五类数据,因为此处仅做预测,并不涉及相关性分析等其他操作,所以普通的分类编码就可以。如果需要用来做相关性分析或其他计算型操作,建议使用独热编码(OneHot- Encoding)。

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y = le.fit_transform(y)
label_mapping = {0: 'AD', 1: 'CN', 2: 'EMCI', 3: 'LMCI', 4: 'SMC'}
#此处为了后续输出混淆矩阵时,用原始数据输出

 3.3 将数据分为训练数据和测试数据

        本文将原始样本数据通过随机洗牌,并将70%的样本数据作为训练数据,30%的样本数据作为测试数据。这是一个较为常见的拆分方法,读者可通过不同的拆分测试最佳准确率和F1-score。

from sklearn.model_selection import train_test_split
# 将数据分为训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, train_size=0.7, random_state=42)

 3.4训练XGBOOST模型

        基于70%的样本数据进行训练建模,python有XGBOOST算法的库,所以很方便就可以调用。

import xgboost as xgb
# 训练XGBoost分类器
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
#xgb.plot_tree(model)

 3.5 测试模型

        利用另外的30%样本数据进行测试模型准确率、精确率、召回率和F1度量值。

# 使用测试数据预测类别
y_pred = model.predict(X_test)

 3.6 输出模型的预测混淆矩阵(结果矩阵)

        此处输出混淆矩阵的方法和之前的随机森林、KNN算法都有点不同,因为随机森拉算法不需要将定类数据进行分类编码就可以直接调用随机森林算法模型。

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
cm = confusion_matrix(y_test, y_pred)
# 输出混淆矩阵
for i, true_label in enumerate(label_mapping.values()):row = ''for j, pred_label in enumerate(label_mapping.values()):row += f'{cm[i, j]} ({pred_label})\t'print(f'{row} | {true_label}')# 输出混淆矩阵
print(classification_report(y_test, y_pred,target_names=['AD', 'CN', 'EMCI', 'LMCI', 'SMC']))  # 输出混淆矩阵

 3.7 输出模型准确率

#此处的导库在上一个代码段中已引入
print("Accuracy:")
print(accuracy_score(y_test, y_pred))

 3.8 绘制混淆矩阵图

        将混淆矩阵结果图绘制并输出,可以将这一结果图放在论文中,提升论文美感和信服度。

import matplotlib.pyplot as plt
import numpy as np
label_names = ['AD', 'CN', 'EMCI', 'LMCI', 'SMC']
cm = confusion_matrix(y_test, y_pred)# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),yticks=np.arange(cm.shape[0]),xticklabels=label_names, yticklabels=label_names,title='Confusion matrix',ylabel='True label',xlabel='Predicted label')# 在矩阵图中显示数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, format(cm[i, j], 'd'),ha="center", va="center",color="white" if cm[i, j] > thresh else "black")fig.tight_layout()
#plt.show()
plt.savefig('XGBoost_Conclusion.png', dpi=300)

        上面的代码首先计算混淆矩阵,然后使用 matplotlib 库中的 imshow 函数将混淆矩阵可视化,最后通过 text 函数在混淆矩阵上添加数字,并使用 show/savefig 函数显示图像,结果输出如图3.1所示。

图3.1 混淆矩阵结果图

  3.9 完整实现代码

# 导入需要的库
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import numpy as nple = LabelEncoder()
label_mapping = {0: 'AD', 1: 'CN', 2: 'EMCI', 3: 'LMCI', 4: 'SMC'}
X = pd.DataFrame(pd.read_excel('DataX.xlsx')).values  # 输入特征
y = pd.DataFrame(pd.read_excel('DataY.xlsx')).values  # 目标变量
y = le.fit_transform(y)
# 将数据分为训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, train_size=0.7, random_state=42)
# 训练XGBoost分类器
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
#xgb.plot_tree(model)
# 使用测试数据预测类别
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
# 输出混淆矩阵
for i, true_label in enumerate(label_mapping.values()):row = ''for j, pred_label in enumerate(label_mapping.values()):row += f'{cm[i, j]} ({pred_label})\t'print(f'{row} | {true_label}')# 输出混淆矩阵
print(classification_report(y_test, y_pred,target_names=['AD', 'CN', 'EMCI', 'LMCI', 'SMC']))  # 输出混淆矩阵
print("Accuracy:")
print(accuracy_score(y_test, y_pred))# label_names 是分类变量的取值名称列表
label_names = ['AD', 'CN', 'EMCI', 'LMCI', 'SMC']
cm = confusion_matrix(y_test, y_pred)# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),yticks=np.arange(cm.shape[0]),xticklabels=label_names, yticklabels=label_names,title='Confusion matrix',ylabel='True label',xlabel='Predicted label')# 在矩阵图中显示数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, format(cm[i, j], 'd'),ha="center", va="center",color="white" if cm[i, j] > thresh else "black")fig.tight_layout()
#plt.show()
plt.savefig('XGBoost_Conclusion.png', dpi=300)
# 上面的代码首先计算混淆矩阵,然后使用 matplotlib 库中的 imshow 函数将混淆矩阵可视化,最后通过 text 函数在混淆矩阵上添加数字,并使用 show/savefig 函数显示图像。

 3.10 结果输出示例

       

 图3.2 结果输出示例

四、 XGBOOST算法的敏感性分析和实际应用

 4.1 敏感性分析

         敏感性分析也叫做稳定性分析,可以基于统计学思想,通过百次测试,记录其准确率、精确率、召回率和F1-Score的数据,统计其中位数、平均值、最大值和最小值等数据,从而进行对应的敏感性分析。结果表明符合原模型成立,则通过了敏感性分析。前面的随机森林算法和KNN算法也是如此。

 4.2 算法应用

         XGBOOST算法可应用于大数据分析、预测等方面,尤其是大数据竞赛(Kaggle、阿里天池等竞赛中)特别常用,也是本人目前认为最好用的一个算法。

五、结论

        本文基于XGBOOST算法,从数据预处理、算法原理、算法实现、敏感性分析和算法应用都做了具体的分析,可适用于大部分机器学习算法初学者。

六、备注

        本文为原创文章,禁止转载,违者必究。如需原始数据,可点赞+收藏,然后私聊作者或在评论区中留下你的邮箱,即可获得训练数据一份。

相关文章:

XGBOOST算法Python实现(保姆级)

摘要 XGBoost算法(eXtreme Gradient Boosting)在目前的Kaggle、数学建模和大数据应用等竞赛中非常流行。本文将会从XGBOOST算法原理、Python实现、敏感性分析和实际应用进行详细说明。 目录 0 绪论 一、材料准备 二、算法原理 三、算法Python实现 3…...

JDK、MAVEN与IDEA的安装与配置

1.认识JDK、MAVEN与IDEA JDK 提供了编译和运行Java程序的基本环境。Maven 帮助管理项目的构建和依赖。IDEA 提供了一个强大的开发环境,使得编写、调试和运行Java程序更加高效。 2. 安装与环境配置 2.1 官网地址 选择你需要的版本下载: MAVEN下载传送…...

输出比较简介

输出比较简介 主要是用来输出PWM波形,这个波形是驱动电机的(智能车和机器人等)必要条件 OC(Output Compare)输出比较,还有IC,全称是Input Capture,意为输入捕获,还有CC…...

什么是反向 DNS 查找以及它的作用是什么?

反向DNS查询(rDNS)是一种技术,用于确定与某个IP地址对应的域名。当我们对一个IP地址进行反向DNS查询时,实际上是向域名系统(DNS)的特殊部分请求信息,这部分被称为PTR记录。PTR记录会返回与这个I…...

集群聊天服务器(13)redis环境安装和发布订阅命令

目录 环境安装订阅redis发布-订阅的客户端编程环境配置客户端编程 功能测试 环境安装 sudo apt-get install redis-server 先启动redis服务 /etc/init.d/redis-server start默认在6379端口上 redis是存键值对的,还可以存链表、数组等等复杂数据结构 而且数据是在…...

[ubuntu]编译共享内存读取出现read.c:(.text+0x1a): undefined reference to `shm_open‘问题解决方案

问题log /tmp/ccByifPx.o: In function main: read.c:(.text0x1a): undefined reference to shm_open read.c:(.text0xd9): undefined reference to shm_unlink collect2: error: ld returned 1 exit status 程序代码 #include <stdio.h> #include <stdlib.h> #…...

Python Matplotlib 安装指南:使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装

Python Matplotlib 安装指南&#xff1a;使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装 Matplotlib是Python最常用的数据可视化工具之一&#xff0c;结合Miniconda可以轻松管理安装和依赖项。在这篇文章中&#xff0c;我们将详细介绍如何使用Miniconda在Linux、mac…...

DimensionX 部署笔记

目录 生成视频用CogVideoX-5b-I2V 推理代码&#xff1a; DimensionX 生成视频用CogVideoX-5b-I2V 推理代码&#xff1a; 可以生成&#xff0c;从左向右旋转的&#xff0c;也可以生成从上往下旋转的&#xff1a; import torch from diffusers import CogVideoXImageToVideo…...

django从入门到精通(五)——表单与模型

好的&#xff0c;下面将详细介绍 Django 的表单与模型&#xff0c;包括它们的定义、使用、如何在 Django Admin 中结合使用&#xff0c;以及相关的字段类型和验证机制。 Django 模型与表单 1. Django 模型 Django 模型是一个 Python 类&#xff0c;用于定义数据库中的数据结…...

C语言Day 03 学习总结

Day 03 学习总结 流程控制语句 顺序结构 程序从上到下依次执行。每一条语句顺序执行&#xff0c;直到结束。 选择结构 程序根据条件选择执行某一条分支。包括 if-else 和 switch-case。 循环结构 程序反复执行某段代码。包括 for、while、do-while。 跳转结构 控制程序直接跳…...

kafka中是如何快速定位到一个offset的

定位到具体的segment日志文件&#xff0c;采用二分法先定位到index索引文件计算查找的offset在日志文件的相对偏移量 1、分区和日志段&#xff1a; 每个主题的分区&#xff08;Partition&#xff09;被划分为多个日志段&#xff08;Log Segment&#xff09;。每个日志段是一个…...

视频对接rtsp协议学习

RTSP协议在视频平台中的应用‌ RTSP&#xff08;Real Time Streaming Protocol&#xff09;是一种基于TCP/IP的应用层协议&#xff0c;主要用于控制流媒体数据的传输和播放。它通过定义一系列命令和请求&#xff0c;实现对流媒体服务器的远程控制&#xff0c;但不传输媒体数据…...

【系统架构设计师】真题论文: 论企业信息化规划的实施与应用(包括解题思路和素材)

更多内容请见: 备考系统架构设计师-专栏介绍和目录 文章目录 真题题目(2012年 试题4)解题思路论文素材参考企业信息化规划概念与主要内容企业信息化规划实施的步骤企业信息化规划的应用案例真题题目(2012年 试题4) 企业信息化建设是一项长期而艰巨的任务,不可能在短时间…...

【ARM Coresight OpenOCD 系列 6.1 -- JTAG Commands】

请阅读【嵌入式开发学习必备专栏】 文章目录 JTAG Transport使用场景配置示例JTAG Speed配置示例初始化过程中的速度调整自适应时钟选择合适的速度Low Level JTAG Commandsdrscanflush_countirscanpathmoveruntestverify_ircaptureverify_jtagJTAG Transport OpenOCD 是一个强…...

开源许可协议

何同学推动了开源协议的认识&#xff0c;功不可没&#xff0c;第一次对开源有了清晰的认识&#xff0c;最宽松的MIT开源协议 源自OSC开源社区&#xff1a;何同学使用开源软件“翻车”&#xff0c;都别吵了&#xff01;扯什么违反MIT...

241121学习日志——[CSDIY] [InternStudio] 大模型训练营 [11]

CSDIY&#xff1a;这是一个非科班学生的努力之路&#xff0c;从今天开始这个系列会长期更新&#xff0c;&#xff08;最好做到日更&#xff09;&#xff0c;我会慢慢把自己目前对CS的努力逐一上传&#xff0c;帮助那些和我一样有着梦想的玩家取得胜利&#xff01;&#xff01;&…...

跟千里马学框架 遇到的坑

在编译 aosp 的 所有的东西都是和他一样的&#xff0c; 但是出现了这个问题 emulator: command not found 明明所有的都是一样的但是出现了这个问题 &#xff0c; 啥情况 。 首先你的 ubuntu 要开启虚拟机 。 这个自己百度去 重新进行这些步骤 1、 . build/envsetup.s…...

Swift从0开始学习 协议和扩展 day5

协议:定义行为的契约 协议类似于其他语言中的接口。它们定义了一组方法、属性或其他需求,供结构体、类、枚举等类型去遵循和实现。协议并不实现这些需求,而是作为一种约定或合同,确保实现协议的类型会遵循特定的行为。 协议的定义和遵循 在 Swift 中,使用 protocol 关键…...

javaScript交互案例

1、模态框(弹出框) &#xff08;1&#xff09;、需求&#xff1a; 点击弹出层&#xff0c;会弹出模态框&#xff0c;并且显示灰色半透明的遮挡层点击关闭按钮&#xff0c;可以关闭模态框&#xff0c;并且同时关闭半透明遮挡层鼠标放在模态框最上面一行&#xff0c;可以按住鼠…...

【自动驾驶】数据集合集!

本文将为您介绍经典、热门的数据集&#xff0c;希望对您在选择适合的数据集时有所帮助。 1 Automatic-driving-Test 更新时间&#xff1a;2024-07-26 访问地址: GitHub 描述&#xff1a; 该模型使用 ultralytics yolo v8 和 deepsort 方法来检测车道与车轮的碰撞并跟踪车辆。…...

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周&#xff0c;有很多同学在写期末Java web作业时&#xff0c;运行tomcat出现乱码问题&#xff0c;经过多次解决与研究&#xff0c;我做了如下整理&#xff1a; 原因&#xff1a; IDEA本身编码与tomcat的编码与Windows编码不同导致&#xff0c;Windows 系统控制台…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module

1、为什么要修改 CONNECT 报文&#xff1f; 多租户隔离&#xff1a;自动为接入设备追加租户前缀&#xff0c;后端按 ClientID 拆分队列。零代码鉴权&#xff1a;将入站用户名替换为 OAuth Access-Token&#xff0c;后端 Broker 统一校验。灰度发布&#xff1a;根据 IP/地理位写…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

Spring Boot面试题精选汇总

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包&#xff08;Closure&#xff09;&#xff1f;闭包有什么应用场景和潜在问题&#xff1f;2.解释 JavaScript 的作用域链&#xff08;Scope Chain&#xff09; 二、原型与继承3.原型链是什么&#xff1f;如何实现继承&a…...

九天毕昇深度学习平台 | 如何安装库?

pip install 库名 -i https://pypi.tuna.tsinghua.edu.cn/simple --user 举个例子&#xff1a; 报错 ModuleNotFoundError: No module named torch 那么我需要安装 torch pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple --user pip install 库名&#x…...

云原生安全实战:API网关Kong的鉴权与限流详解

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关&#xff08;API Gateway&#xff09; API网关是微服务架构中的核心组件&#xff0c;负责统一管理所有API的流量入口。它像一座…...

JS手写代码篇----使用Promise封装AJAX请求

15、使用Promise封装AJAX请求 promise就有reject和resolve了&#xff0c;就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...