当前位置: 首页 > news >正文

Pytorch单、多GPU和CPU训练模型保存和加载

Pytorch多GPU训练模型保存和加载

在多GPU训练中,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_dict()来获取模型的状态字典,以确保保存的参数名与模型定义中的参数名一致。(本质上原来的model还是存在的,参数也会同步更新)

  1. 多GPU训练模型保存
    在多GPU训练时,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_dict()来获取模型的状态字典,以确保保存的参数名与模型定义中的参数名一致。

  2. 单GPU或CPU加载模型
    当在单GPU或CPU上加载模型时,如果直接使用model.state_dict()保存的模型,由于缺少module前缀,会导致参数名不匹配,从而无法正确加载模型。因此,在保存多GPU训练的模型时,应该使用model.module.state_dict()来保存模型的状态字典,这样在单GPU或CPU上加载模型时,可以直接加载,不会出现参数名不匹配的问题。

  3. 示例代码
    以下是一个示例代码,展示了如何在多GPU训练时保存模型,并在单GPU或CPU上加载模型:

import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"	#设置GPU编号
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设这是你的模型定义
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 创建模型实例
model = YourModel()# 将模型移动到多GPU上
if torch.cuda.device_count() > 1:model = nn.DataParallel(model)model = model.to(device)
else:model = model.to(device)
······
# 假设这是你的训练代码,训练完成后保存模型
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')# 在单、多GPU或CPU上加载模型
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to(device)

2 在多GPU训练得到的模型加载时,通常需要考虑以下几个步骤:

  1. 模型保存
    在多GPU训练时,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中。因此,在保存模型时,需要确保保存的是模型的state_dict而不是整个模型对象。例如:
if torch.cuda.device_count() > 1:torch.save(model.module.state_dict(), 'model.pth')
else:torch.save(model.state_dict(), 'model.pth')
  1. 模型加载
    在加载模型时,首先需要创建模型的实例,然后使用load_state_dict方法来加载保存的权重。如果模型是在多GPU环境下训练的,那么在加载时也应该使用torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel来包装模型。例如:
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth'))
model = model.to('cuda')
  1. 注意事项
    在加载模型时,需要注意以下几点:

如果模型是在多GPU环境下训练的,那么在加载时也应该使用相同数量的GPU,或者使用torch.nn.DataParallel来包装模型,即使只有一个GPU可用。
如果模型是在分布式训练环境下训练的,那么在加载时也应该使用torch.nn.parallel.DistributedDataParallel来包装模型。
如果模型是在混合精度训练(如使用了torch.cuda.amp)下训练的,那么在加载模型后,应该恢复之前的精度设置。

3 为了避免模型保存和加载出错

在多GPU训练的模型使用了torch.nn.DataParallel来包装模型,但本质上原来的model是依然存在的,且参数会同步更新:

  1. torch.nn.DataParallel 的工作原理
    torch.nn.DataParallel 是 PyTorch 提供的一个类,用于在多个 GPU 上并行训练模型。它的工作原理如下:
    模型复制:DataParallel 会在每个 GPU 上创建模型的副本。
    数据分发:输入数据会被分发到各个 GPU 上。
    前向传播:每个 GPU 上的模型副本会独立进行前向传播计算。
    梯度收集:所有 GPU 上的梯度会被收集并汇总到主 GPU 上。
    参数更新:主 GPU 上的优化器会根据汇总后的梯度更新模型参数,然后将更新后的参数同步回其他 GPU。
  2. 模型参数更新
    当你使用 model_train = torch.nn.DataParallel(model) 后,model_train 实际上是一个包装了原始模型 model 的对象。虽然 model_train 是多GPU并行的版本,但它的参数更新是通过主 GPU 上的优化器完成的,并且这些更新会同步回原始模型 model
    因此,model 的参数确实会被更新。具体来说:
    前向传播和反向传播:在 train_model 函数中,model_train 用于前向传播和反向传播。
    参数更新:优化器 optimizer 使用的是 model.parameters(),即原始模型的参数。在每次迭代中,优化器会根据汇总后的梯度更新这些参数。
    参数同步:更新后的参数会自动同步到 model_train 中的各个 GPU 副本。
    因此可以使用如下代码,加载模型和保存模型:
import torch
import torch.nn as nn
import os
os.os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"	#设置GPU编号
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 假设这是你的模型定义
class YourModel(nn.Module):def __init__(self):super(YourModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):return self.fc(x)# 创建模型实例
model = YourModel()# 将模型移动到多GPU上,单GPU依然适用
if torch.cuda.device_count() > 1:model_train = nn.DataParallel(model)model_train = model_train.to(device)
else:model_train = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)#注意这是model的参数
······
output = model_train(input)	# 多卡时训练的输入和输出,注意这是model_train# 假设这是你的训练代码,训练完成后保存模型
torch.save(model.state_dict(), 'model.pth')	#注意这是model
  • 再在单/多GPU或CPU上加载模型,都不会报错,因为这里的model不是包装体,不带module
model = YourModel()
if torch.cuda.device_count() > 1:model = torch.nn.DataParallel(model)
model.load_state_dict(torch.load('model.pth',map_location = device))
model = model.to(device)

相关文章:

Pytorch单、多GPU和CPU训练模型保存和加载

Pytorch多GPU训练模型保存和加载 在多GPU训练中,模型通常被包装在torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel中,这会在模型的参数名前加上module前缀。因此,在保存模型时,需要使用model.module.state_di…...

Karate 介绍与快速示例(API测试自动化、模拟、性能测试与UI自动化工具)

Karate是一个将API测试自动化、模拟、性能测试甚至UI自动化结合到一个统一框架中的开源工具。 Karate使用Gherkin 的BDD语法,是语言中性的,即使是非程序员也很容易。断言和HTML报告是内置的,支持并行运行测试以提高速度Karate 是用Java语言编写, 可以在Java 项目项目中运行…...

Pytest 高级用法:间接参数化

文章目录 1. 引言2. 基础概念2.1 Fixture2.2 参数化 3. 代码实例3.1 基础设置3.2 测试用例示例示例 1:基础的间接参数化示例 2:通过 request 获取参数值示例 3:多参数组合测试示例 4:部分间接参数化 4. 最佳实践5. 总结参考资料 1…...

第07章 存储管理(一)

一、磁盘简介 1.1 名称称呼 磁盘/硬盘/disk是同一个东西,不同于内存的是容量比较大。 1.2 类型 机械:机械硬盘即是传统普通硬盘,主要由:盘片,磁头,盘片转轴及控制电机,磁头控制器&#xff0…...

Go语言的 的设计模式(Design Patterns)核心知识

Go语言的设计模式(Design Patterns)核心知识 Go语言(Golang)是一种静态类型、编译型的编程语言,自2009年由Google正式推出以来,因其高效的性能、卓越的并发能力以及简洁的语法受到广泛欢迎。在软件开发中&…...

js函数预览图片:支持鼠标和手势拖拽缩放

对之前的方式改进:原生js实现图片预览控件,支持丝滑拖拽,滚轮放缩,放缩聚焦_js图片预览-CSDN博客 /*** 图片预览函数,调用后自动预览图片* param {图片地址} imgurl*/ function openImagePreview(imgurl) {if (!imgurl…...

用QT实现 端口扫描工具1

安装在线QT,尽量是完整地自己进行安装,不然会少包 参考【保姆级图文教程】QT下载、安装、入门、配置VS Qt环境-CSDN博客 临时存储空间不够。 Windows系统通常会使用C盘来存储临时文件。 修改临时文件存储位置 打开系统属性: 右键点击“此电…...

设计模式 结构型 适配器模式(Adapter Pattern)与 常见技术框架应用 解析

适配器模式(Adapter Pattern)是一种结构型设计模式,它允许将一个类的接口转换成客户端所期望的另一个接口,从而使原本因接口不兼容而无法一起工作的类能够协同工作。这种设计模式在软件开发中非常有用,尤其是在需要集成…...

vue 项目集成 electron 和 electron 打包及环境配置

vue electron 开发桌面端应用 安装 electron npm i electron -D记得加上-D,electron 需添加到devDependencies,如果添加到dependencies后面运行可能会报错 根目录创建electron文件夹,在electron文件夹创建main.js(或者backgrou…...

vscode如何离线安装插件

在没有网络的时候,如果要安装插件,就会麻烦一些,需要通过离线安装的方式进行。下面记录如何在vscode离线安装插件。 一、下载离线插件 在一台能联网的电脑中,下载好离线插件,拷贝到无法联网的电脑上。等待安装。 vscode插件商店地址:https://marketplace.visualstudio.co…...

计算机网络常见面试题及解答

以下是计算机网络中常见的面试题及解答,按主题分类: --- ## **一、基础概念** ### **1. OSI 七层模型和 TCP/IP 模型的区别是什么?** **答:** - **OSI 七层模型:** - 应用层、表示层、会话层、传输层、网络层、数…...

举例说明AI模型怎么聚类,最后神经网络怎么保存

举例说明怎么聚类,最后神经网络怎么保存 目录 举例说明怎么聚类,最后神经网络怎么保存K - Means聚类算法实现神经元特征聚类划分成不同专家的原理和过程 特征提取: 首先,需要从神经元中提取有代表性的特征。例如,对于一个多层感知机(MLP)中的神经元,其权重向量可以作为特…...

HarmonyOS NEXT应用开发实战(一):边学边玩,从零开发一款影视APP

引言 学习一项技能,最好也最快的办法就是动手实战。通过自己给自己找项目练习,不仅能够激发兴趣,还能从开发实战中不断总结经验。这种学习方法是最为高效的。今天,我们将通过开发一款名为“爱影家”的影视APP,来学习H…...

STM32G0B1 can Error_Handler 解决方法

问题现象 MCU上电,发送0x13帧数据固定进入 Error_Handler 硬件介绍 MCU :STM32G0B1 can:NSI1042 tx 接TX RX 接RX 折腾了一下午,无解,问题依旧; 对比测试 STM32G431 手头有块G431 官方评估版CAN 模块; 同样的…...

使用 `llama_index` 构建智能问答系统:多种文档切片方法的评估

使用 llama_index 构建智能问答系统:多种文档切片方法的评估 代码优化与解析1. **代码结构优化**2. **日志管理**3. **环境变量管理**4. **模型初始化**5. **提示模板更新**6. **问答函数优化**7. **索引构建与查询引擎**8. **节点解析器测试** 总结 在现代自然语言…...

【大模型】7 天 AI 大模型学习

7 天 AI 大模型学习 Day 2 今天是 7 天AI 大模型学习的第二天 😄,今天我将会学习 Transformer 、Encoder-based and Decoder-Based LLMs 等 。如果有感兴趣的,就和我一起开始吧 ~ 课程链接 :2025年快速吃透AI大模型&am…...

软件工程大复习之(四)——面向对象与UML

4.1 面向对象概述 面向对象(OO)是一种编程范式,它将数据和处理数据的方法封装在对象中。面向对象的主要概念包括: 对象:实例化的数据和方法的集合。类:对象的蓝图或模板。封装:隐藏对象的内部…...

【Linux】shell命令

目录 shell的基本命令 shell - 贝壳 外在保护工具 用户、shell、内核、硬件之间的关系 解析器的分类: shell命令格式 history -历史记录查询 修改环境变量的值: shell中的特殊字符 通配符 管道 | 输入输出重定向 命令置换符 shell的基本命…...

ValuesRAG:以检索增强情境学习强化文化对齐

随着大型语言模型(LLMs)的迅猛发展,其在各个领域展现出强大的能力。然而,训练数据中西方中心主义的倾向,使得 LLMs 在文化价值观一致性方面面临严峻挑战,这一问题在跨文化场景中尤为突出,可能导…...

【机器学习篇】交通革命:机器学习如何引领未来的道路创新

嘿,你知道吗?机器学习正在交通领域掀起一场革命啦!它将如何引领未来道路创新呢 本文有精彩的 C 代码演示、实用的图片解释,还有超多干货,保证让你大开眼界,点赞收藏关注, 开启一场奇妙的探索之…...

Vim 调用外部命令学习笔记

Vim 外部命令集成完全指南 文章目录 Vim 外部命令集成完全指南核心概念理解命令语法解析语法对比 常用外部命令详解文本排序与去重文本筛选与搜索高级 grep 搜索技巧文本替换与编辑字符处理高级文本处理编程语言处理其他实用命令 范围操作示例指定行范围处理复合命令示例 实用技…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以? 在 Golang 的面试中,map 类型的使用是一个常见的考点,其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

遍历 Map 类型集合的方法汇总

1 方法一 先用方法 keySet() 获取集合中的所有键。再通过 gey(key) 方法用对应键获取值 import java.util.HashMap; import java.util.Set;public class Test {public static void main(String[] args) {HashMap hashMap new HashMap();hashMap.put("语文",99);has…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...

零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)

本期内容并不是很难,相信大家会学的很愉快,当然对于有后端基础的朋友来说,本期内容更加容易了解,当然没有基础的也别担心,本期内容会详细解释有关内容 本期用到的软件:yakit(因为经过之前好多期…...

虚拟电厂发展三大趋势:市场化、技术主导、车网互联

市场化:从政策驱动到多元盈利 政策全面赋能 2025年4月,国家发改委、能源局发布《关于加快推进虚拟电厂发展的指导意见》,首次明确虚拟电厂为“独立市场主体”,提出硬性目标:2027年全国调节能力≥2000万千瓦&#xff0…...

Linux nano命令的基本使用

参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...

Python+ZeroMQ实战:智能车辆状态监控与模拟模式自动切换

目录 关键点 技术实现1 技术实现2 摘要: 本文将介绍如何利用Python和ZeroMQ消息队列构建一个智能车辆状态监控系统。系统能够根据时间策略自动切换驾驶模式(自动驾驶、人工驾驶、远程驾驶、主动安全),并通过实时消息推送更新车…...

系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文通过代码驱动的方式,系统讲解PyTorch核心概念和实战技巧,涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...