在Codelab对llama3做Lora Fine tune微调
Unsloth 高效微调大模型的工具,通过Unsloth微调Llama3, Mistral, Gemma 速度提升2-5倍,内存减少70%!
Codelab 创建一个jupyter notebook

选择 T4 GPU

安装Fine tune 相关的lib
%%capture
import torch
major_version, minor_version= torch.cuda.get_device_capability()
# Must install separately since Colab has torch 2.2.1, which breaks packages
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:# Use this for new GPs like Ampere, Hopper GPUs(RTX 30xx. RIX 40xx, A100. H100. L40)!pip install -no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
else:# Use this for older GPUs (V100, Tesla T4, RTX 20xx)!pip install --no-deps xformers trl peft accelerate bitsandbytes
pass
下载llama3
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False# 4bit pre quantized models we support for 4x faster downloading + no OOMs
fourbit_models = ["unsloth/mistral-7b-bnb-4bit","unsloth/mistral-7b-instruct-bnb-4bit","unsloth/llama-2-7b-bnb-4bit","unsloth/gemma-7b-bnb-4bit","unsloth/gemma-7b-it-bnb-4bit","unsloth/gemma-2b-bnb-4bit","unsloth/gemma-2b-it-bnb-4bit","unsloth/llama-3-8b-bnb-4bit",
] # More models at https://huggingface.co/unslothmodel, tokenizer = FastLanguageModel.from_pretrained(model_name = "unsloth/llama-3-8b-bnb-4bit",max_seq_length = max_seq_length,dtype = dtype,load_in_4bit = load_in_4bit# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf)

model = FastLanguageModel.get_peft_model(model,r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],lora_alpha = 16,lora_dropout = 0, # Supports any, but = 0 is optimizedbias = "none", # Supports any, but = "none" is optimized# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long contextrandom_state = 3407,use_rslora = False, # We support rank stabilized LoRAloftq_config = None # And LoftQ
)

加载hugging face数据集
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{}
### Input:
{}
### Response:
{}
"""EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):instructions = examples["instruction"]inputs = examples["input"]outputs = examples["output"]texts = []for instruction, input, output in zip(instructions, inputs, outputs):# Must add EOS_TOKEN, otherwise your generation will go on forever!text = alpaca_prompt.format(instruction, input, output) + EOS_TOKENtexts.append(text)return { "text": texts, }
passfrom datasets import load_dataset
dataset = load_dataset("pinzhenchen/alpaca-cleaned-zh", split="train")
dataset = dataset.map(formatting_prompts_func, batched=True,)

HuggingFace 官网, 点击数据集 Datasets

搜索数据集 alpaca-cleaned-zh

复制数据集的名字 pinzhenchen/alpaca-cleaned-zh

定义training 方法
from trl import SFTTrainer
from transformers import TrainingArgumentstrainer = SFTTrainer(model = model,tokenizer = tokenizer,train_dataset = dataset,dataset_text_field = "text",max_seq_length = max_seq_length,dataset_num_proc = 2,packing = False, # Can make training 5x faster for short sequences.args = TrainingArguments(per_device_train_batch_size = 2,gradient_accumulation_steps = 4,warmup_steps = 5,max_steps = 60,learning_rate = 2e-4,fp16 = not torch.cuda.is_bf16_supported(),bf16 = torch.cuda.is_bf16_supported(),logging_steps = 1,optim = "adamw_8bit",weight_decay = 0.01,lr_scheduler_type = "linear",seed = 3407,output_dir = "outputs",),
)
打印显存使用情况
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = (gpu_stats.name). Max memory = (max_memory) GB.")
print(f"(start_gpu_memory) GB of memory reserved.")

开始FineTune
trainer_stats = trainer.train()

#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory*100, 3)
lora_percentage = round(used_memory_for_lora / max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} GB.")

用fineTune 过的model,做问答
# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer([alpaca_prompt.format("如何保持健康", # instruction"", # input"", # output - leave this blank for generation!)], return_tensors = "pt"
).to("cuda")outputs = model.generate(**inputs, max_new_tokens = 64, use_cache=True)
tokenizer.batch_decode(outputs)

TextStreamer 流式一个字一个字地打印结果
# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer([alpaca_prompt.format("续写这段话", # instruction"天天向上,好好学习", # input"", # output - leave this blank for generation!)], return_tensors = "pt"
).to("cuda")from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens=128)

保存model到google drive 和 HuggingFace
model.save_pretrained("lora_model") # local saving
model.push_to_hub("zgpeace/lora_model", token="####") # online saving
google drive


相关文章:
在Codelab对llama3做Lora Fine tune微调
Unsloth 高效微调大模型的工具,通过Unsloth微调Llama3, Mistral, Gemma 速度提升2-5倍,内存减少70%! Codelab 创建一个jupyter notebook 选择 T4 GPU 安装Fine tune 相关的lib %%capture import torch major_version, minor_version torch…...
KEIL 5.38的ARM-CM3/4 ARM汇编设计学习笔记13 - STM32的SDIO学习5 - 卡的轮询读写擦
KEIL 5.38的ARM-CM3/4 ARM汇编设计学习笔记13 - STM32的SDIO学习5 - 卡的轮询读写擦 一、前情提要二、目标三、技术方案3.1 读写擦的操作3.1.1 读卡操作3.1.2 写卡操作3.1.3 擦除操作 3.2 一些技术点3.2.1 轮询标志位的选择不唯一3.2.2 写和擦的卡状态查询3.2.3 写的速度 四、代…...
【C++】HP-Socket(三):UdpClient、UdpServer、UdpCast、UdpNode的区别
1、简述 UDP是无连接的,在UDP传输层中并没有客户端和服务端的概念。但是可以在应用层定义客户端和服务端,可以灵活的互换客户端和服务端,或者同时既是客户端也是服务端。 HP-Socket中在应用层定义了四种UDP组件:UdpClient、UdpS…...
java设计模式六 访问者
访问者模式(Visitor Pattern)是一种设计模式,它允许你将算法附加到对象结构中的各个元素上,而不必修改对象结构本身。它主要用于处理对象结构非常稳定,但频繁需要在此结构上执行不同操作的场景。访问者模式通过将操作移…...
中间件研发之Springboot自定义starter
Spring Boot Starter是一种简化Spring Boot应用开发的机制,它可以通过引入一些预定义的依赖和配置,让我们快速地集成某些功能模块,而无需繁琐地编写代码和配置文件。Spring Boot官方提供了很多常用的Starter,例如spring-boot-star…...
libcity笔记:添加新模型(以RNN.py为例)
创建的新模型应该继承AbstractModel或AbstractTrafficStateModel 交通状态预测任务——>继承 AbstractTrafficStateModel类轨迹位置预测任务——>继承AbstractModel类 1 AbstractTrafficStateModel 2 RNN 2.1 构造函数 2.2 predict 2.3 calculate_loss...
Ansible---自动化运维工具
一、Ansible概述 1.1 Ansible简介 Ansible是一款自动化运维工具,通过ssh对目标主机进行配置、应用部署、任务执行、编排调度等操作。它简化了复杂的环境管理和自动化任务,提高了工作效率和一致性,同时,Ansible的剧本(playbooks)…...
5.Git
Git是一个分布式版本控制工具,主要用于管理开发过程中的源代码文件(Java类、xml文件、html文件等)。通过Git仓库来存储和管理这些文件,Git仓库分为两种 本地仓库:开发人员自己电脑上的Git仓库远程仓库:远程…...
探索中位数快速排序算法:高效寻找数据集的中间值
在计算机科学领域,寻找数据集的中位数是一个常见而重要的问题。而快速排序算法作为一种高效的排序算法,可以被巧妙地利用来解决中位数查找的问题。本文将深入探讨中位数快速排序算法的原理、实现方法以及应用场景,带你领略这一寻找中间值的高…...
密码学《图解密码技术》 记录学习 第十五章
目录 十五章 15.1本章学习的内容 15.2 密码技术小结 15.2.1 密码学家的工具箱 15.2.2 密码与认证 15.2.3 密码技术的框架化 15.2.4 密码技术与压缩技术 15.3 虚拟货币——比特币 15.3.1 什么是比特币 15.3.2 P2P 网络 15.3.3地址 15.3.4 钱包 15.3.5 区块链 15.3.…...
如何在 Ubuntu 16.04 上为 Nginx 创建自签名 SSL 证书
简介 TLS,即传输层安全协议,及其前身SSL,即安全套接字层,是用于将普通流量包装在受保护的加密包装中的网络协议。 使用这项技术,服务器可以在服务器和客户端之间安全地发送流量,而不会被外部方拦截。证书…...
5.协议的编解码
本章内容其实没有多大难度,主要考察大家的细心程度.计算数据长度然后截取相应字节数组并按照协议进行解码,编码则反之。 1.基础消息的编解码 Override public BasicMessage decode(byte[] bytes) {int dataLength ByteUtil.bytesToInt(ByteUtil.extra…...
数据结构基础| 线性表
线性表 定义 没有元素则为空表 例子: 稀疏多项式的运算 图书信息管理系统 特点 线性结构 同类型 线性表的类型定义 1.基本操作: InitList(&L) 操作结果:构造空的线性表L DestroyList(&L) 初始化条件:线性表L存在 操作结果:销毁线性表L(线性表L不存在) Cle…...
嵌入式学习
笔记 作业 有如下结构体 struct Student{ char name[16]; int age; double math_score; double chinese_score; double english_score; double physics_score; double chemistry…...
sass-loader和node-sass与node版本的依赖问题
sass-loader和node-sass与node版本的依赖问题 没有人会陪你走到最后,碰到了便是有缘,即使到了要下车的时候,也要心存感激地告别,在心里留下空白的一隅之地,多年后想起时依然心存甘味。——林清玄 报错截图 报错信息 np…...
基于BP神经网络的QPSK解调算法matlab性能仿真
目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 ........................................................................ for ij 1:leng…...
Linux服务器常用巡检命令
在Linux服务器上进行常规巡检是确保服务器稳定性和安全性的重要措施之一。以下是一些常用的巡检命令和技巧: 1. 查看系统信息 1.1 系统信息显示 命令:uname -a [rootlinux100 ~]# uname -a Linux linux100 4.15.0-70-generic #79-Ubuntu SMP…...
VSCode 配置 CMake
VSCode 配置 C/C 环境的详细过程可参考:VSCode 配置 C/C 环境 1 配置C/C编译环境 如果是 Windows 环境,需要安装 MingW。 方案一 可以去官网(https://sourceforge.net/projects/mingw-w64/)下载安装包。 注意安装路径不要出现中文。 打开 windows she…...
《MATLAB科研绘图与学术图表绘制从入门到精通》示例:绘制德国每日风能和太阳能产量3D线图
在MATLAB中,要绘制3D线图,可以使用 plot3 函数。 在《MATLAB科研绘图与学术图表绘制从入门到精通》书中通过绘制德国每日风能和太阳能产量3D线图解释了如何在MATLAB中绘制3D线图。 购书地址:https://item.jd.com/14102657.html...
【信息系统项目管理师知识点速记】质量管理:控制质量
控制质量是为了评估绩效,确保项目输出完整、正确且满足客户期望,而监督和记录质量管理活动执行结果的过程。控制质量过程需要在整个项目期间开展,其目的是测量产品或服务的完整性、合规性和适用性,以确保项目达到主要干系人的质量要求。 12.5.1 输入 项目管理计划 质量管理…...
聚类算法详解
聚类算法作为无监督学习的核心分支,就像一位“智能分类师”,能在没有标签的数据集里,自动把相似的对象归为一类,把不同的对象分开。它广泛应用于客户分群、图像分割、异常检测等场景,接下来我们用通俗易懂的方式拆解常…...
OBS多平台直播插件:打破平台限制的5分钟专业解决方案
OBS多平台直播插件:打破平台限制的5分钟专业解决方案 【免费下载链接】obs-multi-rtmp OBS複数サイト同時配信プラグイン 项目地址: https://gitcode.com/gh_mirrors/ob/obs-multi-rtmp 想象一下这样的场景:你精心准备的游戏直播即将开始…...
YOLO11涨点优化:训练技巧 | 基于EMA(指数滑动平均)与SWA(随机权重平均)双保险,刷榜最后一公里必备
写在前面 在目标检测竞赛和工业落地中,有一个令人头疼的现象:模型在COCO预训练权重上表现惊艳,但迁移到自己的数据集后,精度长期“趴窝”——涨不上去,也掉不下来。投入大量资源调参、改结构、加数据增强,mAP就是纹丝不动。这种“不涨点”现象已经成为许多算法工程师在冲…...
别再让Excel卡死了!手把手教你安装Oracle Crystal Ball并管理加载项(附32/64位安装包)
高效管理Oracle Crystal Ball加载项:告别Excel卡顿的终极指南 你是否经历过这样的场景:刚安装完Oracle Crystal Ball准备大展身手,却发现Excel启动速度慢得像蜗牛爬行?作为一款强大的蒙特卡洛模拟工具,Crystal Ball确…...
3步免费查询:手机号快速查找QQ号的终极Python工具指南
3步免费查询:手机号快速查找QQ号的终极Python工具指南 【免费下载链接】phone2qq 项目地址: https://gitcode.com/gh_mirrors/ph/phone2qq 你是否曾因忘记老同学的QQ号而无法联系?或者需要验证某个手机号是否关联QQ账号?phone2qq这个…...
FPGA高速收发器CDR模块深度解析:从NRZ码中“捞出”时钟的RXOUTCLKPMA是怎么工作的?
FPGA高速收发器CDR模块技术探秘:解码NRZ数据中的时钟玄机 在高速数字通信系统中,时钟数据恢复(CDR)技术如同一位技艺精湛的侦探,能够从看似杂乱无章的NRZ(非归零码)数据流中,精准地&…...
C++ 显式类型转换详解
C 显式类型转换详解一、C 显示类型转换详解1、static_cast2、dynamic_cast3、const_cast4、reinterpret_cast5、C 风格转换6、总体注意事项7、总结二、代码示例1、示例代码2、运行结果一、C 显示类型转换详解 在 C 中,类型转换是编程的核心概念之一。显示类型转换&…...
网盘下载提速终极指南:9大平台直链获取工具完整教程
网盘下载提速终极指南:9大平台直链获取工具完整教程 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天翼云…...
5分钟快速上手:用TMSpeech实现Windows离线语音转文字,保护隐私的会议记录神器
5分钟快速上手:用TMSpeech实现Windows离线语音转文字,保护隐私的会议记录神器 【免费下载链接】TMSpeech 腾讯会议摸鱼工具 项目地址: https://gitcode.com/gh_mirrors/tm/TMSpeech 还在为线上会议记录手忙脚乱吗?担心语音数据上传云端…...
微信网页版访问难题如何破解?wechat-need-web浏览器扩展的轻量级替代方案探索
微信网页版访问难题如何破解?wechat-need-web浏览器扩展的轻量级替代方案探索 【免费下载链接】wechat-need-web 让微信网页版可用 / Allow the use of WeChat via webpage access 项目地址: https://gitcode.com/gh_mirrors/we/wechat-need-web 你是否曾在公…...
