DeepSpeed简介及加速模型训练
DeepSpeed是由微软开发的开源深度学习优化框架,专注于大规模模型的高效训练与推理。其核心目标是通过系统级优化技术降低显存占用、提升计算效率,并支持千亿级参数的模型训练。
官网链接:deepspeed
训练代码下载:git代码
一、DeepSpeed的核心作用
-
显存优化与高效内存管理
-
ZeRO(Zero Redundancy Optimizer)技术:通过分片存储模型状态(参数、梯度、优化器状态)至不同GPU或CPU,显著减少单卡显存占用。例如,ZeRO-2可将显存占用降低8倍,支持单卡训练130亿参数模型。
-
Offload技术:将优化器状态卸载到CPU或NVMe硬盘,扩展至TB级内存,支持万亿参数模型训练。
-
激活值重计算(Activation Checkpointing):牺牲计算时间换取显存节省,适用于长序列输入。
-
-
灵活的并行策略
-
3D并行:融合数据并行(DP)、模型并行(张量并行TP、流水线并行PP),支持跨节点与节点内并行组合,适应不同硬件架构。
-
动态批处理与梯度累积:减少通信频率,支持超大Batch Size训练。
-
-
训练加速与混合精度支持
-
混合精度训练:支持FP16/BF16,结合动态损失缩放平衡效率与数值稳定性。
-
稀疏注意力机制:针对长序列任务优化,执行效率提升6倍。
-
通信优化:支持MPI、NCCL等协议,降低分布式训练通信开销。
-
-
推理优化与模型压缩
-
低精度推理:通过INT8/FP16量化减少模型体积,提升推理速度。
-
模型剪枝与蒸馏:压缩模型参数,降低部署成本。
-
二、与pytorch 对比分析
1. 优势
-
显存效率:相比PyTorch DDP,单卡80GB GPU可训练130亿参数模型(传统方法仅支持约10亿)。
-
并行灵活性:支持3D并行组合,优于Horovod(侧重数据并行)和Megatron(侧重模型并行)。
-
生态集成:与Hugging Face Transformers、PyTorch无缝兼容,简化现有项目迁移。
-
全流程覆盖:同时优化训练与推理,而vLLM仅专注推理优化。
2. 局限性
-
配置复杂度:分布式训练需手动调整通信策略和分片参数,学习曲线陡峭(需编写JSON配置文件)。
-
硬件依赖:部分高级功能(如ZeRO-Infinity)依赖NVMe硬盘或特定GPU架构。
-
推理效率:纯推理场景下,vLLM的吞吐量更高(连续批处理优化更专精)。
三、训练用例
1、ds_config.json(deepspeed执行训练时,使用的配置文件)
- deepspeed训练模型时,不需要在代码中定义优化器,只需要在 json 文件中进行配置即可, json文件内容如下:
{"train_batch_size": 128, //所有GPU上的 单个训练批次大小 之和"gradient_accumulation_steps": 1, //梯度累积 步数"optimizer": {"type": "Adam", //选择的 优化器"params": {"lr": 0.00015 //相关学习率大小}},"zero_optimization": { //加速策略"stage":2}
}
2、训练函数
- 将模型包装成 deepspeed 形式
#将模型 包装成 deepspeed 形式
model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())
- 使用 deepspeed 包装后的模型 进行 反向传播和梯度更新
#使用 deepspeed 进行 反向传播和梯度更新
#反向传播
model_engine.backward(loss)#梯度更新
model_engine.step()
- 完整训练代码如下:
'''
使用命令行进行启动启动命令如下:
deepspeed ds_train.py --epochs 10 --deepspeed --deepspeed_config ds_config.json
'''import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModelif __name__ == '__main__':#读取命令行 传递的参数parser = argparse.ArgumentParser()parser.add_argument("--local_rank", help = "local device id on current node", type = int, default=0)parser.add_argument("--epochs", type = int, default=1)parser = deepspeed.add_config_arguments(parser)args = parser.parse_args()#获取数据集train_loader, test_loader = load_data() #数据集加载器中的 batch_size的大小 = (ds_config.json中 train_batch_size/gpu数量)#获取原始模型model = CustomModel().cuda()#将模型 包装成 deepspeed 形式model_engine, _, _, _ = deepspeed.initialize(args=args,model=model,model_parameters=model.parameters())loss_fn = torch.nn.CrossEntropyLoss().cuda() # 损失函数(分类任务常用)for i in range(args.epochs):for inputs, labels in train_loader:#前向传播inputs = inputs.cuda()labels = labels.cuda()outputs = model_engine(inputs)loss = loss_fn(outputs, labels)#使用 deepspeed 进行 反向传播和梯度更新#反向传播model_engine.backward(loss)#梯度更新model_engine.step()model_engine.save_checkpoint('./ds_models', i)#模型保存torch.save(model_engine.module.state_dict(),'deepspeed_train_model.pth')
3、模型评估
import argparse
import torch
import torchvision
import deepspeed
from model_definition import load_data, CustomModel
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt# 1. 定义数据转换(预处理)
transform = transforms.Compose([transforms.ToTensor(), # 转为Tensor格式(自动归一化到0-1)transforms.Normalize((0.1307,), (0.3081,)) # 标准化(MNIST的均值和标准差)
])test_data = datasets.MNIST(root='./data',train=False, # 测试集transform=transform)#获取数据集
train_loader, test_loader = load_data()model = CustomModel()
model.load_state_dict(torch.load('deepspeed_train_model.pth'))#评估
model.eval() # 设置为评估模式
correct = 0
total = 0with torch.no_grad(): # 不计算梯度(节省内存)for images, labels in test_loader:images, labels = images, labelsoutputs = model(images)_, predicted = torch.max(outputs.data, 1) # 取概率最大的类别total += labels.size(0)correct += (predicted == labels).sum().item()print(f"测试集准确率: {100 * correct / total:.2f}%")# 随机选择一张测试图片
index = np.random.randint(0,1000) # 可以修改这个数字试不同图片
test_image, true_label = test_data[index]
test_image = test_image.unsqueeze(0) # 增加批次维度# 预测
with torch.no_grad():output = model(test_image)
predicted_label = torch.argmax(output).item()print(f"预测: {predicted_label}, 真实: {true_label}")# 显示结果
plt.imshow(test_image.cpu().squeeze(), cmap='gray')
plt.title(f"预测: {predicted_label}, 真实: {true_label}")
plt.show()
相关文章:

DeepSpeed简介及加速模型训练
DeepSpeed是由微软开发的开源深度学习优化框架,专注于大规模模型的高效训练与推理。其核心目标是通过系统级优化技术降低显存占用、提升计算效率,并支持千亿级参数的模型训练。 官网链接:deepspeed 训练代码下载:git代码 一、De…...
网络安全面试题(一)
文章目录 一、基础概念与模型1. 什么是通信协议?列举三种常见的网络通信模型?2. 解释OSI七层模型及各层功能3. TCP/IP四层模型与OSI模型的对应关系是什么?4. 五层协议体系结构与TCP/IP模型的区别?5. 什么是面向连接与非面向连接的服务&…...
Linux 内核探秘:从零构建 GPIO 设备驱动程序实战指南
在嵌入式系统开发领域,GPIO(通用输入 / 输出)作为硬件与软件交互的桥梁,是实现设备控制与数据采集的基础。编写高效、稳定的 GPIO 设备驱动程序,对于发挥硬件性能至关重要。本文将深入剖析 Linux 内核中 GPIO 驱动开发…...

openlayer:10点击地图上某些省份利用Overlay实现提示省份名称
实现点击地图上的省份,在点击经纬度坐标位置附近利用Overlay实现提示框提示相关省份名称。本文介绍了如何通过OpenLayers库实现点击地图上的省份,并在点击的经纬度坐标位置附近显示提示框,提示相关省份名称。首先,定义了两个全局变…...

upload-labs通关笔记-第13关 文件上传之白名单POST法
目录 一、白名单过滤 二、%00截断 1.截断原理 2、截断条件 (1)PHP版本 < 5.3.4 (2)magic_quotes_gpc配置为Off (3)代码逻辑存在缺陷 三、源码分析 1、代码审计 (1)文件…...

数据库健康监测器(BHM)实战:如何通过 HTML 报告识别潜在问题
在数据库运维中,健康监测是保障系统稳定性与性能的关键环节。通过 HTML 报告,开发者可以直观查看数据库的运行状态、资源使用情况与潜在风险。 本文将围绕 数据库健康监测器(Database Health Monitor, BHM) 的核心功能展开分析,结合 Prometheus + Grafana + MySQL Export…...
C++(20): 文件输入输出库 —— <fstream>
目录 一、 的核心功能 二、核心类及功能 三、核心操作示例 1. 文本文件写入(ofstream) 2. 文本文件读取(ifstream) 3. 二进制文件操作(fstream) 四、文件打开模式 五、文件指针操作 六、错误处理技巧…...
使用Starrocks制作拉链表
5月1日向ods_order_info插入3条数据: CREATE TABLE ods_order_info(dt string,id string COMMENT 订单编号,total_amount decimal(10,2) COMMENT 订单金额 ) PRIMARY KEY(dt, id) PARTITION BY (dt) DISTRIBUTED BY HASH(id) PROPERTIES ( "replication_num&q…...

Oracle 11g 单实例使用+asm修改主机名导致ORA-29701 故障分析
解决 把服务器名修改为原来的,重启服务器。 故障 建表空间失败。 分析 查看告警日志 ORA-1119 signalled during: create tablespace splex datafile ‘DATA’ size 2000M… Tue May 20 18:04:28 2025 create tablespace splex datafile ‘DATA/option/dataf…...
Spring Boot接口通用返回值设计与实现最佳实践
一、核心返回值模型设计(增强版) package com.chat.common;import com.chat.util.I18nUtil; import com.chat.util.TraceUtil; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter;import java.io.Serializable;/*** 功能: 通…...
DeepSeek 赋能军事:重塑现代战争形态的科技密码
目录 一、引言:AI 浪潮下的军事变革与 DeepSeek 崛起二、DeepSeek 技术原理与特性剖析2.1 核心技术架构2.2 独特优势 三、DeepSeek 在军事侦察中的应用3.1 海量数据快速处理3.2 精准目标识别追踪3.3 预测潜在威胁 四、DeepSeek 在军事指挥决策中的应用4.1 战场态势实…...
day09-新热文章-实时计算
1. 实时计算与定时计算的区别 定时计算:基于固定时间间隔(如每天/小时)处理全量数据,适用于对实时性要求不高的场景。实时计算:持续处理无界数据流,结果实时输出,适用于高实时性场景࿰…...
Elasticsearch面试题带答案
Elasticsearch面试题带答案 Elasticsearch面试题及答案【最新版】Elasticsearch高级面试题大全(2025版),发现网上很多Elasticsearch面试题及答案整理都没有答案,所以花了很长时间搜集,本套Elasticsearch面试题大全,Elasticsearch面试题大汇总,有大量经典的Elasticsearch面…...

OpenCV CUDA模块图像过滤------用于创建一个最大值盒式滤波器(Max Box Filter)函数createBoxMaxFilter()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 createBoxMaxFilter()函数创建的是一个 最大值滤波器(Maximum Filter),它对图像中每个像素邻域内的像素值取最…...

Redis数据库-消息队列
一、消息队列介绍 二、基于List结构模拟消息队列 总结: 三、基于PubSub实现消息队列 (1)PubSub介绍 PubSub是publish与subscribe两个单词的缩写,见明知意,PubSub就是发布与订阅的意思。 可以到Redis官网查看通配符的书写规则: …...
【Docker】Docker -p 将容器内部的端口映射到宿主机的端口
这里写自定义目录标题 -p 参数的作用基本语法示例单端口映射(将容器 80 端口映射到宿主机 8080):多端口映射(映射多个端口):自动分配宿主机端口(Docker 随机选择宿主机端口)…...

破解充电安全难题:智能终端的多重防护体系构建
随着智能终端的普及,充电安全问题日益凸显。从电池过热到短路起火,充电过程中的安全隐患不仅威胁用户的生命财产安全,也制约了行业的发展。如何构建一套高效、可靠的多重防护体系,成为破解充电安全难题的关键。通过技术创新和系统…...

apptrace 三大策略,助力电商 App 在 618 突围
随着 5 月 13 日 “618” 电商大促预售战的打响,各大平台纷纷祭出百亿补贴、消费券等大招,投入超百亿流量与数十亿现金,意图在这场年度商战中抢占先机。但这场流量争夺战远比想象中艰难,中国互联网络信息中心数据显示,…...
SpringAI的使用
1. 项目依赖配置 首先需要在 pom.xml 中添加 SpringAI 相关依赖。以下是关键依赖项: xml <!-- SpringAI 核心依赖 --> <dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-core</artifactId><…...
Core Web Vitals 全链路优化:从浏览器引擎到网络协议深度调优
Core Web Vitals 全链路优化:从浏览器引擎到网络协议深度调优 一、浏览器渲染引擎级优化 1.1 合成器线程优化策略 • 分层加速:通过will-change属性创建独立的合成层 .accelerated {will-change: transform;backface-visibility: hidden; }• 光栅化策略调整:使用image-r…...

SuperVINS:应对挑战性成像条件的实时视觉-惯性SLAM框架【全流程配置与测试!!!】【2025最新版!!!!】
一、项目背景及意义 SuperVINS是一个改进的视觉-惯性SLAM(同时定位与地图构建)框架,旨在解决在挑战性成像条件下的定位和地图构建问题。该项目基于经典的VINS-Fusion框架,但通过引入深度学习方法进行了显著改进。 视觉-惯性导航系…...

Node-Red通过开疆智能Profinet转ModbusTCP采集西门子PLC数据配置案例
一、内容简介 本篇内容主要介绍Node-Red通过node-red-contrib-modbus插件与开疆智能ModbusTCP转Profinet设备进行通讯,这里Profinet转ModbusTCP网关作为从站设备,Node-Red作为主站分别从0地址开始读取10个线圈状态和10个保持寄存器,分别用Mo…...
vscode连接WSL卡住
原因:打开防火墙 解决: 使用sudo ufw disable关闭防火墙...
Redis面试题全面解析:从基础到底层实现
Redis作为当今最流行的内存数据库之一,是后端开发岗位面试中的高频考点。本文将系统整理Redis面试中常见的基础、中级和底层实现问题,帮助开发者全面准备Redis相关面试。 一、Redis基础问题 1. Redis是什么?主要特点是什么? Re…...

【性能测试】jvm监控
使用本地jvisualvm远程监控服务器 参考文章:https://blog.csdn.net/yeyuningzi/article/details/140261411 jvisualvm工具默认是监控本地jvm,如果需要监控远程就要修改配置参数 1、先查看是否打开 ps -ef|java 如果打开杀掉进程 2、进入项目服务路径下…...

Uniapp开发鸿蒙应用时如何运行和调试项目
经过前几天的分享,大家应该应该对uniapp开发鸿蒙应用的开发语法有了一定的了解,可以进行一些简单的应用开发,今天分享一下在使用uniapp开发鸿蒙应用时怎么运行到鸿蒙设备,并且在开发中怎么调试程序。 运行 Uniapp项目支持运行到…...

QT+RSVisa控制LXI仪器
1.下载并安装visa R&SVISA - Rohde & Schwarz China 2.安装后的目录说明 安装了64位visa会默认把32位的安装上; 64位库和头文件目录为:C:\Program Files\IVI Foundation 32位库和头文件目录为:C:\Program Files (x86)\IVI Foundation…...
PHP8.0版本导出excel失败
环境:fastadmin框架,不是原版接手的项目。PHP8.0,mysql5.7. code // 创建一个新的 Spreadsheet 对象 $spreadsheet new Spreadsheet(); $worksheet $spreadsheet->getActiveSheet();// 设置表头 $worksheet->setCellValue(A1, ID); $worksheet…...
GO语言学习(五)
GO语言学习(五) 前面我们已经学了一些关于golang的基础知识,从这一期开始,我们就来讲解一下基于golang为后端的web开发,首先这一期为一些golang为后端的web开发基础讲解,我们将会从web的工作方式、golang如…...
js不同浏览器标签页、窗口或 iframe 之间可以相互通信
一、创建一个广播通道 // 创建一个名为 vue-apps-channel 的广播通道 const channel new BroadcastChannel(vue-apps-channel);二、发送消息 channel.postMessage({type: popup, message: false}); 三、接收消息(也需要创建广播通道) // 也创建一个…...