1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)
1、卷积神经网络的手写数字旋转角度预测原理及流程
基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:
原理:
-
数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。
-
模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。
-
训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。
-
模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。
流程:
-
加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。
-
构建CNN模型:使用MATLAB深度学习工具箱中的函数(如
convolution2dLayer、maxPooling2dLayer、fullyConnectedLayer、classificationLayer)构建一个适合手写数字旋转角度预测的CNN模型。 -
定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。
-
训练模型:使用训练数据集对CNN模型进行训练,通过调用
trainNetwork函数并传入训练数据和训练选项来完成训练过程。 -
评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。
-
预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。
这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。
2、卷积神经网络的手写数字旋转角度预测案例说明
1)解决问题
卷积神经网络来预测手写数字的旋转角度
2)技术方案
回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。
3、加载数据
1)数据说明
数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
2)加载数据代码
说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。
load DigitsDataTrain
load DigitsDataTest
3)显示训练集代码
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);
视图效果

4)数据集划分代码
说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);
4、检查数据归一化
1)归一化说明
训练神经网络时,确保数据在网络的所有阶段均归一化。
对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1
2)绘制响应的分布代码
说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")
视图效果

5、定义神经网络架构
1)神经网络架构说明
对于图像输入,指定一个图像输入层。
指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
2)神经网络架构代码
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
6、指定训练选项
1)指定训练选项说明
使用Experiment Manager。
将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
在图中显示训练进度并监控均方根误差。
2)指定训练选项代码
miniBatchSize = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
7、训练神经网络
1)训练神经网络说明
使用 trainnet 函数训练神经网络。
对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
2)训练神经网络代码
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
视图效果

8、测试网络
1)测试网络说明
基于测试数据评估准确度来测试网络性能。
使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
2)测试网络代码
YTest = minibatchpredict(net,XTest);
3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
4)散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")
视图效果

9、使用新数据进行预测
1)测试说明
使用 predict 函数并使用神经网络对第一个测试图像进行预测
2)测试代码
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)
10、总结
基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:
总结要点:
-
数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。
-
模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。
-
训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。
-
模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。
实现流程:
-
数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。
-
CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。
-
训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。
-
模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。
-
预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。
通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。
11、源代码
代码
%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。load DigitsDataTrain
load DigitsDataTest%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")%% 定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。miniBatchSize = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)
工程文件
https://download.csdn.net/download/XU157303764/89494539
相关文章:
1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)
1、卷积神经网络的手写数字旋转角度预测原理及流程 基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下…...
Windows如何使用Python的sphinx
在Windows上使用Python的Sphinx进行文档渲染和呈现,可以遵循以下步骤进行操作: 安装Python:首先,确保你的Windows系统上已经安装了Python。你可以从Python的官方网站下载并安装适合你系统(32位或64位&…...
C++ STL nth_element 用法
一:功能 将一个序列分为两组,前一组元素都小于*nth,后一组元素都大于*nth, 并且确保第 nth 个位置就是排序之后所处的位置。即该位置的元素是该序列中第nth小的数。 二:用法 #include <vector> #include <a…...
【PostgreSQL教程】PostgreSQL 选择数据库
博主介绍:✌全网粉丝20W+,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物联网、机器学习等设计与开发。 感兴趣的可…...
C# —— HashTable
集合collections命名空间,专门进行一系列的数据存储和检索的类,主要包含了:堆栈、和队列、list、ArrayList、数组 HashTable 字典 storeList 排序列表等类 Array 数组 长度固定, 类型固定 通过索引值来进行访问 ArrayList动态数组,…...
LeetCode 第407场周赛个人题解
目录 100372. 使两个整数相等的位更改次数 原题链接 思路分析 AC代码 100335. 字符串元音游戏 原题链接 思路分析 AC代码 100360. 将 1 移动到末尾的最大操作次数 原题链接 思路分析 AC代码 100329. 使数组等于目标数组所需的最少操作次数 原题链接 思路分析 A…...
使用Django框架实现音频上传功能
数据库设计(models.py) class Music(models.Model):""" 音乐 """name models.CharField(verbose_name"音乐名字", max_length32)singer models.CharField(verbose_name"歌手", max_length32)# 本质…...
[路由器]IP-MAC的绑定与取消
背景:当公司的网络不想与外部人员进行共享,可以在路由器页面配置IP-MAC的绑定,让公司内部人员的手机和电脑的mac,才能接入到公司。第一步:在ARP防护中,启动IP-MAC绑定选项,必须启动仅允许IP-MAC…...
Idea配置远程开发
Idea配置远程开发 本篇博客介绍使用idea通过ssh连接ubuntu服务器进行开发 目录 Idea配置远程开发1.idae上点击file->Remote Development2.点击New Connection3.填写相关信息4.输入密码5.选择IDE版本和项目路径5.1 点击open an SSH terminal打开控制台5.2 依次执行命令 6.成…...
lua 实现 函数 判断两个时间戳是否在同一天
函数用于判断两个时间戳是否在同一天。下面是对代码的详细解释: ### 函数参数 - stampA 和 stampB:两个时间戳,用于比较。- resetInfo:一个可选参数,包含小时、分钟和秒数,用于调整时间戳。 ### 函数实现…...
工作纪实53-log4j日志打印文件隔离
在项目中,我有一堆业务日志需要打印,另一部分的日志,是没有格式的,需要被云平台离线解析并收集到kafka或者hdfs、hive等,需要将日志隔离打印到不同的文件 正常的log4j配置是下面这样的,配合Sl4j直接使用默认…...
7月21日,贪心练习
大家好呀,今天带来一些贪心算法的应用解题、 一,柠檬水找零 . - 力扣(LeetCode) 解析: 本题的贪心体现在对于20美元的处理上,我们总是优先把功能较少的10元作为找零,这样可以让5元用处更大 …...
FPGA DNA 获取 DNA_PORT
FPGA DNA DNA 是 FPGA 芯片的唯一标识, FPGA 都有一个独特的 ID ,也就是 Device DNA ,这个 ID 相当于我们的身份证,在 FPGA 芯片生产的时候就已经固定在芯片的 eFuse 寄存器中,具有不可修改的属性。在 xilinx 7series…...
使用 hutool工具实现导入导出功能。
hutool工具网址 Hutool参考文档 pom依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.20</version></dependency><dependency><groupId>org.apache.poi</gro…...
大语言模型-Transformer-Attention Is All You Need
一、背景信息: Transformer是一种由谷歌在2017年提出的深度学习模型。 主要用于自然语言处理(NLP)任务,特别是序列到序列(Sequence-to-Sequence)的学习问题,如机器翻译、文本生成等。Transfor…...
spring(二)
一、为对象类型属性赋值 方式一:(引用外部bean) 1.创建班级类Clazz package com.spring.beanpublic class Clazz {private Integer clazzId;private String clazzName;public Integer getClazzId() {return clazzId;}public void setClazzId(Integer clazzId) {th…...
MAC 数据恢复软件: STELLAR Data Recovery For MAC V. 12.1 更多增强功能
天津鸿萌科贸发展有限公司是 Stellar 系列软件的授权代理商。 STELLAR Data Recovery For MAC 该数据恢复软件可从任何存储驱动器、清空的回收站以及崩溃或无法启动的 Mac 设备中恢复丢失或删除的文件。 轻松恢复已删除的文档、照片、音频文件和视频。自定义扫描以帮助恢复特…...
初识godot游戏引擎并安装
简介 Godot是一款自由开源、由社区驱动的2D和3D游戏引擎。游戏开发虽复杂,却蕴含一定的通用规律,正是为了简化这些通用化的工作,游戏引擎应运而生。Godot引擎作为一款功能丰富的跨平台游戏引擎,通过统一的界面支持创建2D和3D游戏。…...
Windows配置Qt+VLC
文章目录 前言下载库文件提取文件编写qmakeqtvlc测试代码 总结 前言 在Windows平台上配置Qt和VLC是开发多媒体应用程序的一个重要步骤。Qt作为一个强大的跨平台应用开发框架,为开发人员提供了丰富的GUI工具和库,而VLC则是一个开源的多媒体播放器&#x…...
本地部署 mistralai/Mistral-Nemo-Instruct-2407
本地部署 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟环境2. 安装 fschat3. 安装 transformers4. 安装 flash-attn5. 安装 pytorch6. 启动 controller7. 启动 mistralai/Mistral-Nemo-Instruct-24078. 启动 api9. 访问 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟…...
MATLAB实战:如何用最小二乘法搞定系统辨识(附完整代码)
MATLAB实战:最小二乘法在系统辨识中的工程应用指南 在工业控制、信号处理等领域,系统辨识是建立数学模型的关键步骤。想象一下,当你面对一组输入输出数据,却不知道背后的系统规律时,最小二乘法就像一把瑞士军刀&#x…...
LVDS信号完整性救星:Xilinx OSERDESE2+IDELAY2配置避坑指南
LVDS信号完整性救星:Xilinx OSERDESE2IDELAY2配置避坑指南 当你在Gbps级LVDS接口设计中遇到信号抖动问题时,是否曾盯着眼图上的毛刺束手无策?作为Xilinx FPGA开发者,我们常陷入这样的困境:明明按照手册配置了OSERDESE2…...
FlyEnv-安装使用摸索记录
下载 官网地址:https://www.macphpstudy.com/zh/ 进入github下载,也可以百度网盘下载。 下载完后进行安装,我是选择为当前用户安装,没有为所有用户安装。 进入页面进行需要安装的软件;看上去还是有蛮多的,…...
5步搞定Qwen3-ASR语音识别:支持多语言和方言,快速上手教程
5步搞定Qwen3-ASR语音识别:支持多语言和方言,快速上手教程 语音识别技术正在改变我们与数字世界的交互方式,而Qwen3-ASR以其强大的多语言和方言支持能力脱颖而出。本文将带你用最简单的方式,在5个步骤内完成这个专业级语音识别系…...
HunyuanVideo-Foley 效果对比:不同算法模型生成音效的质量评估
HunyuanVideo-Foley 效果对比:不同算法模型生成音效的质量评估 1. 音效生成技术概览 音效生成技术正在经历一场革命性的变革。从早期的采样拼接到如今的AI生成,算法模型已经能够根据简单的文字描述创造出丰富多样的声音效果。这项技术在影视制作、游戏…...
惯性导航系统深度解析:从平台式到捷联式的技术演进与精度优化
1. 惯性导航系统的基本原理 想象一下你被蒙上眼睛放在一个陌生的城市里,只给你一个计步器和指南针,要求你记录自己的行走路线。这就是惯性导航系统(INS)工作的基本场景——它通过测量运动载体的加速度和角速度,像做数…...
3大维度破解C盘空间困局:Windows Cleaner让系统重获新生的开源方案
3大维度破解C盘空间困局:Windows Cleaner让系统重获新生的开源方案 【免费下载链接】WindowsCleaner Windows Cleaner——专治C盘爆红及各种不服! 项目地址: https://gitcode.com/gh_mirrors/wi/WindowsCleaner 当你的电脑频繁弹出"磁盘空间…...
Qwen3-Embedding-4B广告过滤应用:恶意内容识别系统实战
Qwen3-Embedding-4B广告过滤应用:恶意内容识别系统实战 1. 引言:当广告变成“牛皮癣”,我们如何反击? 想象一下,你运营着一个用户社区或内容平台。每天,用户都在热情地分享、讨论。但总有一些不速之客&am…...
Noi:整合多 AI 服务的新利器能否突出重围?
Noi:一站式 AI 服务整合新体验Noi 是一款图形用户界面(GUI)应用程序,它的核心亮点在于将所有 AI 服务整合到一处。用户通过单一用户界面(UI)就能访问 ChatGPT、Claude、Gemini、Perplexity 等多个服务&…...
FastMoss TikTok电商数据爬取实战:JS逆向与MD5签名破解
1. FastMoss TikTok电商数据爬取的核心挑战 最近在研究FastMoss平台的TikTok电商数据爬取,发现最大的难点在于请求签名加密。当你访问https://www.fastmoss.com/zh/e-commerce/saleslist这个页面时,切换周榜会触发一个带有fm-sign签名的加密请求。这个签…...
