动手学习RAG:迟交互模型colbert微调实践 bge-m3
- 动手学习RAG: 向量模型
- 动手学习RAG: BGE向量模型微调实践]()
- 动手学习RAG: BCEmbedding 向量模型 微调实践]()
- BCE ranking 微调实践]()
- GTE向量与排序模型 微调实践]()
- 模型微调中的模型序列长度]()
- 相似度与温度系数
本文我们来进行ColBERT模型的实践,按惯例,还是以open-retrievals中的代码为蓝本。在RAG兴起之后,ColBERT也获得了更多的关注。ColBERT整体结构和双塔特别相似,但迟交互式也就意味着比起一般ranking模型,交互来的更晚一些。

准备环境
pip install transformers
pip install open-retrievals
准备数据
还是采用C-MTEB/T2Reranking数据。
- 每个样本有query, positive, negative。其中query和positive构成正样本对,query和negative构成负样本对

使用
由于ColBERT作为迟交互式模型,既可以像向量模型一样生成向量,也可以计算相似度。BAAI/bge-m3中的colbert模型是基于XLMRoberta训练而来,因此使用ColBERT可以直接从bge-m3中加载预训练权重。
import transformers
from retrievals import ColBERT
model_name_or_path: str = 'BAAI/bge-m3'
model = ColBERT.from_pretrained(model_name_or_path,colbert_dim=1024, use_fp16=True,loss_fn=ColbertLoss(use_inbatch_negative=True),
)model

- 生成向量的方法
sentences_1 = ["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."]
sentences_2 = ['A dog is chasing car.', 'A man is playing a guitar.']output_1 = model.encode(sentences_1, normalize_embeddings=True)
print(output_1.shape, output_1)output_2 = model.encode(sentences_2, normalize_embeddings=True)
print(output_2.shape, output_2)

- 计算句子对 相似度的方法
sentences = [["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."],["In 1974, I won the championship in Southeast Asia in my first kickboxing match", 'A man is playing a guitar.'],
]scores_list = model.compute_score(sentences)
print(scores_list)

微调
尝试了两种方法来做,一种是调包自己写代码,一种是采用open-retrievals中的代码写shell脚本。这里我们采用第一种,另外一种方法可参考文章最后番外中的微调
import transformers
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
from retrievals import AutoModelForRanking, RerankCollator, RerankTrainDataset, RerankTrainer, ColBERT, RetrievalTrainDataset, ColBertCollator
from retrievals.losses import ColbertLoss
transformers.logging.set_verbosity_error()model_name_or_path: str = 'BAAI/bge-m3'learning_rate: float = 1e-5
batch_size: int = 2
epochs: int = 1
output_dir: str = './checkpoints'train_dataset = RetrievalTrainDataset('C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev'
)tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)data_collator = ColBertCollator(tokenizer,query_max_length=64,document_max_length=128,positive_key='positive',negative_key='negative',
)
model = ColBERT.from_pretrained(model_name_or_path,colbert_dim=1024,loss_fn=ColbertLoss(use_inbatch_negative=False),
)optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)training_args = TrainingArguments(learning_rate=learning_rate,per_device_train_batch_size=batch_size,num_train_epochs=epochs,output_dir = './checkpoints',remove_unused_columns=False,gradient_accumulation_steps=8,logging_steps=100,)
trainer = RerankTrainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=data_collator,
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()model.save_pretrained(output_dir)
训练过程中会加载BAAI/bge-m3模型权重

损失函数下降

{'loss': 7.4858, 'grad_norm': 30.484981536865234, 'learning_rate': 4.076305220883534e-06, 'epoch': 0.6024096385542169}
{'loss': 1.18, 'grad_norm': 28.68316650390625, 'learning_rate': 3.072289156626506e-06, 'epoch': 1.2048192771084336}
{'loss': 1.1399, 'grad_norm': 14.203865051269531, 'learning_rate': 2.068273092369478e-06, 'epoch': 1.8072289156626506}
{'loss': 1.1261, 'grad_norm': 24.30337905883789, 'learning_rate': 1.0642570281124499e-06, 'epoch': 2.4096385542168672}
{'train_runtime': 471.8191, 'train_samples_per_second': 33.827, 'train_steps_per_second': 1.055, 'train_loss': 2.4146631079984, 'epoch': 3.0}
评测
在C-MTEB中进行评测。微调前保留10%的数据集作为测试集验证
from datasets import load_datasetdataset = load_dataset("C-MTEB/T2Reranking", split="dev")
ds = dataset.train_test_split(test_size=0.1, seed=42)ds_train = ds["train"].filter(lambda x: len(x["positive"]) > 0 and len(x["negative"]) > 0
)ds_train.to_json("t2_ranking.jsonl", force_ascii=False)
微调前的指标:

微调后的指标:

{"dataset_revision": null,"mteb_dataset_name": "CustomReranking","mteb_version": "1.1.1","test": {"evaluation_time": 221.45,"map": 0.6950128151840831,"mrr": 0.8193114944390455}
}
番外:从语言模型直接训练ColBERT
之前的例子里是从BAAI/bge-m3继续微调,这里再跑一个从hfl/chinese-roberta-wwm-ext训练一个ColBERT模型
- 注意,从头跑需要设置更大的学习率与更多的epochs
MODEL_NAME='hfl/chinese-roberta-wwm-ext'
TRAIN_DATA="/root/kaggle101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kaggle101/src/open-retrievals/t2/ft_out"cd /root/open-retrievals/srctorchrun --nproc_per_node 1 \--module retrievals.pipelines.rerank \--output_dir $OUTPUT_DIR \--overwrite_output_dir \--model_name_or_path $MODEL_NAME \--tokenizer_name $MODEL_NAME \--model_type colbert \--do_train \--data_name_or_path $TRAIN_DATA \--positive_key positive \--negative_key negative \--learning_rate 5e-5 \--bf16 \--num_train_epochs 5 \--per_device_train_batch_size 32 \--dataloader_drop_last True \--query_max_length 128 \--max_length 256 \--train_group_size 4 \--unfold_each_positive false \--save_total_limit 1 \--logging_steps 100 \--use_inbatch_negative False
微调后指标
{"dataset_revision": null,"mteb_dataset_name": "CustomReranking","mteb_version": "1.1.1","test": {"evaluation_time": 75.38,"map": 0.6865308507184888,"mrr": 0.8039965986394558}
}
相关文章:
动手学习RAG:迟交互模型colbert微调实践 bge-m3
动手学习RAG: 向量模型动手学习RAG: BGE向量模型微调实践]()动手学习RAG: BCEmbedding 向量模型 微调实践]()BCE ranking 微调实践]()GTE向量与排序模型 微调实践]()模型微调中的模型序列长度]()相似度与温度系数 本文我们来进行ColBERT模型的实践,按惯例ÿ…...
springboot 整合quartz定时任务
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、pom的配置1.加注解 二、使用方法1.工程图2.创建工具类 三、controller 实现 前言 提示:这里可以添加本文要记录的大概内容: 提示&a…...
erlang学习: Mnesia Erlang数据库3
Mnesia数据库删除实现和事务处理 -module(test_mnesia). -include_lib("stdlib/include/qlc.hrl").-record(shop, {item, quantity, cost}). %% API -export([insert/3, select/0, select/1, delete/1, transaction/1,start/0, do_this_once/0]). start() ->mnes…...
善于善行——贵金属回收
在当今社会,贵金属回收已成为一项日益重要的产业。随 着科技的不断进步和人们对资源可持续利用的认识逐渐提高,贵金属回收的现状也备受关注。 目前,贵金属回收市场呈现出蓬勃发展的态势。一方面,贵金属如金、银、铂、钯等在众多领…...
用CSS 方式设置 table 样式
在现代Web开发中,使用CSS来设置table的样式是一种常见且强大的方法,它能让你的表格数据既美观又易于阅读。下面我将通过一个示例来展示如何使用现代CSS技巧来美化表格。 效果图 HTML 结构 首先,我们定义一个基本的HTML表格结构:…...
Elasticsearch7.x 集群迁移文档
一、集群样例信息 集群名称:escluster-ali-test 1、源集群:(source_cluster) 节点IP节点名称节点角色是否为master节点10.200.112.149es2.gj1.china-job.cndata,master是10.200.112.151es1.gj1.china-job.cndata,master否10.200.112.153es…...
高空抛物检测算法的应用场景解析
高空抛物事件频发,对公众安全构成严重威胁。无论是居民区还是商业中心,从高层建筑中丢弃物品都可能导致人员伤亡和财产损失。传统的监控手段多以事后追溯为主,无法在事发时及时预警和干预。为应对这一难题,视觉分析技术的发展为高…...
Leetcode 无重复字符的最长子串
算法思想: 滑动窗口:通过 start 和 end 来维护一个滑动窗口,start 指向当前窗口的起点,end 是当前窗口的末尾。滑动窗口中的字符都是无重复的。哈希表 charIndexMap:用于存储每个字符及其最近一次出现的位置。更新起始…...
用命令行的方式启动.netcore webapi
用命令行的方式启动.netcore web项目 进入指定的项目文件夹,比如我发布后的代码放在下面文件夹中 在此地址栏中输入“cmd”,打开命令提示符,进入到发布代码目录 命令行启动.netcore项目的命令为: dotnet 项目启动文件.dll --urls"ht…...
Spring6详细学习笔记(IOC+AOP)
一、Spring系统架构介绍 1.1、定义 Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器(框架)。Spring官网 Spring是一款主流的Java EE 轻量级开源框架,目的是用于简化Java企业级引用的开发难度和开发周期。从简单性、可测试性和松耦…...
@RequestMapping 基于哪个库进行通信
RequestMapping 是 Spring Framework 中用于处理 HTTP 请求的注解,主要用于定义控制器方法的请求映射。它并不直接基于某个特定的通信库,而是依赖于 Spring MVC 框架的核心功能。 1. Spring MVC RequestMapping 是 Spring MVC 的一部分,Spr…...
GPIO(General Purpose Input/Output)输入/输出
GPIO最简单的功能是输出高低电平;GPIO还可以被设置为输入功能,用于读取按键等输入信号;也可以将GPIO复用成芯片上的其他外设的控制引脚。 STM32F407ZGT6有8组IO。分别为GPIOA~GPIOH,除了GPIOH只有两个IO,其余每组IO有…...
两个pdf合并成一个pdf,这些pdf合并小技巧了解下
在日常工作和学习中,我们经常会遇到需要将多个PDF文件合并成一个文件的情况。这不仅可以提高文件管理的效率,还能让信息展示更加集中和便捷。今天就来给大家分享几种非常简单便捷的PDF合并小技巧,一起来学习下吧。 方法一:WPS WP…...
Transformer学习(2):自注意力机制
回顾 注意力机制 自注意力机制 自注意力机制中同样包含QKV,但它们是同源(Q≈K≈V),也就是来自相同的输入数据X,X可以分为 ( x 1 , x 2 , . . , x n ) (x_1,x_2,..,x_n) (x1,x2,..,xn)。 而通过输入嵌入层(input embedding),…...
分类预测|基于粒子群优化径向基神经网络的数据分类预测Matlab程序PSO-RBF 多特征输入多类别输出 含基础RBF程序
分类预测|基于粒子群优化径向基神经网络的数据分类预测Matlab程序PSO-RBF 多特征输入多类别输出 含基础RBF程序 文章目录 一、基本原理1. 粒子群优化算法(PSO)2. 径向基神经网络(RBF)PSO-RBF模型流程总结 二、实验结果三、核心代码…...
【React】Vite 构建 React
项目搭建 vite 官网:Vite 跟着文档走即可,选择 react ,然后 ts swc。 着重说一下 package-lock.json 这个文件有两个作用: 锁版本号(保证项目在不同人手里安装的依赖都是相同的,解决版本冲突的问题&am…...
算法刷题:300. 最长递增子序列、674. 最长连续递增序列、718. 最长重复子数组
300. 最长递增子序列 1.dp定义:dp[i]表示i之前包括i的以nums[i]结尾的最长递增子序列的长度 2.递推公式:if (nums[i] > nums[j]) dp[i] max(dp[i], dp[j] 1); 注意这里不是要dp[i] 与 dp[j] 1进行比较,而是我们要取dp[j] 1的最大值…...
【linux】一种基于虚拟串口的方式使两个应用通讯
在Linux系统中,两个应用之间通过串口(Serial Port)进行通信是一种常见的通信方式,特别是在嵌入式系统、工业自动化等领域。串口通信通常涉及到对串口设备的配置和读写操作。以下是一个基本的步骤指南,说明如何在Linux中…...
并行程序设计基础——并行I/O(3)
目录 一、多视口的并行文件并行读写 1、文件视口与指针 1.1 MPI_FILE_SET_VIEW 1.2 MPI_FILE_GET_VIEW 1.3 MPI_FILE_SEEK 1.4 MPI_FILE_GET_POSTION 1.5 MPI_FILE_GET_BYTE_OFFSET 2、阻塞方式的视口读写 2.1 MPI_FILE_READ 2.2 MPI_FILE_WRITE 2.3 MPI_FILE_READ_…...
性能测试-jmeter脚本录制(十五)
一、jmeter脚本录制(不推荐)简介: 二、jmeter脚本录制步骤 1、添加代理服务器和线程组 2、配置http代理服务器的端口和目标线程组 3修改本机浏览器代理 4、点击启动 5、每次操作页面前,修改提示文字...
19c补丁后oracle属主变化,导致不能识别磁盘组
补丁后服务器重启,数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后,存在与用户组权限相关的问题。具体表现为,Oracle 实例的运行用户(oracle)和集…...
【Python】 -- 趣味代码 - 小恐龙游戏
文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...
shell脚本--常见案例
1、自动备份文件或目录 2、批量重命名文件 3、查找并删除指定名称的文件: 4、批量删除文件 5、查找并替换文件内容 6、批量创建文件 7、创建文件夹并移动文件 8、在文件夹中查找文件...
可靠性+灵活性:电力载波技术在楼宇自控中的核心价值
可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...
跨链模式:多链互操作架构与性能扩展方案
跨链模式:多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈:模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展(H2Cross架构): 适配层…...
从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...
WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)
一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...
大模型多显卡多服务器并行计算方法与实践指南
一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...
鱼香ros docker配置镜像报错:https://registry-1.docker.io/v2/
使用鱼香ros一件安装docker时的https://registry-1.docker.io/v2/问题 一键安装指令 wget http://fishros.com/install -O fishros && . fishros出现问题:docker pull 失败 网络不同,需要使用镜像源 按照如下步骤操作 sudo vi /etc/docker/dae…...
