当前位置: 首页 > news >正文

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:

Trainer 的核心功能

  1. 自动化训练循环
    自动处理 training_stepvalidation_steptest_step 的调用,无需手动编写 for epoch in epochs 循环。

  2. 硬件加速支持
    支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。

  3. 训练控制
    控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。

  4. 日志与监控
    集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。

  5. 回调机制
    通过回调函数(如 EarlyStoppingModelCheckpoint)实现早停、模型保存等扩展功能。

Trainer 的常用参数

from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10,            # 最大训练轮数accelerator="auto",       # 自动选择设备 (CPU/GPU/TPU)devices="auto",           # 使用所有可用设备(如多 GPU)precision="16-mixed",     # 混合精度训练(FP16)# 日志与调试logger=True,              # 默认使用 TensorBoardlog_every_n_steps=10,     # 每 10 个批次记录一次日志fast_dev_run=False,       # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp",           # 分布式数据并行策略(多 GPU)num_nodes=1,              # 节点数量(多机器训练)
)

使用示例代码

步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1)  # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss)     # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)

步骤 3:启动训练

model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)

关键功能演示

1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练

# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)

常见问题

如何恢复训练?
使用 resume_from_checkpoint 参数:

trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")

如何限制训练时间?

trainer = Trainer(max_time="00:02:00")  # 最多训练 2 分钟

如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:

def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]

总结

通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。

参考:

https://lightning.ai/docs/pytorch/stable/common/trainer.html

相关文章:

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例: Trainer 的核心功能 自动化训练循环 自…...

mysql监控--慢查询

一、监控配置 二、慢查询文件 在 MySQL 中,慢查询日志记录了执行时间较长的查询,通常,慢查询日志可能会生成以下几种文件: 1. 慢查询日志文件 这是最主要的文件,记录了执行时间超过设置阈值的 SQL 查询。可以通过 …...

Conda 包管理:高效安装、更新和删除软件包

Conda 包管理:高效安装、更新和删除软件包 1. 引言 在使用 Anaconda 进行 Python 开发时,包管理是日常操作的核心内容。Conda 提供了一整套高效的工具来管理 Python 环境中的软件包,避免了版本冲突,并确保了环境的一致性。 本篇…...

AcWing 798. 差分矩阵

题目来源: 找不到页面 - AcWing 题目内容: 输入一个 n 行 m 列的整数矩阵,再输入 q 个操作,每个操作包含五个整数 x1,y1,x2,y2,c,其中 (x1,y1) 和 (x2,y2)表示一个子矩阵的左上角坐标和右下角坐标。 每个操作都要将…...

通用定时器学习记录

简介 通用定时器:TIM2/TIM3/TIM4/TIM5 主要特性:16位递增、递减、中心对齐计数器(计数值0~65535) 16位预分频器(分频系数1~65536) 可用于触发DAC、ADC 在更新事件、触发事件、输入捕获、输出比较时&am…...

科技之光闪耀江城:2025武汉国际半导体产业与电子技术博览会5月15日盛大开幕

在科技浪潮汹涌澎湃的当下,半导体产业作为现代信息技术的中流砥柱,正以令人惊叹的速度重塑着世界的面貌。2025年5月15-17日,一场聚焦半导体与电子技术前沿的行业盛会 ——2025 武汉国际半导体产业与电子技术博览会,将在武汉・中国…...

vue开发06:前端通过webpack配置代理处理跨域问题

1.定义 在浏览器尝试请求不同源(域名、协议、端口号不同)的资源时,浏览器的同源策略会阻止这种跨域请求。(比如前端端口15500,后端端口5050,前端界面不可以直接调用5050端口) 2.解决方案 使用前…...

⚡️《静电刺客的猎杀手册:芯片世界里的“千伏惊魂“》⚡️

前言: 在这个电子产品无孔不入的时代,我们每天都在与一群隐形刺客打交道——它们身怀数千伏特的高压绝技,能在0.1秒内让价值百万的芯片灰飞烟灭。这就是静电放电(ESD),电子工业界最令人闻风丧胆的"沉默…...

【云安全】云原生-K8S(三) 安装 Dashboard 面板

在Kubernetes中安装Dashboard需要几个步骤,包括部署Dashboard组件、配置访问权限以及暴露Dashboard服务等。以下是详细的步骤: 1. 部署 K8S Dashboard 可以通过以下命令用Kubernetes官方的YAML文件来快速部署,由于是国外网站,需…...

Spring Boot 常用依赖详解:如何选择和使用常用依赖

在Spring Boot项目中,依赖(Dependencies)是项目的核心组成部分。每个依赖都提供了一些特定的功能或工具,帮助我们快速开发应用程序。本文将详细介绍Spring Boot中常用的依赖及其作用,并指导你如何根据项目需求选择合适…...

C++ 设计模式-组合模式

组合模式(Composite Pattern)允许将对象组合成树形结构,使得客户端以统一的方式处理单个对象和组合对象。以下是一个经典的 C 实现示例,包含透明式设计(基类定义统一接口)和内存管理: #include…...

【Spring Boot】Spring 魔法世界:Bean 作用域与生命周期的奇妙之旅

前言 ???本期讲解关于spring原理Bean的相关知识介绍~~~ ??感兴趣的小伙伴看一看小编主页:-CSDN博客 ?? 你的点赞就是小编不断更新的最大动力 ??那么废话不多说直接开整吧~~ 目录 ???1.Bean的作用域 ??1.1概念 ??1.2Bean的作用域 ??1.3代码演示…...

移远通信边缘计算模组成功运行DeepSeek模型,以领先的工程能力加速端侧AI落地

近日,国产大模型DeepSeek凭借其“开源开放、高效推理、端侧友好”的核心优势,迅速风靡全球。移远通信基于边缘计算模组SG885G,已成功实现DeepSeek模型的稳定运行,并完成了针对性微调。 目前,该模型正在多款智能终端上进…...

Cables Finance 构建集成LST与外汇RWA永续合约的综合性DEX

虽然 DeFi 领域整体发展迅速,但仍旧缺乏交易体验。现阶段市场已拓展至 RWAs 、永续期货和外汇领域,但跨资产交易的实际操作仍充满阻力。交易者面临流动性碎片化、抵押品被锁定在质押合约中缺乏流动性,以及整个系统仍围绕美元稳定币运转等问题…...

AI大模型(DeepSeek)科研应用、论文写作、数据分析与AI绘图学习

【介绍】 在人工智能浪潮中,2024年12月中国公司研发的 DeepSeek 横空出世以惊艳全球的姿态,成为 AI领域不可忽视的力量!DeepSeek 完全开源,可本地部署,无使用限制,保护用户隐私。其次,其性能强大&#xff…...

【算法工程】解决linux下Aspose.slides提示No usable version of libssl found以及强化推理模型的短板

1. 背景 构建ubuntu镜像,然后使用Aspose.slides解析PPTX文档,发现一直提示“No usable version of libssl found”。 2. 尝试 使用deepseek R1、kimi1.5、chatgpt o3,并且都带上联网能力,居然还是没有一个能够真正解决&#xf…...

什么是HTTP和HTTPS?它们之间有什么区别?

什么是HTTP和HTTPS?它们之间有什么区别? HTTP(超文本传输协议)简介 HTTP就像是你通过明信片给朋友发送信息。你在明信片上写下内容,然后寄出去。任何人都可以在途中看到明信片上的内容,因为它是公开的。 …...

【一文读懂】TCP与UDP协议

TCP协议 概述 TCP(Transmission Control Protocol),即传输控制协议,是一种面向连接的、可靠的、基于字节流的传输层通信协议,常用于保证数据可靠、按顺序、无差错地传输。TCP 是互联网协议族(TCP/IP&…...

数据结构 树的存储和遍历

一、树的定义 树的定义 树型结构是⼀类重要的⾮线性数据结构。 • 有⼀个特殊的结点,称为根结点,根结点没有前驱结点。 • 除根结点外,其余结点被分成M个互不相交的集合T1 、T2 、...、Tm T,其中每⼀个集合⼜是⼀棵树&#xff0c…...

Jenkins项目CICD流程

Jenkins项目流程:1.配置git环境 git config --...2.把前后端的目录初始化位本地工作目录 #git init3.提交到本地git #git add ./ git commit -m "" git tag v14.然后提交到远程git(通过,用户,群组,项目,管理项目)git remote add origin http://...git push -…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢

随着互联网技术的飞速发展&#xff0c;消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁&#xff0c;不仅优化了客户体验&#xff0c;还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用&#xff0c;并…...

DIY|Mac 搭建 ESP-IDF 开发环境及编译小智 AI

前一阵子在百度 AI 开发者大会上&#xff0c;看到基于小智 AI DIY 玩具的演示&#xff0c;感觉有点意思&#xff0c;想着自己也来试试。 如果只是想烧录现成的固件&#xff0c;乐鑫官方除了提供了 Windows 版本的 Flash 下载工具 之外&#xff0c;还提供了基于网页版的 ESP LA…...

【git】把本地更改提交远程新分支feature_g

创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

DeepSeek 技术赋能无人农场协同作业:用 AI 重构农田管理 “神经网”

目录 一、引言二、DeepSeek 技术大揭秘2.1 核心架构解析2.2 关键技术剖析 三、智能农业无人农场协同作业现状3.1 发展现状概述3.2 协同作业模式介绍 四、DeepSeek 的 “农场奇妙游”4.1 数据处理与分析4.2 作物生长监测与预测4.3 病虫害防治4.4 农机协同作业调度 五、实际案例大…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!

简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求&#xff0c;并检查收到的响应。它以以下模式之一…...