当前位置: 首页 > news >正文

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:

Trainer 的核心功能

  1. 自动化训练循环
    自动处理 training_stepvalidation_steptest_step 的调用,无需手动编写 for epoch in epochs 循环。

  2. 硬件加速支持
    支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。

  3. 训练控制
    控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。

  4. 日志与监控
    集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。

  5. 回调机制
    通过回调函数(如 EarlyStoppingModelCheckpoint)实现早停、模型保存等扩展功能。

Trainer 的常用参数

from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10,            # 最大训练轮数accelerator="auto",       # 自动选择设备 (CPU/GPU/TPU)devices="auto",           # 使用所有可用设备(如多 GPU)precision="16-mixed",     # 混合精度训练(FP16)# 日志与调试logger=True,              # 默认使用 TensorBoardlog_every_n_steps=10,     # 每 10 个批次记录一次日志fast_dev_run=False,       # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp",           # 分布式数据并行策略(多 GPU)num_nodes=1,              # 节点数量(多机器训练)
)

使用示例代码

步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1)  # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss)     # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)

步骤 3:启动训练

model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)

关键功能演示

1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练

# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)

常见问题

如何恢复训练?
使用 resume_from_checkpoint 参数:

trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")

如何限制训练时间?

trainer = Trainer(max_time="00:02:00")  # 最多训练 2 分钟

如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:

def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]

总结

通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。

参考:

https://lightning.ai/docs/pytorch/stable/common/trainer.html

相关文章:

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例: Trainer 的核心功能 自动化训练循环 自…...

mysql监控--慢查询

一、监控配置 二、慢查询文件 在 MySQL 中,慢查询日志记录了执行时间较长的查询,通常,慢查询日志可能会生成以下几种文件: 1. 慢查询日志文件 这是最主要的文件,记录了执行时间超过设置阈值的 SQL 查询。可以通过 …...

Conda 包管理:高效安装、更新和删除软件包

Conda 包管理:高效安装、更新和删除软件包 1. 引言 在使用 Anaconda 进行 Python 开发时,包管理是日常操作的核心内容。Conda 提供了一整套高效的工具来管理 Python 环境中的软件包,避免了版本冲突,并确保了环境的一致性。 本篇…...

AcWing 798. 差分矩阵

题目来源: 找不到页面 - AcWing 题目内容: 输入一个 n 行 m 列的整数矩阵,再输入 q 个操作,每个操作包含五个整数 x1,y1,x2,y2,c,其中 (x1,y1) 和 (x2,y2)表示一个子矩阵的左上角坐标和右下角坐标。 每个操作都要将…...

通用定时器学习记录

简介 通用定时器:TIM2/TIM3/TIM4/TIM5 主要特性:16位递增、递减、中心对齐计数器(计数值0~65535) 16位预分频器(分频系数1~65536) 可用于触发DAC、ADC 在更新事件、触发事件、输入捕获、输出比较时&am…...

科技之光闪耀江城:2025武汉国际半导体产业与电子技术博览会5月15日盛大开幕

在科技浪潮汹涌澎湃的当下,半导体产业作为现代信息技术的中流砥柱,正以令人惊叹的速度重塑着世界的面貌。2025年5月15-17日,一场聚焦半导体与电子技术前沿的行业盛会 ——2025 武汉国际半导体产业与电子技术博览会,将在武汉・中国…...

vue开发06:前端通过webpack配置代理处理跨域问题

1.定义 在浏览器尝试请求不同源(域名、协议、端口号不同)的资源时,浏览器的同源策略会阻止这种跨域请求。(比如前端端口15500,后端端口5050,前端界面不可以直接调用5050端口) 2.解决方案 使用前…...

⚡️《静电刺客的猎杀手册:芯片世界里的“千伏惊魂“》⚡️

前言: 在这个电子产品无孔不入的时代,我们每天都在与一群隐形刺客打交道——它们身怀数千伏特的高压绝技,能在0.1秒内让价值百万的芯片灰飞烟灭。这就是静电放电(ESD),电子工业界最令人闻风丧胆的"沉默…...

【云安全】云原生-K8S(三) 安装 Dashboard 面板

在Kubernetes中安装Dashboard需要几个步骤,包括部署Dashboard组件、配置访问权限以及暴露Dashboard服务等。以下是详细的步骤: 1. 部署 K8S Dashboard 可以通过以下命令用Kubernetes官方的YAML文件来快速部署,由于是国外网站,需…...

Spring Boot 常用依赖详解:如何选择和使用常用依赖

在Spring Boot项目中,依赖(Dependencies)是项目的核心组成部分。每个依赖都提供了一些特定的功能或工具,帮助我们快速开发应用程序。本文将详细介绍Spring Boot中常用的依赖及其作用,并指导你如何根据项目需求选择合适…...

C++ 设计模式-组合模式

组合模式(Composite Pattern)允许将对象组合成树形结构,使得客户端以统一的方式处理单个对象和组合对象。以下是一个经典的 C 实现示例,包含透明式设计(基类定义统一接口)和内存管理: #include…...

【Spring Boot】Spring 魔法世界:Bean 作用域与生命周期的奇妙之旅

前言 ???本期讲解关于spring原理Bean的相关知识介绍~~~ ??感兴趣的小伙伴看一看小编主页:-CSDN博客 ?? 你的点赞就是小编不断更新的最大动力 ??那么废话不多说直接开整吧~~ 目录 ???1.Bean的作用域 ??1.1概念 ??1.2Bean的作用域 ??1.3代码演示…...

移远通信边缘计算模组成功运行DeepSeek模型,以领先的工程能力加速端侧AI落地

近日,国产大模型DeepSeek凭借其“开源开放、高效推理、端侧友好”的核心优势,迅速风靡全球。移远通信基于边缘计算模组SG885G,已成功实现DeepSeek模型的稳定运行,并完成了针对性微调。 目前,该模型正在多款智能终端上进…...

Cables Finance 构建集成LST与外汇RWA永续合约的综合性DEX

虽然 DeFi 领域整体发展迅速,但仍旧缺乏交易体验。现阶段市场已拓展至 RWAs 、永续期货和外汇领域,但跨资产交易的实际操作仍充满阻力。交易者面临流动性碎片化、抵押品被锁定在质押合约中缺乏流动性,以及整个系统仍围绕美元稳定币运转等问题…...

AI大模型(DeepSeek)科研应用、论文写作、数据分析与AI绘图学习

【介绍】 在人工智能浪潮中,2024年12月中国公司研发的 DeepSeek 横空出世以惊艳全球的姿态,成为 AI领域不可忽视的力量!DeepSeek 完全开源,可本地部署,无使用限制,保护用户隐私。其次,其性能强大&#xff…...

【算法工程】解决linux下Aspose.slides提示No usable version of libssl found以及强化推理模型的短板

1. 背景 构建ubuntu镜像,然后使用Aspose.slides解析PPTX文档,发现一直提示“No usable version of libssl found”。 2. 尝试 使用deepseek R1、kimi1.5、chatgpt o3,并且都带上联网能力,居然还是没有一个能够真正解决&#xf…...

什么是HTTP和HTTPS?它们之间有什么区别?

什么是HTTP和HTTPS?它们之间有什么区别? HTTP(超文本传输协议)简介 HTTP就像是你通过明信片给朋友发送信息。你在明信片上写下内容,然后寄出去。任何人都可以在途中看到明信片上的内容,因为它是公开的。 …...

【一文读懂】TCP与UDP协议

TCP协议 概述 TCP(Transmission Control Protocol),即传输控制协议,是一种面向连接的、可靠的、基于字节流的传输层通信协议,常用于保证数据可靠、按顺序、无差错地传输。TCP 是互联网协议族(TCP/IP&…...

数据结构 树的存储和遍历

一、树的定义 树的定义 树型结构是⼀类重要的⾮线性数据结构。 • 有⼀个特殊的结点,称为根结点,根结点没有前驱结点。 • 除根结点外,其余结点被分成M个互不相交的集合T1 、T2 、...、Tm T,其中每⼀个集合⼜是⼀棵树&#xff0c…...

Jenkins项目CICD流程

Jenkins项目流程:1.配置git环境 git config --...2.把前后端的目录初始化位本地工作目录 #git init3.提交到本地git #git add ./ git commit -m "" git tag v14.然后提交到远程git(通过,用户,群组,项目,管理项目)git remote add origin http://...git push -…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

微信小程序之bind和catch

这两个呢,都是绑定事件用的,具体使用有些小区别。 官方文档: 事件冒泡处理不同 bind:绑定的事件会向上冒泡,即触发当前组件的事件后,还会继续触发父组件的相同事件。例如,有一个子视图绑定了b…...

深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法

深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

【python异步多线程】异步多线程爬虫代码示例

claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

Linux nano命令的基本使用

参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...

Web中间件--tomcat学习

Web中间件–tomcat Java虚拟机详解 什么是JAVA虚拟机 Java虚拟机是一个抽象的计算机,它可以执行Java字节码。Java虚拟机是Java平台的一部分,Java平台由Java语言、Java API和Java虚拟机组成。Java虚拟机的主要作用是将Java字节码转换为机器代码&#x…...

CVPR2025重磅突破:AnomalyAny框架实现单样本生成逼真异常数据,破解视觉检测瓶颈!

本文介绍了一种名为AnomalyAny的创新框架,该方法利用Stable Diffusion的强大生成能力,仅需单个正常样本和文本描述,即可生成逼真且多样化的异常样本,有效解决了视觉异常检测中异常样本稀缺的难题,为工业质检、医疗影像…...

comfyui 工作流中 图生视频 如何增加视频的长度到5秒

comfyUI 工作流怎么可以生成更长的视频。除了硬件显存要求之外还有别的方法吗? 在ComfyUI中实现图生视频并延长到5秒,需要结合多个扩展和技巧。以下是完整解决方案: 核心工作流配置(24fps下5秒120帧) #mermaid-svg-yP…...