matlab从pytorch中导入LeNet-5网络框架
文章目录
- 一、Pytorch的LeNet-5网络准备
- 二、保存用于导入matlab的model
- 三、导入matlab
- 四、用matlab训练这个导入的网络
这里演示从pytorch的LeNet-5网络导入到matlab中进行训练用。
一、Pytorch的LeNet-5网络准备
根据LeNet-5的结构图,我们可以写如下结构
import torch
import torch.nn as nnclass LeNet5(nn.Module):def __init__(self, num_classes=10):super(LeNet5, self).__init__()self.feature_extractor = nn.Sequential(# C1: Conv(1→6), 输出 28x28 → 6x28x28nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),nn.BatchNorm2d(6),nn.ReLU(inplace=True),# S2: MaxPool 2x2, 输出 6x14x14nn.MaxPool2d(kernel_size=2, stride=2),# C3: Conv(6→16), 输出 16x10x10nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.BatchNorm2d(16),nn.ReLU(inplace=True),# S4: MaxPool 2x2, 输出 16x5x5nn.MaxPool2d(kernel_size=2, stride=2),# C5: Conv(16→120), 输出 120x1x1(接近 flatten)nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),nn.BatchNorm2d(120),nn.ReLU(inplace=True))self.classifier = nn.Sequential(nn.Flatten(), # [batch, 120]nn.Linear(120, 84),nn.BatchNorm1d(84),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(84, num_classes))def forward(self, x):x = self.feature_extractor(x)x = self.classifier(x)return xif __name__ == "__main__":model = LeNet5()model.eval()# 示例输入:MNIST 的图像大小 [1, 1, 28, 28]example_input = torch.randn(1, 1, 28, 28)# Tracingtraced_model = torch.jit.trace(model, example_input)# 保存traced_model.save("traced_lenet5.pt")print("✅ traced_lenet5.pt 已成功保存!")
二、保存用于导入matlab的model
在上面的代码中,我们有几行是产生trace model的,即

torch.jit.trace() 是 PyTorch 的一种 静态图(Static Graph)转换方法,它会:
- 运行一次前向传播(forward),记录下所有的张量操作;
- 然后构建一个不可变的计算图(graph),这个图就是所谓的 trace model。
保存这个model后,我们就得到了traced_lenet5.pt这个文件。
三、导入matlab
导入matlab可以通过APPS里的Deep Network Designer,如下图

然后通过From PyTorch这个地方,导入刚才保存的网络结构

点开From PyTorch后, 我们可以复制刚才保存的traced_lenet5.pt这个文件的绝对路径用于导入,如下图

然后,import就会有,如下结果

然后,点击红色方框那部分,进行一下输入尺寸的修改

导入的这个网络框架,我们还要在末尾段加入softmax层,这个层在原pytorch框架里没写

这样,我们就完成了LeNet5从Pytorch里导入到matlab了。接着我们可以通过Analyze按钮分析这个网络,如下图

没有问题后,我们就可以Export这个网络到工作区了,输出的网络自动命名为net_1。

四、用matlab训练这个导入的网络
训练的代码如下
% 创建一个图像数据存储对象 `imds`,用于从名为 "DigitsData" 的文件夹中加载图像数据
imds = imageDatastore("DigitsData", ...IncludeSubfolders=true, ... % 指定在加载数据时包含子文件夹中的图像LabelSource="foldernames"); % 使用子文件夹的名称作为图像的标签(自动分类)% 获取数据集中所有的类别名称(即文件夹名),并将其存储在变量 classNames 中
classNames = categories(imds.Labels); % 将 imds.Labels%%
% 使用 splitEachLabel 函数将原始图像数据集 imds 随机划分为训练集、验证集和测试集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");% 设置用于网络训练的选项,这里使用的是随机梯度下降动量法(SGDM)
% 最大训练轮数(epoch):训练过程中将整个训练集完整迭代 4 次
% 指定验证数据集,用于在训练过程中评估模型的泛化能力
% 每训练 30 个 mini-batch 执行一次验证评估
% 在训练过程中显示实时图形界面,包括损失值和准确率的变化曲线
% 指定训练期间关注的评估指标为准确率(accuracy)
% 禁止在命令行窗口输出详细训练信息(安静模式)
options = trainingOptions("sgdm", ... MaxEpochs = 4, ... ValidationData = imdsValidation, ... ValidationFrequency = 30, ... Plots = "training-progress", ... Metrics = "accuracy", ... Verbose = false); % 使用 trainnet 函数对神经网络进行训练
net = trainnet(imdsTrain, net_1, "crossentropy", options);%%
% 使用 testnet 函数对训练好的神经网络进行验证,并评估其准确率
accuracy = testnet(net, imdsTest, "accuracy");%%
% 对测试集进行批量预测,输出每个图像对应的类别得分(概率)
scores = minibatchpredict(net, imdsTest);% 将得分(scores)转换为类别标签,使用 classNames 映射到原始类名
YTest = scores2label(scores, classNames);% 获取测试集图像的总数量
numTestObservations = numel(imdsTest.Files);% 从测试集中随机选取 9 个样本用于可视化
idx = randi(numTestObservations, 9, 1);% 创建一个新的图形窗口
figure
tiledlayout("flow") % 使用自动流式布局排列子图(tiled layout)% 遍历 9 张图像,显示图像并在标题中标注预测类别
for i = 1:9nexttile % 在下一个网格位置准备绘图img = readimage(imdsTest, idx(i)); % 读取第 idx(i) 张图像imshow(img) % 显示图像title("Predicted Class: " + string(YTest(idx(i)))) % 设置标题,显示预测类别
end
上面用到的数据集是0-9的数字图片,如下图

训练的详细信息如下

预测结果显示

相关文章:
matlab从pytorch中导入LeNet-5网络框架
文章目录 一、Pytorch的LeNet-5网络准备二、保存用于导入matlab的model三、导入matlab四、用matlab训练这个导入的网络 这里演示从pytorch的LeNet-5网络导入到matlab中进行训练用。 一、Pytorch的LeNet-5网络准备 根据LeNet-5的结构图,我们可以写如下结构 import…...
淘宝商品数据爬取与分析
淘宝商品数据爬取与分析是一个涉及网络爬虫技术和数据分析方法的过程,以下是其主要步骤: 数据爬取 确定爬取目标:明确要爬取的淘宝商品类别、具体商品名称或关键词等,例如想要分析智能手机市场,就以 “智能手机” 为…...
Spring Boot向Vue发送消息通过WebSocket实现通信
注意:如果后端有contextPath,如/app,那么前端访问的url就是ip:port/app/ws 后端实现步骤 添加Spring Boot WebSocket依赖配置WebSocket端点和消息代理创建控制器,使用SimpMessagingTemplate发送消息 前端实现步骤 安装sockjs-…...
Django4.0的快速查询以及分页
1. filter 方法 filter 是 Django ORM 中最常用的查询方法之一。它用来根据给定的条件过滤查询集并返回满足条件的对象。 articles Article.objects.all() # 使用 SearchFilter 进行搜索 search_param request.query_params.get(search, None) author_id request.query_pa…...
LangChain/Eliza框架在使用场景上的异同,Eliza通过配置实现功能扩展的例子
LangChain与Eliza框架的异同分析 一、相同点 模块化架构设计 两者均采用模块化设计,支持灵活扩展和功能组合。LangChain通过Chains、Agents等组件实现多步骤任务编排,Eliza通过插件系统和信任引擎实现智能体功能的动态扩展。模块化特性降低…...
用spring-webmvc包实现AI(Deepseek)事件流(SSE)推送
前后端: Spring Boot Angular spring-webmvc-5.2.2包 代码片段如下: 控制层: GetMapping(value "/realtime/page/ai/sse", produces MediaType.TEXT_EVENT_STREAM_VALUE)ApiOperation(value "获取告警记录进行AI分析…...
MusicMint ,AI音乐生成工具
MusicMint是什么 MusicMint 是一款强大的人工智能音乐创作工具,旨在帮助用户轻松制作个性化的音乐作品。借助先进的 AI 技术,用户只需输入简短的描述或选择心仪的音乐风格,便能迅速生成独特的歌曲。该平台支持多种音乐风格,包括流…...
嵌入式学习笔记——SPI协议
SPI协议详解 SPI协议概述SPI接口信号介绍SPI通信模式SPI的通信流程SPI的优缺点优点缺点 SPI在STM32上的实现SPI引脚配置SPI初始化代码(STM32F10x)SPI主设备发送和接收数据SPI从设备数据处理 总结 SPI协议概述 SPI(Serial Peripheral Interfa…...
网络编程—Socket套接字(UDP)
上篇文章: 网络编程—网络概念https://blog.csdn.net/sniper_fandc/article/details/146923380?fromshareblogdetail&sharetypeblogdetail&sharerId146923380&sharereferPC&sharesourcesniper_fandc&sharefromfrom_link 目录 1 概念 2 Soc…...
视频设备轨迹回放平台EasyCVR综合智能化,搭建运动场体育赛事直播方案
一、背景 随着5G技术的发展,体育赛事直播迎来了新的高峰。无论是NBA、西甲、英超、德甲、意甲、中超还是CBA等热门赛事,都是值得记录和回放的精彩瞬间。对于体育迷来说,选择观看的平台众多,但是作为运营者,搭建一套体…...
AIGC实战——CycleGAN详解与实现
AIGC实战——CycleGAN详解与实现 0. 前言1. CycleGAN 基本原理2. CycleGAN 模型分析3. 实现 CycleGAN小结系列链接 0. 前言 CycleGAN 是一种用于图像转换的生成对抗网络(Generative Adversarial Network, GAN),可以在不需要配对数据的情况下将一种风格的图像转换成…...
VS2022远程调试Linux程序
一、 1、VS2022安装参考 VS Studio2022安装教程(保姆级教程)_visual studio 2022-CSDN博客 注意:勾选的时候,要勾选下方的选项,才能调试Linux环境下运行的程序! 2、VS2022远程调试Linux程序测试 原文参…...
345-java人事档案管理系统的设计与实现
345-java人事档案管理系统的设计与实现 项目概述 本项目为基于Java语言的人事档案管理系统,旨在帮助企事业单位高效管理员工档案信息,实现档案的电子化、自动化管理。系统涵盖了员工信息的录入、查询、修改、删除等功能,同时具备权限控制和…...
【Linux系统编程】进程概念,进程状态
目录 一,操作系统(Operator System) 1-1概念 1-2设计操作系统的目的 1-3核心功能 1-4系统调用和库函数概念 二,进程(Process) 2-1进程概念与基本操作 2-2task_struct结构体内容 2-3查看进程 2-4通…...
优选算法的妙思之流:分治——快排专题
专栏:算法的魔法世界 个人主页:手握风云 目录 一、快速排序 二、例题讲解 2.1. 颜色分类 2.2. 排序数组 2.3. 数组中的第K个最大元素 2.4. 库存管理 III 一、快速排序 分治,简单理解为“分而治之”,将一个大问题划分为若干个…...
# 实时人脸识别系统:基于 OpenCV 和 Python 的实现
实时人脸识别系统:基于 OpenCV 和 Python 的实现 在当今数字化时代,人脸识别技术已经广泛应用于各种场景,从手机解锁到安防监控,再到智能门禁系统。今天,我将通过一个完整的代码示例,详细讲解如何使用 Pyt…...
Mysql 中 ACID 背后的原理
在 MySQL 中,ACID 是事务处理的核心原则,用于保证数据库在执行事务时的可靠性、数据一致性和稳定性。ACID 是四个关键特性的首字母缩写,分别是:Atomicity(原子性)、Consistency(一致性ÿ…...
wx206基于ssm+vue+uniapp的优购电商小程序
开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…...
React编程高级主题:错误处理(Error Handling)
文章目录 **5.2 错误处理(Error Handling)概述****5.2.1 onErrorReturn / onErrorResume(错误回退)****1. onErrorReturn:提供默认值****2. onErrorResume:切换备用数据流** **5.2.2 retry / retryWhen&…...
ubuntu20.04升级成ubuntu22.04
命令行 sudo do-release-upgrade 我是按提示输入y确认操作,也可以遇到配置文件冲突时建议选择N保留当前配置...
SpringCloud(25)——Stream介绍
1.场景描述 当我们的分布式系统建设到一定程度了,或者服务间是通过异步请求来通讯的,那么我们避免不了使用MQ来解决问题。 假如公司内部进行了业务合并或者整合,需要服务A和服务B通过MQ的方式进行消息传递,而服务A用的是RabbitMQ&…...
OrangePi5Plus开发板不能正确识别USB 3.0 设备 (绿联HUB和Camera)
1、先插好上电(可正确识别) 2、上电开机后插入USB 3.0 设备,报错如下,只能检测到USB2.0--480M,识别不到USB3.0-5Gbps,重新插拔也不行 Apr 4 21:30:00 orangepi5plus kernel: [ 423.575966] usb 5-1: re…...
centos8上实现lvs集群负载均衡dr模式
1.前言 个人备忘笔记,欢迎探讨。 centos8上实现lvs集群负载均衡nat模式 centos8上实现lvs集群负载均衡dr模式 之前写过一篇lvs-nat模式。实验起来相对顺利。dr模式最大特点是响应报文不经调度器,而是直接返回客户机。 dr模式分同网段和不同网段。同…...
uniapp如何接入星火大模型
写在前面:最近的ai是真的火啊,琢磨了一下,弄个uniappx的版本开发个东西玩一下,想了想不知道放啥内容,突然觉得deepseek可以接入,好家伙,一对接以后发现这是个付费的玩意,我穷&#x…...
三、FFmpeg学习笔记
FFmpeg是一个开源、跨平台的多媒体处理框架,能够实现音视频的录制、转换、剪辑、编码、解码、流媒体传输、过滤与后期处理等几乎所有常见的多媒体操作。其强大之处在于几乎支持所有的音视频格式、编解码器和封装格式,是业界公认的“瑞士军刀”。 FFmp…...
Linux常用基础命令应用
目录 一、文件与目录操作(12个核心命令) 1. pwd - 显示当前路径 2. ls - 查看目录内容 3. cd - 切换目录 4. mkdir - 创建目录 5. touch - 创建文件 6. cp - 复制文件/目录 7. mv - 移动…...
【python中级】关于Cython 的源代码pyx的说明
【python中级】关于Cython 的源代码pyx的说明 1.背景2.编译3.语法1.背景 Cython 是一个编程语言和工具链,用于将 Python 代码(或类 Python 的代码)编译成 C 语言,再进一步生成高性能的 Python 扩展模块(.so 或 .pyd 文件)。 在 Python 中,.pyx 文件是 Cython 的源代码文…...
第十八节课:Python编程基础复习
课程复习 前三周核心内容回顾 第一周:Python基本语法元素 基础语法:缩进、注释、变量命名、保留字数据类型:字符串、整数、浮点数、列表程序结构:赋值语句、分支语句(if)、函数输入输出:inpu…...
MySQL vs MSSQL 对比
在企业数据库管理系统中,MySQL 和 Microsoft SQL Server(MSSQL)是最受欢迎的两大选择。MySQL 是一款开源的关系型数据库管理系统(RDBMS),由 MySQL AB 开发,现归属于 Oracle 公司。而 MSSQL 是微…...
解决LeetCode“使括号有效的最少添加”问题
目录 问题描述 解题思路 复杂度分析 示例分析 暴力替换“不讲码德” 总结 问题描述 给定一个仅由 ( 和 ) 组成的字符串 s,我们需要通过添加最少数量的括号(( 或 ))使得字符串有效。有效字符串需满足: 空字符串是有效的。 …...
