当前位置: 首页 > article >正文

用infoNCE微调Embedding模型

infoNCE

代码1:(样本格式为query_n个positive_n个hardnegative)

  • PairwiseModel并不是模型,而是连接model和loss的一个包装类。
  • PairwiseModel接收两种类型样本 【query + pos pair】or【query + pos + neg triplet】。

  • CrossEntropyLoss还可以传入label_smoothing=0.05,用于对比学习。label_smoothing = 0.3时,label_smoothing 的作用是把硬标签 [0, 0, 1, 0] 平滑成类似 [0.1, 0.1, 0.7, 0.1],从而使得 CrossEntropyLoss 不再只惩罚预测不对的类,还会对非目标类的概率也做约束,使模型更加平滑稳定、泛化更强。
  • AutoModelForEmbedding的pooling_method选择mean还是cls根据模型来定,如果模型训练的时候用cls向量当做句子表征,则用cls。否则则用mean。

代码2:(样本格式为query_positive,只有正样本,负样本为batch内其他样本)

import os
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
model_name_or_path: str = '../model/m3e-base'
# model_name_or_path: str = '../model/bge-small-zh-v1.5'
batch_size: int = 2
epochs: int = 3
#数据集会按照dev、train、test划分。具体有哪个,得print来看,再用split="dev"获取 dev的部分。
train_dataset = load_dataset("../../dataset/C-MTEB/T2Reranking", split="dev") #这个数据集并不是 query_positive格式,而是query_n个positive,因此需要更改
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
train_model = PairwiseModel(model, loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
optimizer = AdamW(train_model.parameters(), lr=5e-5)
num_train_steps = int(len(train_dataset) / batch_size * epochs)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
training_arguments = TrainingArguments(output_dir='./checkpoints',num_train_epochs=epochs,per_device_train_batch_size=batch_size,remove_unused_columns=False,logging_steps=50,
)
# 处理后会得到一个两个key的dict,每个value是一个包含dict(其中包含input_ids、token_type_ids、attention_mask)
dc=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[64, 128]) 
trainer = RetrievalTrainer(model=train_model,args=training_arguments,train_dataset=train_dataset,data_collator=dc, # 相当于 自定义collate_fn 函数
)
trainer.optimizer = optimizer
trainer.scheduler = scheduler
trainer.train()

参考:

动手学习RAG: moka-ai/m3e 模型微调deepspeed与对比学习_m3e模型微调-CSDN博客

https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing#scrollTo=tI_Mc62ryYAm

相关文章:

用infoNCE微调Embedding模型

infoNCE 代码1:(样本格式为query_n个positive_n个hardnegative) PairwiseModel并不是模型,而是连接model和loss的一个包装类。 PairwiseModel接收两种类型样本 【query pos pair】or【query pos neg triplet】。 CrossEntropy…...

Debezium报错处理系列之第128篇:增量快照报错java.lang.OutOfMemoryError: Java heap space

Debezium报错处理系列之第128篇:增量快照报错java.lang.OutOfMemoryError: Java heap space 一、完整报错二、错误原因三、解决方法Debezium从入门到精通系列之:研究Debezium技术遇到的各种错误解决方法汇总: Debezium从入门到精通系列之:百篇系列文章汇总之研究Debezium技…...

AI——使用pandas

文章目录 1、pandas介绍2、为什么使用pandas3、pandas的数据结构1、Series2、DataFrame3、MultiIndex 4、pandas基本数据操作1、索引操作2、赋值操作3、排序4、算术运算5、逻辑运算6、逻辑运算函数7、统计函数8、累计统计函数9、自定义运算 5、pandas读取文件和存储1、csv文件2…...

SMT贴片组装工艺优化与高效生产

内容概要 现代SMT贴片组装工艺的优化与高效生产涉及多维度技术协同,其核心在于构建精密可控的制造体系。本文系统梳理了从焊接参数调控到智能检测部署的全链路关键环节,重点解析影响生产效能的核心变量及其相互作用机制。通过对比不同贴装设备的速度-精…...

2025认证杯挑战赛B题【 谣言在社交网络上的传播 】原创论文讲解(含完整python代码)

大家好呀,从发布赛题一直到现在,总算完成了认证杯数学中国数学建模网络挑战赛第一阶段B题目谣言在社交网络上的传播完整的成品论文。 本论文可以保证原创,保证高质量。绝不是随便引用一大堆模型和代码复制粘贴进来完全没有应用糊弄人的垃圾半…...

用docker容器创建属于自己的一方小世界!容器中,盖周天之变,化吾为王~

用docker容器创建属于自己的一方小世界!容器中,盖周天之变,化吾为王~ 分别查看用户id和组id。 命令: 1、id -u 2、id -g 创建并运行容器 docker run -d -p 31404:22 -v /home/liub:/home -v /data:/app/data --user 1004:1004 --…...

vue拓扑图组件

vue拓扑图组件 介绍技术栈功能特性快速开始安装依赖开发调试构建部署 使用示例演示截图组件源码 介绍 一个基于 Vue3 的拓扑图组件,具有以下特点: 1.基于 vue-flow 实现,提供流畅的拓扑图展示体验 2.支持传入 JSON 对象自动生成拓扑结构 3.自…...

前端防御性编程

关于防御性编程 你是否遇到过,接口请求失败或者返回数据错误,导致系统白屏或者前端自身写的代码存在一些缺陷,导致整个系统不够健壮,从而导致系统白屏 常见的问题与防范 最常见的问题 访问了null或者undefined的属性 null.a …...

Linux服务器网卡深度解析:从ifconfig输出到生产环境性能调优实战

Linux服务器网卡深度解析:从ifconfig输出到生产环境性能调优实战 Linux服务器网卡深度解析:从ifconfig输出到生产环境性能调优实战一、背景二、生产环境的服务器部署情况三、拆解一个真实的 ifconfig 输出1、先看 MAC 地址2、再看设备的 interrupt 和 me…...

【愚公系列】《Python网络爬虫从入门到精通》048-验证码识别(滑动拼图验证码)

🌟【技术大咖愚公搬代码:全栈专家的成长之路,你关注的宝藏博主在这里!】🌟 📣开发者圈持续输出高质量干货的"愚公精神"践行者——全网百万开发者都在追更的顶级技术博主! 👉 江湖人称"愚公搬代码",用七年如一日的精神深耕技术领域,以"…...

SpringBoot分布式项目中实现智能邮件提醒系统

一、应用场景与需求分析 在电商、OA、客服等系统中,邮件提醒是用户触达的重要方式。本文针对以下典型需求进行方案设计: 多类型支持:订单超时、服务到期、待办通知等场景动态内容:支持纯文本/HTML/模板引擎内容格式智能重发:24小时未处理自动升级提醒级别高可用性:分布式…...

对shell脚本敏感命令进行加密执行

我要加密这条命令:rm /root/scripty.sh 如何利用openssl aes-256-cbc 实现加密和解密,并执行命令 加密、解密并执行命令的完整流程 以下是使用 openssl aes-256-cbc 加密命令 rm /root/scripty.sh,解密并执行的详细步骤: 1. 加密…...

《嵌套调用与链式访问:C语言中的函数调用技巧》

🚀个人主页:BabyZZの秘密日记 📖收入专栏:C语言 🌍文章目入 一、嵌套调用(一)定义(二)实现方式(三)优点(四)缺点 二、链式…...

Python-控制语句

控制语句 控制语句和逻辑思维 控制语句:把语句组合成能完成一定功能的小逻辑模块分类:顺序、选择、循环“顺序结构”:代表“先执行a,再执行b”的逻辑“条件判断结构”:代表“如果…,则…”的逻辑“循环结构”:代表“如果…则重复执行…”的逻辑条件判断结构 选择结构通…...

教程:在Typora中显示拼音——附处理工具

原因 因为自己普通话不标准,希望可以制作适合自己的带拼音的文档,可以把平常看到的内容、说过的话作为练习普通话的材料。 在市面上,带拼音的材料、书籍并不多,而且有可能是一些比较生僻的内容。所以希望可以自己制作这样的材料…...

OpenCV 图形API(30)图像滤波-----腐蚀操作函数erode()

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 使用特定的结构元素腐蚀图像。 cv::gapi::erode 是 OpenCV 的 G-API 模块中用于执行图像腐蚀操作的函数。腐蚀是一种基本的形态学操作&#xff…...

【KWDB 创作者计划】第二卷:开发者实战篇

​KWDB技术白皮书卷二:开发者实战篇 ​1. 自然语言到量子查询的编译系统 1.1 NL2QSQL翻译引擎架构 运行时流程图解: ┌──────────────────────┐ ┌───────────────────┐ ┌─────────────…...

设计模式:里氏代换原则 - 继承设计的稳定之道

里氏代换原则(Liskov Substitution Principle, LSP)作为面向对象设计的基石之一,为我们提供了解决之道。它指导我们如何构建高扩展性和低维护成本的继承体系,避免代码行为不一致导致的混乱和错误。 一、错误的继承设计如何毁掉系…...

Node.js中fs模块详解

Node.js 中 fs 模块(非 Promise)API 详解 Node.js 的 fs 模块提供了同步和异步的文件系统操作。以下是非 Promise 版本的 API 详解: 1. 文件读取操作 const fs require(fs);// 异步读取文件 fs.readFile(file.txt, utf8, (err, data) >…...

特殊定制版,太给力了!

今天给大家分享一款超棒的免费录屏软件,真的是录屏的好帮手! 这款软件功能可以录制 MP4、AVI、WMV 格式的标清、高清、原画视频,满足你各种需求。 云豹录屏大师 多功能录屏神器 它的界面特别简洁,上手超快,用起来很顺…...

go:实现最简单区块链

1.新建文件夹命名为blockchain,在此文件夹下分别创建两个文件一个为block.go另一个为chain.go如下图所示: 2.写入代码: block.go package blockchainimport ("bytes""crypto/sha256""encoding/gob""log""strconv""ti…...

工业相机使用笔记

目前工业相机有多种分类方式,以下是基于不同原理和特点的类别总结: 按维度分类 2D相机: 原理:通过镜头将二维平面上的物体成像在图像传感器上,传感器上的像素点阵列捕捉物体的光信号,并转换为电信号或数字…...

系分论文《论面向服务开发方法在设备租赁行业的应用》

系统分析师论文系列 【摘要】 2022年5月,我司承接某工程机械租赁企业"智能租赁运营管理平台"建设项目,我作为系统分析师主导系统架构设计。该项目需整合8大类2000余台设备资产,覆盖全国15个区域运营中心与300家代理商,实…...

【Code】《代码整洁之道》笔记-Chapter12-迭进

第12章 迭进 12.1 通过迭进设计达到整洁目的 假使有4条简单的规则,跟着做就能帮助你创建优良的设计,会如何?假使遵循这些规则,你就能洞见代码的结构和设计,更能轻易地应用SRP和DIP之类的原则,便会如何&…...

04--网络属性设置与多路复用

一、TCP可靠性分析 二、 scoket 属性设置 1、socket 属性设置表 NAMEgetsockopt, setsockopt - get and set options on sockets获取 和 设置 套接字属性 SYNOPSIS#include <sys/types.h> /* See NOTES */#include <sys/socket.h>int getsockopt(int so…...

AI领域再突破,永洪科技荣获“2025人工智能+创新案例”奖

在2025年的今天&#xff0c;人工智能已从技术概念全面渗透至产业核心。中国作为全球AI技术应用的前沿阵地&#xff0c;正通过“人工智能”行动加速推进技术与实体经济深度融合。 这一背景下&#xff0c;永洪科技凭借其“国内某头部ICT人力资源板块GenAI项目”荣获“2025全国企业…...

基于疾风大模型的新能源储能优化系统:方法、实现与案例分析

一、引言 随着可再生能源渗透率不断提高,储能系统在电力系统中的重要性日益凸显。传统储能控制方法主要基于规则策略和简单优化算法,难以应对高比例新能源场景下的复杂决策需求。本文将详细介绍如何利用疾风大模型(Gale Model)构建智能化的新能源储能优化系统,包含核心方…...

菊风RTC 2.0 开发者文档正式发布,解锁音视频新体验!

重磅发布&#xff01; 开发者们&#xff0c;菊风实时音视频2.0文档已正式发布上线&#xff0c;为您提供更清晰、更高效的开发支持&#xff01;让菊风实时音视频2.0为您的音视频应用加速~ 菊风实时音视频2.0聚焦性能升级、体验升级、录制服务升级&#xff0c;助力视频通话、语…...

12c补丁滚动升级

12c打补丁前置检查 备份文件&#xff0c;可以不做&#xff0c;因为文件可能很大&#xff0c;如果可以备份整个安装文件。 1.check grid&#xff1a; % /u01/app/12.1.0/grid/OPatch/opatch prereq CheckConflictAgainstOHWithDetail -phBaseDir /home/software/27010872/2691…...

OpenCv高阶(一)——图像金字塔(上采样、下采样)

目录 图像金字塔 一、上下采样原理 1、向下取样 2、向上采样 3、图像金字塔的作用 二、案例实现 1、高斯下采样 2、高斯金字塔中的上采样 3、对下采样的结果做上采样&#xff0c;图像变模糊&#xff0c;无法复原 4、拉普拉斯金字塔&#xff08;图片复原&#xff09; 图…...