【机器学习】揭秘GBDT:梯度提升决策树
目录
🍔 提升树
🍔 梯度提升树
🍔 举例介绍
3.1 初始化弱学习器(CART树)
3.2 构建第一个弱学习器(CART树)
3.3 构建第二个弱学习器(CART树)
3.4 构建第三个弱学习器(CART树)
3.5 最终强学习器
🍔 GBDT算法
🍔 泰坦尼克号案例实战
5.1 导包并选取特征
5.2 切分数据及特征处理
5.3 三种分类器训练及预测
5.4 三种分类器性能评估
🍔 集成算法多样性
6.1 数据样本扰动
6.2 输入属性的扰动
6.3 算法参数的扰动
🍔 小结
学习目标
🍀 掌握提升树的算法原理思想
🍀 了解梯度提升树的原理思想
🍔 提升树
梯度提升树(Grandient Boosting)是提升树(Boosting Tree)的一种改进算法,所以在讲梯度提升树之前先来说一下提升树。
先来个通俗理解:假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了。如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小。最后将每次拟合的岁数加起来便是模型输出的结果。
上面提到的残差是什么呢?
假设:
-
我们前一轮迭代得到的强学习器是:ft-1(x)
-
损失函数是:L(y,ft−1(x))
-
本轮迭代的目标是找到一个弱学习器:ht(x)
-
让本轮的损失最小化: L(y, ft(x))=L(y, ft−1(x)) + ht(x))
当采用平方损失函数时:
则:
令:R = y - ft-1(x),则:
此处,R 是当前模型拟合数据的残差(residual)
所以,对于提升树来说只需要简单地拟合当前模型的残差。
🍔 梯度提升树
GBDT,全称为Gradient Boosting Decision Tree,即梯度提升决策树(梯度提升树),是一种迭代的决策树算法,也被称作MART(Multiple Additive Regression Tree)。它通过将多个决策树(弱学习器)的结果进行累加来得到最终的预测输出,是集成学习算法的一种,具体属于Boosting类型。
梯度提升树不再使用拟合残差,而是利用最速下降的近似方法,利用损失函数的负梯度作为提升树算法中的残差近似值。
假设: 损失函数仍然为平方损失, 则每个样本要拟合的负梯度为:
此时, 我们发现 GBDT 拟合的负梯度就是残差,或者说对于回归问题,拟合的目标值就是残差。
如果我们的 GBDT 进行的是分类问题,则损失函数变为 logloss,此时拟合的目标值就是该损失函数的负梯度值。
🍔 举例介绍
3.1 初始化弱学习器(CART树)
我们通过计算当模型预测值为何值时,会使得第一个基学习器的平方误差最小,即:求损失函数对 f(xi) 的导数,并令导数为0.
3.2 构建第一个弱学习器(CART树)
由于我们拟合的是样本的负梯度,即:
由此得到数据表如下:
上表中平方损失计算过程说明(以切分点1.5为例):
切分点1.5 将数据集分成两份 [5.56],[5.56 5.7 5.91 6.4 6.8 7.05 8.9 8.7 9. 9.05]
第一份的平均值为5.56 第二份数据的平均值为(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)/9 = 7.5011
由于是回归树,每份数据的平均值即为预测值,则可以计算误差,第一份数据的误差为0,第二份数据的平方误差为 :
$(5.70-7.5011)^2+(5.91-7.5011)^2+...+(9.05-7.5011)^2 = 15.72308$
以 6.5 作为切分点损失最小,构建决策树如下:
3.3 构建第二个弱学习器(CART树)
以 3.5 作为切分点损失最小,构建决策树如下:
3.4 构建第三个弱学习器(CART树)
以 6.5 作为切分点损失最小,构建决策树如下:
3.5 最终强学习器
🍔 GBDT算法
1.初始化弱学习器
2.对$m=1,2,\cdots,M$有:
(a)对每个样本$i=1,2,\cdots,N$,计算负梯度,即残差
(b)将上步得到的残差作为样本新的真实值,并将数据$(x_i,r{im}), i=1,2,..N$作为下棵树的训练数据,得到一颗新的回归树$f{m} (x)$其对应的叶子节点区域为$R_{jm}, j =1,2,\cdots,J$。其中J为回归树t的叶子节点的个数。
(c)对叶子区域$j=1,2,\cdots,J$计算最佳拟合值
(d)更新强学习器
(3)得到最终学习器
🍔 泰坦尼克号案例实战
该案例是在随机森林的基础上修改的,可以对比讲解。
数据地址:
http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt
5.1 导包并选取特征
1.数据导入
# 导入数据
import pandas as pd
# 利用pandas的read.csv模块从互联网中收集泰坦尼克号数据集
titanic=pd.read_csv("data/titanic.csv")
titanic.info() #查看信息
2.人工选择特征pclass,age,sex
X=titanic[['pclass','age','sex']]
y=titanic['survived']
3.特征工程
# 数据的填补
X['age'].fillna(X['age'].mean(),inplace=True)
5.2 切分数据及特征处理
数据的切分
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test =train_test_split(X,y,test_size=0.25,random_state=22)
将数据转化为特征向量
from sklearn.feature_extraction import DictVectorizer
vec=DictVectorizer(sparse=False)
X_train=vec.fit_transform(X_train.to_dict(orient='records'))
X_test=vec.transform(X_test.to_dict(orient='records'))
5.3 三种分类器训练及预测
4.使用单一的决策树进行模型的训练及预测分析
from sklearn.tree import DecisionTreeClassifier
dtc=DecisionTreeClassifier()
dtc.fit(X_train,y_train)
dtc_y_pred=dtc.predict(X_test)
print("score",dtc.score(X_test,y_test))
5.随机森林进行模型的训练和预测分析
from sklearn.ensemble import RandomForestClassifier
rfc=RandomForestClassifier(random_state=9)
rfc.fit(X_train,y_train)
rfc_y_pred=rfc.predict(X_test)
print("score:forest",rfc.score(X_test,y_test))
6.GBDT进行模型的训练和预测分析
from sklearn.ensemble import GradientBoostingClassifier
gbc=GradientBoostingClassifier()
gbc.fit(X_train,y_train)
gbc_y_pred=gbc.predict(X_test)
print("score:GradientBoosting",gbc.score(X_test,y_test))
5.4 三种分类器性能评估
7.性能评估
from sklearn.metrics import classification_report
print("dtc_report:",classification_report(dtc_y_pred,y_test))
print("rfc_report:",classification_report(rfc_y_pred,y_test))
print("gbc_report:",classification_report(gbc_y_pred,y_test))
🍔 集成算法多样性
集成学习中,个体学习器多样性越大越好。通常为了增大个体学习器的多样性,在学习过程中引入随机性。常用的方法包括:对数据样本进行扰动、对输入属性进行扰动、对算法参数进行扰动。
6.1 数据样本扰动
给定数据集,可以使用采样法从中产生出不同的数据子集。然后在利用不同的数据子集训练出不同的个体学习器。
该方法简单有效,使用广泛。
(1)数据样本扰动对于“不稳定学习器”很有效。“不稳定学习器”是这样一类学习器:训练样本稍加变化就会导致学习器有显著的变动,如决策树和神经网络等。
(2)数据样本扰动对于“稳定学习器”无效。“稳定学习器”是这样一类学习器:学习器对于数据样本的扰动不敏感,如线性学习器、支持向量机、朴素贝叶斯、K近邻学习器等。
如Bagging算法就是利用Bootstrip抽样完成对数据样本的自助采样。
6.2 输入属性的扰动
训练样本通常由一组属性描述,可以基于这些属性的不同组合产生不同的数据子集,然后在利用这些数据子集训练出不同的个体学习器。
(1)若数据包含了大量冗余的属性,则输入属性扰动效果较好。此时不仅训练出了多样性大的个体,还会因为属性数量的减少而大幅节省时间开销。同时由于冗余属性多,即使减少一些属性,训练个体学习器也不会很差。
(2)若数据值包含少量属性,则不宜采用输入属性扰动法。
6.3 算法参数的扰动
通常可以通过随机设置不用的参数,比如对模型参数加入小范围的随机扰动,从而产生差别较大的个体学习器。
在使用交叉验证法(GridSearch网格搜索)来确定基学习器的参数时,实际上就是用不同的参数训练出来了多个学习器,然后从中挑选出效果最好的学习器。集成学习相当于将所有这些学习器利用起来了。
随机森林学习器就结合了数据样本的扰动及输入属性的扰动。
🍔 小结
🍬 提升树中的每一个弱学习器通过拟合残差来构建强学习器
🍬 梯度提升树中的每一个弱学习器通过拟合负梯度来构建强学习器
相关文章:

【机器学习】揭秘GBDT:梯度提升决策树
目录 🍔 提升树 🍔 梯度提升树 🍔 举例介绍 3.1 初始化弱学习器(CART树) 3.2 构建第一个弱学习器(CART树) 3.3 构建第二个弱学习器(CART树) 3.4 构建第三个弱学习…...

Android Studio 2024 安装、项目创建、加速、优化
文章目录 Android Studio安装Android Studio项目创建Android Studio加速修改GRADLE_USER_HOME位置减少C盘占用空间GRADLE加速 修改模拟器位置减少C盘占用空间参考资料 Android Studio安装 下载android studio download android-studio-2024.1.2.12-windows.exe 或者 android-…...

JSP(Java Server Pages)基础使用
首先在web文件夹中新建一个jsp/jspx文件,这个文件就是jsp文件 <%--Created by IntelliJ IDEA.User: ***Date: 2024/9/23Time: 18:43To change this template use File | Settings | File Templates. --%> <% page contentType"text/html;charsetUTF-…...

数据结构 - 概述及其术语
经过上一章节《数据结构与算法之间有何关系?》的阐述,相信大家对数据结构多少有了点了解,今天我们将进入数据结构的正式学习中。 在计算机科学中,数据结构是一种数据管理、组织和存储的格式。它是相互之间存在一种或多种特定关系的…...
UE5——在线子系统
Unreal Engine 5 (UE5) 的在线子系统(Online Subsystem)实现多人在线游戏的原理涉及到网络编程和分布式系统设计中的多个方面。以下是该系统工作的一些核心概念和技术: 1. 客户端-服务器架构: - 大多数现代多人在线游戏采用客户端-服务器模型…...

9.23-部署项目
部署项目 一、先部署mariadb [rootk8s-master ~]# mkdir aaa [rootk8s-master ~]# cd aaa/ [rootk8s-master aaa]# # 先部署mariadb [rootk8s-master aaa]# # configmap [rootk8s-master aaa]# vim mariadb-configmap.yaml apiVersion: v1 kind: ConfigMap metadata:name: ma…...
非标独立设计选型--二十六--电磁阀的选型件算
电磁阀:电磁控制---自动化的关键 PLC ---- 继电器----电磁阀----调速阀----气缸 供气源--- 【电磁阀主要负责:换向,实现气缸的动作变化】 电磁阀有哪些参数是会影响到使用的? …...
flume系列之:出现数据堆积时临时增大sink端消费能力
flume系列之:出现数据堆积时临时增大sink端消费能力 一、背景二、增大sink端消费能力flume系列之:flume生产环境sink重要参数理解 一、背景 flume出现数据堆积,消费的数据持续堆积在channel中参数org_apache_flume_channel_channel1_channelfillpercentage的值大于0,并且持…...
SQL Server全方位指南:从入门到高级详解
本文将分为三大部分,逐步深入SQL Server的基础知识、进阶技巧和高级特性,旨在帮助从初学者到经验丰富的开发人员深入理解和使用SQL Server。 一、入门篇 1.1 什么是SQL Server? SQL Server 是由微软开发的关系型数据库管理系统(…...
【JavaSE】IO模型
IO,英文全称是 Input/Output,翻译过来就是输入/输出。我们听得挺多,就是磁盘 IO,网络 IO 等。IO 即输入/输出,到底谁是输入?谁是输出?IO 如果脱离了主体,会让人疑惑。 计算机角度的…...

手术缝合线合格品检测项目众多 线径又是其重要一环!
手术缝合线的合格与否,关系着使用及恢复情况,其品质的优劣非常重要,而要想得到合格的手术缝合线,则需要多种类型的仪器进行检测。其中线径就是重要一环,下面来看看线径检测仪,并简单介绍一下其他所需检测信…...

STM32 单片机最小系统全解析
STM32 单片机最小系统全解析 本文详细介绍了 STM32 单片机最小系统,包括其各个组成部分及设计要点与注意事项。STM32 最小系统在嵌入式开发中至关重要,由电源、时钟、复位、调试接口和启动电路等组成。 在电源电路方面,采用 3.3V 直流电源供…...
深度解析ElasticSearch:构建高效搜索与分析的基石原创
引言 在数据爆炸的时代,如何快速、准确地从海量数据中检索出有价值的信息成为了企业面临的重要挑战。ElasticSearch,作为一款基于Lucene的开源分布式搜索和分析引擎,凭借其强大的实时搜索、分析和扩展能力,成为了众多企业的首选。…...

【Python常用模块】_PyMySQL模块详解
课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)教程合集 👈👈…...

【算法思想·二叉树】最近公共祖先问题
本文参考labuladong算法笔记[拓展:最近公共祖先系列解题框架 | labuladong 的算法笔记] 0、引言 如果说笔试的时候经常遇到各种动归回溯这类稍有难度的题目,那么面试会倾向于一些比较经典的问题,难度不算大,而且也比较实用。 本…...

如何合并pdf文件,四款软件,三步搞定!
在数字化办公的浪潮中,PDF文档因其跨平台兼容性和安全性,成为了我们日常工作中不可或缺的一部分。然而,面对多个PDF文件需要整合成一个文件时,不少小伙伴可能会感到头疼。别担心,今天我们就来揭秘四款高效PDF合并软件&…...

仪表放大器AD620
AD623 是一款低功耗、高精度的仪表放大器,而不是轨到轨运算放大器。它的输入电压范围并不覆盖整个电源电压(轨到轨),但在单电源供电下可以处理接近地电位的输入信号。 AD620 和 AD623 都是仪表放大器,但它们在一些关键…...

【Qt网络编程】Tcp多线程并发服务器和客户端通信
目录 一、编写思路 1、服务器 (1)总体思路widget.c(主线程) (2)详细流程widget.c(主线程) (1)总体思路chat_thread.c(处理聊天逻辑线程&…...
SkyWalking 简介
SkyWalking是什么 skywalking是一个国产开源框架,2015年由吴晟开源 , 2017年加入Apache孵化器。skywalking是分布式系统的应用 程序性能监视工具,专为微服务、云原生架构和基于容器(Docker、K8s、Mesos)架构而设计。它是一款优秀的 APM(Application Performance Manag…...

语音合成(自然、非自然)
1.环境 Python 3.10.14 2.完成代码 2.1简陋版 import pyttsx3# 初始化tts引擎 engine pyttsx3.init()# 设置语音速度 rate engine.getProperty(rate) engine.setProperty(rate, rate - 50)# 设置语音音量 volume engine.getProperty(volume) engine.setProperty(volume, …...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验
一、多模态商品数据接口的技术架构 (一)多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如,当用户上传一张“蓝色连衣裙”的图片时,接口可自动提取图像中的颜色(RGB值&…...

Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...

如何在网页里填写 PDF 表格?
有时候,你可能希望用户能在你的网站上填写 PDF 表单。然而,这件事并不简单,因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件,但原生并不支持编辑或填写它们。更糟的是,如果你想收集表单数据ÿ…...
Xen Server服务器释放磁盘空间
disk.sh #!/bin/bashcd /run/sr-mount/e54f0646-ae11-0457-b64f-eba4673b824c # 全部虚拟机物理磁盘文件存储 a$(ls -l | awk {print $NF} | cut -d. -f1) # 使用中的虚拟机物理磁盘文件 b$(xe vm-disk-list --multiple | grep uuid | awk {print $NF})printf "%s\n"…...

C++ 设计模式 《小明的奶茶加料风波》
👨🎓 模式名称:装饰器模式(Decorator Pattern) 👦 小明最近上线了校园奶茶配送功能,业务火爆,大家都在加料: 有的同学要加波霸 🟤,有的要加椰果…...
作为测试我们应该关注redis哪些方面
1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...
【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统
Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...
电脑桌面太单调,用Python写一个桌面小宠物应用。
下面是一个使用Python创建的简单桌面小宠物应用。这个小宠物会在桌面上游荡,可以响应鼠标点击,并且有简单的动画效果。 import tkinter as tk import random import time from PIL import Image, ImageTk import os import sysclass DesktopPet:def __i…...