当前位置: 首页 > news >正文

动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习

  • 动手学习RAG: 向量模型
  • 动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习
  • 动手学习RAG:迟交互模型colbert微调实践 bge-m3

1. 环境准备

pip install transformers
pip install open-retrievals
  • 注意安装时是pip install open-retrievals,但调用时只需要import retrievals
  • 欢迎关注最新的更新 https://github.com/LongxingTan/open-retrievals

2. 使用M3E模型

from retrievals import AutoModelForEmbeddingembedder = AutoModelForEmbedding.from_pretrained('moka-ai/m3e-base', pooling_method='mean')
embedder

请添加图片描述

sentences = ['* Moka 此文本嵌入模型由 MokaAI 训练并开源,训练脚本使用 uniem','* Massive 此文本嵌入模型通过**千万级**的中文句对数据集进行训练','* Mixed 此文本嵌入模型支持中英双语的同质文本相似度计算,异质文本检索等功能,未来还会支持代码检索,ALL in one'
]embeddings = embedder.encode(sentences)for sentence, embedding in zip(sentences, embeddings):print("Sentence:", sentence)print("Embedding:", embedding)print("")

请添加图片描述

3. deepspeed 微调m3e模型

数据仍然采用之前介绍的t2-ranking数据集

  • deepspeed配置保存为 ds_zero2_no_offload.json. 不过虽然设置了zero2,这里我只用了一张卡. 但deepspeed也很容易扩展到多卡,或多机多卡
    • 关于deepspeed的分布式设置,可参考Tranformer分布式特辑
{"fp16": {"enabled": "auto","loss_scale": 0,"loss_scale_window": 100,"initial_scale_power": 16,"hysteresis": 2,"min_loss_scale": 1e-10},"zero_optimization": {"stage": 2,"allgather_partitions": true,"allgather_bucket_size": 1e8,"overlap_comm": true,"reduce_scatter": true,"reduce_bucket_size": 1e8,"contiguous_gradients": true},"gradient_accumulation_steps": "auto","gradient_clipping": "auto","steps_per_print": 2000,"train_batch_size": "auto","train_micro_batch_size_per_gpu": "auto","wall_clock_breakdown": false
}

这里稍微修改了open-retrievals这里的代码,主要是修改了导入为包的导入,而不是相对引用。保存文件为embed.py

"""Embedding fine tune pipeline"""import logging
import os
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optionalimport torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seedfrom retrievals import (EncodeCollator,EncodeDataset,PairCollator,RetrievalTrainDataset,TripletCollator,
)
from retrievals.losses import AutoLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.models.embedding_auto import AutoModelForEmbedding
from retrievals.trainer import RetrievalTrainer# os.environ["WANDB_LOG_MODEL"] = "false"
logger = logging.getLogger(__name__)@dataclass
class ModelArguments:model_name_or_path: str = field(metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"})config_name: Optional[str] = field(default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"})tokenizer_name: Optional[str] = field(default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"})cache_dir: Optional[str] = field(default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"})causal_lm: bool = field(default=False, metadata={'help': "Whether the model is a causal lm or not"})lora_path: Optional[str] = field(default=None, metadata={'help': "Lora adapter save path"})@dataclass
class DataArguments:data_name_or_path: str = field(default=None, metadata={"help": "Path to train data"})train_group_size: int = field(default=2)unfold_each_positive: bool = field(default=False)query_max_length: int = field(default=32,metadata={"help": "The maximum total input sequence length after tokenization for passage. Sequences longer ""than this will be truncated, sequences shorter will be padded."},)document_max_length: int = field(default=128,metadata={"help": "The maximum total input sequence length after tokenization for passage. Sequences longer ""than this will be truncated, sequences shorter will be padded."},)query_instruction: str = field(default=None, metadata={"help": "instruction for query"})document_instruction: str = field(default=None, metadata={"help": "instruction for document"})query_key: str = field(default=None)positive_key: str = field(default='positive')negative_key: str = field(default='negative')is_query: bool = field(default=False)encoding_save_file: str = field(default='embed.pkl')def __post_init__(self):# self.data_name_or_path = 'json'self.dataset_split = 'train'self.dataset_language = 'default'if self.data_name_or_path is not None:if not os.path.isfile(self.data_name_or_path) and not os.path.isdir(self.data_name_or_path):info = self.data_name_or_path.split('/')self.dataset_split = info[-1] if len(info) == 3 else 'train'self.data_name_or_path = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)self.dataset_language = 'default'if ':' in self.data_name_or_path:self.data_name_or_path, self.dataset_language = self.data_name_or_path.split(':')@dataclass
class RetrieverTrainingArguments(TrainingArguments):train_type: str = field(default='pairwise', metadata={'help': "train type of point, pair, or list"})negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})temperature: Optional[float] = field(default=0.02)fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})normalized: bool = field(default=True)loss_fn: str = field(default='infonce')use_inbatch_negative: bool = field(default=True, metadata={"help": "use documents in the same batch as negatives"})remove_unused_columns: bool = field(default=False)use_lora: bool = field(default=False)use_bnb_config: bool = field(default=False)do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})report_to: Optional[List[str]] = field(default="none", metadata={"help": "The list of integrations to report the results and logs to."})def main():parser = HfArgumentParser((ModelArguments, DataArguments, RetrieverTrainingArguments))model_args, data_args, training_args = parser.parse_args_into_dataclasses()model_args: ModelArgumentsdata_args: DataArgumentstraining_args: TrainingArgumentsif (os.path.exists(training_args.output_dir)and os.listdir(training_args.output_dir)and training_args.do_trainand not training_args.overwrite_output_dir):raise ValueError(f"Output directory ({training_args.output_dir}) already exists and is not empty. ""Use --overwrite_output_dir to overcome.")logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",datefmt="%m/%d/%Y %H:%M:%S",level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,)logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",training_args.local_rank,training_args.device,training_args.n_gpu,bool(training_args.local_rank != -1),training_args.fp16,)logger.info("Training/evaluation parameters %s", training_args)logger.info("Model parameters %s", model_args)logger.info("Data parameters %s", data_args)set_seed(training_args.seed)tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,cache_dir=model_args.cache_dir,use_fast=False,)if training_args.use_bnb_config:from transformers import BitsAndBytesConfiglogger.info('Use quantization bnb config')quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,)else:quantization_config = Noneif training_args.do_train:model = AutoModelForEmbedding.from_pretrained(model_name_or_path=model_args.model_name_or_path,pooling_method=training_args.pooling_method,use_lora=training_args.use_lora,quantization_config=quantization_config,)loss_fn = AutoLoss(loss_name=training_args.loss_fn,loss_kwargs={'use_inbatch_negative': training_args.use_inbatch_negative,'temperature': training_args.temperature,},)model = model.set_train_type("pairwise",loss_fn=loss_fn,)train_dataset = RetrievalTrainDataset(args=data_args,tokenizer=tokenizer,positive_key=data_args.positive_key,negative_key=data_args.negative_key,)logger.info(f"Total training examples: {len(train_dataset)}")trainer = RetrievalTrainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=TripletCollator(tokenizer,query_max_length=data_args.query_max_length,document_max_length=data_args.document_max_length,positive_key=data_args.positive_key,negative_key=data_args.negative_key,),)Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)trainer.train()# trainer.save_model(training_args.output_dir)model.save_pretrained(training_args.output_dir)if trainer.is_world_process_zero():tokenizer.save_pretrained(training_args.output_dir)if training_args.do_encode:model = AutoModelForEmbedding.from_pretrained(model_name_or_path=model_args.model_name_or_path,pooling_method=training_args.pooling_method,use_lora=training_args.use_lora,quantization_config=quantization_config,lora_path=model_args.lora_path,)max_length = data_args.query_max_length if data_args.is_query else data_args.document_max_lengthlogger.info(f'Encoding will be saved in {training_args.output_dir}')encode_dataset = EncodeDataset(args=data_args, tokenizer=tokenizer, max_length=max_length, text_key='text')logger.info(f"Number of train samples: {len(encode_dataset)}, max_length: {max_length}")encode_loader = DataLoader(encode_dataset,batch_size=training_args.per_device_eval_batch_size,collate_fn=EncodeCollator(tokenizer, max_length=max_length, padding='max_length'),shuffle=False,drop_last=False,num_workers=training_args.dataloader_num_workers,)embeddings = model.encode(encode_loader, show_progress_bar=True, convert_to_numpy=True)lookup_indices = list(range(len(encode_dataset)))with open(os.path.join(training_args.output_dir, data_args.encoding_save_file), 'wb') as f:pickle.dump((embeddings, lookup_indices), f)if __name__ == "__main__":main()
  • 最终调用文件 shell run.sh
MODEL_NAME="moka-ai/m3e-base"TRAIN_DATA="/root/kag101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kag101/src/open-retrievals/t2/ft_out"# loss_fn: infonce, simcsedeepspeed -m --include localhost:0 embed.py \--deepspeed ds_zero2_no_offload.json \--output_dir $OUTPUT_DIR \--overwrite_output_dir \--model_name_or_path $MODEL_NAME \--do_train \--data_name_or_path $TRAIN_DATA \--positive_key positive \--negative_key negative \--pooling_method mean \--loss_fn infonce \--use_lora False \--query_instruction "" \--document_instruction "" \--learning_rate 3e-5 \--fp16 \--num_train_epochs 5 \--per_device_train_batch_size 32 \--dataloader_drop_last True \--query_max_length 64 \--document_max_length 256 \--train_group_size 4 \--logging_steps 100 \--temperature 0.02 \--save_total_limit 1 \--use_inbatch_negative false

请添加图片描述

4. 测试

微调前性能 c-mteb t2-ranking score

请添加图片描述

微调后性能

请添加图片描述

采用infoNCE损失函数,没有加in-batch negative,而关注的是困难负样本,经过微调map从0.654提升至0.692,mrr从0.754提升至0.805

对比一下非deepspeed而是直接torchrun的微调

  • map略低,mrr略高。猜测是因为deepspeed中设置的一些auto会和直接跑并不完全一样
    请添加图片描述

相关文章:

动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习

动手学习RAG: 向量模型动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习动手学习RAG:迟交互模型colbert微调实践 bge-m3 1. 环境准备 pip install transformers pip install open-retrievals注意安装时是pip install open-retrievals,但调用时只…...

Nacos rce-0day漏洞复现(nacos 2.3.2)

Nacos rce-0day漏洞复现(nacos 2.3.2) NACOS是 一个开源的服务发现、配置管理和服务治理平台,属于阿里巴巴的一款开源产品。影像版本:nacos2.3.2或2.4.0版本指纹:fofa:app“NACOS” 从 Github 官方介绍文档可以看出国…...

yjs04——matplotlib的使用(多个坐标图)

1.多个坐标图与一个图的折线对比 1.引入包;字体(同) import matplotlib.pyplot as plt import random plt.rcParams[font.family] [SimHei] plt.rcParams[axes.unicode_minus] False 2.创建幕布 2.1建立图层幕布 一个图:plt.fig…...

MOS管和三极管有什么区别?

MOS管是基于金属-氧化物-半导体结构的场效应晶体管,它的控制电压作用于氧化物层,通过调节栅极电势来控制源漏电流。MOS管是FET中的一种,现主要用增强型MOS管,分为PMOS和NMOS。 MOS管的三个极分别是G(栅极),D(漏极)&…...

医院多参数空气质量监控和压差监测系统简介@卓振思众

在现代医院管理中,确保患者和医疗人员的健康与安全是首要任务。为实现这一目标,医院需要依赖高科技设施来维持最佳的环境条件。特别是,多参数空气质量监测系统和压差监测系统在这一方面发挥了不可替代的作用。【卓振思众】多参数空气质量监测…...

[项目实战]EOS多节点部署

文章总览:YuanDaiMa2048博客文章总览 EOS多节点部署 (一)环境设计(二)节点配置(三)区块信息同步(四)启动节点并验证同步EOS单节点的环境如何配置 (一&#xf…...

setImmediate() vs setTimeout() 在 JavaScript 中的区别

setImmediate() vs setTimeout() 在 JavaScript 中的区别 在 JavaScript 中,setImmediate() 和 setTimeout() 都用于调度任务,但它们的工作方式不同。 JavaScript 的异步特性 JavaScript 以其非阻塞、异步行为而闻名,尤其是在 Node.js 环境…...

【Java文件操作】文件系统操作文件内容操作

文件系统操作 常见API 在Java中,File类是用于文件和目录路径名的抽象表示。以下是一些常见的方法: 构造方法: File(String pathname):根据给定的路径创建一个File对象。File(String parent, String child):根据父路径…...

关于若依flowable的安装

有个项目要使用工作流功能,在网上看了flowable的各种资料,最后选择用若依RuoYi-Vue-Flowable这个项目来迁移整合。 一、下载项目代码: 官方项目地址:https://gitee.com/shenzhanwang/Ruoyi-flowable/ 二、新建数据库&#xff…...

猜数字困难版(1-10000)

小游戏&#xff0c;通过提示每次猜高或猜低以及每次猜中的位数&#xff0c;10次内猜中1-10000的一个数。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthde…...

ASPICE术语表

术语来源描述活动Automotive SPICE V4.0由利益相关方或参与方执行的任务用参数Automotive SPICE V4.0应用参数是包含了在系统或软件层级可被更改的数据的软件变量&#xff0c;他们影响系统或软件的行为和属性。应用参数的概念有两种表达方式:规范(分别包括变量名称、值域范围、…...

Knife4j:打造优雅的SpringBoot API文档

1. 为什么需要API文档&#xff1f; 在现代软件开发中,API文档的重要性不言而喻。一份清晰、准确、易于理解的API文档不仅能够提高开发效率,还能降低前后端沟通成本。今天,我们要介绍的Knife4j正是这样一款强大的API文档生成工具,它专为Spring Boot项目量身打造,让API文档的生成…...

数学建模笔记—— 多目标规划

数学建模笔记—— 多目标规划 多目标规划1. 模型原理1.1 多目标规划的一般形式1.2 多目标规划的解1.3 多目标规划的求解 2. 典型例题3. matlab代码实现 多目标规划 多目标规划是数学规划的一个分支。研究多于一个的目标函数在给定区域上的最优化。又称多目标最优化。通常记为 …...

【鸿蒙HarmonyOS NEXT】页面之间相互传递参数

【鸿蒙HarmonyOS NEXT】页面之间相互传递参数 一、环境说明二、页面之间相互传参 一、环境说明 DevEco Studio 版本&#xff1a; API版本&#xff1a;以12为主 二、页面之间相互传参 说明&#xff1a; 页面间的导航可以通过页面路由router模块来实现。页面路由模块根据页…...

SonicWall SSL VPN曝出高危漏洞,可能导致防火墙崩溃

近日&#xff0c;有黑客利用 SonicWall SonicOS 防火墙设备中的一个关键安全漏洞入侵受害者的网络。 这个不当访问控制漏洞被追踪为 CVE-2024-40766&#xff0c;影响到第 5 代、第 6 代和第 7 代防火墙。SonicWall于8月22日对其进行了修补&#xff0c;并警告称其只影响防火墙的…...

关于SAP标准委外(带料外协)采购订单信息

业务背景&#xff1a; 业务部门提出需要将售料外协方式变更为带料外协&#xff0c;带料外协实际业务存在一个委外订单存在多次发料&#xff0c;且每次发票需要进行齐套发料&#xff0c;不同批次的发料涉及物料替代。在半成品收货时需要进行对发料的组件进行扣料。 需求分析&a…...

SpringBoot整合WebSocket实现消息推送或聊天功能示例

最近在做一个功能&#xff0c;就是需要实时给用户推送消息&#xff0c;所以就需要用到 websocket springboot 接入 websocket 非常简单&#xff0c;只需要下面几个配置即可 pom 文件 <!-- spring-boot-web启动器 --><dependency><groupId>org.springframewo…...

使用 QEMU 模拟器运行 FreeRTOS 实时操作系统

文章目录 QEMU 官网QEMU 文档QEMU 简介QEMU 安装QEMU 命令启动虚拟机串口控制台监控命令行 FreeRTOS安装编译工具FreeRTOS 源码RISC-V-Qemu-virt_GCC 示例编译 RISC-V-Qemu-virt_GCC启动虚拟机运行 FreeRTOS QEMU 官网 https://www.qemu.org/ QEMU 文档 https://www.qemu.or…...

Oracle EBS中AR模块的财务流程概览

应收账款 (AR) 模块是Oracle E-Business Suite (EBS) 中另一个重要的财务管理模块&#xff0c;主要用于管理企业销售过程中的账款回收。下面是AR模块中的一些关键财务流程及其详细说明&#xff1a; 1. 销售订单管理 创建销售订单&#xff1a;当客户下单时&#xff0c;销售人员…...

Minitab 的直方图结果分析解释

Minitab 的直方图结果分析解释 步骤 1&#xff1a;评估关键特征 检查分布的尖峰和散布。评估样本数量对直方图外观的影响。 标识尖峰&#xff08;即&#xff0c;条的最高聚类&#xff09;&#xff1a; 尖峰表示样本中最常见的值。评估样本的散布以了解数据的变异程度。例如…...

AgentRE:用智能体框架提升知识图谱构建效果,重点是开源!

发布时间&#xff1a;2024 年 09 月 13 日 Agent应用 AgentRE: An Agent-Based Framework for Navigating Complex Information Landscapes in Relation Extraction 在复杂场景中&#xff0c;关系抽取 (RE) 因关系类型多样和实体间关系模糊而挑战重重&#xff0c;影响了传统 “…...

力扣题解2390

大家好&#xff0c;欢迎来到无限大的频道。 今日继续给大家带来力扣题解。 题目描述​&#xff08;中等&#xff09;&#xff1a; 从字符串中移除星号 给你一个包含若干星号 * 的字符串 s 。 在一步操作中&#xff0c;你可以&#xff1a; 选中 s 中的一个星号。 移除星号…...

用Python获取PDF页面的大小、方向和旋转角度

在文档管理和自动化领域&#xff0c;了解PDF文档的内在属性&#xff08;如页面大小、方向和旋转角度&#xff09;对于确保一致的文档处理和布局保真度至关重要。这些属性在内容重用、归档以及PDF无缝集成到网络环境或其他数字工作流程中起着关键作用&#xff0c;因为它们直接影…...

【即时通讯】轮询方式实现

技术栈 LayUI、jQuery实现前端效果。django4.2、django-ninja实现后端接口。 代码仓 - 后端 代码仓 - 前端 实现功能 首次访问页面并发送消息时需要设置昵称发送内容为空时要提示用户不能发送空消息前端定时获取消息&#xff0c;然后展示在页面上。 效果展示 首次发送需要…...

Flock 明牌空投教程

FLock 旨在为人工智能构建一个去中心化的隐私保护解决方案。FLock提出了一项名为联合学习区块&#xff08;简称 FLocks&#xff09;的研究计划&#xff0c;该计划使用区块链作为数据持有者之间的协调平台来进行机器学习&#xff0c;同时数据保持本地和隐私。通过用区块链取代收…...

项目内部调用的远程接口开发

编写一个项目内部调用的远程接口通常是为了在分布式系统或者微服务架构中&#xff0c;实现各个服务之间的通信和数据交换。这样的远程接口专门用于服务之间的调用&#xff0c;而不是直接暴露给外部用户或前端。 项目内部的远程接口统一放在api工程 首先进入api编写接口&#x…...

影响IP代理池稳定性的因素有哪些?

IP代理池在提供网络服务时&#xff0c;稳定性是一项决定性指标。多个外部和内部因素可能会影响这个稳定性&#xff0c;因此深入理解这些影响因素&#xff0c;可以帮助优化IP代理池的性能与服务质量。 1. IP来源质量 纯净度与使用频次&#xff1a;优质的IP来源常常被描述为纯净…...

基于Prometheus和Grafana的现代服务器监控体系构建

构建一个基于 Prometheus 和 Grafana 的现代服务器监控体系涉及多个步骤。以下是大体的流程和步骤说明&#xff1a; 1. Prometheus 监控系统 Prometheus 是一个开源的系统监控和报警工具&#xff0c;专门设计用于抓取时间序列数据。 1.1 Prometheus 的安装 Docker 安装 Prom…...

原生 input 中的 “type=file“ 上传文件

目标&#xff1a;实现文件上传功能 原型图&#xff1a; HTML部分&#xff1a; <div class"invoice-item"><div class"invoice-title">增值税专用发票</div><div class"invoice-box"><el-form-item label"标准…...

【Unity新闻】Unity的产品命名变化

快速回顾一下Unity产品命名的调整。 在2023年 Unity就宣布版本命名的变化&#xff0c;将使用Unity 6作为最新版本的命名。 具体的规则&#xff0c;在论坛里进行了说明。 以后正式的LTS版本就是Unity 6&#xff0c;将在2024年末发布。 而不管是之前的Runtime费还是今天的费用…...