LLM - LLaMA-2 获取文本向量并计算 Cos 相似度
目录
一.引言
二.获取文本向量
◆ model config
◆ get Embedding
三.获取向量 Cos 相似度
1.向量选择
2.Cos 相似度
3.BERT-whitening 特征白化
四.总结
一.引言
前面提到了两种基于统计的机器翻译评估方法: Rouge 与 BLEU,二者通过统计概率计算 N-Gram 的准确率与召回率,在机器翻译这种回答相对固定的场景该方法可以作为一定参考,但在当前大模型更加多样性的场景以及发散的回答的情况下,Rouge 与 BLEU 有时候并不能更好的描述文本之间的相似度,下面我们尝试从 LLM 大模型提取文本的 Embedding 并进行向量相似度计算。
二.获取文本向量
1.hidden_states 与 last_hidden_states
根据 LLM 模型类型的不同,有的 Model 提供 hidden_states 方法,例如 LLaMA-2-13B,有的模型提供 last_hidden_states 方法,例如 GPT-2。查找模型对应方法 API 可以在 Transformer 官网。
◆ hidden_states
hidden_states 类型为 typing.Optional[typing.Tuple[torch.FloatTensor]],其提供一个 Tuple[Tensor] 分别记录了每层的输出,完整的解释在参数下方:
模型在每一层输出处的隐藏状态加上可选的初始嵌入输出。这里我们可以通过打印模型 Layer 和索引从而获取 hidden_states 中隐层的输出。
◆ last_hidden_states
一些传统的模型例如 GPT-2,还有当下一些的新模型例如 ChatGLM2 都有 last_hidden_states 的 API,可以直接获取最后一层的 Embedding 输出,而如果使用 hidden_states 则只需要通过 [-1] 索引即可获得 last_hidden_states,相比来如前者更全面后者更方便。
2.LLaMA-2 获取 hidden_states
◆ model config
config_kwargs = {"trust_remote_code": True,"cache_dir": None,"revision": 'main',"use_auth_token": None,"output_hidden_states": True}config = AutoConfig.from_pretrained(ori_model_path, **config_kwargs)llama_model = AutoModelForCausalLM.from_pretrained(ori_model_path,config=config,torch_dtype=torch.float16,low_cpu_mem_usage=True,trust_remote_code=True,revision='main')
根据 CausalLMOutputWithPast hidden_states 参数的提示,我们只需要在模型 config 中添加:
"output_hidden_states": True
◆ get Embedding
def get_embeddings(result, llm_tokenizer, model, args):fw = open(args.output, 'w', encoding='utf-8')for qa in result:q = qa[0]a = qa[1]# 对输出文本进行 tokenize 和编码tokens = llm_tokenizer.encode_plus(a, add_special_tokens=True, padding='max_length', truncation=True,max_length=128, return_tensors='pt')input_ids = tokens["input_ids"]attention_mask = tokens['attention_mask']# 获取文本 Embeddingwith torch.no_grad():outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)embedding = list(outputs.hidden_states)last_hidden_states = embedding[-1].cpu().numpy()first_hidden_states = embedding[0].cpu().numpy()last_hidden_states = np.squeeze(last_hidden_states)first_hidden_states = np.squeeze(first_hidden_states)fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)log = "%s\t%s\t%s\n" % (q, a, toString(fisrt_larst_avg_status))fw.write(log)fw.close()
predict 预测 ➔ 将 model 基于 Question generate 得到的 Answer 存入 result
encode 编码 ➔ 对 Answer 进行编码获取对应 Token 与 input_ids、attention_mask
output 模型输出 ➔ 直接调用 model 进行输出,有的也可以调用 model.transform 方法进行输出
hidden_states ➔ outputs.hidden_states 获取各隐层输出
最后获取的向量需要先 cpu 然后再转为 numpy 数组,一般的做法是采用 mean 获得句子的平均表征。
三.获取向量 Cos 相似度
1.向量选择
在 BERT-flow 的论文中,如果不加任何后处理手段,那么基于 BERT 抽取句向量的最好 Pooling 方法是 BERT 的第一层与最后一层的所有 token 向量的平均,即 fisrt-larst-avg,对应 hidden_state 的 0 和 -1 索引,所以后面的相似度计算我们都以 fisrt-larst-avg 为基准来评估 Embedding 相似度。
# 获取文本 Embedding
with torch.no_grad():outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)embedding = list(outputs.hidden_states)last_hidden_states = embedding[-1].cpu().numpy()first_hidden_states = embedding[0].cpu().numpy()last_hidden_states = np.squeeze(last_hidden_states)first_hidden_states = np.squeeze(first_hidden_states)fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)
2.Cos 相似度
# 计算 Cos 相似度
def compute_cosine(a_vec, b_vec):norms1 = np.linalg.norm(a_vec, axis=1)norms2 = np.linalg.norm(b_vec, axis=1)dot_products = np.sum(a_vec * b_vec, axis=1)cos_similarities = dot_products / (norms1 * norms2)return cos_similarities
a_vec 为预测文本转化得到的 Embedding,b_vec 为人工标注正样本文本转化得到的 Embedding,通过计算二者相似度,评估预测文本与人工文本的相似程度。
3.BERT-whitening 特征白化
苏神在 BERT-whitening 一文中提出了一种基于 PCA 降维的无监督 Embedding 评估方式,Bert-whitening 又叫特征白化,其思路与 PCA 降维类似,意在对 SVD 分解后的主成分矩阵取前 λ 个特征向量构造特征值矩阵,提取向量中的关键信息,使输出向量矩阵每个维度均值为零,协方差矩阵为单位阵,λ 个特征值也对应前 λ 个主成分。其算法逻辑如下:
下面我们调用 Sklearn 的 PCA 库简单实现下:
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize# 取出句子的平均表示 -> 使用 PCA 降维 -> 白化处理concatenate = np.concatenate((answer_vector, predict_vector))pca = PCA(n_components=2048)pca.fit(concatenate)ans_white_vec = pca.transform(answer_vector)ans_norm_vec = normalize(ans_white_vec)pre_white_vec = pca.transform(predict_vector)pre_norm_vec = normalize(pre_white_vec)pca_cos_similarities = compute_cosine(ans_norm_vec, pre_norm_vec)
answec_vector 和 predict_vector 均通过 first_and_last 方法从 hidden_states 中获取,n_components 即 top_k 的选择,以 LLaMA-2 为例,原始得到的向量维度为 5120,原文中也有使用 n_components = 256 实验。
四.总结
博主采用 1500+ 样本分别使用 cos、pca 和 self_pca [自己实现 SVD 与特征矩阵] 三种方法对向量相似度进行评估,n_components 设为 1024:
可以看到 SVD 处理后得到的 W 和 mu 的 shape,通过下述操作可完成向量的降维:
vecs = (vecs + bias).dot(kernel)
最终得到的结果 Cosine 与 PCA 降维的相似度差距较大,由于自然语言生成的样本没有严格意义的正样本,上面计算采用的参考文本也是人工标注,有一定的不确定性,所以基于不同的度量,我们也可以统计分析,定一个 threshold,认为大于该 threshold 的输入样本为可用。
相关文章:

LLM - LLaMA-2 获取文本向量并计算 Cos 相似度
目录 一.引言 二.获取文本向量 1.hidden_states 与 last_hidden_states ◆ hidden_states ◆ last_hidden_states 2.LLaMA-2 获取 hidden_states ◆ model config ◆ get Embedding 三.获取向量 Cos 相似度 1.向量选择 2.Cos 相似度 3.BERT-whitening 特征白化 …...
【创建型设计模式】C#设计模式之工厂模式,以及通过反射实现动态工厂。
题目如下: 假设你正在为一家汽车制造公司编写软件。公司生产多种类型的汽车,包括轿车、SUV和卡车。每种汽车都有不同的特点和功能。请设计一个工厂模式,用于创建不同类型的汽车对象。该工厂模式应具有以下要求:工厂类名为 CarFac…...

可拖拽编辑的流程图X6
先上图 //index.html,有时候可能加载失败,那就再找一个别的cdn 或者npm下载,如果npm下载, //那么需要全局引入或者局部引入,代码里面写法也会不同,详细的可以看示例<script src"https://cdn.jsdeli…...

神经网络与卷积神经网络
全连接神经网络 概念及应用场景 全连接神经网络是一种深度学习模型,也被称为多层感知机(MLP)。它由多个神经元组成的层级结构,每个神经元都与前一层的所有神经元相连,它们之间的连接权重是可训练的。每个神经元都计算…...

《Java极简设计模式》第05章:原型模式(Prototype)
作者:冰河 星球:http://m6z.cn/6aeFbs 博客:https://binghe.gitcode.host 文章汇总:https://binghe.gitcode.host/md/all/all.html 源码地址:https://github.com/binghe001/java-simple-design-patterns/tree/master/j…...

OceanBase 4.1解读:读写兼备的DBLink让数据共享“零距离”
梁长青,OceanBase 高级研发工程师,从事 SQL 执行引擎相关工作,目前主要负责 DBLink、单机引擎优化等方面工作。 沈大川,OceanBase 高级研发工程师,从事 SQL 执行引擎相关工作,曾参与 TPC-H 项目攻坚&#x…...

STM32的HAL库的定时器使用
用HAL库老是忘记了定时器中断怎么配置,该调用哪个回调函数。今天记录一下,下次再忘了就来翻一下。 系统的时钟配置,定时器的时钟是84MHz 这里定时器时钟是84M,分频是8400后,时基就是1/10000s,即0.1ms。Per…...

Flink+Paimon多流拼接性能优化实战
目录 (零)本文简介 (一)背景 (二)探索梳理过程 (三)源码改造 (四)修改效果 1、JOB状态 2、Level5的dataFile总大小 3、数据延迟 (五&…...

cocos 2.4 版本 设置物理引擎步长 解决帧数不一致的设备 物理表现不一致问题 设置帧刷新率
官网地址Cocos Creator 3.8 手册 - 2D 物理系统 官网好像写的不太对 下面是我自己运行好使的 PhysicsManager.openPhysicsSystem()var manager cc.director.getPhysicsManager();// 开启物理步长的设置manager.enabledAccumulator true;// cc.PhysicsManagercc.PhysicsManag…...

Spark及其生态简介
一、Spark简介 Spark 是一个用来实现快速而通用的集群计算的平台,官网上的解释是:Apache Spark™是用于大规模数据处理的统一分析引擎。 Spark 适用于各种各样原先需要多种不同的分布式平台的场景,包括批处理、迭代算法、交互式查询、流处理…...

从Instagram到TikTok:利用社交媒体平台实现业务成功
自 2000年代初成立和随后兴起以来,社交媒体一直被大大小小的品牌用作高度针对性的营销工具,自 Facebook推出近二十年以来,这些网站继续彻底改变企业处理广告的方式。 在这篇博文中,我们将讨论订阅企业应该如何从整体上对待社交媒…...
单元测试
1. 单元测试Junit 1.1 什么是单元测试?(掌握) 对部分代码进行测试。 1.2 Junit的特点?(掌握) 是一个第三方的工具。(把别人写的代码导入项目中)(专业叫法:…...

科技云报道:AI+云计算共生共长,能否解锁下一个高增长空间?
科技云报道原创。 在过去近一年的时间里,AI大模型从最初的框架构建,逐步走到落地阶段。 然而,随着AI大模型深入到千行百业中,市场开始意识到通用大模型虽然功能强大,但似乎并不能完全满足不同企业的个性化需求。 大…...

ReactPy:使用 Python 构建动态前端应用程序
在 Web 开发领域,ReactJS 已成为主导者,为开发人员提供了用于创建动态和交互式用户界面的强大工具集。但是,如果您更喜欢 Python 的多功能性和简单性作为后端,并且希望在前端也利用它的功能,该怎么办?ReactPy 是一个 Python 库,它将熟悉的 ReactJS 语法和灵活性带入了 P…...
安全攻防基础以及各种漏洞库
安全攻防基础以及各种漏洞库 信息搜集企业信息搜集1. 企业架构2. ICP备案查询,确定目标子域名3. 员工信息(搜集账号信息、钓鱼攻击)4. 社交渠道 域名信息搜集IP搜集信息泄露移动端搜集打点进内网命令和控制(持续控制)穿…...

护眼灯值不值得买?开学给孩子买什么样的护眼台灯
如果不想家里的孩子年纪小小的就戴着眼镜,从小就容易近视,那么护眼灯的选择就非常重要了,但是市场上那么多品类,价格也参差不齐,到底怎么选呢?大家一定要看完本期内容。为大家推荐五款热门的护眼台灯 一、…...

windows安装Scala
Windows安装Scala 下载地址:https://downloads.lightbend.com/scala/2.11.11/scala-2.11.11.zip 解压完成之后 配置环境变量...

API类型和集成规范指南
在我们的常见应用中,往往包含着大量服务于各种数据交换的API类型、以及各种常见的API架构与协议。下面,我将从集成的角度和您讨论,在准备将多个服务相互集成时,使用不同类型、架构和协议的API意味着什么?我们可以使用哪些工具&am…...

[ES]mac安装es、kibana、ik分词器
一、安装es和kibana 1、创建一个网络,网络内的框架(eskibana)互联 docker network create es-net 2、下载es和kibana docker pull elasticsearch:7.12.1 docker pull kibana:7.12.1 3、运行docker命令部署单点eskibana(用来操作es) doc…...

YOLO目标检测——视觉显著性检测MSRA1000数据集下载分享
MSRA1000数据集是一个常用的视觉显著性检测数据集,它包含了1000张图像和对应的显著性标注。在以下几个应用场景中,MSRA1000数据集可以发挥重要作用:图像编辑和后期处理、图像检索和分类、视觉注意力模型、自动驾驶和智能交通等等 数据集点击下…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

微信小程序云开发平台MySQL的连接方式
注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...

深入解析C++中的extern关键字:跨文件共享变量与函数的终极指南
🚀 C extern 关键字深度解析:跨文件编程的终极指南 📅 更新时间:2025年6月5日 🏷️ 标签:C | extern关键字 | 多文件编程 | 链接与声明 | 现代C 文章目录 前言🔥一、extern 是什么?&…...
[Java恶补day16] 238.除自身以外数组的乘积
给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O(n) 时间复杂度…...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比
在机器学习的回归分析中,损失函数的选择对模型性能具有决定性影响。均方误差(MSE)作为经典的损失函数,在处理干净数据时表现优异,但在面对包含异常值的噪声数据时,其对大误差的二次惩罚机制往往导致模型参数…...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...

Python基于历史模拟方法实现投资组合风险管理的VaR与ES模型项目实战
说明:这是一个机器学习实战项目(附带数据代码文档),如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 在金融市场日益复杂和波动加剧的背景下,风险管理成为金融机构和个人投资者关注的核心议题之一。VaR&…...

Linux nano命令的基本使用
参考资料 GNU nanoを使いこなすnano基础 目录 一. 简介二. 文件打开2.1 普通方式打开文件2.2 只读方式打开文件 三. 文件查看3.1 打开文件时,显示行号3.2 翻页查看 四. 文件编辑4.1 Ctrl K 复制 和 Ctrl U 粘贴4.2 Alt/Esc U 撤回 五. 文件保存与退出5.1 Ctrl …...

C++ 设计模式 《小明的奶茶加料风波》
👨🎓 模式名称:装饰器模式(Decorator Pattern) 👦 小明最近上线了校园奶茶配送功能,业务火爆,大家都在加料: 有的同学要加波霸 🟤,有的要加椰果…...