深度学习 - 模型的保存与部署方式汇总
深度学习模型保存和加载格式科普
在深度学习中,模型的保存和加载是非常重要的环节。不同的格式有不同的特点和适用场景。本文将为新手朋友们介绍几种常见的模型格式,包括它们的简介、保存方式、加载方式、优缺点以及应用场景。
1. PyTorch (.pth, .pt)
简介:PyTorch 的默认保存格式,灵活支持保存整个模型、模型的权重和优化器状态。
保存方式:
import torch
torch.save(model.state_dict(), 'model.pth')
加载方式:
model.load_state_dict(torch.load('model.pth'))
model.eval()
部署代码:
from flask import Flask, request, jsonifyapp = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():data = request.jsontext = data['text']inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)with torch.no_grad():outputs = model(**inputs)logits = outputs.logitspredictions = torch.argmax(logits, dim=-1)return jsonify({'prediction': predictions.item()})if __name__ == '__main__':from transformers import BertTokenizer, BertForSequenceClassificationtokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertForSequenceClassification.from_pretrained('bert-base-uncased')model.load_state_dict(torch.load('model.pth'))model.eval()app.run(host='0.0.0.0', port=5000)
优点:
- 高度灵活:支持复杂的模型和训练过程,因为 PyTorch 允许使用 Python 语言编写任意代码来定义模型。
- 与 PyTorch 框架紧密集成:保存和加载模型非常方便,因为
.pth和.pt是 PyTorch 的原生格式。
缺点:
- 只能在 PyTorch 环境中加载和使用:这限制了跨平台和跨框架的兼容性,因为其他框架无法直接读取这种格式。
应用场景:
- 研究和开发环境。
- 需要频繁保存和加载模型的场景。
2. TensorFlow/Keras (.h5, SavedModel)
简介:TensorFlow 和 Keras 的保存格式,支持保存模型的权重、架构和优化器状态。
保存方式:
model.save('model.h5')
加载方式:
from tensorflow.keras.models import load_model
model = load_model('model.h5')
部署代码:
from flask import Flask, request, jsonify
import tensorflow as tfapp = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():data = request.jsontext = data['text']inputs = tokenizer(text, return_tensors='tf', truncation=True, padding=True, max_length=512)outputs = model(inputs)logits = outputs.logitspredictions = tf.argmax(logits, axis=-1)return jsonify({'prediction': int(predictions.numpy()[0])})if __name__ == '__main__':from transformers import BertTokenizer, TFBertForSequenceClassificationtokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')model.load_weights('model.h5')app.run(host='0.0.0.0', port=5000)
优点:
- 适用于 TensorFlow 和 Keras 环境:模型文件可以直接在这些框架中加载和使用,方便开发和部署。
- 支持多种部署方式:包括 TensorFlow Serving,这使得在生产环境中的部署更加灵活和高效。
缺点:
- 模型文件较大:可能影响加载速度,因为
.h5文件包含了完整的模型架构和权重。
应用场景:
- 生产环境中的模型部署。
- 需要与 TensorFlow 生态系统集成的应用。
3. ONNX (Open Neural Network Exchange)
简介:开放格式,旨在实现不同深度学习框架之间的互操作性。
保存方式:
import torch.onnx
torch.onnx.export(model, dummy_input, 'model.onnx')
加载方式:
import onnx
import onnxruntime as ort
onnx_model = onnx.load('model.onnx')
ort_session = ort.InferenceSession('model.onnx')def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()outputs = ort_session.run(None, {ort_session.get_inputs()[0].name: to_numpy(dummy_input)})
部署代码:
from flask import Flask, request, jsonify
import onnxruntime as ort
import numpy as npapp = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():data = request.jsontext = data['text']inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(inputs['input_ids'])}ort_outs = ort_session.run(None, ort_inputs)predictions = np.argmax(ort_outs[0], axis=1)return jsonify({'prediction': int(predictions[0])})if __name__ == '__main__':import onnxfrom transformers import BertTokenizertokenizer = BertTokenizer.from_pretrained('bert-base-uncased')ort_session = ort.InferenceSession('model.onnx')app.run(host='0.0.0.0', port=5000)
优点:
- 跨平台兼容:支持多种深度学习框架,如 PyTorch、TensorFlow、Caffe2 等,使得模型可以在不同平台之间迁移。
- 统一格式:简化了在不同框架之间转换模型的复杂性。
缺点:
- 需要额外的工具链:需要使用 ONNX 工具来转换和部署模型,增加了一定的复杂性。
应用场景:
- 跨平台模型部署。
- 在不同框架之间转换模型。
4. TensorFlow Lite
简介:专门为移动和嵌入式设备设计的轻量级模型格式。
保存方式:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:f.write(tflite_model)
加载方式:
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()
部署代码:
TensorFlow Lite 模型主要用于移动设备和嵌入式设备,下面是一个简化的示例,展示如何在 Python 环境中进行推理:
import tensorflow as tf
import numpy as np# 加载模型
interpreter = tf.lite.Interpreter(model_path='model.tflite')
interpreter.allocate_tensors()# 获取模型输入和输出的详细信息
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()# 准备输入数据
input_data = np.array([...], dtype=np.float32) # 根据模型输入需求准备数据# 设置模型输入
interpreter.set_tensor(input_details[0]['index'], input_data)# 推理
interpreter.invoke()# 获取输出数据
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
优点:
- 轻量级:适合资源受限的设备,因为 TensorFlow Lite 模型被优化为更小、更高效。
- 快速加载和推理:由于优化了模型结构,推理速度更快。
缺点:
- 支持的操作有限:不支持所有 TensorFlow 操作,可能需要调整模型架构以适应 TensorFlow Lite 的限制。
应用场景:
- 移动设备上的应用。
- 物联网和嵌入式设备。
5. CoreML
简介:苹果公司为 iOS 和 macOS 设备提供的模型格式。
保存方式:
import coremltools as ct
coreml_model = ct.convert(model)
coreml_model.save('model.mlmodel')
加载方式:
在 iOS/macOS 应用中使用 CoreML 框架加载。
部署代码:
CoreML 模型主要用于 iOS 和 macOS 应用开发,下面是一个简化的示例,展示如何在 Swift 中使用 CoreML 模型进行推理:
import CoreML
import Foundation// 加载模型
let model = try! MyCoreMLModel(configuration: MLModelConfiguration())// 准备输入数据
let input = MyCoreMLModelInput(text: "your input text")// 获取模型预测结果
let prediction = try! model.prediction(input: input)
print(prediction.label)
优点:
- 与苹果生态系统深度集成:在 iOS 和 macOS 设备上运行非常高效,且与其他苹果生态系统的服务无缝集成。
- 易于部署:CoreML 模型可以直接在 Xcode 中使用,非常适合苹果开发者。
缺点:
- 仅限于苹果设备:无法在其他平台上运行,这限制了跨平台应用的开发。
应用场景:
- iOS 应用开发。
- macOS 应用开发。
6. PaddlePaddle (.pdparams)
简介:百度开发的深度学习框架 PaddlePaddle 的保存格式。
保存方式:
import paddle
paddle.save(model.state_dict(), 'model.pdparams')
加载方式:
model.set_state_dict(paddle.load('model.pdparams'))
部署代码:
from flask import Flask, request, jsonifyapp = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():data = request.jsontext = data['text']inputs = tokenizer(text, return_tensors='pd', truncation=True, padding=True, max_length=512)with paddle.no_grad():outputs = model(**inputs)logits = outputs.logitspredictions = paddle.argmax(logits, axis=-1)return jsonify({'prediction': predictions.item()})if __name__ == '__main__':from paddlenlp.transformers import BertTokenizer, BertForSequenceClassificationtokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertForSequenceClassification.from_pretrained('bert-base-uncased')model.set_state_dict(paddle.load('model.pdparams'))model.eval()app.run(host='0.0.0.0', port=5000)
优点:
- 与 PaddlePaddle 框架集成:适用于使用 PaddlePaddle 进行开发和部署的项目,提供了百度生态系统的支持。
- 优化的中国市场支持:PaddlePaddle 在中国市场有良好的支持和资源。
缺点:
- 只能在 PaddlePaddle 环境中加载和使用:这限制了在其他深度学习框架中的兼容性。
应用场景:
- 需要使用百度深度学习工具的项目。
- 在中国市场的应用。
7. HDF5 (.h5)
简介:一种用于存储大型数据集的文件格式,Keras 默认支持这种格式。
保存方式:
model.save('model.h5')
加载方式:
from tensorflow.keras.models import load_model
model = load_model('model.h5')
部署代码:
与 TensorFlow/Keras 的 .h5 部署代码相同,参考 TensorFlow/Keras 部分的部署代码。
优点:
- 方便存储和管理大型数据集:HDF5 格式擅长处理大规模数据,并支持压缩和并行 I/O 操作。
- 与 Keras 深度集成:Keras 默认支持这种格式,保存和加载模型非常方便。
缺点:
- 模型文件较大:包含了完整的模型架构和权重,导致文件较大,加载速度可能较慢。
应用场景:
- Keras 环境下的模型存储和加载。
- 需要保存大型模型的场景。
8. SafeTensors
简介:一种新型的格式,旨在提高模型保存和加载的安全性和速度。
保存方式:
from safetensors.torch import save_file
save_file(model.state_dict(), 'model.safetensors')
加载方式:
from safetensors.torch import load_file
state_dict = load_file('model.safetensors')
model.load_state_dict(state_dict)
部署代码:
与 PyTorch 部署代码类似,可使用 Flask 或其他框架创建 API 服务。
优点:
- 安全性高:消除潜在执行代码风险,因为
safetensors格式不允许在加载模型时执行任意代码。 - 加载速度快:优化了模型的加载速度,特别适用于大型模型。
缺点:
- 需要额外的库支持:必须安装
safetensors库才能使用这种格式。
应用场景:
- 需要高安全性和快速加载的环境。
- 大型模型的存储和部署。
汇总
下面是各深度学习模型保存和加载格式的汇总表,包括格式、简介、优点、缺点和应用场景:
| 格式 | 简介 | 优点 | 缺点 | 应用场景 |
|---|---|---|---|---|
| PyTorch (.pth, .pt) | PyTorch 的默认保存格式,支持保存整个模型、权重和优化器状态 | 高度灵活,支持复杂的模型和训练过程。与 PyTorch 框架紧密集成。 | 只能在 PyTorch 环境中加载和使用,限制了跨平台和跨框架的兼容性。 | 研究和开发环境,频繁保存和加载模型的场景 |
| TensorFlow/Keras (.h5, SavedModel) | TensorFlow 和 Keras 的保存格式,支持保存模型的权重、架构和优化器状态 | 适用于 TensorFlow 和 Keras 环境,支持多种部署方式(如 TensorFlow Serving)。 | 模型文件较大,可能影响加载速度。 | 生产环境中的模型部署,与 TensorFlow 生态系统集成的应用 |
| ONNX | 开放格式,实现不同深度学习框架之间的互操作性 | 跨平台兼容,支持多种深度学习框架。统一格式,简化了在不同框架之间转换模型的复杂性。 | 需要额外的工具链来转换和部署模型。 | 跨平台模型部署,在不同框架之间转换模型 |
| TensorFlow Lite | 专为移动和嵌入式设备设计的轻量级模型格式 | 轻量级,适合资源受限的设备。快速加载和推理。 | 支持的操作有限,可能需要调整模型架构以适应 TensorFlow Lite 的限制。 | 移动设备上的应用,物联网和嵌入式设备 |
| CoreML | 苹果公司为 iOS 和 macOS 设备提供的模型格式 | 与苹果生态系统深度集成,在 iOS 和 macOS 设备上运行非常高效。易于部署,适合苹果开发者。 | 仅限于苹果设备,无法在其他平台上运行。 | iOS 应用开发,macOS 应用开发 |
| PaddlePaddle (.pdparams) | 百度开发的深度学习框架 PaddlePaddle 的保存格式 | 与 PaddlePaddle 框架集成,适用于百度生态系统。优化的中国市场支持。 | 只能在 PaddlePaddle 环境中加载和使用,限制了在其他深度学习框架中的兼容性。 | 使用百度深度学习工具的项目,在中国市场的应用 |
| HDF5 (.h5) | 一种用于存储大型数据集的文件格式,Keras 默认支持这种格式 | 方便存储和管理大型数据集,HDF5 格式擅长处理大规模数据,并支持压缩和并行 I/O 操作。与 Keras 深度集成。 | 模型文件较大,包含了完整的模型架构和权重,加载速度可能较慢。 | Keras 环境下的模型存储和加载,需要保存大型模型的场景 |
| SafeTensors | 一种新型格式,提高模型保存和加载的安全性和速度 | 安全性高,消除潜在执行代码风险。加载速度快,特别适用于大型模型。 | 需要额外的库支持,必须安装 safetensors 库才能使用这种格式。 | 需要高安全性和快速加载的环境,大型模型的存储和部署 |
希望这张表格能够帮助新手朋友们更好地理解不同格式的特点,并根据自己的需求选择合适的格式来保存和部署模型。
相关文章:
深度学习 - 模型的保存与部署方式汇总
深度学习模型保存和加载格式科普 在深度学习中,模型的保存和加载是非常重要的环节。不同的格式有不同的特点和适用场景。本文将为新手朋友们介绍几种常见的模型格式,包括它们的简介、保存方式、加载方式、优缺点以及应用场景。 1. PyTorch (.pth, .pt)…...
人工智能对网络安全有何影响?
人工智能网络安全在短期、中期和长期如何变化 当今数字时代网络安全的重要性 在谈论人工智能在网络安全中的作用时,必须首先考虑短期影响,因为它们是最明显的,而且它是一个未知的领域,需要超越直接炒作的能力。 因此࿰…...
Oracle的RECYCLEBIN回收站:轻松恢复误删对象
目录 Oracle的RECYCLEBIN回收站:轻松恢复误删对象一、概念二、工作原理三、使用方法1 查看回收站中的对象2 恢复回收站中的对象2.1 恢复表(TABLE)2.2 恢复索引(INDEX)2.3 恢复视图(VIEW)2.4 恢复…...
Android 内存原理详解以及优化(二)
上一篇讲了内存原理,如果还没看可以先看上一篇:Android 内存原理详解以及优化(一) 这一篇我总结一下我们经常遇到的内存优化问题: 1.内存抖动 自定义view的ondraw是会被频繁调用的,那在这个方法里面就不能频…...
Shell学习——Shell变量
文章目录 Shell变量使用变量只读变量删除变量变量类型字符串变量: 在 Shell中,变量通常被视为字符串。整数变量: 在一些Shell中,你可以使用 declare 或 typeset 命令来声明整数变量。数组变量: Shell 也支持数组&#…...
Java中的持续集成与持续部署(CI/CD)
Java中的持续集成与持续部署(CI/CD) 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将深入探讨Java中的持续集成(Co…...
极狐GitLab 将亮相2024空天信息大会暨数字地球生态峰会,携手中科星图赋能空天行业开发者
GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab :https://gitlab.cn/install?channelcontent&utm_sourcecsdn 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署…...
Beats:使用 Filebeat 从 Python 应用程序中提取日志
本指南演示了如何从 Python 应用程序中提取日志并将其安全地传送到 Elasticsearch Service 部署中。你将设置 Filebeat 来监控具有标准 Elastic Common Schema (ECS) 格式字段的 JSON 结构日志文件,然后你将在 Kibana 中查看日志事件发生的实时可视化。虽然此示例使…...
51单片机第23步_定时器1工作在模式0(13位定时器)
重点学习51单片机定时器1工作在模式0的应用。 在51单片机中,定时器1工作在模式0,它和定时器0一样,TL1占低5位,TH1占高8位,合计13位,也是向上计数。 1、定时器1工作在模式0 1)、定时器1工作在模式0的框图…...
linux的服务管理
systemd systemd 是一个系统和服务管理器,用于Linux操作系统中,旨在替代传统的Unix系统V初始化系统(SysV init)。 不一定所有使用 yum 安装的软件都可以通过 systemctl start 来管理。能否通过 systemctl start 管理取决于软件包…...
动手学深度学习(Pytorch版)代码实践 -循环神经网络-53语言模型和数据集
53语言模型和数据集 1.自然语言统计 引入库和读取数据: import random import torch from d2l import torch as d2l import liliPytorch as lp import numpy as np import matplotlib.pyplot as plttokens lp.tokenize(lp.read_time_machine())一元语法…...
Python 学习之自动化运维技术(八)
Python 的自动化运维技术 Python的自动化运维技术是指利用Python编程语言和相关工具实现运维工作的自动化,以提高效率、减轻工作负担。以下是对Python自动化运维技术的清晰归纳和详细介绍: 一、自动化运维的核心优势 ● 提高效率:通过自动化脚…...
【python】PyQt5可视化开发,如何设计鼠标显示的形状?
✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…...
利用大模型知识库,优化智能客服问答效果 | 创新场景
ITValue 痛点 SSC( Share Service Center ,共享服务中心)是企业日常接触最多的场景之一,更多是对内服务,包括 HR 、财务、IT 等。该场景对专业度要求非常高,知识点非常多,对于知识的使用者或者查…...
物联网协议都包含哪些协议?
物联网协议是物联网生态系统中不可或缺的组成部分,它们负责处理和协调物联网设备之间的通信。具体介绍如下: Ethernet:以太网是一种有线网络协议,广泛应用于局域网络(LAN)中,提供稳定的高速数据传输。Wi-Fi࿱…...
面试专区|【52道微服务架构高频题整理(附答案背诵版)】
简述什么是微服务? 微服务是一种软件架构风格,它将应用程序拆分成一系列小型、独立的服务,每个服务都运行在其自己的进程中,通过轻量级通信机制进行通信。每个服务都具有明确的业务能力,并且可以独立开发、测试、部署…...
数据结构之算法的时间复杂度
1.时间复杂度的定义 在计算机科学中,算法的时间复杂度是一个函数,它定量描述了算法的运行时间。一个算法所花费的时间与其中语句的执行次数成正比列,算法中的基本操作的执行次数,为算法的时间复杂度 例1: 计算Func1…...
unity中物体被激活自动执行挂载代码
在Unity中,如果希望当物体被激活时自动执行特定的函数,可以利用 MonoBehaviour 的生命周期函数 OnEnable()。这个方法会在对象被激活时调用,可以用来执行初始化或者处理其他逻辑。以下是如何在脚本中使用 OnEnable() 方法: using UnityEngine;public class ActivateFuncti…...
Pandas数据可视化详解:大案例解析(第27天)
系列文章目录 Pandas数据可视化解决不显示中文和负号问题matplotlib数据可视化seaborn数据可视化pyecharts数据可视化优衣库数据分析案例 文章目录 系列文章目录前言1. Pandas数据可视化1.1 案例解析:代码实现 2. 解决不显示中文和负号问题3. matplotlib数据可视化…...
Redis基础教程(七):redis列表(List)
💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 💝Ὁ…...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...
srs linux
下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935,SRS管理页面端口是8080,可…...
成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战
在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...
vulnyx Blogger writeup
信息收集 arp-scan nmap 获取userFlag 上web看看 一个默认的页面,gobuster扫一下目录 可以看到扫出的目录中得到了一个有价值的目录/wordpress,说明目标所使用的cms是wordpress,访问http://192.168.43.213/wordpress/然后查看源码能看到 这…...
三分算法与DeepSeek辅助证明是单峰函数
前置 单峰函数有唯一的最大值,最大值左侧的数值严格单调递增,最大值右侧的数值严格单调递减。 单谷函数有唯一的最小值,最小值左侧的数值严格单调递减,最小值右侧的数值严格单调递增。 三分的本质 三分和二分一样都是通过不断缩…...
【Android】Android 开发 ADB 常用指令
查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...
rknn toolkit2搭建和推理
安装Miniconda Miniconda - Anaconda Miniconda 选择一个 新的 版本 ,不用和RKNN的python版本保持一致 使用 ./xxx.sh进行安装 下面配置一下载源 # 清华大学源(最常用) conda config --add channels https://mirrors.tuna.tsinghua.edu.cn…...
软件工程 期末复习
瀑布模型:计划 螺旋模型:风险低 原型模型: 用户反馈 喷泉模型:代码复用 高内聚 低耦合:模块内部功能紧密 模块之间依赖程度小 高内聚:指的是一个模块内部的功能应该紧密相关。换句话说,一个模块应当只实现单一的功能…...
[特殊字符] 手撸 Redis 互斥锁那些坑
📖 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作,想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁,也顺便跟 Redisson 的 RLock 机制对比了下,记录一波,别踩我踩过…...
海云安高敏捷信创白盒SCAP入选《中国网络安全细分领域产品名录》
近日,嘶吼安全产业研究院发布《中国网络安全细分领域产品名录》,海云安高敏捷信创白盒(SCAP)成功入选软件供应链安全领域产品名录。 在数字化转型加速的今天,网络安全已成为企业生存与发展的核心基石,为了解…...
