深度学习笔记(九)——tf模型导出保存、模型加载、常用模型导出tflite、权重量化、模型部署
文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
本篇博客主要是工具性介绍,可能由于软件版本问题导致的部分内容无法使用。
首先介绍tflite: TensorFlow Lite 是一组工具,可帮助开发者在移动设备、嵌入式设备和 loT 设备上运行模型,以便实现设备端机器学习。
框架具有的主要特性:
- 延时(数据无需往返服务器)
- 隐私(没有任何个人数据离开设备)
- 连接性(无需连接互联网)
- 大小(缩减了模型和二进制文件的大小)
- 功耗(高效推断,且无需网络连接)
官方目前支持了大约130中可以量化的算子,在查阅大量资料后目前自定义的算子使用tflite导出任然存在较多问题。就解决常见的算法,使用支持的算子基本可以覆盖。tflite的压缩能力极强:使用官方算子构建的模型,导出TensorFlow Lite 二进制文件的大小约为 1 MB(针对 32 位 ARM build);如果仅使用支持常见图像分类模型(InceptionV3 和 MobileNet)所需的运算符,TensorFlow Lite 二进制文件的大小不到 300 KB。在后文的实例中我们用iris数据集的分类演示,可以导出一个仅仅只有2kb大小的模型权重相比未压缩的70kb模型缩小了30多倍。
同时tflite还实验性的在支持导出极轻量化的TFLM模型(TensorFlow Lite for Microcontrollers),这些模型可以直接在嵌入式单片机上进行推理,不过现阶段支持的算子还很少,简单的可以利用全连接和低向量卷积实现一些传感器参数的识别任务。现在主要的实例场景是MCU+IMU组合,识别IMU连续数据,来判断人体特定动作。同时开可以在MCU上离线运行语音命令识别,可以实现一个关键字的识别。
好了那我们继续看一下怎么保存模型,加载模型,保存tflite,加载tflite
保存权重或TF格式标准模型
通常情况下当完成了网络结构设计,数据处理,网络训练和评价之后需要及时的保存数据。先看到前面博客中已经介绍过的iris数据集实现网络分类任务。当时通过添加保存回调函数实现了网络权重的保存,这样保存下的是网络权重模型,需要配合网络结构的实例化使用。当然tf还提供了很多种模型的保存方式,tf2官方推荐使用tf形式保存,通过这种方式相关文件会保存到一个指定文件夹中,包含模型的权重参数模型结构信息。
ckpt格式
通过回调函数实现动态权重的保存。
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)# 定义网络结构
class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return y
# 实例化化模型
model = IrisModel()
# 定义保存和记录数据的回调器
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath="./checkpoint/iris/iris.ckpt", # 保存模型权重参数save_weights_only=True,save_best_only=True)
# 初始化模型
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'],callbacks=[cp_callback])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
在上面的代码中通过tf.keras.callbacks.ModelCheckpoint
设置了一个回调器,会动态的在网络训练的过程中保存下参数表现效果最好的权重参数。这里主要保存网络中各个可变参数的值和网络的当前图数据。这个模型无法直接用于推理,应为其不包含网络完整的图信息。所以我们需要在训练结束时保存网络整体的图信息。
.pd格式
pd格式是tf保存静态模型的专用权重文件。在训练完成后直接执行:
model.save('./yor_save_path/model', save_format='tf') # 保存模型为静态权重
这样就可以把model的全部图信息保存下来了那么怎么保存最好的呢?可以结合上一个.ckpt文件使用。
对比两个的保存方式差距,前者在动态的训练过程中存储数据,后者针对某一个节点的网络状态完整保存。所以可以在训练过程中保存下最好的参数,当训练结束后再加载回最好的动态权重,然后再保存为.pd文件。
加载权重或TF格式标准模型
动态权重和静态图模型的保存不同,加载也不同。加载.ckpt时,需要先实例化网络结构,然后再读取权重参数给实例化的模型赋值。对于静态模型文件则不需要实例化模型,也就是无需关注网络的内部,直接读取加载模型就会完成网络构建和参数赋值两个任务,在部署时明显静态模型模型的程序文件会更加简单。
model = userModer()
model.compile()
checkpoint_save_path = "./yor_file_path.ckpt"
if os.path.exists(checkpoint_save_path + '.index'): #print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)
上面的程序展示从动态图加载权重。下面的程序则直接从静态图加载模型。
model_path = './yor_model_path'
new_model = tf.keras.models.load_model(model_path) # 从tf模型加载,无需重新实例化网络
从静态图文件加载模型有便捷之处,但是也需要注意模型的输入和输出结构,要保证预测时输入网络的数据维度是符合要求的,同时根据网络输出的模式接收输出数据做相应处理。
转化模型到tflite
转化模型主要有三种方式:
- 使用现有的 TensorFlow Lite 模型
- 创建 TensorFlow Lite 模型
- 将 TensorFlow 模型转换为TensorFlow Lite 模型
模型的保存就分别对应三个主要函数:
后续主要介绍使用tf构建网络后从tf模型保存到tflite,并以keras model为主。
首先我们需要上面iris数据集分类的例子,当网络训练结束后,可以使用如下的程序导出:
tflite_save_path = './your_file_path'
os.makedirs(tflite_save_path, exist_ok=True)
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
# Save the model.
with open(tflite_save_path+'/model.tflite', 'wb') as f:f.write(tflite_model)
导出后可以看到model.tflite
的模型文件。可以比较上面直接导出的完整模型,这个模型的体积小了很多,更加适合在低算力和存储能力的设备上运行。
从tflite加载模型并执行推理
从tflite上加载模型并推理主要有两个手段:使用完整tf框架加载tf.lite读取;或使用tflite_runtime
,这是 TensorFlow Lite 解释器,无需安装所有 TensorFlow 软件包,但是对python版本和系统,硬件有一定的要求。目前tf-runtime支持的平台有:
在此以外的模型需要拉取完整源码在本地设备上编译执行。
安装了相关的软件环境后,可以使用如下的代码来加载模型并推理:
# Load the TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=tflite_save_path+'/model.tflite') # 加载模型
interpreter.allocate_tensors() # 为模型分配张量参数# Get input and output tensors.
input_details = interpreter.get_input_details() # 设置网络输入
output_details = interpreter.get_output_details() # 设置网络输出# Test the model on set input data.
input_shape = input_details[0]['shape'] # 获取输入层(第一层)的数据维度
print(input_shape) # 输出维度结构,便于调试input_data = np.array([6.0,3.4,4.5,1.3], dtype=np.float32) # 手动给一组鸢尾花数据
input_data = input_data.reshape([1,4]) # 确保维度相同
print(input_data.shape)
interpreter.set_tensor(input_details[0]['index'], input_data) # 将数据输入到网络中
interpreter.invoke() # 运行推理
output_data = interpreter.get_tensor(output_details[0]['index']) # 获得网络输出print(output_data)pred = tf.argmax(output_data, axis=1) # 网络输出层是softmax,需要找到最大值
print(int(pred)) # 输出最大位置的index
通过上面几行简单的代码就可以在终端设备实现预测推理。使用tf-runtim时只需要做简单修改,将包名替换即可。例如:
import tensorflow as tf 改为:import tflite_runtime.interpreter as tflite
interpreter = tf.lite.Interpreter(model_path=args.model_file) 改为: interpreter = tflite.Interpreter(model_path=args.model_file)
模型量化
对模型执行量化可以进一步解决嵌入式终端设备的痛点。量化模型可以实现:
- 较小的存储大小:小模型在用户设备上占用的存储空间更少
- 较小的下载大小:小模型下载到用户设备所需的时间和带宽较少
- 更少的内存用量:小模型在运行时使用的内存更少,从而释放内存供应用的其他部分使用,并可以转化为更好的性能和稳定性
tflite支持的量化形式有:
训练后量化
训练后量化是一种转换技术,它可以在改善 CPU 和硬件加速器延迟的同时缩减模型大小,且几乎不会降低模型准确率。使用 TensorFlow Lite 转换器将已训练的浮点 TensorFlow 模型转换为 TensorFlow Lite 格式后,可以对该模型进行量化。
动态范围量化
动态范围量化能使模型大小缩减至原来的四分之一,在量化时激活函数始终以浮点格式保存,其它支持的算子会根据损失动态保存为8位整形,以此减小模型体积。在导出时量化模型,设置 optimizations 标记以优化大小:
# 上续训练后的模型
tflite_save_path = './your_file_path'
os.makedirs(tflite_save_path, exist_ok=True)
# Convert the model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"opt_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)
加载模型时使用相同的方式加载即可
全整数量化
全整型量化相对更加复杂一些。
在上面导出tflite的过程中,实际是将tf默认的协议缓冲区模型压缩为FlatBuffers的格式,这种格式具有多种优势,例如可缩减大小(代码占用的空间较小)以及提高推断速度(可直接访问数据,无需执行额外的解析/解压缩步骤)。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
上面的两行代码实质是做了模型格式的转换和压缩,并没有调整权重参数和计算格式。在上面的基础上,可以进一步使用动态范围量化:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()
通过动态范围量化之后模型已经缩小了,但是任然有一部分模型参数是浮点格式,这对存储有效,计算能力有限的设备还是存在限制。
在进行整形量化时需要量化模型内部层和输入输出层。tflite给出了两种量化的方式,第一种量化兼容性相对广泛,但是需要输入一组足够大的代表数据集用来推理量化。这样得到的模型任然有小部分参数会是浮点,这无法支持纯整形计算的硬件。
这里转述官方给出的第二种整形量化方式:
为了量化输入和输出张量,并让转换器在遇到无法量化的运算时引发错误,使用一些附加参数再次转换模型:
def representative_data_gen():for input_value in tf.data.Dataset.from_tensor_slices(train_data).batch(yordatabatch).take(100):yield [input_value]converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8tflite_model_quant = converter.convert()
上面的第一个函数是量化的一个必要步骤,要量化可变数据(例如模型输入/输出和层之间的中间体),需要提供 RepresentativeDataset。这是一个生成器函数,它提供一组足够大的输入数据来代表典型值。转换器可以通过该函数估算所有可变数据的动态范围。(相比训练或评估数据集,此数据集不必唯一。)为了支持多个输入,每个代表性数据点都是一个列表,并且列表中的元素会根据其索引被馈送到模型。
通过转化给定数据推理量化,现在模型的输入层和输出层数据已经是整形格式:
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)
此时模型已经完全支持全整形设备的计算。
那继续的,将模型文件保存下来:
import pathlib
tflite_models_dir = pathlib.Path("/tmp/user_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
# Save the unquantized/float model:
tflite_model_file = tflite_models_dir/"user_model.tflite"
tflite_model_file.write_bytes(tflite_model)
# Save the quantized model:
tflite_model_quant_file = tflite_models_dir/"user_model_quant.tflite"
tflite_model_quant_file.write_bytes(tflite_model_quant)
执行推理时使用的程序结构和上文介绍的从tflite加载模型并执行推理的内容相同。
相关文章:

深度学习笔记(九)——tf模型导出保存、模型加载、常用模型导出tflite、权重量化、模型部署
文中程序以Tensorflow-2.6.0为例 部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。 本篇博客主要是工具性介绍,可能由于软件版本问题导致的部分内容无法使用。 首先介绍tflite: TensorFlow Lite 是一组工具,可帮助开…...

七Docker可视化管理工具
Docker可视化管理工具 本节介绍几款Docker可视化管理工具。 DockerUI(ui for Docker) 官方GitHub:https://github.com/kevana/ui-for-docker 项目已废弃,现在转投Portainer项目,不建议使用。 Portainer 简介:Portainer是一个…...

vue和react的差异梳理
特性VueReact响应式系统使用Object.defineProperty()或Proxy使用不可变数据流和状态提升模板系统HTML模板语法JSX(JavaScript扩展语法)组件作用域样式支持scoped样式需要CSS-in-JS库(如styled-components)状态管理Vuex(…...

(笔记总结)C/C++语言的常用库函数(持续记录,积累量变)
写在前面: 由于时间的不足与学习的碎片化,写博客变得有些奢侈。 但是对于记录学习(忘了以后能快速复习)的渴望一天天变得强烈。 既然如此 不如以天为单位,以时间为顺序,仅仅将博客当做一个知识学习的目录&a…...

OceanBase集群扩缩容
OceanBase 数据库采用 Shared-Nothing 架构,各个节点之间完全对等,每个节点都有自己的 SQL 引擎、存储引擎、事务引擎,天然支持多租户,租户间资源、数据隔离,集群运行的最小资源单元是Unit,每个租户在每…...

html 3D 倒计时爆炸特效
下面是代码: <!DOCTYPE html> <html><head><meta charset"UTF-8"><title>HTML5 Canvas 3D 倒计时爆炸特效DEMO演示</title><link rel"stylesheet" href"css/style.css" media"screen&q…...

记一次垃圾笔记应用VNote安装失败过程
特色功能简介 1.全文搜索: VNote支持根据关键词搜索整个笔记本或者特定文件夹内的文档内容,非常适合快速找到信息。 2.标签管理: 你可以给笔记添加标签,从而更好地组织和检索你的笔记内容。 3.自定义主题和样式: 进入设置,VNote允许你选…...

记一次 stackoverflowerror 线上排查过程
一.线上 stackOverFlowError xxx日,突然收到线上日志关键字频繁告警 classCastException.从字面上的报警来看,仅仅是类型转换异常,查看细则发现其实是 stackOverFlowError.很多同学面试的时候总会被问到有没有遇到过线上stackOverFlowError?有么有遇到栈溢出?具体栈溢出怎么来…...

论文写作之十个问题
前言 最近进入瓶颈? 改论文,改到有些抑郁了 总是不对,总是被打回 好的写作,让人一看就清楚明白非常重要 郁闷时候看看大佬们怎么说的 沈向洋、华刚:读科研论文的三个层次、四个阶段与十个问题 十问 What is the pro…...

leetcode2171 拿出最少数目的魔法豆
题目 给定一个 正整数 数组 beans ,其中每个整数表示一个袋子里装的魔法豆的数目。 请你从每个袋子中 拿出 一些豆子(也可以 不拿出),使得剩下的 非空 袋子中(即 至少还有一颗 魔法豆的袋子)魔法豆的数目…...

测试C#调用OpenCvSharp和ViewFaceCore从摄像头中识别人脸
学习了基于OpenCvSharp获取摄像头数据,同时学习了基于ViewFaceCore的人脸识别用法,将这两者结合即是从摄像头中识别人脸。本文测试测试C#调用OpenCvSharp和ViewFaceCore从摄像头中识别人脸,并进行人脸红框标记。 新建Winform项目…...

测试经理面试初体验
家人们谁懂啊,我在海口实在难找计算机类的实习,就直接在BOss上海投了,结果一个hr直接给我弄了个测试经理的面试(可能年底冲业绩吧),然后就在明天下午,我直接抱下f脚了,就当体验一下~…...

使用ffmpeg调整视频中音频采样率及声道
1 原始视频信息 通过ffmpeg -i命令查看视频基本信息 ffmpeg -i example2.mp4 ffmpeg version 6.1-essentials_build-www.gyan.dev Copyright (c) 2000-2023 the FFmpeg developersbuilt with gcc 12.2.0 (Rev10, Built by MSYS2 project)configuration: --enable-gpl --enable…...

详细分析Java中的Date类以及格式转换
目录 前言1. 基本知识2. 格式化输出3. 格式转换 前言 记录这篇文章的缘由,主要是涉及一个格式转换,对此深挖了这个类 在Java中,Date类是用于表示日期和时间的类。 位于java.util包中,是Java平台中处理日期和时间的基本类之一。…...

【计算机网络】应用层——HTTP 协议(一)
个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【网络编程】 本专栏旨在分享学习计算机网络的一点学习心得,欢迎大家在评论区交流讨论💌 目录 一、什么是 HTTP 协…...

线程和进程的区别
Java面试题 线程和进程的区别 进程是操作系统资源分配的基本单位。 线程是处理器任务调度和执行的基本单位 一个进程可以包含多个线程。进程之间的资源是相互独立,而同一进程下的线程之间可以共享进程中的资源。...

proxy 代理的接口报错301问题
项目系统里仅仅这个接口报错,反向代理错误导致。 默认情况下,不接受运行在HTTPS上,且使用了无效证书的后端服务器。如果你想要接受,修改配置:secure: false(简单意思:如果本地没有进行过https相…...

mysql进阶-执行计划
目录 1. 概念 2. 使用 3. 具体相关字段含义 3.1 id 3.2 select_type 3.3 table 3.4 partition 3.5 type 3.6 possible_key 3.7 key 3.8 key_len 3.9 ref 3.10 row 3.11 filtered 3.12 extra 1. 概念 一条语句通过优化器之后,会生成具体的执行计划用…...

【UE5】第一次尝试项目转插件(Plugin)的时候,无法编译
VS显示100条左右的错误,UE热编译也不能通过。原因可能是[名字.Build.cs]文件的错误,缺少一些内容,比如说如果要写UserWidget类,那么就要在 ]名字.Build.cs] 中加入如下内容: public class beibaoxitong : ModuleRules …...

MeterSphere本地化部署实践
项目结构 搭建本地环境 安装JDK11,配置好JDK环境,系统同时支持JDK8和JDK11安装IEAD,配置JDK环境配置maven环境,IDEA配置(解压可以直接使用)无限重置IDEA试用期配置redis环境(解压可以直接使用) 配置kafka环境 安装mysql-5.7环境ÿ…...

巨变!如何理解中国发起的“数据要素X”计划?
作者 张群(赛联区块链教育首席讲师,工信部赛迪特聘资深专家,CSDN认证业界专家,微软认证专家,多家企业区块链产品顾问)关注张群,为您提供一站式区块链技术和方案咨询。 刘烈宏在第25届北大光华新…...

CS8370错误,这是由于使用了C# 7.3中不支持的功能
目录 背景: 第一种方法: 第二种办法: 背景: 在敲代码的时候,程序提示报错消息提示:CS8370错误,那么这是什么原因导致的,这是由于使用了C# 7.3中不支持的功能,不支持该功能,那就是版本太低我们就需要升级更高的版本&…...

Raspbian安装云台
Raspbian安装云台 1. 源由2. 选型3. 组装4. 调试4.1 python3-print问题4.2 python函数入参类型错误4.3 缺少mjpg-streamer可执行文件4.4 缺失编译头文件和库4.5 python库缺失4.6 图像无法显示,但libcamera-jpeg测试正常4.7 异常IOCTL报错4.8 Git问题 5. 效果5.1 WEB…...

蓝桥杯理历年真题 —— 数学
1. 买不到的数目 这道题目,考得就是一个日常数学的积累,如果你学过这个公式的话,就是一道非常简单的输出问题;可是如果没学过,就非常吃亏,在考场上只能暴力求解,或是寻找规律。这就要求我们什么…...

自然语言处理--双向匹配算法
自然语言处理作业1--双向匹配算法 一、概述 双向匹配算法是一种用于自然语言处理的算法,用于确定两个文本之间的相似度或匹配程度。该算法通常使用在文本对齐、翻译、语义匹配等任务中。 在双向匹配算法中,首先将两个文本分别进行处理,然后…...

IDEA 2023.3.2 安装教程
1.下载2023.3.2版本IDEA 链接:https://pan.baidu.com/s/1RkXBLz6qxsd8VxXuvXCEMA?pwd5im6 提取码:5im6 2.安装 3.解压文件,进入,选择方式3 4.将下面文件夹复制到任意位置(不要有中文路径) 5.进入下面文…...

C语言常见面试题:什么是宏,宏的作用是什么?
宏在计算机科学中是一种批量处理程序命令,它是一种抽象的规则或模式,用于说明某一特定输入(通常是字符串)如何根据预定义的规则转换成对应的输出(通常也是字符串)。在编译时,预处理器会对宏进行…...

【0248】Background Writing实现机制分析
文章目录 1. 前言2. 有了checkpoint,为何还需要background writing?2.1 checkpoint和background writing有何差异? 如何协同工作?2.2 background writing如何工作? 职责是什么?1. 前言 本文是Background Writing进程理论篇,源码剖析实战篇会在后面给出。本文的主要内容…...

基于springboot+vue的教师工作量管理系统(前后端分离)
博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容:毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目背景…...

4-新建子模块(尝鲜)
新建子模块 Maven多模块下新建子模块流程案例。 1、新建业务模块目录,例如:ruoyi-test。 2、在ruoyi-test业务模块下新建pom.xml文件以及src\main\java,src\main\resources目录。 <?xml version"1.0" encoding"UTF-8&…...