梯度提升算法决策过程的逐步可视化
梯度提升算法是最常用的集成机器学习技术之一,该模型使用弱决策树序列来构建强学习器。这也是XGBoost和LightGBM模型的理论基础,所以在这篇文章中,我们将从头开始构建一个梯度增强模型并将其可视化。
梯度提升算法介绍
梯度提升算法(Gradient Boosting)是一种集成学习算法,它通过构建多个弱分类器,然后将它们组合成一个强分类器来提高模型的预测准确率。
梯度提升算法的原理可以分为以下几个步骤:
- 初始化模型:一般来说,我们可以使用一个简单的模型(比如说决策树)作为初始的分类器。
- 计算损失函数的负梯度:计算出每个样本点在当前模型下的损失函数的负梯度。这相当于是让新的分类器去拟合当前模型下的误差。
- 训练新的分类器:用这些负梯度作为目标变量,训练一个新的弱分类器。这个弱分类器可以是任意的分类器,比如说决策树、线性模型等。
- 更新模型:将新的分类器加入到原来的模型中,可以用加权平均或者其他方法将它们组合起来。
- 重复迭代:重复上述步骤,直到达到预设的迭代次数或者达到预设的准确率。
由于梯度提升算法是一种串行算法,所以它的训练速度可能会比较慢,我们以一个实际的例子来介绍:
假设我们有一个特征集Xi和值Yi,要计算y的最佳估计
我们从y的平均值开始
每一步我们都想让F_m(x)更接近y|x。
在每一步中,我们都想要F_m(x)一个更好的y给定x的近似。
首先,我们定义一个损失函数
然后,我们向损失函数相对于学习者Fm下降最快的方向前进:
因为我们不能为每个x计算y,所以不知道这个梯度的确切值,但是对于训练数据中的每一个x_i,梯度完全等于步骤m的残差:r_i!
所以我们可以用弱回归树h_m来近似梯度函数g_m,对残差进行训练:
然后,我们更新学习器
这就是梯度提升,我们不是使用损失函数相对于当前学习器的真实梯度g_m来更新当前学习器F_{m},而是使用弱回归树h_m来更新它。
也就是重复下面的步骤
1、计算残差:
2、将回归树h_m拟合到训练样本及其残差(x_i, r_i)上
3、用步长\alpha更新模型
看着很复杂对吧,下面我们可视化一下这个过程就会变得非常清晰了
决策过程可视化
这里我们使用sklearn的moons 数据集,因为这是一个经典的非线性分类数据
import numpy as npimport sklearn.datasets as dsimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib as mplfrom sklearn import treefrom itertools import product,isliceimport seaborn as snsmoonDS = ds.make_moons(200, noise = 0.15, random_state=16)moon = moonDS[0]color = -1*(moonDS[1]*2-1)df =pd.DataFrame(moon, columns = ['x','y'])df['z'] = colordf['f0'] =df.y.mean()df['r0'] = df['z'] - df['f0']df.head(10)
让我们可视化数据:
下图可以看到,该数据集是可以明显的区分出分类的边界的,但是因为他是非线性的,所以使用线性算法进行分类时会遇到很大的困难。
那么我们先编写一个简单的梯度增强模型:
def makeiteration(i:int):"""Takes the dataframe ith f_i and r_i and approximated r_i from the features, then computes f_i+1 and r_i+1"""clf = tree.DecisionTreeRegressor(max_depth=1)clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}'])df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)eta = 0.9df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']df[f'r{i}'] = df['z'] - df[f'f{i}']rmse = (df[f'r{i}']**2).sum()clfs.append(clf)rmses.append(rmse)
上面代码执行3个简单步骤:
将决策树与残差进行拟合:
clf.fit(X=df[['x','y']].values, y = df[f'r{i-1}'])df[f'r{i-1}hat'] = clf.predict(df[['x','y']].values)
然后,我们将这个近似的梯度与之前的学习器相加:
df[f'f{i}'] = df[f'f{i-1}'] + eta*df[f'r{i-1}hat']
最后重新计算残差:
df[f'r{i}'] = df['z'] - df[f'f{i}']
步骤就是这样简单,下面我们来一步一步执行这个过程。
第1次决策
Tree Split for 0 and level 1.563690960407257
第2次决策
Tree Split for 1 and level 0.5143677890300751
第3次决策
Tree Split for 0 and level -0.6523728966712952
第4次决策
Tree Split for 0 and level 0.3370491564273834
第5次决策
Tree Split for 0 and level 0.3370491564273834
第6次决策
Tree Split for 1 and level 0.022058885544538498
第7次决策
Tree Split for 0 and level -0.3030575215816498
第8次决策
Tree Split for 0 and level 0.6119407713413239
第9次决策
可以看到通过9次的计算,基本上已经把上面的分类进行了区分
我们这里的学习器都是非常简单的决策树,只沿着一个特征分裂!但整体模型在每次决策后边的越来越复杂,并且整体误差逐渐减小。
plt.plot(rmses)
这也就是上图中我们看到的能够正确区分出了大部分的分类
如果你感兴趣可以使用下面代码自行实验:
https://avoid.overfit.cn/post/533a0736b7554ef6b8464a5d8ba964ab
作者:Tanguy Renaudie
相关文章:

梯度提升算法决策过程的逐步可视化
梯度提升算法是最常用的集成机器学习技术之一,该模型使用弱决策树序列来构建强学习器。这也是XGBoost和LightGBM模型的理论基础,所以在这篇文章中,我们将从头开始构建一个梯度增强模型并将其可视化。 梯度提升算法介绍 梯度提升算法&#x…...
Linux系统调用之文件属性操作函数
前言 如果,想要深入的学习Linux系统调用中access,chmod,chown,truncate这些有关于文件属性的操作函数,还是需要去自己阅读Linux系统中的帮助文档。 具体输入命令: man 2 access/chmod/chown/truncate 即可…...
VMware 安装 银河麒麟高级服务器操作系统 V10 + QT 开发环境搭建
下载并安装vmware 下载银河麒麟操作烯烃服务器版v10的镜像文件从官网下载,因为是x86的电脑芯片,选择AMD64版,即vmare 安装麒麟操作系统注意事项:安装位置选择自动分区网络和主机名设置打开网络,ip4就不用再设置了创建一…...

2023年疫情开放,国内程序员薪资涨了还是跌了?大数据告诉你答案
自从疫情开放,国内各个行业都开始有复苏的迹象,尤其是旅游行业更是空前暴涨,那么互联网行业如何? 有人说今年好找工作多了,有人说依然是内卷得一塌糊涂,那么今年开春以来,各个岗位的程序员工资…...

太赫兹频段耦合器设计相关经验总结
1拿到耦合器的频段后,确定中心频率和波导的宽度和高度 此处贴一张不同频段对应的波导尺寸图 需要注意的是1英寸 2.54厘米,需注意换算 具体网址:矩形波导尺寸 | 扩维 (qualwave.com) 仅列举我比较常用的太赫兹频段部分 2.以220~320GHz频段&a…...

反弹shell数据不回显带外查询pikaqiu靶场搭建
P1 文件上传下载(解决无图形化和解决数据传输) 解决无图形化: 当我们想下载一个文件时,通常是通过浏览器的一个链接直接访问网站点击下载的,但是在实际的安全测试中,我们获取的权限只是一个执行命令的窗口…...

按键修改阈值功能、报警功能、空气质量功能实现
按键修改阈值功能 要使用按键,首先要定义按键。通过查阅资料,可知按键的引脚如图所示:按键1(S1)通过KEY0与PA0连接,按键2(S2)通过KEY1与PE2连接,按键3(S3&…...

spring重点整理篇--springMVC(嘿嘿,开心哟)
Spring MVC是的基于JavaWeb的MVC框架,是Spring框架中的一个组成部分(WEB模块) MVC设计模式: Controller(控制器) Model(模型) View(视图) 重点来了😄 SpringMVC的工作机制…...
图像融合评估指标Python版
图像融合评估指标Python版 这篇博客利用Python把大部分图像融合指标基于图像融合评估指标复现了,从而方便大家更好的使用Python进行指标计算,以及一些I/O 操作。除了几个特征互信息的指标没有成功复现之外,其他指标均可以通过这篇博客提到的P…...
20230303----重返学习-函数概念-函数组成-函数调用-形参及匿名函数及自调用函数
day-019-nineteen-20230303-函数概念-函数组成-函数调用-形参及匿名函数及自调用函数 变量 变量声明 变量 声明定义(赋值) var num;num 100; 声明与赋值分开var num 100; 声明时就赋值 赋值只能声明一次,可以赋值无数次 变量声明关键词 varconstletclassfunctio…...

Java面试题总结
文章目录前言1、JDK1.8 的新特性有哪些?2、JDK 和 JRE 有什么区别?3、String,StringBuilder,StringBuffer 三者的区别?4、为什么 String 拼接的效率低?5、ArrayList 和 LinkedList 有哪些区别?6…...

深圳大学计软《面向对象的程序设计》实验7 拷贝构造函数与复合类
A. Point&Circle(复合类与构造) 题目描述 类Point是我们写过的一个类,类Circle是一个新的类,Point作为其成员对象,请完成类Circle的成员函数的实现。 在主函数中生成一个圆和若干个点,判断这些点与圆的位置关系,…...
Java的JVM(Java虚拟机)参数配置
JVM原理 (1)jvm是java的核心和基础,在java编译器和os平台之间的虚拟处理器,可在上面执行字节码程序。 (2)java编译器只要面向jvm,生成jvm能理解的字节码文件。java源文件经编译成字节码程序&a…...
leetcode 困难 —— 数据流的中位数(优先队列)
题目: 中位数是有序整数列表中的中间值。如果列表的大小是偶数,则没有中间值,中位数是两个中间值的平均值。 例如 arr [2,3,4] 的中位数是 3 。 例如 arr [2,3] 的中位数是 (2 3) / 2 2.5 。 实现 MedianFinder 类: MedianFinder() 初始化…...

7个常用的原生JS数组方法
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 7个常用的原生JS数组方法一、Array.map()二、Array.filter三、Array.reduce四、Array.forEach五、Array.find六、Array.every七、Array.some总结一、Array.map() 作用&#…...

一、一篇文章打好高数基础-函数
1.连续函数的性质考点分析函数的连续性主要考察函数的奇偶性、有界性、单调性、周期性。例题判断函数的奇偶性的有界区间为() A.(-1,0) B(0,1) C(1,2) D(2,3)2.闭区间上连续函数的性质考点分析闭区间上连续函数的性质主要考察函数的最大最小值定理、零点…...

pipenv的基本使用
一. pipenv 基础 pipenv安装: pip install pipenvpipenv常用命令 pipenv --python 3 # 创建python3虚拟环境 pipenv --venv # 查看创建的虚拟环境 pipenv install 包名 # 安装包 pipenv shell # 切换到虚拟环境中 pip list # 查看当前已经安装的包࿰…...

OpenCV入门(三)快速学会OpenCV2图像处理基础
OpenCV入门(三)快速学会OpenCV2图像处理基础 1.颜色变换cvtColor imgproc的模块名称是由image(图像)和process(处理)两个单词的缩写组合而成的,是重要的图像处理模块,主要包括图像…...

基于PySide6的MySql数据库快照备份与恢复软件
db-camera 软件介绍 db-camera是一款MySql数据库备份(快照保存)与恢复软件。功能上与dump类似,但是提供了相对有好的交互界面,能够有效地管理导出的sql文件。 使用场景 开发阶段、测试阶段,尤其适合单人开发的小项目…...

BI不是报表,千万不要混淆
商业智能BI作为商业世界的新宠儿,在市场上实现了高速增长并获得了各领域企业的口碑赞誉。 很多企业把商业智能BI做成了纯报表,二维表格的数据展现形式,也有一些简单的图表可视化。但是这些简单的商业智能BI可视化报表基本上只服务到了一线的…...

工业安全零事故的智能守护者:一体化AI智能安防平台
前言: 通过AI视觉技术,为船厂提供全面的安全监控解决方案,涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面,能够实现对应负责人反馈机制,并最终实现数据的统计报表。提升船厂…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...

YSYX学习记录(八)
C语言,练习0: 先创建一个文件夹,我用的是物理机: 安装build-essential 练习1: 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件,随机修改或删除一部分,之后…...

前端导出带有合并单元格的列表
// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...
c++ 面试题(1)-----深度优先搜索(DFS)实现
操作系统:ubuntu22.04 IDE:Visual Studio Code 编程语言:C11 题目描述 地上有一个 m 行 n 列的方格,从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子,但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心
当仓库学会“思考”,物流的终极形态正在诞生 想象这样的场景: 凌晨3点,某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径;AI视觉系统在0.1秒内扫描包裹信息;数字孪生平台正模拟次日峰值流量压力…...

select、poll、epoll 与 Reactor 模式
在高并发网络编程领域,高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表,以及基于它们实现的 Reactor 模式,为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。 一、I…...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...
laravel8+vue3.0+element-plus搭建方法
创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...