当前位置: 首页 > news >正文

optuna和 lightgbm

文章目录

  • optuna使用
    • 1.导入相关包
    • 2.定义模型可选参数
    • 3.定义训练代码和评估代码
    • 4.定义目标函数
    • 5.运行程序
    • 6.可视化
    • 7.超参数的重要性
    • 8.查看相关信息
    • 9.可视化的一个完整示例
    • 10.lightgbm实验

optuna使用

1.导入相关包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from fvcore.nn import FlopCountAnalysisimport optunaDEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
DIR = ".."
BATCHSIZE = 128
N_TRAIN_EXAMPLES = BATCHSIZE * 30   # 128 * 30个训练
N_VALID_EXAMPLES = BATCHSIZE * 10   # 128 * 10个预测

2.定义模型可选参数

optuna支持很多种搜索方式:
(1)trial.suggest_categorical(‘optimizer’, [‘MomentumSGD’, ‘Adam’]):表示从SGD和adam里选一个使用;
(2)trial.suggest_int(‘num_layers’, 1, 3):从1~3范围内的int里选;
(3)trial.suggest_uniform(‘dropout_rate’, 0.0, 1.0):从0~1内的uniform分布里选;
(4)trial.suggest_loguniform(‘learning_rate’, 1e-5, 1e-2):从1e-5~1e-2的log uniform分布里选;
(5)trial.suggest_discrete_uniform(‘drop_path_rate’, 0.0, 1.0, 0.1):从0~1且step为0.1的离散uniform分布里选;

def define_model(trial):n_layers = trial.suggest_int("n_layers", 1, 3) # 从[1,3]范围里面选一个layers = []in_features = 28 * 28for i in range(n_layers):out_features = trial.suggest_int("n_units_l{}".format(i), 4, 128)layers.append(nn.Linear(in_features, out_features))layers.append(nn.ReLU())p = trial.suggest_float("dropout_{}".format(i), 0.2, 0.5)layers.append(nn.Dropout(p))in_features = out_featureslayers.append(nn.Linear(in_features, 10))layers.append(nn.LogSoftmax(dim=1))return nn.Sequential(*layers)

3.定义训练代码和评估代码

# Defines training and evaluation.
def train_model(model, optimizer, train_loader):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)optimizer.zero_grad()F.nll_loss(model(data), target).backward()optimizer.step()def eval_model(model, valid_loader):model.eval()correct = 0with torch.no_grad():for batch_idx, (data, target) in enumerate(valid_loader):data, target = data.view(-1, 28 * 28).to(DEVICE), target.to(DEVICE)pred = model(data).argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()accuracy = correct / N_VALID_EXAMPLESflops = FlopCountAnalysis(model, inputs=(torch.randn(1, 28 * 28).to(DEVICE),)).total()return flops, accuracy

4.定义目标函数

def objective(trial):train_dataset = torchvision.datasets.FashionMNIST(DIR, train=True, download=True, transform=torchvision.transforms.ToTensor())train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),batch_size=BATCHSIZE,shuffle=True,)val_dataset = torchvision.datasets.FashionMNIST(DIR, train=False, transform=torchvision.transforms.ToTensor())val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),batch_size=BATCHSIZE,shuffle=True,)model = define_model(trial).to(DEVICE)optimizer = torch.optim.Adam(model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True))for epoch in range(10):train_model(model, optimizer, train_loader)flops, accuracy = eval_model(model, val_loader)return flops, accuracy

5.运行程序

运行30次实验,每次实验返回 flops,accuracy

study = optuna.create_study(directions=["minimize", "maximize"]) # flops 最小化, accuracy 最大化
study.optimize(objective, n_trials=30, timeout=300)print("Number of finished trials: ", len(study.trials))

6.可视化

flops, accuracy 二维图
optuna.visualization.plot_pareto_front(study, target_names=[“FLOPS”, “accuracy”])

在这里插入图片描述

7.超参数的重要性

对于flops
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[0], target_name=“flops”
)

对于accuracy
optuna.visualization.plot_param_importances(
study, target=lambda t: t.values[1], target_name=“accuracy”
)

在这里插入图片描述

8.查看相关信息

# https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/002_multi_objective.html
# 利用pytorch mnist 识别
# 设置了一些超参数,lr, layer number, feature_number等
# 然后目标是 flops 和 accurary# 最后是可视化:
# 显示试验的一些结果:
# optuna.visualization.plot_pareto_front(study, target_names=["FLOPS", "accuracy"])
# 左上角是最好的# 显示重要性:
# optuna.visualization.plot_param_importances(
#     study, target=lambda t: t.values[0], target_name="flops"
# )
# optuna.visualization.plot_param_importances(
#     study, target=lambda t: t.values[1], target_name="accuracy"
# )# trials的属性:
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1])
print(f"Trial with highest accuracy: ")
print(f"\tnumber: {trial_with_highest_accuracy.number}")
print(f"\tparams: {trial_with_highest_accuracy.params}")
print(f"\tvalues: {trial_with_highest_accuracy.values}")

9.可视化的一个完整示例

# You can use Matplotlib instead of Plotly for visualization by simply replacing `optuna.visualization` with
# `optuna.visualization.matplotlib` in the following examples.
from optuna.visualization import plot_contour
from optuna.visualization import plot_edf
from optuna.visualization import plot_intermediate_values
from optuna.visualization import plot_optimization_history
from optuna.visualization import plot_parallel_coordinate
from optuna.visualization import plot_param_importances
from optuna.visualization import plot_rank
from optuna.visualization import plot_slice
from optuna.visualization import plot_timelinedef objective(trial):train_dataset = torchvision.datasets.FashionMNIST(DIR, train=True, download=True, transform=torchvision.transforms.ToTensor())train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, list(range(N_TRAIN_EXAMPLES))),batch_size=BATCHSIZE,shuffle=True,)val_dataset = torchvision.datasets.FashionMNIST(DIR, train=False, transform=torchvision.transforms.ToTensor())val_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(val_dataset, list(range(N_VALID_EXAMPLES))),batch_size=BATCHSIZE,shuffle=True,)model = define_model(trial).to(DEVICE)optimizer = torch.optim.Adam(model.parameters(), trial.suggest_float("lr", 1e-5, 1e-1, log=True))for epoch in range(10):train_model(model, optimizer, train_loader)val_accuracy = eval_model(model, val_loader)trial.report(val_accuracy, epoch)if trial.should_prune():raise optuna.exceptions.TrialPruned()return val_accuracystudy = optuna.create_study(direction="maximize",sampler=optuna.samplers.TPESampler(seed=SEED),pruner=optuna.pruners.MedianPruner(),
)
study.optimize(objective, n_trials=30, timeout=300)

运行之后可视化:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

10.lightgbm实验

"""
Optuna example that optimizes a classifier configuration for cancer dataset using LightGBM.In this example, we optimize the validation accuracy of cancer detection using LightGBM.
We optimize both the choice of booster model and their hyperparameters."""import numpy as np
import optunaimport lightgbm as lgb
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)dtrain = lgb.Dataset(train_x, label=train_y)param = {"objective": "binary","metric": "binary_logloss","verbosity": -1,"boosting_type": "gbdt","lambda_l1": trial.suggest_float("lambda_l1", 1e-8, 10.0, log=True),"lambda_l2": trial.suggest_float("lambda_l2", 1e-8, 10.0, log=True),"num_leaves": trial.suggest_int("num_leaves", 2, 256),"feature_fraction": trial.suggest_float("feature_fraction", 0.4, 1.0),"bagging_fraction": trial.suggest_float("bagging_fraction", 0.4, 1.0),"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),}gbm = lgb.train(param, dtrain)preds = gbm.predict(valid_x)pred_labels = np.rint(preds)accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)return accuracyif __name__ == "__main__":study = optuna.create_study(direction="maximize")study.optimize(objective, n_trials=100)print("Number of finished trials: {}".format(len(study.trials)))print("Best trial:")trial = study.best_trialprint("  Value: {}".format(trial.value))print("  Params: ")for key, value in trial.params.items():print("    {}: {}".format(key, value))

运行结果:
在这里插入图片描述

https://github.com/microsoft/LightGBM/tree/master/examples

https://blog.csdn.net/yang1015661763/article/details/131364826

相关文章:

optuna和 lightgbm

文章目录 optuna使用1.导入相关包2.定义模型可选参数3.定义训练代码和评估代码4.定义目标函数5.运行程序6.可视化7.超参数的重要性8.查看相关信息9.可视化的一个完整示例10.lightgbm实验 optuna使用 1.导入相关包 import torch import torch.nn as nn import torch.nn.functi…...

Android 设置铃声和闹钟

Android设置铃声和闹钟使用的方法是一样的,但是要区别的去获取对应的权限。 统一权限,不管是设置闹钟还是铃声,他们都需要一个系统设置权限如下: //高版本需要WRITE_SETTINGS权限//此权限是敏感权限,无法动态申请,需要…...

自动化测试模型(一)

8.8.1 自动化测试模型概述 在自动化测试运用于测试工作的过程中,测试人员根据不同自动化测试工具、测试框架等所进行的测试活动进行了抽象,总结出线性测试、模块化驱动测试、数据驱动测试和关键字驱动测试这4种自动化测试模型。 线性测试 首先&#…...

解决nuxt3下载慢下载报错问题

在下载nuxt3时总是下不下来,最后还报错了。即使改成国内镜像源也不行。 解决方法: 直接去github上下载 https://github.com/nuxt/starter/tree/v3 解压后得到如下目录: 手动修改项目名和文件夹名 安装依赖 npm install可能会比较慢或下不…...

Ubuntu修改swap大小

查看swap位置和大小: swapon -s 方案一:修改原有文件大小方式 第一步:进入系统根目录cd /; 第二步:执行:sudo dd if/dev/zero of/swap bs1M count16384 //每段块1M 共16384块,即16G 第三步:执行…...

[C#] 复数乘法的跨平台SIMD硬件加速向量算法(不仅支持X86的Sse、Avx、Avx512,还支持Arm的AdvSimd)

文章目录 一、简单算法二、向量算法2.1 算法思路2.1.1 复数乘法的数学定义2.1.2 复数的数据布局2.1.3 第1步:计算 (a*c) (-b*d)i2.1.4 第2步:计算 (a*d) (b*c)i2.1.5 第3步:计算结果合并 2.2 算法实现(UseVectors)2.…...

C#WPF基础介绍/第一个WPF程序

什么是WPF WPF(Windows Presentation Foundation)是微软公司推出的一种用于创建窗口应用程序的界面框架。它是.NET Framework的一部分,提供了一套先进的用户界面设计工具和功能,可以实现丰富的图形、动画和多媒体效果。 WPF 使用…...

强大的接口测试可视化工具:Postman Flows

Postman Flows是一种接口测试可视化工具,可以使用流的形式在Postman工作台将请求接口、数据处理和创建实际流程整合到一起。如下图所示 Postman Flows是以API为中心的可视化应用程序开发界面。它提供了一个无限的画布用于编排和串连API,数据可视化来显示…...

系统设计及解决方案

发送验证码 1:根据手机号从Redis中获取value(验证码_时间戳) 2:如果value不为空,并且时间戳与当前时间戳的间隔小于60秒,则返回一个错误信息 3:生成随机验证码 4:调用阿里云短信服务API给用户发送短信验证码…...

从0入门自主空中机器人-2-2【无人机硬件选型-PX4篇】

1. 常用资料以及官方网站 无人机飞控PX4用户使用手册(无人机基本设置、地面站使用教程、软硬件搭建等):https://docs.px4.io/main/en/ PX4固件开源地址:https://github.com/PX4/PX4-Autopilot 飞控硬件、数传模块、GPS、分电板等…...

Linux之ARM(MX6U)裸机篇----2.汇编LED驱动实验

一,alpha的LED灯硬件原理分析 STM32 IO初始化流程 ①,使能GPIO时钟 ②,设置IO复用,复用为GPIO ③,配置GPIO的电气属性推挽,上拉下拉 ④,使用GPIO,输出高/低电平 MX6ULL IO初始化…...

e3 1220lv3 cpu-z分数

e3 1220lv3 双核四线程,1.1G频率,最低可在800MHZ运行,TDP 13W。 使用PE启动后测试cpu-z分数。 现在e3 1220lv3的价格落到69元。...

HTML5适配手机

要使 HTML5 网站适配手机设备&#xff0c;您可以遵循以下几个步骤和最佳实践&#xff1a; 1. 使用视口&#xff08;Viewport&#xff09; 在 HTML 文档的 <head> 部分添加视口元标签&#xff0c;以确保页面在移动设备上正确缩放和显示&#xff1a; <meta name"…...

C# 中使用 MassTransit

在生产环境中使用 MassTransit 时&#xff0c;通常需要进行详细的配置&#xff0c;包括设置连接字符串、配置队列、配置消费者、处理重试和错误队列等。以下是一个完整的示例&#xff0c;展示了如何在 ASP.NET Core 应用程序中配置 MassTransit&#xff0c;包括请求/响应模式和…...

网络编程 实现联网 b+Tree

网络编程是客户端和服务器之间通信的基础&#xff0c;也是现代应用开发中不可或缺的技能。在 Unity 中实现网络功能&#xff0c;需要结合计算机网络原理、数据结构与算法&#xff0c;以及网络协议的实际应用。以下是对这一块内容的详细介绍&#xff0c;包括每个涉及到的知识点&…...

zentao ubuntu上安装

#下载ZenTaoPMS-21.2-zbox_amd64.tar.gz&#xff08;https://www.zentao.net/downloads.html&#xff09; https://dl.zentao.net/zentao/21.2/ZenTaoPMS-21.2-zbox_amd64.tar.gzcd /opt tar -zxvf ZenTaoPMS-21.2-zbox_amd64.tar.gz#启动 /opt/zbox/zbox start /opt/zbox/zbox…...

Java 网络原理 ①-IO多路复用 || 自定义协议 || XML || JSON

这里是Themberfue 在学习完简单的网络编程后&#xff0c;我们将更加深入网络的学习——HTTP协议、TCP协议、UDP协议、IP协议........... IO多路复用 ✨在上一节基于 TCP 协议 编写应用层代码时&#xff0c;我们通过一个线程处理连接的申请&#xff0c;随后通过多线程或者线程…...

Bash Shell知识合集

1. chmod命令 创建一个bash shell脚本 hello.sh ~script $ touch hello.sh脚本创建完成后并不能直接执行&#xff0c;我们要用chmod命令授予它可执行的权限&#xff1a; ~script $ chmod 755 hello.sh授权后的脚本可以直接执行&#xff1a; ~script $ ./hello.sh2.指定运行…...

从0入门自主空中机器人-1【课程介绍】

关于本课程&#xff1a; 本次课程是一套面向对自主空中机器人感兴趣的学生、爱好者、相关从业人员的免费课程&#xff0c;包含了从硬件组装、机载电脑环境设置、代码部署、实机实验等全套详细流程&#xff0c;带你从0开始&#xff0c;组装属于自己的自主无人机&#xff0c;并让…...

Doris使用注意点

自己学习过程中整理&#xff0c;非官方 dws等最后用于查询的表可以考虑使用row存储加快查询&#xff0c;即用空间换时间duplicate key的选择要考虑最常查询使用适当使用bloomfilter 加速查询适当使用aggregate 模式降低max&#xff0c;avg&#xff0c;min之类的计算并加快查询…...

powershell 安装 .netframework3.5

在 PowerShell 中安装 .NET Framework 3.5 可以通过几种不同的方法实现&#xff0c;取决于你的操作系统版本。以下是几种常见的方法&#xff1a; 方法1&#xff1a;使用 DISM 命令 对于 Windows 10 和 Windows 8.1&#xff0c;你可以使用 DISM&#xff08;Deployment Image Se…...

架构师级考验!飞算 JavaAI 炫技赛:AI 辅助编程解决老项目难题

当十年前 Hibernate 框架的 N1 查询隐患在深夜持续困扰排查&#xff0c;当 SpringMVC 控制器中错综复杂的业务逻辑在跨语言迁移时令人抓狂&#xff0c;企业数字化进程中的百万行老系统&#xff0c;已然成为暗藏危机的 “技术债冰山”。而此刻&#xff0c;飞算科技全新发布的 Ja…...

快捷键的记录

下面对应的ATL数字 ATL4 显示编译输出 CTRL B 编译 CTRLR 运行exe 菜单栏 ALTF ALTE ALTB ALTD ALTH...

深入理解 Spring IOC:从概念到实践

目录 一、引言 二、什么是 IOC&#xff1f; 2.1 控制反转的本质 2.2 类比理解 三、Spring IOC 的核心组件 3.1 IOC 容器的分类 3.2 Bean 的生命周期 四、依赖注入&#xff08;DI&#xff09;的三种方式 4.1 构造器注入 4.2 Setter 方法注入 4.3 注解注入&#xff08;…...

【深度学习-Day 24】过拟合与欠拟合:深入解析模型泛化能力的核心挑战

Langchain系列文章目录 01-玩转LangChain&#xff1a;从模型调用到Prompt模板与输出解析的完整指南 02-玩转 LangChain Memory 模块&#xff1a;四种记忆类型详解及应用场景全覆盖 03-全面掌握 LangChain&#xff1a;从核心链条构建到动态任务分配的实战指南 04-玩转 LangChai…...

【envoy】-1.安装与下载源码

1.安装 建议使用ubuntu2004&#xff0c;对glibc有要求。上个ti子更快。 wget -O- https://apt.envoyproxy.io/signing.key | sudo gpg --dearmor -o /etc/apt/keyrings/envoy-keyring.gpg $ echo "deb [arch$(dpkg --print-architecture) signed-by/etc/apt/keyrings/envo…...

鸿蒙仓颉语言开发实战教程:购物车页面

大家上午好&#xff0c;仓颉语言商城应用的开发进程已经过半&#xff0c;不知道大家通过这一系列的教程对仓颉开发是否有了进一步的了解。今天要分享的购物车页面&#xff1a; 看到这个页面&#xff0c;我们首先要对它简单的分析一下。这个页面一共分为三部分&#xff0c;分别是…...

如何从浏览器中导出网站证书

以导出 GitHub 证书为例&#xff0c;点击 小锁 点击 导出 注意&#xff1a;这里需要根据你想要证书格式手动加上后缀名&#xff0c;我的是加 .crt 双击文件打开...

全生命周期的智慧城市管理

前言 全生命周期的智慧城市管理。未来&#xff0c;城市将在 实现从基础设施建设、日常运营到数据管理的 全生命周期统筹。这将避免过去智慧城市建设 中出现的“碎片化”问题&#xff0c;实现资源的高效配 置和项目的协调发展。城市管理者将运用先进 的信息技术&#xff0c;如物…...

el-tabs 切换时数据不更新的问题

最近业务需求&#xff0c;需要在页面中使用tabs&#xff0c;使用过程中出现tabs切换&#xff0c;数据不更新的问题&#xff0c;以下是思路和解决办法。 Vue 会追踪你在模板中绑定的数据&#xff0c;并在数据发生变化时重新渲染相应的部分。但在使用 el-tabs 时&#xff0c;有时…...