0基础学会在亚马逊云科技AWS上利用SageMaker、PEFT和LoRA高效微调AI大语言模型(含具体教程和代码)
项目简介:
小李哥今天将继续介绍亚马逊云科技AWS云计算平台上的前沿前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS上的AI软甲开发最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon SageMaker上微调(Fine-tune)大语言模型dolly-v2-3b,满足日常生活中不同的场景需求,并将介分享如何在SageMaker上优化模型性能并节省计算资源实现成本控制,最后将部署后的大语言模型URL集成到自己云上的软件应用中。
本方案包括通过Amazon Cloudfront和S3托管前端页面,并通过Amazon API Gateway和AWS Lambda将应用程序与AI模型集成,调用大模型实现推理。本方案的解决方案架构图如下:
利用微调模型创建的对话机器人前端UI
利用本方案小李哥用微调后的模型搭建了一个Q&A对话机器人助手,可以生成代码、文字总结、回答问题。
在开始分享案例之前,我们来了解一下本方案的技术背景,帮助大家更好的理解方案架构。
什么是Amazon SageMaker?
Amazon SageMaker 是一个完全托管的机器学习服务(大家可以理解为Serverless的Jupyter Notebook),专为应用开发和数据科学家设计,帮助他们快速构建、训练和部署机器学习模型。使用 SageMaker,您无需担心底层基础设施的管理,可以专注于模型的开发和优化。它提供了一整套工具和功能,包括数据准备、模型训练、超参数调优、模型部署和监控,简化了整个机器学习工作流程。
本方案将介绍以下内容:
1. 使用 SageMaker Jupyter Notebook进行dolly-v2-3b模型开发和微调
2. 在SageMaker部署微调后的大语言模型LLM并基于数据进行推理
3. 使用多场景的测试案例验证推理结果表现,并将部署的模型API节点集成进云端应用
项目搭建具体步骤:
下面跟着小李哥手把手微调一个亚马逊云科技AWS上的生成式AI模型(dolly-v2-3b)的软件应用,并将AI大模型部署与应用集成。
1. 在控制台进入Amazon SageMaker, 点击Notebook
2. 打开Jupyter Notebook
3. 创建一个新的Notebook:“lab-notebook.ipynb”并打开
4. 接下来我们在单元格内一步一步运行代码,检查CUDA的内存状态
!nvidia-smi
5.接下来,我们安装必要依赖并导入
%%capture!pip3 install -r requirements.txt --quiet
!pip install sagemaker --quiet --upgrade --force-reinstall
%%captureimport os
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Tuple, Union
from datasets import Dataset, load_dataset, disable_caching
disable_caching() ## disable huggingface cachefrom transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TextDatasetimport torch
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer
import accelerate
import bitsandbytesfrom IPython.display import Markdown
6. 导入提前准备好的FAQs数据集
sagemaker_faqs_dataset = load_dataset("csv", data_files='data/amazon_sagemaker_faqs.csv')['train']
sagemaker_faqs_dataset
sagemaker_faqs_dataset[0]
7. 我们定义用于模型推理的提示词格式
from utils.helpers import INTRO_BLURB, INSTRUCTION_KEY, RESPONSE_KEY, END_KEY, RESPONSE_KEY_NL, DEFAULT_SEED, PROMPT
'''
PROMPT = """{intro}{instruction_key}{instruction}{response_key}{response}{end_key}"""
'''
Markdown(PROMPT)
8. 下面我们进入重头戏,导入一个提前预训练好的LLM大语言模型“databricks/dolly-v2-3b”。
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", padding_side="left")tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b",# use_cache=False,device_map="auto", #"balanced",load_in_8bit=True,
)
9. 对模型训练进行预准备, 处理数据集、优化模型训练(PEFT)效率
model.resize_token_embeddings(len(tokenizer))from functools import partial
from utils.helpers import mlu_preprocess_batchMAX_LENGTH = 256
_preprocessing_function = partial(mlu_preprocess_batch, max_length=MAX_LENGTH, tokenizer=tokenizer)encoded_sagemaker_faqs_dataset = sagemaker_faqs_dataset.map(_preprocessing_function,batched=True,remove_columns=["instruction", "response", "text"],
)processed_dataset = encoded_sagemaker_faqs_dataset.filter(lambda rec: len(rec["input_ids"]) < MAX_LENGTH)split_dataset = processed_dataset.train_test_split(test_size=14, seed=0)
split_dataset
10. 同时我们使用LoRA(Low-Rank Adaptation)模型加速我们的模型微调
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskTypeMICRO_BATCH_SIZE = 8
BATCH_SIZE = 64
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LORA_R = 256 # 512
LORA_ALPHA = 512 # 1024
LORA_DROPOUT = 0.05# Define LoRA Config
lora_config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM"
)model = get_peft_model(model, lora_config)
model.print_trainable_parameters()from utils.helpers import MLUDataCollatorForCompletionOnlyLMdata_collator = MLUDataCollatorForCompletionOnlyLM(tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)
11. 接下来我们定义模型训练参数并开始训练。其中Batch=1,Step=20000,epoch为10.
EPOCHS = 10
LEARNING_RATE = 1e-4
MODEL_SAVE_FOLDER_NAME = "dolly-3b-lora"training_args = TrainingArguments(output_dir=MODEL_SAVE_FOLDER_NAME,fp16=True,per_device_train_batch_size=1,per_device_eval_batch_size=1,learning_rate=LEARNING_RATE,num_train_epochs=EPOCHS,logging_strategy="steps",logging_steps=100,evaluation_strategy="steps",eval_steps=100, save_strategy="steps",save_steps=20000,save_total_limit=10,
)trainer = Trainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=split_dataset['train'],eval_dataset=split_dataset["test"],data_collator=data_collator,
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
12. 接下来我们将微调后的模型保存在本地
trainer.model.save_pretrained(MODEL_SAVE_FOLDER_NAME)trainer.save_model()trainer.model.config.save_pretrained(MODEL_SAVE_FOLDER_NAME)tokenizer.save_pretrained(MODEL_SAVE_FOLDER_NAME)
13. 接下来,我们将保存到本地的模型进行部署,生成公开访问的API节点Endpoint
对部署所需要的参数进行定义和初始化
import boto3
import json
import sagemaker.djl_inference
from sagemaker.session import Session
from sagemaker import image_uris
from sagemaker import Modelsagemaker_session = Session()
print("sagemaker_session: ", sagemaker_session)aws_role = sagemaker_session.get_caller_identity_arn()
print("aws_role: ", aws_role)aws_region = boto3.Session().region_name
print("aws_region: ", aws_region)image_uri = image_uris.retrieve(framework="djl-deepspeed",version="0.22.1",region=sagemaker_session._region_name)
print("image_uri: ", image_uri)
进行模型部署
model_data="s3://{}/lora_model.tar.gz".format(mybucket)model = Model(image_uri=image_uri,model_data=model_data,predictor_cls=sagemaker.djl_inference.DJLPredictor,role=aws_role)
14.最后我们写入提示词,对大语言模型进行测试, 得到推理
outputs = predictor.predict({"inputs": "What solutions come pre-built with Amazon SageMaker JumpStart?"})from IPython.display import Markdown
Markdown(outputs)
15. 我们下面进入SageMaker Endpoint页面,得到刚部署的模型API端点的URL,通过这种方式我们就可以在应用中调用我们的微调后的大语言模型了。
相关文章:

0基础学会在亚马逊云科技AWS上利用SageMaker、PEFT和LoRA高效微调AI大语言模型(含具体教程和代码)
项目简介: 小李哥今天将继续介绍亚马逊云科技AWS云计算平台上的前沿前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS上的AI软甲开发最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon SageMaker上…...

护网HW面试——redis利用方式即复现
参考:https://xz.aliyun.com/t/13071 面试中经常会问到ssrf的打法,讲到ssrf那么就会讲到配合打内网的redis,本篇就介绍redis的打法。 未授权 原理: Redis默认情况下,会绑定在0.0.0.0:6379,如果没有采用相关…...

C++ //练习 15.8 给出静态类型和动态类型的定义。
C Primer(第5版) 练习 15.8 练习 15.8 给出静态类型和动态类型的定义。 环境:Linux Ubuntu(云服务器) 工具:vim 解释 静态类型:在编译时已知,是在变量声明时的类型或表达式生成的…...

阿里云ECS服务器安装jdk并运行jar包,访问成功详解
安装 OpenJDK 8 使用 yum 包管理器安装 OpenJDK 8 sudo yum install -y java-1.8.0-openjdk-devel 验证安装 安装完成后,验证 JDK 是否安装成功: java -version设置 JAVA_HOME 环境变量: 为了确保系统中的其他应用程序可以找到 JDK&…...

Windows系统上使用npm来安装和配置Yarn,在VSCode中使用
一、安装Yarn 1. 安装Node.js和npm 如果还没有安装Node.js和npm,可以从Node.js官方网站下载并安装最新版本的Node.js,npm会随Node.js一起安装。 2. 使用npm安装Yarn 打开命令提示符或PowerShell,运行以下命令来全局安装Yarn: …...

Unity ColorSpace 之 【颜色空间】相关说明,以及【Linear】颜色校正 【Gamma】的简单整理
Unity ColorSpace 之 【颜色空间】相关说明,以及【Linear】颜色校正 【Gamma】的简单整理 目录 Unity ColorSpace 之 【颜色空间】相关说明,以及【Linear】颜色校正 【Gamma】的简单整理 一、简单介绍 二、在Unity中设置颜色空间 三、Unity中的Gamma…...

JavaScript的学习(二)
今天继续学习JavaScript的第二天,还是打基础 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title…...

【接口自动化_06课_Pytest+Excel+Allure完整框架集成】
一、logging在接口自动化里的应用 1、设置日志的配置,并收集日志文件 日志的设置需要在pytest.ini文件里设置。这个里面尽量不要有中文 2、debug日志的打印 pytest.ini文件的开关一定得是true才能在控制台打印日志 import allure import pytest from P06_PytestFr…...

Profibus协议转Profinet协议网关模块连接智能电表通讯案例
一、背景 在工业自动化领域,Profibus协议和Profinet协议是两种常见的工业通讯协议,而连接智能电表需要用到这两种协议之间的网关模块。本文将通过一个实际案例,详细介绍如何使用Profibus转Profinet模块(XD-PNPBM20)实…...

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(九)-无人机服务区分离
引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…...

acrobat 中 PDF 复制时不能精确选中所选内容所在行的一种解决方法
现象:划取行的时候,自动扩展为多行 如果整段选中复制,粘贴后是乱码 解决步骤 识别完,保存 验证 可以按行复制了。 如果遇到仅使用 acrobat OCR 不能彻底解决的,更换其他自己熟悉的进行 OCR。...

安卓学习中遇到的问题【bug】
安卓学习中遇到的问题 1Gradle下载慢怎么办? Gradle下载慢怎么办? distributionUrlhttps://mirrors.cloud.tencent.com/gradle/gradle-7.5-bin.zip 2 Could not resolve all files for configuration ‘:classpath‘. > Could not resolv…...

【日常记录】【CSS】display:inline 的样式截断
文章目录 1. 案例2. css属性:box-decoration-break参考地址 1. 案例 现在有一篇文章,某些句子,是要被标记的,加一些css 让他突出一下 可以看到,在最后,断开了,那如若要让 断开哪里的样式 和 开始…...

数据库系统安全
数据库安全威胁 数据库作为信息系统中的核心组成部分,存储和管理着大量敏感和关键的数据,成为网络攻击者的主要目标之一。以下是常见的数据库安全威胁及其详细描述: 一、常见数据库安全威胁 SQL注入攻击(SQL Injectionÿ…...

Qt MV架构-代理模型
一、基本概念 代理模型可以将一个模型中的数据进行排序或者过滤,然后提供给视图进行显示。 Qt中提供了QSortFilterProxyModel作为标准的代理模型来完成模型中数据的排序和过滤。 要使用一个代理模型,则只需要为其设置源模型,然后再视图中使…...

WebSocket实现群聊功能、房间隔离
引用WebSocket相关依赖 <dependency><groupId>javax.websocket</groupId><artifactId>javax.websocket-api</artifactId><version>1.1</version></dependency><dependency><groupId>org.springframework</grou…...

顶顶通呼叫中心中间件实现随时启动和停止质检(mod_cti基于FreeSWITCH)
文章目录 前言联系我们拨号方案启动停止ASR执行FreeSWITCH 命令接口启动ASR接口停止ASR接口 通知配置cti.json配置质检结果写入数据库 前言 顶顶通呼叫中心中间件的实时质检功能是由两个模块组成:mod_asr 和 mod_qc。 mod_asr:负责调用ASR将用户们在通…...

基于conda包的环境创建、激活、管理与删除
Anaconda是一个免费、易于安装的包管理器、环境管理器和 Python 发行版,支持平台包括Windows、macOS 和 Linux。下载安装地址:Download Anaconda Distribution | Anaconda 很多不同的项目可能需要使用不同的环境。例如某个项目需要使用pytorch1.6&#x…...

处理线程安全的列表CopyOnWriteArrayList 和Collections.synchronizedList
ConcurrentModificationException 是 Java 中的一种异常,用于指示在迭代集合时,该集合的结构发生了并发修改。 在 Java 中,许多集合类(如 ArrayList, HashMap 等)都不是线程安全的。如果一个线程在迭代集合的同时&…...

技术成神之路:设计模式(六)策略模式
1.介绍 策略模式(Strategy Pattern)是一种行为型设计模式,它定义了一系列算法,封装每一个算法,并使它们可以相互替换。策略模式使得算法的变化独立于使用算法的客户端。 2.主要作用 策略模式的主要作用是将算法或行为…...

华为OD机考题(HJ90 合法IP)
前言 经过前期的数据结构和算法学习,开始以OD机考题作为练习题,继续加强下熟练程度。 描述 IPV4地址可以用一个32位无符号整数来表示,一般用点分方式来显示,点将IP地址分成4个部分,每个部分为8位,表示成…...

值得关注的数据资产入表
不错的讲解视频,来自:第122期-杜海博士-《数据资源入表及数据资产化》-大数据百家讲坛-厦门大学数据库实验室主办第122期-杜海博士-《数据资源入表及数据资产化》-大数据百家讲坛-厦门大学数据库实验室主办-20240708_哔哩哔哩_bilibili...

Postman API性能测试:解锁高级技巧的宝库
🚀 Postman API性能测试:解锁高级技巧的宝库 在API开发和测试过程中,性能测试是确保API稳定性和可靠性的关键环节。Postman作为API测试的强大工具,提供了多种性能测试功能和高级技巧,帮助开发者深入分析API的性能表现…...

stm32中断详解
stm32中断详解 文章目录 stm32中断详解1.什么是中断?1.STM32中断系统特点2.中断处理流程3.中断配置与使用 2.AFIO寄存器3.NVIC寄存器3.中断分组、抢占优先级和响应优先级1. 中断分组2. 抢占优先级3. 响应优先级4.配置与应用 4.中断服务函数5.配置中断流程1.配置外设…...

【LeetCode】最小栈
目录 一、题目二、解法完整代码 一、题目 设计一个支持 push ,pop ,top 操作,并能在常数时间内检索到最小元素的栈。 实现 MinStack 类: MinStack() 初始化堆栈对象。 void push(int val) 将元素val推入堆栈。 void pop() 删除堆栈顶部的元…...

链接追踪系列-09.spring cloud项目整合elk显示业务日志
准备工作: 参看本系列之前篇:服务器安装elastic search 本机docker启动的kibana-tencent 使用本机安装的logstash。。。 本微服务实现的logstash配置如下: 使用腾讯云redis 启动本机mysql 启动本机docker 启动nacos,微服务依赖它作为…...

老年生活照护实训室:让养老护理更个性化
本文探讨了老年生活照护实训室在实现养老护理个性化方面的重要作用。通过分析其提供的实践环境、专业培训、模拟案例和评估机制,阐述了如何培养护理人员的个性化服务能力,以满足老年人多样化的需求,提高养老护理的质量和满意度。 在老龄化社会…...

c++课后作业
把字符串转换为整数 int main() {char pn[21];cout << "请输入一个由数字组成的字符串: ";cin >> pn;int last 0;int res[10];int j strlen(pn);int idx 2;cout << "请选择(2-二进制,10-十进制…...

SpringBoot+Vue实现简单的文件上传(txt篇)
SpringBootVue实现简单的文件上传 1 环境 SpringBoot 3.2.1,Vue 2,ElementUI 2 页面 3 效果:只能上传txt文件且大小限制为2M,选择文件后自动上传。 4 前端代码 <template><div class"container"><el-…...

LLMs之RAG:GraphRAG(本质是名词Knowledge Graph/Microsoft微软发布)的简介、安装和使用方法、案例应用之详细攻略
LLMs之RAG:GraphRAG(本质是名词Knowledge Graph/Microsoft微软发布)的简介、安装和使用方法、案例应用之详细攻略 导读:2024年7月3日,微软正式开源发布GraphRAG。GraphRAG可以提高大型语言模型在私有数据集上的推理能力。 背景痛点࿱…...