梯度提升用于高效的分类与回归
使用 决策树(Decision Tree) 实现 梯度提升(Gradient Boosting) 主要是模拟 GBDT(Gradient Boosting Decision Trees) 的原理,即:
- 第一棵树拟合原始数据
- 计算残差(负梯度方向)
- 用新的树去拟合残差
- 累加所有树的预测值
- 重复步骤 2-4,直至达到指定轮数
下面是一个 纯 Python + PyTorch 实现 GBDT(梯度提升决策树) 的代码示例。
1. 纯 Python 实现梯度提升决策树
import numpy as np
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split# 生成数据
X, y = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 参数
n_trees = 50 # 多少棵树
learning_rate = 0.1 # 学习率# 初始化预测值(全部为 0)
y_pred_train = np.zeros_like(y_train)
y_pred_test = np.zeros_like(y_test)# 训练梯度提升决策树
trees = []
for i in range(n_trees):residuals = y_train - y_pred_train # 计算残差(负梯度方向)tree = DecisionTreeRegressor(max_depth=3) # 这里使用较浅的树tree.fit(X_train, residuals) # 让树学习残差trees.append(tree)# 更新预测值(累加弱学习器的结果)y_pred_train += learning_rate * tree.predict(X_train)y_pred_test += learning_rate * tree.predict(X_test)# 计算损失mse = mean_squared_error(y_train, y_pred_train)print(f"Iteration {i+1}: MSE = {mse:.4f}")# 计算最终测试集误差
final_mse = mean_squared_error(y_test, y_pred_test)
print(f"\nFinal Test MSE: {final_mse:.4f}")
代码解析
- 第一步:构建一个基础决策树
DecisionTreeRegressor(max_depth=3)。 - 第二步:每棵树学习前面所有树的残差(负梯度方向)。
- 第三步:训练
n_trees棵树,每棵树的预测结果乘以learning_rate累加到最终预测值。 - 第四步:每次迭代后更新预测值,减少误差。
2. 用 PyTorch 实现 GBDT
虽然 GBDT 主要基于决策树,但如果你希望用 PyTorch 计算梯度并模拟 GBDT,可以如下操作:
- 用 PyTorch 计算 损失函数的梯度
- 用
sklearn.tree.DecisionTreeRegressor拟合梯度 - 用 PyTorch 计算最终误差
import torch
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split# 生成数据
X, y = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 参数
n_trees = 50 # 多少棵树
learning_rate = 0.1 # 学习率# 转换数据为 PyTorch 张量
X_train_torch = torch.tensor(X_train, dtype=torch.float32)
y_train_torch = torch.tensor(y_train, dtype=torch.float32)# 初始化预测值
y_pred_train = torch.zeros_like(y_train_torch)# 训练 GBDT
trees = []
for i in range(n_trees):# 计算梯度(残差)residuals = y_train_torch - y_pred_train# 用决策树拟合梯度tree = DecisionTreeRegressor(max_depth=3)tree.fit(X_train, residuals.numpy())trees.append(tree)# 更新预测值y_pred_train += learning_rate * torch.tensor(tree.predict(X_train), dtype=torch.float32)# 计算损失mse = mean_squared_error(y_train, y_pred_train.numpy())print(f"Iteration {i+1}: MSE = {mse:.4f}")
PyTorch 实现的关键点
y_train_torch - y_pred_train计算 损失的梯度DecisionTreeRegressor作为弱学习器,拟合梯度- 预测值
+= learning_rate * tree.predict(X_train)
3. 结合 PyTorch 和 XGBoost
如果你要 结合 PyTorch 和 GBDT,可以先用 XGBoost 训练 GBDT,再用 PyTorch 进行深度学习:
import xgboost as xgb
import torch.nn as nn
import torch.optim as optim
import torch
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split# 生成数据
X, y = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练 XGBoost 作为特征提取器
xgb_model = xgb.XGBRegressor(n_estimators=50, max_depth=3, learning_rate=0.1)
xgb_model.fit(X_train, y_train)# 提取 XGBoost 叶子节点特征
X_train_leaves = xgb_model.apply(X_train)
X_test_leaves = xgb_model.apply(X_test)# 定义 PyTorch 神经网络
class NeuralNet(nn.Module):def __init__(self, input_size):super(NeuralNet, self).__init__()self.fc = nn.Linear(input_size, 1)def forward(self, x):return self.fc(x)# 训练 PyTorch 神经网络
model = NeuralNet(X_train_leaves.shape[1])
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()X_train_tensor = torch.tensor(X_train_leaves, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)for epoch in range(100):optimizer.zero_grad()output = model(X_train_tensor)loss = loss_fn(output, y_train_tensor)loss.backward()optimizer.step()print("Training complete!")
结论
| 方法 | 适用场景 | 备注 |
|---|---|---|
| 纯 Python GBDT | 适合小规模数据 | 使用 sklearn.tree.DecisionTreeRegressor |
| PyTorch 计算梯度 + GBDT | 适合梯度优化实验 | 计算梯度后用 DecisionTreeRegressor 训练 |
| XGBoost + PyTorch | 适合大规模数据 | 先用 XGBoost 提取特征,再用 PyTorch 训练 |
如果你的数据是结构化的(如 表格数据),建议 直接使用 XGBoost/LightGBM,再结合 PyTorch 进行特征工程或后处理。
相关文章:
梯度提升用于高效的分类与回归
使用 决策树(Decision Tree) 实现 梯度提升(Gradient Boosting) 主要是模拟 GBDT(Gradient Boosting Decision Trees) 的原理,即: 第一棵树拟合原始数据计算残差(负梯度…...
【单细胞第二节:单细胞示例数据分析-GSE218208】
GSE218208 1.创建Seurat对象 #untar(“GSE218208_RAW.tar”) rm(list ls()) a data.table::fread("GSM6736629_10x-PBMC-1_ds0.1974_CountMatrix.tsv.gz",data.table F) a[1:4,1:4] library(tidyverse) a$alias:gene str_split(a$alias:gene,":",si…...
设计模式 - 行为模式_Template Method Pattern模板方法模式在数据处理中的应用
文章目录 概述1. 核心思想2. 结构3. 示例代码4. 优点5. 缺点6. 适用场景7. 案例:模板方法模式在数据处理中的应用案例背景UML搭建抽象基类 - 数据处理的 “总指挥”子类定制 - 适配不同供应商供应商 A 的数据处理器供应商 B 的数据处理器 在业务代码中整合运用 8. 总…...
新春登蛇山:告别岁月,启航未来
大年初一,晨曦透过薄雾,温柔地洒在武汉的大街小巷。2025 年的蛇年春节,带着新春的喜气与希望悄然而至。我站在蛇山脚下,心中涌动着复杂的情感,因为今天,我不仅将与家人一起登山揽胜,更将在这一天…...
hive:基本数据类型,关于表和列语法
基本数据类型 Hive 的数据类型分为基本数据类型和复杂数据类型 加粗的是常用数据类型 BOOLEAN出现ture和false外的其他值会变成NULL值 没有number,decimal类似number 如果输入的数据不符合数据类型, 映射时会变成NULL, 但是数据本身并没有被修改 创建表 创建表的本质其实就是在…...
安装最小化的CentOS7后,执行yum命令报错Could not resolve host mirrorlist.centos.org; 未知的错误
文章目录 安装最小化的CentOS7后,执行yum命令报错"Could not resolve host: mirrorlist.centos.org; 未知的错误"错误解决方案: 安装最小化的CentOS7后,执行yum命令报错"Could not resolve host: mirrorlist.centos.org; 未知…...
图论——spfa判负环
负环 图 G G G中存在一个回路,该回路边权之和为负数,称之为负环。 spfa求负环 方法1:统计每个点入队次数, 如果某个点入队n次, 说明存在负环。 证明:一个点入队n次,即被更新了n次。一个点每次被更新时所对应最短路的边数一定是…...
软件工程概论试题三
一、单选 1.需求确认主要检査五个方面的内容,其中那一项是为了保证文档中的需求不互相冲突(即不应该有相互矛盾的约束或者对同一个系统功能有不同的描述)。 A.现实性 B. 可验证性 C.一致性 D.正确性 E.完整性 正答:C 2.下列开发方法中,( )不…...
21.3-启动流程、编码风格(了解) 第21章-FreeRTOS项目实战--基础知识之新建任务、启动流程、编码风格、系统配置 文件组成和编码风格(了解)
21.3-启动流程、编码风格(了解) 启动流程 第一种启动流程(我们就使用这个): 在main函数中将硬件初始化、RTOS系统初始化,同时创建所有任务,再启动RTOS调度器。 第二种启动流程: 在main函数中将硬件初始化、RTOS系统初始化,只…...
未来无线技术的发展方向
未来无线技术的发展趋势呈现出多样化、融合化的特点,涵盖速度、覆盖范围、应用领域、频段利用、安全性等多个方面。这些趋势将深刻改变人们的生活和社会的运行方式。 传输速度提升:Wi-Fi 技术迭代加快,如 Wi-Fi7 理论峰值速率达 46Gbps&#…...
Qt5离线安装包无法下载问题解决办法
想在电脑里装一个Qt,但是直接报错。果然还是有解决办法滴。 qt download from your ip is not allowed Qt5安装包下载办法 方法一:简单直接,直接科学一下,不过违法行为咱不做,遵纪守法好公民(不过没办法阻…...
qt-C++笔记之QLine、QRect、QPainterPath、和自定义QGraphicsPathItem、QGraphicsRectItem的区别
qt-C笔记之QLine、QRect、QPainterPath、和自定义QGraphicsPathItem、QGraphicsRectItem的区别 code review! 参考笔记 1.qt-C笔记之重写QGraphicsItem的paint方法(自定义QGraphicsItem) 文章目录 qt-C笔记之QLine、QRect、QPainterPath、和自定义QGraphicsPathItem、QGraphic…...
doris:导入时实现数据转换
Doris 在数据导入时提供了强大的数据转换能力,可以简化部分数据处理流程,减少对额外 ETL 工具的依赖。主要支持以下四种转换方式: 列映射:将源数据列映射到目标表的不同列。 列变换:使用函数和表达式对源数据进行实时…...
新版231普通阿里滑块 自动化和逆向实现 分析
声明: 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 逆向过程 补环境逆向 部分补环境 …...
如何构建树状的思维棱镜认知框架
在思维与知识管理中,“树状思维棱镜”通常指一种层级式、可多维度展开和不断深入(下钻)的认知框架。它不仅仅是普通的树状结构(如传统思维导图),更强调“棱镜”所体现的多视角、多维度切换与综合分析的能力…...
openRv1126 AI算法部署实战之——ONNX模型部署实战
在RV1126开发板上部署ONNX算法,实时目标检测RTSP传输。视频演示地址 rv1126 yolov5 实时目标检测 rtsp传输_哔哩哔哩_bilibili 一、准备工作 1.从官网下载YOLOv5-v7.0工程(YOLOv5的第7个版本) 手动在线下载: Releases ultraly…...
Vue 组件开发:构建高效可复用的前端界面要素
1 引言 在现代 Web 开发中,构建高效且可复用的前端界面要素是提升开发效率和用户体验的关键。Vue.js 作为一种轻量级且功能强大的前端框架,提供了丰富的工具和机制,帮助开发者快速构建高质量的应用程序。通过合理设计和封装 Vue 组件,我们可以实现组件的高效复用,提高开发…...
Vue.js组件开发-实现全屏平滑移动、自适应图片全屏滑动切换
使用Vue实现全屏平滑移动、自适应图片全屏滑动切换的功能。使用Vue 3和Vue Router,并结合一些CSS样式来完成这个效果。 步骤 创建Vue项目:使用Vue CLI创建一个新的Vue项目。准备图片:将需要展示的图片放在项目的public目录下。创建组件&…...
水果实体店品牌数字化:RWA + 智能体落地方案
一、方案背景 随着数字化技术的迅猛发展,实体零售行业正面临前所未有的挑战与机遇。传统的零售模式难以满足消费者对个性化、便捷化、智能化的需求,尤其是在水果等生鲜商品领域,如何通过技术手段提升运营效率、增强顾客体验、拓宽盈利模式&a…...
DeepSeek模型:开启人工智能的新篇章
DeepSeek模型:开启人工智能的新篇章 在当今快速发展的技术浪潮中,人工智能(AI)已经成为了推动社会进步和创新的核心力量之一。而DeepSeek模型,作为AI领域的一颗璀璨明珠,正以其强大的功能和灵活的用法&…...
Landsat8温度反演结果不准?可能是这5个参数没搞对(ENVI实战经验分享)
Landsat8温度反演精度提升:5个关键参数优化与ENVI实战解析 当你在深夜盯着屏幕上那些明显偏离预期的温度反演结果时,是否曾怀疑过ENVI软件出了问题?事实上,90%的温度反演误差都源于几个关键参数的设置不当。作为一位经历过数十个遥…...
手把手教你用TI F28P65X开发板实现LED定时闪烁(基于CPU Timer2,含完整源码)
从零玩转TI F28P65X开发板:CPU Timer2实现可调频LED闪烁实战指南 刚拿到TI F28P65X开发板时,面对密密麻麻的引脚和复杂的开发环境,很多嵌入式新手会感到无从下手。本文将带你用最直观的方式,通过控制LED闪烁这个经典入门项目&…...
图解Linux内核DRM框架:从用户态ioctl到plane更新的完整数据流(以4.14版本为例)
图解Linux内核DRM框架:从用户态ioctl到plane更新的完整数据流(以4.14版本为例) 在图形显示技术领域,Linux内核的DRM(Direct Rendering Manager)框架扮演着核心角色。本文将聚焦于DRM_IOCTL_MODE_SETPLANE这…...
如何让Windows 11告别臃肿?Win11Debloat完整指南帮你一键优化系统
如何让Windows 11告别臃肿?Win11Debloat完整指南帮你一键优化系统 【免费下载链接】Win11Debloat A simple, lightweight PowerShell script that allows you to remove pre-installed apps, disable telemetry, as well as perform various other changes to declu…...
从LC谐振到信号振铃:用Multisim仿真带你理解PCB上的阻尼振荡
从LC谐振到信号振铃:用Multisim仿真揭示PCB阻尼振荡的本质 1. 振铃现象:硬件工程师的"噩梦" 第一次在示波器上看到信号边沿那些诡异的振荡波形时,我差点以为自己的电路板被某种神秘力量干扰了。这种被称为"振铃"的现象…...
CTF逆向实战:从RC4到Base64,手把手拆解CTFshow赛题
1. RC4加密实战:从文件分析到密钥破解 第一次接触CTF逆向题时,看到RC4加密可能会觉得无从下手。但实际拆解后你会发现,这类题目往往藏着明显的突破口。就拿CTFshow这道re2赛题来说,整个解题过程就像在玩解谜游戏。 用IDA打开题目…...
StructBERT在嵌入式Linux设备上的轻量化部署方案
StructBERT在嵌入式Linux设备上的轻量化部署方案 1. 为什么要在树莓派上跑StructBERT 你可能已经试过在笔记本或服务器上运行大模型,但有没有想过让AI在树莓派这样的小设备上工作?不是为了炫技,而是因为很多实际场景根本用不上那么大的机器…...
解锁智能OCR新范式:Pix2Text多模态内容识别技术全解析
解锁智能OCR新范式:Pix2Text多模态内容识别技术全解析 【免费下载链接】Pix2Text Pix In, Latex & Text Out. Recognize Chinese, English Texts, and Math Formulas from Images. 项目地址: https://gitcode.com/gh_mirrors/pi/Pix2Text Pix2Text是一款…...
Visio高效绘制神经网络卷积层:从基础到三维呈现
1. Visio绘制神经网络卷积层的入门指南 第一次用Visio画神经网络结构时,我盯着满屏的工具栏发懵——这玩意儿比Photoshop的图层还复杂。但摸索半天后发现,只要掌握几个核心功能,画卷积层其实比用PPT简单十倍。先说说最基础的形状选择…...
res-downloader:智能资源捕获工具的技术实现与高效工作流指南
res-downloader:智能资源捕获工具的技术实现与高效工作流指南 【免费下载链接】res-downloader 视频号、小程序、抖音、快手、小红书、直播流、m3u8、酷狗、QQ音乐等常见网络资源下载! 项目地址: https://gitcode.com/GitHub_Trending/re/res-downloader 资源…...
