PyTorch Lightning Trainer介绍
PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:
Trainer 的核心功能
-
自动化训练循环
自动处理training_step、validation_step、test_step的调用,无需手动编写for epoch in epochs循环。 -
硬件加速支持
支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。 -
训练控制
控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。 -
日志与监控
集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。 -
回调机制
通过回调函数(如EarlyStopping,ModelCheckpoint)实现早停、模型保存等扩展功能。
Trainer 的常用参数
from pytorch_lightning import Trainertrainer = Trainer(# 基础配置max_epochs=10, # 最大训练轮数accelerator="auto", # 自动选择设备 (CPU/GPU/TPU)devices="auto", # 使用所有可用设备(如多 GPU)precision="16-mixed", # 混合精度训练(FP16)# 日志与调试logger=True, # 默认使用 TensorBoardlog_every_n_steps=10, # 每 10 个批次记录一次日志fast_dev_run=False, # 快速运行一个批次(调试模式)# 回调函数callbacks=[pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)],# 分布式训练strategy="ddp", # 分布式数据并行策略(多 GPU)num_nodes=1, # 节点数量(多机器训练)
)
使用示例代码
步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as plclass LitModel(pl.LightningModule):def __init__(self):super().__init__()self.layer1 = nn.Linear(28*28, 128)self.layer2 = nn.Linear(128, 10)def forward(self, x):x = x.view(x.size(0), -1) # 展平输入x = F.relu(self.layer1(x))x = self.layer2(x)return xdef training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss) # 自动记录日志return lossdef validation_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("val_loss", loss) # 自动记录验证损失def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensorclass MNISTDataModule(pl.LightningDataModule):def __init__(self, batch_size=32):super().__init__()self.batch_size = batch_sizedef prepare_data(self):MNIST(root="data", download=True)def setup(self, stage=None):full_dataset = MNIST(root="data", train=True, transform=ToTensor())self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])def train_dataloader(self):return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)def val_dataloader(self):return DataLoader(self.val_data, batch_size=self.batch_size)dm = MNISTDataModule(batch_size=32)
步骤 3:启动训练
model = LitModel()
trainer = Trainer(max_epochs=10,accelerator="auto",devices="auto",logger=True,callbacks=[pl.callbacks.ModelCheckpoint(monitor="val_loss")]
)# 开始训练与验证
trainer.fit(model, datamodule=dm)# 测试(可选)
trainer.test(model, datamodule=dm)
关键功能演示
1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练
# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),pl.callbacks.ModelCheckpoint(dirpath="checkpoints/",filename="best-model-{epoch:02d}-{val_loss:.2f}",save_top_k=2,monitor="val_loss")
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)
常见问题
如何恢复训练?
使用 resume_from_checkpoint 参数:
trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
如何限制训练时间?
trainer = Trainer(max_time="00:02:00") # 最多训练 2 分钟
如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:
def configure_optimizers(self):optimizer = Adam(self.parameters())scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)return [optimizer], [scheduler]
总结
通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。
参考:
https://lightning.ai/docs/pytorch/stable/common/trainer.html
相关文章:
PyTorch Lightning Trainer介绍
PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例: Trainer 的核心功能 自动化训练循环 自…...
mysql监控--慢查询
一、监控配置 二、慢查询文件 在 MySQL 中,慢查询日志记录了执行时间较长的查询,通常,慢查询日志可能会生成以下几种文件: 1. 慢查询日志文件 这是最主要的文件,记录了执行时间超过设置阈值的 SQL 查询。可以通过 …...
Conda 包管理:高效安装、更新和删除软件包
Conda 包管理:高效安装、更新和删除软件包 1. 引言 在使用 Anaconda 进行 Python 开发时,包管理是日常操作的核心内容。Conda 提供了一整套高效的工具来管理 Python 环境中的软件包,避免了版本冲突,并确保了环境的一致性。 本篇…...
AcWing 798. 差分矩阵
题目来源: 找不到页面 - AcWing 题目内容: 输入一个 n 行 m 列的整数矩阵,再输入 q 个操作,每个操作包含五个整数 x1,y1,x2,y2,c,其中 (x1,y1) 和 (x2,y2)表示一个子矩阵的左上角坐标和右下角坐标。 每个操作都要将…...
通用定时器学习记录
简介 通用定时器:TIM2/TIM3/TIM4/TIM5 主要特性:16位递增、递减、中心对齐计数器(计数值0~65535) 16位预分频器(分频系数1~65536) 可用于触发DAC、ADC 在更新事件、触发事件、输入捕获、输出比较时&am…...
科技之光闪耀江城:2025武汉国际半导体产业与电子技术博览会5月15日盛大开幕
在科技浪潮汹涌澎湃的当下,半导体产业作为现代信息技术的中流砥柱,正以令人惊叹的速度重塑着世界的面貌。2025年5月15-17日,一场聚焦半导体与电子技术前沿的行业盛会 ——2025 武汉国际半导体产业与电子技术博览会,将在武汉・中国…...
vue开发06:前端通过webpack配置代理处理跨域问题
1.定义 在浏览器尝试请求不同源(域名、协议、端口号不同)的资源时,浏览器的同源策略会阻止这种跨域请求。(比如前端端口15500,后端端口5050,前端界面不可以直接调用5050端口) 2.解决方案 使用前…...
⚡️《静电刺客的猎杀手册:芯片世界里的“千伏惊魂“》⚡️
前言: 在这个电子产品无孔不入的时代,我们每天都在与一群隐形刺客打交道——它们身怀数千伏特的高压绝技,能在0.1秒内让价值百万的芯片灰飞烟灭。这就是静电放电(ESD),电子工业界最令人闻风丧胆的"沉默…...
【云安全】云原生-K8S(三) 安装 Dashboard 面板
在Kubernetes中安装Dashboard需要几个步骤,包括部署Dashboard组件、配置访问权限以及暴露Dashboard服务等。以下是详细的步骤: 1. 部署 K8S Dashboard 可以通过以下命令用Kubernetes官方的YAML文件来快速部署,由于是国外网站,需…...
Spring Boot 常用依赖详解:如何选择和使用常用依赖
在Spring Boot项目中,依赖(Dependencies)是项目的核心组成部分。每个依赖都提供了一些特定的功能或工具,帮助我们快速开发应用程序。本文将详细介绍Spring Boot中常用的依赖及其作用,并指导你如何根据项目需求选择合适…...
C++ 设计模式-组合模式
组合模式(Composite Pattern)允许将对象组合成树形结构,使得客户端以统一的方式处理单个对象和组合对象。以下是一个经典的 C 实现示例,包含透明式设计(基类定义统一接口)和内存管理: #include…...
【Spring Boot】Spring 魔法世界:Bean 作用域与生命周期的奇妙之旅
前言 ???本期讲解关于spring原理Bean的相关知识介绍~~~ ??感兴趣的小伙伴看一看小编主页:-CSDN博客 ?? 你的点赞就是小编不断更新的最大动力 ??那么废话不多说直接开整吧~~ 目录 ???1.Bean的作用域 ??1.1概念 ??1.2Bean的作用域 ??1.3代码演示…...
移远通信边缘计算模组成功运行DeepSeek模型,以领先的工程能力加速端侧AI落地
近日,国产大模型DeepSeek凭借其“开源开放、高效推理、端侧友好”的核心优势,迅速风靡全球。移远通信基于边缘计算模组SG885G,已成功实现DeepSeek模型的稳定运行,并完成了针对性微调。 目前,该模型正在多款智能终端上进…...
Cables Finance 构建集成LST与外汇RWA永续合约的综合性DEX
虽然 DeFi 领域整体发展迅速,但仍旧缺乏交易体验。现阶段市场已拓展至 RWAs 、永续期货和外汇领域,但跨资产交易的实际操作仍充满阻力。交易者面临流动性碎片化、抵押品被锁定在质押合约中缺乏流动性,以及整个系统仍围绕美元稳定币运转等问题…...
AI大模型(DeepSeek)科研应用、论文写作、数据分析与AI绘图学习
【介绍】 在人工智能浪潮中,2024年12月中国公司研发的 DeepSeek 横空出世以惊艳全球的姿态,成为 AI领域不可忽视的力量!DeepSeek 完全开源,可本地部署,无使用限制,保护用户隐私。其次,其性能强大ÿ…...
【算法工程】解决linux下Aspose.slides提示No usable version of libssl found以及强化推理模型的短板
1. 背景 构建ubuntu镜像,然后使用Aspose.slides解析PPTX文档,发现一直提示“No usable version of libssl found”。 2. 尝试 使用deepseek R1、kimi1.5、chatgpt o3,并且都带上联网能力,居然还是没有一个能够真正解决…...
什么是HTTP和HTTPS?它们之间有什么区别?
什么是HTTP和HTTPS?它们之间有什么区别? HTTP(超文本传输协议)简介 HTTP就像是你通过明信片给朋友发送信息。你在明信片上写下内容,然后寄出去。任何人都可以在途中看到明信片上的内容,因为它是公开的。 …...
【一文读懂】TCP与UDP协议
TCP协议 概述 TCP(Transmission Control Protocol),即传输控制协议,是一种面向连接的、可靠的、基于字节流的传输层通信协议,常用于保证数据可靠、按顺序、无差错地传输。TCP 是互联网协议族(TCP/IP&…...
数据结构 树的存储和遍历
一、树的定义 树的定义 树型结构是⼀类重要的⾮线性数据结构。 • 有⼀个特殊的结点,称为根结点,根结点没有前驱结点。 • 除根结点外,其余结点被分成M个互不相交的集合T1 、T2 、...、Tm T,其中每⼀个集合⼜是⼀棵树,…...
Jenkins项目CICD流程
Jenkins项目流程:1.配置git环境 git config --...2.把前后端的目录初始化位本地工作目录 #git init3.提交到本地git #git add ./ git commit -m "" git tag v14.然后提交到远程git(通过,用户,群组,项目,管理项目)git remote add origin http://...git push -…...
GE Eager Style Graph Builder类关系文档
Eager Style Graph Builder 类关系文档 【免费下载链接】ge GE(Graph Engine)是面向昇腾的图编译器和执行器,提供了计算图优化、多流并行、内存复用和模型下沉等技术手段,加速模型执行效率,减少模型内存占用。 GE 提供…...
CANN/AMCT大模型FlatQuant量化
AMCT大模型对于LLAMA2/Qwen3的FlatQuant量化 【免费下载链接】amct AMCT是CANN提供的昇腾AI处理器亲和的模型压缩工具仓。 项目地址: https://gitcode.com/cann/amct 1 量化前提 1.1 安装依赖 本sample依赖包可参考requirements.txt 需要注意的是torch_npu包版本需要…...
CANN/社区安全发布指南
版本发布网络安全质量要求 【免费下载链接】community 本项目是CANN开源社区的核心管理仓库,包含社区的治理章程、治理组织、通用操作指引及流程规范等基础信息 项目地址: https://gitcode.com/cann/community 为保障版本网络安全质量,版本发布前…...
边缘AI安全与隐私实战:从联邦学习到可信执行环境
1. 项目概述:当边缘计算遇上AI,安全与隐私的新战场最近几年,我身边做物联网和移动应用开发的朋友,聊天的主题都绕不开两个词:边缘计算和AI。大家一边兴奋于把AI模型塞进摄像头、网关甚至手机里带来的实时性提升&#x…...
【2025最新】基于SpringBoot+Vue的抗疫物资管理系统管理系统源码+MyBatis+MySQL
摘要 近年来,全球范围内的突发公共卫生事件频发,抗疫物资的高效管理成为保障社会稳定的重要环节。传统的物资管理方式依赖人工操作,存在效率低下、信息不透明、资源分配不均等问题,难以应对大规模疫情的需求。特别是在物资调配、库…...
保姆级教程:H3C NX30 PRO刷OpenWrt后,用Cron定时任务搞定烦人的LED灯
智能路由器灯光管理:OpenWrt定时任务实战指南 深夜的书房里,路由器LED指示灯像个小太阳一样刺眼。这种困扰对于追求完美使用体验的技术爱好者来说,简直不能忍。好在OpenWrt系统的强大自定义能力可以轻松解决这个问题——不需要复杂的命令行操…...
CANN/asc-devkit Tiling模板参数选择接口
ASCENDC_TPL_SEL_PARAM 【免费下载链接】asc-devkit 本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言,原生支持C和C标准规范,主要由类库和语言扩展层构成,提供多层级API,满足多维场景算子开发诉求。 项目地址: https://…...
通过审计日志功能回溯与分析团队的API调用情况
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 通过审计日志功能回溯与分析团队的API调用情况 作为团队的技术负责人,在引入大模型能力支持业务创新的同时,…...
从IMU到自动驾驶:卡尔曼滤波参数(Q,R)怎么调?一个Python仿真实验说清楚
卡尔曼滤波参数调优实战:用Python仿真破解Q/R矩阵之谜 在自动驾驶和机器人定位领域,卡尔曼滤波器的性能往往取决于两个神秘参数——过程噪声协方差Q和测量噪声协方差R。许多工程师能够熟练实现算法代码,却在参数调试阶段陷入反复试错的泥潭。…...
CANN算子测试赛Add报告
【免费下载链接】cann-competitions 本仓库用于 CANN 开源社区各类竞赛、开源课题、社区任务等课题发布、开发者作品提交和展示。 项目地址: https://gitcode.com/cann/cann-competitions 元信息(请如实填写,此区块将由组委会脚本自动解析…...
