当前位置: 首页 > news >正文

30分钟吃掉wandb可视化自动调参

wandb.sweep: 低代码,可视化,分布式 自动调参工具。

使用wandb 的 sweep 进行超参调优,具有以下优点。

(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。

(2)可视化:在wandb网页中可以实时监控调参过程中每次尝试,并可视化地分析调参任务的目标值分布,超参重要性等。

(3)分布式:sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。

公众号后台回复关键词:wandb,获取本文notebook代码和B站视频演示。

使用 wandb 的sweep 调参的缺点:

需要联网:由于wandb的controller位于wandb的服务器机器上,wandb日志也需要联网上传,在没有互联网的环境下无法正常使用wandb 进行模型跟踪 以及 wandb sweep 可视化调参。

d6731f3afe349a385fa50a5eb394b50a.png

〇,使用Sweep的3步骤

  1. 配置 sweep_config

配置调优算法,调优目标,需要优化的超参数列表 等等。
  1. 初始化 sweep controller:

sweep_id = wandb.sweep(sweep_config,project)
  1. 启动 sweep agents:

wandb.agent(sweep_id, function=train)
import os,PIL 
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms
import datetime
import wandb wandb.login()
from argparse import Namespacedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#初始化参数配置
config = Namespace(project_name = 'wandb_demo',batch_size = 512,hidden_layer_width = 64,dropout_p = 0.1,lr = 1e-4,optim_type = 'Adam',epochs = 15,ckpt_path = 'checkpoint.pt'
)

一. 配置 Sweep config

详细配置文档可以参考:https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

1,选择一个调优算法

Sweep支持如下3种调优算法:

(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。

(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。

(3)贝叶斯搜索:bayes. 创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。

sweep_config = {'method': 'random'}

2,定义调优目标

设置优化指标,以及优化方向。

sweep agents 通过 wandb.log 的形式向 sweep controller 传递优化目标的值。

metric = {'name': 'val_acc','goal': 'maximize'   }
sweep_config['metric'] = metric

3,定义超参空间

超参空间可以分成 固定型,离散型和连续型。

  • 固定型:指定 value

  • 离散型:指定 values,列出全部候选取值。

  • 连续性:需要指定 分布类型 distribution, 和范围 min, max。用于 random 或者 bayes采样。

sweep_config['parameters'] = {}# 固定不变的超参
sweep_config['parameters'].update({'project_name':{'value':'wandb_demo'},'epochs': {'value': 10},'ckpt_path': {'value':'checkpoint.pt'}})# 离散型分布超参
sweep_config['parameters'].update({'optim_type': {'values': ['Adam', 'SGD','AdamW']},'hidden_layer_width': {'values': [16,32,48,64,80,96,112,128]}})# 连续型分布超参
sweep_config['parameters'].update({'lr': {'distribution': 'log_uniform_values','min': 1e-6,'max': 0.1},'batch_size': {'distribution': 'q_uniform','q': 8,'min': 32,'max': 256,},'dropout_p': {'distribution': 'uniform','min': 0,'max': 0.6,}
})

4,定义剪枝策略 (可选)

可以定义剪枝策略,提前终止那些没有希望的任务。

sweep_config['early_terminate'] = {'type':'hyperband','min_iter':3,'eta':2,'s':3
} #在step=3, 6, 12 时考虑是否剪枝
from pprint import pprint
pprint(sweep_config)

二. 初始化 sweep controller

sweep_id = wandb.sweep(sweep_config, project=config.project_name)

三, 启动 Sweep agent

我们需要把模型训练相关的全部代码整理成一个 train函数。

def create_dataloaders(config):transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,num_workers=2,drop_last=True)dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, num_workers=2,drop_last=True)return dl_train,dl_val
def create_net(config):net = nn.Sequential()net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,out_channels=config.hidden_layer_width,kernel_size = 5))net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))net.add_module("flatten",nn.Flatten())net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))net.add_module("relu",nn.ReLU())net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))return net
def train_epoch(model,dl_train,optimizer):model.train()for step, batch in enumerate(dl_train):features,labels = batchfeatures,labels = features.to(device),labels.to(device)preds = model(features)loss = nn.CrossEntropyLoss()(preds,labels)loss.backward()optimizer.step()optimizer.zero_grad()return model
def eval_epoch(model,dl_val):model.eval()accurate = 0num_elems = 0for batch in dl_val:features,labels = batchfeatures,labels = features.to(device),labels.to(device)with torch.no_grad():preds = model(features)predictions = preds.argmax(dim=-1)accurate_preds =  (predictions==labels)num_elems += accurate_preds.shape[0]accurate += accurate_preds.long().sum()val_acc = accurate.item() / num_elemsreturn val_acc
def train(config = config):dl_train, dl_val = create_dataloaders(config)model = create_net(config); optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)#======================================================================nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)model.run_id = wandb.run.id#======================================================================model.best_metric = -1.0for epoch in range(1,config.epochs+1):model = train_epoch(model,dl_train,optimizer)val_acc = eval_epoch(model,dl_val)if val_acc>model.best_metric:model.best_metric = val_acctorch.save(model.state_dict(),config.ckpt_path)   nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")#======================================================================wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})#======================================================================        #======================================================================wandb.finish()#======================================================================return model   #model = train(config)

一切准备妥当,点火🔥🔥。

# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

四,调参可视化和跟踪

1,平行坐标系图

可以直观展示哪些超参数组合更加容易获取更好的结果。

7366fd427d9e456030174fd9764948ec.png


2,超参数重要性图

可以显示超参数和优化目标最终取值的重要性,和相关性方向。

79447d4928b8ac70a543e57a9c7141f7.png


caa2b93c396396eb37a888a052bd7cdf.png

相关文章:

30分钟吃掉wandb可视化自动调参

wandb.sweep: 低代码,可视化,分布式 自动调参工具。使用wandb 的 sweep 进行超参调优,具有以下优点。(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。(2)可视化…...

【8】AMBA_SOC项目自学IC验证项目-仿真平台脚本使用讲解

仿真平台文件介绍和脚本使用说明 1、项目路径:2、文件夹说明:3、仿真运行命令:第一步:进入项目路径第二步:设置环境第三步:运行仿真第四步:查看波形1、项目路径: 位置:/tool/project/axi 2、文件夹说明: a、env就是放的我们uvm环境相关的env文件; b、out就是我们…...

智慧水务未来技术发展方向预测探讨

随着科技的不断发展和城市化的加速,智慧水务作为一种新的水务模式,逐渐受到广泛关注。未来,智慧水务将会面临更多的技术挑战和商机。本博客将对智慧水务的未来技术发展方向进行预测,以探讨智慧水务未来可能的技术重点。 1. 人工…...

数据结构 | 栈与队列

🔥Go for it!🔥 📝个人主页:按键难防 📫 如果文章知识点有错误的地方,请指正!和大家一起学习,一起进步👀 📖系列专栏:数据结构与算法 &#x1f52…...

Redux 源码分析

Redux 目录结构 redux ├─ .babelrc.js ├─ .editorconfig ├─ .gitignore …...

第五十二章 BFS进阶(二)——双向广搜

第五十二章 BFS进阶(二)——双向广搜一、双向广搜1、优越之处2、实现逻辑3、复杂度分析二、例题1、问题2、分析3、代码一、双向广搜 1、优越之处 双向广搜是指我们从终点和起点同时开始搜索,当二者到达同一个中间状态的时候,即相…...

业务建模题

一. 单选题:1.在活动图中负责在一个活动节点执行完毕后切换到另一个节点的元素是( A)。A.控制流 B.对象流 C.判断节点 D.扩展区城2.以下说法错误的是(C)。A.活动图中的开始标记一般只有一一个,而终止标记可能有多个B.判断节点的出口条件必须保证不互相重复,并且不缺…...

电子秤专用模拟数字(AD)转换器芯片HX711介绍

HX711简介HX711是一款专为高精度电子秤而设计的24 位A/D 转换器芯片。与同类型其它芯片相比,该芯片集成了包括稳压电源、片内时钟振荡器等其它同类型芯片所需要的外围电路,具有集成度高、响应速度快、抗干扰性强等优点。降低了电子秤的整机成本&#xff…...

微服务 RocketMQ-延时消息 消息过滤 管控台搜索问题

~~微服务 RocketMQ-延时消息 消息过滤 管控台搜索问题~~ RocketMQ-延时消息实现延时消息RocketMQ-消息过滤Tag标签过滤SQL标签过滤管控台搜索问题RocketMQ-延时消息 给消息设置延时时间,到一定时间,消费者才能消费的到,中间件内部通过每秒钟扫…...

js发送邮件(node.js)

以前看别人博客留言或者评论文章时必须填写邮箱信息,感觉甚是麻烦。 后来才知道是为了在博主回复后让访客收到邮件,用心良苦。 于是我也在新增留言和文章评论的接口里,新增了给自己发送邮件提醒的功能。 我用的QQ邮箱,具体如下…...

English Learning - Day58 一周高频问题汇总 2023.2.12 周日

English Learning - Day58 一周高频问题汇总 2023.2.12 周日这周主要内容继续说说状语从句结果状语从句这周主要内容 DAY58【周日总结】 一周高频问题汇总 (打卡作业详见 Day59) 一近期主要讲了 一 01.主动脉修饰 以下是最常问到的知识点拓展&#xff…...

【微电网】基于风光储能和需求响应的微电网日前经济调度(Python代码实现)

目录 1 概述 2 知识点及数学模型 3 算例实现 3.1算例介绍 3.2风光参与的模型求解 3.3 风光和储能参与的模型求解 3.5 风光储能和需求响应都参与模型求解 3.6 结果分析对比 4 Python代码及算例数据 1 概述 近年来,微电网、清洁能源等已成为全球关注的热点…...

四种方式的MySQL安装

mysql安装常见的方法有四种序号 安装方式 说明1 yum\rpm简单、快速,不能定制参数2二进制 解压,简单配置就可使用 免安装 mysql-a.b.c-linux2.x-x86_64.tar.gz3源码编译 可以定制参数,安装时间长 mysql-a.b.c.tar.gz4源码制成rpm包 把源码制…...

软考高级信息系统项目管理师系列之九:项目范围管理

软考高级信息系统项目管理师系列之九:项目范围管理 一、范围管理输入、输出、工具和技术表二、范围管理概述三、规划范围管理四、收集需求1.收集需求:2.需求分类3.收集需求的工具与技术4.收集需求过程主要输出5.需求文件内容6.需求管理7.可跟踪性8.双向可跟踪性9.需求跟踪矩阵…...

【项目精选】javaEE健康管理系统(论文+开题报告+答辩PPT+源代码+数据库+讲解视频)

点击下载源码 javaEE健康管理系统主要功能包括:教师登录退出、教师饮食管理、教师健康日志、体检管理等等。本系统结构如下: (1)用户模块: 实现登录功能 实现用户登录的退出 实现用户注册 (2)教…...

ctfshow nodejs

web 334 大小写转换特殊字符绕过。 “ı”.toUpperCase() ‘I’,“ſ”.toUpperCase() ‘S’。 “K”.toLowerCase() ‘k’. payload: CTFſHOW 123456web 335 通过源码可知 eval(xxx),eval 中可以执行 js 代码,那么我们可以依此执行系…...

无线传感器原理及方法|重点理论知识|2021年19级|期末考试

Min-Max定位 【P63】 最小最大法的基本思想是依据未知节点到各锚节点的距离测量值及锚节点的坐标构造若干个边界框,即以参考节点为圆心,未知节点到该锚节点的距离测量值为半径所构成圆的外接矩形,计算外接矩形的质心为未知节点的估计坐标。 多边定位法的浮点运算量大,计算代…...

带你写出符合 Promise/A+ 规范 Promise 的源码

Promise是前端面试中的高频问题,如果你能根据PromiseA的规范,写出符合规范的源码,那么我想,对于面试中的Promise相关的问题,都能够给出比较完美的答案。 我的建议是,对照规范多写几次实现,也许…...

回流与重绘

触发回流与重绘条件👉回流当渲染树中部分或者全部元素的尺寸、结构或者属性发生变化时,浏览器会重新渲染部分或者全部文档的过程就称为 回流。引起回流原因1.页面的首次渲染2.浏览器的窗口大小发生变化3.元素的内容发生变化4.元素的尺寸或者位置发生变化…...

openpyxl表格的简单实用

示例:创建简单的电子表格和条形图 在这个例子中,我们将从头开始创建一个工作表并添加一些数据,然后绘制它。我们还将探索一些有限的单元格样式和格式。 我们将在工作表上输入的数据如下: 首先,让我们加载 openpyxl 并创建一个新工作簿。并获取活动表。我们还将输入我们…...

Spark 之 入门讲解详细版(1)

1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...

R语言AI模型部署方案:精准离线运行详解

R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

2023赣州旅游投资集团

单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

MySQL 索引底层结构揭秘:B-Tree 与 B+Tree 的区别与应用

文章目录 一、背景知识:什么是 B-Tree 和 BTree? B-Tree(平衡多路查找树) BTree(B-Tree 的变种) 二、结构对比:一张图看懂 三、为什么 MySQL InnoDB 选择 BTree? 1. 范围查询更快 2…...

C语言中提供的第三方库之哈希表实现

一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

TSN交换机正在重构工业网络,PROFINET和EtherCAT会被取代吗?

在工业自动化持续演进的今天,通信网络的角色正变得愈发关键。 2025年6月6日,为期三天的华南国际工业博览会在深圳国际会展中心(宝安)圆满落幕。作为国内工业通信领域的技术型企业,光路科技(Fiberroad&…...

Leetcode33( 搜索旋转排序数组)

题目表述 整数数组 nums 按升序排列&#xff0c;数组中的值 互不相同 。 在传递给函数之前&#xff0c;nums 在预先未知的某个下标 k&#xff08;0 < k < nums.length&#xff09;上进行了 旋转&#xff0c;使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...

Axure 下拉框联动

实现选省、选完省之后选对应省份下的市区...

负载均衡器》》LVS、Nginx、HAproxy 区别

虚拟主机 先4&#xff0c;后7...