当前位置: 首页 > news >正文

【机器学习】揭秘GBDT:梯度提升决策树

 

目录

🍔 提升树

🍔 梯度提升树

🍔 举例介绍

3.1 初始化弱学习器(CART树)

3.2 构建第一个弱学习器(CART树)

3.3 构建第二个弱学习器(CART树)

3.4 构建第三个弱学习器(CART树)

3.5 最终强学习器

🍔 GBDT算法

🍔 泰坦尼克号案例实战

5.1 导包并选取特征

5.2 切分数据及特征处理

5.3 三种分类器训练及预测

5.4 三种分类器性能评估

🍔 集成算法多样性

6.1 数据样本扰动

6.2 输入属性的扰动

6.3 算法参数的扰动

🍔 小结


学习目标

🍀 掌握提升树的算法原理思想

🍀 了解梯度提升树的原理思想

🍔 提升树

梯度提升树(Grandient Boosting)是提升树(Boosting Tree)的一种改进算法,所以在讲梯度提升树之前先来说一下提升树。

先来个通俗理解:假如有个人30岁,我们首先用20岁去拟合,发现损失有10岁,这时我们用6岁去拟合剩下的损失,发现差距还有4岁,第三轮我们用3岁拟合剩下的差距,差距就只有一岁了。如果我们的迭代轮数还没有完,可以继续迭代下面,每一轮迭代,拟合的岁数误差都会减小。最后将每次拟合的岁数加起来便是模型输出的结果。

上面提到的残差是什么呢?

假设:

  1. 我们前一轮迭代得到的强学习器是:ft-1(x)

  2. 损失函数是:L(y,f​t−1(x))

  3. 本轮迭代的目标是找到一个弱学习器:ht(x)

  4. 让本轮的损失最小化: L(y, ft(x))=L(y, ft−1(x)) + ht(x))

当采用平方损失函数时:

则:

令:R = y - ft-1(x),则:

此处,R 是当前模型拟合数据的残差(residual)

所以,对于提升树来说只需要简单地拟合当前模型的残差。

🍔 梯度提升树

GBDT,全称为Gradient Boosting Decision Tree,即梯度提升决策树(梯度提升树),是一种迭代的决策树算法,也被称作MART(Multiple Additive Regression Tree)。它通过将多个决策树(弱学习器)的结果进行累加来得到最终的预测输出,是集成学习算法的一种,具体属于Boosting类型。

梯度提升树不再使用拟合残差,而是利用最速下降的近似方法,利用损失函数的负梯度作为提升树算法中的残差近似值。

假设: 损失函数仍然为平方损失, 则每个样本要拟合的负梯度为:

此时, 我们发现 GBDT 拟合的负梯度就是残差,或者说对于回归问题,拟合的目标值就是残差。

如果我们的 GBDT 进行的是分类问题,则损失函数变为 logloss,此时拟合的目标值就是该损失函数的负梯度值。

🍔 举例介绍

3.1 初始化弱学习器(CART树)

我们通过计算当模型预测值为何值时,会使得第一个基学习器的平方误差最小,即:求损失函数对 f(xi) 的导数,并令导数为0.


3.2 构建第一个弱学习器(CART树)

由于我们拟合的是样本的负梯度,即:

由此得到数据表如下:

上表中平方损失计算过程说明(以切分点1.5为例):

  1. 切分点1.5 将数据集分成两份 [5.56],[5.56 5.7 5.91 6.4 6.8 7.05 8.9 8.7 9. 9.05]

  2. 第一份的平均值为5.56 第二份数据的平均值为(5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)/9 = 7.5011

  3. 由于是回归树,每份数据的平均值即为预测值,则可以计算误差,第一份数据的误差为0,第二份数据的平方误差为 :

$(5.70-7.5011)^2+(5.91-7.5011)^2+...+(9.05-7.5011)^2 = 15.72308$

以 6.5 作为切分点损失最小,构建决策树如下:

3.3 构建第二个弱学习器(CART树)

以 3.5 作为切分点损失最小,构建决策树如下:

3.4 构建第三个弱学习器(CART树)

以 6.5 作为切分点损失最小,构建决策树如下:

3.5 最终强学习器

🍔 GBDT算法

1.初始化弱学习器

2.对$m=1,2,\cdots,M$有:

(a)对每个样本$i=1,2,\cdots,N$,计算负梯度,即残差

(b)将上步得到的残差作为样本新的真实值,并将数据$(x_i,r{im}), i=1,2,..N$作为下棵树的训练数据,得到一颗新的回归树$f{m} (x)$其对应的叶子节点区域为$R_{jm}, j =1,2,\cdots,J$。其中J为回归树t的叶子节点的个数。

(c)对叶子区域$j=1,2,\cdots,J$计算最佳拟合值

(d)更新强学习器

(3)得到最终学习器

🍔 泰坦尼克号案例实战

该案例是在随机森林的基础上修改的,可以对比讲解。

数据地址:

http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt

5.1 导包并选取特征

1.数据导入
# 导入数据
import  pandas as pd
# 利用pandas的read.csv模块从互联网中收集泰坦尼克号数据集
titanic=pd.read_csv("data/titanic.csv")
titanic.info() #查看信息
2.人工选择特征pclass,age,sex
X=titanic[['pclass','age','sex']]
y=titanic['survived']
3.特征工程
# 数据的填补
X['age'].fillna(X['age'].mean(),inplace=True)

5.2 切分数据及特征处理

数据的切分
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test =train_test_split(X,y,test_size=0.25,random_state=22)
将数据转化为特征向量
from sklearn.feature_extraction import DictVectorizer
vec=DictVectorizer(sparse=False)
X_train=vec.fit_transform(X_train.to_dict(orient='records'))
X_test=vec.transform(X_test.to_dict(orient='records'))

5.3 三种分类器训练及预测

4.使用单一的决策树进行模型的训练及预测分析
from sklearn.tree import DecisionTreeClassifier
dtc=DecisionTreeClassifier()
dtc.fit(X_train,y_train)
dtc_y_pred=dtc.predict(X_test)
print("score",dtc.score(X_test,y_test))
5.随机森林进行模型的训练和预测分析
from sklearn.ensemble import RandomForestClassifier
rfc=RandomForestClassifier(random_state=9)
rfc.fit(X_train,y_train)
rfc_y_pred=rfc.predict(X_test)
print("score:forest",rfc.score(X_test,y_test))
6.GBDT进行模型的训练和预测分析
from sklearn.ensemble import GradientBoostingClassifier
gbc=GradientBoostingClassifier()
gbc.fit(X_train,y_train)
gbc_y_pred=gbc.predict(X_test)
print("score:GradientBoosting",gbc.score(X_test,y_test))

5.4 三种分类器性能评估

7.性能评估
from sklearn.metrics import classification_report
print("dtc_report:",classification_report(dtc_y_pred,y_test))
print("rfc_report:",classification_report(rfc_y_pred,y_test))
print("gbc_report:",classification_report(gbc_y_pred,y_test))

🍔 集成算法多样性

集成学习中,个体学习器多样性越大越好。通常为了增大个体学习器的多样性,在学习过程中引入随机性。常用的方法包括:对数据样本进行扰动、对输入属性进行扰动、对算法参数进行扰动。

6.1 数据样本扰动

给定数据集,可以使用采样法从中产生出不同的数据子集。然后在利用不同的数据子集训练出不同的个体学习器。

该方法简单有效,使用广泛。

(1)数据样本扰动对于“不稳定学习器”很有效。“不稳定学习器”是这样一类学习器:训练样本稍加变化就会导致学习器有显著的变动,如决策树和神经网络等。

(2)数据样本扰动对于“稳定学习器”无效。“稳定学习器”是这样一类学习器:学习器对于数据样本的扰动不敏感,如线性学习器、支持向量机、朴素贝叶斯、K近邻学习器等。

如Bagging算法就是利用Bootstrip抽样完成对数据样本的自助采样。

6.2 输入属性的扰动

训练样本通常由一组属性描述,可以基于这些属性的不同组合产生不同的数据子集,然后在利用这些数据子集训练出不同的个体学习器。

(1)若数据包含了大量冗余的属性,则输入属性扰动效果较好。此时不仅训练出了多样性大的个体,还会因为属性数量的减少而大幅节省时间开销。同时由于冗余属性多,即使减少一些属性,训练个体学习器也不会很差。

(2)若数据值包含少量属性,则不宜采用输入属性扰动法。

6.3 算法参数的扰动

通常可以通过随机设置不用的参数,比如对模型参数加入小范围的随机扰动,从而产生差别较大的个体学习器。

在使用交叉验证法(GridSearch网格搜索)来确定基学习器的参数时,实际上就是用不同的参数训练出来了多个学习器,然后从中挑选出效果最好的学习器。集成学习相当于将所有这些学习器利用起来了。

随机森林学习器就结合了数据样本的扰动及输入属性的扰动。

🍔 小结

🍬 提升树中的每一个弱学习器通过拟合残差来构建强学习器

🍬 梯度提升树中的每一个弱学习器通过拟合负梯度来构建强学习器

相关文章:

【机器学习】揭秘GBDT:梯度提升决策树

目录 🍔 提升树 🍔 梯度提升树 🍔 举例介绍 3.1 初始化弱学习器(CART树) 3.2 构建第一个弱学习器(CART树) 3.3 构建第二个弱学习器(CART树) 3.4 构建第三个弱学习…...

Android Studio 2024 安装、项目创建、加速、优化

文章目录 Android Studio安装Android Studio项目创建Android Studio加速修改GRADLE_USER_HOME位置减少C盘占用空间GRADLE加速 修改模拟器位置减少C盘占用空间参考资料 Android Studio安装 下载android studio download android-studio-2024.1.2.12-windows.exe 或者 android-…...

JSP(Java Server Pages)基础使用

首先在web文件夹中新建一个jsp/jspx文件&#xff0c;这个文件就是jsp文件 <%--Created by IntelliJ IDEA.User: ***Date: 2024/9/23Time: 18:43To change this template use File | Settings | File Templates. --%> <% page contentType"text/html;charsetUTF-…...

数据结构 - 概述及其术语

经过上一章节《数据结构与算法之间有何关系&#xff1f;》的阐述&#xff0c;相信大家对数据结构多少有了点了解&#xff0c;今天我们将进入数据结构的正式学习中。 在计算机科学中&#xff0c;数据结构是一种数据管理、组织和存储的格式。它是相互之间存在一种或多种特定关系的…...

UE5——在线子系统

Unreal Engine 5 (UE5) 的在线子系统&#xff08;Online Subsystem&#xff09;实现多人在线游戏的原理涉及到网络编程和分布式系统设计中的多个方面。以下是该系统工作的一些核心概念和技术&#xff1a; 1. 客户端-服务器架构: - 大多数现代多人在线游戏采用客户端-服务器模型…...

9.23-部署项目

部署项目 一、先部署mariadb [rootk8s-master ~]# mkdir aaa [rootk8s-master ~]# cd aaa/ [rootk8s-master aaa]# # 先部署mariadb [rootk8s-master aaa]# # configmap [rootk8s-master aaa]# vim mariadb-configmap.yaml apiVersion: v1 kind: ConfigMap metadata:name: ma…...

非标独立设计选型--二十六--电磁阀的选型件算

电磁阀&#xff1a;电磁控制---自动化的关键 PLC ---- 继电器----电磁阀----调速阀----气缸 供气源--- 【电磁阀主要负责&#xff1a;换向&#xff0c;实现气缸的动作变化】 电磁阀有哪些参数是会影响到使用的&#xff1f; …...

flume系列之:出现数据堆积时临时增大sink端消费能力

flume系列之:出现数据堆积时临时增大sink端消费能力 一、背景二、增大sink端消费能力flume系列之:flume生产环境sink重要参数理解 一、背景 flume出现数据堆积,消费的数据持续堆积在channel中参数org_apache_flume_channel_channel1_channelfillpercentage的值大于0,并且持…...

SQL Server全方位指南:从入门到高级详解

本文将分为三大部分&#xff0c;逐步深入SQL Server的基础知识、进阶技巧和高级特性&#xff0c;旨在帮助从初学者到经验丰富的开发人员深入理解和使用SQL Server。 一、入门篇 1.1 什么是SQL Server&#xff1f; SQL Server 是由微软开发的关系型数据库管理系统&#xff08…...

【JavaSE】IO模型

IO&#xff0c;英文全称是 Input/Output&#xff0c;翻译过来就是输入/输出。我们听得挺多&#xff0c;就是磁盘 IO&#xff0c;网络 IO 等。IO 即输入/输出&#xff0c;到底谁是输入&#xff1f;谁是输出&#xff1f;IO 如果脱离了主体&#xff0c;会让人疑惑。 计算机角度的…...

手术缝合线合格品检测项目众多 线径又是其重要一环!

手术缝合线的合格与否&#xff0c;关系着使用及恢复情况&#xff0c;其品质的优劣非常重要&#xff0c;而要想得到合格的手术缝合线&#xff0c;则需要多种类型的仪器进行检测。其中线径就是重要一环&#xff0c;下面来看看线径检测仪&#xff0c;并简单介绍一下其他所需检测信…...

STM32 单片机最小系统全解析

STM32 单片机最小系统全解析 本文详细介绍了 STM32 单片机最小系统&#xff0c;包括其各个组成部分及设计要点与注意事项。STM32 最小系统在嵌入式开发中至关重要&#xff0c;由电源、时钟、复位、调试接口和启动电路等组成。 在电源电路方面&#xff0c;采用 3.3V 直流电源供…...

深度解析ElasticSearch:构建高效搜索与分析的基石原创

引言 在数据爆炸的时代&#xff0c;如何快速、准确地从海量数据中检索出有价值的信息成为了企业面临的重要挑战。ElasticSearch&#xff0c;作为一款基于Lucene的开源分布式搜索和分析引擎&#xff0c;凭借其强大的实时搜索、分析和扩展能力&#xff0c;成为了众多企业的首选。…...

【Python常用模块】_PyMySQL模块详解

课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)教程合集 👈👈…...

【算法思想·二叉树】最近公共祖先问题

本文参考labuladong算法笔记[拓展&#xff1a;最近公共祖先系列解题框架 | labuladong 的算法笔记] 0、引言 如果说笔试的时候经常遇到各种动归回溯这类稍有难度的题目&#xff0c;那么面试会倾向于一些比较经典的问题&#xff0c;难度不算大&#xff0c;而且也比较实用。 本…...

如何合并pdf文件,四款软件,三步搞定!

在数字化办公的浪潮中&#xff0c;PDF文档因其跨平台兼容性和安全性&#xff0c;成为了我们日常工作中不可或缺的一部分。然而&#xff0c;面对多个PDF文件需要整合成一个文件时&#xff0c;不少小伙伴可能会感到头疼。别担心&#xff0c;今天我们就来揭秘四款高效PDF合并软件&…...

仪表放大器AD620

AD623 是一款低功耗、高精度的仪表放大器&#xff0c;而不是轨到轨运算放大器。它的输入电压范围并不覆盖整个电源电压&#xff08;轨到轨&#xff09;&#xff0c;但在单电源供电下可以处理接近地电位的输入信号。 AD620 和 AD623 都是仪表放大器&#xff0c;但它们在一些关键…...

【Qt网络编程】Tcp多线程并发服务器和客户端通信

目录 一、编写思路 1、服务器 &#xff08;1&#xff09;总体思路widget.c&#xff08;主线程&#xff09; &#xff08;2&#xff09;详细流程widget.c&#xff08;主线程&#xff09; &#xff08;1&#xff09;总体思路chat_thread.c&#xff08;处理聊天逻辑线程&…...

SkyWalking 简介

SkyWalking是什么 skywalking是一个国产开源框架,2015年由吴晟开源 , 2017年加入Apache孵化器。skywalking是分布式系统的应用 程序性能监视工具,专为微服务、云原生架构和基于容器(Docker、K8s、Mesos)架构而设计。它是一款优秀的 APM(Application Performance Manag…...

语音合成(自然、非自然)

1.环境 Python 3.10.14 2.完成代码 2.1简陋版 import pyttsx3# 初始化tts引擎 engine pyttsx3.init()# 设置语音速度 rate engine.getProperty(rate) engine.setProperty(rate, rate - 50)# 设置语音音量 volume engine.getProperty(volume) engine.setProperty(volume, …...

【第二十一章 SDIO接口(SDIO)】

第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用

文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么&#xff1f;1.1.2 感知机的工作原理 1.2 感知机的简单应用&#xff1a;基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...

省略号和可变参数模板

本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...

Scrapy-Redis分布式爬虫架构的可扩展性与容错性增强:基于微服务与容器化的解决方案

在大数据时代&#xff0c;海量数据的采集与处理成为企业和研究机构获取信息的关键环节。Scrapy-Redis作为一种经典的分布式爬虫架构&#xff0c;在处理大规模数据抓取任务时展现出强大的能力。然而&#xff0c;随着业务规模的不断扩大和数据抓取需求的日益复杂&#xff0c;传统…...

深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏

一、引言 在深度学习中&#xff0c;我们训练出的神经网络往往非常庞大&#xff08;比如像 ResNet、YOLOv8、Vision Transformer&#xff09;&#xff0c;虽然精度很高&#xff0c;但“太重”了&#xff0c;运行起来很慢&#xff0c;占用内存大&#xff0c;不适合部署到手机、摄…...

redis和redission的区别

Redis 和 Redisson 是两个密切相关但又本质不同的技术&#xff0c;它们扮演着完全不同的角色&#xff1a; Redis: 内存数据库/数据结构存储 本质&#xff1a; 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能&#xff1a; 提供丰…...

Windows电脑能装鸿蒙吗_Windows电脑体验鸿蒙电脑操作系统教程

鸿蒙电脑版操作系统来了&#xff0c;很多小伙伴想体验鸿蒙电脑版操作系统&#xff0c;可惜&#xff0c;鸿蒙系统并不支持你正在使用的传统的电脑来安装。不过可以通过可以使用华为官方提供的虚拟机&#xff0c;来体验大家心心念念的鸿蒙系统啦&#xff01;注意&#xff1a;虚拟…...

【51单片机】4. 模块化编程与LCD1602Debug

1. 什么是模块化编程 传统编程会将所有函数放在main.c中&#xff0c;如果使用的模块多&#xff0c;一个文件内会有很多代码&#xff0c;不利于组织和管理 模块化编程则是将各个模块的代码放在不同的.c文件里&#xff0c;在.h文件里提供外部可调用函数声明&#xff0c;其他.c文…...

【Java】Ajax 技术详解

文章目录 1. Filter 过滤器1.1 Filter 概述1.2 Filter 快速入门开发步骤:1.3 Filter 执行流程1.4 Filter 拦截路径配置1.5 过滤器链2. Listener 监听器2.1 Listener 概述2.2 ServletContextListener3. Ajax 技术3.1 Ajax 概述3.2 Ajax 快速入门服务端实现:客户端实现:4. Axi…...