当前位置: 首页 > article >正文

DeepSpeed简介及加速模型训练

DeepSpeed是由微软开发的开源深度学习优化框架,专注于大规模模型的高效训练与推理。其核心目标是通过系统级优化技术降低显存占用、提升计算效率,并支持千亿级参数的模型训练。

官网链接:deepspeed
训练代码下载:git代码

一、DeepSpeed的核心作用

  1. 显存优化与高效内存管理

    • ZeRO(Zero Redundancy Optimizer)技术:通过分片存储模型状态(参数、梯度、优化器状态)至不同GPU或CPU,显著减少单卡显存占用。例如,ZeRO-2可将显存占用降低8倍,支持单卡训练130亿参数模型。
      在这里插入图片描述

    • Offload技术:将优化器状态卸载到CPU或NVMe硬盘,扩展至TB级内存,支持万亿参数模型训练。

    • 激活值重计算(Activation Checkpointing):牺牲计算时间换取显存节省,适用于长序列输入。

  2. 灵活的并行策略

    • 3D并行:融合数据并行(DP)、模型并行(张量并行TP、流水线并行PP),支持跨节点与节点内并行组合,适应不同硬件架构。

    • 动态批处理与梯度累积:减少通信频率,支持超大Batch Size训练。

  3. 训练加速与混合精度支持

    • 混合精度训练:支持FP16/BF16,结合动态损失缩放平衡效率与数值稳定性。

    • 稀疏注意力机制:针对长序列任务优化,执行效率提升6倍。

    • 通信优化:支持MPI、NCCL等协议,降低分布式训练通信开销。

  4. 推理优化与模型压缩

    • 低精度推理:通过INT8/FP16量化减少模型体积,提升推理速度。

    • 模型剪枝与蒸馏:压缩模型参数,降低部署成本。


二、与pytorch 对比分析

1. 优势

  • 显存效率:相比PyTorch DDP,单卡80GB GPU可训练130亿参数模型(传统方法仅支持约10亿)。

  • 并行灵活性:支持3D并行组合,优于Horovod(侧重数据并行)和Megatron(侧重模型并行)。

  • 生态集成:与Hugging Face Transformers、PyTorch无缝兼容,简化现有项目迁移。

  • 全流程覆盖:同时优化训练与推理,而vLLM仅专注推理优化。

2. 局限性

  • 配置复杂度:分布式训练需手动调整通信策略和分片参数,学习曲线陡峭(需编写JSON配置文件)。

  • 硬件依赖:部分高级功能(如ZeRO-Infinity)依赖NVMe硬盘或特定GPU架构。

  • 推理效率:纯推理场景下,vLLM的吞吐量更高(连续批处理优化更专精)。


三、训练用例

1、ds_config.json(deepspeed执行训练时,使用的配置文件)
  • deepspeed训练模型时,不需要在代码中定义优化器,只需要在 json 文件中进行配置即可, json文件内容如下:
{"train_batch_size": 128, //所有GPU上的 单个训练批次大小 之和"gradient_accumulation_steps": 1, //梯度累积 步数"optimizer": {"type": "Adam", //选择的 优化器"params": {"lr": 0.00015 //相关学习率大小}},"zero_optimization": { //加速策略"stage":2}
}

2、训练函数

  • 将模型包装成 deepspeed 形式
#将模型 包装成 deepspeed 形式
model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())
  • 使用 deepspeed 包装后的模型 进行 反向传播和梯度更新
#使用 deepspeed 进行 反向传播和梯度更新
#反向传播
model_engine.backward(loss)#梯度更新
model_engine.step()
  • 完整训练代码如下:
'''
使用命令行进行启动启动命令如下:
deepspeed ds_train.py --epochs 10 --deepspeed --deepspeed_config ds_config.json
'''import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModelif __name__ == '__main__':#读取命令行 传递的参数parser = argparse.ArgumentParser()parser.add_argument("--local_rank", help = "local device id on current node", type = int, default=0)parser.add_argument("--epochs", type = int, default=1)parser = deepspeed.add_config_arguments(parser)args = parser.parse_args()#获取数据集train_loader, test_loader = load_data() #数据集加载器中的 batch_size的大小 = (ds_config.json中 train_batch_size/gpu数量)#获取原始模型model = CustomModel().cuda()#将模型 包装成 deepspeed 形式model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())loss_fn = torch.nn.CrossEntropyLoss().cuda() # 损失函数(分类任务常用)for i in range(args.epochs):for inputs, labels in train_loader:#前向传播inputs = inputs.cuda()labels = labels.cuda()outputs = model_engine(inputs)loss = loss_fn(outputs, labels)#使用 deepspeed 进行 反向传播和梯度更新#反向传播model_engine.backward(loss)#梯度更新model_engine.step()model_engine.save_checkpoint('./ds_models', i)#模型保存torch.save(model_engine.module.state_dict(),'deepspeed_train_model.pth')
3、模型评估

import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt# 1. 定义数据转换(预处理)
transform = transforms.Compose([transforms.ToTensor(),          # 转为Tensor格式(自动归一化到0-1)transforms.Normalize((0.1307,), (0.3081,))  # 标准化(MNIST的均值和标准差)
])test_data = datasets.MNIST(root='./data',train=False,          # 测试集transform=transform)#获取数据集
train_loader, test_loader = load_data()model = CustomModel()
model.load_state_dict(torch.load('deepspeed_train_model.pth'))#评估
model.eval()  # 设置为评估模式
correct = 0
total = 0with torch.no_grad():  # 不计算梯度(节省内存)for images, labels in test_loader:images, labels = images, labelsoutputs = model(images)_, predicted = torch.max(outputs.data, 1)  # 取概率最大的类别total += labels.size(0)correct += (predicted == labels).sum().item()print(f"测试集准确率: {100 * correct / total:.2f}%")# 随机选择一张测试图片
index = np.random.randint(0,1000)  # 可以修改这个数字试不同图片
test_image, true_label = test_data[index]
test_image = test_image.unsqueeze(0)  # 增加批次维度# 预测
with torch.no_grad():output = model(test_image)
predicted_label = torch.argmax(output).item()print(f"预测: {predicted_label}, 真实: {true_label}")# 显示结果
plt.imshow(test_image.cpu().squeeze(), cmap='gray')
plt.title(f"预测: {predicted_label}, 真实: {true_label}")
plt.show()

相关文章:

DeepSpeed简介及加速模型训练

DeepSpeed是由微软开发的开源深度学习优化框架,专注于大规模模型的高效训练与推理。其核心目标是通过系统级优化技术降低显存占用、提升计算效率,并支持千亿级参数的模型训练。 官网链接:deepspeed 训练代码下载:git代码 一、De…...

网络安全面试题(一)

文章目录 一、基础概念与模型‌1. 什么是通信协议?列举三种常见的网络通信模型?2. 解释OSI七层模型及各层功能3. TCP/IP四层模型与OSI模型的对应关系是什么?4. 五层协议体系结构与TCP/IP模型的区别?5. 什么是面向连接与非面向连接的服务&…...

Linux 内核探秘:从零构建 GPIO 设备驱动程序实战指南

在嵌入式系统开发领域,GPIO(通用输入 / 输出)作为硬件与软件交互的桥梁,是实现设备控制与数据采集的基础。编写高效、稳定的 GPIO 设备驱动程序,对于发挥硬件性能至关重要。本文将深入剖析 Linux 内核中 GPIO 驱动开发…...

openlayer:10点击地图上某些省份利用Overlay实现提示省份名称

实现点击地图上的省份,在点击经纬度坐标位置附近利用Overlay实现提示框提示相关省份名称。本文介绍了如何通过OpenLayers库实现点击地图上的省份,并在点击的经纬度坐标位置附近显示提示框,提示相关省份名称。首先,定义了两个全局变…...

upload-labs通关笔记-第13关 文件上传之白名单POST法

目录 一、白名单过滤 二、%00截断 1.截断原理 2、截断条件 &#xff08;1&#xff09;PHP版本 < 5.3.4 &#xff08;2&#xff09;magic_quotes_gpc配置为Off &#xff08;3&#xff09;代码逻辑存在缺陷 三、源码分析 1、代码审计 &#xff08;1&#xff09;文件…...

数据库健康监测器(BHM)实战:如何通过 HTML 报告识别潜在问题

在数据库运维中,健康监测是保障系统稳定性与性能的关键环节。通过 HTML 报告,开发者可以直观查看数据库的运行状态、资源使用情况与潜在风险。 本文将围绕 数据库健康监测器(Database Health Monitor, BHM) 的核心功能展开分析,结合 Prometheus + Grafana + MySQL Export…...

C++(20): 文件输入输出库 —— <fstream>

目录 一、 的核心功能 二、核心类及功能 三、核心操作示例 1. 文本文件写入&#xff08;ofstream&#xff09; 2. 文本文件读取&#xff08;ifstream&#xff09; 3. 二进制文件操作&#xff08;fstream&#xff09; 四、文件打开模式 五、文件指针操作 六、错误处理技巧…...

使用Starrocks制作拉链表

5月1日向ods_order_info插入3条数据&#xff1a; CREATE TABLE ods_order_info(dt string,id string COMMENT 订单编号,total_amount decimal(10,2) COMMENT 订单金额 ) PRIMARY KEY(dt, id) PARTITION BY (dt) DISTRIBUTED BY HASH(id) PROPERTIES ( "replication_num&q…...

Oracle 11g 单实例使用+asm修改主机名导致ORA-29701 故障分析

解决 把服务器名修改为原来的&#xff0c;重启服务器。 故障 建表空间失败。 分析 查看告警日志 ORA-1119 signalled during: create tablespace splex datafile ‘DATA’ size 2000M… Tue May 20 18:04:28 2025 create tablespace splex datafile ‘DATA/option/dataf…...

Spring Boot接口通用返回值设计与实现最佳实践

一、核心返回值模型设计&#xff08;增强版&#xff09; package com.chat.common;import com.chat.util.I18nUtil; import com.chat.util.TraceUtil; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter;import java.io.Serializable;/*** 功能: 通…...

DeepSeek 赋能军事:重塑现代战争形态的科技密码

目录 一、引言&#xff1a;AI 浪潮下的军事变革与 DeepSeek 崛起二、DeepSeek 技术原理与特性剖析2.1 核心技术架构2.2 独特优势 三、DeepSeek 在军事侦察中的应用3.1 海量数据快速处理3.2 精准目标识别追踪3.3 预测潜在威胁 四、DeepSeek 在军事指挥决策中的应用4.1 战场态势实…...

day09-新热文章-实时计算

1. 实时计算与定时计算的区别 定时计算&#xff1a;基于固定时间间隔&#xff08;如每天/小时&#xff09;处理全量数据&#xff0c;适用于对实时性要求不高的场景。实时计算&#xff1a;持续处理无界数据流&#xff0c;结果实时输出&#xff0c;适用于高实时性场景&#xff0…...

Elasticsearch面试题带答案

Elasticsearch面试题带答案 Elasticsearch面试题及答案【最新版】Elasticsearch高级面试题大全(2025版),发现网上很多Elasticsearch面试题及答案整理都没有答案,所以花了很长时间搜集,本套Elasticsearch面试题大全,Elasticsearch面试题大汇总,有大量经典的Elasticsearch面…...

OpenCV CUDA模块图像过滤------用于创建一个最大值盒式滤波器(Max Box Filter)函数createBoxMaxFilter()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 createBoxMaxFilter()函数创建的是一个 最大值滤波器&#xff08;Maximum Filter&#xff09;&#xff0c;它对图像中每个像素邻域内的像素值取最…...

Redis数据库-消息队列

一、消息队列介绍 二、基于List结构模拟消息队列 总结&#xff1a; 三、基于PubSub实现消息队列 (1)PubSub介绍 PubSub是publish与subscribe两个单词的缩写&#xff0c;见明知意&#xff0c;PubSub就是发布与订阅的意思。 可以到Redis官网查看通配符的书写规则&#xff1a; …...

【Docker】Docker -p 将容器内部的端口映射到宿主机的端口

这里写自定义目录标题 -p 参数的作用基本语法示例单端口映射&#xff08;将容器 80 端口映射到宿主机 8080&#xff09;&#xff1a;多端口映射&#xff08;映射多个端口&#xff09;&#xff1a;自动分配宿主机端口&#xff08;Docker 随机选择宿主机端口&#xff09;&#xf…...

破解充电安全难题:智能终端的多重防护体系构建

随着智能终端的普及&#xff0c;充电安全问题日益凸显。从电池过热到短路起火&#xff0c;充电过程中的安全隐患不仅威胁用户的生命财产安全&#xff0c;也制约了行业的发展。如何构建一套高效、可靠的多重防护体系&#xff0c;成为破解充电安全难题的关键。通过技术创新和系统…...

apptrace 三大策略,助力电商 App 在 618 突围

随着 5 月 13 日 “618” 电商大促预售战的打响&#xff0c;各大平台纷纷祭出百亿补贴、消费券等大招&#xff0c;投入超百亿流量与数十亿现金&#xff0c;意图在这场年度商战中抢占先机。但这场流量争夺战远比想象中艰难&#xff0c;中国互联网络信息中心数据显示&#xff0c;…...

SpringAI的使用

1. 项目依赖配置 首先需要在 pom.xml 中添加 SpringAI 相关依赖。以下是关键依赖项&#xff1a; xml <!-- SpringAI 核心依赖 --> <dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-core</artifactId><…...

Core Web Vitals 全链路优化:从浏览器引擎到网络协议深度调优

Core Web Vitals 全链路优化:从浏览器引擎到网络协议深度调优 一、浏览器渲染引擎级优化 1.1 合成器线程优化策略 • 分层加速:通过will-change属性创建独立的合成层 .accelerated {will-change: transform;backface-visibility: hidden; }• 光栅化策略调整:使用image-r…...

SuperVINS:应对挑战性成像条件的实时视觉-惯性SLAM框架【全流程配置与测试!!!】【2025最新版!!!!】

一、项目背景及意义 SuperVINS是一个改进的视觉-惯性SLAM&#xff08;同时定位与地图构建&#xff09;框架&#xff0c;旨在解决在挑战性成像条件下的定位和地图构建问题。该项目基于经典的VINS-Fusion框架&#xff0c;但通过引入深度学习方法进行了显著改进。 视觉-惯性导航系…...

Node-Red通过开疆智能Profinet转ModbusTCP采集西门子PLC数据配置案例

一、内容简介 本篇内容主要介绍Node-Red通过node-red-contrib-modbus插件与开疆智能ModbusTCP转Profinet设备进行通讯&#xff0c;这里Profinet转ModbusTCP网关作为从站设备&#xff0c;Node-Red作为主站分别从0地址开始读取10个线圈状态和10个保持寄存器&#xff0c;分别用Mo…...

vscode连接WSL卡住

原因&#xff1a;打开防火墙 解决&#xff1a; 使用sudo ufw disable关闭防火墙...

Redis面试题全面解析:从基础到底层实现

Redis作为当今最流行的内存数据库之一&#xff0c;是后端开发岗位面试中的高频考点。本文将系统整理Redis面试中常见的基础、中级和底层实现问题&#xff0c;帮助开发者全面准备Redis相关面试。 一、Redis基础问题 1. Redis是什么&#xff1f;主要特点是什么&#xff1f; Re…...

【性能测试】jvm监控

使用本地jvisualvm远程监控服务器 参考文章&#xff1a;https://blog.csdn.net/yeyuningzi/article/details/140261411 jvisualvm工具默认是监控本地jvm&#xff0c;如果需要监控远程就要修改配置参数 1、先查看是否打开 ps -ef|java 如果打开杀掉进程 2、进入项目服务路径下…...

Uniapp开发鸿蒙应用时如何运行和调试项目

经过前几天的分享&#xff0c;大家应该应该对uniapp开发鸿蒙应用的开发语法有了一定的了解&#xff0c;可以进行一些简单的应用开发&#xff0c;今天分享一下在使用uniapp开发鸿蒙应用时怎么运行到鸿蒙设备&#xff0c;并且在开发中怎么调试程序。 运行 Uniapp项目支持运行到…...

QT+RSVisa控制LXI仪器

1.下载并安装visa R&SVISA - Rohde & Schwarz China 2.安装后的目录说明 安装了64位visa会默认把32位的安装上&#xff1b; 64位库和头文件目录为&#xff1a;C:\Program Files\IVI Foundation 32位库和头文件目录为&#xff1a;C:\Program Files (x86)\IVI Foundation…...

PHP8.0版本导出excel失败

环境&#xff1a;fastadmin框架&#xff0c;不是原版接手的项目。PHP8.0,mysql5.7. code // 创建一个新的 Spreadsheet 对象 $spreadsheet new Spreadsheet(); $worksheet $spreadsheet->getActiveSheet();// 设置表头 $worksheet->setCellValue(A1, ID); $worksheet…...

GO语言学习(五)

GO语言学习&#xff08;五&#xff09; 前面我们已经学了一些关于golang的基础知识&#xff0c;从这一期开始&#xff0c;我们就来讲解一下基于golang为后端的web开发&#xff0c;首先这一期为一些golang为后端的web开发基础讲解&#xff0c;我们将会从web的工作方式、golang如…...

js不同浏览器标签页、窗口或 iframe 之间可以相互通信

一、创建一个广播通道 // 创建一个名为 vue-apps-channel 的广播通道 const channel new BroadcastChannel(vue-apps-channel);二、发送消息 channel.postMessage({type: popup, message: false}); 三、接收消息&#xff08;也需要创建广播通道&#xff09; // 也创建一个…...