【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪
本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~
https://github.com/LFF8888/FF-Studio-Resources
在机器学习项目中,实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构,还是监控训练过程中的性能变化,清晰的记录和直观的可视化都能显著提升开发效率。然而,许多开发者在实际操作中往往忽视了这一点,导致实验结果难以复现,或者在项目协作中出现混乱。今天,笔者将介绍如何利用 PyTorch Lightning 和 Weights & Biases 这一强大的工具组合,轻松构建和训练一个图像分类模型。通过本文,你将学会如何高效地组织数据管道、定义模型架构,并利用 W&B 实现实验跟踪和结果可视化,让每一次实验都清晰可溯,每一次优化都有据可依。
使用 PyTorch Lightning ⚡️ 进行图像分类
我们将使用 PyTorch Lightning 构建一个图像分类管道。我们将遵循这个 风格指南 来提高代码的可读性和可重复性。这里有一个很酷的解释:使用 PyTorch Lightning 进行图像分类。
设置 PyTorch Lightning 和 W&B
对于本教程,我们需要 PyTorch Lightning(这不是很明显吗!)和 Weights and Biases。
!pip install lightning torchvision -q
# 安装 weights and biases
!pip install wandb -qU
你需要这些导入。
import lightning.pytorch as pl
# 你最喜欢的机器学习跟踪工具
from lightning.pytorch.loggers import WandbLoggerimport torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoaderfrom torchmetrics import Accuracyfrom torchvision import transforms
from torchvision.datasets import CIFAR10import wandb
现在你需要登录到你的 wandb 账户。
wandb.login()
🔧 DataModule - 我们应得的数据管道
DataModules 是一种将数据相关的钩子与 LightningModule 解耦的方式,以便你可以开发与数据集无关的模型。
它将数据管道组织成一个可共享和可重用的类。一个 datamodule 封装了 PyTorch 中数据处理的五个步骤:
- 下载 / 分词 / 处理。
- 清理并(可能)保存到磁盘。
- 加载到 Dataset 中。
- 应用转换(旋转、分词等)。
- 包装到 DataLoader 中。
了解更多关于 datamodules 的信息 这里。让我们为 Cifar-10 数据集构建一个 datamodule。
class CIFAR10DataModule(pl.LightningDataModule):def __init__(self, batch_size, data_dir: str = './'):super().__init__()self.data_dir = data_dirself.batch_size = batch_sizeself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])self.num_classes = 10def prepare_data(self):CIFAR10(self.data_dir, train=True, download=True)CIFAR10(self.data_dir, train=False, download=True)def setup(self, stage=None):# 为 dataloaders 分配训练/验证数据集if stage == 'fit' or stage is None:cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])# 为 dataloader(s) 分配测试数据集if stage == 'test' or stage is None:self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)def train_dataloader(self):return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.cifar_val, batch_size=self.batch_size)def test_dataloader(self):return DataLoader(self.cifar_test, batch_size=self.batch_size)
📱 Callbacks
回调是一个独立的程序,可以在项目之间重用。PyTorch Lightning 提供了一些 内置回调,这些回调经常被使用。
了解更多关于 PyTorch Lightning 中的回调 这里。
内置回调
在本教程中,我们将使用 Early Stopping 和 Model Checkpoint 内置回调。它们可以传递给 Trainer
。
自定义回调
如果你熟悉自定义 Keras 回调,那么在 PyTorch 管道中实现相同功能的能力只是锦上添花。
由于我们正在进行图像分类,能够可视化模型对一些样本图像的预测可能很有帮助。这种形式的回调可以帮助在早期阶段调试模型。
class ImagePredictionLogger(pl.callbacks.Callback):def __init__(self, val_samples, num_samples=32):super().__init__()self.num_samples = num_samplesself.val_imgs, self.val_labels = val_samplesdef on_validation_epoch_end(self, trainer, pl_module):# 将张量带到 CPUval_imgs = self.val_imgs.to(device=pl_module.device)val_labels = self.val_labels.to(device=pl_module.device)# 获取模型预测logits = pl_module(val_imgs)preds = torch.argmax(logits, -1)# 将图像记录为 wandb Imagetrainer.logger.experiment.log({"examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")for x, pred, y in zip(val_imgs[:self.num_samples],preds[:self.num_samples],val_labels[:self.num_samples])]})
🎺 LightningModule - 定义系统
LightningModule 定义了一个系统,而不是一个模型。在这里,系统将所有研究代码分组到一个类中,使其自包含。LightningModule
将你的 PyTorch 代码组织成 5 个部分:
- 计算 (
__init__
)。 - 训练循环 (
training_step
) - 验证循环 (
validation_step
) - 测试循环 (
test_step
) - 优化器 (
configure_optimizers
)
因此,可以构建一个与数据集无关的模型,并且可以轻松共享。让我们为 Cifar-10 分类构建一个系统。
class LitModel(pl.LightningModule):def __init__(self, input_shape, num_classes, learning_rate=2e-4):super().__init__()# 记录超参数self.save_hyperparameters()self.learning_rate = learning_rateself.conv1 = nn.Conv2d(3, 32, 3, 1)self.conv2 = nn.Conv2d(32, 32, 3, 1)self.conv3 = nn.Conv2d(32, 64, 3, 1)self.conv4 = nn.Conv2d(64, 64, 3, 1)self.pool1 = torch.nn.MaxPool2d(2)self.pool2 = torch.nn.MaxPool2d(2)n_sizes = self._get_conv_output(input_shape)self.fc1 = nn.Linear(n_sizes, 512)self.fc2 = nn.Linear(512, 128)self.fc3 = nn.Linear(128, num_classes)self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)# 返回从卷积块进入线性层的输出张量的大小。def _get_conv_output(self, shape):batch_size = 1input = torch.autograd.Variable(torch.rand(batch_size, *shape))output_feat = self._forward_features(input)n_size = output_feat.data.view(batch_size, -1).size(1)return n_size# 返回卷积块的特征张量def _forward_features(self, x):x = F.relu(self.conv1(x))x = self.pool1(F.relu(self.conv2(x)))x = F.relu(self.conv3(x))x = self.pool2(F.relu(self.conv4(x)))return x# 将在推理期间使用def forward(self, x):x = self._forward_features(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.log_softmax(self.fc3(x), dim=1)return xdef training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 训练指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)return lossdef validation_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 验证指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('val_loss', loss, prog_bar=True)self.log('val_acc', acc, prog_bar=True)return lossdef test_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 验证指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('test_loss', loss, prog_bar=True)self.log('test_acc', acc, prog_bar=True)return lossdef configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)return optimizer
🚋 训练和评估
现在我们已经使用 DataModule
组织了数据管道,并使用 LightningModule
组织了模型架构和训练循环,PyTorch Lightning Trainer
为我们自动化了其他所有内容。
Trainer 自动化了以下内容:
- Epoch 和 batch 迭代
- 调用
optimizer.step()
、backward
、zero_grad()
- 调用
.eval()
,启用/禁用梯度 - 保存和加载权重
- Weights and Biases 日志记录
- 多 GPU 训练支持
- TPU 支持
- 16 位训练支持
dm = CIFAR10DataModule(batch_size=32)
# 要访问 x_dataloader,我们需要调用 prepare_data 和 setup。
dm.prepare_data()
dm.setup()# 自定义 ImagePredictionLogger 回调所需的样本,用于记录图像预测。
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)# 初始化 wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')# 初始化 Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()# 初始化一个 trainer
trainer = pl.Trainer(max_epochs=2,logger=wandb_logger,callbacks=[early_stop_callback,ImagePredictionLogger(val_samples),checkpoint_callback],)# 训练模型 ⚡🚅⚡
trainer.fit(model, dm)# 在保留的测试集上评估模型 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())# 关闭 wandb run
wandb.finish()
最终想法
我来自 TensorFlow/Keras 生态系统,发现 PyTorch 虽然是一个优雅的框架,但有点让人不知所措。这只是我的个人经验。在探索 PyTorch Lightning 时,我意识到几乎所有让我远离 PyTorch 的原因都得到了解决。以下是我兴奋的快速总结:
- 过去:传统的 PyTorch 模型定义通常分散在各个地方。模型在某个
model.py
脚本中,训练循环在train.py
文件中。需要来回查看才能理解管道。 - 现在:
LightningModule
作为一个系统,模型定义与training_step
、validation_step
等一起定义。现在它是模块化的且可共享的。 - 过去:TensorFlow/Keras 最棒的部分是输入数据管道。他们的数据集目录丰富且不断增长。PyTorch 的数据管道曾经是最大的痛点。在普通的 PyTorch 代码中,数据下载/清理/准备通常分散在许多文件中。
- 现在:DataModule 将数据管道组织成一个可共享和可重用的类。它只是
train_dataloader
、val_dataloader
(s)、test_dataloader
(s) 以及匹配的转换和数据处理/下载步骤的集合。 - 过去:使用 Keras,可以调用
model.fit
来训练模型,调用model.predict
来运行推理。model.evaluate
提供了一个简单而有效的测试数据评估。这在 PyTorch 中不是这样。通常会找到单独的train.py
和test.py
文件。 - 现在:有了
LightningModule
,Trainer
自动化了一切。只需调用trainer.fit
和trainer.test
来训练和评估模型。 - 过去:TensorFlow 喜欢 TPU,PyTorch…嗯!
- 现在:使用 PyTorch Lightning,可以轻松地在多个 GPU 上训练相同的模型,甚至在 TPU 上。哇!
- 过去:我是回调的忠实粉丝,更喜欢编写自定义回调。像 Early Stopping 这样简单的事情曾经是传统 PyTorch 的讨论点。
- 现在:使用 PyTorch Lightning,使用 Early Stopping 和 Model Checkpointing 是小菜一碟。我甚至可以编写自定义回调。
🎨 结论和资源
我希望你觉得这份报告有帮助。我鼓励你玩一下代码,并使用你选择的数据集训练一个图像分类器。
以下是一些学习更多关于 PyTorch Lightning 的资源:
- 逐步演练 - 这是官方教程之一。他们的文档写得非常好,我强烈推荐它作为学习资源。
- 使用 PyTorch Lightning 与 Weights & Biases - 这是一个快速 colab,你可以通过它学习如何使用 W&B 与 PyTorch Lightning。
相关文章:

【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪
本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~ https://github.com/LFF8888/FF-Studio-Resources 在机器学习项目中,实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构,还是监…...
SSM开发(十一) mybatis关联关系多表查询(嵌套查询,举例说明)
目录 一、背景介绍 二、一对一查询(嵌套查询) 三、一对多查询(嵌套查询) 四、嵌套查询效率评估 注:关联查询则是指在一个查询中涉及到多个表的联合查询 一、背景介绍 当对数据库的操作涉及到多张表,这在面向对象语言如Java中就涉及到了对象与对象之间的关联关系。针对多…...
The Simulation技术浅析(六):机器学习
机器学习(Machine Learning)是模拟技术(The Simulation)的重要组成部分,通过从数据中自动学习规律和模式,机器学习能够提升模拟系统的智能化水平,增强其预测、决策和优化能力。 一、监督学习(Supervised Learning) 1. 基本原理 监督学习是指利用标注数据(即输入数…...

apache-poi导出excel数据
excel导出 自动设置宽度,设置标题框,设置数据边框。 excel导出 添加依赖 <dependency><groupId>org.apache.poi</groupId><artifactId>poi-ooxml</artifactId><version>5.2.2</version></dependency>…...
唯一值校验的实现思路(续)
本文接着上一篇文章《唯一值校验的实现思路》,在后端实现唯一值校验。用代码实现。 /*** checkUniqueException[唯一值校验]** param entity 新增或编辑的学生实体* param insert 是否新增,如果是传入true;反之传入false* return void* date…...
ffmpeg基本用法
一、用法 ffmpeg [options] [[infile options] -i infile]... {[outfile options] outfile}... 说明: global options:全局选项,应用于整个 FFmpeg 进程,它们通常不受输入或输出部分的限制。 infile options:输入选…...

MYSQL第四次
目录 题目分析 代码实现 一、修改 Student 表中年龄(sage)字段属性,数据类型由 int 改变为 smallint 二、为 Course 表中 Cno 字段设置索引,并查看索引 三、为 SC 表建立按学号(sno)和课程号ÿ…...

联德胜w801开发板(六)手机蓝牙设置wifi名称和密码
一、概述 W801 是一款集成了 Wi-Fi 和蓝牙功能的芯片,本文将介绍如何利用 W801 的蓝牙功能,实现手机 APP 通过蓝牙配置 W801 连接的 Wi-Fi 名称和密码(即配网功能)。 二、文档查看: demo使用手册这里很清楚…...

Linux:库
目录 静态库 动态库 目标文件 ELF文件 ELF形成可执行 ELF可执行加载 ELF加载 全局偏移量表GOT(global offset table) 库是写好的,成熟的,可以复用的代码 现实中每个程序都要依赖很多的基础的底层库,不可能都是从零开始的 库有两种…...

向量数据库简单对比
文章目录 一、Chroma二、Pinecone/腾讯云VectorDB/VikingDB三、redis四、Elasticsearch五、Milvus六、Qdrant七、Weaviate八、Faiss 一、Chroma 官方地址: https://www.trychroma.com/优点 ①简单,非常简单构建服务。 ②此外,Chroma还具有自…...

大模型基本原理(四)——如何武装ChatGPT
传统的LLM存在几个短板:编造事实、计算不准确、数据过时等,为了应对这几个问题,可以借助一些外部工具或数据把AI武装起来。 实现这一思路的框架包括RAG、PAL、ReAct。 1、RAG(检索增强生成) LLM生成的内容会受到训练…...

从零开始:使用Jenkins实现高效自动化部署
在这篇文章中我们将深入探讨如何通过Jenkins构建高效的自动化部署流水线,帮助团队实现从代码提交到生产环境部署的全流程自动化。无论你是Jenkins新手还是有一定经验的开发者,这篇文章都会为你提供实用的技巧和最佳实践,助你在项目部署中走得…...

Spring Cloud工程完善
目录 完善订单服务 启动类 配置文件 实体类 Controller Service Mapper 测试运行 完成商品服务 启动类 配置文件 实体类 Controller Service Mapper 测试运行 远程调用 需求 实现 1.定义RestTemplate 2.修改order-service中的OrderService 测试运行 Rest…...

SSM仓库物品管理系统 附带详细运行指导视频
文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.用户登录代码:2.保存物品信息代码:3.删除仓库信息代码: 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SSM框架开发的仓库…...
UI自动化测试中如何处理验证码?
在UI自动化测试中处理验证码是常见的技术挑战,以下是分步解决方案及实际应用建议: 一、验证码处理策略对比 方法实现方式优点缺点适用场景禁用验证码测试环境配置关闭验证码生成简单快捷,零成本无法测试验证码功能本身非验证码相关功能测试万…...
华为交换机堆叠配置
一、CSS堆叠集群配置(框式交换机) 1、通过集群卡连接方式组建集群 [SwitchA] set css mode css-card \\配置集群卡连接方式 [SwitchA] set css id 1 \\配置成员交换机的集群ID(缺省值为1) [SwitchA] set css priority 100 \\配…...
Vue 和 dhtmlx-gantt 实现图表构建动态多级甘特图效果 ,横坐标为动态刻度不是日期
注意事项:1、横坐标根据日期转换成时间刻度在( gantt.config.scales);2、获取时间刻度的最大值(findMaxRepairTime);3、甘特图多级列表需注意二级三级每个父子id需要唯一(convertData) 安装依赖 npm install dhtmlx-gantt --save 在当前页引入和配置 dhtmlx-gantt im…...

collabora online+nextcloud+mariadb在线文档协助
1、环境 龙蜥os 8.9 docker 2、安装docker dnf -y install dnf-plugins-core dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sed -i shttps://download.docker.comhttps://mirrors.tuna.tsinghua.edu.cn/docker-ce /etc/yum.repos.…...

“可通过HTTP获取远端WWW服务信息”漏洞修复
环境说明:①操作系统:windows server;②nginx:1.27.1。 1.漏洞说明 “可通过HTTP获取远端WWW服务信息”。 修复前,在“响应标头”能看到Server信息,如下图所示: 修复后,“响应标头…...

【AI时代】-开发环境准备 之 Conda 创建 Python 环境 (含pip常用命令、jupyter 安装及汉化、自定义文档位置等配置)
一、 安装 Anaconda 1.1 下载并安装 https://www.anaconda.com/download/success 1.2 验证是否成功 CMD输入命令: conda --version注意:找不到命令需要配置环境变量: Path 中 添加 Anaconda 的安装路径: 如果没有修改安装位…...

【Axure高保真原型】引导弹窗
今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互
物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...
k8s从入门到放弃之Ingress七层负载
k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...
【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密
在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...
聊一聊接口测试的意义有哪些?
目录 一、隔离性 & 早期测试 二、保障系统集成质量 三、验证业务逻辑的核心层 四、提升测试效率与覆盖度 五、系统稳定性的守护者 六、驱动团队协作与契约管理 七、性能与扩展性的前置评估 八、持续交付的核心支撑 接口测试的意义可以从四个维度展开,首…...

HDFS分布式存储 zookeeper
hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...