24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)
知识要点
keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数
- 1.保存模型和参数
- save_model
- callback ModelCheckpoint
- 2. 只保存参数
- save_weights
- callback ModelCheckpoint save_weights_only = True
保存模型:
- 案例数据: Fashion-MNIST总共有十个类别的图像
- model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5')) # 保存参数的方法
- 加载参数: model.load_weights(os.path.join(logdir, 'fashion_mnist_weight.h5'))
- 保存模型: model.save(os.path.join(logdir, 'fashion_mnist_model.h5'))
- 加载模型: model2 = keras.models.load_model(os.path.join(logdir, 'fashion_mnist_model.h5'))
- 把keras模型保存成savedmodel格式: tf.saved_model.save(model, './keras_saved_model')
一 模型保存和部署
- TFLite是为了将深度学习模型部署在移动端和嵌入式设备的工具包,可以把训练好的TF模型通过转化、部署和优化三个步骤,达到提升运算速度,减少内存、显存占用的效果。
- TFlite主要由Converter和Interpreter组成。Converter负责把TensorFlow训练好的模型转化,并输出为.tflite文件(FlatBuffer格式)。转化的同时,还完成了对网络的优化,如量化。Interpreter则负责把.tflite部署到移动端,嵌入式(embedded linux device)和microcontroller,并高效地执行推理过程,同时提供API接口给Python,Objective-C,Swift,Java等多种语言。简单来说,Converter负责打包优化模型,Interpreter负责高效易用地执行推理。
-
Fashion-MNIST总共有十个类别的图像。每一个类别由训练数据集6000张图像和测试数据集1000张图像。所以训练集和测试集分别包含60000张和10000张。测试训练集用于评估模型的性能。
-
每一个输入图像的高度和宽度均为28像素。数据集由灰度图像组成。Fashion-MNIST,中包含十个类别,分别是t-shirt,trouser,pillover,dress,coat,sandal,shirt,sneaker,bag,ankle boot。
1.1 模型创建
- 导包
# 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
- 时尚数据导入
# 时尚数据导入
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
- 标准化
# 标准化
from sklearn.preprocessing import StandardScaler # preprocessing 预处理
scaler = StandardScaler()x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 784))
x_valid_scaled = scaler.fit_transform(x_valid.astype(np.float32).reshape(-1, 784))
x_test_scaled = scaler.fit_transform(x_test.astype(np.float32).reshape(-1, 784))
- 创建模型
# 创建模型
model = keras.models.Sequential([keras.layers.Dense(512, activation = 'relu', input_shape = (784, )),keras.layers.Dense(256, activation = 'relu'),keras.layers.Dense(128, activation = 'relu'),keras.layers.Dense(10, activation = 'softmax')])model.compile(loss = 'sparse_categorical_crossentropy',optimizer = 'adam',metrics = ['accuracy'])
1.2 保存模型
# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir), # 保存地址# 保存效果最好的模型: save_best_onlykeras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True, save_weights_only = True),keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)] history = model.fit(x_train_scaled, y_train, epochs = 10,validation_data= (x_valid_scaled, y_valid),callbacks = callbacks)
- 保存模型
# 保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
model.save(output_model_file2)
-
保存参数
# 另一种保存参数的方法
model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))
- 模型评估
# evaluate 评估
model.evaluate(x_valid_scaled, y_valid) # [0.35909169912338257, 0.88919997215271]
- 模型加载
# 加载模型
model2 = keras.models.load_model(output_model_file2)
model2.evaluate(x_valid_scaled, y_valid) # [0.35909169912338257, 0.88919997215271]
二 保存模型为savemodel格式
# 把keras模型保存成savedmodel格式
tf.saved_model.save(model, './keras_saved_model')
- 读取模型
# 加载savedmodel模型
loaded_saved_model = tf.saved_model.load('./keras_saved_model')
loaded_saved_model
2.1 另一种保存
# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):os.mkdir(logdir)output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
model.load_weights(output_model_file)
三 tflite_interpreter 的使用
- 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import os
with open('./tflite_models/concrete_func_tf_lite', 'rb') as f:concrete_func_tflite = f.read()
- 创建interpreter
# 创建interpreter
interpreter = tf.lite.Interpreter(model_content = concrete_func_tflite)
# 分配内存
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
- 预测数值
input_data = tf.constant(np.ones(input_details[0]['shape'], dtype = np.float32))
# 传入预测数据
interpreter.set_tensor(input_details[0]['index'], input_data)# 执行预测
interpreter.invoke()# 获取输出
output_results = interpreter.get_tensor(output_details[0]['index'])
print(output_results)
四 to_concrete_function
- 加载文件
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model(np.ones((1, 784)))
- 把keras模型转化为concrete function
# 把keras模型转化为concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
# 使用
keras_concrete_func(tf.constant(np.ones((1, 784), dtype = np.float32)))
五 to_quantized_tflite
5.1 keras to tflite
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model
# lite 精简版模型 # 创建转化器
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
keras_to_tflite_converter
# 给converter添加量化的优化 # 把32位的浮点数变成8位整数
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]# 执行转化
keras_tflite = keras_to_tflite_converter.convert()
# 写入指定文件
import os
if not os.path.exists('./tflite_models'):os.mkdir('./tflite_models')with open('./tflite_models/quantized_keras_tflite', 'wb') as f:f.write(keras_tflite)
5.2 concrete function to tflite
# 把keras模型转化成concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,loaded_keras_model.inputs[0].dtype))
concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])
concrete_func_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
concrete_func_tflite = concrete_func_to_tflite_converter.convert()
with open('./tflite_models/quantized_concrete_func_tf_lite', 'wb') as f:f.write(concrete_func_tflite)
5.3 saved_model to tflite
saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('./keras_saved_model/')
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()
with open('./tflite_models/quantized_saved_model_tflite', 'wb') as f:f.write(saved_model_tflite)
相关文章:

24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)
知识要点 keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数 1.保存模型和参数 save_modelcallback ModelCheckpoint2. 只保存参数 save_weightscallback ModelCheckpoint save_weights_only True 保存模型: 案例数据: Fashion-MNIST总共有十个类别的图像model.save_w…...

【Echarts图例点击事件】自定义Echarts图例legend点击事件(已解决)
目录先睹为快(效果)1、实现Echarts多条曲线2、点击echarts触发接口请求2.1 先默认隐藏部分数据2.2 自定义legend图例点击事件3、源码下载地址(解压即用)**【写在前面】**这下我又不得不说了,还是客户现场使用时想查询一…...

uniapp-首页配置
为了获取到后台服务器发来的数据,需要配置相应的网络地址。位置在main.js入口文件中。 import { $http } from escook/request-miniprogramuni.$http $http // 配置请求根路径 $http.baseUrl https://api-hmugo-web.itheima.net// 请求开始之前做一些事情 $http.…...

支持DDR5,超频更简单,小雕够给力,技嘉B760M小雕WIFI主板上手
目前13代酷睿已经全员集结了,其中全新的i5 13490F应该依然会备受欢迎,当然了,刚上市不久的13代酷睿价格方面还不是很有吸引力,好在12代酷睿在新一代主板上面依然可用,所以预算有限的朋友,完全可用继续使用1…...

fengMap 自定义dom 偏离实际位置;缩放时飘出地图所在区域
目录 一、问题 二、原因及解决方法 三、总结 一、问题 1.前人写了一份代码,很奇怪。使用 new fengmap.FMCompositeMarker添加的复合覆盖物位置是正常的,缩放的时候也是正常的,仍然处于地图内部;但是new fengmap.FMDomMarker添加…...

TryHackMe-黑我杯
黑我杯 相信我们大家在TryHackMe的日积月累都学到了不少东西,从纯萌新到oscp再到更高 我很高兴能将国内各thm玩家聚集到一起,构建一个更好的学习环境和氛围 本次娱乐分两场: Offensive Pentesting — 中等难度Junior Penetration — 容易难…...

【JAVA程序设计】【C00109】基于SSM(非maven)的员工工资管理系统
基于SSM(非maven)的员工工资管理系统项目简介项目获取开发环境项目技术运行截图项目简介 基于ssm框架非maven开发的企业工资管理系统共分为二个角色:系统管理员、员工 管理员角色包含以下功能: 系统后台登陆、管理员管理、员工信…...

《计算机原理》——HelloWorld.cpp如何运行的
学校《计算机原理》开课啦!特此开辟专栏,将一些知识作为笔记,记录下来。 前言 本篇博客知识点来源于educoder的相关题目 1. 相关知识 1.1 计算机语言 计算机语言是人与计算机之间通讯的语言,计算机语言包括编写计算机程序的字符…...
【面试题】在JS循环中使用await会怎么样?
前言这个问题是这样产生的?某天,在学习异步的知识遇到这样一道题:使用Promise的方式,每隔一秒输出数组中一个值const arr [1, 2, 3] arr.reduce((pre, cur) > {return pre.then(() > {returnnewPromise((resolve, rejec…...

Qt QMessageBox详解
文章目录一.QMessageBox介绍枚举属性函数二.QMessageBox的用法1.导入QMessage库2.弹窗提示3.提供选项的弹窗提示4.作为提示,报警,报错提示窗口一.QMessageBox介绍 文本消息显示框(message box)向用户发出情况警报信息并进一步解释警报或向用户提问&…...
Flutter之beamer路由入门指南
beamer路由入门指南 前言使用方法1、路由配置方式1路由配置方式2路由跳转测试现象前言 Beamer是一个很好用的路由组件,本文以beamer1.5.0版本进行说明,前面博主也介绍了其他路由组件 Flutter实战之go_router路由组件入门指南 、 Flutter之Fluro路由组件入门指南 Flutter之Ge…...

「基础篇」机器学习概览
文章目录1. 什么是机器学习2. 引入机器学习3. 应用场景4. 机器学习分类4.1. 有无人类监督4.2. 是否增量学习4.3. 泛化方式5. 主要挑战6. 测试与验证1. 什么是机器学习 机器学习(Machine Learning,ML)是一个研究领域,让计算机无需…...

揭秘可视化图探索工具 NebulaGraph Explore 是如何实现图计算的
前言 在可视化图探索工具 NebulaGraph Explorer 3.1.0 版本中加入了图计算工作流功能,针对 NebulaGraph 提供了图计算的能力,同时可以利用工作流的 nGQL 运行能力支持简单的数据读取,过滤及写入等数据处理功能。 本文将简单分享下 NebulaGr…...

移动架构43_什么是Jetpack
Android移动架构汇总 文章目录一 Android 开发框架演变1 MVC2 MVP3 MVVM二 什么是JetPack三 如何构建支持Jetpack项目一 Android 开发框架演变 1 MVC Model-View-Controller,模型-视图-控制器,Model负责数据管理,View负责UI显…...
TiDB的分布式事务原理探究
事务开启 获取全局授时作为startTS构建一个tikvTxn对象(包括snapshot)。 事务写 txn.Set方法本质上将kv值写入了一个内存缓存(即kv/memdb_buffer.go中的memDbBuffer)中。该内存kv数据库利用的是golevel提供的功能。 事务回滚 直接将tikvTxn的valid字段…...

【C语言】函数指针和指针函数
文章目录[TOC](文章目录)前言概述函数指针定义:使用:回调函数指针函数前言 今天学一下函数指针 提示:以下是本篇文章正文内容,下面案例可供参考 概述 函数指针:是一个指向函数的指针,在内存空间中存放的…...
Nodejs中npx简介和作用
一、npx简介npm从5.25.2版开始,增加了 npx 命令。方便了我在项目中使用全局包。二、安装Node安装后自带npm模块,可以直接使用npx命令。如果不能使用用,就要手动安装一下。npm install -g npx三、使用npx想要解决的主要问题,就是调…...

Matplotlib精品学习笔记001——绘制3D图形详解+实例讲解
3D图片更生动,或许在时间序列数据的展示上更胜一筹 想法: 学习3D绘图的想法来自科研绘图中。我从事的专业是古植物学,也就是和植物化石打交道。化石有三大信息:1.物种信息,也就是它的分类学价值;2.时间信息…...

学习ifconfig实战技巧,成为网络管理高手
文章目录前言一. ifconfig 命令介绍二. 语法格式及常用选项三. 参考案例3.1 显示网络设备信息3.2 启动和关闭指定的网卡3.3 对指定的网卡设备执行修改IP地址操作3.4 启动和关闭ARP协议3.5 使用ifconfig添加网卡总结前言 大家好,又见面了,我是沐风晓月&a…...
day38|70. 爬楼梯(进阶)、322. 零钱兑换、279.完全平方数
70. 爬楼梯(进阶) 假设你正在爬楼梯。需要 n 阶你才能到达楼顶。 每次你可以爬 1 或 2 个台阶。你有多少种不同的方法可以爬到楼顶呢? 示例 1: 输入:n 2 输出:2 解释:有两种方法可以爬到楼顶。 1. 1 阶 1 阶 2. 2…...

UE5 学习系列(二)用户操作界面及介绍
这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

(二)原型模式
原型的功能是将一个已经存在的对象作为源目标,其余对象都是通过这个源目标创建。发挥复制的作用就是原型模式的核心思想。 一、源型模式的定义 原型模式是指第二次创建对象可以通过复制已经存在的原型对象来实现,忽略对象创建过程中的其它细节。 📌 核心特点: 避免重复初…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...
Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!
一、引言 在数据驱动的背景下,知识图谱凭借其高效的信息组织能力,正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合,探讨知识图谱开发的实现细节,帮助读者掌握该技术栈在实际项目中的落地方法。 …...

云原生玩法三问:构建自定义开发环境
云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

排序算法总结(C++)
目录 一、稳定性二、排序算法选择、冒泡、插入排序归并排序随机快速排序堆排序基数排序计数排序 三、总结 一、稳定性 排序算法的稳定性是指:同样大小的样本 **(同样大小的数据)**在排序之后不会改变原始的相对次序。 稳定性对基础类型对象…...

宇树科技,改名了!
提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...