当前位置: 首页 > news >正文

NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

文本匹配多用于计算两个文本之间的相似度,该示例会基于 ESimCSE 实现一个无监督的文本匹配模型的训练流程。文本匹配多用于计算两段「自然文本」之间的「相似度」。

例如,在搜索引擎中,我们通常需要判断用户的搜索内容是否相似:

A:蛋黄吃多了有什么坏处    B:吃鸡蛋白过多有什么坏处  ->  不相似
A:蛋黄吃多了有什么坏处    B:蛋黄可以多吃吗         ->  相似
...

那最直觉的思路就是让人工去标注文本对,再喂给模型去学习,这种方法称为基于「监督学习」训练出的模型:

但是,如果我们今天没有这么多的标注数据,只有一大堆的「未标注」数据,我们还能训练一个匹配模型吗?这种不依赖于「人工标注数据」的方式,就叫做「无监督」(或自监督)学习方式。我们今天要讲的 SimCSE, 就是一种「无监督」训练模型。

SimCSE: Simple Contrastive Learning of Sentence Embeddings

1.SimCSE 是如何做到无监督的?

SimCSE 将对比学习(Contrastive Learning)的思想引入到文本匹配中。对比学习的核心思想就是:将相似的样本拉近,将不相似的样本推远

但现在问题是:我们没有标注数据,怎么知道哪些文本是相似的,哪些是不相似的呢?SimCSE 相出了一种很妙的办法,由于预训练模型在训练的时候通常都会使用 dropout 机制。这就意味着:即使是同一个样本过两次模型也会得到两个不同的 embedding。而因为同样的样本,那一定是相似的,模型输出的这两个 embedding 距离就应当尽可能的相近;反之,那些不同的输入样本过模型后得到的 embedding 就应当尽可能的被推远。

具体来讲,一个 batch 内每个句子会过 2 次模型,得到 2 * batch 个向量,将这些句子中通过同样句子得到的向量设置为正例,其他设置为负例。

假设 a1 和 a2 是由句子 a 过两次模型得到的结果,那么一个 batch 内的正负例构建如下所示:

a1a2b1b2c1c2
a1-10010000
a21-1000000
b100-100100
b2001-10000
c10000-1001
c200001-100

其中,对角线上的 - 100 表示自身和自身不做相似度比较。

2. SimCSE 的缺点?

从 SimCSE 的正例构建中我们可以看出来,所有的正例都是由「同一个句子」过了两次模型得到的。这就会造成一个问题:模型会更倾向于认为,长度相同的句子就代表一样的意思。由于数据样本是随机选取的,那么很有可能在一个 batch 内采样到的句子长度是不相同的。

为了解决这个问题,我们最终采取的实现方式为 ESimCSE

3. ESimCSE 解决模型对文本长度的敏感问题

ESimCSE 通过随机重复单词(Word Repetition)的方式来构建正例,巧妙的解决了句子长度敏感性的问题:

ESimCSE: Enhanced Sample Building Method for Contrastive Learning of Unsupervised Sentence Embedding

要想消除模型对句子长度的敏感,我们就需要在构建正例的时候让输入句子的长度发生改变,如下所示:

那么,改变句子长度通常有 3 种方法:随机删除、随机添加、同义词替换,但它们均存在句意变化的风险:

方法原句子变换后的句子句意是否改变
随机删除我 [不] 喜欢你我喜欢你
随机添加今天的饭好吃今天的饭 [不] 好吃
同义词替换小明长得像一只 [狼]小明长得像一只 [狗]

用语义变换后的句子去构建正例,模型效果自然会受到影响。

那如果我们随机重复一些单词呢?

方法原句子变换后的句子句意是否改变
随机重复单词今天天气很好今今天天气很好好
随机重复单词我喜欢你我我喜欢欢你

可以看到,通过随机重复单词,既能够改变句子长度,又不会轻易改变语义。

实现上,假设我们有一个 batch 的句子,我们先依次将每一个句子都进行随机单词重复(产生正例),如下:

origin ->     ['人和畜生的区别', '今天天气很好', '三星手机屏幕是不是最好的?']
repetition -> ['人人和畜生的的区别', '今今天天气很好好', '三星星手机屏屏幕是不是最最好好的?']

随后,我们将 origin 的 embedding(batch,768) 和 repetition 的 embedding(batch,768)做矩阵乘法,可以得到一个矩阵(batch,batch),矩阵对角线上就是正例,其余的均是负例:

句子 a句子 b句子 c
句子 a0.92480.23420.4242
句子 b0.31420.91230.1422
句子 c0.29030.18570.9983

矩阵中第(i,j)个元素代表 origin 列表中的第 i 个元素和 repetition 列表中第 j 个元素的相似度。

接下来就好构建训练标签了,因为 label 都在对角线上,所以第 n 行的 label 就是 n 。

labels = [i for i in range(len(origin))]     # labels = [0, 1, 2]

之后就用 CrossEntropyLoss 去计算并梯度回传就能开始训练啦。

def forward(self,query_input_ids: torch.tensor,query_token_type_ids: torch.tensor,doc_input_ids: torch.tensor,doc_token_type_ids: torch.tensor,device='cpu') -> torch.tensor:"""传入query/doc对,构建正/负例并计算contrastive loss。Args:query_input_ids (torch.LongTensor): (batch, seq_len)query_token_type_ids (torch.LongTensor): (batch, seq_len)doc_input_ids (torch.LongTensor): (batch, seq_len)doc_token_type_ids (torch.LongTensor): (batch, seq_len)device (str): 使用设备Returns:torch.tensor: (1)"""query_embedding = self.get_pooled_embedding(input_ids=query_input_ids,token_type_ids=query_token_type_ids)                                                           # (batch, self.output_embedding_dim)doc_embedding = self.get_pooled_embedding(input_ids=doc_input_ids,token_type_ids=doc_token_type_ids)                                                           # (batch, self.output_embedding_dim)cos_sim = torch.matmul(query_embedding, doc_embedding.T)    # (batch, batch)margin_diag = torch.diag(torch.full(                        # (batch, batch), 只有对角线等于margin值的对角矩阵size=[query_embedding.size()[0]], fill_value=self.margin)).to(device)cos_sim = cos_sim - margin_diag                             # 主对角线(正例)的余弦相似度都减掉 margincos_sim *= self.scale                                       # 缩放相似度,便于收敛labels = torch.arange(                                      # 只有对角上为正例,其余全是负例,所以这个batch样本标签为 -> [0, 1, 2, ...]0, query_embedding.size()[0], dtype=torch.int64).to(device)loss = self.criterion(cos_sim, labels)return loss

4.DiffCSE

结合句子间差异的无监督句子嵌入对比学习方法——DiffCSE主要还是在SimCSE上进行优化(可见SimCSE的重要性),通过ELECTRA模型的生成伪造样本和RTD(Replaced Token Detection)任务,来学习原始句子与伪造句子之间的差异,以提高句向量表征模型的效果。

其思想同样来自于CV领域(采用不变对比学习和可变对比学习相结合的方法可以提高图像表征的效果)。作者提出使用基于dropout masks机制的增强作为不敏感转换学习对比学习损失和基于MLM语言模型进行词语替换的方法作为敏感转换学习「原始句子与编辑句子」之间的差异,共同优化句向量表征。

在SimCSE模型中,采用pooler层(一个带有tanh激活函数的全连接层)作为句子向量输出。该论文发现,采用带有BN的两层pooler效果更为突出,BN在SimCSE模型上依然有效。

①对于掩码概率,经实验发现,在掩码概率为30%时,模型效果最优。
②针对两个损失之间的权重值,经实验发现,对比学习损失为RTD损失200倍时,模型效果最优。

参考链接:https://blog.csdn.net/PX2012007/article/details/127696477

5. 数据集准备

项目中提供了一部分示例数据,我们使用未标注的用户搜索记录数据来训练一个文本匹配模型,数据在 data/LCQMC

若想使用自定义数据训练,只需要仿照示例数据构建数据集即可:

  • 训练集:
喜欢打篮球的男生喜欢什么样的女生
我手机丢了,我想换个手机
大家觉得她好看吗
晚上睡觉带着耳机听音乐有什么害处吗?
学日语软件手机上的
...
  • 测试集:
开初婚未育证明怎么弄?	初婚未育情况证明怎么开?	1
谁知道她是网络美女吗?	爱情这杯酒谁喝都会醉是什么歌	0
人和畜生的区别是什么?	人与畜生的区别是什么!	1
男孩喝女孩的尿的故事	怎样才知道是生男孩还是女孩	0
...

由于是无监督训练,因此训练集(train.txt)中不需要记录标签,只需要大量的文本即可。

测试集(dev.tsv)用于测试无监督模型的效果,因此需要包含真实标签。

每一行用 \t 分隔符分开,第一部分部分为句子A,中间部分为句子B,最后一部分为两个句子是否相似(label)

6.模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \--model "nghuyong/ernie-3.0-base-zh" \--train_path "data/LCQMC/train.txt" \--dev_path "data/LCQMC/dev.tsv" \--save_dir "checkpoints/LCQMC" \--img_log_dir "logs/LCQMC" \--img_log_name "ERNIE-ESimCSE" \--learning_rate 1e-5 \--dropout 0.3 \--batch_size 64 \--max_seq_len 64 \--valid_steps 400 \--logging_steps 50 \--num_train_epochs 8 \--device "cuda:0"

正确开启训练后,终端会打印以下信息:

...
0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 226.41it/s]
DatasetDict({train: Dataset({features: ['text'],num_rows: 477532})dev: Dataset({features: ['text'],num_rows: 8802})
})
global step 50, epoch: 1, loss: 0.34367, speed: 2.01 step/s
global step 100, epoch: 1, loss: 0.19121, speed: 2.02 step/s
global step 150, epoch: 1, loss: 0.13498, speed: 2.00 step/s
global step 200, epoch: 1, loss: 0.10696, speed: 1.99 step/s
global step 250, epoch: 1, loss: 0.08858, speed: 2.02 step/s
global step 300, epoch: 1, loss: 0.07613, speed: 2.02 step/s
global step 350, epoch: 1, loss: 0.06673, speed: 2.01 step/s
global step 400, epoch: 1, loss: 0.05954, speed: 1.99 step/s
Evaluation precision: 0.58459, recall: 0.87210, F1: 0.69997, spearman_corr: 
0.36698
best F1 performence has been updated: 0.00000 --> 0.69997
global step 450, epoch: 1, loss: 0.25825, speed: 2.01 step/s
global step 500, epoch: 1, loss: 0.27889, speed: 1.99 step/s
global step 550, epoch: 1, loss: 0.28029, speed: 1.98 step/s
global step 600, epoch: 1, loss: 0.27571, speed: 1.98 step/s
global step 650, epoch: 1, loss: 0.26931, speed: 2.00 step/s
...

logs/LCQMC 文件下将会保存训练曲线图:

7.模型推理

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

...if __name__ == '__main__':...sentence_pair = [('男孩喝女孩的故事', '怎样才知道是生男孩还是女孩'),('这种图片是用什么软件制作的?', '这种图片制作是用什么软件呢?')]...res = inference(query_list, doc_list, model, tokenizer, device)print(res)

运行推理程序:

python inference.py

得到以下推理结果:

[0.1527191698551178, 0.9263839721679688]   # 第一对文本相似分数较低,第二对文本相似分数较高

参考链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/text_matching/supervised

github无法连接的可以在:https://download.csdn.net/download/sinat_39620217/88214437 下载

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

相关文章:

NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

NLP文本匹配任务Text Matching [无监督训练]&#xff1a;SimCSE、ESimCSE、DiffCSE 项目实践 文本匹配多用于计算两个文本之间的相似度&#xff0c;该示例会基于 ESimCSE 实现一个无监督的文本匹配模型的训练流程。文本匹配多用于计算两段「自然文本」之间的「相似度」。 例如…...

复习vue3,简简单单记录

这里的知识是结合视频以及其他文章一起学习&#xff0c;仅用于个人复习记录 ref 和reactive ref 用于基本类型 reactive 用于引用类型 如果使用ref 传递对象&#xff0c;修改值时候需要写为obj.value.attr 方式修改属性值 如果使用reactive 处理对象&#xff0c;直接obj.att…...

【自用】云服务器 docker 环境下 HomeAssistant 安装 HACS 教程

一、进入 docker 中的 HomeAssistant 1.查找 HomeAssistant 的 CONTAINER ID 连接上云服务器&#xff08;宿主机&#xff09;后&#xff0c;终端内进入 root &#xff0c;输入&#xff1a; docker ps找到了 docker 的 container ID 2.config HomeAssistant 输入下面的命令&…...

使用dockerfile手动构建JDK11镜像运行容器并校验

Docker官方维护镜像的公共仓库网站 Docker Hub 国内无法访问了&#xff0c;大部分镜像无法下载&#xff0c;准备逐步构建自己的镜像库。【转载aliyun官方-容器镜像服务 ACR】Docker常见问题 阿里云容器镜像服务ACR&#xff08;Alibaba Cloud Container Registry&#xff09;是面…...

编程语言学习笔记-架构师和工程师的区别,PHP架构师之路

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌&#xff0c;CSDN博客专家&#xff0c;阿里云社区专家博主&#xff0c;2023年6月CSDN上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责…...

Streamlit 讲解专栏(十):数据可视化-图表绘制详解(上)

文章目录 1 前言2 st.line_chart&#xff1a;绘制线状图3 st.area_chart&#xff1a;绘制面积图4 st.bar_chart&#xff1a;绘制柱状图5 st.pyplot&#xff1a;绘制自定义图表6 结语 1 前言 在数据可视化的世界中&#xff0c;绘制清晰、易于理解的图表是非常关键的。Streamlit…...

其他行业跳槽转入计算机领域简单看法

其他行业跳槽转入计算机领域简单看法 本人选择从以下几个方向谈谈自己的想法和观点。 先看一下总体图&#xff0c;下面会详细分析 如何规划才能实现转码 自我评估和目标设定&#xff1a;首先&#xff0c;你需要评估自己的技能和兴趣&#xff0c;确定你希望在计算机领域从事…...

Unity制作一个简单的登入注册页面

1.创建Canvas组件 首先我们创建一个Canvas画布&#xff0c;我们再在Canvas画布底下创建一个空物体&#xff0c;取名为Resgister。把空物体的锚点设置为全屏撑开。 2.我们在Resgister空物体底下创建一个Image组件&#xff0c;改名为bg。我们也把它 的锚点设置为全屏撑开状态。接…...

常用游戏运营指标DAU、LTV及参考范围

文章目录 前言运营指标指标范围参考值留存指标的意义总结 前言 作为游戏人免不了听到 DAU 、UP值、留存 等名词&#xff0c;并且有些名词听起来还很像&#xff0c;特别是一款上线的游戏&#xff0c;这些游戏运营指标是衡量游戏业务绩效和用户参与度的重要数据&#xff0c;想做…...

标准模板库STL——deque和list

deque概述 deque属于顺序容器&#xff0c;称为双端队列容器 底层数据结构是动态二维数组&#xff0c;从整体上看&#xff0c;deque的内存不连续 初始数组第一维数量为2&#xff0c;必要时进行2倍扩容 每次第一维扩容后&#xff0c;原来数组第二维元素从新数组下标为OldSize/2的…...

分类预测 | MATLAB实现WOA-CNN-BiGRU-Attention数据分类预测

分类预测 | MATLAB实现WOA-CNN-BiGRU-Attention数据分类预测 目录 分类预测 | MATLAB实现WOA-CNN-BiGRU-Attention数据分类预测分类效果基本描述模型描述程序设计参考资料 分类效果 基本描述 1.Matlab实现WOA-CNN-BiGRU-Attention多特征分类预测&#xff0c;多特征输入模型&…...

C++ Primer Plus 第6版 读书笔记(10) 第十章 类与对象

第十章 类与对象 在面向对象编程中&#xff0c;类和对象是两个重要的概念。 类&#xff08;Class&#xff09;是一种用户自定义的数据类型&#xff0c;用于封装数据和操作。它是对象的模板或蓝图&#xff0c;描述了对象的属性&#xff08;成员变量&#xff09;和行为&#xf…...

基于C++ 的OpenCV绘制多边形,多边形多条边用不用的颜色绘制

使用基于C的OpenCV库来绘制多边形&#xff0c;并且为多边形的不同边使用不同的颜色&#xff0c;可以按照以下步骤进行操作&#xff1a; 首先&#xff0c;确保你已经安装了OpenCV库并配置好了你的开发环境。 导入必要的头文件&#xff1a; #include <opencv2/opencv.hpp&g…...

(六)、深度学习框架中的算子

1、深度学习框架算子的基本概念 深度学习框架中的算子&#xff08;operator&#xff09;是指用于执行各种数学运算和操作的函数或类。这些算子通常被用来构建神经网络的各个层和组件&#xff0c;实现数据的传递、转换和计算。 算子是深度学习模型的基本组成单元&#xff0c;它们…...

Redis实现共享Session

Redis实现共享Session 分布式系统中&#xff0c;sessiong共享有很多的解决方案&#xff0c;其中托管到缓存中应该是最常用的方案之一。 1、引入依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM…...

网络通信原理UDP协议(第五十课)

UDP协议:用户数据包协议,无连接、不可靠,效率高 字段长度描述Source Port2字节标识哪个应用程序发送(发送进程)。Destination Port2字节标识哪个应用程序接收(接收进程)。Length2字节UDP首部加上UDP数据的字节数,最小为8。Checksum2字节覆盖UDP首部和UDP数据,是可…...

43、TCP报文(一)

本节内容开始&#xff0c;我们正式学习TCP协议中具体的一些原理。首先&#xff0c;最重要的内容仍然是这个协议的封装结构和首部格式&#xff0c;因为这里面牵扯到一些环环相扣的知识点&#xff0c;例如ACK、SYN等等&#xff0c;如果这些内容不能很好的理解&#xff0c;那么后续…...

【JavaScript】使用js实现滑块验证码功能与浏览器打印

滑块验证码 效果图&#xff1a; 实现思路&#xff1a; 根据滑块的最左侧点跟最右侧点&#xff0c; 是否在规定的距离内【页面最左侧为原点】&#xff0c;来判断是否通过 html代码&#xff1a; <!DOCTYPE html> <html><head><title>滑动图片验证码&…...

【使用群晖远程链接drive挂载电脑硬盘】

文章目录 前言1.群晖Synology Drive套件的安装1.1 安装Synology Drive套件1.2 设置Synology Drive套件1.3 局域网内电脑测试和使用 2.使用cpolar远程访问内网Synology Drive2.1 Cpolar云端设置2.2 Cpolar本地设置2.3 测试和使用 3. 结语 前言 群晖作为专业的数据存储中心&…...

easyx图形库基础4:贪吃蛇

贪吃蛇 一实现贪吃蛇&#xff1a;1.绘制网格&#xff1a;1.绘制蛇&#xff1a;3.控制蛇的默认移动向右&#xff1a;4.控制蛇的移动方向&#xff1a;5.生成食物6.判断蛇吃到食物并且长大。7.判断游戏结束&#xff1a;8.重置函数&#xff1a; 二整体代码&#xff1a; 一实现贪吃蛇…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank&#xff1f;由于时间太久&#xff0c;我真忘记了。搜搜发现&#xff0c;还真有人和我一样。见下面的链接&#xff1a;https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

Leetcode 3576. Transform Array to All Equal Elements

Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接&#xff1a;3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到&#xf…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时&#xff0c;需结合业务场景设计数据流转链路&#xff0c;重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点&#xff1a; 一、核心对接场景与目标 商品数据同步 场景&#xff1a;将1688商品信息…...

dedecms 织梦自定义表单留言增加ajax验证码功能

增加ajax功能模块&#xff0c;用户不点击提交按钮&#xff0c;只要输入框失去焦点&#xff0c;就会提前提示验证码是否正确。 一&#xff0c;模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...

python如何将word的doc另存为docx

将 DOCX 文件另存为 DOCX 格式&#xff08;Python 实现&#xff09; 在 Python 中&#xff0c;你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是&#xff0c;.doc 是旧的 Word 格式&#xff0c;而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...

C++八股 —— 单例模式

文章目录 1. 基本概念2. 设计要点3. 实现方式4. 详解懒汉模式 1. 基本概念 线程安全&#xff08;Thread Safety&#xff09; 线程安全是指在多线程环境下&#xff0c;某个函数、类或代码片段能够被多个线程同时调用时&#xff0c;仍能保证数据的一致性和逻辑的正确性&#xf…...

代码随想录刷题day30

1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币&#xff0c;另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额&#xff0c;返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...

从 GreenPlum 到镜舟数据库:杭银消费金融湖仓一体转型实践

作者&#xff1a;吴岐诗&#xff0c;杭银消费金融大数据应用开发工程师 本文整理自杭银消费金融大数据应用开发工程师在StarRocks Summit Asia 2024的分享 引言&#xff1a;融合数据湖与数仓的创新之路 在数字金融时代&#xff0c;数据已成为金融机构的核心竞争力。杭银消费金…...

Docker拉取MySQL后数据库连接失败的解决方案

在使用Docker部署MySQL时&#xff0c;拉取并启动容器后&#xff0c;有时可能会遇到数据库连接失败的问题。这种问题可能由多种原因导致&#xff0c;包括配置错误、网络设置问题、权限问题等。本文将分析可能的原因&#xff0c;并提供解决方案。 一、确认MySQL容器的运行状态 …...