PyTorch Lightning Trainer介绍
PyTorch Lightning 的 Trainer
是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:
Trainer
的核心功能
-
自动化训练循环
自动处理training_step
、validation_step
、test_step
的调用,无需手动编写for epoch in epochs
循环。 -
硬件加速支持
支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。 -
训练控制
控制训练轮数 (max_epochs
)、批次大小 (batch_size
)、梯度裁剪 (gradient_clip_val
) 等。 -
日志与监控
集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。 -
回调机制
通过回调函数(如EarlyStopping
,ModelCheckpoint
)实现早停、模型保存等扩展功能。
Trainer
的常用参数
from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10, # 最大训练轮数accelerator="auto", # 自动选择设备 (CPU/GPU/TPU)devices="auto", # 使用所有可用设备(如多 GPU)precision="16-mixed", # 混合精度训练(FP16)# 日志与调试logger=True, # 默认使用 TensorBoardlog_every_n_steps=10, # 每 10 个批次记录一次日志fast_dev_run=False, # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp", # 分布式数据并行策略(多 GPU)num_nodes=1, # 节点数量(多机器训练)
)
使用示例代码
步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1) # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss) # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss) # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)
步骤 3:启动训练
model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)
关键功能演示
1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练
# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)
常见问题
如何恢复训练?
使用 resume_from_checkpoint
参数:
trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
如何限制训练时间?
trainer = Trainer(max_time="00:02:00") # 最多训练 2 分钟
如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers
方法中返回优化器和调度器:
def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]
总结
通过 Trainer
,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。
参考:
https://lightning.ai/docs/pytorch/stable/common/trainer.html
相关文章:
PyTorch Lightning Trainer介绍
PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例: Trainer 的核心功能 自动化训练循环 自…...
mysql监控--慢查询
一、监控配置 二、慢查询文件 在 MySQL 中,慢查询日志记录了执行时间较长的查询,通常,慢查询日志可能会生成以下几种文件: 1. 慢查询日志文件 这是最主要的文件,记录了执行时间超过设置阈值的 SQL 查询。可以通过 …...
Conda 包管理:高效安装、更新和删除软件包
Conda 包管理:高效安装、更新和删除软件包 1. 引言 在使用 Anaconda 进行 Python 开发时,包管理是日常操作的核心内容。Conda 提供了一整套高效的工具来管理 Python 环境中的软件包,避免了版本冲突,并确保了环境的一致性。 本篇…...

AcWing 798. 差分矩阵
题目来源: 找不到页面 - AcWing 题目内容: 输入一个 n 行 m 列的整数矩阵,再输入 q 个操作,每个操作包含五个整数 x1,y1,x2,y2,c,其中 (x1,y1) 和 (x2,y2)表示一个子矩阵的左上角坐标和右下角坐标。 每个操作都要将…...

通用定时器学习记录
简介 通用定时器:TIM2/TIM3/TIM4/TIM5 主要特性:16位递增、递减、中心对齐计数器(计数值0~65535) 16位预分频器(分频系数1~65536) 可用于触发DAC、ADC 在更新事件、触发事件、输入捕获、输出比较时&am…...

科技之光闪耀江城:2025武汉国际半导体产业与电子技术博览会5月15日盛大开幕
在科技浪潮汹涌澎湃的当下,半导体产业作为现代信息技术的中流砥柱,正以令人惊叹的速度重塑着世界的面貌。2025年5月15-17日,一场聚焦半导体与电子技术前沿的行业盛会 ——2025 武汉国际半导体产业与电子技术博览会,将在武汉・中国…...
vue开发06:前端通过webpack配置代理处理跨域问题
1.定义 在浏览器尝试请求不同源(域名、协议、端口号不同)的资源时,浏览器的同源策略会阻止这种跨域请求。(比如前端端口15500,后端端口5050,前端界面不可以直接调用5050端口) 2.解决方案 使用前…...
⚡️《静电刺客的猎杀手册:芯片世界里的“千伏惊魂“》⚡️
前言: 在这个电子产品无孔不入的时代,我们每天都在与一群隐形刺客打交道——它们身怀数千伏特的高压绝技,能在0.1秒内让价值百万的芯片灰飞烟灭。这就是静电放电(ESD),电子工业界最令人闻风丧胆的"沉默…...

【云安全】云原生-K8S(三) 安装 Dashboard 面板
在Kubernetes中安装Dashboard需要几个步骤,包括部署Dashboard组件、配置访问权限以及暴露Dashboard服务等。以下是详细的步骤: 1. 部署 K8S Dashboard 可以通过以下命令用Kubernetes官方的YAML文件来快速部署,由于是国外网站,需…...
Spring Boot 常用依赖详解:如何选择和使用常用依赖
在Spring Boot项目中,依赖(Dependencies)是项目的核心组成部分。每个依赖都提供了一些特定的功能或工具,帮助我们快速开发应用程序。本文将详细介绍Spring Boot中常用的依赖及其作用,并指导你如何根据项目需求选择合适…...
C++ 设计模式-组合模式
组合模式(Composite Pattern)允许将对象组合成树形结构,使得客户端以统一的方式处理单个对象和组合对象。以下是一个经典的 C 实现示例,包含透明式设计(基类定义统一接口)和内存管理: #include…...

【Spring Boot】Spring 魔法世界:Bean 作用域与生命周期的奇妙之旅
前言 ???本期讲解关于spring原理Bean的相关知识介绍~~~ ??感兴趣的小伙伴看一看小编主页:-CSDN博客 ?? 你的点赞就是小编不断更新的最大动力 ??那么废话不多说直接开整吧~~ 目录 ???1.Bean的作用域 ??1.1概念 ??1.2Bean的作用域 ??1.3代码演示…...

移远通信边缘计算模组成功运行DeepSeek模型,以领先的工程能力加速端侧AI落地
近日,国产大模型DeepSeek凭借其“开源开放、高效推理、端侧友好”的核心优势,迅速风靡全球。移远通信基于边缘计算模组SG885G,已成功实现DeepSeek模型的稳定运行,并完成了针对性微调。 目前,该模型正在多款智能终端上进…...

Cables Finance 构建集成LST与外汇RWA永续合约的综合性DEX
虽然 DeFi 领域整体发展迅速,但仍旧缺乏交易体验。现阶段市场已拓展至 RWAs 、永续期货和外汇领域,但跨资产交易的实际操作仍充满阻力。交易者面临流动性碎片化、抵押品被锁定在质押合约中缺乏流动性,以及整个系统仍围绕美元稳定币运转等问题…...
AI大模型(DeepSeek)科研应用、论文写作、数据分析与AI绘图学习
【介绍】 在人工智能浪潮中,2024年12月中国公司研发的 DeepSeek 横空出世以惊艳全球的姿态,成为 AI领域不可忽视的力量!DeepSeek 完全开源,可本地部署,无使用限制,保护用户隐私。其次,其性能强大ÿ…...
【算法工程】解决linux下Aspose.slides提示No usable version of libssl found以及强化推理模型的短板
1. 背景 构建ubuntu镜像,然后使用Aspose.slides解析PPTX文档,发现一直提示“No usable version of libssl found”。 2. 尝试 使用deepseek R1、kimi1.5、chatgpt o3,并且都带上联网能力,居然还是没有一个能够真正解决…...
什么是HTTP和HTTPS?它们之间有什么区别?
什么是HTTP和HTTPS?它们之间有什么区别? HTTP(超文本传输协议)简介 HTTP就像是你通过明信片给朋友发送信息。你在明信片上写下内容,然后寄出去。任何人都可以在途中看到明信片上的内容,因为它是公开的。 …...
【一文读懂】TCP与UDP协议
TCP协议 概述 TCP(Transmission Control Protocol),即传输控制协议,是一种面向连接的、可靠的、基于字节流的传输层通信协议,常用于保证数据可靠、按顺序、无差错地传输。TCP 是互联网协议族(TCP/IP&…...
数据结构 树的存储和遍历
一、树的定义 树的定义 树型结构是⼀类重要的⾮线性数据结构。 • 有⼀个特殊的结点,称为根结点,根结点没有前驱结点。 • 除根结点外,其余结点被分成M个互不相交的集合T1 、T2 、...、Tm T,其中每⼀个集合⼜是⼀棵树,…...

Jenkins项目CICD流程
Jenkins项目流程:1.配置git环境 git config --...2.把前后端的目录初始化位本地工作目录 #git init3.提交到本地git #git add ./ git commit -m "" git tag v14.然后提交到远程git(通过,用户,群组,项目,管理项目)git remote add origin http://...git push -…...
模型参数、模型存储精度、参数与显存
模型参数量衡量单位 M:百万(Million) B:十亿(Billion) 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的,但是一个参数所表示多少字节不一定,需要看这个参数以什么…...
day52 ResNet18 CBAM
在深度学习的旅程中,我们不断探索如何提升模型的性能。今天,我将分享我在 ResNet18 模型中插入 CBAM(Convolutional Block Attention Module)模块,并采用分阶段微调策略的实践过程。通过这个过程,我不仅提升…...

智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql
智慧工地管理云平台系统,智慧工地全套源码,java版智慧工地源码,支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求,提供“平台网络终端”的整体解决方案,提供劳务管理、视频管理、智能监测、绿色施工、安全管…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
Frozen-Flask :将 Flask 应用“冻结”为静态文件
Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是:将一个 Flask Web 应用生成成纯静态 HTML 文件,从而可以部署到静态网站托管服务上,如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...

华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

Map相关知识
数据结构 二叉树 二叉树,顾名思义,每个节点最多有两个“叉”,也就是两个子节点,分别是左子 节点和右子节点。不过,二叉树并不要求每个节点都有两个子节点,有的节点只 有左子节点,有的节点只有…...

推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...

(一)单例模式
一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...