Boosting三巨头:XGBoost、LightGBM和CatBoost(发展、原理、区别和联系,附代码和案例)
Boosting三巨头:XGBoost、LightGBM和CatBoost(发展、原理、区别和联系,附代码和案例)
机器学习中,提高模型精度是研究的重点之一,而模型融合技术中,Boosting算法是一种常用的方法。在Boosting算法中,XGBoost、LightGBM和CatBoost是三个最为流行的框架。它们在实际使用中有各自的优势和适用场景,下面将会介绍它们的区别与联系。
1. 算法原理
1.1 XGBoost
XGBoost是由陈天奇等人提出的一个优化的Gradient Boosting算法,以其出色的表现和可扩展性而受到广泛关注。XGBoost使用了C++实现,可以运行在多个平台上,并支持多种编程语言,如Python、R、Java等。其原理可以概括为将弱学习器依次加入到一个全局的加权模型中,每一轮迭代都在损失函数的梯度方向上优化模型。它在原有GBDT的基础上,添加了正则化项和缺失值处理,使得模型更加稳定和准确。其原理如下:
首先,假设有n个训练样本(xi,yi)(x_{i}, y_{i})(xi,yi),其中xix_{i}xi为输入特征,yiy_{i}yi为输出值。那么,目标就是找到一个函数f(x),使得f(x)f(x)f(x)可以预测yyy的值。
其次,定义损失函数L(y,f(x))L(y, f(x))L(y,f(x)),用来度量f(x)f(x)f(x)的预测值与实际值之间的误差。
再次,我们使用Boosting算法来不断迭代提高模型精度。假设现在已经有了一个弱分类器fm−1(x)f_{m-1}(x)fm−1(x),那么我们希望找到一个新的弱分类器fm(x)f_{m}(x)fm(x)来减少L(y,f(x))L(y, f(x))L(y,f(x))的值。于是我们在已有的弱分类器fm−1(x)f_{m-1}(x)fm−1(x)基础上,加上一个新的弱分类器fm(x)f_{m}(x)fm(x),最终得到新的分类器fm(x)=fm−1(x)+γhm(x)f_{m}(x)=f_{m-1}(x)+\gamma h_{m}(x)fm(x)=fm−1(x)+γhm(x),其中γ\gammaγ为学习率,hm(x)h_{m}(x)hm(x)为新的弱分类器。
最后,由于XGBoost使用了正则化项来控制模型的复杂度,并采用了特殊的梯度下降方法进行训练,使得其在处理高维稀疏数据时,具有较好的效果。
1.2 LightGBM
LightGBM是由微软提出的一种基于Histogram算法的Gradient Boosting框架。它通过对样本特征值进行离散化,将连续特征离散化为有限个整数,从而将高维稀疏数据转化为低维稠密数据,从而加速了训练速度。相比XGBoost,LightGBM的最大优势在于其快速的训练速度和较小的内存占用,这主要得益于其采用了基于直方图的决策树算法和局部优化等技术。LightGBM的核心思想是在构造决策树时,将连续特征离散化为若干个桶,然后将每个桶作为一个离散特征对待,从而加速树的构建和训练过程。其原理如下:
首先,对于每个特征,我们需要将其离散化为一些桶,每个桶中包含一些连续的特征值。在训练时,我们只需要计算每个桶中的样本的统计信息(如平均值和方差),而不需要计算每个样本的特征值,从而减少了计算量。
其次,对于每个样本,我们根据离散化后的特征值,将其归入对应的桶中,然后计算桶中样本的统计信息。接着,我们通过梯度单边采样(GOSS)算法,选择一部分样本进行训练,这些样本中包含了大部分的梯度信息,从而保证了训练的准确性和效率。
最后,LightGBM还使用了基于直方图的决策树算法,使得在处理高维稀疏数据时,具有较好的效果。
1.3 CatBoost
CatBoost是由Yandex提出的一种基于梯度提升算法的开源机器学习框架。它在处理分类问题时,可以自动处理类别特征,无需手动进行特征编码。CatBoost的原理与XGBoost和LightGBM类似,同样是通过将多个弱学习器组合成一个强学习器。不同之处在于,CatBoost使用了一种新的损失函数,即加权交叉熵损失函数,可以有效地处理类别不平衡问题。其原理如下:
首先,CatBoost使用了一种称为Ordered Boosting的算法来提高模型精度。Ordered Boosting可以看做是一种特殊的特征选择方法,它将训练样本按照特征值大小排序,然后使用分段线性模型拟合每一段特征值的梯度,从而提高了模型的拟合能力。
其次,CatBoost在处理分类问题时,可以自动处理类别特征。它使用了一种称为Target Encoding的方法,将类别特征转化为一组实数值,从而避免了手动进行特征编码的麻烦。
最后,CatBoost还使用了基于对称树的决策树算法,使得在处理高维稀疏数据时,具有较好的效果。
2. 发展前景和应用场景
XGBoost、LightGBM和CatBoost作为目前最先进的梯度提升算法,在许多数据科学竞赛和实际应用中都取得了很好的效果。随着大数据时代的到来,这三种算法的应用场景也越来越广泛。
其中,XGBoost在传统机器学习领域仍然是最常用的算法之一,特别是在结构化数据的分类、回归和排序任务中表现突出。LightGBM在大规模数据集和高维度数据上表现更佳,适用于处理文本分类、图像分类、推荐系统等领域的数据。CatBoost在处理类别特征和缺失值方面表现出色,适用于电商推荐、医疗预测、金融风控等领域的数据。
总的来说,XGBoost、LightGBM和CatBoost作为梯度提升算法的代表,都具有自身的优势和适用场景,随着数据和计算能力的不断提升,它们的应用前景也会越来越广阔。
3. 使用案例
3.1 XGBoost
XGBoost可以应用于多种场景,如回归、分类、排序等。下面以Kaggle竞赛中的房价预测问题为例,展示如何使用XGBoost进行模型训练和预测。
首先,我们使用Pandas读取数据集,并将其划分为训练集和测试集。
import pandas as pd
from sklearn.model_selection import train_test_splitdf = pd.read_csv('train.csv')
X = df.drop('SalePrice', axis=1)
y = df['SalePrice']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
接着,我们使用XGBoost的Python接口进行模型训练和预测。
import xgboost as xgb
from sklearn.metrics import mean_squared_errordtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)params = {'max_depth': 3, 'learning_rate': 0.1, 'objective': 'reg:squarederror'}
num_rounds = 100model = xgb.train(params, dtrain, num_rounds)
y_pred = model.predict(dtest)mse = mean_squared_error(y_test, y_pred)
print("Mean Squared Error:", mse)
上述代码中,我们首先使用xgb.DMatrix将训练数据转化为DMatrix格式。接着,我们定义了模型参数,并设置了迭代次数为100。然后,我们使用xgb.train函数进行模型训练,并使用model.predict函数进行模型预测。最后,我们使用sklearn.metrics.mean_squared_error函数计算了模型的均方误差。
3.2 LightGBM
LightGBM可以应用于多种场景,如回归、分类、排序等。下面以Kaggle竞赛中的鸢尾花分类问题为例,展示如何使用LightGBM进行模型训练和预测。
首先,我们使用Pandas读取数据集,并将其划分为训练集和测试集。
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_splitdata = load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
df['target'] = data.targetX = df.drop('target', axis=1)
y = df['target']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
接着,我们使用LightGBM的Python接口进行模型训练和预测。
import lightgbm as lgb
from sklearn.metrics import accuracy_scoredtrain = lgb.Dataset(X_train, label=y_train)
dtest = lgb.Dataset(X_test, label=y_test)params = {'objective': 'multiclass', 'num_class': 3, 'metric': 'multi_logloss'}
num_rounds = 100model = lgb.train(params, dtrain, num_rounds)
y_pred = model.predict(X_test)y_pred = [list(x).index(max(x)) for x in y_pred]
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
上述代码中,我们首先使用lgb.Dataset将训练数据转化为Dataset格式。接着,我们定义了模型参数,并设置了迭代次数为100。然后,我们使用lgb.train函数进行模型训练,并使用model.predict函数进行模型预测。最后,我们使用sklearn.metrics.accuracy_score函数计算了模型的准确率。
3.3 CatBoost
import catboost as cb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_scoreiris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)train_data = cb.Pool(X_train, label=y_train)
test_data = cb.Pool(X_test, label=y_test)params = {'loss_function': 'MultiClass', 'num_class': 3, 'eval_metric': 'Accuracy'}num_rounds = 20
bst = cb.train(params, train_data, num_rounds)preds = bst.predict(X_test)
y_pred = [np.argmax(pred) for pred in preds]acc = accuracy_score(y_test, y_pred)
print("Accuracy:", acc)
综合展示
为了更好地展示XGBoost、LightGBM和CatBoost的应用场景和效果,我们以波士顿房价预测数据集为例进行实验。
首先,我们使用sklearn库中的load_boston函数加载数据集,并对数据进行划分,80%用于训练,20%用于测试。
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_splitdata = load_boston()
X_train, X_test, y_train, y_test = train_test_split(data.data, data.target, test_size=0.2, random_state=42)
然后,我们依次使用XGBoost、LightGBM和CatBoost训练模型,并对模型进行评估。
import xgboost as xgb
from lightgbm import LGBMRegressor
from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error# XGBoost
xgb_model = xgb.XGBRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42)
xgb_model.fit(X_train, y_train)
xgb_pred = xgb_model.predict(X_test)
xgb_rmse = mean_squared_error(y_test, xgb_pred, squared=False)# LightGBM
lgb_model = LGBMRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, random_state=42)
lgb_model.fit(X_train, y_train)
lgb_pred = lgb_model.predict(X_test)
lgb_rmse = mean_squared_error(y_test, lgb_pred, squared=False)# CatBoost
cat_model = CatBoostRegressor(n_estimators=100, max_depth=5, learning_rate=0.1, random_seed=42, silent=True)
cat_model.fit(X_train, y_train)
cat_pred = cat_model.predict(X_test)
cat_rmse = mean_squared_error(y_test, cat_pred, squared=False)print("XGBoost RMSE: {:.2f}".format(xgb_rmse))
print("LightGBM RMSE: {:.2f}".format(lgb_rmse))
print("CatBoost RMSE: {:.2f}".format(cat_rmse))
参考文献
[1] Chen, T., & Guestrin, C. (2016). Xgboost: A scalable tree boosting system. In Proceedings of the 22nd acm sigkdd international conference on knowledge discovery and data mining (pp. 785-794).
[2] Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., … & Liu,
G. (2017). Lightgbm: A highly efficient gradient boosting decision tree. In Advances in Neural Information Processing Systems (pp. 3146-3154).
[3] Prokhorenkova, L., Gusev, G., Vorobev, A., Dorogush, A. V., & Gulin, A. (2018). CatBoost: unbiased boosting with categorical features. In Advances in neural information processing systems (pp. 6638-6648).
[4] XGBoost官方文档:https://xgboost.readthedocs.io/en/latest/
[5] LightGBM官方文档:https://lightgbm.readthedocs.io/en/latest/
[6] CatBoost官方文档:https://catboost.ai/docs/
[7] 《Python机器学习基础教程》(吴斌):介绍了XGBoost、LightGBM和CatBoost的使用方法和实例。
[8] 《Applied Machine Learning》(Kelleher, John D.):介绍了各种机器学习算法,其中也包括了梯度提升算法和其变种。
[9] 《Hands-On Gradient Boosting with XGBoost and scikit-learn》(Villalba, Benjamin):详细介绍了XGBoost和scikit-learn库的梯度提升实现。
[10] 《Gradient Boosting》(Friedman, Jerome H.):介绍了梯度提升算法的基本思想和实现原理。
相关文章:

Boosting三巨头:XGBoost、LightGBM和CatBoost(发展、原理、区别和联系,附代码和案例)
❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博…...
设计模式~模板方法模式(Template method)-10
目录 (1)优点: (2)缺点: (3)使用场景: (4)注意事项: (5)应用实例: (6)Servlet Api & Spring 中的应用 代码 (钩子函数)在模板模式(Template Pattern)中,一个抽象类公开定…...

【WebSocket】在SSM项目中配置websocket
在SSM项目中配置websocket 最近在ssm项目中配置了websocket,踩了很多坑,来分享一下 本文暂不提供发送消息等内容的代码逻辑(后续也许会补充),如果你直接复制这类可能会对配置造成更大的麻烦(博主就是复制…...
node-red中创建自定义节点 JavaScript 文件API编写详解
前言 在node-red中如果你没有找到自己需要的节点时,那么你可以自定义一个节点来满足自己的需求。之前的文章中,我有简单介绍过如何创建一个节点,并以转换大小写来举例。例子虽然简单,但可以让大家了解创建自定义节点的步骤以及一个节点的组成部分。那么本篇将会聚焦在自定…...
华为OD机试 - 寻找路径 or 数组二叉树(C 语言解题)【独家】
最近更新的博客 华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南)华为od机试,独家整理 已参加机试人员的实战技巧文章目录 使用说明本期题目:寻找路径…...

YOLOv7、YOLOv5改进之打印热力图可视化:适用于自定义模型,丰富实验数据
💡该教程为改进YOLO高阶指南,属于《芒果书》📚系列,包含大量的原创改进方式🚀 💡更多改进内容📚可以点击查看:YOLOv5改进、YOLOv7改进、YOLOv8改进、YOLOX改进原创目录 | 唐宇迪老师联袂推荐🏆 💡🚀🚀🚀内含改进源代码 按步骤操作运行改进后的代码即可�…...
【Java代码与架构之完美优化】篇1:代码质量优化通用准则
工欲善其事,必先利其器 1. 避免使用空块 常见空块一般有以下几种情况: 多余的分号:if(xxx);多余的大括号:if(xxx){这里没有内容}空finall语句:try{...}catch(...){...}finally{这里没有内容} 空块的存在࿰…...

Linux进程间通信详解(最全)
进程间的五种通信方式介绍 进程间通信(IPC,InterProcess Communication)是指在不同进程之间传播或交换信息。IPC的方式通常有管道(包括无名管道和命名管道)、消息队列、信号量、共享内存、Socket(套接字&a…...

ROS 摄像头的使用
参考: youtubeArticulated Robotics 作者Josh Newans博客 建议: 这个只是我的看法,强烈建议看原视频或博客 png:无损压缩 jpeg:有损压缩 Driver Node 负责连接硬件设备,读取摄像头数据"ima…...

VR全景云展厅,实现7*24小时的线上宣传能力!
数字化时代,虚拟现实技术的应用越来越广泛,其中VR全景云展厅是一种新兴的展示方式,具有独特的展示优势。随着VR技术的不断发展,越来越多的企业、机构和个人开始使用VR全景云展厅来展示他们的产品和服务。一、展厅营销痛点1、实地到…...

RK3568平台开发系列讲解(显示篇) DRM显示系统组成分析
🚀返回专栏总目录 文章目录 一、DRM Framebuffer二、CRTC三、Planes四、Encoder五、Connector沉淀、分享、成长,让自己和他人都能有所收获!😄 📢让我们分析一下绿框中的五个部件,以及他们的联动。 一、DRM Framebuffer 与 framebuffer一样,是一片存放图像的内存区域,…...

WPF DataGrid控件的使用 使用列模板来进行数据格式的美化
<Grid><Grid.RowDefinitions><RowDefinition Height"0.1*" /><RowDefinition /></Grid.RowDefinitions><Button Content"刷新"FontSize"25"Command"{Binding ExecuteRefreshCommand}" /><Dat…...

elasticsearch自定义企业词典
我们中文分词用的是ik,但是ik只是对基本的中文词进行了分词,而对于企业或者人名没有进行分词。比如,我搜索中国平安,那么ik只能分成中国、平安如果这样,这肯定是不行滴!接下来,俺就教你…...

【AcWing】学了一坤时才明白的一道题
🎆音乐分享 (点击链接可以听哦) The Right Path - Thomas Greenberg 这道题小吉花了一坤时才弄明白,虽然花的时间有点长 但是至少是明白了 😎😎😎😎😎😎 …...
ES6的export和import
ES6中的模块加载ES6 模块是编译时加载,编译时就能确定模块的依赖关系,以及输入和输出的变量,相比于CommonJS 和 AMD 模块都只能在运行时确定输入输出变量的加载效率要高。严格模式ES6 的模块自动采用严格模式,不管你有没有在模块头…...

ASEMI高压MOS管20N60参数,20N60尺寸,20N60体积
编辑-Z ASEMI高压MOS管20N60参数: 型号:20N60 漏极-源极电压(VDS):600V 栅源电压(VGS):30V 漏极电流(ID):20A 功耗(PDÿ…...

【备战面试】TCP的三次握手与四次挥手
本篇总结的是计算机网络知识相关的面试题,后续也会更新其他相关内容 文章目录1、TCP头部结构2、三次握手3、四次挥手4、为什么TCP连接的时候是三次?两次是否可以?5、为什么TCP连接的时候是三次,关闭的时候却是四次?6、…...

【模板进阶】
目录 1. 非类型模板参数 2. 模板的特化 2.1 概念 2.2 函数模板特化 2.3 类模板特化 2.3.1 全特化 3 模板分离编译 3.1 什么是分离编译 3.2 模板的分离编译 4. 模板总结 有需要的老哥可以先看看模板的介绍:http://t.csdn.cn/2TkUYhttp://t.csdn.cn/2TkUY 1. …...

Tech Talk | 电致变色技术带来的智能AR体验
2023年2月27日,小米在2023MWC世界移动通信大会上,正式发布了小米无线AR眼镜探索版。这款产品搭载了创新的数控电致变色镜片,能适应不同光环境,遮光模式可以在观影时更沉浸,通透模式又能让AR虚实结合的体验更生动。“ 本…...

ACWING蓝桥杯每日一题python(持续更新
ACWing蓝桥杯每日一题 一直没时间去总结算法,终于有空可以总结一下刷的acwing了,因为没时间所以最近只刷了ACWING的蓝桥杯每日一题。。。真是该死 1.截断数组 首先我们要知道,如果sum(a)不能被3整除或者len(a) < 3 ,那么他肯…...

渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...

如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
Nginx server_name 配置说明
Nginx 是一个高性能的反向代理和负载均衡服务器,其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机(Virtual Host)。 1. 简介 Nginx 使用 server_name 指令来确定…...

SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...

python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...

[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.
ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #:…...

抽象类和接口(全)
一、抽象类 1.概念:如果⼀个类中没有包含⾜够的信息来描绘⼀个具体的对象,这样的类就是抽象类。 像是没有实际⼯作的⽅法,我们可以把它设计成⼀个抽象⽅法,包含抽象⽅法的类我们称为抽象类。 2.语法 在Java中,⼀个类如果被 abs…...
Vue 3 + WebSocket 实战:公司通知实时推送功能详解
📢 Vue 3 WebSocket 实战:公司通知实时推送功能详解 📌 收藏 点赞 关注,项目中要用到推送功能时就不怕找不到了! 实时通知是企业系统中常见的功能,比如:管理员发布通知后,所有用户…...

【版本控制】GitHub Desktop 入门教程与开源协作全流程解析
目录 0 引言1 GitHub Desktop 入门教程1.1 安装与基础配置1.2 核心功能使用指南仓库管理日常开发流程分支管理 2 GitHub 开源协作流程详解2.1 Fork & Pull Request 模型2.2 完整协作流程步骤步骤 1: Fork(创建个人副本)步骤 2: Clone(克隆…...
记一次spark在docker本地启动报错
1,背景 在docker中部署spark服务和调用spark服务的微服务,微服务之间通过fegin调用 2,问题,docker容器中服务器来后,注册中心都有,调用服务也正常,但是调用spark启动任务后报错,报错…...