当前位置: 首页 > news >正文

LLM - LLaMA-2 获取文本向量并计算 Cos 相似度

 

目录

一.引言

二.获取文本向量

1.hidden_states 与 last_hidden_states

◆ hidden_states

◆ last_hidden_states 

2.LLaMA-2 获取 hidden_states

◆ model config 

◆ get Embedding

三.获取向量 Cos 相似度

1.向量选择

2.Cos 相似度

3.BERT-whitening 特征白化

四.总结


一.引言

前面提到了两种基于统计的机器翻译评估方法: Rouge 与 BLEU,二者通过统计概率计算 N-Gram 的准确率与召回率,在机器翻译这种回答相对固定的场景该方法可以作为一定参考,但在当前大模型更加多样性的场景以及发散的回答的情况下,Rouge 与 BLEU 有时候并不能更好的描述文本之间的相似度,下面我们尝试从 LLM 大模型提取文本的 Embedding 并进行向量相似度计算。

二.获取文本向量

1.hidden_states 与 last_hidden_states

根据 LLM 模型类型的不同,有的 Model 提供 hidden_states 方法,例如 LLaMA-2-13B,有的模型提供 last_hidden_states 方法,例如 GPT-2。查找模型对应方法 API 可以在 Transformer 官网。

 hidden_states

hidden_states 类型为 typing.Optional[typing.Tuple[torch.FloatTensor]],其提供一个 Tuple[Tensor] 分别记录了每层的输出,完整的解释在参数下方: 

模型在每一层输出处的隐藏状态加上可选的初始嵌入输出。这里我们可以通过打印模型 Layer 和索引从而获取 hidden_states 中隐层的输出。

◆ last_hidden_states 

一些传统的模型例如 GPT-2,还有当下一些的新模型例如 ChatGLM2 都有 last_hidden_states 的 API,可以直接获取最后一层的 Embedding 输出,而如果使用 hidden_states 则只需要通过 [-1] 索引即可获得 last_hidden_states,相比来如前者更全面后者更方便。

2.LLaMA-2 获取 hidden_states

model config 

    config_kwargs = {"trust_remote_code": True,"cache_dir": None,"revision": 'main',"use_auth_token": None,"output_hidden_states": True}config = AutoConfig.from_pretrained(ori_model_path, **config_kwargs)llama_model = AutoModelForCausalLM.from_pretrained(ori_model_path,config=config,torch_dtype=torch.float16,low_cpu_mem_usage=True,trust_remote_code=True,revision='main')

 根据 CausalLMOutputWithPast hidden_states 参数的提示,我们只需要在模型 config 中添加:

"output_hidden_states": True

get Embedding

def get_embeddings(result, llm_tokenizer, model, args):fw = open(args.output, 'w', encoding='utf-8')for qa in result:q = qa[0]a = qa[1]# 对输出文本进行 tokenize 和编码tokens = llm_tokenizer.encode_plus(a, add_special_tokens=True, padding='max_length', truncation=True,max_length=128, return_tensors='pt')input_ids = tokens["input_ids"]attention_mask = tokens['attention_mask']# 获取文本 Embeddingwith torch.no_grad():outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)embedding = list(outputs.hidden_states)last_hidden_states = embedding[-1].cpu().numpy()first_hidden_states = embedding[0].cpu().numpy()last_hidden_states = np.squeeze(last_hidden_states)first_hidden_states = np.squeeze(first_hidden_states)fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)log = "%s\t%s\t%s\n" % (q, a, toString(fisrt_larst_avg_status))fw.write(log)fw.close()

predict  预测       ➔  将 model 基于 Question generate 得到的 Answer 存入 result

encode 编码       ➔  对 Answer 进行编码获取对应 Token 与 input_ids、attention_mask

output 模型输出  ➔  直接调用 model 进行输出,有的也可以调用 model.transform 方法进行输出

hidden_states     ➔  outputs.hidden_states 获取各隐层输出

最后获取的向量需要先 cpu 然后再转为 numpy 数组,一般的做法是采用 mean 获得句子的平均表征。

三.获取向量 Cos 相似度

1.向量选择

在 BERT-flow 的论文中,如果不加任何后处理手段,那么基于 BERT 抽取句向量的最好 Pooling 方法是 BERT 的第一层与最后一层的所有 token 向量的平均,即 fisrt-larst-avg,对应 hidden_state 的 0 和 -1 索引,所以后面的相似度计算我们都以 fisrt-larst-avg 为基准来评估 Embedding 相似度。

# 获取文本 Embedding
with torch.no_grad():outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)embedding = list(outputs.hidden_states)last_hidden_states = embedding[-1].cpu().numpy()first_hidden_states = embedding[0].cpu().numpy()last_hidden_states = np.squeeze(last_hidden_states)first_hidden_states = np.squeeze(first_hidden_states)fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)

2.Cos 相似度

# 计算 Cos 相似度
def compute_cosine(a_vec, b_vec):norms1 = np.linalg.norm(a_vec, axis=1)norms2 = np.linalg.norm(b_vec, axis=1)dot_products = np.sum(a_vec * b_vec, axis=1)cos_similarities = dot_products / (norms1 * norms2)return cos_similarities

a_vec 为预测文本转化得到的 Embedding,b_vec 为人工标注正样本文本转化得到的 Embedding,通过计算二者相似度,评估预测文本与人工文本的相似程度。

3.BERT-whitening 特征白化

苏神在 BERT-whitening 一文中提出了一种基于 PCA 降维的无监督 Embedding 评估方式,Bert-whitening 又叫特征白化,其思路与 PCA 降维类似,意在对 SVD 分解后的主成分矩阵取前 λ 个特征向量构造特征值矩阵,提取向量中的关键信息,使输出向量矩阵每个维度均值为零,协方差矩阵为单位阵,λ 个特征值也对应前 λ 个主成分。其算法逻辑如下:

 下面我们调用 Sklearn 的 PCA 库简单实现下:

from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize# 取出句子的平均表示 -> 使用 PCA 降维 -> 白化处理concatenate = np.concatenate((answer_vector, predict_vector))pca = PCA(n_components=2048)pca.fit(concatenate)ans_white_vec = pca.transform(answer_vector)ans_norm_vec = normalize(ans_white_vec)pre_white_vec = pca.transform(predict_vector)pre_norm_vec = normalize(pre_white_vec)pca_cos_similarities = compute_cosine(ans_norm_vec, pre_norm_vec)

answec_vector 和 predict_vector 均通过 first_and_last 方法从 hidden_states 中获取,n_components 即 top_k 的选择,以 LLaMA-2 为例,原始得到的向量维度为 5120,原文中也有使用 n_components = 256 实验。

四.总结

博主采用 1500+ 样本分别使用 cos、pca 和 self_pca [自己实现 SVD 与特征矩阵] 三种方法对向量相似度进行评估,n_components 设为 1024:

可以看到 SVD 处理后得到的 W 和 mu 的 shape,通过下述操作可完成向量的降维:

vecs = (vecs + bias).dot(kernel)

最终得到的结果 Cosine 与 PCA 降维的相似度差距较大,由于自然语言生成的样本没有严格意义的正样本,上面计算采用的参考文本也是人工标注,有一定的不确定性,所以基于不同的度量,我们也可以统计分析,定一个 threshold,认为大于该 threshold 的输入样本为可用。

相关文章:

LLM - LLaMA-2 获取文本向量并计算 Cos 相似度

目录 一.引言 二.获取文本向量 1.hidden_states 与 last_hidden_states ◆ hidden_states ◆ last_hidden_states 2.LLaMA-2 获取 hidden_states ◆ model config ◆ get Embedding 三.获取向量 Cos 相似度 1.向量选择 2.Cos 相似度 3.BERT-whitening 特征白化 …...

【创建型设计模式】C#设计模式之工厂模式,以及通过反射实现动态工厂。

题目如下: 假设你正在为一家汽车制造公司编写软件。公司生产多种类型的汽车,包括轿车、SUV和卡车。每种汽车都有不同的特点和功能。请设计一个工厂模式,用于创建不同类型的汽车对象。该工厂模式应具有以下要求:工厂类名为 CarFac…...

可拖拽编辑的流程图X6

先上图 //index.html&#xff0c;有时候可能加载失败&#xff0c;那就再找一个别的cdn 或者npm下载&#xff0c;如果npm下载&#xff0c; //那么需要全局引入或者局部引入&#xff0c;代码里面写法也会不同&#xff0c;详细的可以看示例<script src"https://cdn.jsdeli…...

神经网络与卷积神经网络

全连接神经网络 概念及应用场景 全连接神经网络是一种深度学习模型&#xff0c;也被称为多层感知机&#xff08;MLP&#xff09;。它由多个神经元组成的层级结构&#xff0c;每个神经元都与前一层的所有神经元相连&#xff0c;它们之间的连接权重是可训练的。每个神经元都计算…...

《Java极简设计模式》第05章:原型模式(Prototype)

作者&#xff1a;冰河 星球&#xff1a;http://m6z.cn/6aeFbs 博客&#xff1a;https://binghe.gitcode.host 文章汇总&#xff1a;https://binghe.gitcode.host/md/all/all.html 源码地址&#xff1a;https://github.com/binghe001/java-simple-design-patterns/tree/master/j…...

OceanBase 4.1解读:读写兼备的DBLink让数据共享“零距离”

梁长青&#xff0c;OceanBase 高级研发工程师&#xff0c;从事 SQL 执行引擎相关工作&#xff0c;目前主要负责 DBLink、单机引擎优化等方面工作。 沈大川&#xff0c;OceanBase 高级研发工程师&#xff0c;从事 SQL 执行引擎相关工作&#xff0c;曾参与 TPC-H 项目攻坚&#x…...

STM32的HAL库的定时器使用

用HAL库老是忘记了定时器中断怎么配置&#xff0c;该调用哪个回调函数。今天记录一下&#xff0c;下次再忘了就来翻一下。 系统的时钟配置&#xff0c;定时器的时钟是84MHz 这里定时器时钟是84M&#xff0c;分频是8400后&#xff0c;时基就是1/10000s&#xff0c;即0.1ms。Per…...

Flink+Paimon多流拼接性能优化实战

目录 &#xff08;零&#xff09;本文简介 &#xff08;一&#xff09;背景 &#xff08;二&#xff09;探索梳理过程 &#xff08;三&#xff09;源码改造 &#xff08;四&#xff09;修改效果 1、JOB状态 2、Level5的dataFile总大小 3、数据延迟 &#xff08;五&…...

cocos 2.4 版本 设置物理引擎步长 解决帧数不一致的设备 物理表现不一致问题 设置帧刷新率

官网地址Cocos Creator 3.8 手册 - 2D 物理系统 官网好像写的不太对 下面是我自己运行好使的 PhysicsManager.openPhysicsSystem()var manager cc.director.getPhysicsManager();// 开启物理步长的设置manager.enabledAccumulator true;// cc.PhysicsManagercc.PhysicsManag…...

Spark及其生态简介

一、Spark简介 Spark 是一个用来实现快速而通用的集群计算的平台&#xff0c;官网上的解释是&#xff1a;Apache Spark™是用于大规模数据处理的统一分析引擎。 Spark 适用于各种各样原先需要多种不同的分布式平台的场景&#xff0c;包括批处理、迭代算法、交互式查询、流处理…...

从Instagram到TikTok:利用社交媒体平台实现业务成功

自 2000年代初成立和随后兴起以来&#xff0c;社交媒体一直被大大小小的品牌用作高度针对性的营销工具&#xff0c;自 Facebook推出近二十年以来&#xff0c;这些网站继续彻底改变企业处理广告的方式。 在这篇博文中&#xff0c;我们将讨论订阅企业应该如何从整体上对待社交媒…...

单元测试

1. 单元测试Junit 1.1 什么是单元测试&#xff1f;&#xff08;掌握&#xff09; 对部分代码进行测试。 1.2 Junit的特点&#xff1f;&#xff08;掌握&#xff09; 是一个第三方的工具。&#xff08;把别人写的代码导入项目中&#xff09;&#xff08;专业叫法&#xff1a;…...

科技云报道:AI+云计算共生共长,能否解锁下一个高增长空间?

科技云报道原创。 在过去近一年的时间里&#xff0c;AI大模型从最初的框架构建&#xff0c;逐步走到落地阶段。 然而&#xff0c;随着AI大模型深入到千行百业中&#xff0c;市场开始意识到通用大模型虽然功能强大&#xff0c;但似乎并不能完全满足不同企业的个性化需求。 大…...

ReactPy:使用 Python 构建动态前端应用程序

在 Web 开发领域,ReactJS 已成为主导者,为开发人员提供了用于创建动态和交互式用户界面的强大工具集。但是,如果您更喜欢 Python 的多功能性和简单性作为后端,并且希望在前端也利用它的功能,该怎么办?ReactPy 是一个 Python 库,它将熟悉的 ReactJS 语法和灵活性带入了 P…...

安全攻防基础以及各种漏洞库

安全攻防基础以及各种漏洞库 信息搜集企业信息搜集1. 企业架构2. ICP备案查询&#xff0c;确定目标子域名3. 员工信息&#xff08;搜集账号信息、钓鱼攻击&#xff09;4. 社交渠道 域名信息搜集IP搜集信息泄露移动端搜集打点进内网命令和控制&#xff08;持续控制&#xff09;穿…...

护眼灯值不值得买?开学给孩子买什么样的护眼台灯

如果不想家里的孩子年纪小小的就戴着眼镜&#xff0c;从小就容易近视&#xff0c;那么护眼灯的选择就非常重要了&#xff0c;但是市场上那么多品类&#xff0c;价格也参差不齐&#xff0c;到底怎么选呢&#xff1f;大家一定要看完本期内容。为大家推荐五款热门的护眼台灯 一、…...

windows安装Scala

Windows安装Scala 下载地址&#xff1a;https://downloads.lightbend.com/scala/2.11.11/scala-2.11.11.zip 解压完成之后 配置环境变量...

API类型和集成规范指南

在我们的常见应用中&#xff0c;往往包含着大量服务于各种数据交换的API类型、以及各种常见的API架构与协议。下面&#xff0c;我将从集成的角度和您讨论&#xff0c;在准备将多个服务相互集成时&#xff0c;使用不同类型、架构和协议的API意味着什么?我们可以使用哪些工具&am…...

[ES]mac安装es、kibana、ik分词器

一、安装es和kibana 1、创建一个网络&#xff0c;网络内的框架(eskibana)互联 docker network create es-net 2、下载es和kibana docker pull elasticsearch:7.12.1 docker pull kibana:7.12.1 3、运行docker命令部署单点eskibana&#xff08;用来操作es&#xff09; doc…...

YOLO目标检测——视觉显著性检测MSRA1000数据集下载分享

MSRA1000数据集是一个常用的视觉显著性检测数据集&#xff0c;它包含了1000张图像和对应的显著性标注。在以下几个应用场景中&#xff0c;MSRA1000数据集可以发挥重要作用&#xff1a;图像编辑和后期处理、图像检索和分类、视觉注意力模型、自动驾驶和智能交通等等 数据集点击下…...

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站&#xff0c;会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后&#xff0c;网站没有变化的情况。 不熟悉siteground主机的新手&#xff0c;遇到这个问题&#xff0c;就很抓狂&#xff0c;明明是哪都没操作错误&#x…...

基于算法竞赛的c++编程(28)结构体的进阶应用

结构体的嵌套与复杂数据组织 在C中&#xff0c;结构体可以嵌套使用&#xff0c;形成更复杂的数据结构。例如&#xff0c;可以通过嵌套结构体描述多层级数据关系&#xff1a; struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架&#xff0c;它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用&#xff0c;和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​&#xff1a; 下载安装DevEco Studio 4.0&#xff08;支持HarmonyOS 5&#xff09;配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​&#xff1a; ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问&#xff1a;说说对 IP 数据报中 TTL 的理解&#xff1f;我们都知道&#xff0c;IP 数据报由首部和数据两部分组成&#xff0c;首部又分为两部分&#xff1a;固定部分和可变部分&#xff0c;共占 20 字节&#xff0c;而即将讨论的 TTL 就位于首…...

dify打造数据可视化图表

一、概述 在日常工作和学习中&#xff0c;我们经常需要和数据打交道。无论是分析报告、项目展示&#xff0c;还是简单的数据洞察&#xff0c;一个清晰直观的图表&#xff0c;往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server&#xff0c;由蚂蚁集团 AntV 团队…...

Fabric V2.5 通用溯源系统——增加图片上传与下载功能

fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下&#xff0c;企业和个人创作者为了扩大影响力、提升传播效果&#xff0c;纷纷采用短视频矩阵运营策略&#xff0c;同时管理多个平台、多个账号的内容发布。然而&#xff0c;频繁的文案创作需求让运营者疲于应对&#xff0c;如何高效产出高质量文案成…...