当前位置: 首页 > news >正文

动手学习RAG:迟交互模型colbert微调实践 bge-m3

  • 动手学习RAG: 向量模型
  • 动手学习RAG: BGE向量模型微调实践]()
  • 动手学习RAG: BCEmbedding 向量模型 微调实践]()
  • BCE ranking 微调实践]()
  • GTE向量与排序模型 微调实践]()
  • 模型微调中的模型序列长度]()
  • 相似度与温度系数

本文我们来进行ColBERT模型的实践,按惯例,还是以open-retrievals中的代码为蓝本。在RAG兴起之后,ColBERT也获得了更多的关注。ColBERT整体结构和双塔特别相似,但迟交互式也就意味着比起一般ranking模型,交互来的更晚一些。
请添加图片描述

准备环境

pip install transformers
pip install open-retrievals

准备数据

还是采用C-MTEB/T2Reranking数据。

  • 每个样本有query, positive, negative。其中query和positive构成正样本对,query和negative构成负样本对
    请添加图片描述

使用

由于ColBERT作为迟交互式模型,既可以像向量模型一样生成向量,也可以计算相似度。BAAI/bge-m3中的colbert模型是基于XLMRoberta训练而来,因此使用ColBERT可以直接从bge-m3中加载预训练权重。

import transformers
from retrievals import ColBERT
model_name_or_path: str =  'BAAI/bge-m3' 
model = ColBERT.from_pretrained(model_name_or_path,colbert_dim=1024,    use_fp16=True,loss_fn=ColbertLoss(use_inbatch_negative=True),
)model

请添加图片描述

  • 生成向量的方法
sentences_1 = ["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."]
sentences_2 = ['A dog is chasing car.', 'A man is playing a guitar.']output_1 = model.encode(sentences_1, normalize_embeddings=True)
print(output_1.shape, output_1)output_2 = model.encode(sentences_2, normalize_embeddings=True)
print(output_2.shape, output_2)

请添加图片描述

  • 计算句子对 相似度的方法
sentences = [["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."],["In 1974, I won the championship in Southeast Asia in my first kickboxing match", 'A man is playing a guitar.'],
]scores_list = model.compute_score(sentences)
print(scores_list)

请添加图片描述

微调

尝试了两种方法来做,一种是调包自己写代码,一种是采用open-retrievals中的代码写shell脚本。这里我们采用第一种,另外一种方法可参考文章最后番外中的微调

import transformers
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
from retrievals import AutoModelForRanking, RerankCollator, RerankTrainDataset, RerankTrainer, ColBERT, RetrievalTrainDataset, ColBertCollator
from retrievals.losses import ColbertLoss
transformers.logging.set_verbosity_error()model_name_or_path: str = 'BAAI/bge-m3'learning_rate: float = 1e-5
batch_size: int = 2
epochs: int = 1
output_dir: str = './checkpoints'train_dataset = RetrievalTrainDataset('C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev'
)tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)data_collator = ColBertCollator(tokenizer,query_max_length=64,document_max_length=128,positive_key='positive',negative_key='negative',
)
model = ColBERT.from_pretrained(model_name_or_path,colbert_dim=1024,loss_fn=ColbertLoss(use_inbatch_negative=False),
)optimizer = AdamW(model.parameters(), lr=learning_rate)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)training_args = TrainingArguments(learning_rate=learning_rate,per_device_train_batch_size=batch_size,num_train_epochs=epochs,output_dir = './checkpoints',remove_unused_columns=False,gradient_accumulation_steps=8,logging_steps=100,)
trainer = RerankTrainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=data_collator,
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()model.save_pretrained(output_dir)

训练过程中会加载BAAI/bge-m3模型权重
请添加图片描述
损失函数下降
请添加图片描述

{'loss': 7.4858, 'grad_norm': 30.484981536865234, 'learning_rate': 4.076305220883534e-06, 'epoch': 0.6024096385542169}
{'loss': 1.18, 'grad_norm': 28.68316650390625, 'learning_rate': 3.072289156626506e-06, 'epoch': 1.2048192771084336}
{'loss': 1.1399, 'grad_norm': 14.203865051269531, 'learning_rate': 2.068273092369478e-06, 'epoch': 1.8072289156626506}
{'loss': 1.1261, 'grad_norm': 24.30337905883789, 'learning_rate': 1.0642570281124499e-06, 'epoch': 2.4096385542168672}
{'train_runtime': 471.8191, 'train_samples_per_second': 33.827, 'train_steps_per_second': 1.055, 'train_loss': 2.4146631079984, 'epoch': 3.0}

评测

在C-MTEB中进行评测。微调前保留10%的数据集作为测试集验证

from datasets import load_datasetdataset = load_dataset("C-MTEB/T2Reranking", split="dev")
ds = dataset.train_test_split(test_size=0.1, seed=42)ds_train = ds["train"].filter(lambda x: len(x["positive"]) > 0 and len(x["negative"]) > 0
)ds_train.to_json("t2_ranking.jsonl", force_ascii=False)

微调前的指标:
请添加图片描述
微调后的指标:
请添加图片描述

{"dataset_revision": null,"mteb_dataset_name": "CustomReranking","mteb_version": "1.1.1","test": {"evaluation_time": 221.45,"map": 0.6950128151840831,"mrr": 0.8193114944390455}
}

番外:从语言模型直接训练ColBERT

之前的例子里是从BAAI/bge-m3继续微调,这里再跑一个从hfl/chinese-roberta-wwm-ext训练一个ColBERT模型

  • 注意,从头跑需要设置更大的学习率与更多的epochs
MODEL_NAME='hfl/chinese-roberta-wwm-ext'
TRAIN_DATA="/root/kaggle101/src/open-retrievals/t2/t2_ranking.jsonl"
OUTPUT_DIR="/root/kaggle101/src/open-retrievals/t2/ft_out"cd /root/open-retrievals/srctorchrun --nproc_per_node 1 \--module retrievals.pipelines.rerank \--output_dir $OUTPUT_DIR \--overwrite_output_dir \--model_name_or_path $MODEL_NAME \--tokenizer_name $MODEL_NAME \--model_type colbert \--do_train \--data_name_or_path $TRAIN_DATA \--positive_key positive \--negative_key negative \--learning_rate 5e-5 \--bf16 \--num_train_epochs 5 \--per_device_train_batch_size 32 \--dataloader_drop_last True \--query_max_length 128 \--max_length 256 \--train_group_size 4 \--unfold_each_positive false \--save_total_limit 1 \--logging_steps 100 \--use_inbatch_negative False

微调后指标

{"dataset_revision": null,"mteb_dataset_name": "CustomReranking","mteb_version": "1.1.1","test": {"evaluation_time": 75.38,"map": 0.6865308507184888,"mrr": 0.8039965986394558}
}

相关文章:

动手学习RAG:迟交互模型colbert微调实践 bge-m3

动手学习RAG: 向量模型动手学习RAG: BGE向量模型微调实践]()动手学习RAG: BCEmbedding 向量模型 微调实践]()BCE ranking 微调实践]()GTE向量与排序模型 微调实践]()模型微调中的模型序列长度]()相似度与温度系数 本文我们来进行ColBERT模型的实践,按惯例&#xff…...

springboot 整合quartz定时任务

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、pom的配置1.加注解 二、使用方法1.工程图2.创建工具类 三、controller 实现 前言 提示:这里可以添加本文要记录的大概内容: 提示&a…...

erlang学习: Mnesia Erlang数据库3

Mnesia数据库删除实现和事务处理 -module(test_mnesia). -include_lib("stdlib/include/qlc.hrl").-record(shop, {item, quantity, cost}). %% API -export([insert/3, select/0, select/1, delete/1, transaction/1,start/0, do_this_once/0]). start() ->mnes…...

善于善行——贵金属回收

在当今社会,贵金属回收已成为一项日益重要的产业。随 着科技的不断进步和人们对资源可持续利用的认识逐渐提高,贵金属回收的现状也备受关注。 目前,贵金属回收市场呈现出蓬勃发展的态势。一方面,贵金属如金、银、铂、钯等在众多领…...

用CSS 方式设置 table 样式

在现代Web开发中,使用CSS来设置table的样式是一种常见且强大的方法,它能让你的表格数据既美观又易于阅读。下面我将通过一个示例来展示如何使用现代CSS技巧来美化表格。 效果图 HTML 结构 首先,我们定义一个基本的HTML表格结构:…...

Elasticsearch7.x 集群迁移文档

一、集群样例信息 集群名称:escluster-ali-test 1、源集群:(source_cluster) 节点IP节点名称节点角色是否为master节点10.200.112.149es2.gj1.china-job.cndata,master是10.200.112.151es1.gj1.china-job.cndata,master否10.200.112.153es…...

高空抛物检测算法的应用场景解析

高空抛物事件频发,对公众安全构成严重威胁。无论是居民区还是商业中心,从高层建筑中丢弃物品都可能导致人员伤亡和财产损失。传统的监控手段多以事后追溯为主,无法在事发时及时预警和干预。为应对这一难题,视觉分析技术的发展为高…...

Leetcode 无重复字符的最长子串

算法思想: 滑动窗口:通过 start 和 end 来维护一个滑动窗口,start 指向当前窗口的起点,end 是当前窗口的末尾。滑动窗口中的字符都是无重复的。哈希表 charIndexMap:用于存储每个字符及其最近一次出现的位置。更新起始…...

用命令行的方式启动.netcore webapi

用命令行的方式启动.netcore web项目 进入指定的项目文件夹,比如我发布后的代码放在下面文件夹中 在此地址栏中输入“cmd”,打开命令提示符,进入到发布代码目录 命令行启动.netcore项目的命令为: dotnet 项目启动文件.dll --urls"ht…...

Spring6详细学习笔记(IOC+AOP)

一、Spring系统架构介绍 1.1、定义 Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器(框架)。Spring官网 Spring是一款主流的Java EE 轻量级开源框架,目的是用于简化Java企业级引用的开发难度和开发周期。从简单性、可测试性和松耦…...

@RequestMapping 基于哪个库进行通信

RequestMapping 是 Spring Framework 中用于处理 HTTP 请求的注解,主要用于定义控制器方法的请求映射。它并不直接基于某个特定的通信库,而是依赖于 Spring MVC 框架的核心功能。 1. Spring MVC RequestMapping 是 Spring MVC 的一部分,Spr…...

GPIO(General Purpose Input/Output)输入/输出

GPIO最简单的功能是输出高低电平;GPIO还可以被设置为输入功能,用于读取按键等输入信号;也可以将GPIO复用成芯片上的其他外设的控制引脚。 STM32F407ZGT6有8组IO。分别为GPIOA~GPIOH,除了GPIOH只有两个IO,其余每组IO有…...

两个pdf合并成一个pdf,这些pdf合并小技巧了解下

在日常工作和学习中,我们经常会遇到需要将多个PDF文件合并成一个文件的情况。这不仅可以提高文件管理的效率,还能让信息展示更加集中和便捷。今天就来给大家分享几种非常简单便捷的PDF合并小技巧,一起来学习下吧。 方法一:WPS WP…...

Transformer学习(2):自注意力机制

回顾 注意力机制 自注意力机制 自注意力机制中同样包含QKV,但它们是同源(Q≈K≈V),也就是来自相同的输入数据X,X可以分为 ( x 1 , x 2 , . . , x n ) (x_1,x_2,..,x_n) (x1​,x2​,..,xn​)。 而通过输入嵌入层(input embedding)&#xff0c…...

分类预测|基于粒子群优化径向基神经网络的数据分类预测Matlab程序PSO-RBF 多特征输入多类别输出 含基础RBF程序

分类预测|基于粒子群优化径向基神经网络的数据分类预测Matlab程序PSO-RBF 多特征输入多类别输出 含基础RBF程序 文章目录 一、基本原理1. 粒子群优化算法(PSO)2. 径向基神经网络(RBF)PSO-RBF模型流程总结 二、实验结果三、核心代码…...

【React】Vite 构建 React

项目搭建 vite 官网:Vite 跟着文档走即可,选择 react ,然后 ts swc。 着重说一下 package-lock.json 这个文件有两个作用: 锁版本号(保证项目在不同人手里安装的依赖都是相同的,解决版本冲突的问题&am…...

算法刷题:300. 最长递增子序列、674. 最长连续递增序列、718. 最长重复子数组

300. 最长递增子序列 1.dp定义:dp[i]表示i之前包括i的以nums[i]结尾的最长递增子序列的长度 2.递推公式:if (nums[i] > nums[j]) dp[i] max(dp[i], dp[j] 1); 注意这里不是要dp[i] 与 dp[j] 1进行比较,而是我们要取dp[j] 1的最大值…...

【linux】一种基于虚拟串口的方式使两个应用通讯

在Linux系统中,两个应用之间通过串口(Serial Port)进行通信是一种常见的通信方式,特别是在嵌入式系统、工业自动化等领域。串口通信通常涉及到对串口设备的配置和读写操作。以下是一个基本的步骤指南,说明如何在Linux中…...

并行程序设计基础——并行I/O(3)

目录 一、多视口的并行文件并行读写 1、文件视口与指针 1.1 MPI_FILE_SET_VIEW 1.2 MPI_FILE_GET_VIEW 1.3 MPI_FILE_SEEK 1.4 MPI_FILE_GET_POSTION 1.5 MPI_FILE_GET_BYTE_OFFSET 2、阻塞方式的视口读写 2.1 MPI_FILE_READ 2.2 MPI_FILE_WRITE 2.3 MPI_FILE_READ_…...

性能测试-jmeter脚本录制(十五)

一、jmeter脚本录制(不推荐)简介: 二、jmeter脚本录制步骤 1、添加代理服务器和线程组 2、配置http代理服务器的端口和目标线程组 3修改本机浏览器代理 4、点击启动 5、每次操作页面前,修改提示文字...

关系型数据库 - MySQL I

MySQL 数据库 MySQL 是一种关系型数据库。开源免费,并且方便扩展。在 Java 开发中常用于保存和管理数据。默认端口号 3306。 MySQL 数据库主要分为 Server 和存储引擎两部分,现在最常用的存储引擎是 InnoDB。 指令执行过程 MySQL 数据库接收到用户指令…...

解锁AI写作新境界:5款工具让你的论文创作事半功倍

在这个数字化飞速发展的时代,人工智能(AI)已经不再是科幻小说中的幻想,而是实实在在地融入了我们的日常生活。特别是在学术领域,AI技术的介入正在改变传统的论文写作方式。你是否还在为撰写论文而熬夜苦战?…...

一文读懂多组学联合分析产品在医学领域的应用

疾病的发生和发展通常涉及多个层面的生物学过程,包括基因表达、蛋白质功能、代谢物变化等。传统的单一组学研究只能提供某一层面的信息,而多组学关联分析能够综合多个层面的数据,提供更全面、更深入的疾病理解。例如,通过分析患者…...

js react 笔记 2

起因, 目的: 记录一些 js, react, css 1. 生成一个随机的 uuid // 需要先安装 crypto 模块 const { randomUUID } require(crypto);const uuid randomUUID(); console.log(uuid); // 输出类似 9b1deb4d-3b7d-4bad-9bdd-2b0d7b3dcb6d 2. 使用 props, 传递参数…...

快速使用react 全局状态管理工具--redux

redux Redux 是 JavaScript 应用中管理应用状态的工具,特别适用于复杂的、需要共享状态的中大型应用。Redux 的核心思想是将应用的所有状态存储在一个单一的、不可变的状态树(state tree)中,状态只能通过触发特定的 action 来更新…...

活动系统开发之采用设计模式与非设计模式的区别-非设计模式

1、父类Base.php <?php /*** 初始化控制器* User: Administrator* Date: 2022/9/26* Time: 18:00*/ declare (strict_types 1); namespace app\controller; use app\model\common\Token; use app\BaseController; use app\BaseError; use OpenSSL\Encrypt; use app\model…...

JVM面试(六)垃圾收集器

目录 概述STW收集器的并发和并行 Serial收集器ParNew收集器Parallel Scavenge收集器Serial Old收集器Parallel Old收集器CMS收集器Garbage First&#xff08;G1&#xff09;收集器 概述 上一章我们分析了垃圾收集算法&#xff0c;那这一章我们来认识一下这些垃圾收集器是如何运…...

固态硬盘装系统有必要分区吗?

前言 现在的新电脑有哪一台是不使用固态硬盘的呢&#xff1f;这个好像很少很少了…… 有个朋友买了一台新的笔记本电脑&#xff0c;开机之后&#xff0c;电脑只有一个分区&#xff08;系统C盘500GB&#xff09;。这时候她想要给笔记本分区…… 这个真的有必要分区吗&#xf…...

网络安全架构师

网络安全架构师负责构建全面的安全框架&#xff0c;以保护组织的数字资产免受侵害&#xff0c;确保组织在数字化转型的同时维持强大的安全防护。 摩根大通的网络安全运营副总裁兼安全架构总监Lester Nichols强调&#xff0c;成为网络安全架构师对现代企业至关重要&#xff0c;…...

如何本地部署Ganache并使用内网穿透配置公网地址远程连接测试网络

目录 前言 1. 安装Ganache 2. 安装cpolar 3. 创建公网地址 4. 公网访问连接 5. 固定公网地址 作者简介&#xff1a; 懒大王敲代码&#xff0c;计算机专业应届生 今天给大家聊聊如何本地部署Ganache并使用内网穿透配置公网地址远程连接测试网络&#xff0c;欢迎大家点赞 &a…...