当前位置: 首页 > news >正文

30分钟吃掉wandb可视化自动调参

wandb.sweep: 低代码,可视化,分布式 自动调参工具。

使用wandb 的 sweep 进行超参调优,具有以下优点。

(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。

(2)可视化:在wandb网页中可以实时监控调参过程中每次尝试,并可视化地分析调参任务的目标值分布,超参重要性等。

(3)分布式:sweep采用类似master-workers的controller-agents架构,controller在wandb的服务器机器上运行,agents在用户机器上运行,controller和agents之间通过互联网进行通信。同时启动多个agents即可轻松实现分布式超参搜索。

公众号后台回复关键词:wandb,获取本文notebook代码和B站视频演示。

使用 wandb 的sweep 调参的缺点:

需要联网:由于wandb的controller位于wandb的服务器机器上,wandb日志也需要联网上传,在没有互联网的环境下无法正常使用wandb 进行模型跟踪 以及 wandb sweep 可视化调参。

d6731f3afe349a385fa50a5eb394b50a.png

〇,使用Sweep的3步骤

  1. 配置 sweep_config

配置调优算法,调优目标,需要优化的超参数列表 等等。
  1. 初始化 sweep controller:

sweep_id = wandb.sweep(sweep_config,project)
  1. 启动 sweep agents:

wandb.agent(sweep_id, function=train)
import os,PIL 
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch 
from torch import nn 
import torchvision 
from torchvision import transforms
import datetime
import wandb wandb.login()
from argparse import Namespacedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#初始化参数配置
config = Namespace(project_name = 'wandb_demo',batch_size = 512,hidden_layer_width = 64,dropout_p = 0.1,lr = 1e-4,optim_type = 'Adam',epochs = 15,ckpt_path = 'checkpoint.pt'
)

一. 配置 Sweep config

详细配置文档可以参考:https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

1,选择一个调优算法

Sweep支持如下3种调优算法:

(1)网格搜索:grid. 遍历所有可能得超参组合,只在超参空间不大的时候使用,否则会非常慢。

(2)随机搜索:random. 每个超参数都选择一个随机值,非常有效,一般情况下建议使用。

(3)贝叶斯搜索:bayes. 创建一个概率模型估计不同超参数组合的效果,采样有更高概率提升优化目标的超参数组合。对连续型的超参数特别有效,但扩展到非常高维度的超参数时效果不好。

sweep_config = {'method': 'random'}

2,定义调优目标

设置优化指标,以及优化方向。

sweep agents 通过 wandb.log 的形式向 sweep controller 传递优化目标的值。

metric = {'name': 'val_acc','goal': 'maximize'   }
sweep_config['metric'] = metric

3,定义超参空间

超参空间可以分成 固定型,离散型和连续型。

  • 固定型:指定 value

  • 离散型:指定 values,列出全部候选取值。

  • 连续性:需要指定 分布类型 distribution, 和范围 min, max。用于 random 或者 bayes采样。

sweep_config['parameters'] = {}# 固定不变的超参
sweep_config['parameters'].update({'project_name':{'value':'wandb_demo'},'epochs': {'value': 10},'ckpt_path': {'value':'checkpoint.pt'}})# 离散型分布超参
sweep_config['parameters'].update({'optim_type': {'values': ['Adam', 'SGD','AdamW']},'hidden_layer_width': {'values': [16,32,48,64,80,96,112,128]}})# 连续型分布超参
sweep_config['parameters'].update({'lr': {'distribution': 'log_uniform_values','min': 1e-6,'max': 0.1},'batch_size': {'distribution': 'q_uniform','q': 8,'min': 32,'max': 256,},'dropout_p': {'distribution': 'uniform','min': 0,'max': 0.6,}
})

4,定义剪枝策略 (可选)

可以定义剪枝策略,提前终止那些没有希望的任务。

sweep_config['early_terminate'] = {'type':'hyperband','min_iter':3,'eta':2,'s':3
} #在step=3, 6, 12 时考虑是否剪枝
from pprint import pprint
pprint(sweep_config)

二. 初始化 sweep controller

sweep_id = wandb.sweep(sweep_config, project=config.project_name)

三, 启动 Sweep agent

我们需要把模型训练相关的全部代码整理成一个 train函数。

def create_dataloaders(config):transform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,num_workers=2,drop_last=True)dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, num_workers=2,drop_last=True)return dl_train,dl_val
def create_net(config):net = nn.Sequential()net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,out_channels=config.hidden_layer_width,kernel_size = 5))net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))net.add_module("flatten",nn.Flatten())net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))net.add_module("relu",nn.ReLU())net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))return net
def train_epoch(model,dl_train,optimizer):model.train()for step, batch in enumerate(dl_train):features,labels = batchfeatures,labels = features.to(device),labels.to(device)preds = model(features)loss = nn.CrossEntropyLoss()(preds,labels)loss.backward()optimizer.step()optimizer.zero_grad()return model
def eval_epoch(model,dl_val):model.eval()accurate = 0num_elems = 0for batch in dl_val:features,labels = batchfeatures,labels = features.to(device),labels.to(device)with torch.no_grad():preds = model(features)predictions = preds.argmax(dim=-1)accurate_preds =  (predictions==labels)num_elems += accurate_preds.shape[0]accurate += accurate_preds.long().sum()val_acc = accurate.item() / num_elemsreturn val_acc
def train(config = config):dl_train, dl_val = create_dataloaders(config)model = create_net(config); optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)#======================================================================nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)model.run_id = wandb.run.id#======================================================================model.best_metric = -1.0for epoch in range(1,config.epochs+1):model = train_epoch(model,dl_train,optimizer)val_acc = eval_epoch(model,dl_val)if val_acc>model.best_metric:model.best_metric = val_acctorch.save(model.state_dict(),config.ckpt_path)   nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")#======================================================================wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})#======================================================================        #======================================================================wandb.finish()#======================================================================return model   #model = train(config)

一切准备妥当,点火🔥🔥。

# 该agent 随机搜索 尝试5次
wandb.agent(sweep_id, train, count=5)

四,调参可视化和跟踪

1,平行坐标系图

可以直观展示哪些超参数组合更加容易获取更好的结果。

7366fd427d9e456030174fd9764948ec.png


2,超参数重要性图

可以显示超参数和优化目标最终取值的重要性,和相关性方向。

79447d4928b8ac70a543e57a9c7141f7.png


caa2b93c396396eb37a888a052bd7cdf.png

相关文章:

30分钟吃掉wandb可视化自动调参

wandb.sweep: 低代码,可视化,分布式 自动调参工具。使用wandb 的 sweep 进行超参调优,具有以下优点。(1)低代码:只需配置一个sweep.yaml配置文件,或者定义一个配置dict,几乎不用编写调参相关代码。(2)可视化…...

【8】AMBA_SOC项目自学IC验证项目-仿真平台脚本使用讲解

仿真平台文件介绍和脚本使用说明 1、项目路径:2、文件夹说明:3、仿真运行命令:第一步:进入项目路径第二步:设置环境第三步:运行仿真第四步:查看波形1、项目路径: 位置:/tool/project/axi 2、文件夹说明: a、env就是放的我们uvm环境相关的env文件; b、out就是我们…...

智慧水务未来技术发展方向预测探讨

随着科技的不断发展和城市化的加速,智慧水务作为一种新的水务模式,逐渐受到广泛关注。未来,智慧水务将会面临更多的技术挑战和商机。本博客将对智慧水务的未来技术发展方向进行预测,以探讨智慧水务未来可能的技术重点。 1. 人工…...

数据结构 | 栈与队列

🔥Go for it!🔥 📝个人主页:按键难防 📫 如果文章知识点有错误的地方,请指正!和大家一起学习,一起进步👀 📖系列专栏:数据结构与算法 &#x1f52…...

Redux 源码分析

Redux 目录结构 redux ├─ .babelrc.js ├─ .editorconfig ├─ .gitignore …...

第五十二章 BFS进阶(二)——双向广搜

第五十二章 BFS进阶(二)——双向广搜一、双向广搜1、优越之处2、实现逻辑3、复杂度分析二、例题1、问题2、分析3、代码一、双向广搜 1、优越之处 双向广搜是指我们从终点和起点同时开始搜索,当二者到达同一个中间状态的时候,即相…...

业务建模题

一. 单选题:1.在活动图中负责在一个活动节点执行完毕后切换到另一个节点的元素是( A)。A.控制流 B.对象流 C.判断节点 D.扩展区城2.以下说法错误的是(C)。A.活动图中的开始标记一般只有一一个,而终止标记可能有多个B.判断节点的出口条件必须保证不互相重复,并且不缺…...

电子秤专用模拟数字(AD)转换器芯片HX711介绍

HX711简介HX711是一款专为高精度电子秤而设计的24 位A/D 转换器芯片。与同类型其它芯片相比,该芯片集成了包括稳压电源、片内时钟振荡器等其它同类型芯片所需要的外围电路,具有集成度高、响应速度快、抗干扰性强等优点。降低了电子秤的整机成本&#xff…...

微服务 RocketMQ-延时消息 消息过滤 管控台搜索问题

~~微服务 RocketMQ-延时消息 消息过滤 管控台搜索问题~~ RocketMQ-延时消息实现延时消息RocketMQ-消息过滤Tag标签过滤SQL标签过滤管控台搜索问题RocketMQ-延时消息 给消息设置延时时间,到一定时间,消费者才能消费的到,中间件内部通过每秒钟扫…...

js发送邮件(node.js)

以前看别人博客留言或者评论文章时必须填写邮箱信息,感觉甚是麻烦。 后来才知道是为了在博主回复后让访客收到邮件,用心良苦。 于是我也在新增留言和文章评论的接口里,新增了给自己发送邮件提醒的功能。 我用的QQ邮箱,具体如下…...

English Learning - Day58 一周高频问题汇总 2023.2.12 周日

English Learning - Day58 一周高频问题汇总 2023.2.12 周日这周主要内容继续说说状语从句结果状语从句这周主要内容 DAY58【周日总结】 一周高频问题汇总 (打卡作业详见 Day59) 一近期主要讲了 一 01.主动脉修饰 以下是最常问到的知识点拓展&#xff…...

【微电网】基于风光储能和需求响应的微电网日前经济调度(Python代码实现)

目录 1 概述 2 知识点及数学模型 3 算例实现 3.1算例介绍 3.2风光参与的模型求解 3.3 风光和储能参与的模型求解 3.5 风光储能和需求响应都参与模型求解 3.6 结果分析对比 4 Python代码及算例数据 1 概述 近年来,微电网、清洁能源等已成为全球关注的热点…...

四种方式的MySQL安装

mysql安装常见的方法有四种序号 安装方式 说明1 yum\rpm简单、快速,不能定制参数2二进制 解压,简单配置就可使用 免安装 mysql-a.b.c-linux2.x-x86_64.tar.gz3源码编译 可以定制参数,安装时间长 mysql-a.b.c.tar.gz4源码制成rpm包 把源码制…...

软考高级信息系统项目管理师系列之九:项目范围管理

软考高级信息系统项目管理师系列之九:项目范围管理 一、范围管理输入、输出、工具和技术表二、范围管理概述三、规划范围管理四、收集需求1.收集需求:2.需求分类3.收集需求的工具与技术4.收集需求过程主要输出5.需求文件内容6.需求管理7.可跟踪性8.双向可跟踪性9.需求跟踪矩阵…...

【项目精选】javaEE健康管理系统(论文+开题报告+答辩PPT+源代码+数据库+讲解视频)

点击下载源码 javaEE健康管理系统主要功能包括:教师登录退出、教师饮食管理、教师健康日志、体检管理等等。本系统结构如下: (1)用户模块: 实现登录功能 实现用户登录的退出 实现用户注册 (2)教…...

ctfshow nodejs

web 334 大小写转换特殊字符绕过。 “ı”.toUpperCase() ‘I’,“ſ”.toUpperCase() ‘S’。 “K”.toLowerCase() ‘k’. payload: CTFſHOW 123456web 335 通过源码可知 eval(xxx),eval 中可以执行 js 代码,那么我们可以依此执行系…...

无线传感器原理及方法|重点理论知识|2021年19级|期末考试

Min-Max定位 【P63】 最小最大法的基本思想是依据未知节点到各锚节点的距离测量值及锚节点的坐标构造若干个边界框,即以参考节点为圆心,未知节点到该锚节点的距离测量值为半径所构成圆的外接矩形,计算外接矩形的质心为未知节点的估计坐标。 多边定位法的浮点运算量大,计算代…...

带你写出符合 Promise/A+ 规范 Promise 的源码

Promise是前端面试中的高频问题,如果你能根据PromiseA的规范,写出符合规范的源码,那么我想,对于面试中的Promise相关的问题,都能够给出比较完美的答案。 我的建议是,对照规范多写几次实现,也许…...

回流与重绘

触发回流与重绘条件👉回流当渲染树中部分或者全部元素的尺寸、结构或者属性发生变化时,浏览器会重新渲染部分或者全部文档的过程就称为 回流。引起回流原因1.页面的首次渲染2.浏览器的窗口大小发生变化3.元素的内容发生变化4.元素的尺寸或者位置发生变化…...

openpyxl表格的简单实用

示例:创建简单的电子表格和条形图 在这个例子中,我们将从头开始创建一个工作表并添加一些数据,然后绘制它。我们还将探索一些有限的单元格样式和格式。 我们将在工作表上输入的数据如下: 首先,让我们加载 openpyxl 并创建一个新工作簿。并获取活动表。我们还将输入我们…...

Chapter03-Authentication vulnerabilities

文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

docker详细操作--未完待续

docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...

云计算——弹性云计算器(ECS)

弹性云服务器:ECS 概述 云计算重构了ICT系统,云计算平台厂商推出使得厂家能够主要关注应用管理而非平台管理的云平台,包含如下主要概念。 ECS(Elastic Cloud Server):即弹性云服务器,是云计算…...

Cesium1.95中高性能加载1500个点

一、基本方式&#xff1a; 图标使用.png比.svg性能要好 <template><div id"cesiumContainer"></div><div class"toolbar"><button id"resetButton">重新生成点</button><span id"countDisplay&qu…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

关于nvm与node.js

1 安装nvm 安装过程中手动修改 nvm的安装路径&#xff0c; 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解&#xff0c;但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后&#xff0c;通常在该文件中会出现以下配置&…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 &#xff08;FL&#xff09; 支持跨分布式客户端进行协作模型训练&#xff0c;而无需共享原始数据&#xff0c;这使其成为在互联和自动驾驶汽车 &#xff08;CAV&#xff09; 等领域保护隐私的机器学习的一种很有前途的方法。然而&#xff0c;最近的研究表明&…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

css的定位(position)详解:相对定位 绝对定位 固定定位

在 CSS 中&#xff0c;元素的定位通过 position 属性控制&#xff0c;共有 5 种定位模式&#xff1a;static&#xff08;静态定位&#xff09;、relative&#xff08;相对定位&#xff09;、absolute&#xff08;绝对定位&#xff09;、fixed&#xff08;固定定位&#xff09;和…...

2023赣州旅游投资集团

单选题 1.“不登高山&#xff0c;不知天之高也&#xff1b;不临深溪&#xff0c;不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...