深度学习中的Checkpoint是什么?
诸神缄默不语-个人CSDN博文目录
文章目录
- 引言
- 1. 什么是Checkpoint?
- 2. 为什么需要Checkpoint?
- 3. 如何使用Checkpoint?
- 3.1 TensorFlow 中的 Checkpoint
- 3.2 PyTorch 中的 Checkpoint
- 3.3 transformers中的Checkpoint
- 4. 在 NLP 任务中的应用
- 5. 总结
- 6. 参考资料
引言
在深度学习训练过程中,模型的训练往往需要较长的时间,并且计算资源昂贵。由于训练过程中可能遇到各种意外情况,比如断电、程序崩溃,甚至想要在不同阶段对比模型的表现,因此我们需要一种机制来保存训练进度,以便可以随时恢复。这就是**Checkpoint(检查点)**的作用。
对于刚入门深度学习的小伙伴,理解Checkpoint的概念并合理使用它,可以大大提高模型训练的稳定性和效率。本文将详细介绍Checkpoint的概念、用途以及如何在NLP任务中使用它。
1. 什么是Checkpoint?
Checkpoint(检查点)是指在训练过程中,定期保存模型的状态,包括模型的权重参数、优化器状态以及训练进度(如当前的epoch数)。这样,即使训练中断,我们也可以从最近的Checkpoint恢复训练,而不是从头开始。
简单来说,Checkpoint 就像一个存档点,让我们能够在不重头训练的情况下继续优化模型。
一个大模型的checkpoint可能以如下文件形式储存:

2. 为什么需要Checkpoint?
Checkpoint 的主要作用包括:
-
防止训练中断导致的损失:训练神经网络需要消耗大量计算资源,训练时间可能长达数小时甚至数天。如果训练因突发情况(如断电、程序崩溃)中断,Checkpoint 可以帮助我们恢复进度。
-
支持断点续训:当训练过程中需要调整超参数或遇到不可预见的问题时,我们可以从最近的Checkpoint继续训练,而不必重新训练整个模型。
-
保存最佳模型:在训练过程中,我们通常会评估模型在验证集上的表现。通过Checkpoint,我们可以保存最优表现的模型,而不是仅仅保存最后一次训练的结果。
-
支持迁移学习:在实际应用中,我们经常会使用预训练模型(如BERT、GPT等),然后在特定任务上进行微调(fine-tuning)。这些预训练模型的Checkpoint可以用作新的任务的起点,而不必从零开始训练。
3. 如何使用Checkpoint?
在深度学习框架(如 TensorFlow 和 PyTorch)中,Checkpoint 的使用非常方便。下面分别介绍在 TensorFlow 和 PyTorch 中如何保存和加载 Checkpoint。
3.1 TensorFlow 中的 Checkpoint
保存Checkpoint:
在 TensorFlow(Keras)中,可以使用 ModelCheckpoint 回调函数来实现自动保存。
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint# 创建简单的模型
model = tf.keras.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(100,)),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])# 设置Checkpoint,保存最优模型
checkpoint_callback = ModelCheckpoint(filepath='best_model.h5', # 保存路径save_best_only=True, # 仅保存最优模型monitor='val_loss', # 监控的指标mode='min', # val_loss 越小越好verbose=1 # 输出日志
)# 训练模型,并使用Checkpoint
model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=10, callbacks=[checkpoint_callback])
加载Checkpoint:
from tensorflow.keras.models import load_model# 加载已保存的模型
model = load_model('best_model.h5')
这样,我们就可以在训练过程中自动保存最优模型,并在需要时加载它。
3.2 PyTorch 中的 Checkpoint
在 PyTorch 中,我们可以使用 torch.save 和 torch.load 来手动保存和加载模型。
保存Checkpoint:
import torch# 假设 model 是我们的神经网络,optimizer 是优化器
checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, 'checkpoint.pth')
加载Checkpoint:
# 加载Checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
在 PyTorch 中,保存和加载 Checkpoint 需要手动指定模型和优化器的状态,而 TensorFlow 处理起来更为自动化。
3.3 transformers中的Checkpoint
如果直接用transformers的Trainer的话,就会自动根据TrainingArguments的参数来设置checkpoint保存策略。具体的参数有save_strategy、save_steps、save_total_limit、load_best_model_at_end等,可以看我之前写过的关于transformers包的博文。
epochs = 10
lr = 2e-5
train_bs = 8
eval_bs = train_bs * 2training_args = TrainingArguments(output_dir=output_dir,num_train_epochs=epochs,learning_rate=lr,per_device_train_batch_size=train_bs,per_device_eval_batch_size=eval_bs,evaluation_strategy="epoch",logging_steps=logging_steps
)
断点续训:
# Trainer 的定义
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset
)# 从最近的检查点恢复训练
trainer.train(resume_from_checkpoint=True)
4. 在 NLP 任务中的应用
在自然语言处理任务中,Checkpoint 主要用于:
- 训练 Transformer 模型(如 BERT、GPT)时,保存和恢复训练进度。
- 微调预训练模型时,从预训练权重(如
bert-base-uncased)加载 Checkpoint 进行继续训练。 - 文本生成任务(如 Seq2Seq 模型),确保中断时可以从最近的 Checkpoint 继续训练。
5. 总结
- Checkpoint 是深度学习训练过程中保存模型状态的机制,可以防止训练中断带来的损失。
- 它有助于断点续训、保存最佳模型以及进行迁移学习。
- 在 TensorFlow 和 PyTorch 中都有方便的方式来保存和加载 Checkpoint。
- 在 NLP 任务中,Checkpoint 被广泛用于 Transformer 训练、预训练模型微调等任务。
6. 参考资料
- 模型训练当中 checkpoint 作用是什么 - 简书

相关文章:
深度学习中的Checkpoint是什么?
诸神缄默不语-个人CSDN博文目录 文章目录 引言1. 什么是Checkpoint?2. 为什么需要Checkpoint?3. 如何使用Checkpoint?3.1 TensorFlow 中的 Checkpoint3.2 PyTorch 中的 Checkpoint3.3 transformers中的Checkpoint 4. 在 NLP 任务中的应用5. 总…...
字符串高频算法:无重复字符的最长子串
题目 3. 无重复字符的最长子串 - 力扣(LeetCode) 解题思路 思路 方法: 滑动窗口 [!简单思路] [^1]以示例一中的字符串 abcabcbb 为例,找出从每一个字符开始的,不包含重复字符的最长子串,其中最长的那个字符串即为答…...
用深度学习模型构建海洋动物图像分类保姆教程
使用深度学习模型构建深度学习海洋动物图像分类模型的完整步骤如下,分为关键阶段和详细操作说明: 1. 数据准备与预处理 1.1 数据集组织 按类别分文件夹存储图像,例如:dataset/train/class1/class2/...val/class1/class2/...test…...
51单片机俄罗斯方块计分函数
/************************************************************************************************************** * 名称:scoring * 功能:计分 * 参数:NULL * 返回:NULL * 备注:采用非阻塞延时 ****************…...
Android开发获取缓存,删除缓存
Android开发获取缓存,删除缓存 app设置中往往有清理缓存的功能。会显示当前缓存时多少,然后可以点击清理缓存 直接上代码: object CacheHelper {/*** 获取缓存大小* param context* return* throws Exception*/JvmStaticfun getTotalCache…...
npm无法加载文件 因为此系统禁止运行脚本
安装nodejs后遇到问题: 在项目里【node -v】可以打印出来,【npm -v】打印不出来,显示npm无法加载文件 因为此系统禁止运行脚本。 但是在winr,cmd里【node -v】,【npm -v】都也可打印出来。 解决方法: cmd里可以打印出…...
NLP_[2]-认识文本预处理
文章目录 1 认识文本预处理1 文本预处理及其作用2. 文本预处理中包含的主要环节2.1 文本处理的基本方法2.2 文本张量表示方法2.3 文本语料的数据分析2.4 文本特征处理2.5数据增强方法2.6 重要说明 2 文本处理的基本方法1. 什么是分词2 什么是命名实体识别3 什么是词性标注 1 认…...
知识库升级新思路:用生成式AI打造智能知识助手
在当今信息爆炸的时代,企业和组织面临着海量数据的处理和管理挑战。知识库管理系统(Knowledge Base Management System, KBMS)作为一种有效的信息管理工具,帮助企业存储、组织和检索知识。然而,传统的知识库系统往往依…...
蚂蚁爬行最短问题
初二数学问题记录 分析过程 考点:2点之间直线最短。 思考过程:将EBCF以BC为边翻折,EF边翻折后为,则A为蚂蚁需要爬行的最小距离。...
【电机控制器】STC8H1K芯片——低功耗
【电机控制器】STC8H1K芯片——低功耗 文章目录 [TOC](文章目录) 前言一、芯片手册说明二、IDLE模式三、PD模式四、PD模式唤醒五、实验验证1.接线2.视频(待填) 六、参考资料总结 前言 使用工具: 1.STC仿真器烧录器 提示:以下是本…...
【专题】2024-2025人工智能代理深度剖析:GenAI 前沿、LangChain 现状及演进影响与发展趋势报告汇总PDF洞察(附原数据表)
原文链接:https://tecdat.cn/?p39630 在科技飞速发展的当下,人工智能代理正经历着深刻的变革,其能力演变已然成为重塑各行业格局的关键力量。从早期简单的规则执行,到如今复杂的自主决策与多智能体协作,人工智能代理…...
SAP-ABAP:SAP的第一行REPORT后面后缀作用详解
在SAP ABAP中,REPORT 语句是定义报表程序的核心语句,其后可以跟多个后缀(参数),用于控制报表的行为和属性。以下是常见的 REPORT 后缀及其作用的详解: 程序名称 • 语法:REPORT <program_nam…...
25/2/8 <机器人基础> 阻抗控制
1. 什么是阻抗控制? 阻抗控制旨在通过调节机器人与环境的相互作用,控制其动态行为。阻抗可以理解为一个力和位移之间的关系,涉及力、速度和位置的协同控制。 2. 阻抗控制的基本概念 力控制:根据感测的外力调节机械手的动作。位置…...
java-list深入理解(流程图)
List源码学习: 此篇文章使用流程图和源码方式,理解List的源码,方便记忆 核心逻辑流程图: #mermaid-svg-BBrPrDuqUdLMtHvj {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-BBrPrDuqUdLMtHvj .error-icon{fill:#…...
Sparse4D v3:推进端到端3D检测和跟踪
论文地址:2311.11722 (arxiv.org) 代码地址:HorizonRobotics/Sparse4D (github.com) 在自动驾驶感知系统中,3D 检测和跟踪是两项基本任务。本文在 Sparse4D 框架的基础上更深入地探讨了这一领域。作者引入了两个辅助训练任务(Temp…...
LeetCode781 森林中的兔子
问题描述 在一片神秘的森林里,住着许多兔子,但是我们并不知道兔子的具体数量。现在,我们对其中若干只兔子进行提问,问题是 “还有多少只兔子与你(指被提问的兔子)颜色相同?” 我们将每只兔子的…...
M系列/Mac安装配置Node.js全栈开发环境(nvm+npm+yarn)
一、安装 nvm(Node Version Manager) 打开终端,使用 curl 在 M 系列 Mac 上安装 nvm: curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.1/install.sh | bash对于非 M 系列的 Intel Mac,上述命令同样适…...
Dify使用
1. 概述 官网:Dify.AI 生成式 AI 应用创新引擎 文档:欢迎使用 Dify | Dify GITHUB:langgenius/dify: Dify is an open-source LLM app development platform. Difys intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, ob…...
借助 Cursor 快速实现小程序前端开发
借助 Cursor 快速实现小程序前端开发 在当今快节奏的互联网时代,小程序因其便捷性、高效性以及无需下载安装的特点,成为众多企业和开发者关注的焦点。然而,小程序的开发往往需要耗费大量的时间和精力,尤其是在前端开发阶段。幸运…...
python 语音识别方案对比
目录 一、语音识别 二、代码实践 2.1 使用vosk三方库 2.2 使用SpeechRecognition 2.3 使用Whisper 一、语音识别 今天识别了别人做的这个app,觉得虽然是个日记app 但是用来学英语也挺好的,能进行语音识别,然后矫正语法,自己说的时候 ,实在不知道怎么说可以先乱说,然…...
Hanoi ( 2022 ICPC Southeastern Europe Regional Contest )
Hanoi ( 2022 ICPC Southeastern Europe Regional Contest ) The original problem “Towers of Hanoi” is about moving n n n circular disks of distinct sizes between 3 3 3 rods. In one move, the player can move only the top disk from on…...
革新在线购物体验:CatV2TON引领虚拟试穿技术新纪元
在这个数字化飞速发展的时代,图像与视频合成技术正以前所未有的速度重塑着我们的生活,尤其在在线零售领域,一场关于购物体验的革命正在悄然上演。想象一下,无需亲自试穿,仅凭一张照片或一段视频,就能精准预…...
【Git】ssh如何配置gitlab+github
当我们工作项目在gitlab上,又希望同时能更新自己个人的github项目时,可能因为隐私问题,不能使用同一′密钥。就需要在本地电脑上分别配置两次ssh。 1、分别创建ssh key 在用户主目录下,查询是否存在“.ssh”文件: 如…...
全国路网矢量shp数据(分不同类型分省份)
科研练习数据 全国路网矢量shp数据(分不同类型分省份) 有需要的自取 数据格式:shp(线) 数据包含类型:城市主干道、城市次干道、城市快速路、城市支路、高速公路、内部道路、人行道、乡村道路、自行车道路…...
音频进阶学习十二——Z变换一(Z变换、收敛域、性质与定理)
文章目录 前言一、Z变换1.Z变换的作用2.Z变换公式3.Z的状态表示1) r 1 r1 r12) 0 < r < 1 0<r<1 0<r<13) r > 1 r>1 r>1 4.关于Z的解释 二、收敛域1.收敛域的定义2.收敛域的表示方式3.ROC的分析1)当 …...
使用Redis解决使用Session登录带来的共享问题
在学习项目的过程中遇到了使用Session实现登录功能所带来的共享问题,此问题可以使用Redis来解决,也即是加上一层来解决问题。 接下来介绍一些Session的相关内容并且采用Session实现登录功能(并附上代码),进行分析其存在…...
STM32F1学习——USART串口通信
一、USART通用同步异步收发机 USART的全称是Universal Synchronous/Asynchronous Receiver Transmitter , 通用同步异步收发机,但由于他主要以异步通信为主,所以他也叫UART。它遵循TTL电平标准,是一种全双工异步通信标准ÿ…...
[概率论] 随机变量
Kolmogorov 定义的随机变量是基于测度论和实变函数的。这是因为随机变量的概念需要精确地定义其可能的取值、发生的概率以及这些事件之间的依赖关系。 测度论:在数学中,测度论是用来研究集合大小的理论,特别是无穷可数集和无界集的大小。对于…...
Docker 部署 MinIO | 国内阿里镜像
一、导读 Minio 是个基于 Golang 编写的开源对象存储套件,基于Apache License v2.0开源协议,虽然轻量,却拥有着不错的性能。它兼容亚马逊S3云存储服务接口。可以很简单的和其他应用结合使用,例如 NodeJS、Redis、MySQL等。 二、…...
探索Aviator:轻量级Java动态表达式求值引擎的使用指南
目录 一、快速介绍 (一)Aviator (二)Aviator、IKExpression、QLExpress比较和建议 二、基本应用使用手册 1.执行表达式 2.使用变量 3.exec 方法 4.调用函数 调用内置函数 调用字符串函数 调用自定义函数 5.编译表达式…...
