当前位置: 首页 > news >正文

小白入门:sentence-transformer 提取embedding模型转onnx

文章目录

  • 序言
  • 原理讲解
    • 哪些部分可转onnx
  • 代码区
    • 0. 安装依赖
    • 1. 路径配置
    • 2. 测试数据
    • 3. 准备工作
      • 3.1迁移保存目标文件
    • 4. model转onnx-gpu
    • 5. 测试一下是否出错以及速度
      • 5.1 测试速度是否OK
      • 5.2测试结果是否OK
    • 6. tar 这些文件

序言

本文适合小白入门,以自己训练的句子embedding模型为例,像大家展示了如何手动将sentence-transformer的模型转为onnx。

很多时候,我也不知道这段代码啥意思,但是作为应用人员,不要在意这段代码到底干了啥,除非必要。

这里不仅展示了如何转onnx,还有你部署时候,所需要的所有的文件,都打包到一个文件夹中了。

原理讲解

哪些部分可转onnx

onnx转换的时候,tokenizer部分是无法被onnx的,只有你backone模型才能进行转onnx,不要问我为啥,因为我也不知道。
我的模型使用代码如下:

from sentence_transformers import SentenceTransformer, models# 1. backone模型,这里用的bert-small
bert_model = models.Transformer("all-MiniLM-L6-v2") # 2. bert_model得到的是所有单词的向量,这些向量通过pool变成一个向量,
# 再通过normalize变成单位向量,即可进行dot,计算得到cosine相似度。
pool = models.Pooling(bert_model.get_word_embedding_dimension())
normalize = models.Normalize()# 模型组装
mymodel = SentenceTransformer(modules=[bert_model, pool, normalize])

代码区

0. 安装依赖

pip install onnx
pip install onnxruntime
pip install onnxruntime-gpu
先CPU然后GPU,不按顺序装可能会出现问题

1. 路径配置

import os# 你自己的模型
raw_model_dir = "../model/model11_all-MiniLM-L6-v2/"
abspath, raw_model_name = os.path.split(os.path.abspath(raw_model_dir))
# onnx后,所需要的文件,都转到了这个文件夹中
onnx_dir = os.path.join(abspath, raw_model_name+"_onnx-gpu/")if not os.path.exists(onnx_dir):os.mkdir(onnx_dir)print("build dir:", onnx_dir)

2. 测试数据

titles = ["Treehobby Metal 2PCS Front CVD Drive Shafts RC Cars Upgrade Parts for WLtoys 144001 1/14 RC Car Truck Buggy Replacement Accessories", 
"Solar System for Kids Space Toys, 8 Planets for Kids Solar System Model with Projector, Stem Educational Toys for 5 Year Old Boys Gift", 
"Bella Haus Design Peeing Gnome - 10.3 Tall Polyresin - Naughty Garden Gnome for Lawn Ornaments, Indoor or Outdoor Decorations - Red and Green Funny Flashing Gnomes", 
"FATExpress CMX500 CMX300 Parts Motorcycle CNC Front Fork Boot Shock Absorber Tube Slider Cover Gaiters for 2017 2018 2019 Rebel CMX 300 500 17-19 (Black)", 
"All Balls Racing 56-133-1 Fork and Dust Seal Kit", 
"Shaluoman Plating 5-Spoke Wheel Rims with Hard Plastic Tires for RC 1:10 Drift Car Color Black", 
"Betonex 5pcs PLASTIK MOLDS Casting Concrete Paving Garden Paths Pavement Stone Patio#S25", 
"OwnMy 5.2 Inch Rainbow Crystal Lotus Candle Tealight Holder Candlestick, Glass Votive Candle Lamps Holder Night Light Candlestick with Gift Box for Altar Windowsill Home Decor Christmas Wedding Party", 
"cnomg Pot Creative Plants DIY Container Pot Mini Fairy Garden Flower Plants and Sweet House for Decoration, Holiday Decoration, Indoor Decoration and Gift (Silver)", 
"DUSICHIN DUS-018 Foam Cannon Lance Pressure Washer Nozzle Tip Spray Gun 3000 PSI Jet Wash", 
"Haoohu Multicolored Bucket Hat for Women Men Girls Frog Fisherman Hat Beach Sun Hat for Outdoor Travel", 
"Renzline Pool CUE Glove Billiard Player - Green/Black - for Left Hand - One Size fits All",
"Hobbywing QUICRUN WP 1080 brushed (2-3S) Electronic Speed Controller Waterproof ESC With Program Box LED BEC XT60-Plug RC Car 1:10 30112750", 
"Mk Morse CSM868NTSC Metal Devil NXT Metal Cutting Circular Saw Blade, Thin Steel, 8-Inch Diameter, 68 TPI, 5/8-Inch Arbor, multi", 
"Barbie Fashionistas Doll 109", 
"KeShi Cordless Rotary Tool, Upgraded 3.7V Li-ion Rotary Accessory Kit with 42 Pieces Swap-able Heads, 3-Speed and USB Charging Multi-Purpose Power Tool for Delicate & Light DIY Small Projects", 
"White Knight 1707SBK-20AM Black Chrome M12x1.50 Bulge Acorn Lug Nut, 20 Pack", 
"Memory Foam Bath Mat Rug,16x24 Inches,Luxury Non Slip Washable Bath Rugs for Bathroom,Soft Absorbent Floor Mats of Green Leaves for Kitchen Bedroom Indoor", 
"DEWIN Airbrush Kit, Multi-purpose Airbrush Sets with Compressor -Dual Action 0.3mm 7CC Capacity Mini Air Compressor Spray Gun for Paint Makeup Tattoo Cake Decoration, Art Tattoo Nail Design", 
"Park Tool BBT-69.2 16-Notch Bottom Bracket Tool - Fits Shimano, SRAM, Chris King, Campagnolo, etc.", 
"ElaDeco 216 Ft Artificial Vines Garland Leaf Ribbon Greenery Foliage Rattan Greek Wild Jungle Decorative Accessory Wedding Party Garden Craft Wall Decoration"]

3. 准备工作

def load_all_model(path):#从modules.json读取模型路径modules_json_path = os.path.join(path, 'modules.json')with open(modules_json_path) as fIn:modules_config = json.load(fIn)from_backbone_path = os.path.join(path, modules_config[0].get('path'))from_pooling_path = os.path.join(path, modules_config[1].get('path'))from_Normalize_path = os.path.join(path, modules_config[2].get('path'))return from_backbone_path, from_pooling_path, from_Normalize_path
from_backbone_path, from_pooling_path, from_Normalize_path = load_all_model(raw_model_dir)
from transformers import AutoConfig, AutoModel, AutoTokenizer
model = AutoModel.from_pretrained(from_backbone_path)
tokenizer = AutoTokenizer.from_pretrained(from_backbone_path)
inputs = tokenizer(titles, padding=True, truncation=True, max_length=256, return_tensors="pt")
import torch
from sentence_transformers import modelspooling = models.Pooling.load(from_pooling_path)
normalize = models.Normalize.load(from_Normalize_path)

3.1迁移保存目标文件

import shutil
_dir, pooling_end_dir = os.path.split(from_pooling_path)
shutil.copytree(from_pooling_path, os.path.join(onnx_dir, pooling_end_dir))_dir, normalize_end_dir = os.path.split(from_Normalize_path)
shutil.copytree(from_Normalize_path, os.path.join(onnx_dir, normalize_end_dir))
'../model/model11_all-MiniLM-L6-v2_onnx-gpu-test/2_Normalize'
def copy_tokenize_filename(filename):full_filename = os.path.join(from_backbone_path, filename)return shutil.copy(full_filename, os.path.join(onnx_dir, filename))print(copy_tokenize_filename("tokenizer.json"))
print(copy_tokenize_filename("tokenizer_config.json"))
print(copy_tokenize_filename("vocab.txt"))
../model/model11_all-MiniLM-L6-v2_onnx-gpu-test/tokenizer.json
../model/model11_all-MiniLM-L6-v2_onnx-gpu-test/tokenizer_config.json
../model/model11_all-MiniLM-L6-v2_onnx-gpu-test/vocab.txt

4. model转onnx-gpu

device = torch.device("cuda:0")
model.eval()
model.to(device)
inputs = inputs.to(device)
export_model_path = os.path.join(onnx_dir, "model.onnx")with torch.no_grad():symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}torch.onnx.export(model,  # model being runargs=tuple(inputs.values()),f=export_model_path,opset_version=12,  # 这个值传说12比11好,当然取决于onnx和onnxruntimedo_constant_folding=True, input_names=['input_ids',  'attention_mask','token_type_ids'],output_names=['start', 'end'], dynamic_axes={'input_ids': symbolic_names,  'attention_mask': symbolic_names,'token_type_ids': symbolic_names,'start': symbolic_names,'end': symbolic_names})print("Model exported at ", export_model_path)
Model exported at  ../model/model11_all-MiniLM-L6-v2_onnx-gpu-test/model.onnx

5. 测试一下是否出错以及速度

5.1 测试速度是否OK

import onnxruntime
from torch import Tensor
export_model_path = os.path.join(onnx_dir, "model.onnx")
device = torch.device("cuda:0")
sess_options = onnxruntime.SessionOptions()
sess_options.optimized_model_filepath = export_model_path
session = onnxruntime.InferenceSession(export_model_path, sess_options, providers=['CUDAExecutionProvider']) # 你的是安装在cuda
2023-07-21 17:54:55.912264962 [W:onnxruntime:, session_state.cc:1136 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2023-07-21 17:54:55.912385419 [W:onnxruntime:, session_state.cc:1138 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
2023-07-21 17:54:56.222846005 [W:onnxruntime:, inference_session.cc:1491 Initialize] Serializing optimized model with Graph Optimization level greater than ORT_ENABLE_EXTENDED and the NchwcTransformer enabled. The generated model may contain hardware specific optimizations, and should only be used in the same environment the model was optimized in.
pooling_gpu = pooling.cuda()
normalize_gpu = normalize.cuda()
import time
begin = time.time()
for i in range(1000):inputs = tokenizer(titles, padding=True, truncation=True, max_length=256, return_tensors="pt")ort_inputs = {k: v.cpu().numpy() for k, v in inputs.items()}ort_outputs = session.run(None, ort_inputs)ort_outputs1 = pooling_gpu.forward(features={'token_embeddings': Tensor(ort_outputs[0]),'attention_mask': Tensor(ort_inputs.get('attention_mask'))})ort_outputs2 = normalize_gpu.forward(ort_outputs1)['sentence_embedding']
end = time.time()    
print("cost time:", end-begin)
cost time: 31.3445
begin = time.time()
for i in range(1000):inputs = tokenizer(titles, padding=True, truncation=True, max_length=256, return_tensors="np")ort_inputs = dict(inputs)ort_outputs = session.run(None, ort_inputs)ort_outputs1 = pooling_gpu.forward(features={'token_embeddings': Tensor(ort_outputs[0]).to(device),'attention_mask': Tensor(ort_inputs.get('attention_mask')).to(device)})ort_outputs2 = normalize_gpu.forward(ort_outputs1)['sentence_embedding']
end = time.time()    
print("cost time:", end-begin)
cost time: 19.234

5.2测试结果是否OK

from sentence_transformers import SentenceTransformerst_model = SentenceTransformer(raw_model_dir)
x = st_model.encode(titles)
import numpy as np
np.abs((x - ort_outputs2.cpu().numpy())).sum()
0.00010381325

误差数值很小,结果OK

6. tar 这些文件

abs_onnx_dir = os.path.abspath(onnx_dir)
# _dir, onnx_name = os.path.split(abs_onnx_dir)
os.system(f"tar -cf {abs_onnx_dir[:-1]}.tar {abs_onnx_dir}")
# f"tar -cf {abs_onnx_dir[:-1]}.tar {abs_onnx_dir}"
tar: Removing leading `/' from member names
0

相关文章:

小白入门:sentence-transformer 提取embedding模型转onnx

文章目录 序言原理讲解哪些部分可转onnx 代码区0. 安装依赖1. 路径配置2. 测试数据3. 准备工作3.1迁移保存目标文件 4. model转onnx-gpu5. 测试一下是否出错以及速度5.1 测试速度是否OK5.2测试结果是否OK 6. tar 这些文件 序言 本文适合小白入门,以自己训练的句子e…...

数据库应用:Redis持久化

目录 一、理论 1.Redis 高可用 2.Redis持久化 3.RDB持久化 4.AOF持久化(支持秒级写入) 5.RDB和AOF的优缺点 6.RDB和AOF对比 7.Redis性能管理 8.Redis的优化 二、实验 1.RDB持久化 2.AOF持久化 3.Redis性能管理 4.Redis的优化 三、总结 一、…...

js版计算比亚迪行驶里程连续12个月计算不超3万公里改进版带echar

<!DOCTYPE html> <html lang"zh-CN" style"height: 100%"> <head> <meta charset"utf-8" /> <title>连续12个月不超3万公里计算LIGUANGHUA</title> <style> .clocks { …...

一文详解Spring Bean循环依赖

一、背景 有好几次线上发布老应用时&#xff0c;遭遇代码启动报错&#xff0c;具体错误如下&#xff1a; Caused by: org.springframework.beans.factory.BeanCurrentlyInCreationException: Error creating bean with name xxxManageFacadeImpl: Bean with name xxxManageFa…...

基于PHP+ vue2 + element +mysql自主研发的医院不良事件上报系统

医院不良事件上报管理系统源码 不良事件上报是为了响应卫生部下发的等级医院评审细则中第三章第9条规定&#xff1a;医院要有主动报告医疗安全&#xff08;不良&#xff09;事件的制度与工作流程。由医疗机构医院或医疗机构报告医疗安全不良事件信息&#xff0c;利用报告进行研…...

微服务远程调用openFeign简单回顾(内附源码示例)

目录 一. OpenFeign简介 二. OpenFeign原理 演示使用 provider模块 消费者模块 配置全局feign日志 示例源代码: 一. OpenFeign简介 OpenFeign是SpringCloud服务调用中间件&#xff0c;可以帮助代理服务API接口。并且可以解析SpringMVC的RequestMapping注解下的接口&#x…...

【云计算小知识】云环境是什么意思?有什么优点?

随着云计算的快速发展&#xff0c;了解云计算相关知识也是运维人员必备的。那你知道云环境是什么意思&#xff1f;有什么优点&#xff1f;云环境安全威胁有哪些&#xff1f;如何保证云环境的运维安全&#xff1f;这里我们就来简单聊聊。 云环境是什么意思&#xff1f; 云环境是…...

【搜索引擎Solr】Apache Solr 神经搜索

Sease[1] 与 Alessandro Benedetti&#xff08;Apache Lucene/Solr PMC 成员和提交者&#xff09;和 Elia Porciani&#xff08;Sease 研发软件工程师&#xff09;共同为开源社区贡献了 Apache Solr 中神经搜索的第一个里程碑。 它依赖于 Apache Lucene 实现 [2] 进行 K-最近邻…...

PostgreSQL 设置时区,时间/日期函数汇总

文章目录 前言查看时区修改时区时间/日期操作符和函数时间/日期操作符日期/时间函数&#xff1a;extract&#xff0c;date_part函数支持的field 数据类型格式化函数用于日期/时间格式化的模式&#xff1a; 扩展 前言 本文基于 PostgreSQL 12.6 版本&#xff0c;不同版本的函数…...

性能测试Ⅱ(压力测试与负载测试详解)

协议 性能理论&#xff1a;并发编程 &#xff0c;系统调度&#xff0c;调度算法 监控 压力测试与负载测试的区别是什么&#xff1f; 负载测试 在被测系统上持续不断的增加压力&#xff0c;直到性能指标(响应时间等)超过预定指标或者某种资源(CPU&内存)使用已达到饱和状…...

【Python入门系列】第十八篇:Python自然语言处理和文本挖掘

文章目录 前言一、Python常用的NLP和文本挖掘库二、Python自然语言处理和文本挖掘1、文本预处理和词频统计2、文本分类3、命名实体识别4、情感分析5、词性标注6、文本相似度计算 总结 前言 Python自然语言处理&#xff08;Natural Language Processing&#xff0c;简称NLP&…...

【GD32F103】自定义程序库08-DMA+ADC

DMA 自定义函数库说明: 将DMA先关的变量方式在一个机构体中封装起来,主要参数有 dma外设,时钟,通道,外设寄存器地址,数据传输宽度,数据方向,外设是能dma传输使能回调函数,扫描模式中断编号dma中断使能传输完成标志数据存储空间使用一个枚举类型指明每个DMA绑定到那个…...

集成了Eureka的应用启动失败,端口号变为8080

问题 报错&#xff1a;集成了Eureka的应用启动失败&#xff0c;端口号变为8080。 原来运行的项目&#xff0c;突然报错&#xff0c;端口号变为8080&#xff1a; Tomcat initialized with port(s): 8080 (http)并且&#xff0c;还有如下的错误提示&#xff1a; RedirectingE…...

CMU 15-445 -- Timestamp Ordering Concurrency Control - 15

CMU 15-445 -- Timestamp Ordering Concurrency Control - 15 引言Basic T/OBasic T/O ReadsBasic T/O WritesBasic T/O - Example #1Basic T/O - Example #2 Basic T/O SummaryRecoverable Schedules Optimistic Concurrency Control (OCC)OCC - ExampleSERIAL VALIDATIONOCC …...

MURF2080CT/MURF2080CTR-ASEMI快恢复对管

编辑&#xff1a;ll MURF2080CT/MURF2080CTR-ASEMI快恢复对管 型号&#xff1a;MURF2080CT/MURF2080CTR 品牌&#xff1a;ASEMI 芯片个数&#xff1a;2 芯片尺寸&#xff1a;102MIL*2 封装&#xff1a;TO-220F 恢复时间&#xff1a;50ns 工作温度&#xff1a;-50C~150C…...

去除 idea warn Raw use of parameterized class ‘Map‘

去除 idea warn Raw use of parameterized class ‘Map’ 文档&#xff1a;Raw use of parameterized class ‘Map’… 链接&#xff1a;http://note.youdao.com/noteshare?id99bf4003db8cc5ae9813ee11e58c4d13&sub5856371AEFA740AF8FA4D8935B4F6912 添加链接描述 public…...

使用BERT分类的可解释性探索

最近尝试了使用BERT将告警信息当成一个文本去做分类&#xff0c;从分类的准召率上来看&#xff0c;还是取得了不错的效果&#xff08;非结构化数据强标签训练&#xff0c;BERT确实是一把大杀器&#xff09;。但准召率并不是唯一追求的目标&#xff0c;在安全场景下&#xff0c;…...

web APIs-练习二

轮播图点击切换&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" content"…...

rpc通信原理浅析

rpc通信原理浅析 rpc(remote procedure call)&#xff0c;即远程过程调用&#xff0c;广泛用于分布式或是异构环境下的通信&#xff0c;数据格式一般采取protobuf。 protobuf&#xff08;protocol buffer&#xff09;是google 的一种数据交换的格式&#xff0c;它独立于平台语…...

【机器学习】分类算法 - KNN算法(K-近邻算法)KNeighborsClassifier

「作者主页」&#xff1a;士别三日wyx 「作者简介」&#xff1a;CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」&#xff1a;零基础快速入门人工智能《机器学习入门到精通》 K-近邻算法 1、什么是K-近邻算法&#xff1f;2、K-近邻算法API3、…...

OpenLayers 可视化之热力图

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 热力图&#xff08;Heatmap&#xff09;又叫热点图&#xff0c;是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

Debian系统简介

目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版&#xff…...

centos 7 部署awstats 网站访问检测

一、基础环境准备&#xff08;两种安装方式都要做&#xff09; bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats&#xff0…...

Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务

通过akshare库&#xff0c;获取股票数据&#xff0c;并生成TabPFN这个模型 可以识别、处理的格式&#xff0c;写一个完整的预处理示例&#xff0c;并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务&#xff0c;进行预测并输…...

剑指offer20_链表中环的入口节点

链表中环的入口节点 给定一个链表&#xff0c;若其中包含环&#xff0c;则输出环的入口节点。 若其中不包含环&#xff0c;则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么&#xff1f; WebAssembly&#xff08;WASM&#xff09; 是一种能在现代浏览器中高效运行的二进制指令格式&#xff0c;它不是传统的编程语言&#xff0c;而是一种 低级字节码格式&#xff0c;可由高级语言&#xff08;如 C、C、Rust&am…...

汇编常见指令

汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX&#xff08;不访问内存&#xff09;XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...

CSS设置元素的宽度根据其内容自动调整

width: fit-content 是 CSS 中的一个属性值&#xff0c;用于设置元素的宽度根据其内容自动调整&#xff0c;确保宽度刚好容纳内容而不会超出。 效果对比 默认情况&#xff08;width: auto&#xff09;&#xff1a; 块级元素&#xff08;如 <div>&#xff09;会占满父容器…...

C++.OpenGL (20/64)混合(Blending)

混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...