使用 Amazon SageMaker 微调 Llama 2 模型

本篇文章主要介绍如何使用 Amazon SageMaker 进行 Llama 2 模型微调的示例。
这个示例主要包括:
Llama 2 总体介绍
Llama 2 微调介绍
Llama 2 环境设置
Llama 2 微调训练
前言
随着生成式 AI 的热度逐渐升高,国内外各种基座大语言竞相出炉,在其基础上衍生出种类繁多的应用场景。训练优异的基座大语言模型在通用性方面表现较好,但模型可能并未涉及到特定领域的专业术语、领域内的特定用语或上下文等。采用微调技术可以通过在领域特定数据上进行训练,使模型更好地适应目标领域的特殊语言模式和结构;结合基座模型的通用性和领域特定性,使得模型更具实际应用价值。
Llama 2 总体介绍
Llama 2 是 META 最新开源的 LLM,包括 7B、13B 和 70B 三个版本,训练数据集超过了 Llama 2 的 40%,达到 2 万亿 token;上下文长度也提升到 4K,可以极大扩展多轮对话的轮数、提示词输入数据;与此同时,Llama 2 Chat 模型使用基于人类反馈的强化学习(Reinforcement Learning from Human Feedback,RLHF),针对对话场景进行了大幅优化,达到了非常出色的有用性和安全性基准。HuggingFace 的 TGI 和 vLLM 等框架均有针对 Llama 2 的推理优化,进一步强化了 Llama 2 的可用性。
Llama 2 被认为是开源界大语言模型的首选,众多的垂类大模型均采用 Llama 2 作为基座大模型,在此基础上添加行业数据进行模型的预训练或者微调,适配更多的行业场景。
Llama 2 微调介绍
模型微调主要分为 Full Fine-Tune 和 PEFT (Performance-Efficient Fine-Tune),前者模型全部参数都会进行更新,训练时间较长,训练资源较大;而后者会冻结大部分参数、微调训练网络结构,常见的方式是 LoRA 和 P-Tuning v2。
PEFT 微调方式由于参数更新较少,可能导致模型无法学习到全部领域知识,对于特定任务或领域来说会出现推理不稳定的情况,因此大多数生产系统均使用全参数方式进行模型的微调。基于上述原因,本文会以全参数微调方式介绍 Llama 2 在 Amazon SageMaker 上的微调。
Llama 2 环境设置
备注:项目中的示例代码均保存于代码仓库,地址如下:
https://github.com/aws-samples/llm-workshop-on-amazon-sagemaker
1. 升级 Python SDK
pip install -U sagemaker 2. 获取运行时资源,包括区域、角色、账号、S3 桶等
import boto3
import sagemaker
from sagemaker import get_execution_rolesess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name Llama 2 微调训练
微调准备
克隆代码
采用 lm-sys 团队发布的 FastChat 平台进行 Llama 2 的微调,FastChat 也用于训练了知名的 Vicuna 模型,具有良好的代码规范和性能优化。
git clone https://github.com/lm-sys/FastChat.git
cd FastChat
git reset --hard 974537efbd82093b45e64d07904efe7728193a52 下载 Llama 2 原始模型
from huggingface_hub import snapshot_download
from pathlib import Pathlocal_cache_path = Path("./model")
local_cache_path.mkdir(exist_ok=True)model_name = "TheBloke/Llama-2-13B-fp16"# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.model", "*.py"]model_download_path = snapshot_download(repo_id=model_name,cache_dir=local_cache_path,allow_patterns=allow_patterns,revision='b2e65e8ad4bb35e5abaee0170ebd5fc2134a50bb'
)# Get the model files path
import os
from glob import globlocal_model_path = Nonepaths = os.walk(r'./model')
for root, dirs, files in paths:for file in files:if file == 'config.json':print(os.path.join(root,file))local_model_path = str(os.path.join(root,file))[0:-11]print(local_model_path)
if local_model_path == None:print("Model download may failed, please check prior step!") 拷贝模型和数据到 Amazon S3
chmod +x ./s5cmd
./s5cmd sync ${local_model_path} s3://${sagemaker_default_bucket}/llm/models/llama2/TheBloke/Llama-2-13B-fp16/
rm -rf model 模型微调
模型的微调使用全参数模型,以实现微调后模型的稳定性。
模型的微调使用开源框架 DeepSpeed 进行加速。
准备基础镜像
使用 Amazon SageMaker 定制的深度学习训练镜像作为基础镜像,再安装 Llama 2 训练所需的依赖包。Dockerfile 如下:
%%writefile Dockerfile
## You should change below region code to the region you used, here sample is use us-west-2
From 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04 ENV LANG=C.UTF-8
ENV PYTHONUNBUFFERED=TRUE
ENV PYTHONDONTWRITEBYTECODE=TRUERUN pip3 uninstall -y deepspeed \&& pip3 install deepspeed==0.10.0 \&& pip3 install transformers==4.30.2## Make all local GPUs visible
ENV NVIDIA_VISIBLE_DEVICES="all" 模型微调代码
模型微调源代码较多,细节可以参考上述 git 仓库。
微调参数
为了节省显存,采用 DeepSpeed Stage-3
训练过程开启 bf16,实现整数范围和精度的平衡
训练数据集采用官方提供的 dummy_conversation.json,也就是典型的 {"instruction"、"input"、"output"} 的格式,同时可以支持多轮对话
DEEPSPEED_OPTS="""FastChat/fastchat/train/train_mem.py --deepspeed ds.json --model_name_or_path "/tmp/llama_pretrain/" --data_path FastChat/data/dummy_conversation.json --output_dir "/tmp/llama_out" --num_train_epochs 1 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 4 --evaluation_strategy "no" --save_strategy "no" --save_steps 2000 --save_total_limit 1 --learning_rate 2e-5 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type "cosine" --logging_steps 1 --cache_dir '/tmp' --model_max_length 2048 --gradient_checkpointing True --lazy_preprocess True --bf16 True --tf32 True --report_to "none"
""" 微调脚本
微调使用 torchrun + DeepSpeed 进行分布式训练
%%writefile ./src/ds-train-dist.sh
#!/bin/bash
CURRENT_HOST="${SM_CURRENT_HOST}"IFS=',' read -ra hosts_array <<< "${SM_HOSTS}"
NNODES=${#hosts_array[@]}
NODE_RANK=0for i in "${!hosts_array[@]}"; doif [[ "${hosts_array[$i]}" == *${CURRENT_HOST}* ]]; thenecho "host index:$i"NODE_RANK="$i" fi
doneMASTER_PORT="13579"
export NCCL_SOCKET_IFNAME="eth0"#Configure the distributed arguments for torch.distributed.launch.
GPUS_PER_NODE="$SM_NUM_GPUS"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \--nnodes $NNODES \--node_rank $NODE_RANK \--master_addr $MASTER_ADDR \--master_port $MASTER_PORT"chmod +x ./s5cmd
./s5cmd sync s3://$MODEL_S3_BUCKET/llm/models/llama2/TheBloke/Llama-2-13B-fp16/* /tmp/llama_pretrain/CMD="torchrun ${DISTRIBUTED_ARGS} ${DEEPSPEED_OPTS}"
echo ${CMD}
${CMD} 2>&1 if [[ "${CURRENT_HOST}" == "${MASTER_ADDR}" ]]; then ./s5cmd sync /tmp/llama_out s3://$MODEL_S3_BUCKET/llm/models/llama2/output/TheBloke/Llama-2-13B-fp16/$(date +%Y-%m-%d-%H-%M-%S)/
fi 启动微调
全参数微调,需要使用至少一台 p4de.12xlarge(8 卡 A100 40GB)作为训练机器。
当微调完成后,训练好的模型自动存储于指定的 S3 桶内,可用于后续的模型部署推理。
import time
from sagemaker.estimator import Estimatorenvironment = {'MODEL_S3_BUCKET': sagemaker_default_bucket # The bucket to store pretrained model and fine-tune model
}base_job_name = 'llama2-13b-finetune'instance_type = 'ml.p4d.24xlarge'estimator = Estimator(role=role,entry_point='ds-train-dist.sh',source_dir='./src',base_job_name=base_job_name,instance_count=1,instance_type=instance_type,image_uri=image_uri,environment=environment,disable_profiler=True,debugger_hook_config=False)estimator.fit() 总结
大语言模型方兴未艾,正在以各种方式改变和影响着整个世界。客户拥抱大语言模型,亚马逊云科技团队同样在深耕客户需求和大语言模型技术,可以在未来更好地协助客户实现需求,提升业务价值。
本篇作者

高郁
亚马逊云科技解决方案架构师,主要负责企业客户上云,帮助客户进行云架构设计和技术咨询,专注于智能湖仓、AI/ML 等技术方向。

星标不迷路,开发更极速!
关注后记得星标「亚马逊云开发者」

听说,点完下面4个按钮
就不会碰到bug了!

相关文章:
使用 Amazon SageMaker 微调 Llama 2 模型
本篇文章主要介绍如何使用 Amazon SageMaker 进行 Llama 2 模型微调的示例。 这个示例主要包括: Llama 2 总体介绍Llama 2 微调介绍Llama 2 环境设置Llama 2 微调训练 前言 随着生成式 AI 的热度逐渐升高,国内外各种基座大语言竞相出炉,在其基础上衍生出…...
牛客小白月赛86(D剪纸游戏)
题目链接:D-剪纸游戏_牛客小白月赛86 (nowcoder.com) 题目描述: 输入描述: 输入第一行包含两个空格分隔的整数分别代表 n 和 m。 接下来输入 n行,每行包含 m 个字符,代表残缺纸张。 保证: 1≤n,m≤10001 字符仅有 . 和 * 两种字符…...
MySQL的基础操作与管理
一.MySQL数据库基本操作知识: 1.SQL语句: 关系型数据库,都是使用SQL语句来管理数据库中的数据。 SQL,即结构化查询语言(Structured Query Language) 。 SQL语句用于维护管理数据库,包括数据查询、数据更新、访问控…...
Pytorch 中的forward 函数内部原理
PyTorch中的forward函数是nn.Module类的一部分,它定义了模型的前向传播规则。当你创建一个继承自nn.Module的类时,你实际上是在定义网络的结构。forward函数是这个结构中最关键的部分,因为它指定了数据如何通过网络流动。 单独设计 forward …...
四、C语言中的数组:如何输入与输出二维数组(数组,完)
本章的学习内容如下 四、C语言中的数组:数组的创建与初始化四、C语言中的数组:数组的输入与元素个数C语言—第6次作业—十道代码题掌握一维数组四、C语言中的数组:二维数组 1.二维数组的输入与输出 当我们输入一维数组时需要一个循环来遍历…...
基于python+vue智慧农业小程序flask-django-php-nodejs
传统智慧农业采取了人工的管理方法,但这种管理方法存在着许多弊端,比如效率低下、安全性低以及信息传输的不准确等,同时由于智慧农业中会形成众多的个人文档和信息系统数据,通过人工方法对知识科普、土壤信息、水质信息、购物商城…...
好用的GPTs:指定主题搜索、爬虫、数据清洗、数据分析自动化
好用的GPTs:指定主题搜索、爬虫、数据清洗、数据分析自动化 Scholar:搜索 YOLO小目标医学方面最新论文Scraper:爬虫自动化数据清洗数据分析 点击 Explore GPTs: Scholar:搜索 YOLO小目标医学方面最新论文 搜索 Scho…...
使用Qt自带windeployqt打包QML的exe
1.在开始菜单输入CMD找到对应的Qt开发版本,我的是Qt5.15.2(MinGW 8.1.0 64-bit)。 2.在控制台输入如下字符串,格式为 windeployqt exe绝对路径 --qmldir 工程的绝对路径 如下是我的打包代码。 我需要打包的exe的绝对路径 D:\Prj\Code\Demo\QML\Ana…...
C代码快速傅里叶变换-分类和推理-常微分和偏微分方程
要点 C代码例程函数计算实现: 线性代数方程解:全旋转高斯-乔丹消元,LU分解前向替换和后向替换,对角矩阵处理,任意矩阵奇异值分解,稀疏线性系统循环三对角系统解,将矩阵从完整存储模式转换为行索…...
计算机组成原理 双端口存储器原理实验
一、实验目的 1、了解双端口静态随机存储器IDT7132的工作特性及使用方法 2、了解半导体存储器怎样存储和读出数据 3、了解双端口存储器怎样并行读写,产生冲突的情况如何 二、实验任务 (1)按图7所示,将有关控制信号和和二进制开关对应接好,…...
[音视频学习笔记]六、自制音视频播放器Part1 -新版本ffmpeg,Qt +VS2022,都什么年代了还在写传统播放器?
前言 参考了雷神的自制播放器项目,100行代码实现最简单的基于FFMPEGSDL的视频播放器(SDL1.x) 不过老版本的代码参考意义不大了,我现在准备使用Qt VS2022 FFmpeg59重写这部分代码,具体的代码仓库如下: …...
GPT-5可能会在今年夏天作为对ChatGPT的“实质性改进”而到来
每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…...
官宣|阿里巴巴捐赠的 Flink CDC 项目正式加入 Apache 基金会
摘要:本文整理自阿里云开源大数据平台徐榜江 (雪尽),关于阿里巴巴捐赠的 Flink CDC 项目正式加入 Apache 基金会,内容主要分为以下四部分: 1、Flink CDC 新仓库,新流程 2、Flink CDC 新定位,新玩法 3、Flin…...
部署单节点k8s并允许master节点调度pod
安装k8s 需要注意的是k8s1.24 已经弃用dockershim,现在使用docker需要cri-docker插件作为垫片,对接k8s的CRI。 硬件环境: 2c2g 主机环境: CentOS Linux release 7.9.2009 (Core) IP地址: 192.168.44.161 一、 主机配…...
Django日志(三)
内置TimedRotatingFileHandler 按时间自动切分的log文件,文件后缀 %Y-%m-%d_%H-%M-%S , 初始化参数: 注意 发送邮件的邮箱,开启SMTP服务 filename when=h 时间间隔类型,不区分大小写 S:秒 M:分钟 H:小时 D:天 W0-W6:星期几(0 = 星期一) midnight:如果atTime未指定,…...
【吾爱破解】Android初级题(二)的解题思路 _
拿到apk,我们模拟器打开看一下 好好,抽卡模拟器是吧😀 jadx反编译看一下源码 找到生成flag的地方,大概逻辑就是 java signatureArr getPackageManager().getPackageInfo(getPackageName(), 64).signaturesfor (int i 0; i &l…...
富格林:谨记可信计策安全做单
富格林悉知,现货黄金由于活跃的行情给投资者带来不少的盈利的机会,吸引着众多的投资者进场做单。但在黄金投资市场中一定要掌握可信的投资方法,提前布局好策略,这样才能增加安全获利的机会。不建议直接进入市场做单,因…...
【工具使用】mingw64编译完成运行可执行文件时出现乱码
一,问题现象: notepad设置的时UTF-8编码: mingw64命令行设置的编码格式为: 二,问题原因: 在执行的时候,windows下的编码格式是GBK 三,解决方法: 编译时࿰…...
WebSocket 使用示例,后台为nodejs
效果图 页面代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>WebSocket Client</title&g…...
【算法】力扣【树形DP】687. 最长同值路径
【算法】力扣【树形DP】687. 最长同值路径 687. 最长同值路径 文章目录 【算法】力扣【树形DP】687. 最长同值路径题目描述输入输出示例 题解思路代码描述 复杂度分析总结 题目描述 本题要求在给定的二叉树中寻找最长的同值路径,这个路径中的每个节点的值都相同。…...
【Java学习笔记】Arrays类
Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...
【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...
服务器硬防的应用场景都有哪些?
服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
镜像里切换为普通用户
如果你登录远程虚拟机默认就是 root 用户,但你不希望用 root 权限运行 ns-3(这是对的,ns3 工具会拒绝 root),你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案:创建非 roo…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...
SpringCloudGateway 自定义局部过滤器
场景: 将所有请求转化为同一路径请求(方便穿网配置)在请求头内标识原来路径,然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...
大数据学习(132)-HIve数据分析
🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言Ǵ…...
均衡后的SNRSINR
本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt 根发送天线, n r n_r nr 根接收天线的 MIMO 系…...
