当前位置: 首页 > news >正文

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:

Trainer 的核心功能

  1. 自动化训练循环
    自动处理 training_stepvalidation_steptest_step 的调用,无需手动编写 for epoch in epochs 循环。

  2. 硬件加速支持
    支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。

  3. 训练控制
    控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。

  4. 日志与监控
    集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。

  5. 回调机制
    通过回调函数(如 EarlyStoppingModelCheckpoint)实现早停、模型保存等扩展功能。

Trainer 的常用参数

from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10,            # 最大训练轮数accelerator="auto",       # 自动选择设备 (CPU/GPU/TPU)devices="auto",           # 使用所有可用设备(如多 GPU)precision="16-mixed",     # 混合精度训练(FP16)# 日志与调试logger=True,              # 默认使用 TensorBoardlog_every_n_steps=10,     # 每 10 个批次记录一次日志fast_dev_run=False,       # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp",           # 分布式数据并行策略(多 GPU)num_nodes=1,              # 节点数量(多机器训练)
)

使用示例代码

步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1)  # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss)     # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)

步骤 3:启动训练

model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)

关键功能演示

1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练

# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)

常见问题

如何恢复训练?
使用 resume_from_checkpoint 参数:

trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")

如何限制训练时间?

trainer = Trainer(max_time="00:02:00")  # 最多训练 2 分钟

如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:

def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]

总结

通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。

参考:

https://lightning.ai/docs/pytorch/stable/common/trainer.html

相关文章:

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例: Trainer 的核心功能 自动化训练循环 自…...

mysql监控--慢查询

一、监控配置 二、慢查询文件 在 MySQL 中,慢查询日志记录了执行时间较长的查询,通常,慢查询日志可能会生成以下几种文件: 1. 慢查询日志文件 这是最主要的文件,记录了执行时间超过设置阈值的 SQL 查询。可以通过 …...

Conda 包管理:高效安装、更新和删除软件包

Conda 包管理:高效安装、更新和删除软件包 1. 引言 在使用 Anaconda 进行 Python 开发时,包管理是日常操作的核心内容。Conda 提供了一整套高效的工具来管理 Python 环境中的软件包,避免了版本冲突,并确保了环境的一致性。 本篇…...

AcWing 798. 差分矩阵

题目来源: 找不到页面 - AcWing 题目内容: 输入一个 n 行 m 列的整数矩阵,再输入 q 个操作,每个操作包含五个整数 x1,y1,x2,y2,c,其中 (x1,y1) 和 (x2,y2)表示一个子矩阵的左上角坐标和右下角坐标。 每个操作都要将…...

通用定时器学习记录

简介 通用定时器:TIM2/TIM3/TIM4/TIM5 主要特性:16位递增、递减、中心对齐计数器(计数值0~65535) 16位预分频器(分频系数1~65536) 可用于触发DAC、ADC 在更新事件、触发事件、输入捕获、输出比较时&am…...

科技之光闪耀江城:2025武汉国际半导体产业与电子技术博览会5月15日盛大开幕

在科技浪潮汹涌澎湃的当下,半导体产业作为现代信息技术的中流砥柱,正以令人惊叹的速度重塑着世界的面貌。2025年5月15-17日,一场聚焦半导体与电子技术前沿的行业盛会 ——2025 武汉国际半导体产业与电子技术博览会,将在武汉・中国…...

vue开发06:前端通过webpack配置代理处理跨域问题

1.定义 在浏览器尝试请求不同源(域名、协议、端口号不同)的资源时,浏览器的同源策略会阻止这种跨域请求。(比如前端端口15500,后端端口5050,前端界面不可以直接调用5050端口) 2.解决方案 使用前…...

⚡️《静电刺客的猎杀手册:芯片世界里的“千伏惊魂“》⚡️

前言: 在这个电子产品无孔不入的时代,我们每天都在与一群隐形刺客打交道——它们身怀数千伏特的高压绝技,能在0.1秒内让价值百万的芯片灰飞烟灭。这就是静电放电(ESD),电子工业界最令人闻风丧胆的"沉默…...

【云安全】云原生-K8S(三) 安装 Dashboard 面板

在Kubernetes中安装Dashboard需要几个步骤,包括部署Dashboard组件、配置访问权限以及暴露Dashboard服务等。以下是详细的步骤: 1. 部署 K8S Dashboard 可以通过以下命令用Kubernetes官方的YAML文件来快速部署,由于是国外网站,需…...

Spring Boot 常用依赖详解:如何选择和使用常用依赖

在Spring Boot项目中,依赖(Dependencies)是项目的核心组成部分。每个依赖都提供了一些特定的功能或工具,帮助我们快速开发应用程序。本文将详细介绍Spring Boot中常用的依赖及其作用,并指导你如何根据项目需求选择合适…...

C++ 设计模式-组合模式

组合模式(Composite Pattern)允许将对象组合成树形结构,使得客户端以统一的方式处理单个对象和组合对象。以下是一个经典的 C 实现示例,包含透明式设计(基类定义统一接口)和内存管理: #include…...

【Spring Boot】Spring 魔法世界:Bean 作用域与生命周期的奇妙之旅

前言 ???本期讲解关于spring原理Bean的相关知识介绍~~~ ??感兴趣的小伙伴看一看小编主页:-CSDN博客 ?? 你的点赞就是小编不断更新的最大动力 ??那么废话不多说直接开整吧~~ 目录 ???1.Bean的作用域 ??1.1概念 ??1.2Bean的作用域 ??1.3代码演示…...

移远通信边缘计算模组成功运行DeepSeek模型,以领先的工程能力加速端侧AI落地

近日,国产大模型DeepSeek凭借其“开源开放、高效推理、端侧友好”的核心优势,迅速风靡全球。移远通信基于边缘计算模组SG885G,已成功实现DeepSeek模型的稳定运行,并完成了针对性微调。 目前,该模型正在多款智能终端上进…...

Cables Finance 构建集成LST与外汇RWA永续合约的综合性DEX

虽然 DeFi 领域整体发展迅速,但仍旧缺乏交易体验。现阶段市场已拓展至 RWAs 、永续期货和外汇领域,但跨资产交易的实际操作仍充满阻力。交易者面临流动性碎片化、抵押品被锁定在质押合约中缺乏流动性,以及整个系统仍围绕美元稳定币运转等问题…...

AI大模型(DeepSeek)科研应用、论文写作、数据分析与AI绘图学习

【介绍】 在人工智能浪潮中,2024年12月中国公司研发的 DeepSeek 横空出世以惊艳全球的姿态,成为 AI领域不可忽视的力量!DeepSeek 完全开源,可本地部署,无使用限制,保护用户隐私。其次,其性能强大&#xff…...

【算法工程】解决linux下Aspose.slides提示No usable version of libssl found以及强化推理模型的短板

1. 背景 构建ubuntu镜像,然后使用Aspose.slides解析PPTX文档,发现一直提示“No usable version of libssl found”。 2. 尝试 使用deepseek R1、kimi1.5、chatgpt o3,并且都带上联网能力,居然还是没有一个能够真正解决&#xf…...

什么是HTTP和HTTPS?它们之间有什么区别?

什么是HTTP和HTTPS?它们之间有什么区别? HTTP(超文本传输协议)简介 HTTP就像是你通过明信片给朋友发送信息。你在明信片上写下内容,然后寄出去。任何人都可以在途中看到明信片上的内容,因为它是公开的。 …...

【一文读懂】TCP与UDP协议

TCP协议 概述 TCP(Transmission Control Protocol),即传输控制协议,是一种面向连接的、可靠的、基于字节流的传输层通信协议,常用于保证数据可靠、按顺序、无差错地传输。TCP 是互联网协议族(TCP/IP&…...

数据结构 树的存储和遍历

一、树的定义 树的定义 树型结构是⼀类重要的⾮线性数据结构。 • 有⼀个特殊的结点,称为根结点,根结点没有前驱结点。 • 除根结点外,其余结点被分成M个互不相交的集合T1 、T2 、...、Tm T,其中每⼀个集合⼜是⼀棵树&#xff0c…...

Jenkins项目CICD流程

Jenkins项目流程:1.配置git环境 git config --...2.把前后端的目录初始化位本地工作目录 #git init3.提交到本地git #git add ./ git commit -m "" git tag v14.然后提交到远程git(通过,用户,群组,项目,管理项目)git remote add origin http://...git push -…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...

(十)学生端搭建

本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

Python:操作 Excel 折叠

💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖 本博客的精华专栏: 【自动化测试】 【测试经验】 【人工智能】 【Python】 Python 操作 Excel 系列 读取单元格数据按行写入设置行高和列宽自动调整行高和列宽水平…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...

python如何将word的doc另存为docx

将 DOCX 文件另存为 DOCX 格式(Python 实现) 在 Python 中,你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是,.doc 是旧的 Word 格式,而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...

ip子接口配置及删除

配置永久生效的子接口,2个IP 都可以登录你这一台服务器。重启不失效。 永久的 [应用] vi /etc/sysconfig/network-scripts/ifcfg-eth0修改文件内内容 TYPE"Ethernet" BOOTPROTO"none" NAME"eth0" DEVICE"eth0" ONBOOT&q…...

ABAP设计模式之---“简单设计原则(Simple Design)”

“Simple Design”(简单设计)是软件开发中的一个重要理念,倡导以最简单的方式实现软件功能,以确保代码清晰易懂、易维护,并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计,遵循“让事情保…...

MyBatis中关于缓存的理解

MyBatis缓存 MyBatis系统当中默认定义两级缓存:一级缓存、二级缓存 默认情况下,只有一级缓存开启(sqlSession级别的缓存)二级缓存需要手动开启配置,需要局域namespace级别的缓存 一级缓存(本地缓存&#…...