1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)
1、卷积神经网络的手写数字旋转角度预测原理及流程
基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:
原理:
-
数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。
-
模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。
-
训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。
-
模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。
流程:
-
加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。
-
构建CNN模型:使用MATLAB深度学习工具箱中的函数(如
convolution2dLayer
、maxPooling2dLayer
、fullyConnectedLayer
、classificationLayer
)构建一个适合手写数字旋转角度预测的CNN模型。 -
定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。
-
训练模型:使用训练数据集对CNN模型进行训练,通过调用
trainNetwork
函数并传入训练数据和训练选项来完成训练过程。 -
评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。
-
预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。
这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。
2、卷积神经网络的手写数字旋转角度预测案例说明
1)解决问题
卷积神经网络来预测手写数字的旋转角度
2)技术方案
回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。
3、加载数据
1)数据说明
数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
2)加载数据代码
说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。
load DigitsDataTrain
load DigitsDataTest
3)显示训练集代码
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);
视图效果
4)数据集划分代码
说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);
4、检查数据归一化
1)归一化说明
训练神经网络时,确保数据在网络的所有阶段均归一化。
对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1
2)绘制响应的分布代码
说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")
视图效果
5、定义神经网络架构
1)神经网络架构说明
对于图像输入,指定一个图像输入层。
指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
2)神经网络架构代码
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
6、指定训练选项
1)指定训练选项说明
使用Experiment Manager。
将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
在图中显示训练进度并监控均方根误差。
2)指定训练选项代码
miniBatchSize = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
7、训练神经网络
1)训练神经网络说明
使用 trainnet 函数训练神经网络。
对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
2)训练神经网络代码
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
视图效果
8、测试网络
1)测试网络说明
基于测试数据评估准确度来测试网络性能。
使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
2)测试网络代码
YTest = minibatchpredict(net,XTest);
3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
4)散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")
视图效果
9、使用新数据进行预测
1)测试说明
使用 predict 函数并使用神经网络对第一个测试图像进行预测
2)测试代码
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)
10、总结
基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:
总结要点:
-
数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。
-
模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。
-
训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。
-
模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。
实现流程:
-
数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。
-
CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。
-
训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。
-
模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。
-
预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。
通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。
11、源代码
代码
%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。load DigitsDataTrain
load DigitsDataTest%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")%% 定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。miniBatchSize = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)
工程文件
https://download.csdn.net/download/XU157303764/89494539
相关文章:

1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)
1、卷积神经网络的手写数字旋转角度预测原理及流程 基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下…...
Windows如何使用Python的sphinx
在Windows上使用Python的Sphinx进行文档渲染和呈现,可以遵循以下步骤进行操作: 安装Python:首先,确保你的Windows系统上已经安装了Python。你可以从Python的官方网站下载并安装适合你系统(32位或64位&…...
C++ STL nth_element 用法
一:功能 将一个序列分为两组,前一组元素都小于*nth,后一组元素都大于*nth, 并且确保第 nth 个位置就是排序之后所处的位置。即该位置的元素是该序列中第nth小的数。 二:用法 #include <vector> #include <a…...

【PostgreSQL教程】PostgreSQL 选择数据库
博主介绍:✌全网粉丝20W+,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物联网、机器学习等设计与开发。 感兴趣的可…...
C# —— HashTable
集合collections命名空间,专门进行一系列的数据存储和检索的类,主要包含了:堆栈、和队列、list、ArrayList、数组 HashTable 字典 storeList 排序列表等类 Array 数组 长度固定, 类型固定 通过索引值来进行访问 ArrayList动态数组,…...
LeetCode 第407场周赛个人题解
目录 100372. 使两个整数相等的位更改次数 原题链接 思路分析 AC代码 100335. 字符串元音游戏 原题链接 思路分析 AC代码 100360. 将 1 移动到末尾的最大操作次数 原题链接 思路分析 AC代码 100329. 使数组等于目标数组所需的最少操作次数 原题链接 思路分析 A…...

使用Django框架实现音频上传功能
数据库设计(models.py) class Music(models.Model):""" 音乐 """name models.CharField(verbose_name"音乐名字", max_length32)singer models.CharField(verbose_name"歌手", max_length32)# 本质…...

[路由器]IP-MAC的绑定与取消
背景:当公司的网络不想与外部人员进行共享,可以在路由器页面配置IP-MAC的绑定,让公司内部人员的手机和电脑的mac,才能接入到公司。第一步:在ARP防护中,启动IP-MAC绑定选项,必须启动仅允许IP-MAC…...

Idea配置远程开发
Idea配置远程开发 本篇博客介绍使用idea通过ssh连接ubuntu服务器进行开发 目录 Idea配置远程开发1.idae上点击file->Remote Development2.点击New Connection3.填写相关信息4.输入密码5.选择IDE版本和项目路径5.1 点击open an SSH terminal打开控制台5.2 依次执行命令 6.成…...
lua 实现 函数 判断两个时间戳是否在同一天
函数用于判断两个时间戳是否在同一天。下面是对代码的详细解释: ### 函数参数 - stampA 和 stampB:两个时间戳,用于比较。- resetInfo:一个可选参数,包含小时、分钟和秒数,用于调整时间戳。 ### 函数实现…...
工作纪实53-log4j日志打印文件隔离
在项目中,我有一堆业务日志需要打印,另一部分的日志,是没有格式的,需要被云平台离线解析并收集到kafka或者hdfs、hive等,需要将日志隔离打印到不同的文件 正常的log4j配置是下面这样的,配合Sl4j直接使用默认…...

7月21日,贪心练习
大家好呀,今天带来一些贪心算法的应用解题、 一,柠檬水找零 . - 力扣(LeetCode) 解析: 本题的贪心体现在对于20美元的处理上,我们总是优先把功能较少的10元作为找零,这样可以让5元用处更大 …...

FPGA DNA 获取 DNA_PORT
FPGA DNA DNA 是 FPGA 芯片的唯一标识, FPGA 都有一个独特的 ID ,也就是 Device DNA ,这个 ID 相当于我们的身份证,在 FPGA 芯片生产的时候就已经固定在芯片的 eFuse 寄存器中,具有不可修改的属性。在 xilinx 7series…...
使用 hutool工具实现导入导出功能。
hutool工具网址 Hutool参考文档 pom依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.20</version></dependency><dependency><groupId>org.apache.poi</gro…...

大语言模型-Transformer-Attention Is All You Need
一、背景信息: Transformer是一种由谷歌在2017年提出的深度学习模型。 主要用于自然语言处理(NLP)任务,特别是序列到序列(Sequence-to-Sequence)的学习问题,如机器翻译、文本生成等。Transfor…...
spring(二)
一、为对象类型属性赋值 方式一:(引用外部bean) 1.创建班级类Clazz package com.spring.beanpublic class Clazz {private Integer clazzId;private String clazzName;public Integer getClazzId() {return clazzId;}public void setClazzId(Integer clazzId) {th…...
MAC 数据恢复软件: STELLAR Data Recovery For MAC V. 12.1 更多增强功能
天津鸿萌科贸发展有限公司是 Stellar 系列软件的授权代理商。 STELLAR Data Recovery For MAC 该数据恢复软件可从任何存储驱动器、清空的回收站以及崩溃或无法启动的 Mac 设备中恢复丢失或删除的文件。 轻松恢复已删除的文档、照片、音频文件和视频。自定义扫描以帮助恢复特…...

初识godot游戏引擎并安装
简介 Godot是一款自由开源、由社区驱动的2D和3D游戏引擎。游戏开发虽复杂,却蕴含一定的通用规律,正是为了简化这些通用化的工作,游戏引擎应运而生。Godot引擎作为一款功能丰富的跨平台游戏引擎,通过统一的界面支持创建2D和3D游戏。…...

Windows配置Qt+VLC
文章目录 前言下载库文件提取文件编写qmakeqtvlc测试代码 总结 前言 在Windows平台上配置Qt和VLC是开发多媒体应用程序的一个重要步骤。Qt作为一个强大的跨平台应用开发框架,为开发人员提供了丰富的GUI工具和库,而VLC则是一个开源的多媒体播放器&#x…...

本地部署 mistralai/Mistral-Nemo-Instruct-2407
本地部署 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟环境2. 安装 fschat3. 安装 transformers4. 安装 flash-attn5. 安装 pytorch6. 启动 controller7. 启动 mistralai/Mistral-Nemo-Instruct-24078. 启动 api9. 访问 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟…...
FFmpeg 低延迟同屏方案
引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...

如何在看板中体现优先级变化
在看板中有效体现优先级变化的关键措施包括:采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中,设置任务排序规则尤其重要,因为它让看板视觉上直观地体…...

家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...
2023赣州旅游投资集团
单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)
上一章用到了V2 的概念,其实 Fiori当中还有 V4,咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务),代理中间件(ui5-middleware-simpleproxy)-CSDN博客…...

计算机基础知识解析:从应用到架构的全面拆解
目录 前言 1、 计算机的应用领域:无处不在的数字助手 2、 计算机的进化史:从算盘到量子计算 3、计算机的分类:不止 “台式机和笔记本” 4、计算机的组件:硬件与软件的协同 4.1 硬件:五大核心部件 4.2 软件&#…...

【C++】纯虚函数类外可以写实现吗?
1. 答案 先说答案,可以。 2.代码测试 .h头文件 #include <iostream> #include <string>// 抽象基类 class AbstractBase { public:AbstractBase() default;virtual ~AbstractBase() default; // 默认析构函数public:virtual int PureVirtualFunct…...

aardio 自动识别验证码输入
技术尝试 上周在发学习日志时有网友提议“在网页上识别验证码”,于是尝试整合图像识别与网页自动化技术,完成了这套模拟登录流程。核心思路是:截图验证码→OCR识别→自动填充表单→提交并验证结果。 代码在这里 import soImage; import we…...