当前位置: 首页 > news >正文

rwkv模型lora微调之accelerate和deepspeed训练加速

       

目录

一、rwkv模型简介

二、lora原理简介

三、rwkv-lora微调

1、数据整理

2、环境搭建

a、Dockerfile编写

b、制造镜像

c、容器启动

3、训练代码修改

四、模型推理

1、模型推理

2、lora权重合并

3、推理web服务

五、总结


        由于业务采用的ChatGLM模型推理成本太大了,希望降低模型推理成本。因此对rwkv_1.5B模型进行了预研和业务领域的验证。为了快速验证,采用了lora+accelerate+deepspeed的训练方式。微调的过程中对rwkv模型认识更加深刻,同时对于docker训练环境搭建也更加熟悉了。这篇博客就分享一下这次微调中的一些实践,主要是关于训练流程拉通和rwkv模型在业务领域的一些结论。

一、rwkv模型简介

                rwkv模型是国人研发的一个非常优秀的模型,采用RNN架构代码目前主流的attention机制的transformer架构,在时间复杂度和空间复杂度都减少比较多的情况下,还能取得非常不错的效果,在各个榜单都有上榜。

       ​​

      上图是rwkv模型语言建模的架构,可以看到舍弃了attention机制,采用time mix 和channel mix模块。 

二、lora原理简介

      论文LoRA: Low-Rank Adaptation of Large Language Models 开发了一种方法,专为微调大模型减小显存。如下图:

       

   

对于一个参数,在微调的时候不直接微调W,而是把W通过低秩分解为两个小矩阵B和A的乘积,然后学习更新B和A,从而达到减少参数量和梯度等,同时保证模型lora微调后的效果和全参数微调效果相当。实现的时候会在BAx乘以一个系数,一般是lora_alpha/lora_rank的比值,注意lora_rank越大可学习的参数越多,显存占用就越多。

实践一般采用peft来实现对模型的linear层进行weight分解,使用方法如下:

model初始化
......
peft_config = LoraConfig(peft_type="LORA",task_type=TaskType.CAUSAL_LM,inference_mode=False,r=args.lora_rank,lora_alpha=args.lora_alpha,lora_dropout=args.lora_dropout,target_modules=args.target_modules.split(","),)
model = get_peft_model(model, peft_config)
......
model训练和保存
model_state_dict = lora.lora_state_dict(model)
torch.save(path,model_state_dict )

三、rwkv-lora微调

        rwkv的微调主要的重点内容在于数据的整理(整理成模型可训练的格式)、训练环境的搭建、训练代码的修改和最后的模型效果评估,其中至于怎么样微调才能获得比较好的效果,本文不予讨论。由于rwkv支持2中数据格式,一种是question+answer拼接,另外一种是instruction+input+response拼接;目前1.5B,rwkv开源了v4和v5版本的权重,因此这里会做4次实验,用相同的业务数据构成训练集和测试集,使用不用的权重和数据指令拼接方式进行实验。

1、数据整理

qa指令拼接——适合做问答类

{"text": "Question: 问题\n\nAnswer: 答案"}

iir指令拼接——适合做阅读理解问答

{"text": "Instruction:基于专业背景的知识问题\n\nInput:专业领域的资料背景知识内容\n\nResponse:基于上述的专业回答"}

其中Instruction 是指示,Input 是需要操作的数据(注意Input可以为空),Response是答案

我们的业务数据

{"context": "姓名:未知\n服务时间:晚上23点\n联系方式:未知\n地址:广东省深圳市龙岗区南湾街道康桥花园\n空调品牌:卡萨帝\n空调样式:挂机\n是否5匹:10匹\n故障类型:异味\n\n坐席:空调发生什么故障了,不制热、不制冷、不开机还是其他问题?\n客户:其他不故障现象\n\n以上海尔导航场景收集的要素信息以及坐席和客户的一轮对话,你是要素抽取的专家,请根据坐席和客户的对话,更新上述要素结果,对话中未提及到的要素,保持原样结果输出,“空调品牌”取值范围是“卡萨帝”、“海尔”、“统帅”、“小超人”,“空调样式”取值范围是“柜机”、“挂机”、“嵌入机”、“中央空调”,“是否5匹”取值范围是“5匹以上”、“5匹以下”,“故障类型”取值范围是“不制冷”、“不制热”、“机器制热效果差”、“机器制冷效果差”、“机器着火”、“遥控器故障”、“无法关机”、“噪音大”、“温度不能调整”、“外观伤”、“频繁开停机”、“显示屏乱码跳屏”、“机器报故障”、“室内机漏水”、“连接管未包扎好”、“送风强度”、“异味”、“漏电”、“不通电”、“不启动”、“按键失灵”、“出风异常”、“显示屏异常”、“不停机”、“不除霜”、“排水管问题”、“空调漏气/漏氟”、“购买配件”、“自动开/关机”\n请给出要素抽取结果", "target": "姓名:未知\n\n服务时间:晚上23点\n\n联系方式:未知\n\n地址:广东省深圳市龙岗区南湾街道康桥花园\n\n空调品牌:卡萨帝\n\n空调样式:挂机\n\n是否5匹:10匹\n\n故障类型:其它故障"}

qa拼接后的形式:

{"text": "Question:姓名:未知\n服务时间:晚上23点\n联系方式:未知\n地址:广东省深圳市龙岗区南湾街道康桥花园\n空调品牌:卡萨帝\n空调样式:挂机\n是否5匹:10匹\n故障类型:异味\n\n坐席:空调发生什么故障了,不制热、不制冷、不开机还是其他问题?\n客户:其他不故障现象\n\n以上是海尔导航场景收集的要素信息以及坐席和客户的一轮对话,你是要素抽取的专家,请根据坐席和客户的对话,更新上述要素结果,对话中未提及到的要素,无需输出,若所有要素在对话中均未提到,请直接输出“无效对话”,空调品牌取值范围是“卡萨帝”、“海尔”、“统帅”、“小超人”;空调样式取值范围是“柜机”、“挂机”、“嵌入机”、“中央空调”;是否5匹取值范围是“5匹”、“5匹以上”、“5匹以下”、“10匹”;故障类型取值范围是“不制冷”、“不制热”、“机器制热效果差”、“机器制冷效果差”、“机器着火”、“遥控器故障”、“无法关机”、“噪音大”、“温度不能调整”、“外观伤”、“频繁开停机”、“显示屏乱码跳屏”、“机器报故障”、“室内机漏水”、“连接管未包扎好”、“送风强度”、“异味”、“漏电”、“不通电”、“不启动”、“按键失灵”、“显示屏异常”、“不停机”、“不除霜”、“空调漏气/漏氟”、“购买配件”、“自动开/关机”、“出风异常”、“排水管问题”、“其他故障”\n请给出要素抽取结果\n\nAnswer:故障类型:其它故障"}

iir拼接后的形式:

{"text": "Instruction:以上是海尔导航场景收集的要素信息以及坐席和客户的一轮对话,你是要素抽取的专家,请根据坐席和客户的对话,更新上述要素结果,对话中未提及到的要素,无需输出,若所有要素在对话中均未提到,请直接输出“无效对话”,空调品牌取值范围是“卡萨帝”、“海尔”、“统帅”、“小超人”;空调样式取值范围是“柜机”、“挂机”、“嵌入机”、“中央空调”;是否5匹取值范围是“5匹”、“5匹以上”、“5匹以下”、“10匹”;故障类型取值范围是“不制冷”、“不制热”、“机器制热效果差”、“机器制冷效果差”、“机器着火”、“遥控器故障”、“无法关机”、“噪音大”、“温度不能调整”、“外观伤”、“频繁开停机”、“显示屏乱码跳屏”、“机器报故障”、“室内机漏水”、“连接管未包扎好”、“送风强度”、“异味”、“漏电”、“不通电”、“不启动”、“按键失灵”、“显示屏异常”、“不停机”、“不除霜”、“空调漏气/漏氟”、“购买配件”、“自动开/关机”、“出风异常”、“排水管问题”、“其他故障”\n请给出要素抽取结果\n\nInput:姓名:未知\n服务时间:晚上23点\n联系方式:未知\n地址:广东省深圳市龙岗区南湾街道康桥花园\n空调品牌:卡萨帝\n空调样式:挂机\n是否5匹:10匹\n故障类型:异味\n\n坐席:空调发生什么故障了,不制热、不制冷、不开机还是其他问题?\n客户:其他不故障现象\n\nResponse:故障类型:其它故障"}

2、环境搭建

        官方代码库指定的环境直接安装就好了,不过安装的过程中要注意机器的显卡驱动一定要比安装的cuda版本要高,并且cuda版本的算力不能低于显卡的算力(大多数情况下,显卡是支持一定的低版本的cuda和torch的);torch的版本要和cuda的版本一致,比如4090显卡安装了12.0的显卡驱动,安装了cuda11.8,那么torch也要安装cuda11.8的版本 torch2.0_cu118。rwkv有自己实现的cuda算子需要python调用C++和nvcc来编译作为torch的扩展,所以要严格匹配版本,不然会报显卡算力过高和torch版本不匹配,cuda和torch版本不匹配等错误。C++编译的时候还需要完整的libso库文件,由于本人使用的机器多人使用,不好升级libso库文件——错误操作可能会导致linux系统出错。稳妥起见直接使用docker来搭建训练环境,并且在docker中训练。物理机器上安装docker,编写dockerfile后,制作镜像,启动容器然后训练就OK了。

a、Dockerfile编写
##build 镜像
#docker build -t  images_name(images_name:tag) -f ./Dockerfile .
##运行容器  --gpus all 宿主机上的显卡可用  --ipc host  代表与宿主机器共享命名空间,即让Docker容器和宿主机器使用同一个进程ID命名空间和信号命名空间,从而实现进程间通信的能力
## --network host docker 使用本机的IP和端口
#docker run -d -it --name my_container --gpus all --network host --ipc host  images_name(id)#cuda toolkit共享的库,涵盖了运行环境的最小集合如动态库等,但没有cuda的编译工具nvcc
#FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04#基于runtime,添加了编译工具链、调试工具、头文件、静态库,用于从源码编译cuda应用,是有nvcc的
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04WORKDIR /rwkv
# Set up time zone.
ENV TZ=Asia/Shanghai
RUN  ln -snf /usr/share/zoneinfo/$TZ /etc/localtimeENV STAGE_DIR=/tmp
RUN mkdir -p ${STAGE_DIR}RUN  apt-get update && \apt-get install -y --no-install-recommends \software-properties-common build-essential autotools-dev \nfs-common pdsh \cmake g++ gcc \curl wget vim tmux emacs less unzip \htop iftop iotop ca-certificates openssh-client openssh-server \rsync iputils-ping net-toolsRUN  apt-get update && \apt-get install -y --no-install-recommends \libsndfile-dev \libcupti-dev \libjpeg-dev \libpng-dev \screen \libaio-dev#从源码安装python
RUN apt install unzip wget build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libsqlite3-dev libreadline-dev libffi-dev curl libbz2-dev pkg-config make -y
RUN apt-get install liblzma-dev -y
#RUN wget https://www.python.org/ftp/python/3.10.10/Python-3.10.10.tar.xz
COPY Python-3.10.10.tar.xz ./
RUN tar xf Python-3.10.10.tar.xz
RUN cd Python-3.10.10 && ./configure --enable-optimizations && make altinstall && cd .. && rm -fr *
RUN python3.10 -m pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu118WORKDIR /rwkv
COPY requirements.txt ./
#RUN python3.10 -m pip install -r requirements.txt
#RUN python3.10 -m pip install --upgrade pip && python3.10 -m pip install -i  https://mirrors.aliyun.com/pypi/simple -r requirements.txt
RUN  python3.10 -m pip install -i  https://mirrors.aliyun.com/pypi/simple -r requirements.txt
# 拷贝所有nue文件
COPY . ./

        注意python可以提前现在源码,然后上传到服务器再制作镜像;cuda docker 一定要拉取devel版本,runtime版本会精简,不安装nvcc等编译工具,python安装一些第三方库会依赖nvcc编译工具的。其他的都没有什么了,一切正常编写即可。

b、制造镜像
docker build -t  images_name(images_name:tag) -f ./Dockerfile .

这个耗时比较久,一个是镜像、已经库文件安装,还有数据、代码等copy。

c、容器启动
docker run -d -it --name my_container --gpus all --network host --ipc host  images_name(id)

        关注的地方是--gpus 一定要是all,这样容器才能使用物理机上的所有显卡;--network host保证docker使用物理机的ip和端口,可以通过改ip访问docker内的服务;--ipc host让Docker容器和宿主机器使用同一个进程ID命名空间和信号命名空间,从而实现进程间通信的能力——跑分布式训练必须选项,因为多进程中的子进程要和主进程进行通信,传输梯度等信息。

3、训练代码修改

        原始的训练代码是不支持lora和accelerate的,这里我们修改为支持lora以及accelerate的形式。同时由于采用分布式训练,目前可以使用deepspeed来做,而accelerate也支持deepspeed的插件形式(和直接使用deepspeed来做分布式训练稍有不同,直接使用deepspeed对系统的各种库libso要求的比较严格,之前使用deepspeed一直没有成功过)。代码主体结构如下:

from accelerate import Accelerator, DeepSpeedPlugin
from peft import get_peft_model, LoraConfig, TaskType
import loralib as lora#初始化分布式环境
accumulate_step = 4
mixed_precision = 'bf16'
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step)
accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin)
device = accelerator.device......
......
model = RWKV(args)#lora设置,设置模型的那些参数使用lora以及其他的一些参数。
peft_config = LoraConfig(peft_type="LORA",task_type=TaskType.CAUSAL_LM,inference_mode=False,r=args.lora_rank,lora_alpha=args.lora_alpha,lora_dropout=args.lora_dropout,target_modules=args.target_modules.split(","),)
model = get_peft_model(model, peft_config)
......
#模型、优化器、数据加载器等用accelerate包装一下。
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer,train_dataloader)
......
for epoch in range(int(args.epoch_count)):for step, batch in enumerate(t := tqdm(train_dataloader, ncols=100)):model(batch)......accelerator.backward(loss)optimizer.step()lr_scheduler.step()optimizer.zero_grad()

分布式环境的初始化以及lora参数的设置,针对rwkv模型lora设置如下:

lora_rank=16
lora_alpha=32
lora_dropout=0.1
target_modules=emb,key,value,receptance,output,head

完整的训练代码如下(其他的部分自行完成,代码修改自rwkv_LM中的rwkv-v4neo):

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, warnings, math, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
import logging
from transformers import get_linear_schedule_with_warmup
from argparse import ArgumentParser
logging.basicConfig(level=logging.INFO)
import os
import sys
sys.path.append(os.getcwd())
def script_method(fn, _rcb=None):return fndef script(obj, optimize=True, _frames_up=0, _rcb=None):return objimport torch.jitscript_method1 = torch.jit.script_method
script1 = torch.jit.script
torch.jit.script_method = script_method
torch.jit.script = scriptfrom torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn as nnfrom torch.utils.data import DataLoader
import gcimport psutil
import traceback
from tqdm import tqdm
import numpy as npfrom accelerate import Accelerator, DeepSpeedPlugin
from torch.utils.data import Dataset, IterableDataset
import random
import json
from collections import defaultdictimport threading
from tokenizer import build_tokenizer
from datetime import datetime
from peft import get_peft_model, LoraConfig, TaskType
import loralib as loraaccumulate_step = 4
mixed_precision = 'bf16'
deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step)
accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin)
device = accelerator.devicedef b2mb(x):return int(x / 2 ** 20)class TorchTracemalloc:def __enter__(self):gc.collect()torch.cuda.empty_cache()torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zeroself.begin = torch.cuda.memory_allocated()self.process = psutil.Process()self.cpu_begin = self.cpu_mem_used()self.peak_monitoring = Truepeak_monitor_thread = threading.Thread(target=self.peak_monitor_func)peak_monitor_thread.daemon = Truepeak_monitor_thread.start()return selfdef cpu_mem_used(self):"""get resident set size memory for the current process"""return self.process.memory_info().rssdef peak_monitor_func(self):self.cpu_peak = -1while True:self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)# can't sleep or will not catch the peak right (this comment is here on purpose)# time.sleep(0.001) # 1msecif not self.peak_monitoring:breakdef __exit__(self, *exc):self.peak_monitoring = Falsegc.collect()torch.cuda.empty_cache()self.end = torch.cuda.memory_allocated()self.peak = torch.cuda.max_memory_allocated()self.used = b2mb(self.end - self.begin)self.peaked = b2mb(self.peak - self.begin)self.cpu_end = self.cpu_mem_used()self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")def collate_fn(batch):tokens, labels, domains = zip(*batch)input_ids = torch.nn.utils.rnn.pad_sequence(tokens,batch_first=True,padding_value=0)labels = torch.nn.utils.rnn.pad_sequence(labels,batch_first=True,padding_value=-100)domains = torch.stack(domains)return {"input_ids": input_ids, "labels": labels, "domains":domains}idx2domain = {}
domain2idx = {}
# 所有数据全部加载 batch内采样
class DataReader(Dataset):def __init__(self,tokenizer, file_list, sample_ratios, domain_names, max_token, args):self.args = argsself.tokenizer = tokenizerfile_list = file_list.split(",")sample_ratios = list(map(float, sample_ratios.split(",")))domain_names = domain_names.split(",")assert len(file_list) == len(sample_ratios) and len(file_list) == len(domain_names)self.file_list = file_listself.domain_names = domain_namesself.max_token = max_tokenself.sample_ratios = sample_ratiosself.sum_ratio = sum(sample_ratios)print("self.sum_ratio: ",self.sum_ratio)assert self.sum_ratio <= 1.0self.cum_ratios = [sum(sample_ratios[:i + 1]) for i in range(len(sample_ratios))]print("file_list: {}, sample_ratios: {} cum_ratios:{}".format(file_list, sample_ratios, self.cum_ratios))self.domain2num = defaultdict(int)self.common_datas = {}for i in range(len(file_list)):domain2idx[domain_names[i]] = iidx2domain[i] = domain_names[i]self.common_datas[domain_names[i]] = self.loaddata_convert_token_to_ids(domain_names[i], file_list[i])print(file_list[i], len(self.common_datas[domain_names[i]]))print("domain2num:{}".format(self.domain2num))self.train_data = []self.index = 0self.epoch = 0self.train_length = 4000self.train_step = 1000def loaddata_convert_token_to_ids(self, domain_name, file_path):with open(file_path, 'r', encoding='utf-8') as f:lines = f.readlines()domain_idx = domain2idx[domain_name]all_datas = []for line in tqdm(lines[0:], desc=f"read{file_path}",ncols=100):text = json.loads(line)["text"]text = text.split('\n\n')q = '\n\n'.join(text[0:3]) + "Answer:"a = '\n\n'.join(text[3:])a = a.replace('Answer:',"")q_ids = self.tokenizer.tokenize(q)a_ids = self.tokenizer.tokenize(a)ids = q_ids + a_idsids.append(self.tokenizer.eod)if len(ids) > 2:if len(ids) > self.max_token:# 大于最大长度的数据丢弃掉continueelse:labels = [-100] * len(q_ids) + a_ids + [self.tokenizer.eod]assert len(ids) == len(labels), " len(ids) != len(labels)"input_ids = torch.as_tensor(ids[:-1], dtype=torch.long)labels = torch.as_tensor(labels[1:], dtype=torch.long)domain_idx = torch.as_tensor(domain_idx, dtype=torch.long)all_datas.append((input_ids, labels, domain_idx))print(f"{file_path}--{len(all_datas)}")self.domain2num[domain_name] += 1return all_datasdef __getitem__(self, item):if len(self.train_data) == 0:time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")print("=============={}==============".format(time_str))for k, v in self.common_datas.items():if k in ['friso','kongtiao','qa','other']:self.train_data.extend(v)else:split_count = len(v)//20epoch = self.epoch % 20temp = v[epoch*split_count:(epoch+1)*split_count]# temp = random.choices(v, k=split_count)self.train_data.extend(temp)print(f"len(self.train_data) {len(self.train_data)} epoch {self.epoch}")if self.index < self.train_step:self.index += 1if item >= len(self.train_data):item = random.randint(0,len(self.train_data)-1)input_ids, labels, domain_idx = self.train_data[item]return input_ids, labels, domain_idxelse:self.epoch += 1self.index = 0self.train_data = []for k, v in self.common_datas.items():if k in ['friso','kongtiao','qa','other']:self.train_data.extend(v)else:split_count = len(v)//20epoch = self.epoch % 20temp = v[epoch*split_count:(epoch+1)*split_count]# temp = random.choices(v, k=split_count)self.train_data.extend(temp)print(f"len(self.train_data) {len(self.train_data)} epoch {self.epoch}")self.index += 1if item >= len(self.train_data):item = random.randint(0, len(self.train_data) - 1)input_ids, labels, domain_idx = self.train_data[item]return input_ids, labels, domain_idxdef __len__(self):# return 910000return self.train_lengthif __name__ == "__main__":parser = ArgumentParser()parser.add_argument("--file_list", default="", type=str)parser.add_argument("--sample_ratios", default="utf-8", type=str)parser.add_argument("--domain_names", default="", type=str)parser.add_argument("--use_owndatareader", default="1", type=str)parser.add_argument("--logdir", default="", type=str)parser.add_argument("--datadir", default="", type=str)parser.add_argument("--save_step",default=50000,type=int)# loraparser.add_argument("--lora_rank", default=16, type=int)parser.add_argument("--lora_alpha", default=32, type=int)parser.add_argument("--lora_dropout", default=0.1, type=float)parser.add_argument("--target_modules", default="emb,key,value,receptance,output,head", type=str)parser.add_argument("--load_model", default="/AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth", type=str)  # full path, with .pthparser.add_argument("--wandb", default="", type=str)  # wandb project name. if "" then don't use wandbparser.add_argument("--proj_dir", default="out", type=str)parser.add_argument("--random_seed", default="-1", type=int)parser.add_argument("--data_file", default="", type=str)parser.add_argument("--data_type", default="utf-8", type=str)parser.add_argument("--vocab_size", default=65536, type=int)  # vocab_size = 0 means auto (for char-level LM and .txt data)parser.add_argument("--ctx_len", default=2560, type=int)parser.add_argument("--epoch_steps", default=1000, type=int)  # a mini "epoch" has [epoch_steps] stepsparser.add_argument("--epoch_count", default=500, type=int)  # train for this many "epochs". will continue afterwards with lr = lr_finalparser.add_argument("--epoch_begin", default=0, type=int)  # if you load a model trained for x "epochs", set epoch_begin = xparser.add_argument("--epoch_save", default=5, type=int)  # save the model every [epoch_save] "epochs"parser.add_argument("--micro_bsz", default=12, type=int)  # micro batch size (batch size per GPU)parser.add_argument("--n_layer", default=24, type=int)parser.add_argument("--n_embd", default=2048, type=int)parser.add_argument("--dim_att", default=0, type=int)parser.add_argument("--dim_ffn", default=0, type=int)parser.add_argument("--pre_ffn", default=0, type=int)  # replace first att layer by ffn (sometimes better)parser.add_argument("--head_qk", default=0, type=int)  # my headQK trickparser.add_argument("--tiny_att_dim", default=0, type=int)  # tiny attention dimparser.add_argument("--tiny_att_layer", default=-999, type=int)  # tiny attention @ which layerparser.add_argument("--lr_init", default=6e-4, type=float)  # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048parser.add_argument("--lr_final", default=1e-5, type=float)parser.add_argument("--warmup_steps", default=-1, type=int)  # try 50 if you load a modelparser.add_argument("--beta1", default=0.9, type=float)parser.add_argument("--beta2", default=0.99, type=float)  # use 0.999 when your model is close to convergenceparser.add_argument("--adam_eps", default=1e-8, type=float)parser.add_argument("--grad_cp", default=0, type=int)  # gradient checkpt: saves VRAM, but slowerparser.add_argument("--dropout", default=0, type=float) # try 0.01 / 0.02 / 0.05 / 0.1parser.add_argument("--weight_decay", default=0, type=float) # try 0.1 / 0.01 / 0.001parser.add_argument("--weight_decay_final", default=-1, type=float)parser.add_argument("--my_pile_version", default=1, type=int)  # my special pile versionparser.add_argument("--my_pile_stage", default=0, type=int)  # my special pile modeparser.add_argument("--my_pile_shift", default=-1, type=int)  # my special pile mode - text shiftparser.add_argument("--my_pile_edecay", default=0, type=int)parser.add_argument("--layerwise_lr", default=1, type=int)  # layerwise lr for faster convergence (but slower it/s)parser.add_argument("--ds_bucket_mb", default=200, type=int)  # deepspeed bucket size in MB. 200 seems enough# parser.add_argument("--cuda_cleanup", default=0, type=int)  # extra cuda cleanup (sometimes helpful)parser.add_argument("--my_img_version", default=0, type=str)parser.add_argument("--my_img_size", default=0, type=int)parser.add_argument("--my_img_bit", default=0, type=int)parser.add_argument("--my_img_clip", default='x', type=str)parser.add_argument("--my_img_clip_scale", default=1, type=float)parser.add_argument("--my_img_l1_scale", default=0, type=float)parser.add_argument("--my_img_encoder", default='x', type=str)# parser.add_argument("--my_img_noise_scale", default=0, type=float)parser.add_argument("--my_sample_len", default=0, type=int)parser.add_argument("--my_ffn_shift", default=1, type=int)parser.add_argument("--my_att_shift", default=1, type=int)parser.add_argument("--head_size_a", default=64, type=int) # can try larger values for larger modelsparser.add_argument("--head_size_divisor", default=8, type=int)parser.add_argument("--my_pos_emb", default=0, type=int)parser.add_argument("--load_partial", default=0, type=int)parser.add_argument("--magic_prime", default=0, type=int)parser.add_argument("--my_qa_mask", default=0, type=int)parser.add_argument("--my_random_steps", default=0, type=int)parser.add_argument("--my_testing", default='', type=str)parser.add_argument("--my_exit", default=99999999, type=int)parser.add_argument("--my_exit_tokens", default=0, type=int)args = parser.parse_args()summary_writer = SummaryWriter(args.logdir)print(args)########################################################################################################np.set_printoptions(precision=4, suppress=True, linewidth=200)warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*")# os.environ["WDS_SHOW_SEED"] = "1"args.my_timestamp = datetime.today().strftime("%Y-%m-%d-%H-%M-%S")args.enable_checkpointing = Falseargs.replace_sampler_ddp = Falseargs.logger = Falseargs.gradient_clip_val = 1.0args.num_sanity_val_steps = 0args.check_val_every_n_epoch = int(1e20)args.log_every_n_steps = int(1e20)args.max_epochs = -1  # continue foreverargs.betas = (args.beta1, args.beta2)args.real_bsz = args.micro_bszos.environ["RWKV_T_MAX"] = str(args.ctx_len)os.environ["RWKV_MY_TESTING"] = args.my_testingos.environ["RWKV_HEAD_SIZE_A"] = str(args.head_size_a)if args.dim_att <= 0:args.dim_att = args.n_embdif args.dim_ffn <= 0:if 'r3' in args.my_testing:args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)else:args.dim_ffn = args.n_embd * 4if args.data_type == "wds_img":args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"args.proj_dir = f"{args.proj_dir}-{args.run_name}"else:args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}"if accelerator.is_main_process and not os.path.exists(args.proj_dir):os.makedirs(args.proj_dir)if args.my_pile_stage > 0:magic_prime_bak = args.magic_primeif args.my_pile_version == 1:if args.ctx_len == 1024:args.magic_prime = 324331313elif args.ctx_len == 2048:args.magic_prime = 162165671elif args.ctx_len == 4096:args.magic_prime = 81082817elif args.ctx_len == 8192:args.magic_prime = 40541399else:if args.ctx_len == 1024:args.magic_prime = 1670239709elif args.ctx_len == 2048:args.magic_prime = 835119767elif args.ctx_len == 4096:args.magic_prime = 417559889elif args.ctx_len == 6144:args.magic_prime = 278373239elif args.ctx_len == 8192:args.magic_prime = 208779911if args.my_pile_shift < 0:args.my_pile_shift = 0if magic_prime_bak > 0:args.magic_prime = magic_prime_bakif args.my_qa_mask == 2:args.epoch_count = 2 * args.magic_prime // 40320else:args.epoch_count = args.magic_prime // 40320args.epoch_steps = 40320 // args.real_bszassert args.epoch_steps * args.real_bsz == 40320# if args.my_pile_stage == 2:#     assert args.lr_final == args.lr_initif args.my_pile_stage >= 2:  # find latest saved modellist_p = []for p in os.listdir(args.proj_dir):if p.startswith("rwkv") and p.endswith(".pth"):p = ((p.split("-"))[1].split("."))[0]if p != "final":if p == "init":p = -1else:p = int(p)list_p += [p]list_p.sort()max_p = list_p[-1]if len(list_p) > 1:args.my_pile_prev_p = list_p[-2]  # in case max_p is corruptedif max_p == -1:args.load_model = f"{args.proj_dir}/rwkv-init.pth"else:args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"if args.warmup_steps < 0:if args.my_pile_stage == 2:args.warmup_steps = 10else:args.warmup_steps = 30args.epoch_begin = max_p + 1samples_per_epoch = args.epoch_steps * args.real_bsztokens_per_epoch = samples_per_epoch * args.ctx_lenassert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"]args.precision = "bf16"assert args.precision in ["fp32", "tf32", "fp16", "bf16"]os.environ["RWKV_FLOAT_MODE"] = args.precision# os.environ["RWKV_JIT_ON"] = "1"os.environ["RWKV_JIT_ON"] = "0"torch.backends.cudnn.benchmark = Truetorch.backends.cudnn.enabled = Trueif args.precision == "fp32":torch.backends.cudnn.allow_tf32 = Falsetorch.backends.cuda.matmul.allow_tf32 = Falseelse:torch.backends.cudnn.allow_tf32 = Truetorch.backends.cuda.matmul.allow_tf32 = Trueargs.precision = "bf16"if args.data_type == 'wds_img':from src.model_img import RWKV_IMGmodel = RWKV_IMG(args)else:from src.model import RWKVmodel = RWKV(args)try:load_dict = torch.load(args.load_model, map_location="cpu")load_keys = list(load_dict.keys())for k in load_keys:if k.startswith('_forward_module.'):load_dict[k.replace('_forward_module.','')] = load_dict[k]del load_dict[k]except:if args.my_pile_stage >= 2:  # try again using another checkpointmax_p = args.my_pile_prev_pif max_p == -1:args.load_model = f"{args.proj_dir}/rwkv-init.pth"else:args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"args.epoch_begin = max_p + 1load_dict = torch.load(args.load_model, map_location="cpu")model.load_state_dict(load_dict)peft_config = LoraConfig(peft_type="LORA",task_type=TaskType.CAUSAL_LM,inference_mode=False,r=args.lora_rank,lora_alpha=args.lora_alpha,lora_dropout=args.lora_dropout,target_modules=args.target_modules.split(","),)model = get_peft_model(model, peft_config)optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr_init)tokenizer_type = "RWKVTokenizer"vocab_file = "./json2binidx/rwkv_vocab_v20230424.txt"tokenizer = build_tokenizer(tokenizer_type, vocab_file)train_data = DataReader(tokenizer, args.file_list, args.sample_ratios, args.domain_names, args.ctx_len, args)# train_data = DataReader( tokenizer, args.ctx_len, args.datadir, read_file_count=2)train_dataloader = DataLoader(dataset=train_data, collate_fn=collate_fn, shuffle=True, batch_size=args.micro_bsz)print(f"已经加载完了数据:{len(train_dataloader)}条")warm_up_ratio = 0.1lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=int(len(train_dataloader) / accumulate_step * warm_up_ratio),num_training_steps=(int(len(train_dataloader) / accumulate_step) * args.epoch_count),)model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)print(f"已经加载完了数据:{len(train_dataloader)}条")loss_fct = nn.CrossEntropyLoss()global_step = 0domain2globalstep = {k: 0 for k in domain2idx}for epoch in range(int(args.epoch_count)):name2loss = {k: 0 for k in domain2idx}domain2step = {k: 0 for k in domain2idx}print("name2loss",name2loss)total_loss = 0mean_loss = 0domain2num = {k: 0 for k in domain2idx}with TorchTracemalloc() as tracemalloc:model.to(device).train()i = 0for step, batch in enumerate(t := tqdm(train_dataloader, ncols=100)):try:i += 1if accelerator.is_main_process and i % args.save_step == 0:model_state_dict = lora.lora_state_dict(accelerator.unwrap_model(model))save_path = os.path.join(args.proj_dir, f"rwkv-epoch{epoch}_step{i}_lora.pt")accelerator.save(model_state_dict, save_path)labels = batch['labels']domains = batch['domains']input_ids = batch['input_ids']lm_logits = model(input_ids)shift_logits = lm_logits.contiguous()shift_labels = labels.contiguous()loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))accelerator.backward(loss)optimizer.step()lr_scheduler.step()optimizer.zero_grad()if i % 50 == 0:torch.cuda.empty_cache()loss_detach = loss.detach().cpu().float()total_loss += loss_detachtime_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")des_train = f"{time_str} shape:{input_ids.shape[1]} loss: {loss_detach}"for domian_name, domian_idx in domain2idx.items():select_idx = domains == domian_idxselect_shift_logits = shift_logits[select_idx]select_shift_labels = shift_labels[select_idx]loss_domain = 0if len(select_shift_labels) > 0:domain2num[domian_name] += len(select_shift_labels)loss_domain = loss_fct(select_shift_logits.view(-1, select_shift_logits.size(-1)),select_shift_labels.view(-1)).detach().cpu().float()domain2globalstep[domian_name] += 1domain2step[domian_name] += 1name2loss[domian_name] += loss_domainsummary_writer.add_scalar(f"train_step/{domian_name}", loss_domain, domain2globalstep[domian_name])des_train += f" {domian_name}: {loss_domain}"# domain2loss_detach[domian_name] = loss_domaint.set_description(des_train)# t.set_postfix(des_train)if accelerator.is_main_process:summary_writer.add_scalar(f"train_step/total_loss", loss_detach, global_step)global_step += 1except Exception as e:print(str(e))print(traceback.format_exc())print("oom", batch['input_ids'].shape)optimizer.zero_grad()torch.cuda.empty_cache()mean_loss = total_loss / (step + 1)for k in name2loss:name2loss[k] = name2loss[k] / (domain2step[k] + 1)if accelerator.is_main_process:summary_writer.add_scalar(f"train/{k}", name2loss[k], epoch)s = ""s_num = ""for k, v in name2loss.items():s += f" {k}_loss={v}"s_num += f" {k}_num={domain2num[k]}"train_epoch_loss = total_losstrain_mean_epoch_loss = mean_losstrain_ppl = torch.exp(train_epoch_loss)time_str = datetime.now().strftime("%Y-%m-%d-%H_%M_%S")accelerator.print(f"{time_str}  epoch={epoch}: train_ppl={train_ppl} train_epoch_loss={train_epoch_loss} train_mean_epoch_loss={train_mean_epoch_loss}")accelerator.print(s)accelerator.print(s_num)accelerator.wait_for_everyone()

accelerate联合deepspeed启动的时候需要配置文件:

compute_environment: LOCAL_MACHINE
deepspeed_config:gradient_accumulation_steps: 1gradient_clipping: 1.0offload_optimizer_device: noneoffload_param_device: nonezero3_init_flag: falsezero3_save_16bit_model: falsezero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
dynamo_backend: 'yes'
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: fp16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
use_cpu: true
main_process_port: 20667

主要关注num_processes,要和使用的显卡数量一致。

训练启动脚本,使用CUDA_VISIBLE_DEVICES指定机器上使用的显卡;nohup后台启动;accelerate launch 启动accelerate;--config_file 配置文件设置以及deepspeed的配置等

CUDA_VISIBLE_DEVICES=1,2,4,5 nohup  accelerate launch --config_file accelerate_ds_zero3_cpu_offload_config.yaml  train_accelerator_deepspeed_lora_v1.py \
--load_model /AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth
......
......

采用lora以及2张4090来训练,只需要几分钟就可以训练好一个epoch,显存占用也非常友好:

四、模型推理

1、模型推理

模型推理使用rwkv第三方库来实现,核心逻辑如下:

from rwkv.model import RWKV
from rwkv.utils import PIPELINE
model = RWKV(model='./rwkv.pth', strategy='cuda bf16')
model.eval()
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")out_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
token = None
for i in range(max_length):tokens = pipeline.encode(ctx) if i == 0 else [token]out, state = pipeline.model.forward(tokens, state)for n in occurrence:out[n] -= (0.4 + occurrence[n] * 0.4)  # repetition penaltytoken = pipeline.sample_logits(out, temperature=1.0, top_p=0.0)if token == 0:break  # exit when 'endoftext'out_tokens += [token]occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)tmp = pipeline.decode(out_tokens[out_last:])if ('\ufffd' not in tmp) and (not tmp.endswith('\n')):# print(tmp, end='', flush=True)out_str += tmpout_last = i + 1
return out_str

同时由于采用lora训练因此需要把lora权重合并到原始的权重上,方可使用上述方式进行模型加载和推理

2、lora权重合并

lora权重合并到原始权重,依据公式直接实现,代码如下:

def merge_lora_weights():rwkv_path = "RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"lora_path = "./lora.pt"print("lora_path: ",lora_path)model_weight = torch.load(rwkv_path, map_location='cpu')lora_model = torch.load(lora_path,  map_location='cpu')for k, v in tqdm(model_weight.items(),desc="model_weight", ncols=100):if "emb" in k or "key" in k or "value" in k or "receptance" in k or "output" in k  or "head" in k:if "emb" in k:lora_a = "base_model.model." + k.replace(".weight", ".lora_embedding_A.default")lora_b = "base_model.model." + k.replace(".weight", ".lora_embedding_B.default")device = v.devicew_a = lora_model[lora_a].Tw_b = lora_model[lora_b].Tw = torch.mm(w_a, w_b).cpu()new_w = v.cpu() + 2 * wmodel_weight[k] = new_w.to(device)elif "weight" in k:lora_a = "base_model.model." + k.replace(".weight", ".lora_A.default.weight")lora_b = "base_model.model." + k.replace(".weight", ".lora_B.default.weight")device = v.devicew_a = lora_model[lora_a]w_b = lora_model[lora_b]w = torch.mm(w_b, w_a).cpu()# w = torch.mm(w_b, w_a)new_w = v.cpu() + 2 * wmodel_weight[k] = new_w.to(device)else:model_weight[k] = velse:model_weight[k] = vrwkv_lora_path = "./rwkv.pth"torch.save(model_weight,rwkv_lora_path)print("merge_lora_weights finished!")

3、推理web服务

一般都是需要提供web接口,采用aiohttp来做异步web接口,把上述模型推理和lora权重合并功能逻辑集成到web服务程序中:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import asyncio
import json
import logging.handlers
import os
import socket
import timeimport aiohttpfrom aiohttp import webimport torch
from argparse import ArgumentParser
from tqdm import tqdmtorch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = Trueos.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1'
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS# logger
log_level = logging.DEBUGlogger = logging.getLogger(__name__)
logger.setLevel(log_level)formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(filename)s:%(lineno)s %(message)s')stream_handler = logging.StreamHandler()
stream_handler.setLevel(log_level)
stream_handler.setFormatter(formatter)os.makedirs('./log', exist_ok=True)
file_handler = logging.handlers.RotatingFileHandler(filename='log/server.log', maxBytes=10 << 20, backupCount=5,encoding='utf8')
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)logger.addHandler(stream_handler)
logger.addHandler(file_handler)#
NODE_NAME = 'general.rwkv.loratest_20231010'
NODE_NAME_2 = 'general.chat.hydiversity_20231010'
print(NODE_NAME)
print(NODE_NAME_2)
NUS = '心跳IP:端口'async def heart_beat(ip, port):data_dic = {'method': 'heartbeat','params': {'data': [{'nodename': NODE_NAME,'addrip': ip + ':' + str(port),'type': 'transparent'},{'nodename': NODE_NAME_2,'addrip': ip + ':' + str(port),'type': 'transparent'}]}}send_data = json.dumps(data_dic)client = aiohttp.ClientSession()while True:try:await client.post(f'http://{NUS}/heartbeat', data=send_data)except Exception as e:logger.error(f'send heartbeat fail: {e}')await asyncio.sleep(1)class TimeMeasure:def __init__(self, desc=''):self.start = 0self.desc = descdef __enter__(self):self.start = time.time()logger.info(f'{self.desc} start')def __exit__(self, exc_type, exc_val, exc_tb):end = time.time()cost_s = end - self.startif cost_s > 10:cost_s = round(cost_s, 2)logger.info(f'{self.desc} end, cost : {cost_s}s')else:cost_ms = round(cost_s * 1000, 2)logger.info(f'{self.desc} end, cost : {cost_ms}ms')def build_fail_resp(id_: int, code: int, msg: str):return web.json_response({'id': id_,'jsonrpc': '2.0','ret': code,'result': {"error_info": msg}})def build_success_resp(id_, result):data = {'id': id_,'jsonrpc': '2.0','ret': 0,'result': {'chatInfo': {'answer': result,'elements':[]}}}for ele in result.split('\n\n'):ele = ele.split(":")try:temp = {"tag":ele[0],"value":ele[1]}data['result']['chatInfo']['elements'].append(temp)except Exception as e:print(e)send_data = json.dumps(data, ensure_ascii=False)return web.json_response(text=send_data)class Server:def __init__(self):self.lock = asyncio.Semaphore(20)self.model = RWKV(model='./rwkv.pth', strategy='cuda bf16')# self.model = RWKV(model='./rwkv.pth', strategy='cuda fp16')self.model.eval()self.pipeline = PIPELINE(self.model, "rwkv_vocab_v20230424")out_str = self.chat("Question:你好呀,你是谁?\n\nAnswer:")logger.info(f'out_str——{out_str}')logger.info(f'Server __init__ finished!')@torch.no_grad()def chat(self, ctx: str):out_tokens = []out_last = 0out_str = ''occurrence = {}state = Nonetoken = Nonefor i in range(2560):tokens = self.pipeline.encode(ctx) if i == 0 else [token]out, state = self.pipeline.model.forward(tokens, state)for n in occurrence:out[n] -= (0.4 + occurrence[n] * 0.4)  # repetition penaltytoken = self.pipeline.sample_logits(out, temperature=1.0, top_p=0.0)if token == 0:break  # exit when 'endoftext'out_tokens += [token]occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)tmp = self.pipeline.decode(out_tokens[out_last:])if ('\ufffd' not in tmp) and (not tmp.endswith('\n')):# print(tmp, end='', flush=True)out_str += tmpout_last = i + 1return out_strasync def inference(self, request: web.Request):req = await request.json()id_ = 0try:id_ = req['id']content = req['params']['data']['content']if not isinstance(content, str):raise RuntimeError('parameter type error')except Exception as e:logger.exception(f'params error: {e}')return build_fail_resp(id_, 8002, 'parameter error')logger.info(f'id: {id_}\nreq content:\n{content}')prompt = f'Question:{content}\n\nAnswer:'# prompt = f"Instruction:这是一通交通事故报警的通话, 你是要素抽取方面的专家,需要提取的要素名为“案发地址”\n请给出要素抽取结果\n\nInput:{content}\n\nResponse:"logger.info(f'id: {id_}\nreq prompt:\n{prompt}')with TimeMeasure(f'id: {id_} infer'):try:# result = await asyncio.get_running_loop().run_in_executor(None, self.chat, prompt)result = await asyncio.to_thread(self.chat, prompt)except Exception as e:logger.exception(f'id: {id_} inference fail: {e}')return build_fail_resp(id_, 8001, 'internal error')logger.info(f'id: {id_}, resp: {result}')return build_success_resp(id_, result)def get_local_ip(ip, port):try:conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)conn.connect((ip, port))ip = conn.getsockname()[0]except Exception:raiseconn.close()return ipasync def main(ip, port):server = Server()app = web.Application()app.add_routes([web.post('/nlp', server.inference)])asyncio.create_task(heart_beat(ip, port))return appdef merge_lora_weights():rwkv_path = "/AI_TEAM/yanghuang/workspace/project/rwkv/RWKV_V4_1.5B/RWKV-4-World-CHNtuned-1.5B-v1-20230620-ctx4096.pth"lora_path = "./output/20231016_kongtiao_v1/rwkv-epoch5_step1000_lora.pt"print("lora_path: ",lora_path)model_weight = torch.load(rwkv_path, map_location='cpu')lora_model = torch.load(lora_path,  map_location='cpu')for k, v in tqdm(model_weight.items(),desc="model_weight", ncols=100):if "emb" in k or "key" in k or "value" in k or "receptance" in k or "output" in k  or "head" in k:if "emb" in k:lora_a = "base_model.model." + k.replace(".weight", ".lora_embedding_A.default")lora_b = "base_model.model." + k.replace(".weight", ".lora_embedding_B.default")device = v.devicew_a = lora_model[lora_a].Tw_b = lora_model[lora_b].Tw = torch.mm(w_a, w_b).cpu()new_w = v.cpu() + 2 * wmodel_weight[k] = new_w.to(device)elif "weight" in k:lora_a = "base_model.model." + k.replace(".weight", ".lora_A.default.weight")lora_b = "base_model.model." + k.replace(".weight", ".lora_B.default.weight")device = v.devicew_a = lora_model[lora_a]w_b = lora_model[lora_b]w = torch.mm(w_b, w_a).cpu()# w = torch.mm(w_b, w_a)new_w = v.cpu() + 2 * wmodel_weight[k] = new_w.to(device)else:model_weight[k] = velse:model_weight[k] = vrwkv_lora_path = "./rwkv.pth"torch.save(model_weight,rwkv_lora_path)print("merge_lora_weights finished!")if __name__ == '__main__':merge_lora_weights()bind_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0)local_ip = get_local_ip('心跳地址', 心跳IP)bind_socket.bind(('0.0.0.0', 0))web.run_app(main(local_ip, bind_socket.getsockname()[1]), sock=bind_socket)

web服务启动展示

2023-11-02 06:21:12,812 [INFO] rwkv_chat_lora_iir.py:147 out_str——我是一个基于GPT-3.5接口的AI机器人。Question: 你好呀,你是谁?Answer: 我是一个基于GPT-3.5接口的AI机器人
2023-11-02 06:21:12,838 [INFO] rwkv_chat_lora_iir.py:148 Server __init__ finished!
======== Running on http://0.0.0.0:45149 ========
(Press CTRL+C to quit)

可以采用心跳地址来请求 也可以直连物理机IP:45149/nlp地址来请求:

五、总结

结果:

1、今天rwkv_v4  集内55%(49 epoch) 集外15% (1191条数据)
2、昨天rwkv_v5 集内最高34%(9 epoch) 集外24%(1191条数据 4epoch)
结论:
a、rwkv_v5  确实要比rwkv_v4 对集外的泛化能力强很多
b、比ChatGLM6B蒸馏到ChatGLM1.5B效果差很多(集外92%)——训练方式完全不同,这个训练成本非常大

        虽然rwkv1.5B在我们业务领域上表现很差(具体表现为泛化能力差,生成不稳定,和我们的任务难度有关以及训练数据规模也有关),但是它的推理速度是真的非常快,要比同参数规模的任何模型都要快,如果能有办法把效果做起来就更好了 ;lora在快速验证模型基本效果的效率上非常高;同时做单机多卡的训练的时候,accelerate和deepspeed真的是一个很好的工具,并且能节约显存;多人共用的机器不要瞎升级系统lib库,可以直接搭建docker环境来完成任务。

参考文章

RWKV语言模型从入门到放弃,保姆级Training、Fine-tuning、Lora入坑教程

相关文章:

rwkv模型lora微调之accelerate和deepspeed训练加速

目录 一、rwkv模型简介 二、lora原理简介 三、rwkv-lora微调 1、数据整理 2、环境搭建 a、Dockerfile编写 b、制造镜像 c、容器启动 3、训练代码修改 四、模型推理 1、模型推理 2、lora权重合并 3、推理web服务 五、总结 由于业务采用的ChatGLM模型推理成本太大了…...

分享一下在微信小程序里怎么做一个投票链接

在当今信息化社会&#xff0c;投票已成为各行各业收集意见、汇聚智慧的重要手段。传统的投票方式往往需要投入大量人力物力&#xff0c;而如今&#xff0c;借助微信小程序&#xff0c;我们可以在几分钟内创建一个高效、便捷的投票平台。本文将详细介绍如何在微信小程序中添加投…...

v-model语法糖

v-model原理 v-model实现双向绑定的语法糖&#xff0c;常用于表单与组件之间的数据双向绑定v-model本质上是 value属性和input事件的一层包装 v-model的作用&#xff1a;提供数据的双向绑定数据发生了改变&#xff0c;页面会自动变 v-bind:value页面输入改变 &#xff0c; 数据…...

纷享销客荣获最佳制造业数字营销服务商奖

2023年10月26日&#xff0c;第二届中国制造业数智化发展大会在上海盛大召开。本次大会汇聚了制造行业的顶尖企业和专家&#xff0c;共同探讨如何通过数字化转型赋能企业自身成长&#xff0c;实现信息化向数字化的升级转型。 在本次盛会上&#xff0c;纷享销客以其卓越的基本面、…...

蓝桥杯每日一题2023.11.3

题目描述 承压计算 - 蓝桥云课 (lanqiao.cn) 题目分析 将重量存入a中&#xff0c;每一层从上到下进行计算&#xff0c;用d进行计算列的重量&#xff0c;当前d的重量应为正上数组和右上数组的个半和并加上自身的重量 计算到30层记录最大最小值&#xff0c;进行比例运算即可 …...

中国电子云-隐私计算-云原生安全可信计算,物理-硬件-系统-云产品-云平台,数据安全防护

目录 联邦学习的架构思想 中国电子云-隐私计算-云原生安全...

PHP服务器端电商API原理及示例讲解(电商接口开发/接入)

下面小编就为大家分享一篇PHP服务器端API原理及示例讲解(接口开发)&#xff0c;具有很好的参考价值&#xff0c;希望对大家有所帮助 相信大家都做过PHP请求电商API接口获取数据&#xff0c;比如淘宝平台商品API接口&#xff0c;订单接口&#xff0c;京东接口&#xff0c;1688接…...

Spring Cloud应用- Eureka原理、搭建

初期对Spring Cloud的学习以应用搭建为主&#xff0c;所以内容不会太枯燥。 一直以来&#xff0c;自以为Spring全家桶的学习中&#xff0c;Spring framework是基础中的基础&#xff0c;部分内容也还是必须要读源码去理解底层原理&#xff0c;SpringMVC、SpringBoot&#xff0c…...

Servlet 设置启动时机(web.xml方式和@WebServlet方式)

1、通过web.xml方式 5)Servlet的启动时机 - 默认情况下&#xff0c;servlet是不会随着容器的启动而被实例化的&#xff0c;只有当第一次给我发请求时才会被实例化那么&#xff0c;这种情况对于第一次请求是不公平的因此&#xff0c;为了提高用户体验度&#xff0c;提高服务器的…...

一个使用uniapp+vue3+ts+pinia+uview-plus开发小程序的基础模板

uniappuviewPlusvue3tspiniavite 开发基础模板 使用 uniapp vue3 ts pinia vite 开发基础模板&#xff0c;拿来即可使用&#xff0c;不要删除 yarn.lock 文件&#xff0c;否则会启动报错&#xff0c;这个可能和 pinia 的版本有关&#xff0c;所以不要随意修改。 拉取代码…...

Kali安装docker

第一步&#xff1a;kali添加Docker官方的GPG密钥 curl -fsSL https://download.docker.com/linux/debian/gpg | sudo apt-key add 第二步&#xff1a;进入root更新源&#xff1a; su rootecho ‘deb https://download.docker.com/linux/debian stretch stable’> /etc/ap…...

Maven第七章:Maven工程最佳实践

Maven第七章:Maven工程最佳实践 前言 本章重点,通过一个maven工程最佳实践案例,熟悉和掌握maven在项目中的应用基本思路,让你的技能值瞬间暴涨。 最佳实践 确定项目的坐标和依赖 在Maven中,项目的坐标定义了项目的唯一标识符,包括groupId、artifactId和version。因此,在…...

【深度学习】【pytorch】对卷积层置零卷积核进行真实剪枝

最近需要对深度学习模型进行部署,因此需要对模型进行压缩,博主取舍了很多大佬的博文并亲测有效,分享笔记邀大家共同学习讨论 文章目录 前言卷积层剪枝总结 前言 深度学习剪枝(Pruning)是一种用于减少神经网络模型大小、减少计算量和提高推理效率的技术&#xff0c;通过去除神经…...

机器人仿真-gazebo学习笔记(3)URDF和机器人模型

1.URDF简介 URDF(统一机器人麦哦书格式)是ROS中的重要机器人模型描述格式&#xff0c;ROS提供了URDF文件的c解析器&#xff0c;可以解析URDF文件中使用XML格式的机器人模型。 urdf - ROS Wiki 自己查阅ros官方对URDF的介绍其实会强于大部分网上流传的文章。 1.URDF文件常用的…...

lua-resty-request库写入爬虫ip实现数据抓取

根据提供的引用内容&#xff0c;正确的库名称应该是lua-resty-http&#xff0c;而不是lua-resty-request。使用lua-resty-http库可以方便地进行爬虫&#xff0c;需要先安装OpenResty和lua-resty-http库&#xff0c;并将其引入到Lua脚本中。然后&#xff0c;可以使用lua-resty-h…...

gitlab Activating and deactivating users

原文&#xff1a;Redirecting... Deactivating a userActivating a user Activating and deactivating users GitLab 管理员可以停用和激活用户. Deactivating a user 在 GitLab 12.4 中引入 . 为了临时阻止没有最近活动的 GitLab 用户访问&#xff0c;管理员可以选择停用…...

linux入门到精通-第五章-动态库和静态库

目录 参考概述1、静态链接2 、动态链接3 、静态、动态编译对比 静态库和动态库简介传统编译 静态库制作和使用1、创建静态库的过程2、使用静态库 动态库制作和使用1、创建动态库的过程1&#xff09;、生成目标文件&#xff0c;此时要加编译选项&#xff1a;-fPIC &#xff08;f…...

markdown 如何更改字体以及颜色等功能

markdown 是IT人士写文档的常用方式&#xff0c;但是markdown默认又不支持颜色字体等特殊功能&#xff0c;所以呢想实现字体颜色高亮等特殊功能&#xff0c;实现的方法呢就是使用HTML&#xff0c;所以将部分文字改成HTML代码就行 颜色 <font color#0099ff>color #0099f…...

一次cs上线服务器的练习

环境&#xff1a;利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线...

STM32-高级定时器

以STM32F407为例。 高级定时器 高级定时器比通用定时器增加了可编程死区互补输出、重复计数器、带刹车&#xff08;断路&#xff09;功能&#xff0c;这些功能都是针对工业电机控制方面。 功能框图 16位向上、向下、向上/向下自动重装载计数器。 16位可编程预分频器&#xff0c…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂

蛋白质结合剂&#xff08;如抗体、抑制肽&#xff09;在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上&#xff0c;高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术&#xff0c;但这类方法普遍面临资源消耗巨大、研发周期冗长…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式&#xff1a; 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

优选算法第十二讲:队列 + 宽搜 优先级队列

优选算法第十二讲&#xff1a;队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

【Elasticsearch】Elasticsearch 在大数据生态圈的地位 实践经验

Elasticsearch 在大数据生态圈的地位 & 实践经验 1.Elasticsearch 的优势1.1 Elasticsearch 解决的核心问题1.1.1 传统方案的短板1.1.2 Elasticsearch 的解决方案 1.2 与大数据组件的对比优势1.3 关键优势技术支撑1.4 Elasticsearch 的竞品1.4.1 全文搜索领域1.4.2 日志分析…...

C++--string的模拟实现

一,引言 string的模拟实现是只对string对象中给的主要功能经行模拟实现&#xff0c;其目的是加强对string的底层了解&#xff0c;以便于在以后的学习或者工作中更加熟练的使用string。本文中的代码仅供参考并不唯一。 二,默认成员函数 string主要有三个成员变量&#xff0c;…...

比较数据迁移后MySQL数据库和ClickHouse数据仓库中的表

设计一个MySQL数据库和Clickhouse数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...