如何在Spark中使用gbdt模型分布式预测
这目录
- 1 训练gbdt模型
- 2 第三方包python环境打包
- 3 Spark中使用gbdt模型
- 3.1 spark配置文件
- 3.2 主函数main.py
- 4 spark任务提交
1 训练gbdt模型
我们可以基于lightgbm快速的训练一个gbdt模型,训练相对比较简单,只要把训练样本处理好,几行代码可以快速训练好模型,如下是训练一个多分类模型训练核心代码如下:
import lightgbm as lgb
import joblib
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
#假设处理好的训练样本为train.csv
df = pd.read_csv('./train.csv')
X = pd.drop(['label'],axis=1)
Y = df.label
# split data for val
x_train,x_val,y_train,y_val = train_test_split(X,Y,test_size=0.2,random_state=123)
# train model
cate_features=['sex','brand']
train_data = train_data = lgb.Dataset(x_train,label=y_train,categoryical_featrues=cate_features)
params = {'objective':'multiclass','learning_rate':0.1,'n_estimators':100,'num_class':23}
model = lgb.train(params, train_data,100)
#predict val
y_pred = model.predict(x_val)
y_pred = y_pred.argmax(axis=1)# acc
acc = accuracy_score(y_val, y_pred)
print(acc)# feature importance
feature_name = model.feature_name()
feature_importance = model.feature_importance()
feature_score = dict(zip(feature_name, feature_importance))
feature_score_sort = sorted(feature_score.items(),key=lambda x:x[1], reverse=True)# save model
joblib.dump(model, 'model.pkl')
上述就是基于lightgbm训练gbdt模型的代码,训练完后我们通过joblib保存了我们训练好的模型,这个模型接下来我们可以在spark进行分布式预测。
2 第三方包python环境打包
在使用spark的时候,我们可以自定义python环境,并且把我们需要的第三方包都可以安装该python环境里,这样在spark里我们就可以用python第三方包,比如等会我们需要的joblib, numpy等。具体如何配置python环境和第三方包,可以参考我上一篇博客:如何在spark中使用scikit-learn和tensorflow等第三方python包
3 Spark中使用gbdt模型
通过上述步骤,把需要的python环境和第三方包制作好了,包名为python39.zip,接下来我们介绍一下如何在spark中使用我们刚才训练好的gbdt模型进行分布式快速预测。
3.1 spark配置文件
提交spark任务的时候,配置文件这块也需要稍微修改一下,配置文件信息如下:
$SPARK_HOME/bin/spark-submit \
--master yarn \
--deploy-memory 12G
--executor-memory 20G \
--executor-cores 4 \
--queue root.your_queue_name \
--archives ./python39.zip#python39 \
--conf spark.yarn.appMasterEnv.PYSPARK_PYTHON=./python39/python39/bin/python3.9 \
--conf spark.yarn.appMasterEnv,HADOOP_USER_NAME=your_hduser_name \
--conf spark.shuffle.service.enabled=true \
--conf spark.dynamicAllocation.enbled=true \
--conf spark.dynamicAllocation.maxExecutors=50 \
--conf spark.dynamicAllocation.minExecutors=50 \
--conf spark.braodcast.compress=True \
--conf saprk.network.timeout=1000s \
--conf spark.sql.hive.mergeFiles=true \
--conf spark.speculation=false \
--conf spark.yarn.executor.memoryOverhead=4096 \
--files $HIVE_CONF_DIR/hive-site.xml \
--py-files ./model.pkl \
$@
上述是基本的提交spark任务的配置文件,其中
–archives ./python39.zip#python39 \
–archives参数用于在Spark应用程序运行期间将本地压缩档案文件解压到YARN集群节点上。#python39 是为档案文件定义的别名,这将在Spark应用程序中使用。
这个参数的目的是将名为python39.zip的压缩文件解压到YARN集群节点,并将其路径设置为python39,以供Spark应用程序使用。这通常用于指定特定版本的Python环境,以便在Spark任务中使用。
–conf spark.yarn.appMasterEnv.PYSPARK_PYTHON=./python39/python39/bin/python3.9
–conf参数用于设置Spark配置属性。
spark.yarn.appMasterEnv.PYSPARK_PYTHON 是一个Spark配置属性,它指定了YARN应用程序的主节点(ApplicationMaster)使用的Python解释器。
./python39/python39/bin/python3.9 是实际Python解释器的路径,它将在YARN应用程序的主节点上执行。
这个参数的目的是告诉Spark应用程序在YARN的主节点上使用特定的Python解释器即./python39/python39/bin/python3.9。这通常用于确保Spark应用程序使用正确的Python版本和环境来运行任务。
–py-files ./model.pkl \
–py-file是 Spark 提交任务时的一个参数,用于将指定的 .py 文件、.zip 文件或 .egg 文件分发到集群的所有 Worker 节点。Spark 会将这些文件自动添加到 Python 的模块路径中(即 sys.path),使得这些文件可以被任务中的代码引用。所以在这里我们将 model.pkl 模型文件分发到 Spark 集群的每个节点,确保每个节点在运行任务时都能访问并使用这个模型。
3.2 主函数main.py
接下来,我们来看下如何在在spark调用我们训练好的gbdt模型进行预测,核心代码主要如下:
1)import基础函数功能包等
# -*-coding:utf8 -*-
import sys
from pyspark.sql.types import Row
from pyspark.sql.types import *
from pyspark import SparkConf, SparkContext, HiveContext
import datetime
import numpy as np
import joblibsave_table='your_target_table_name'
source_table = 'your_source_predict_table_name'
index_table = 'your_index_table_name'
# define saved table schema
schema = StructType([StructFiled('userid', StringType(), True),StructFiled('names', ArrayType(StringType()), True),StructFiled('scores', ArrayType(FloatType()), True)])
- main执行入口基础配置和执行流程
if __name__=='__main__':conf = SparkConf()sc = SparkContext(conf=conf, appName='gbdt_spark_predict')sc.setLogLevel("WARN")hiveCtx = HiveContext(sc)#hive基础配置hiveCtx.setConf('spark.shuffle.consolidateFiles','true')hiveCtx.setConf('spark.shuffle.memoryFraction','0.4')hiveCtx.setConf('spark.sql.shuffle.partitions','1000')if len(sys.argv) == 1:dt = datetime.datetime.now() + datetime.timedelta(-1)else:dt = datetime.datetime.strptime(sys.argv[1],"%Y%m%d').date()dt_str = dt.strftime('%Y-%m-%d')hiveCtx.sql("use your_datebase_name")#注册函数,在sql时候可以使用hiveCtx.registerFunction('null_chage', null_chage, StringType())#创建目标表create_table(hiveCtx)#主函数get_predict(hiveCtx)
上面主函数给出了一个基本的流程步骤:1)spark, hive context等初始化 2)注册函数可以直接在sql中使用,方便数据处理 3)建立目标hive表 4)执行功能函数。
- 函数功能模块实现
在第2步骤里,我们主要有三个函数需要编写,一个是可以在sql中调用的基础函数,第二个就是创建表函数,第三个就是功能函数,我们接下来实现这三个的基本功能:
#用在sql中的基本数据操作处理
def null_chage(x):return 'unknow' if x is None else x#创建目标表
def create_table(hiveCtx):create_tbl = """CREATE EXTERNAL TABLE IF NOT EXISTS your_database_name.{table_name} (userid string COMMENT 'user id';names array<string> COMMENT 'predict label names')scores array<float> COMMENT 'predict socre')PARTITIONED BY(dt string, dp string)STORED AS ORCLOCATION 'hdfs://your_database_name.db/{table_name}'TBLPROPERTIES('orc.compress'='SNAPPY','comment'='gbdt predict user score')""".format(table_name=save_table)
# 功能函数
def get_predict(hiveCtx):# get label and idex datasql="""select index, valuefrom {index_table}where dt='active;""".format(index_table=index_table)print(sql)vocab = hiveCtx.sql(sql).rdd.collect()vocab_dict = dict()for x in vocab:vocab_dict.setdefault(x[0],x[1])# broadcastbr_vocab_dict = sc.broadcast(vocab_dict)# get predict datasql="""select null_chage(userid) as userid, featuresfrom {source_table}where dt='active'""".format(source_table=source_table)print(sql)hiveCtx.sql(sql).rdd.mapPartitions(lambda rows: main_func(rows, br_vocab_dict)) \.toDF(schema=schema) \.registerTempTable('final_tbl')# insert tableinsert_sql = """insert overwrite table {save_table} partition (dt='{dt}')select * from final_tbl""".format(save_table=save_table,dt='active')print(insert_sql)hiveCtx.sql(insert_sql)
接下来,我们来看下main_func函数的实现:
def main_func(rows, br_vocab_dict):# load modelmodel = joblib.load('./model.pkl')vocab_dict = br_vobab_dict.valuefor row in rows:userid, features = rowfeatures = np.array(features)predict = model.predict(features)predict_sort = np.argsort(-predict[0])names = [vocab_dict[idx] for idx in predict_sort]scores = [float(predict[0][idx]) for idx in predict_sort]yield userid, names, scores
整个代码的实现我们在这里就写完了,整体实现逻辑是比较清晰易懂的,按照这个流程来,我们可以很高效快速的基于spark分布式的跑一些数据处理和模型预测性的任务。
4 spark任务提交
接下来,就是提交我们的spark任务了,在工作环境目录如下文件信息:
- 提前准备好的python环境包python39.zip
- spark config文件 run_spark_arg.sh
- 主函数代码 main.py
- gbdt模型文件model.pkl
最后环节就是提交spark任务,我么可以在服务器提交命令如下:
nohup sh run_spark_arg.sh main.py >log.txt 2>&1 &
相关文章:
如何在Spark中使用gbdt模型分布式预测
这目录 1 训练gbdt模型2 第三方包python环境打包3 Spark中使用gbdt模型3.1 spark配置文件3.2 主函数main.py 4 spark任务提交 1 训练gbdt模型 我们可以基于lightgbm快速的训练一个gbdt模型,训练相对比较简单,只要把训练样本处理好,几行代码可…...

Qt-5.14.2 example
官方历程很丰富,modbus、串口、chart图表、3D、视频 共享方便使用 Building and Running an Example You can test that your Qt installation is successful by opening an existing example application project. To run an example application on an Android …...

virtualbox给Ubuntu22创建共享文件夹
1.在windows上的操作,创建共享文件夹Share 2.Ubuntu22上的操作,创建共享文件夹LinuxShare 3.在virtualbox虚拟机设置里,设置共享文件夹 共享文件夹路径:选择Windows系统中你需要共享的文件夹 共享文件夹名称:挂载至wi…...
GPT打字机效果—— fetchEventSouce进行sse流式请求
EventStream基本用法 与 WebSocket 不同的是,服务器发送事件是单向的。数据消息只能从服务端到发送到客户端(如用户的浏览器)。这使其成为不需要从客户端往服务器发送消息的情况下的最佳选择。 const evtSource new EventSource(“/api/v1/…...

SpringBoot 在线家具商城:设计考量与实现细节聚焦
第4章 系统设计 市面上设计比较好的系统都有一个共同特征,就是主题鲜明突出。通过对页面简洁清晰的布局,让页面的内容,包括文字语言,或者视频图片等元素可以清晰表达出系统的主题。让来访用户无需花费过多精力和时间找寻需要的内容…...
每日速记10道java面试题07
其他资料: 每日速记10道java面试题01-CSDN博客 每日速记10道java面试题02-CSDN博客 每日速记10道java面试题03-CSDN博客 每日速记10道java面试题04-CSDN博客 每日速记10道java面试题05-CSDN博客 每日速记10道java面试题06-CSDN博客 目录 1.线程的生命周期在j…...
前端面试热门题(二)[html\css\js\node\vue)
Vue 性能优化的方法 Vue 性能优化的方法多种多样,以下是一些常用的策略: 使用v-show替换v-if:v-show是通过CSS控制元素的显示与隐藏,而v-if是通过操作DOM来控制元素的显示与隐藏,频繁操作DOM会导致性能下降。因此&am…...
mvc基础及搭建一个静态网站
mvc asp.net core mvc环境 .net8vscode * Asp.Net Core 基础* .net8* 前辈* .net 4.9 非跨平台版本 VC* 跨平台版本* 1.0* 2.0* 2.1* 3.1* 5* 语言* C#* F# * Visual Basic* 框架* web应用* asp应用* WebFrom* mvc应用* 桌面应用* Winform* WPF* Web Api api应用或者叫服务* …...

AOSP的同步问题
repo sync同步时提示出错: error: .repo/manifests/: contains uncommitted changesRepo command failed due to the following UpdateManifestError errors: contains uncommitted changes解决方法: 1、cd 进入.repo/manifests cd .repo/manifests2、执行如下三…...

HarmonyOS4+NEXT星河版入门与项目实战(23)------实现手机游戏摇杆功能
文章目录 1、案例效果2、案例实现1、代码实现2、代码解释4、总结1、案例效果 2、案例实现 1、代码实现 代码如下(示例): import router from @ohos.router import {ResizeDirection } from @ohos.UiTest import curves...

Logistic Regression(逻辑回归)、Maximum Likelihood Estimatio(最大似然估计)
Logistic Regression(逻辑回归)、Maximum Likelihood Estimatio(最大似然估计) 逻辑回归(Logistic Regression,LR)逻辑回归的基本思想逻辑回归模型逻辑回归的目标最大似然估计优化方法 逻辑回归…...
Vue文字转语音实现
在开发流程中,面对语音支持的需求,小规模语音内容或许可以通过预处理后播放来轻松应对,但当涉及大量语音时,这一方法就显得繁琐低效了。为此,智慧的开发者们总能找到便捷的解决方案——利用Web技术实现语音播放&#x…...
Docker快速部署RabbitMq
在外网服务器拉取镜像 docker pull arm64v8/rabbitmq:3.8.9-management或者拉去我的服务器的 docker pull registry.cn-hangzhou.aliyuncs.com/qiluo-images/linux_arm64_rabbitmq:3.8.9-management重新命名 docker tag registry.cn-hangzhou.aliyuncs.com/qiluo-images/lin…...

glog在vs2022 hello world中使用
准备工作 设置dns为阿里云dns 223.5.5.5,下载cmake,vs2022,git git clone https://github.com/google/glog.git cd glog mkdir build cd build cmake .. 拷贝文件 新建hello world并设置 设置预处理器增加GLOG_USE_GLOG_EXPORT;GLOG_NO_AB…...

[241129] Docker Desktop 4.36 发布:企业级管理功能、WSL 2 增强 | Smile v4.0.0 发布
目录 Docker Desktop 4.36 发布:企业级管理功能、WSL 2 和 ECI 增强Smile v4.0.0 发布!Java 机器学习库迎来重大升级 Docker Desktop 4.36 发布:企业级管理功能、WSL 2 和 ECI 增强 Docker Desktop 4.36 带来了强大的更新,简化了…...
CentOS使用chrony服务进行时间同步源设置脚本
CentOS使用chrony服务进行时间同步源设置脚本 #!/bin/bash# Created: 2024-11-26 # Function: Check and Set OS time sync source to 10.0.11.100 # FileName: centos_set_time_source_to_ad.sh # Creator: Anster # Usage: # curl http://webserver-ip/scripts/centos_set…...
Git仓库迁移到远程仓库(源码、分支、提交)
单个迁移仓库 一、迁移仓库 1.准备工作 > 手动在电脑创建一个临时文件夹,CMD进入该目录 > 远程仓库上创建一个同名的空仓库 2.CMD命令:拉取旧Git仓库(包含提交、分支、源码) $ git clone --bare http://git.domain.cn/…...

【算法刷题指南】优先级队列
🌈个人主页: 南桥几晴秋 🌈C专栏: 南桥谈C 🌈C语言专栏: C语言学习系列 🌈Linux学习专栏: 南桥谈Linux 🌈数据结构学习专栏: 数据结构杂谈 🌈数据…...

使用pymupdf提取PDF文档中的文字和其颜色
最近我在捣鼓一个PDF文件,想把它里面的文字和文字颜色给提取出来。后来发现有个叫pymupdf的库能搞定这事儿。操作起来挺简单的,pymupdf的示例文档里就有现成的代码可以参考。 how-to-extract-text-with-color 我本地的测试代码如下: impor…...

贪心算法题
0简介 0.1什么是贪心算法 贪心算法是用贪婪(鼠目寸光)的角度,找到解决问题的最优解 贪心策略:(从局部最优 --> 整体最优) 1把解决问题的过程分为若干步; 2解决每一个问题时,都选择当前“看上去”最优的解法; 3“…...
脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)
一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...
前端倒计时误差!
提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...
Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器
第一章 引言:语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域,文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量,支撑着搜索引擎、推荐系统、…...

【Zephyr 系列 10】实战项目:打造一个蓝牙传感器终端 + 网关系统(完整架构与全栈实现)
🧠关键词:Zephyr、BLE、终端、网关、广播、连接、传感器、数据采集、低功耗、系统集成 📌目标读者:希望基于 Zephyr 构建 BLE 系统架构、实现终端与网关协作、具备产品交付能力的开发者 📊篇幅字数:约 5200 字 ✨ 项目总览 在物联网实际项目中,**“终端 + 网关”**是…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

用机器学习破解新能源领域的“弃风”难题
音乐发烧友深有体会,玩音乐的本质就是玩电网。火电声音偏暖,水电偏冷,风电偏空旷。至于太阳能发的电,则略显朦胧和单薄。 不知你是否有感觉,近两年家里的音响声音越来越冷,听起来越来越单薄? —…...
C++.OpenGL (14/64)多光源(Multiple Lights)
多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...

CVE-2020-17519源码分析与漏洞复现(Flink 任意文件读取)
漏洞概览 漏洞名称:Apache Flink REST API 任意文件读取漏洞CVE编号:CVE-2020-17519CVSS评分:7.5影响版本:Apache Flink 1.11.0、1.11.1、1.11.2修复版本:≥ 1.11.3 或 ≥ 1.12.0漏洞类型:路径遍历&#x…...

算法:模拟
1.替换所有的问号 1576. 替换所有的问号 - 力扣(LeetCode) 遍历字符串:通过外层循环逐一检查每个字符。遇到 ? 时处理: 内层循环遍历小写字母(a 到 z)。对每个字母检查是否满足: 与…...