当前位置: 首页 > news >正文

大模型训练框架DeepSpeed使用入门(1): 训练设置

文章目录

  • 一、安装
  • 二、训练设置
    • Step1 第一步参数解析
    • Step2 初始化后端
    • Step3 训练初始化
  • 三、训练代码展示


官方文档直接抄过来,留个笔记。
https://deepspeed.readthedocs.io/en/latest/initialize.html

使用案例来自:
https://github.com/OvJat/DeepSpeedTutorial


大模型训练的痛点是模型参数过大,动辄上百亿,如果单靠单个GPU来完成训练基本不可能。所以需要多卡或者分布式训练来完成这项工作。

DeepSpeed是由Microsoft提供的分布式训练工具,旨在支持更大规模的模型和提供更多的优化策略和工具。对于更大模型的训练来说,DeepSpeed提供了更多策略,例如:Zero、Offload等。

本文简单介绍下如何使用DeepSpeed。


一、安装

pip install deepspeed

二、训练设置

Step1 第一步参数解析

DeepSpeed 使用 argparse 来应用控制台的设置,使用

deepspeed.add_config_arguments()

可以将DeepSpeed内置的参数增加到我们自己的应用参数解析中。

parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()

Step2 初始化后端

与Step3中的 deepspeed.initialize() 不同,
直接调用即可。
一般发生在以下场景

when using model parallelism, pipeline parallelism, or certain data loader scenarios.

在Step3的initialize前,进行调用

deepspeed.init_distributed()

Step3 训练初始化

首先调用 deepspeed.initialize() 进行初始化,是整个调用DeepSpeed训练的入口。
调用后,如果分布式后端没有被初始化后,此时会初始化分布式后端。
使用案例:

model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args,model=net,model_parameters=net.parameters(),training_data=ds)

API如下:

def initialize(args=None,model: torch.nn.Module = None,optimizer: Optional[Union[Optimizer, DeepSpeedOptimizerCallable]] = None,model_parameters: Optional[torch.nn.Module] = None,training_data: Optional[torch.utils.data.Dataset] = None,lr_scheduler: Optional[Union[_LRScheduler, DeepSpeedSchedulerCallable]] = None,distributed_port: int = TORCH_DISTRIBUTED_DEFAULT_PORT,mpu=None,dist_init_required: Optional[bool] = None,collate_fn=None,config=None,config_params=None):"""Initialize the DeepSpeed Engine.Arguments:args: an object containing local_rank and deepspeed_config fields.This is optional if `config` is passed.model: Required: nn.module class before apply any wrappersoptimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.This overrides any optimizer definition in the DeepSpeed json config.model_parameters: Optional: An iterable of torch.Tensors or dicts.Specifies what Tensors should be optimized.training_data: Optional: Dataset of type torch.utils.data.Datasetlr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methodsdistributed_port: Optional: Master node (rank 0)'s free port that needs to be used for communication during distributed trainingmpu: Optional: A model parallelism unit object that implementsget_{model,data}_parallel_{rank,group,world_size}()dist_init_required: Optional: None will auto-initialize torch distributed if needed,otherwise the user can force it to be initialized or not via boolean.collate_fn: Optional: Merges a list of samples to form amini-batch of Tensor(s).  Used when using batched loading from amap-style dataset.config: Optional: Instead of requiring args.deepspeed_config you can pass your deepspeed configas an argument instead, as a path or a dictionary.config_params: Optional: Same as `config`, kept for backwards compatibility.Returns:A tuple of ``engine``, ``optimizer``, ``training_dataloader``, ``lr_scheduler``* ``engine``: DeepSpeed runtime engine which wraps the client model for distributed training.* ``optimizer``: Wrapped optimizer if a user defined ``optimizer`` is supplied, or ifoptimizer is specified in json config else ``None``.* ``training_dataloader``: DeepSpeed dataloader if ``training_data`` was supplied,otherwise ``None``.* ``lr_scheduler``: Wrapped lr scheduler if user ``lr_scheduler`` is passed, orif ``lr_scheduler`` specified in JSON configuration. Otherwise ``None``."""

三、训练代码展示

def parse_arguments():import argparseparser = argparse.ArgumentParser(description='deepspeed training script.')parser.add_argument('--local_rank', type=int, default=-1,help='local rank passed from distributed launcher')# Include DeepSpeed configuration argumentsparser = deepspeed.add_config_arguments(parser)args = parser.parse_args()return argsdef train():args = parse_arguments()# init distributeddeepspeed.init_distributed()# init modelmodel = MyClassifier(3, 100, ch_multi=128)# init datasetds = MyDataset((3, 512, 512), 100, sample_count=int(1e6))# init engineengine, optimizer, training_dataloader, lr_scheduler = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters(),training_data=ds,# config=deepspeed_config,)# load checkpointengine.load_checkpoint("./data/checkpoints/MyClassifier/")# trainlast_time = time.time()loss_list = []echo_interval = 10engine.train()for step, (xx, yy) in enumerate(training_dataloader):step += 1xx = xx.to(device=engine.device, dtype=torch.float16)yy = yy.to(device=engine.device, dtype=torch.long).reshape(-1)outputs = engine(xx)loss = tnf.cross_entropy(outputs, yy)engine.backward(loss)engine.step()loss_list.append(loss.detach().cpu().numpy())if step % echo_interval == 0:loss_avg = np.mean(loss_list[-echo_interval:])used_time = time.time() - last_timetime_p_step = used_time / echo_intervalif args.local_rank == 0:logging.info("[Train Step] Step:{:10d}  Loss:{:8.4f} | Time/Batch: {:6.4f}s",step, loss_avg, time_p_step,)last_time = time.time()# save checkpointengine.save_checkpoint("./data/checkpoints/MyClassifier/")

最后~

码字不易~~

独乐不如众乐~~

如有帮助,欢迎点赞+收藏~~


相关文章:

大模型训练框架DeepSpeed使用入门(1): 训练设置

文章目录 一、安装二、训练设置Step1 第一步参数解析Step2 初始化后端Step3 训练初始化 三、训练代码展示 官方文档直接抄过来,留个笔记。 https://deepspeed.readthedocs.io/en/latest/initialize.html 使用案例来自: https://github.com/OvJat/DeepSp…...

自定义类型——结构体、枚举和联合

自定义类型——结构体、枚举和联合 结构体结构体的声明匿名结构体结构体的自引用结构体的初始化结构体的内存对齐修改默认对齐数结构体传参 位段枚举联合 结构体 结构是一些值的集合,这些值被称为成员变量,结构的每个成员可以是不同类型的变量。 数组是…...

Windows11系统安装Mysql8之后,启动服务net start mysql报错“服务没有响应控制功能”的解决办法

问题 系统环境:Windows11 数据库版本:Mysql8 双击安装,一路下一步,完成,很顺利,但是开启服务后 net start mysql 报错: 服务没有响应控制功能。 请键入 NET HELPMSG 2186 以获得更多的帮助 不…...

WIFI模块的AT指令联网数据交互--第十天

1.1.蓝牙,ESP-01s,Zigbee, NB-Iot等通信模块都是基于AT指令的设计 初始配置和验证 ESP-01s出厂波特率正常是115200, 注意:AT指令,控制类都要加回车,数据传输时不加回车 1.2.上电后,通过串口输出一串系统…...

设计模式Java实现-迭代器模式

✨这里是第七人格的博客✨小七,欢迎您的到来~✨ 🍅系列专栏:设计模式🍅 ✈️本篇内容: 迭代器模式✈️ 🍱 本篇收录完整代码地址:https://gitee.com/diqirenge/design-pattern 🍱 楔子 很久…...

单页源码加密屋zip文件加密API源码

简介: 单页源码加密屋zip文件加密API源码 api源码里面的参数已改好,往服务器或主机一丢就行,出现不能加密了就是加密次数达到上限了,告诉我在到后台修改加密次数 点击下载...

47.全排列

1.题目 47. 全排列 II - 力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/permutations-ii/description/ 2.思路 注意剪枝的条件 3.代码 class Solution {vector<int> path;vector<vector<int>> ret;bool check[9]; public:vector<…...

呼叫中心系统选pscc好还是okcc好

选择PSCC&#xff08;商业软件呼叫中心&#xff09;还是OKCC&#xff08;开源呼叫中心&#xff09;&#xff0c;应基于以下几个关键因素来决定&#xff1a; 技术能力&#xff1a;如果企业拥有或愿意投入资源培养内部技术团队&#xff0c;开源解决方案可能更合适&#xff0c;因为…...

【SRC实战】前端脱敏信息泄露

挖个洞先 https://mp.weixin.qq.com/s/xnCQQCAneT21vYH8Q3OCpw “ 以下漏洞均为实验靶场&#xff0c;如有雷同&#xff0c;纯属巧合 ” 01 — 漏洞证明 一、前端脱敏&#xff0c;请求包泄露明文 “ 前端脱敏处理&#xff0c;请求包是否存在泄露&#xff1f; ” 1、获取验…...

区块链 | NFT 水印:Review on Watermarking Techniques(三)

&#x1f34d;原文&#xff1a;Review on Watermarking Techniques Aiming Authentication of Digital Image Artistic Works Minted as NFTs into Blockchains 一个 NFT 的水印认证协议 可以引入第三方实体来实现对交易的认证&#xff0c;即通过使用 R S A \mathsf{RSA} RSA…...

初识C语言——第十九天

for循环 1.简单概述 2.执行流程 3.建议事项&#xff1a;...

软件需求工程习题

1.&#xff08;面谈&#xff09;是需求获取活动中发生的需求工程师和用户间面对面的会见。 2.使用原型法进行需求获取&#xff0c;&#xff08;演化式&#xff09;原型必须具有健壮性&#xff0c;代码质量要从一开始就能达到最终系统的要求 3.利用面谈进行需求获取时&#xf…...

Win10弹出这个:https://logincdn.msauth.ne

问题描述&#xff1a; Win10脚本错误 Windows10家庭版操作系统开机后弹出这个 https://logincdn.msauth.net/shared/1.0/content/js/ConvergedLogin_PCore_vi321_9jVworKN8EONYo0A2.js 解决方法&#xff1a; 重启计算机后手动关闭第三方安全优化软件&#xff0c;然后在任务管理…...

Vue2 动态路由

VUE CLI 项目 router.js import Vue from "vue"; import Router from "vue-router"; import base from "/view/404/404.vue";const originalPush Router.prototype.push Router.prototype.push function push (location) {return originalPu…...

LeetCode746:使用最小花费爬楼梯

题目描述 给你一个整数数组 cost &#xff0c;其中 cost[i] 是从楼梯第 i 个台阶向上爬需要支付的费用。一旦你支付此费用&#xff0c;即可选择向上爬一个或者两个台阶。 你可以选择从下标为 0 或下标为 1 的台阶开始爬楼梯。 请你计算并返回达到楼梯顶部的最低花费。 代码 …...

DockerFile介绍与使用

一、DockerFile介绍 大家好&#xff0c;今天给大家分享一下关于 DockerFile 的介绍与使用&#xff0c;DockerFile 是一个用于定义如何构建 Docker 镜像的文本文件&#xff0c;具体来说&#xff0c;具有以下重要作用&#xff1a; 标准化构建&#xff1a;提供了一种统一、可重复…...

Java基础知识(六) 字符串

六 字符串 6.1 String字符串 1、String类对象创建 定义String类对象格式&#xff1a;** 1&#xff09;String 字符串变量名“字符串常量”&#xff1b; 2&#xff09;String 字符串变量名new String(字符串常量); 3&#xff09;String 字符串变量名; 字符串变量名“字符串常…...

为什么跨境电商大佬都在自养号测评?看完你就懂了!

在跨境电商的激烈竞争中&#xff0c;各大平台如亚马逊、拼多多Temu、shopee、Lazada、wish、速卖通、煤炉、敦煌、独立站、雅虎、eBay、TikTok、Newegg、Allegro、乐天、美客多、阿里国际、沃尔玛、Nike、OZON、Target以及Joom等&#xff0c;纷纷成为商家们竞相角逐市场份额的焦…...

AtCoder Beginner Contest 353

A 题意&#xff1a;检查是否有比第一个数大的数 #include<bits/stdc.h>using namespace std;int main() {int n;cin>>n;int a;cin>>a;int f0;for(int i2;i<n;i){int k;cin>>k;if(k>a){cout<<i<<endl;f1;break;}}if(f0){cout<&l…...

深度解读《深度探索C++对象模型》之虚继承的实现分析和效率评测(一)

目录 前言 具有虚基类的对象的构造过程 通过子类的对象存取虚基类成员的实现分析 接下来我将持续更新“深度解读《深度探索C对象模型》”系列&#xff0c;敬请期待&#xff0c;欢迎左下角点击关注&#xff01;也可以关注公众号&#xff1a;iShare爱分享&#xff0c;或文章末…...

ARM嵌入式C/C++库架构与优化实践

1. ARM C/C库架构解析ARM架构下的C/C标准库实现与通用PC环境存在显著差异&#xff0c;其设计充分考虑了嵌入式系统的特殊需求。库函数分为两个主要部分&#xff1a;与硬件无关的纯算法实现&#xff08;如字符串处理、数学运算&#xff09;&#xff0c;以及与硬件/操作系统相关的…...

如何快速掌握DevDocs:API文档浏览的终极指南

如何快速掌握DevDocs&#xff1a;API文档浏览的终极指南 【免费下载链接】devdocs API Documentation Browser 项目地址: https://gitcode.com/GitHub_Trending/de/devdocs DevDocs是一款强大的API Documentation Browser&#xff0c;它整合了多种技术文档资源&#xff…...

Voxtral-4B-TTS-2603语音合成入门:标点符号(!?。)对语调与停顿的实际影响

Voxtral-4B-TTS-2603语音合成入门&#xff1a;标点符号&#xff08;&#xff01;&#xff1f;。&#xff09;对语调与停顿的实际影响 1. 引言 你是否遇到过这样的情况&#xff1a;使用语音合成工具生成的音频听起来机械生硬&#xff0c;缺乏自然的情感表达&#xff1f;其实&a…...

Sakura编辑器 宏的基本使用

参考资料 初めてのサクラエディタマクロ(JScript版導入編) すぐに使えるJScript関数集 マクロ専用関数/変数 目录 一. 宏的基本使用 1.1 指定宏脚本执行 1.2 登录宏脚本 1.3 宏脚本执行效果展示 二. 宏案例 一. 宏的基本使用 ⏹此处写一个简单的demo脚本 Sakura编辑器中还有…...

Semantic Kernel 在企业级 Harness 开发中的应用

Semantic Kernel 在企业级 Harness 开发中的应用:打造 AI 原生的内部开发平台(IDP) 摘要 随着企业数字化转型的深入,云原生CI/CD平台Harness已经成为众多中大型企业构建内部开发平台(IDP)的首选方案,但Harness的YAML编排复杂度高、排障耗时久、自定义扩展门槛高、知识…...

Changelogger:实时更新日志聚合器的架构设计与工程实践

1. 项目概述与核心价值在技术迭代日新月异的今天&#xff0c;尤其是AI工具和开发者软件领域&#xff0c;几乎每天都有新的功能发布、API更新或产品迭代。作为一名长期泡在代码和产品里的从业者&#xff0c;我深有体会&#xff1a;错过一个关键更新&#xff0c;可能意味着浪费数…...

C语言命令行参数的使用

### C语言中命令行参数的用法与示例在C语言中&#xff0c;main函数可以通过两个参数来接收命令行参数&#xff1a;int argc 和 char *argv。其中&#xff0c;argc表示命令行参数的数量&#xff08;包括程序名本身&#xff09;&#xff0c;而argv是一个字符串数组&#xff0c;存…...

从注入到调用:一个完整的Unity il2cpp运行时Hook实战指南(附C++代码)

从注入到调用&#xff1a;一个完整的Unity il2cpp运行时Hook实战指南&#xff08;附C代码&#xff09; 在游戏开发与逆向工程领域&#xff0c;Unity引擎的il2cpp后端因其性能优势被广泛采用&#xff0c;但也带来了动态分析的独特挑战。本文将深入探讨如何通过运行时注入技术&am…...

轻松搞定文件压缩:7-Zip新手完全入门指南

轻松搞定文件压缩&#xff1a;7-Zip新手完全入门指南 【免费下载链接】7z 7-Zip Official Chinese Simplified Repository (Homepage and 7z Extra package) 项目地址: https://gitcode.com/gh_mirrors/7z1/7z 你是不是经常遇到这样的情况&#xff1f;电脑硬盘空间告急&…...

怎样高效部署ClearerVoice-Studio:专业级AI语音处理工具包全面指南

怎样高效部署ClearerVoice-Studio&#xff1a;专业级AI语音处理工具包全面指南 【免费下载链接】ClearerVoice-Studio An AI-Powered Speech Processing Toolkit and Open Source SOTA Pretrained Models, Supporting Speech Enhancement, Separation, and Target Speaker Extr…...