当前位置: 首页 > news >正文

动手学习RAG:迟交互模型colbert微调实践 bge-m3

  • 动手学习RAG: 向量模型
  • 动手学习RAG: BGE向量模型微调实践]()
  • 动手学习RAG: BCEmbedding 向量模型 微调实践]()
  • BCE ranking 微调实践]()
  • GTE向量与排序模型 微调实践]()
  • 模型微调中的模型序列长度]()
  • 相似度与温度系数

本文我们来进行ColBERT模型的实践,按惯例,还是以open-retrievals中的代码为蓝本。在RAG兴起之后,ColBERT也获得了更多的关注。ColBERT整体结构和双塔特别相似,但迟交互式也就意味着比起一般ranking模型,交互来的更晚一些。
请添加图片描述

准备环境

pip install transformers
pip install open-retrievals

准备数据

还是采用C-MTEB/T2Reranking数据。

  • 每个样本有query, positive, negative。其中query和positive构成正样本对,query和negative构成负样本对
    请添加图片描述

使用

由于ColBERT作为迟交互式模型,既可以像向量模型一样生成向量,也可以计算相似度。BAAI/bge-m3中的colbert模型是基于XLMRoberta训练而来,因此使用ColBERT可以直接从bge-m3中加载预训练权重。

import transformers
from retrievals import ColBERT
model_name_or_path: str =  'BAAI/bge-m3' 
model = ColBERT.from_pretrained(model_name_or_path,colbert_dim=1024,    use_fp16=True,loss_fn=ColbertLoss(use_inbatch_negative=True),
)model

请添加图片描述

  • 生成向量的方法
sentences_1 = ["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."]
sentences_2 = ['A dog is chasing car.', 'A man is playing a guitar.']output_1 = model.encode(sentences_1, normalize_embeddings=True)
print(output_1.shape, output_1)output_2 = model.encode(sentences_2, normalize_embeddings=True)
print(output_2.shape, output_2)

请添加图片描述

  • 计算句子对 相似度的方法
sentences = [["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."],["In 1974, I won the championship in Southeast Asia in my first kickboxing match", 'A man is playing a guitar.'],
]scores_list = model.compute_score(sentences)
print(scores_list)

请添加图片描述

微调

尝试了两种方法来做,一种是调包自己写代码,一种是采用open-retrievals中的代码写shell脚本。这里我们采用第一种,另外一种方法可参考文章最后番外中的微调

import transformers
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
from retrievals import AutoModelForRanking, RerankCollator, RerankTrainDataset, RerankTrainer, ColBERT, RetrievalTrainDataset, ColBertCollator
from retrievals.losses import ColbertLoss
transformers.logging.set_verbosity_error()model_name_or_path: str = 'BAAI/bge-m3'learning_rate: float = 1e-5
batch_size: int = 2
epochs: int = 1
output_dir: str = './checkpoints'train_dataset = RetrievalTrainDataset('C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev'
)tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)data_collator = ColBertCollator(tokenizer,query_max_length=64,document_max_length=128,positive_key='positive',negative_key='negative',
)
model = ColBERT.from_pretrained(model_name_or_path,colbert_dim=1024,loss_fn=ColbertLoss(use_inbatch_negative=False),
)optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)training_args = TrainingArguments(learning_rate=learning_rate,per_device_train_batch_size=batch_size,num_train_epochs=epochs,output_dir = './checkpoints',remove_unused_columns=False,gradient_accumulation_steps=8,logging_steps=100,)
trainer = RerankTrainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=data_collator,
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()model.save_pretrained(output_dir)

训练过程中会加载BAAI/bge-m3模型权重
请添加图片描述
损失函数下降
请添加图片描述

{'loss': 7.4858, 'grad_norm': 30.484981536865234, 'learning_rate': 4.076305220883534e-06, 'epoch': 0.6024096385542169}
{'loss': 1.18, 'grad_norm': 28.68316650390625, 'learning_rate': 3.072289156626506e-06, 'epoch': 1.2048192771084336}
{'loss': 1.1399, 'grad_norm': 14.203865051269531, 'learning_rate': 2.068273092369478e-06, 'epoch': 1.8072289156626506}
{'loss': 1.1261, 'grad_norm': 24.30337905883789, 'learning_rate': 1.0642570281124499e-06, 'epoch': 2.4096385542168672}
{'train_runtime': 471.8191, 'train_samples_per_second': 33.827, 'train_steps_per_second': 1.055, 'train_loss': 2.4146631079984, 'epoch': 3.0}

评测

在C-MTEB中进行评测。微调前保留10%的数据集作为测试集验证

from datasets import load_datasetdataset = load_dataset("C-MTEB/T2Reranking", split="dev")
ds = dataset.train_test_split(test_size=0.1, seed=42)ds_train = ds["train"].filter(lambda x: len(x["positive"]) > 0 and len(x["negative"]) > 0
)ds_train.to_json("t2_ranking.jsonl", force_ascii=False)

微调前的指标:
请添加图片描述
微调后的指标:
请添加图片描述

{"dataset_revision": null,"mteb_dataset_name": "CustomReranking","mteb_version": "1.1.1","test": {"evaluation_time": 221.45,"map": 0.6950128151840831,"mrr": 0.8193114944390455}
}

番外:从语言模型直接训练ColBERT

之前的例子里是从BAAI/bge-m3继续微调,这里再跑一个从hfl/chinese-roberta-wwm-ext训练一个ColBERT模型

  • 注意,从头跑需要设置更大的学习率与更多的epochs
MODEL_NAME='hfl/chinese-roberta-wwm-ext'
TRAIN_DATA="/root/kaggle101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kaggle101/src/open-retrievals/t2/ft_out"cd /root/open-retrievals/srctorchrun --nproc_per_node 1 \--module retrievals.pipelines.rerank \--output_dir $OUTPUT_DIR \--overwrite_output_dir \--model_name_or_path $MODEL_NAME \--tokenizer_name $MODEL_NAME \--model_type colbert \--do_train \--data_name_or_path $TRAIN_DATA \--positive_key positive \--negative_key negative \--learning_rate 5e-5 \--bf16 \--num_train_epochs 5 \--per_device_train_batch_size 32 \--dataloader_drop_last True \--query_max_length 128 \--max_length 256 \--train_group_size 4 \--unfold_each_positive false \--save_total_limit 1 \--logging_steps 100 \--use_inbatch_negative False

微调后指标

{"dataset_revision": null,"mteb_dataset_name": "CustomReranking","mteb_version": "1.1.1","test": {"evaluation_time": 75.38,"map": 0.6865308507184888,"mrr": 0.8039965986394558}
}

相关文章:

动手学习RAG:迟交互模型colbert微调实践 bge-m3

动手学习RAG: 向量模型动手学习RAG: BGE向量模型微调实践]()动手学习RAG: BCEmbedding 向量模型 微调实践]()BCE ranking 微调实践]()GTE向量与排序模型 微调实践]()模型微调中的模型序列长度]()相似度与温度系数 本文我们来进行ColBERT模型的实践,按惯例&#xff…...

springboot 整合quartz定时任务

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、pom的配置1.加注解 二、使用方法1.工程图2.创建工具类 三、controller 实现 前言 提示:这里可以添加本文要记录的大概内容: 提示&a…...

erlang学习: Mnesia Erlang数据库3

Mnesia数据库删除实现和事务处理 -module(test_mnesia). -include_lib("stdlib/include/qlc.hrl").-record(shop, {item, quantity, cost}). %% API -export([insert/3, select/0, select/1, delete/1, transaction/1,start/0, do_this_once/0]). start() ->mnes…...

善于善行——贵金属回收

在当今社会,贵金属回收已成为一项日益重要的产业。随 着科技的不断进步和人们对资源可持续利用的认识逐渐提高,贵金属回收的现状也备受关注。 目前,贵金属回收市场呈现出蓬勃发展的态势。一方面,贵金属如金、银、铂、钯等在众多领…...

用CSS 方式设置 table 样式

在现代Web开发中,使用CSS来设置table的样式是一种常见且强大的方法,它能让你的表格数据既美观又易于阅读。下面我将通过一个示例来展示如何使用现代CSS技巧来美化表格。 效果图 HTML 结构 首先,我们定义一个基本的HTML表格结构:…...

Elasticsearch7.x 集群迁移文档

一、集群样例信息 集群名称:escluster-ali-test 1、源集群:(source_cluster) 节点IP节点名称节点角色是否为master节点10.200.112.149es2.gj1.china-job.cndata,master是10.200.112.151es1.gj1.china-job.cndata,master否10.200.112.153es…...

高空抛物检测算法的应用场景解析

高空抛物事件频发,对公众安全构成严重威胁。无论是居民区还是商业中心,从高层建筑中丢弃物品都可能导致人员伤亡和财产损失。传统的监控手段多以事后追溯为主,无法在事发时及时预警和干预。为应对这一难题,视觉分析技术的发展为高…...

Leetcode 无重复字符的最长子串

算法思想: 滑动窗口:通过 start 和 end 来维护一个滑动窗口,start 指向当前窗口的起点,end 是当前窗口的末尾。滑动窗口中的字符都是无重复的。哈希表 charIndexMap:用于存储每个字符及其最近一次出现的位置。更新起始…...

用命令行的方式启动.netcore webapi

用命令行的方式启动.netcore web项目 进入指定的项目文件夹,比如我发布后的代码放在下面文件夹中 在此地址栏中输入“cmd”,打开命令提示符,进入到发布代码目录 命令行启动.netcore项目的命令为: dotnet 项目启动文件.dll --urls"ht…...

Spring6详细学习笔记(IOC+AOP)

一、Spring系统架构介绍 1.1、定义 Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器(框架)。Spring官网 Spring是一款主流的Java EE 轻量级开源框架,目的是用于简化Java企业级引用的开发难度和开发周期。从简单性、可测试性和松耦…...

@RequestMapping 基于哪个库进行通信

RequestMapping 是 Spring Framework 中用于处理 HTTP 请求的注解,主要用于定义控制器方法的请求映射。它并不直接基于某个特定的通信库,而是依赖于 Spring MVC 框架的核心功能。 1. Spring MVC RequestMapping 是 Spring MVC 的一部分,Spr…...

GPIO(General Purpose Input/Output)输入/输出

GPIO最简单的功能是输出高低电平;GPIO还可以被设置为输入功能,用于读取按键等输入信号;也可以将GPIO复用成芯片上的其他外设的控制引脚。 STM32F407ZGT6有8组IO。分别为GPIOA~GPIOH,除了GPIOH只有两个IO,其余每组IO有…...

两个pdf合并成一个pdf,这些pdf合并小技巧了解下

在日常工作和学习中,我们经常会遇到需要将多个PDF文件合并成一个文件的情况。这不仅可以提高文件管理的效率,还能让信息展示更加集中和便捷。今天就来给大家分享几种非常简单便捷的PDF合并小技巧,一起来学习下吧。 方法一:WPS WP…...

Transformer学习(2):自注意力机制

回顾 注意力机制 自注意力机制 自注意力机制中同样包含QKV,但它们是同源(Q≈K≈V),也就是来自相同的输入数据X,X可以分为 ( x 1 , x 2 , . . , x n ) (x_1,x_2,..,x_n) (x1​,x2​,..,xn​)。 而通过输入嵌入层(input embedding)&#xff0c…...

分类预测|基于粒子群优化径向基神经网络的数据分类预测Matlab程序PSO-RBF 多特征输入多类别输出 含基础RBF程序

分类预测|基于粒子群优化径向基神经网络的数据分类预测Matlab程序PSO-RBF 多特征输入多类别输出 含基础RBF程序 文章目录 一、基本原理1. 粒子群优化算法(PSO)2. 径向基神经网络(RBF)PSO-RBF模型流程总结 二、实验结果三、核心代码…...

【React】Vite 构建 React

项目搭建 vite 官网:Vite 跟着文档走即可,选择 react ,然后 ts swc。 着重说一下 package-lock.json 这个文件有两个作用: 锁版本号(保证项目在不同人手里安装的依赖都是相同的,解决版本冲突的问题&am…...

算法刷题:300. 最长递增子序列、674. 最长连续递增序列、718. 最长重复子数组

300. 最长递增子序列 1.dp定义:dp[i]表示i之前包括i的以nums[i]结尾的最长递增子序列的长度 2.递推公式:if (nums[i] > nums[j]) dp[i] max(dp[i], dp[j] 1); 注意这里不是要dp[i] 与 dp[j] 1进行比较,而是我们要取dp[j] 1的最大值…...

【linux】一种基于虚拟串口的方式使两个应用通讯

在Linux系统中,两个应用之间通过串口(Serial Port)进行通信是一种常见的通信方式,特别是在嵌入式系统、工业自动化等领域。串口通信通常涉及到对串口设备的配置和读写操作。以下是一个基本的步骤指南,说明如何在Linux中…...

并行程序设计基础——并行I/O(3)

目录 一、多视口的并行文件并行读写 1、文件视口与指针 1.1 MPI_FILE_SET_VIEW 1.2 MPI_FILE_GET_VIEW 1.3 MPI_FILE_SEEK 1.4 MPI_FILE_GET_POSTION 1.5 MPI_FILE_GET_BYTE_OFFSET 2、阻塞方式的视口读写 2.1 MPI_FILE_READ 2.2 MPI_FILE_WRITE 2.3 MPI_FILE_READ_…...

性能测试-jmeter脚本录制(十五)

一、jmeter脚本录制(不推荐)简介: 二、jmeter脚本录制步骤 1、添加代理服务器和线程组 2、配置http代理服务器的端口和目标线程组 3修改本机浏览器代理 4、点击启动 5、每次操作页面前,修改提示文字...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道(多模态 OCR → 语义检索 → 答案渲染)、两级检索(倒排 BM25 向量 HNSW)并以大语言模型兜底”的整体框架: 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后,分别用…...

设计模式和设计原则回顾

设计模式和设计原则回顾 23种设计模式是设计原则的完美体现,设计原则设计原则是设计模式的理论基石, 设计模式 在经典的设计模式分类中(如《设计模式:可复用面向对象软件的基础》一书中),总共有23种设计模式,分为三大类: 一、创建型模式(5种) 1. 单例模式(Sing…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...

反向工程与模型迁移:打造未来商品详情API的可持续创新体系

在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

实现弹窗随键盘上移居中

实现弹窗随键盘上移的核心思路 在Android中&#xff0c;可以通过监听键盘的显示和隐藏事件&#xff0c;动态调整弹窗的位置。关键点在于获取键盘高度&#xff0c;并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...