当前位置: 首页 > news >正文

深度学习中的Checkpoint是什么?

诸神缄默不语-个人CSDN博文目录

文章目录

  • 引言
  • 1. 什么是Checkpoint?
  • 2. 为什么需要Checkpoint?
  • 3. 如何使用Checkpoint?
    • 3.1 TensorFlow 中的 Checkpoint
    • 3.2 PyTorch 中的 Checkpoint
    • 3.3 transformers中的Checkpoint
  • 4. 在 NLP 任务中的应用
  • 5. 总结
  • 6. 参考资料

引言

在深度学习训练过程中,模型的训练往往需要较长的时间,并且计算资源昂贵。由于训练过程中可能遇到各种意外情况,比如断电、程序崩溃,甚至想要在不同阶段对比模型的表现,因此我们需要一种机制来保存训练进度,以便可以随时恢复。这就是**Checkpoint(检查点)**的作用。

对于刚入门深度学习的小伙伴,理解Checkpoint的概念并合理使用它,可以大大提高模型训练的稳定性和效率。本文将详细介绍Checkpoint的概念、用途以及如何在NLP任务中使用它。

1. 什么是Checkpoint?

Checkpoint(检查点)是指在训练过程中,定期保存模型的状态,包括模型的权重参数、优化器状态以及训练进度(如当前的epoch数)。这样,即使训练中断,我们也可以从最近的Checkpoint恢复训练,而不是从头开始。

简单来说,Checkpoint 就像一个存档点,让我们能够在不重头训练的情况下继续优化模型。

一个大模型的checkpoint可能以如下文件形式储存:
在这里插入图片描述

2. 为什么需要Checkpoint?

Checkpoint 的主要作用包括:

  1. 防止训练中断导致的损失:训练神经网络需要消耗大量计算资源,训练时间可能长达数小时甚至数天。如果训练因突发情况(如断电、程序崩溃)中断,Checkpoint 可以帮助我们恢复进度。

  2. 支持断点续训:当训练过程中需要调整超参数或遇到不可预见的问题时,我们可以从最近的Checkpoint继续训练,而不必重新训练整个模型。

  3. 保存最佳模型:在训练过程中,我们通常会评估模型在验证集上的表现。通过Checkpoint,我们可以保存最优表现的模型,而不是仅仅保存最后一次训练的结果。

  4. 支持迁移学习:在实际应用中,我们经常会使用预训练模型(如BERT、GPT等),然后在特定任务上进行微调(fine-tuning)。这些预训练模型的Checkpoint可以用作新的任务的起点,而不必从零开始训练。

3. 如何使用Checkpoint?

在深度学习框架(如 TensorFlow 和 PyTorch)中,Checkpoint 的使用非常方便。下面分别介绍在 TensorFlow 和 PyTorch 中如何保存和加载 Checkpoint。

3.1 TensorFlow 中的 Checkpoint

保存Checkpoint:

在 TensorFlow(Keras)中,可以使用 ModelCheckpoint 回调函数来实现自动保存。

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 创建简单的模型
model = tf.keras.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(100,)),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 设置Checkpoint,保存最优模型
checkpoint_callback = ModelCheckpoint(filepath='best_model.h5',  # 保存路径save_best_only=True,        # 仅保存最优模型monitor='val_loss',         # 监控的指标mode='min',                 # val_loss 越小越好verbose=1                   # 输出日志
)# 训练模型,并使用Checkpoint
model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=10, callbacks=[checkpoint_callback])

加载Checkpoint:

from tensorflow.keras.models import load_model# 加载已保存的模型
model = load_model('best_model.h5')

这样,我们就可以在训练过程中自动保存最优模型,并在需要时加载它。

3.2 PyTorch 中的 Checkpoint

在 PyTorch 中,我们可以使用 torch.savetorch.load 来手动保存和加载模型。

保存Checkpoint:

import torch# 假设 model 是我们的神经网络,optimizer 是优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, 'checkpoint.pth')

加载Checkpoint:

# 加载Checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

在 PyTorch 中,保存和加载 Checkpoint 需要手动指定模型和优化器的状态,而 TensorFlow 处理起来更为自动化。

3.3 transformers中的Checkpoint

如果直接用transformers的Trainer的话,就会自动根据TrainingArguments的参数来设置checkpoint保存策略。具体的参数有save_strategy、save_steps、save_total_limit、load_best_model_at_end等,可以看我之前写过的关于transformers包的博文。

epochs = 10
lr = 2e-5
train_bs = 8
eval_bs = train_bs * 2training_args = TrainingArguments(output_dir=output_dir,num_train_epochs=epochs,learning_rate=lr,per_device_train_batch_size=train_bs,per_device_eval_batch_size=eval_bs,evaluation_strategy="epoch",logging_steps=logging_steps
)

断点续训:

# Trainer 的定义
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset
)# 从最近的检查点恢复训练
trainer.train(resume_from_checkpoint=True)

4. 在 NLP 任务中的应用

在自然语言处理任务中,Checkpoint 主要用于:

  1. 训练 Transformer 模型(如 BERT、GPT)时,保存和恢复训练进度。
  2. 微调预训练模型时,从预训练权重(如 bert-base-uncased)加载 Checkpoint 进行继续训练。
  3. 文本生成任务(如 Seq2Seq 模型),确保中断时可以从最近的 Checkpoint 继续训练。

5. 总结

  • Checkpoint 是深度学习训练过程中保存模型状态的机制,可以防止训练中断带来的损失。
  • 它有助于断点续训、保存最佳模型以及进行迁移学习
  • 在 TensorFlow 和 PyTorch 中都有方便的方式来保存和加载 Checkpoint
  • 在 NLP 任务中,Checkpoint 被广泛用于 Transformer 训练、预训练模型微调等任务

6. 参考资料

  1. 模型训练当中 checkpoint 作用是什么 - 简书

在这里插入图片描述

相关文章:

深度学习中的Checkpoint是什么?

诸神缄默不语-个人CSDN博文目录 文章目录 引言1. 什么是Checkpoint?2. 为什么需要Checkpoint?3. 如何使用Checkpoint?3.1 TensorFlow 中的 Checkpoint3.2 PyTorch 中的 Checkpoint3.3 transformers中的Checkpoint 4. 在 NLP 任务中的应用5. 总…...

STM32开发笔记,编译与烧录

1. Keil开发环境 【Project】》【Manager】》【Pack Installer】选择相应的芯片,Unpack安装。 2. 编译 3. 烧录 烧录时,Boot0 为 1,Boot1 为 0。烧录后启动,Boot0 为 0 ,Boot 1 为 0。 3.1 ST-LINK烧录 测试连接&a…...

【CXX-Qt】1 CXX-Qt入门

与其他Qt-Rust绑定相比,CXX-Qt的目标不仅仅是将Qt功能暴露给Rust,而是完全将Rust集成到Qt生态系统中。我们将通过一个最小示例,展示如何使用CXX-Qt在Rust中创建自己的QObject,并将其与基于QML的小型GUI集成。 一、阅读前准备知识…...

JS宏进阶:XMLHttpRequest对象

一、概述 XMLHttpRequest简称XHR,它是一个可以在JavaScript中使用的对象,用于在后台与服务器交换数据,实现页面的局部更新,而无需重新加载整个页面,也是Ajax(Asynchronous JavaScript and XML)…...

物联网智能语音控制灯光系统设计与实现

背景 随着物联网技术的蓬勃发展,智能家居逐渐成为现代生活的一部分。在众多智能家居应用中,智能灯光控制系统尤为重要。通过语音控制和自动调节灯光,用户可以更便捷地操作家中的照明设备,提高生活的舒适度与便利性。本文将介绍一…...

hyperf知识问题汇总

1、简单说下 hyperf(什么是 hyperf) 答:hyperf 是一个依赖swoole扩展的 php 开源开发框架,它由黄朝辉团队设计创建维护,具备简洁而强大的组件和超强的并发性能,而且还支持微服务架构,例如&…...

制药行业 BI 可视化数据分析方案

一、行业背景 随着医药行业数字化转型的深入,企业积累了海量的数据,包括销售数据、生产数据、研发数据、市场数据等。如何利用这些数据,挖掘其价值,为企业决策提供支持,成为医药企业面临的重大挑战。在当今竞争激烈的…...

【SVN基础】

软件:ToritoiseSVN 代码版本回退:回退到上一个版本 问题:SVN版本已经提交了版本1和版本2,现在发现不需要版本2的内容,需要回退到版本1然后继续开发。 如图SVN版本已经提交到了107版本,那么本地仓库也已经…...

多项式插值(数值计算方法)Matlab实现

多项式插值(数值计算方法)Matlab实现 一. 原理介绍二. 程序设计1. 构建矩阵2. 求解矩阵方程3. 作出多项式函数4. 绘制插值曲线5. 完整代码 三. 图例 一. 原理介绍 关于插值的定义及基本原理可以参照如下索引 插值原理(数值计算方法&#xff…...

[AI]Mac本地部署Deepseek R1模型 — — 保姆级教程

[AI]Mac本地部署DeepSeek R1模型 — — 保姆级教程 DeepSeek R1是中国AI初创公司深度求索(DeepSeek)推出大模型DeepSeek-R1。 作为一款开源模型,R1在数学、代码、自然语言推理等任务上的性能能够比肩OpenAI o1模型正式版,并采用MI…...

android手机本地部署deepseek1.5B

手机本地部署大模型需要一个开源软件 Release Release v1.6.7 a-ghorbani/pocketpal-ai GitHub 下载release版本apk 它也支持ios,并且是开源的,你可以编译修改它 安装完后是这样的 可以下载推荐的模型,也可以在pc上下载好,然后copy到手机里 点 + 号加载本地模型...

理解UML中的四种关系:依赖、关联、泛化和实现

在软件工程中,统一建模语言(UML)是一种广泛使用的工具,用于可视化、设计、构造和文档化软件系统。UML提供了多种图表类型,如类图、用例图、序列图等,帮助开发者和设计师更好地理解系统的结构和行为。在UML中…...

机器学习 - 词袋模型(Bag of Words)实现文本情感分类的详细示例

为了简单直观的理解模型训练,我这里搜集了两个简单的实现文本情感分类的例子,第一个例子基于朴素贝叶斯分类器,第二个例子基于逻辑回归,通过这两个例子,掌握词袋模型(Bag of Words)实现文本情感…...

Kimi k1.5: Scaling Reinforcement Learning with LLMs

TL;DR 2025 年 kimi 发表的 k1.5 模型技术报告,和 DeepSeek R1 同一天发布,虽然精度上和 R1 有微小差距,但是文章提出的 RL 路线也有很强的参考意义 Paper name Kimi k1.5: Scaling Reinforcement Learning with LLMs Paper Reading Note…...

如何评估云原生GenAI应用开发中的安全风险(下)

以上就是如何评估云原生GenAI应用开发中的安全风险系列中的上篇内容,在本篇中我们介绍了在云原生AI应用开发中不同层级的风险,并了解了如何定义AI系统的风险。在本系列下篇中我们会继续探索我们为我们的云原生AI应用评估风险的背景和意义,并且…...

ASP.NET Core程序的部署

发布 不能直接把bin/Debug部署到生产环境的服务器上,性能低。应该创建网站的发布版,用【发布】功能。两种部署模式:“框架依赖”和“独立”。独立模式选择目标操作系统和CPU类型。Windows、Linux、iOS;关于龙芯。 网站的运行 在…...

《深度LSTM vs 普通LSTM:训练与效果的深度剖析》

在深度学习领域,长短期记忆网络(LSTM)以其出色的处理序列数据能力而备受瞩目。而深度LSTM作为LSTM的扩展形式,与普通LSTM在训练和效果上存在着一些显著的不同。 训练方面 参数数量与计算量:普通LSTM通常只有一层或较少…...

Spring依赖注入方式

写在前面:大家好!我是晴空๓。如果博客中有不足或者的错误的地方欢迎在评论区或者私信我指正,感谢大家的不吝赐教。我的唯一博客更新地址是:https://ac-fun.blog.csdn.net/。非常感谢大家的支持。一起加油,冲鸭&#x…...

Photoshop自定义键盘快捷键

编辑 - 键盘快捷键 CtrlShiftAltK 把画笔工具改成Q , 橡皮擦改成W , 涂抹工具改成E , 增加和减小画笔大小A和S 偏好设置 - 透明度和色域 设置一样颜色 套索工具 可以自定义套选一片区域 Shiftf5 填充 CtrlU 可以改颜色/色相/饱和度 CtrlE 合并图层 CtrlShiftS 另存…...

解决VsCode的 Vetur 插件has no default export Vetur问题

文章目录 前言1.问题2. 原因3. 解决其他 前言 提示: 1.问题 Cannot find module ‘ant-design-vue’. Did you mean to set the ‘moduleResolution’ option to ‘node’, or to add aliases to the ‘paths’ option? Module ‘“/xxx/xxx/xxx/xxx/xxx/src/vie…...

IDEA运行Tomcat出现乱码问题解决汇总

最近正值期末周,有很多同学在写期末Java web作业时,运行tomcat出现乱码问题,经过多次解决与研究,我做了如下整理: 原因: IDEA本身编码与tomcat的编码与Windows编码不同导致,Windows 系统控制台…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

Zustand 状态管理库:极简而强大的解决方案

Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

相机Camera日志实例分析之二:相机Camx【专业模式开启直方图拍照】单帧流程日志详解

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了: 这一篇我们开始讲: 目录 一、场景操作步骤 二、日志基础关键字分级如下 三、场景日志如下: 一、场景操作步骤 操作步…...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

Python爬虫(一):爬虫伪装

一、网站防爬机制概述 在当今互联网环境中,具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类: 身份验证机制:直接将未经授权的爬虫阻挡在外反爬技术体系:通过各种技术手段增加爬虫获取数据的难度…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...

LLM基础1_语言模型如何处理文本

基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken:OpenAI开发的专业"分词器" torch:Facebook开发的强力计算引擎,相当于超级计算器 理解词嵌入:给词语画"…...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...