24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)
知识要点
keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数
- 1.保存模型和参数
- save_model
- callback ModelCheckpoint
- 2. 只保存参数
- save_weights
- callback ModelCheckpoint save_weights_only = True
保存模型:
- 案例数据: Fashion-MNIST总共有十个类别的图像
- model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5')) # 保存参数的方法
- 加载参数: model.load_weights(os.path.join(logdir, 'fashion_mnist_weight.h5'))
- 保存模型: model.save(os.path.join(logdir, 'fashion_mnist_model.h5'))
- 加载模型: model2 = keras.models.load_model(os.path.join(logdir, 'fashion_mnist_model.h5'))
- 把keras模型保存成savedmodel格式: tf.saved_model.save(model, './keras_saved_model')
一 模型保存和部署
- TFLite是为了将深度学习模型部署在移动端和嵌入式设备的工具包,可以把训练好的TF模型通过转化、部署和优化三个步骤,达到提升运算速度,减少内存、显存占用的效果。
- TFlite主要由Converter和Interpreter组成。Converter负责把TensorFlow训练好的模型转化,并输出为.tflite文件(FlatBuffer格式)。转化的同时,还完成了对网络的优化,如量化。Interpreter则负责把.tflite部署到移动端,嵌入式(embedded linux device)和microcontroller,并高效地执行推理过程,同时提供API接口给Python,Objective-C,Swift,Java等多种语言。简单来说,Converter负责打包优化模型,Interpreter负责高效易用地执行推理。
-
Fashion-MNIST总共有十个类别的图像。每一个类别由训练数据集6000张图像和测试数据集1000张图像。所以训练集和测试集分别包含60000张和10000张。测试训练集用于评估模型的性能。
-
每一个输入图像的高度和宽度均为28像素。数据集由灰度图像组成。Fashion-MNIST,中包含十个类别,分别是t-shirt,trouser,pillover,dress,coat,sandal,shirt,sneaker,bag,ankle boot。
1.1 模型创建
- 导包
# 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
- 时尚数据导入
# 时尚数据导入
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
- 标准化
# 标准化
from sklearn.preprocessing import StandardScaler # preprocessing 预处理
scaler = StandardScaler()x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 784))
x_valid_scaled = scaler.fit_transform(x_valid.astype(np.float32).reshape(-1, 784))
x_test_scaled = scaler.fit_transform(x_test.astype(np.float32).reshape(-1, 784))
- 创建模型
# 创建模型
model = keras.models.Sequential([keras.layers.Dense(512, activation = 'relu', input_shape = (784, )),keras.layers.Dense(256, activation = 'relu'),keras.layers.Dense(128, activation = 'relu'),keras.layers.Dense(10, activation = 'softmax')])model.compile(loss = 'sparse_categorical_crossentropy',optimizer = 'adam',metrics = ['accuracy'])
1.2 保存模型
# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir), # 保存地址# 保存效果最好的模型: save_best_onlykeras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True, save_weights_only = True),keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)] history = model.fit(x_train_scaled, y_train, epochs = 10,validation_data= (x_valid_scaled, y_valid),callbacks = callbacks)

- 保存模型
# 保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
model.save(output_model_file2)
-
保存参数
# 另一种保存参数的方法
model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))
- 模型评估
# evaluate 评估
model.evaluate(x_valid_scaled, y_valid) # [0.35909169912338257, 0.88919997215271]
- 模型加载
# 加载模型
model2 = keras.models.load_model(output_model_file2)
model2.evaluate(x_valid_scaled, y_valid) # [0.35909169912338257, 0.88919997215271]
二 保存模型为savemodel格式
# 把keras模型保存成savedmodel格式
tf.saved_model.save(model, './keras_saved_model')
- 读取模型
# 加载savedmodel模型
loaded_saved_model = tf.saved_model.load('./keras_saved_model')
loaded_saved_model
2.1 另一种保存
# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
model.load_weights(output_model_file)
三 tflite_interpreter 的使用
- 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import os
with open('./tflite_models/concrete_func_tf_lite', 'rb') as f:concrete_func_tflite = f.read()
- 创建interpreter
# 创建interpreter
interpreter = tf.lite.Interpreter(model_content = concrete_func_tflite)
# 分配内存
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
- 预测数值
input_data = tf.constant(np.ones(input_details[0]['shape'], dtype = np.float32))
# 传入预测数据
interpreter.set_tensor(input_details[0]['index'], input_data)# 执行预测
interpreter.invoke()# 获取输出
output_results = interpreter.get_tensor(output_details[0]['index'])
print(output_results)
四 to_concrete_function
- 加载文件
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model(np.ones((1, 784)))

- 把keras模型转化为concrete function
# 把keras模型转化为concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
# 使用
keras_concrete_func(tf.constant(np.ones((1, 784), dtype = np.float32)))

五 to_quantized_tflite
5.1 keras to tflite
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model
# lite 精简版模型 # 创建转化器
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
keras_to_tflite_converter
# 给converter添加量化的优化 # 把32位的浮点数变成8位整数
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]# 执行转化
keras_tflite = keras_to_tflite_converter.convert()
# 写入指定文件
import os
if not os.path.exists('./tflite_models'):os.mkdir('./tflite_models')with open('./tflite_models/quantized_keras_tflite', 'wb') as f:f.write(keras_tflite)
5.2 concrete function to tflite
# 把keras模型转化成concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])
concrete_func_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
concrete_func_tflite = concrete_func_to_tflite_converter.convert()
with open('./tflite_models/quantized_concrete_func_tf_lite', 'wb') as f:f.write(concrete_func_tflite)
5.3 saved_model to tflite
saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('./keras_saved_model/')
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()
with open('./tflite_models/quantized_saved_model_tflite', 'wb') as f:f.write(saved_model_tflite)
相关文章:
24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)
知识要点 keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数 1.保存模型和参数 save_modelcallback ModelCheckpoint2. 只保存参数 save_weightscallback ModelCheckpoint save_weights_only True 保存模型: 案例数据: Fashion-MNIST总共有十个类别的图像model.save_w…...
【Echarts图例点击事件】自定义Echarts图例legend点击事件(已解决)
目录先睹为快(效果)1、实现Echarts多条曲线2、点击echarts触发接口请求2.1 先默认隐藏部分数据2.2 自定义legend图例点击事件3、源码下载地址(解压即用)**【写在前面】**这下我又不得不说了,还是客户现场使用时想查询一…...
uniapp-首页配置
为了获取到后台服务器发来的数据,需要配置相应的网络地址。位置在main.js入口文件中。 import { $http } from escook/request-miniprogramuni.$http $http // 配置请求根路径 $http.baseUrl https://api-hmugo-web.itheima.net// 请求开始之前做一些事情 $http.…...
支持DDR5,超频更简单,小雕够给力,技嘉B760M小雕WIFI主板上手
目前13代酷睿已经全员集结了,其中全新的i5 13490F应该依然会备受欢迎,当然了,刚上市不久的13代酷睿价格方面还不是很有吸引力,好在12代酷睿在新一代主板上面依然可用,所以预算有限的朋友,完全可用继续使用1…...
fengMap 自定义dom 偏离实际位置;缩放时飘出地图所在区域
目录 一、问题 二、原因及解决方法 三、总结 一、问题 1.前人写了一份代码,很奇怪。使用 new fengmap.FMCompositeMarker添加的复合覆盖物位置是正常的,缩放的时候也是正常的,仍然处于地图内部;但是new fengmap.FMDomMarker添加…...
TryHackMe-黑我杯
黑我杯 相信我们大家在TryHackMe的日积月累都学到了不少东西,从纯萌新到oscp再到更高 我很高兴能将国内各thm玩家聚集到一起,构建一个更好的学习环境和氛围 本次娱乐分两场: Offensive Pentesting — 中等难度Junior Penetration — 容易难…...
【JAVA程序设计】【C00109】基于SSM(非maven)的员工工资管理系统
基于SSM(非maven)的员工工资管理系统项目简介项目获取开发环境项目技术运行截图项目简介 基于ssm框架非maven开发的企业工资管理系统共分为二个角色:系统管理员、员工 管理员角色包含以下功能: 系统后台登陆、管理员管理、员工信…...
《计算机原理》——HelloWorld.cpp如何运行的
学校《计算机原理》开课啦!特此开辟专栏,将一些知识作为笔记,记录下来。 前言 本篇博客知识点来源于educoder的相关题目 1. 相关知识 1.1 计算机语言 计算机语言是人与计算机之间通讯的语言,计算机语言包括编写计算机程序的字符…...
【面试题】在JS循环中使用await会怎么样?
前言这个问题是这样产生的?某天,在学习异步的知识遇到这样一道题:使用Promise的方式,每隔一秒输出数组中一个值const arr [1, 2, 3] arr.reduce((pre, cur) > {return pre.then(() > {returnnewPromise((resolve, rejec…...
Qt QMessageBox详解
文章目录一.QMessageBox介绍枚举属性函数二.QMessageBox的用法1.导入QMessage库2.弹窗提示3.提供选项的弹窗提示4.作为提示,报警,报错提示窗口一.QMessageBox介绍 文本消息显示框(message box)向用户发出情况警报信息并进一步解释警报或向用户提问&…...
Flutter之beamer路由入门指南
beamer路由入门指南 前言使用方法1、路由配置方式1路由配置方式2路由跳转测试现象前言 Beamer是一个很好用的路由组件,本文以beamer1.5.0版本进行说明,前面博主也介绍了其他路由组件 Flutter实战之go_router路由组件入门指南 、 Flutter之Fluro路由组件入门指南 Flutter之Ge…...
「基础篇」机器学习概览
文章目录1. 什么是机器学习2. 引入机器学习3. 应用场景4. 机器学习分类4.1. 有无人类监督4.2. 是否增量学习4.3. 泛化方式5. 主要挑战6. 测试与验证1. 什么是机器学习 机器学习(Machine Learning,ML)是一个研究领域,让计算机无需…...
揭秘可视化图探索工具 NebulaGraph Explore 是如何实现图计算的
前言 在可视化图探索工具 NebulaGraph Explorer 3.1.0 版本中加入了图计算工作流功能,针对 NebulaGraph 提供了图计算的能力,同时可以利用工作流的 nGQL 运行能力支持简单的数据读取,过滤及写入等数据处理功能。 本文将简单分享下 NebulaGr…...
移动架构43_什么是Jetpack
Android移动架构汇总 文章目录一 Android 开发框架演变1 MVC2 MVP3 MVVM二 什么是JetPack三 如何构建支持Jetpack项目一 Android 开发框架演变 1 MVC Model-View-Controller,模型-视图-控制器,Model负责数据管理,View负责UI显…...
TiDB的分布式事务原理探究
事务开启 获取全局授时作为startTS构建一个tikvTxn对象(包括snapshot)。 事务写 txn.Set方法本质上将kv值写入了一个内存缓存(即kv/memdb_buffer.go中的memDbBuffer)中。该内存kv数据库利用的是golevel提供的功能。 事务回滚 直接将tikvTxn的valid字段…...
【C语言】函数指针和指针函数
文章目录[TOC](文章目录)前言概述函数指针定义:使用:回调函数指针函数前言 今天学一下函数指针 提示:以下是本篇文章正文内容,下面案例可供参考 概述 函数指针:是一个指向函数的指针,在内存空间中存放的…...
Nodejs中npx简介和作用
一、npx简介npm从5.25.2版开始,增加了 npx 命令。方便了我在项目中使用全局包。二、安装Node安装后自带npm模块,可以直接使用npx命令。如果不能使用用,就要手动安装一下。npm install -g npx三、使用npx想要解决的主要问题,就是调…...
Matplotlib精品学习笔记001——绘制3D图形详解+实例讲解
3D图片更生动,或许在时间序列数据的展示上更胜一筹 想法: 学习3D绘图的想法来自科研绘图中。我从事的专业是古植物学,也就是和植物化石打交道。化石有三大信息:1.物种信息,也就是它的分类学价值;2.时间信息…...
学习ifconfig实战技巧,成为网络管理高手
文章目录前言一. ifconfig 命令介绍二. 语法格式及常用选项三. 参考案例3.1 显示网络设备信息3.2 启动和关闭指定的网卡3.3 对指定的网卡设备执行修改IP地址操作3.4 启动和关闭ARP协议3.5 使用ifconfig添加网卡总结前言 大家好,又见面了,我是沐风晓月&a…...
day38|70. 爬楼梯(进阶)、322. 零钱兑换、279.完全平方数
70. 爬楼梯(进阶) 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2…...
SciencePlots——绘制论文中的图片
文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了:一行…...
五年级数学知识边界总结思考-下册
目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...
数据链路层的主要功能是什么
数据链路层(OSI模型第2层)的核心功能是在相邻网络节点(如交换机、主机)间提供可靠的数据帧传输服务,主要职责包括: 🔑 核心功能详解: 帧封装与解封装 封装: 将网络层下发…...
ServerTrust 并非唯一
NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个生活电费的缴纳和查询小程序
一、项目初始化与配置 1. 创建项目 ohpm init harmony/utility-payment-app 2. 配置权限 // module.json5 {"requestPermissions": [{"name": "ohos.permission.INTERNET"},{"name": "ohos.permission.GET_NETWORK_INFO"…...
【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...
Map相关知识
数据结构 二叉树 二叉树,顾名思义,每个节点最多有两个“叉”,也就是两个子节点,分别是左子 节点和右子节点。不过,二叉树并不要求每个节点都有两个子节点,有的节点只 有左子节点,有的节点只有…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
怎么让Comfyui导出的图像不包含工作流信息,
为了数据安全,让Comfyui导出的图像不包含工作流信息,导出的图像就不会拖到comfyui中加载出来工作流。 ComfyUI的目录下node.py 直接移除 pnginfo(推荐) 在 save_images 方法中,删除或注释掉所有与 metadata …...
