当前位置: 首页 > news >正文

XGBOOST算法Python实现(保姆级)

摘要

        XGBoost算法(eXtreme Gradient Boosting)在目前的Kaggle、数学建模和大数据应用等竞赛中非常流行。本文将会从XGBOOST算法原理、Python实现、敏感性分析和实际应用进行详细说明。

目录

0 绪论

一、材料准备

二、算法原理

三、算法Python实现

        3.1 数据加载

        3.2 将目标变量的定类数据分类编码

        3.3 将数据分为训练数据和测试数据

        3.4训练XGBOOST模型

        3.5 测试模型

        3.6 输出模型的预测混淆矩阵(结果矩阵)

        3.7 输出模型准确率

        3.8 绘制混淆矩阵图

         3.9 完整实现代码

        3.10 结果输出示例

四、 XGBOOST算法的敏感性分析和实际应用

        4.1 敏感性分析

        4.2 算法应用

五、结论

六、备注

0 绪论

        数据挖掘和数学建模等比赛中,除了算法的实现,还需要对数据进行较为合理的预处理,包括缺失值处理、异常值处理、定类数据特征编码和冗余特征的删除等等,本文默认读者的数据均已完成数据预处理,如有需要,后续会将数据预处理的方法也进行发布。

一、材料准备

        Python编译器:Pycharm社区版或个人版等

        训练数据集:此处使用2022年数维杯国际大学生数学建模竞赛C题的附件数据为例。

        数据处理:经过初步数据清洗和相关性分析等操作得到初步的特征,并利用决策树进行特征重要性分析,完成二次特征降维,得到'CDRSB_bl', 'PIB_bl', 'FBB_bl'三个自变量特征,DX_bl为分类特征。

二、算法原理

     XGBOOST算法基于决策树的集成方法,主要采用了Boosting的思想,是Gradient Boosting算法的扩展,并使用梯度提升技术来提高模型的准确性和泛化能力。

        首先将基分类器层层叠加,然后每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重,XGBOOST的目标函数为:

    (1)

        其中,为损失函数;为正则项,用于控制树的复杂度;为常数项,为新树的预测值,它是将树的个数的结果进行求和。

三、算法Python实现

3.1 数据加载

        此处导入本文所需数据,DataX为自变量数据,DataY为目标变量数据(DX_bl)。

import pandas as pd
X = pd.DataFrame(pd.read_excel('DataX.xlsx')).values  # 输入特征
y = pd.DataFrame(pd.read_excel('DataY.xlsx')).values  # 目标变量

 3.2 将目标变量的定类数据分类编码

此处仅用0-4来代替五类数据,因为此处仅做预测,并不涉及相关性分析等其他操作,所以普通的分类编码就可以。如果需要用来做相关性分析或其他计算型操作,建议使用独热编码(OneHot- Encoding)。

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y = le.fit_transform(y)
label_mapping = {0: 'AD', 1: 'CN', 2: 'EMCI', 3: 'LMCI', 4: 'SMC'}
#此处为了后续输出混淆矩阵时,用原始数据输出

 3.3 将数据分为训练数据和测试数据

        本文将原始样本数据通过随机洗牌,并将70%的样本数据作为训练数据,30%的样本数据作为测试数据。这是一个较为常见的拆分方法,读者可通过不同的拆分测试最佳准确率和F1-score。

from sklearn.model_selection import train_test_split
# 将数据分为训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, train_size=0.7, random_state=42)

 3.4训练XGBOOST模型

        基于70%的样本数据进行训练建模,python有XGBOOST算法的库,所以很方便就可以调用。

import xgboost as xgb
# 训练XGBoost分类器
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
#xgb.plot_tree(model)

 3.5 测试模型

        利用另外的30%样本数据进行测试模型准确率、精确率、召回率和F1度量值。

# 使用测试数据预测类别
y_pred = model.predict(X_test)

 3.6 输出模型的预测混淆矩阵(结果矩阵)

        此处输出混淆矩阵的方法和之前的随机森林、KNN算法都有点不同,因为随机森拉算法不需要将定类数据进行分类编码就可以直接调用随机森林算法模型。

from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
cm = confusion_matrix(y_test, y_pred)
# 输出混淆矩阵
for i, true_label in enumerate(label_mapping.values()):row = ''for j, pred_label in enumerate(label_mapping.values()):row += f'{cm[i, j]} ({pred_label})\t'print(f'{row} | {true_label}')# 输出混淆矩阵
print(classification_report(y_test, y_pred,target_names=['AD', 'CN', 'EMCI', 'LMCI', 'SMC']))  # 输出混淆矩阵

 3.7 输出模型准确率

#此处的导库在上一个代码段中已引入
print("Accuracy:")
print(accuracy_score(y_test, y_pred))

 3.8 绘制混淆矩阵图

        将混淆矩阵结果图绘制并输出,可以将这一结果图放在论文中,提升论文美感和信服度。

import matplotlib.pyplot as plt
import numpy as np
label_names = ['AD', 'CN', 'EMCI', 'LMCI', 'SMC']
cm = confusion_matrix(y_test, y_pred)# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),yticks=np.arange(cm.shape[0]),xticklabels=label_names, yticklabels=label_names,title='Confusion matrix',ylabel='True label',xlabel='Predicted label')# 在矩阵图中显示数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, format(cm[i, j], 'd'),ha="center", va="center",color="white" if cm[i, j] > thresh else "black")fig.tight_layout()
#plt.show()
plt.savefig('XGBoost_Conclusion.png', dpi=300)

        上面的代码首先计算混淆矩阵,然后使用 matplotlib 库中的 imshow 函数将混淆矩阵可视化,最后通过 text 函数在混淆矩阵上添加数字,并使用 show/savefig 函数显示图像,结果输出如图3.1所示。

图3.1 混淆矩阵结果图

  3.9 完整实现代码

# 导入需要的库
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import numpy as nple = LabelEncoder()
label_mapping = {0: 'AD', 1: 'CN', 2: 'EMCI', 3: 'LMCI', 4: 'SMC'}
X = pd.DataFrame(pd.read_excel('DataX.xlsx')).values  # 输入特征
y = pd.DataFrame(pd.read_excel('DataY.xlsx')).values  # 目标变量
y = le.fit_transform(y)
# 将数据分为训练数据和测试数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, train_size=0.7, random_state=42)
# 训练XGBoost分类器
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
#xgb.plot_tree(model)
# 使用测试数据预测类别
y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
# 输出混淆矩阵
for i, true_label in enumerate(label_mapping.values()):row = ''for j, pred_label in enumerate(label_mapping.values()):row += f'{cm[i, j]} ({pred_label})\t'print(f'{row} | {true_label}')# 输出混淆矩阵
print(classification_report(y_test, y_pred,target_names=['AD', 'CN', 'EMCI', 'LMCI', 'SMC']))  # 输出混淆矩阵
print("Accuracy:")
print(accuracy_score(y_test, y_pred))# label_names 是分类变量的取值名称列表
label_names = ['AD', 'CN', 'EMCI', 'LMCI', 'SMC']
cm = confusion_matrix(y_test, y_pred)# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),yticks=np.arange(cm.shape[0]),xticklabels=label_names, yticklabels=label_names,title='Confusion matrix',ylabel='True label',xlabel='Predicted label')# 在矩阵图中显示数字标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):for j in range(cm.shape[1]):ax.text(j, i, format(cm[i, j], 'd'),ha="center", va="center",color="white" if cm[i, j] > thresh else "black")fig.tight_layout()
#plt.show()
plt.savefig('XGBoost_Conclusion.png', dpi=300)
# 上面的代码首先计算混淆矩阵,然后使用 matplotlib 库中的 imshow 函数将混淆矩阵可视化,最后通过 text 函数在混淆矩阵上添加数字,并使用 show/savefig 函数显示图像。

 3.10 结果输出示例

       

 图3.2 结果输出示例

四、 XGBOOST算法的敏感性分析和实际应用

 4.1 敏感性分析

         敏感性分析也叫做稳定性分析,可以基于统计学思想,通过百次测试,记录其准确率、精确率、召回率和F1-Score的数据,统计其中位数、平均值、最大值和最小值等数据,从而进行对应的敏感性分析。结果表明符合原模型成立,则通过了敏感性分析。前面的随机森林算法和KNN算法也是如此。

 4.2 算法应用

         XGBOOST算法可应用于大数据分析、预测等方面,尤其是大数据竞赛(Kaggle、阿里天池等竞赛中)特别常用,也是本人目前认为最好用的一个算法。

五、结论

        本文基于XGBOOST算法,从数据预处理、算法原理、算法实现、敏感性分析和算法应用都做了具体的分析,可适用于大部分机器学习算法初学者。

六、备注

        本文为原创文章,禁止转载,违者必究。如需原始数据,可点赞+收藏,然后私聊作者或在评论区中留下你的邮箱,即可获得训练数据一份。

相关文章:

XGBOOST算法Python实现(保姆级)

摘要 XGBoost算法(eXtreme Gradient Boosting)在目前的Kaggle、数学建模和大数据应用等竞赛中非常流行。本文将会从XGBOOST算法原理、Python实现、敏感性分析和实际应用进行详细说明。 目录 0 绪论 一、材料准备 二、算法原理 三、算法Python实现 3…...

JDK、MAVEN与IDEA的安装与配置

1.认识JDK、MAVEN与IDEA JDK 提供了编译和运行Java程序的基本环境。Maven 帮助管理项目的构建和依赖。IDEA 提供了一个强大的开发环境,使得编写、调试和运行Java程序更加高效。 2. 安装与环境配置 2.1 官网地址 选择你需要的版本下载: MAVEN下载传送…...

输出比较简介

输出比较简介 主要是用来输出PWM波形,这个波形是驱动电机的(智能车和机器人等)必要条件 OC(Output Compare)输出比较,还有IC,全称是Input Capture,意为输入捕获,还有CC…...

什么是反向 DNS 查找以及它的作用是什么?

反向DNS查询(rDNS)是一种技术,用于确定与某个IP地址对应的域名。当我们对一个IP地址进行反向DNS查询时,实际上是向域名系统(DNS)的特殊部分请求信息,这部分被称为PTR记录。PTR记录会返回与这个I…...

集群聊天服务器(13)redis环境安装和发布订阅命令

目录 环境安装订阅redis发布-订阅的客户端编程环境配置客户端编程 功能测试 环境安装 sudo apt-get install redis-server 先启动redis服务 /etc/init.d/redis-server start默认在6379端口上 redis是存键值对的,还可以存链表、数组等等复杂数据结构 而且数据是在…...

[ubuntu]编译共享内存读取出现read.c:(.text+0x1a): undefined reference to `shm_open‘问题解决方案

问题log /tmp/ccByifPx.o: In function main: read.c:(.text0x1a): undefined reference to shm_open read.c:(.text0xd9): undefined reference to shm_unlink collect2: error: ld returned 1 exit status 程序代码 #include <stdio.h> #include <stdlib.h> #…...

Python Matplotlib 安装指南:使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装

Python Matplotlib 安装指南&#xff1a;使用 Miniconda 实现跨 Linux、macOS 和 Windows 平台安装 Matplotlib是Python最常用的数据可视化工具之一&#xff0c;结合Miniconda可以轻松管理安装和依赖项。在这篇文章中&#xff0c;我们将详细介绍如何使用Miniconda在Linux、mac…...

DimensionX 部署笔记

目录 生成视频用CogVideoX-5b-I2V 推理代码&#xff1a; DimensionX 生成视频用CogVideoX-5b-I2V 推理代码&#xff1a; 可以生成&#xff0c;从左向右旋转的&#xff0c;也可以生成从上往下旋转的&#xff1a; import torch from diffusers import CogVideoXImageToVideo…...

django从入门到精通(五)——表单与模型

好的&#xff0c;下面将详细介绍 Django 的表单与模型&#xff0c;包括它们的定义、使用、如何在 Django Admin 中结合使用&#xff0c;以及相关的字段类型和验证机制。 Django 模型与表单 1. Django 模型 Django 模型是一个 Python 类&#xff0c;用于定义数据库中的数据结…...

C语言Day 03 学习总结

Day 03 学习总结 流程控制语句 顺序结构 程序从上到下依次执行。每一条语句顺序执行&#xff0c;直到结束。 选择结构 程序根据条件选择执行某一条分支。包括 if-else 和 switch-case。 循环结构 程序反复执行某段代码。包括 for、while、do-while。 跳转结构 控制程序直接跳…...

kafka中是如何快速定位到一个offset的

定位到具体的segment日志文件&#xff0c;采用二分法先定位到index索引文件计算查找的offset在日志文件的相对偏移量 1、分区和日志段&#xff1a; 每个主题的分区&#xff08;Partition&#xff09;被划分为多个日志段&#xff08;Log Segment&#xff09;。每个日志段是一个…...

视频对接rtsp协议学习

RTSP协议在视频平台中的应用‌ RTSP&#xff08;Real Time Streaming Protocol&#xff09;是一种基于TCP/IP的应用层协议&#xff0c;主要用于控制流媒体数据的传输和播放。它通过定义一系列命令和请求&#xff0c;实现对流媒体服务器的远程控制&#xff0c;但不传输媒体数据…...

【系统架构设计师】真题论文: 论企业信息化规划的实施与应用(包括解题思路和素材)

更多内容请见: 备考系统架构设计师-专栏介绍和目录 文章目录 真题题目(2012年 试题4)解题思路论文素材参考企业信息化规划概念与主要内容企业信息化规划实施的步骤企业信息化规划的应用案例真题题目(2012年 试题4) 企业信息化建设是一项长期而艰巨的任务,不可能在短时间…...

【ARM Coresight OpenOCD 系列 6.1 -- JTAG Commands】

请阅读【嵌入式开发学习必备专栏】 文章目录 JTAG Transport使用场景配置示例JTAG Speed配置示例初始化过程中的速度调整自适应时钟选择合适的速度Low Level JTAG Commandsdrscanflush_countirscanpathmoveruntestverify_ircaptureverify_jtagJTAG Transport OpenOCD 是一个强…...

开源许可协议

何同学推动了开源协议的认识&#xff0c;功不可没&#xff0c;第一次对开源有了清晰的认识&#xff0c;最宽松的MIT开源协议 源自OSC开源社区&#xff1a;何同学使用开源软件“翻车”&#xff0c;都别吵了&#xff01;扯什么违反MIT...

241121学习日志——[CSDIY] [InternStudio] 大模型训练营 [11]

CSDIY&#xff1a;这是一个非科班学生的努力之路&#xff0c;从今天开始这个系列会长期更新&#xff0c;&#xff08;最好做到日更&#xff09;&#xff0c;我会慢慢把自己目前对CS的努力逐一上传&#xff0c;帮助那些和我一样有着梦想的玩家取得胜利&#xff01;&#xff01;&…...

跟千里马学框架 遇到的坑

在编译 aosp 的 所有的东西都是和他一样的&#xff0c; 但是出现了这个问题 emulator: command not found 明明所有的都是一样的但是出现了这个问题 &#xff0c; 啥情况 。 首先你的 ubuntu 要开启虚拟机 。 这个自己百度去 重新进行这些步骤 1、 . build/envsetup.s…...

Swift从0开始学习 协议和扩展 day5

协议:定义行为的契约 协议类似于其他语言中的接口。它们定义了一组方法、属性或其他需求,供结构体、类、枚举等类型去遵循和实现。协议并不实现这些需求,而是作为一种约定或合同,确保实现协议的类型会遵循特定的行为。 协议的定义和遵循 在 Swift 中,使用 protocol 关键…...

javaScript交互案例

1、模态框(弹出框) &#xff08;1&#xff09;、需求&#xff1a; 点击弹出层&#xff0c;会弹出模态框&#xff0c;并且显示灰色半透明的遮挡层点击关闭按钮&#xff0c;可以关闭模态框&#xff0c;并且同时关闭半透明遮挡层鼠标放在模态框最上面一行&#xff0c;可以按住鼠…...

【自动驾驶】数据集合集!

本文将为您介绍经典、热门的数据集&#xff0c;希望对您在选择适合的数据集时有所帮助。 1 Automatic-driving-Test 更新时间&#xff1a;2024-07-26 访问地址: GitHub 描述&#xff1a; 该模型使用 ultralytics yolo v8 和 deepsort 方法来检测车道与车轮的碰撞并跟踪车辆。…...

el-table表头前几列固定,后面几列根据接口返回的值不同展示不同

在使用 Element UI 的 el-table 组件时&#xff0c;如果想要实现表头的前几列固定&#xff0c;而后面的列根据接口返回的数据动态展示&#xff0c;可以通过以下步骤来实现&#xff1a; 1. 固定表头前几列 在 el-table-column 中使用 fixed 属性来固定表头的前几列。例如&…...

【Redis】redis缓存击穿,缓存雪崩,缓存穿透

一、什么是缓存&#xff1f; 缓存就是与数据交互中的缓冲区&#xff0c;它一般存储在内存中且读写效率高&#xff0c;提高响应时间提高并发性能&#xff0c;如果访问数据的话可以先访问缓存&#xff0c;避免数据查询直接操作数据库&#xff0c;造成后端压力过大。 但是可能会面…...

HBase Flink操作

Apache Flink 是一个开源的分布式流处理框架&#xff0c;能够高效地处理和分析实时数据流以及批数据。HBase 是一个分布式、面向列的开源数据库&#xff0c;是 Hadoop 项目的子项目&#xff0c;适合非结构化数据结构的存储&#xff0c;并提供实时读写能力。以下是关于 Flink 对…...

C# .Net Core通过StreamLoad向Doris写入CSV数据

以下代码可以只关注StreamLoad具体实现。 1.创建StreamLoad返回值Model public class StreamLoadResponse {public long TxnId { get; set; }public string Label { get; set; }public string Comment { get; set; }public string TwoPhaseCommit { get; set; }public string…...

React-自定义Hook与逻辑共享

#题引&#xff1a;我认为跟着官方文档学习不会走歪路 在 React 中&#xff0c;自定义 Hook 是一种复用逻辑的方式。自定义 Hook 是一个 JavaScript 函数&#xff0c;名称以 use 开头&#xff0c;可以调用其他的 Hook, 可以返回任意值。 创建自定义Hook 假设你正在开发一款重…...

蓝桥杯每日真题 - 第17天

题目&#xff1a;&#xff08;最大数字&#xff09; 题目描述&#xff08;13届 C&C B组D题&#xff09; 题目分析&#xff1a; 操作规则&#xff1a; 1号操作&#xff1a;将数字加1&#xff08;如果该数字为9&#xff0c;变为0&#xff09;。 2号操作&#xff1a;将数字…...

游戏开发实现简易实用的ui框架

游戏开发实现简易实用的ui框架 本文使用cocos引擎实现&#xff0c;框架代码本质上不依赖某一个引擎&#xff0c;稍作修改也能作为其他引擎的实现 1.1 UI管理框架的核心需求剖析 分层与类型管理 对不同类型UI需要进行分层管理。不同层级的UI需要有不同的父节点&#xff0c;保证渲…...

vue3的attr透传属性详解和使用法方式。以及在css样式的伪元素中实现

在 Vue 3 和 TypeScript 中&#xff0c;属性透传&#xff08;attr pass-through&#xff09;是指将组件的属性传递到其根元素或某个子元素中。这个概念在开发可复用的组件时非常有用&#xff0c;尤其是当你希望将父组件的属性动态地传递给子组件的某个 DOM 元素时。 在 Vue 3 …...

【仿真建模-MESA】框架简介

1. 简介 Mesa是一个基于Python3的开源项目&#xff0c;旨在提供一个现代、易用的多智能体仿真环境。它借鉴了NetLogo、Repast和MASON等多智能体仿真框架的优点&#xff0c;并结合Python语言的强大功能&#xff0c;为用户提供了丰富的建模和仿真工具。 《官方文档》 2. 核心组件…...

Linux环境基础开发工具的使用(yum、vim、gcc、g++、gdb、make/Makefile)

目录 Linux软件包管理器 - yum Linux下安装软件包的方式 认识yum 查找软件包 安装软件 如何实现本地机器和云服务器之间的文件互传 卸载软件 Linux编辑器 - vim vim的基本概念 vim下各模式的切换 批量化注释 vim的简单配置 Linux编译器 - gcc/g gcc/g的作用 gcc/g语…...