【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪
本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~
https://github.com/LFF8888/FF-Studio-Resources
在机器学习项目中,实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构,还是监控训练过程中的性能变化,清晰的记录和直观的可视化都能显著提升开发效率。然而,许多开发者在实际操作中往往忽视了这一点,导致实验结果难以复现,或者在项目协作中出现混乱。今天,笔者将介绍如何利用 PyTorch Lightning 和 Weights & Biases 这一强大的工具组合,轻松构建和训练一个图像分类模型。通过本文,你将学会如何高效地组织数据管道、定义模型架构,并利用 W&B 实现实验跟踪和结果可视化,让每一次实验都清晰可溯,每一次优化都有据可依。
使用 PyTorch Lightning ⚡️ 进行图像分类
我们将使用 PyTorch Lightning 构建一个图像分类管道。我们将遵循这个 风格指南 来提高代码的可读性和可重复性。这里有一个很酷的解释:使用 PyTorch Lightning 进行图像分类。
设置 PyTorch Lightning 和 W&B
对于本教程,我们需要 PyTorch Lightning(这不是很明显吗!)和 Weights and Biases。
!pip install lightning torchvision -q
# 安装 weights and biases
!pip install wandb -qU
你需要这些导入。
import lightning.pytorch as pl
# 你最喜欢的机器学习跟踪工具
from lightning.pytorch.loggers import WandbLoggerimport torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoaderfrom torchmetrics import Accuracyfrom torchvision import transforms
from torchvision.datasets import CIFAR10import wandb
现在你需要登录到你的 wandb 账户。
wandb.login()
🔧 DataModule - 我们应得的数据管道
DataModules 是一种将数据相关的钩子与 LightningModule 解耦的方式,以便你可以开发与数据集无关的模型。
它将数据管道组织成一个可共享和可重用的类。一个 datamodule 封装了 PyTorch 中数据处理的五个步骤:
- 下载 / 分词 / 处理。
- 清理并(可能)保存到磁盘。
- 加载到 Dataset 中。
- 应用转换(旋转、分词等)。
- 包装到 DataLoader 中。
了解更多关于 datamodules 的信息 这里。让我们为 Cifar-10 数据集构建一个 datamodule。
class CIFAR10DataModule(pl.LightningDataModule):def __init__(self, batch_size, data_dir: str = './'):super().__init__()self.data_dir = data_dirself.batch_size = batch_sizeself.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])self.num_classes = 10def prepare_data(self):CIFAR10(self.data_dir, train=True, download=True)CIFAR10(self.data_dir, train=False, download=True)def setup(self, stage=None):# 为 dataloaders 分配训练/验证数据集if stage == 'fit' or stage is None:cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])# 为 dataloader(s) 分配测试数据集if stage == 'test' or stage is None:self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)def train_dataloader(self):return DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.cifar_val, batch_size=self.batch_size)def test_dataloader(self):return DataLoader(self.cifar_test, batch_size=self.batch_size)
📱 Callbacks
回调是一个独立的程序,可以在项目之间重用。PyTorch Lightning 提供了一些 内置回调,这些回调经常被使用。
了解更多关于 PyTorch Lightning 中的回调 这里。
内置回调
在本教程中,我们将使用 Early Stopping 和 Model Checkpoint 内置回调。它们可以传递给 Trainer。
自定义回调
如果你熟悉自定义 Keras 回调,那么在 PyTorch 管道中实现相同功能的能力只是锦上添花。
由于我们正在进行图像分类,能够可视化模型对一些样本图像的预测可能很有帮助。这种形式的回调可以帮助在早期阶段调试模型。
class ImagePredictionLogger(pl.callbacks.Callback):def __init__(self, val_samples, num_samples=32):super().__init__()self.num_samples = num_samplesself.val_imgs, self.val_labels = val_samplesdef on_validation_epoch_end(self, trainer, pl_module):# 将张量带到 CPUval_imgs = self.val_imgs.to(device=pl_module.device)val_labels = self.val_labels.to(device=pl_module.device)# 获取模型预测logits = pl_module(val_imgs)preds = torch.argmax(logits, -1)# 将图像记录为 wandb Imagetrainer.logger.experiment.log({"examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")for x, pred, y in zip(val_imgs[:self.num_samples],preds[:self.num_samples],val_labels[:self.num_samples])]})
🎺 LightningModule - 定义系统
LightningModule 定义了一个系统,而不是一个模型。在这里,系统将所有研究代码分组到一个类中,使其自包含。LightningModule 将你的 PyTorch 代码组织成 5 个部分:
- 计算 (
__init__)。 - 训练循环 (
training_step) - 验证循环 (
validation_step) - 测试循环 (
test_step) - 优化器 (
configure_optimizers)
因此,可以构建一个与数据集无关的模型,并且可以轻松共享。让我们为 Cifar-10 分类构建一个系统。
class LitModel(pl.LightningModule):def __init__(self, input_shape, num_classes, learning_rate=2e-4):super().__init__()# 记录超参数self.save_hyperparameters()self.learning_rate = learning_rateself.conv1 = nn.Conv2d(3, 32, 3, 1)self.conv2 = nn.Conv2d(32, 32, 3, 1)self.conv3 = nn.Conv2d(32, 64, 3, 1)self.conv4 = nn.Conv2d(64, 64, 3, 1)self.pool1 = torch.nn.MaxPool2d(2)self.pool2 = torch.nn.MaxPool2d(2)n_sizes = self._get_conv_output(input_shape)self.fc1 = nn.Linear(n_sizes, 512)self.fc2 = nn.Linear(512, 128)self.fc3 = nn.Linear(128, num_classes)self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)# 返回从卷积块进入线性层的输出张量的大小。def _get_conv_output(self, shape):batch_size = 1input = torch.autograd.Variable(torch.rand(batch_size, *shape))output_feat = self._forward_features(input)n_size = output_feat.data.view(batch_size, -1).size(1)return n_size# 返回卷积块的特征张量def _forward_features(self, x):x = F.relu(self.conv1(x))x = self.pool1(F.relu(self.conv2(x)))x = F.relu(self.conv3(x))x = self.pool2(F.relu(self.conv4(x)))return x# 将在推理期间使用def forward(self, x):x = self._forward_features(x)x = x.view(x.size(0), -1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.log_softmax(self.fc3(x), dim=1)return xdef training_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 训练指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)return lossdef validation_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 验证指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('val_loss', loss, prog_bar=True)self.log('val_acc', acc, prog_bar=True)return lossdef test_step(self, batch, batch_idx):x, y = batchlogits = self(x)loss = F.nll_loss(logits, y)# 验证指标preds = torch.argmax(logits, dim=1)acc = self.accuracy(preds, y)self.log('test_loss', loss, prog_bar=True)self.log('test_acc', acc, prog_bar=True)return lossdef configure_optimizers(self):optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)return optimizer
🚋 训练和评估
现在我们已经使用 DataModule 组织了数据管道,并使用 LightningModule 组织了模型架构和训练循环,PyTorch Lightning Trainer 为我们自动化了其他所有内容。
Trainer 自动化了以下内容:
- Epoch 和 batch 迭代
- 调用
optimizer.step()、backward、zero_grad() - 调用
.eval(),启用/禁用梯度 - 保存和加载权重
- Weights and Biases 日志记录
- 多 GPU 训练支持
- TPU 支持
- 16 位训练支持
dm = CIFAR10DataModule(batch_size=32)
# 要访问 x_dataloader,我们需要调用 prepare_data 和 setup。
dm.prepare_data()
dm.setup()# 自定义 ImagePredictionLogger 回调所需的样本,用于记录图像预测。
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape
model = LitModel((3, 32, 32), dm.num_classes)# 初始化 wandb logger
wandb_logger = WandbLogger(project='wandb-lightning', job_type='train')# 初始化 Callbacks
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss")
checkpoint_callback = pl.callbacks.ModelCheckpoint()# 初始化一个 trainer
trainer = pl.Trainer(max_epochs=2,logger=wandb_logger,callbacks=[early_stop_callback,ImagePredictionLogger(val_samples),checkpoint_callback],)# 训练模型 ⚡🚅⚡
trainer.fit(model, dm)# 在保留的测试集上评估模型 ⚡⚡
trainer.test(dataloaders=dm.test_dataloader())# 关闭 wandb run
wandb.finish()
最终想法
我来自 TensorFlow/Keras 生态系统,发现 PyTorch 虽然是一个优雅的框架,但有点让人不知所措。这只是我的个人经验。在探索 PyTorch Lightning 时,我意识到几乎所有让我远离 PyTorch 的原因都得到了解决。以下是我兴奋的快速总结:
- 过去:传统的 PyTorch 模型定义通常分散在各个地方。模型在某个
model.py脚本中,训练循环在train.py文件中。需要来回查看才能理解管道。 - 现在:
LightningModule作为一个系统,模型定义与training_step、validation_step等一起定义。现在它是模块化的且可共享的。 - 过去:TensorFlow/Keras 最棒的部分是输入数据管道。他们的数据集目录丰富且不断增长。PyTorch 的数据管道曾经是最大的痛点。在普通的 PyTorch 代码中,数据下载/清理/准备通常分散在许多文件中。
- 现在:DataModule 将数据管道组织成一个可共享和可重用的类。它只是
train_dataloader、val_dataloader(s)、test_dataloader(s) 以及匹配的转换和数据处理/下载步骤的集合。 - 过去:使用 Keras,可以调用
model.fit来训练模型,调用model.predict来运行推理。model.evaluate提供了一个简单而有效的测试数据评估。这在 PyTorch 中不是这样。通常会找到单独的train.py和test.py文件。 - 现在:有了
LightningModule,Trainer自动化了一切。只需调用trainer.fit和trainer.test来训练和评估模型。 - 过去:TensorFlow 喜欢 TPU,PyTorch…嗯!
- 现在:使用 PyTorch Lightning,可以轻松地在多个 GPU 上训练相同的模型,甚至在 TPU 上。哇!
- 过去:我是回调的忠实粉丝,更喜欢编写自定义回调。像 Early Stopping 这样简单的事情曾经是传统 PyTorch 的讨论点。
- 现在:使用 PyTorch Lightning,使用 Early Stopping 和 Model Checkpointing 是小菜一碟。我甚至可以编写自定义回调。
🎨 结论和资源
我希望你觉得这份报告有帮助。我鼓励你玩一下代码,并使用你选择的数据集训练一个图像分类器。
以下是一些学习更多关于 PyTorch Lightning 的资源:
- 逐步演练 - 这是官方教程之一。他们的文档写得非常好,我强烈推荐它作为学习资源。
- 使用 PyTorch Lightning 与 Weights & Biases - 这是一个快速 colab,你可以通过它学习如何使用 W&B 与 PyTorch Lightning。
相关文章:
【WB 深度学习实验管理】使用 PyTorch Lightning 实现高效的图像分类实验跟踪
本文使用到的 Jupyter Notebook 可在GitHub仓库002文件夹找到,别忘了给仓库点个小心心~~~ https://github.com/LFF8888/FF-Studio-Resources 在机器学习项目中,实验跟踪和结果可视化是至关重要的环节。无论是调整超参数、优化模型架构,还是监…...
编译spring 6.2.2
如何编译Spring 6.2.2 下载spring 6.2.2 首先,下载spring 6.2.2,地址:下载 解压到你的目录下。 下载gradle 下载gradle,这是spring项目的依赖管理工具,本文下载的是8.12.1。 gradle idea配置如下:在你的…...
【centOS】搭建公司内网git环境-GitLab 社区版(GitLab CE)
1. 安装必要的依赖 以 CentOS 7 系统为例,安装必要的依赖包: sudo yum install -y curl policycoreutils openssh-server openssh-clients postfix sudo systemctl start postfix sudo systemctl enable postfix2. 添加 GitLab 仓库 curl -sS https:/…...
MHTML文件如何在前端页面展示
MHTML文件如何在前端页面展示 需求背景: 目前在给证券公司做项目,但是在使用新系统的过程中,甲方还希望之前之前系统的历史记录可以看到。 最初制定的计划是项目组里面做数据的把原系统页面爬取下来,转成图片,直接给…...
Spring Boot的常用注解
Spring Boot 常用注解 主要分为以下几类: Spring 核心注解Spring Boot 相关注解Spring MVC 相关注解Spring Data JPA 相关注解Spring 事务管理Spring Security 相关注解Spring AOP 相关注解Spring 其他常用注解 下面是详细分类和表格展示👇:…...
【R语言】plyr包和dplyr包
一、plyr包 plyr扩展包主要是实现数据处理中的“分割-应用-组合”(split-apply-combine)策略。此策略是指将一个问题分割成更容易操作的部分,再对每一部分进行独立的操作,最后将各部分的操作结果组合起来。 plyr扩展包中的主要函…...
《XSS跨站脚本攻击》
一、XSS简介 XSS全称(Cross Site Scripting)跨站脚本攻击,为了避免和CSS层叠样式表名称冲突,所以改为了XSS,是最常见的Web应用程序安全漏洞之一,位于OWASP top 10 2013/2017年度分别为第三名和第七名&…...
Golang:精通sync/atomic 包的Atomic 操作
在本指南中,我们将探索sync/atomic包的细节,展示如何编写更安全、更高效的并发代码。无论你是经验丰富的Gopher还是刚刚起步,你都会发现有价值的见解来提升Go编程技能。让我们一起开启原子运算的力量吧! 理解Go中的原子操作 在快…...
代码随想录_二叉树
二叉树 二叉树的递归遍历 144.二叉树的前序遍历145.二叉树的后序遍历94.二叉树的中序遍历 // 前序遍历递归LC144_二叉树的前序遍历 class Solution {public List<Integer> preorderTraversal(TreeNode root) {List<Integer> result new ArrayList<Integer&g…...
详解Swift中 Sendable AnyActor Actor GlobalActor MainActor Task、await、async
详解Swift中 Sendable AnyActor Actor GlobalActor MainActor 的关联或者关系 及其 各自的作用 和 用法 以及与 Task、await、async: Sendable 协议 作用: Sendable 是一个协议,它用于标记可以安全地跨线程或异步任务传递的数据类型。符合 S…...
【C语言标准库函数】浮点数分解与构造: frexp() 和 ldexp()
目录 一、头文件 二、函数简介 2.1. frexp(double x, int *exp) 2.2. ldexp(double x, int exp) 三、函数实现(概念性) 3.1. frexp 的概念性实现 3.2. ldexp 的概念性实现 四、注意事项 五、示例代码 在C语言标准库中,frexp() 和 ld…...
【Git】tortoisegit使用配置
1. 安装 首先下载小乌龟,下载地址:https://tortoisegit.org/download/, 可以顺便下载语言包! 安装时,默认安装就可以,一路next。也可以安装到指定目录中 目前已完成本地安装,接下来就需要与远程仓库建立连接&…...
Spring基于文心一言API使用的大模型
有时做项目我们可能会遇到要在项目中对接AI大模型 本篇文章是对使用文心一言大模型的使用总结 前置任务 在百度智能云开放平台中注册成为开发者 百度智能云开放平台 进入百度智能云官网进行登录,点击立即体验 点击千帆大模型平台 向下滑动,进入到模型…...
运维_Mac环境单体服务Docker部署实战手册
Docker部署 本小节,讲解如何将前端 后端项目,使用 Docker 容器,部署到 dev 开发环境下的一台 Mac 电脑上。 1 环境准备 需要安装如下环境: Docker:容器MySQL:数据库Redis:缓存Nginx&#x…...
[论文笔记] Deepseek-R1R1-zero技术报告阅读
启发: 1、SFT&RL的训练数据使用CoT输出的格式,先思考再回答,大大提升模型的数学与推理能力。 2、RL训练使用群体相对策略优化(GRPO),奖励模型是规则驱动,准确性奖励和格式化奖励。 1. 总体概述 背景与目标 报告聚焦于利用强化学习(RL)提升大型语言模型(LLMs)…...
Centos Ollama + Deepseek-r1+Chatbox运行环境搭建
Centos Ollama Deepseek-r1Chatbox运行环境搭建 内容介绍下载ollama在Ollama运行DeepSeek-r1模型使用chatbox连接ollama api 内容介绍 你好! 这篇文章简单讲述一下如何在linux环境搭建 Ollama Deepseek-r1。并在本地安装的Chatbox中进行远程调用 下载ollama 登…...
一文读懂:TCP网络拥塞的应对策略与方案
TCP(传输控制协议)是互联网中广泛使用的可靠传输协议,它通过序列号、确认应答、重发控制、连接管理以及窗口控制等机制确保数据的可靠传输。然而,在网络环境中,由于多个主机共享网络资源,网络拥塞成为了一个…...
SpringSecurity:授权服务器与客户端应用(入门案例)
文章目录 一、需求概述二、开发授权服务器1、pom依赖2、yml配置3、启动服务端 三、开发客户端应用1、pom依赖2、yml配置3、SecurityConfig4、接口5、测试 一、需求概述 maven需要3.6.0以上版本 二、开发授权服务器 1、pom依赖 <dependency><groupId>org.springfr…...
Python与java的区别
一开始接触Python的时候,哔哩视频铺天盖地,看了很多人主讲的,要找适合自己口味的,各种培训机构喜欢在各种平台引流打广告,看了很多家,要么就是一个视频几个小时,长篇大论不讲原理只讲应用&#…...
doris:MySQL 兼容性
Doris 高度兼容 MySQL 语法,支持标准 SQL。但是 Doris 与 MySQL 还是有很多不同的地方,下面给出了它们的差异点介绍。 数据类型 数字类型 类型MySQLDorisBoolean- 支持 - 范围:0 代表 false,1 代表 true- 支持 - 关键字&am…...
SQL中 的exists用法
EXISTS 是 SQL 中的一个子查询条件,用于检查子查询是否返回任何行。如果子查询返回至少一行,则 EXISTS 返回 TRUE。 例如,查询有订单的客户列表: SELECT * FROM customers c WHERE EXISTS (SELECT 1 FROM orders o WHERE o.cust…...
案例1.spark和flink分别实现作业配置动态更新案例
目录 目录 一、背景 二、解决 1.方法1:spark broadcast广播变量 a. 思路 b. 案例 ① 需求 ② 数据 ③ 代码 2.方法2:flink RichSourceFunction a. 思路 b. 案例 ① 需求 ② 数据 ③ 代码 ④ 测试验证 测试1 测试2 测试3 一、背景 在实时作业(如 Spark Str…...
大数据学习之SparkSql
95.SPARKSQL_简介 网址: https://spark.apache.org/sql/ Spark SQL 是 Spark 的一个模块,用于处理 结构化的数据 。 SparkSQL 特点 1 易整合 无缝的整合了 SQL 查询和 Spark 编程,随时用 SQL 或 DataFrame API 处理结构化数据。并且支…...
鸿蒙UI(ArkUI-方舟UI框架)- 使用文本
返回主章节 → 鸿蒙UI(ArkUI-方舟UI框架) 文本使用 文本显示 (Text/Span) Text是文本组件,通常用于展示用户视图,如显示文章的文字内容。Span则用于呈现显示行内文本。 创建文本 string字符串 Text("我是一段文本"…...
Spider 数据集上实现nlp2sql训练任务
NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中&…...
数据结构——【树模板】
#思路 1、 结点类: 属性:数据,孩子结点列表 功能1:认孩子: 前提:在父子都是结点的情况下 2. 树类: 属性:根节点,生成初始化的总结点 功能1:获取初始化…...
R 数组:高效数据处理的基础
R 数组:高效数据处理的基础 引言 在数据科学和统计分析领域,R 语言以其强大的数据处理和分析能力而备受推崇。R 数组是 R 语言中用于存储和操作数据的基本数据结构。本文将详细介绍 R 数组的创建、操作和优化,帮助读者掌握 R 数组的使用技巧…...
【DeepSeek】DeepSeek概述 | 本地部署deepseek
目录 1 -> 概述 1.1 -> 技术特点 1.2 -> 模型发布 1.3 -> 应用领域 1.4 -> 优势与影响 2 -> 本地部署 2.1 -> 安装ollama 2.2 -> 部署deepseek-r1模型 1 -> 概述 DeepSeek是由中国的深度求索公司开发的一系列人工智能模型,以其…...
npm link,lerna,pnmp workspace区别
npm link、Lerna 和 pnpm workspace 是三种不同的工具/功能,用于处理 JavaScript 项目的依赖管理和 Monorepo 场景。它们的核心区别如下: 1. npm link 用途 本地调试依赖:将本地开发的包(Package A)临时链接到另一个…...
ASP.NET Core 使用 WebClient 从 URL 下载
本文使用 ASP .NET Core 3.1,但它在.NET 5、 .NET 6和.NET 8上也同样适用。如果使用较旧的.NET Framework,请参阅本文,不过,变化不大。 如果想要从 URL 下载任何数据类型,请参阅本文:HttpClient 使用WebC…...
