从 PyTorch 到 TensorFlow Lite:模型训练与推理
一、方案介绍
- 研发阶段:利用 PyTorch 的动态图特性进行快速原型验证,快速迭代模型设计。
- 灵活性与易用性:PyTorch 是一个非常灵活且易于使用的深度学习框架,特别适合研究和实验。其动态计算图特性使得模型的构建和调试变得更加直观,开发者可以在运行时修改模型结构。
- 快速原型开发:许多研究人员和开发者选择 PyTorch 进行模型训练,因为它支持快速原型开发和灵活的模型设计,能够快速验证新想法并进行迭代。
- 转换阶段:将训练好的模型通过 TorchScript 导出为 ONNX 格式,再转换为 TensorFlow 格式,最后生成 TFLite 模型。
- 专为移动和嵌入式设备优化:TensorFlow Lite 是专为移动和嵌入式设备设计的推理框架,能够在资源有限的环境中高效运行模型,确保在各种设备上实现实时推理。
- 支持模型量化和优化:TFLite 支持模型量化和优化,能够显著减小模型大小并提高推理速度,适合在手机、边缘设备等场景中使用。这使得开发者能够在不牺牲准确度的情况下,提升模型的运行效率。
- 部署阶段:将 TFLite 模型集成到 Android、iOS 或嵌入式系统中,确保模型能够在目标设备上高效运行。
- 内存和计算资源的优化:在推理阶段,使用 TFLite 可以减少内存占用和计算资源消耗,尤其是在移动设备和嵌入式系统上。这对于需要长时间运行的应用尤为重要,可以延长设备的电池寿命。
- 多种优化技术:TFLite 提供了多种优化技术,如模型量化(将浮点数转换为整数),可以进一步提高推理速度并降低功耗。这使得在实时应用中能够实现更快的响应时间,提升用户体验。
二、实例1:CNN模型的转换
注:python 版本为3.10
2.1 pytorch模型训练
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 检查是否支持 MPS
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 初始化模型、损失函数和优化器
model = CNNModel().to(device) # 将模型移动到 MPS 设备
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
for epoch in range(20):for images, labels in train_loader:images, labels = images.to(device), labels.to(device) # 将数据移动到 MPS 设备optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/20], Loss: {loss.item():.6f}')# 保存模型
torch.save(model.state_dict(), 'cnn_mnist.pth')
print("Model saved as cnn_mnist.pth")
2.2 pth模型转onnx 并验证一致性
import numpy as np
import onnx
import onnxruntime as ort
import torch
import torch.nn as nn# 定义 CNN 模型
class CNNModel(nn.Module):def __init__(self):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = nn.functional.relu(self.conv1(x))x = nn.functional.max_pool2d(x, 2)x = nn.functional.relu(self.conv2(x))x = nn.functional.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = nn.functional.relu(self.fc1(x))x = self.fc2(x)return x# 加载模型并进行推理
model = CNNModel()
model.load_state_dict(torch.load('cnn_mnist.pth', weights_only=True)) # 加载保存的模型权重
model.eval() # 设置为评估模式# 创建一个示例输入
dummy_input = torch.randn(1, 1, 28, 28) # MNIST 图像的形状# 使用 PyTorch 进行推理
with torch.no_grad():pytorch_output = model(dummy_input)# 导出模型为 ONNX 格式
torch.onnx.export(model, dummy_input, 'cnn_mnist.onnx', export_params=True, opset_version=11)
print("Model exported to cnn_mnist.onnx")# 使用 ONNX 进行推理
onnx_model = onnx.load('cnn_mnist.onnx')
ort_session = ort.InferenceSession('cnn_mnist.onnx')# 准备输入数据
onnx_input = dummy_input.numpy() # 将 PyTorch 张量转换为 NumPy 数组
onnx_input = onnx_input.astype(np.float32) # 确保数据类型为 float32# 使用 ONNX 进行推理
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: onnx_input})# 比较输出
pytorch_output_np = pytorch_output.numpy() # 将 PyTorch 输出转换为 NumPy 数组
onnx_output_np = onnx_output[0] # ONNX 输出是一个列表,取第一个元素# 检查输出是否一致
if np.allclose(pytorch_output_np, onnx_output_np, atol=1e-5):print("The outputs are consistent between PyTorch and ONNX.")
else:print("The outputs are NOT consistent between PyTorch and ONNX.")# 打印输出结果
print("PyTorch output:", pytorch_output_np)
print("ONNX output:", onnx_output_np)
The outputs are consistent between PyTorch and ONNX.
PyTorch output: [[ -1.5153266 -11.934659 0.5428004 -16.058285 -3.6684208 -4.596178-14.53585 -3.3159208 -5.7872214 -5.3301578]]
ONNX output: [[ -1.5153263 -11.934658 0.5428015 -16.058285 -3.66842 -4.5961757-14.53585 -3.3159204 -5.787223 -5.3301597]]
2.3 onnx模型转tflite
参考这个项目:onnx2tflite
git clone https://github.com/MPolaris/onnx2tflite.git
cd onnx2tflite
conda install tensorflow=2.11.0
pip install .
python -m onnx2tflite --weights ../pth2onnx/cnn_mnist.onnx
2.4 onnx模型和tflite一致性验证
import numpy as np
import onnxruntime as ort
import tensorflow as tf# 1. 加载 ONNX 模型
onnx_model_path = 'cnn_mnist.onnx'
onnx_session = ort.InferenceSession(onnx_model_path)# 2. 加载 TFLite 模型
tflite_model_path = 'cnn_mnist.tflite'
tflite_interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
tflite_interpreter.allocate_tensors()# 3. 准备输入数据
# 假设输入数据是 MNIST 数据集的一部分,形状为 (1, 28, 28, 1)
input_data = np.random.rand(1, 28, 28, 1).astype(np.float32) # Keras 输入
input_data_onnx = input_data.transpose(0, 3, 1, 2) # 转换为 ONNX 输入格式 (1, 1, 28, 28)# 4. 使用相同的输入数据进行推理# ONNX 模型推理
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output = onnx_session.run(None, {onnx_input_name: input_data_onnx})[0]
print("ONNX Output:", onnx_output)# TFLite 模型推理
tflite_input_details = tflite_interpreter.get_input_details()
tflite_output_details = tflite_interpreter.get_output_details()# 检查 TFLite 输入形状
print("TFLite Input Shape:", tflite_input_details[0]['shape'])# 设置 TFLite 输入
# 确保输入数据的形状与 TFLite 模型的输入要求一致
tflite_interpreter.set_tensor(tflite_input_details[0]['index'], input_data)
tflite_interpreter.invoke()
tflite_output = tflite_interpreter.get_tensor(tflite_output_details[0]['index'])
print("TFLite Output:", tflite_output)# 5. 比较输出结果
# 计算输出的差异
onnx_difference = np.abs(onnx_output - tflite_output)# 输出结果
print("Difference (ONNX vs TFLite):", onnx_difference)# 检查是否一致
if np.all(onnx_difference < 1e-5): # 设定一个阈值print("The outputs are consistent between ONNX and TFLite models.")
else:print("The outputs are not consistent between ONNX and TFLite models.")
ONNX Output: [[ -3.7372704 -6.5073314 -1.1807165 -2.4232314 -10.638929 2.2660115-4.5868526 -2.7494073 -0.5609715 -6.331989 ]]
TFLite Input Shape: [ 1 28 28 1]
TFLite Output: [[ -3.7372704 -6.5073323 -1.180716 -2.4232314 -10.6389282.2660117 -4.5868545 -2.7494078 -0.56097114 -6.331988 ]]
Difference (ONNX vs TFLite): [[0.0000000e+00 9.5367432e-07 4.7683716e-07 0.0000000e+00 9.5367432e-072.3841858e-07 1.9073486e-06 4.7683716e-07 3.5762787e-07 9.5367432e-07]]
The outputs are consistent between ONNX and TFLite models.
相关文章:

从 PyTorch 到 TensorFlow Lite:模型训练与推理
一、方案介绍 研发阶段:利用 PyTorch 的动态图特性进行快速原型验证,快速迭代模型设计。 灵活性与易用性:PyTorch 是一个非常灵活且易于使用的深度学习框架,特别适合研究和实验。其动态计算图特性使得模型的构建和调试变得更加直…...
C++ 17 正则表达式
正则表达式不是C语言的一部分,这里仅做简单的介绍。 将这项技术引进,在 』的讨论 正则表达式描述了一种字符串匹配的模式。一般使用正则表达式主要是实现下面三个需求: 1,检查一个串是否包含某种形式的子串; 2,将匹配的子串替换&a…...

【存储基础】存储设备和服务器的关系和区别
文章目录 1. 存储设备和服务器的区别2. 客户端访问数据路径场景1:经过服务器处理场景2:客户端直连 3. 服务器作为"中转站"的作用 刚开始接触存储的时候,以为数据都是存放在服务器上的,服务器和存储设备是一个东西&#…...
kernel内核和driver驱动的区别
“kernel”和“driver”虽然都跟操作系统和硬件有关,但它们指的是不同的东西。 1. Kernel(内核) 定义:操作系统的核心组件,是操作系统中负责管理系统资源和硬件的最底层软件。 职责: 管理CPU调度ÿ…...

5.29打卡
浙大疏锦行 DAY 38 Dataset和Dataloader类 知识点回顾: 1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法) 2. Dataloader类 3. minist手写数据集的了解 作业:了解下cifar数据集,尝试获取其中一张图…...

【黑马程序员uniapp】项目配置、请求函数封装
黑马程序员前端项目uniapp小兔鲜儿微信小程序项目视频教程,基于Vue3TsPiniauni-app的最新组合技术栈开发的电商业务全流程_哔哩哔哩_bilibili 参考 有代码,还有app、h5页面、小程序的演示 小兔鲜儿-vue3ts-uniapp-一套代码多端部署: 小兔鲜儿-vue3ts-un…...
ios tableview吸顶
由于项目需要实现一个上滑吸顶的效果,网上也看到有很多种方式实现,但是如果加上下拉刷新的功能会导致界面异常,还有第三方库实现方式库,太繁琐了,下面是我的实现方式,效果如下: tablevie滑动吸顶…...

PyTorch——DataLoader的使用
batch_size, drop_last 的用法 shuffle shuffleTrue 各批次训练的图像不一样 shuffleFalse 在第156step顺序一致...
【Python 进阶2】抽象方法和实例调用方法
抽象方法和实例调用方法 对比表格: 特性抽象方法 (forward)实例调用方法 (call)定义方式abc.abstractmethod 装饰器特殊方法名 __call__调用方式不能直接调用,必须通过子类实现可以直接调用对象:controller(attn, ...)实现要求必须由子类实…...
第1章:走进Golang
第1章:走进Golang 一、Golang简介 Go语言(又称Golang)是由Google的Robert Griesemer、Rob Pike及Ken Thompson开发的一种开源编程语言。它诞生于2007年,2009年11月正式开源。Go语言的设计初衷是为了在不损失应用程序性能的情况下…...

Predixy的docker化
概述 当前已有一套redis cluster的集群,但是fs中的hiredis只能配置单实例redis。 AI了一下方案,可以使用redis的proxy组件来实现从hiredis到redis cluster的互通。 代码地址:https://github.com/joyieldInc/predixy Predixy特性介绍&…...

C++ 之 多态 【虚函数表、多态的原理、动态绑定与静态绑定】
目录 前言 1.多态的原理 1.1虚函数表 1.2派生类中的虚表 1.3虚函数、虚表存放位置 1.4多态的原理 1.5多态条件的思考 2.动态绑定与静态绑定 3.单继承和虚继承中的虚函数表 3.1单继承中的虚函数表 3.2多继承(非菱形继承)中的虚函数表 4.问答题 前言 需要声明的&#x…...

【JavaWeb】Maven、Servlet、cookie/session
目录 5. Maven6. Servlet6.1 Servlet 简介6.2 HelloServlet6.3 Servlet原理6.4 Mapping( **<font style"color:rgb(44, 44, 54);">映射 ** )问题6.5 ServletContext6.6 HttpServletResponse<font style"color:rgb(232, 62, 140);background-color:rgb(…...
[蓝桥杯]阶乘求值【省模拟赛】
问题描述 给定 nn,求 n!n! 除以 10000000071000000007 的余数。 其中 n!n! 表示 nn 的阶乘,值为从 11 连乘到 nn 的积,即 n!123…nn!123…n。 输入格式 输入一行包含一个整数 nn。 输出格式 输出一行,包含一个整数ÿ…...
鸿蒙OSUniApp微服务架构实践:从设计到鸿蒙部署#三方框架 #Uniapp
UniApp微服务架构实践:从设计到鸿蒙部署 引言 在最近的一个大型跨平台项目中,我们面临着一个有趣的挑战:如何在UniApp框架下构建一个可扩展的微服务架构,并确保其在包括鸿蒙在内的多个操作系统上流畅运行。本文将分享我们的实践…...

Rust 编程实现猜数字游戏
文章目录 编程实现猜数字游戏游戏规则创建新项目默认代码处理用户输入代码解析 生成随机数添加依赖生成逻辑 比较猜测值与目标值类型转换 循环与错误处理优化添加循环优雅处理非法输入 最终完整代码核心概念总结 编程实现猜数字游戏 我们使用cargo和rust实现一个经典编程练习…...

关于神经网络中的激活函数
这篇博客主要介绍一下神经网络中的激活函数以及为什么要存在激活函数。 首先,我先做一个简单的类比:激活函数的作用就像给神经网络里的 “数字信号” 加了一个 “智能阀门”,让机器能学会像人类一样思考复杂问题。 没有激活i函数的神经网络…...

CentOS_7.9 2U物理服务器上部署系统简易操作步骤
近期单位网站革新,鉴于安全加固,计划将原有Windows环境更新到Linux-CentOS 7.9,这版本也没的说(绝)了(版)官方停止更新,但无论如何还是被sisi的牵挂着这一大批人,毕竟从接…...
第十三篇:MySQL 运维自动化与可观测性建设实践指南
本篇重点介绍 MySQL 运维自动化的关键工具与流程,深入实践如何构建高效可观测体系,实现数据库系统的持续稳定运行与故障快速响应。 一、为什么需要 MySQL 运维自动化与可观测性? 运维挑战: 手动备份容易遗漏或失败; …...

短视频平台差异视角下开源AI智能名片链动2+1模式S2B2C商城小程序的适配性研究——以抖音与快手为例
摘要 本文以抖音与快手两大短视频平台为研究对象,从用户群体、内容生态、推荐逻辑三维度分析其差异化特征,并探讨开源AI智能名片链动21模式与S2B2C商城小程序在平台适配中的创新价值。研究发现,抖音的流量中心化机制与优质内容导向适合品牌化…...
HTTP 如何升级成 HTTPS
有一个自己的项目需要上线,域名解析完成后,发现只能使用 http 协议,这在浏览器上会限制,提示用户不安全,所以需要把 HTTP 升级成 HTTPS 协议,但又不想花钱。 前提条件: 已经配置好 Nginx 服务器…...

【笔记】Windows 下载并安装 ChromeDriver
以下是 在 Windows 上下载并安装 ChromeDriver 的笔记: ✅ Windows 下载并安装 ChromeDriver 1️⃣ 确认 Chrome 浏览器版本 打开 Chrome 浏览器 点击右上角 ︙ → 帮助 → 关于 Google Chrome 记下版本号,例如:114.0.5735.199 2️⃣ 下载…...

Spark-Core Project
RDD转换算子总结 RDD转换算子分为Value类型、双Value类型和Key - Value类型。 1、Value类型 map:对数据逐条映射转换,可改变数据类型或值。如 dataRDD.map(num > num * 2 运行结果: 2)mapPartitions:以分区为单位处…...
SQL 中的 `CASE WHEN` 如何使用?
✅ SQL 中的 CASE WHEN 如何使用? 一、CASE WHEN 是什么? CASE WHEN 是 SQL 中用于实现 条件判断 的表达式,功能类似于 if-else 或 switch-case,可用于 SELECT、WHERE、ORDER BY 等子句中。 go专栏:https://duoke360.com/tutorial/path/golang 二、语法格式 1. 简单 C…...

Wireshark 使用教程:让抓包不再神秘
一、什么是 tshark? tshark 是 Wireshark 的命令行版本,支持几乎所有 Wireshark 的核心功能。它可以用来: 抓包并保存为 pcap 文件 实时显示数据包信息 提取指定字段进行分析 配合 shell 脚本完成自动化任务 二、安装与验证 Kali Linux…...

JWT安全:接收无签名令牌.【签名算法设置为none绕过验证】
JWT安全:假密钥【签名随便写实现越权绕过.】 JSON Web 令牌 (JWT)是一种在系统之间发送加密签名 JSON 数据的标准化格式。理论上,它们可以包含任何类型的数据,但最常用于在身份验证、会话处理和访问控制机制中发送有关用户的信息(“声明”)。…...
什么算得到?什么又算失去?
目录 **一、什么是“得到”?****二、什么是“失去”?****三、得到与失去的悖论****四、如何超越得失二元论?****五、一个更本质的答案** 关于“得到”与“失去”的界定,本质上是对存在状态和主观认知的辩证思考。这两者并非绝对&a…...

白银价格查询接口如何用Java进行调用?
一、什么是白银价格查询接口? 它聚焦于上海黄金交易所、上海期货交易所等权威市场,精准提供白银价格行情数据,助力用户实时把握市场脉搏,做出明智的投资决策。 二、应用场景 分析软件:金融类平台可以集成本接口&…...

FreeBSD 14.3 候选版本附带 Docker 镜像和关键修复
新的月份已经到来,FreeBSD 14.3 候选发布版 1 现已开放测试,它带来了一些您可能会觉得有用的更新,特别是如果您对Docker容器感兴趣的话。RC1 版本中一个非常受欢迎的改进是,FreeBSD 项目已开始将官方开放容器计划 (OCI) 镜像发布到…...
NodeJS全栈WEB3面试题——P6安全与最佳实践
🔐 6.1 如何防范重放攻击、私钥泄露、钓鱼签名? ✅ 重放攻击(Replay Attack)防范: 引入 nonce:每次登录或交易签名都携带唯一 nonce; 链 ID 检查:在签名中加入特定链 ID࿰…...