1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)
1、卷积神经网络的手写数字旋转角度预测原理及流程
基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下是基于MATLAB的手写数字旋转角度预测的原理和流程:
原理:
-
数据准备:首先,准备一个包含手写数字图像和其对应标签(即旋转角度)的数据集。这些图像可以是MNIST数据集的手写数字。
-
模型建立:构建一个CNN模型,包括卷积层、池化层、全连接层等,来学习手写数字图像的特征并预测它们的旋转角度。
-
训练模型:利用准备好的训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数以最小化预测与真实标签之间的误差。
-
模型评估:使用测试数据集对训练好的模型进行评估,计算模型的准确率或其他性能指标,以评估其在预测手写数字旋转角度方面的性能。
流程:
-
加载数据集:在MATLAB中加载手写数字图像数据集,并对图像进行预处理和标签处理,以便输入到CNN模型中。
-
构建CNN模型:使用MATLAB深度学习工具箱中的函数(如
convolution2dLayer、maxPooling2dLayer、fullyConnectedLayer、classificationLayer)构建一个适合手写数字旋转角度预测的CNN模型。 -
定义训练选项:设置训练选项,包括优化器类型、学习率、最大训练轮数等。
-
训练模型:使用训练数据集对CNN模型进行训练,通过调用
trainNetwork函数并传入训练数据和训练选项来完成训练过程。 -
评估模型:使用测试数据集对训练好的模型进行评估,计算准确率等性能指标。
-
预测手写数字的旋转角度:最后,使用训练好的模型对新的手写数字图像进行预测,得到其旋转角度的预测结果。
这是基于卷积神经网络的手写数字旋转角度预测的基本原理和流程。
2、卷积神经网络的手写数字旋转角度预测案例说明
1)解决问题
卷积神经网络来预测手写数字的旋转角度
2)技术方案
回归任务涉及预测连续数值而不是离散类标签,回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。
3、加载数据
1)数据说明
数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
2)加载数据代码
说明:变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。
load DigitsDataTrain
load DigitsDataTest
3)显示训练集代码
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);
视图效果

4)数据集划分代码
说明:使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);
4、检查数据归一化
1)归一化说明
训练神经网络时,确保数据在网络的所有阶段均归一化。
对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1
2)绘制响应的分布代码
说明:响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")
视图效果

5、定义神经网络架构
1)神经网络架构说明
对于图像输入,指定一个图像输入层。
指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
2)神经网络架构代码
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
6、指定训练选项
1)指定训练选项说明
使用Experiment Manager。
将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
在图中显示训练进度并监控均方根误差。
2)指定训练选项代码
miniBatchSize = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
7、训练神经网络
1)训练神经网络说明
使用 trainnet 函数训练神经网络。
对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
2)训练神经网络代码
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
视图效果

8、测试网络
1)测试网络说明
基于测试数据评估准确度来测试网络性能。
使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
2)测试网络代码
YTest = minibatchpredict(net,XTest);
3)计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
4)散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")
视图效果

9、使用新数据进行预测
1)测试说明
使用 predict 函数并使用神经网络对第一个测试图像进行预测
2)测试代码
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)
10、总结
基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题,通过使用MATLAB深度学习工具箱可以比较方便地实现。下面是对这一任务的总结:
总结要点:
-
数据准备:准备包含手写数字图像和对应旋转角度标签的数据集,如MNIST数据集。
-
模型建立:构建卷积神经网络(CNN)模型,通过卷积层、池化层、全连接层等结构来学习手写数字图像的特征和预测旋转角度。
-
训练模型:使用训练数据集对CNN模型进行训练,通过反向传播算法来调整模型参数,最小化预测与真实标签的误差。
-
模型评估:使用测试数据集对训练好的模型进行评估,计算准确率或其他性能指标,评定模型在预测旋转角度上的性能。
实现流程:
-
数据加载和预处理:加载手写数字图像数据集,对图像进行预处理(如缩放、归一化)并提取对应的旋转角度标签。
-
CNN模型构建:使用MATLAB深度学习工具箱中的函数构建CNN模型,包括卷积层、池化层、全连接层,并适当选择激活函数。
-
训练模型:定义训练选项,选择优化器和学习率等参数,使用训练数据集对CNN模型进行训练。
-
模型评估:使用测试数据集对训练好的模型进行评估,检验其在预测手写数字旋转角度的准确性。
-
预测和应用:最后,使用训练好的模型对新的手写数字图像进行预测,实现手写数字旋转角度的自动识别和预测。
通过以上流程和总结,您可以利用MATLAB深度学习工具箱来实现基于卷积神经网络的手写数字旋转角度预测任务。
11、源代码
代码
%% 基于卷积神经网络的手写数字旋转角度预测
%卷积神经网络来预测手写数字的旋转角度
%回归任务涉及预测连续数值而不是离散类标签
%回归构造卷积神经网络架构,训练网络,并使用经过训练的网络来预测旋转手写数字的角度。%% 加载数据
%数据集包含手写数字的合成图像以及每个图像的旋转角度(以度为单位)。
%变量 anglesTrain 和 anglesTest 是以度为单位的旋转角度。训练数据集和测试数据集各包含 5000 个图像。load DigitsDataTrain
load DigitsDataTest%显示训练集
numObservations = size(XTrain,4);
idx = randperm(numObservations,49);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I);%数据集划分
%使用 trainingPartitions 函数将 XTrain 和 anglesTrain 分区为训练分区和验证分区,留出 15% 的训练数据用于验证。
[idxTrain,idxValidation] = trainingPartitions(numObservations,[0.85 0.15]);XValidation = XTrain(:,:,:,idxValidation);
anglesValidaiton = anglesTrain(idxValidation);XTrain = XTrain(:,:,:,idxTrain);
anglesTrain = anglesTrain(idxTrain);%% 检查数据归一化
%训练神经网络时,确保数据在网络的所有阶段均归一化。
%对于使用梯度下降的网络训练,归一化有助于训练的稳定和加速.
%数据比例不佳,则损失可能会变为 NaN,并且网络参数在训练过程中可能发生偏离
%归一化数据的常用方法包括重新缩放数据,使其范围变为 [0,1],或使其均值为 0 且标准差为 1%绘制响应的分布。
% 响应(以度为单位的旋转角度)大致均匀地分布在 -45 和 45 之间,效果很好,无需归一化。
figure
histogram(anglesTrain)
axis tight
ylabel("Counts")
xlabel("Rotation Angle")%% 定义神经网络架构
%对于图像输入,指定一个图像输入层。
%指定四个 convolution-batchnorm-ReLU 模块,并增加滤波器数量。
%在每个模块之间指定一个具有池化区域的平均池化层,步幅大小为 2。
%在网络末尾,包含一个全连接层,其输出大小与响应数量匹配。
numResponses = 1;layers = [imageInputLayer([28 28 1])convolution2dLayer(3,8,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,16,Padding="same")batchNormalizationLayerreluLayeraveragePooling2dLayer(2,Stride=2)convolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerconvolution2dLayer(3,32,Padding="same")batchNormalizationLayerreluLayerfullyConnectedLayer(numResponses)];
%% 指定训练选项
%使用Experiment Manager。
%将初始学习率设置为 0.001,并在 20 轮训练后降低学习率。
%通过指定验证数据和验证频率,监控训练过程中的网络准确度。软件基于训练数据训练网络,并在训练过程中按固定时间间隔计算基于验证数据的准确度。验证数据不用于更新网络权重。
%在图中显示训练进度并监控均方根误差。miniBatchSize = 128;
validationFrequency = floor(numel(anglesTrain)/miniBatchSize);options = trainingOptions("sgdm", ...MiniBatchSize=miniBatchSize, ...InitialLearnRate=1e-3, ...LearnRateSchedule="piecewise", ...LearnRateDropFactor=0.1, ...LearnRateDropPeriod=20, ...Shuffle="every-epoch", ...ValidationData={XTest,anglesTest}, ...ValidationFrequency=validationFrequency, ...Plots="training-progress", ...Metrics="rmse", ...Verbose=false);
%% 训练神经网络
%使用 trainnet 函数训练神经网络。
%对于回归,请使用均方误差损失。默认情况下,trainnet 函数使用 GPU(如果有)。使用 GPU 需要 Parallel Computing Toolbox™ 许可证和受支持的 GPU 设备。要指定执行环境,请使用 ExecutionEnvironment 训练选项。
net = trainnet(XTrain,anglesTrain,layers,"mse",options);
%% 测试网络
%基于测试数据评估准确度来测试网络性能。
%使用 minibatchpredict 函数进行预测。默认情况下,minibatchpredict 函数使用 GPU(如果有)。
YTest = minibatchpredict(net,XTest);
%计算均方根误差 (RMSE) 以衡量预测旋转角度和实际旋转角度之间的差异。
predictionError = anglesTest - YTest;
squares = predictionError.^2;
rmse = sqrt(mean(squares))
%散点图中可视化预测。绘制预测值对真实值的图。
figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")hold on
plot([-60 60], [-60 60],"y--")%% 使用新数据进行预测
%使用 predict 函数并使用神经网络对第一个测试图像进行预测
X = XTest(:,:,:,1);
if canUseGPUX = gpuArray(X);
end
Y = predict(net,X)
工程文件
https://download.csdn.net/download/XU157303764/89494539
相关文章:
1.30、基于卷积神经网络的手写数字旋转角度预测(matlab)
1、卷积神经网络的手写数字旋转角度预测原理及流程 基于卷积神经网络的手写数字旋转角度预测是一个常见的计算机视觉问题。在这种情况下,我们可以通过构建一个卷积神经网络(Convolutional Neural Network,CNN)来实现该任务。以下…...
Windows如何使用Python的sphinx
在Windows上使用Python的Sphinx进行文档渲染和呈现,可以遵循以下步骤进行操作: 安装Python:首先,确保你的Windows系统上已经安装了Python。你可以从Python的官方网站下载并安装适合你系统(32位或64位&…...
C++ STL nth_element 用法
一:功能 将一个序列分为两组,前一组元素都小于*nth,后一组元素都大于*nth, 并且确保第 nth 个位置就是排序之后所处的位置。即该位置的元素是该序列中第nth小的数。 二:用法 #include <vector> #include <a…...
【PostgreSQL教程】PostgreSQL 选择数据库
博主介绍:✌全网粉丝20W+,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物联网、机器学习等设计与开发。 感兴趣的可…...
C# —— HashTable
集合collections命名空间,专门进行一系列的数据存储和检索的类,主要包含了:堆栈、和队列、list、ArrayList、数组 HashTable 字典 storeList 排序列表等类 Array 数组 长度固定, 类型固定 通过索引值来进行访问 ArrayList动态数组,…...
LeetCode 第407场周赛个人题解
目录 100372. 使两个整数相等的位更改次数 原题链接 思路分析 AC代码 100335. 字符串元音游戏 原题链接 思路分析 AC代码 100360. 将 1 移动到末尾的最大操作次数 原题链接 思路分析 AC代码 100329. 使数组等于目标数组所需的最少操作次数 原题链接 思路分析 A…...
使用Django框架实现音频上传功能
数据库设计(models.py) class Music(models.Model):""" 音乐 """name models.CharField(verbose_name"音乐名字", max_length32)singer models.CharField(verbose_name"歌手", max_length32)# 本质…...
[路由器]IP-MAC的绑定与取消
背景:当公司的网络不想与外部人员进行共享,可以在路由器页面配置IP-MAC的绑定,让公司内部人员的手机和电脑的mac,才能接入到公司。第一步:在ARP防护中,启动IP-MAC绑定选项,必须启动仅允许IP-MAC…...
Idea配置远程开发
Idea配置远程开发 本篇博客介绍使用idea通过ssh连接ubuntu服务器进行开发 目录 Idea配置远程开发1.idae上点击file->Remote Development2.点击New Connection3.填写相关信息4.输入密码5.选择IDE版本和项目路径5.1 点击open an SSH terminal打开控制台5.2 依次执行命令 6.成…...
lua 实现 函数 判断两个时间戳是否在同一天
函数用于判断两个时间戳是否在同一天。下面是对代码的详细解释: ### 函数参数 - stampA 和 stampB:两个时间戳,用于比较。- resetInfo:一个可选参数,包含小时、分钟和秒数,用于调整时间戳。 ### 函数实现…...
工作纪实53-log4j日志打印文件隔离
在项目中,我有一堆业务日志需要打印,另一部分的日志,是没有格式的,需要被云平台离线解析并收集到kafka或者hdfs、hive等,需要将日志隔离打印到不同的文件 正常的log4j配置是下面这样的,配合Sl4j直接使用默认…...
7月21日,贪心练习
大家好呀,今天带来一些贪心算法的应用解题、 一,柠檬水找零 . - 力扣(LeetCode) 解析: 本题的贪心体现在对于20美元的处理上,我们总是优先把功能较少的10元作为找零,这样可以让5元用处更大 …...
FPGA DNA 获取 DNA_PORT
FPGA DNA DNA 是 FPGA 芯片的唯一标识, FPGA 都有一个独特的 ID ,也就是 Device DNA ,这个 ID 相当于我们的身份证,在 FPGA 芯片生产的时候就已经固定在芯片的 eFuse 寄存器中,具有不可修改的属性。在 xilinx 7series…...
使用 hutool工具实现导入导出功能。
hutool工具网址 Hutool参考文档 pom依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.7.20</version></dependency><dependency><groupId>org.apache.poi</gro…...
大语言模型-Transformer-Attention Is All You Need
一、背景信息: Transformer是一种由谷歌在2017年提出的深度学习模型。 主要用于自然语言处理(NLP)任务,特别是序列到序列(Sequence-to-Sequence)的学习问题,如机器翻译、文本生成等。Transfor…...
spring(二)
一、为对象类型属性赋值 方式一:(引用外部bean) 1.创建班级类Clazz package com.spring.beanpublic class Clazz {private Integer clazzId;private String clazzName;public Integer getClazzId() {return clazzId;}public void setClazzId(Integer clazzId) {th…...
MAC 数据恢复软件: STELLAR Data Recovery For MAC V. 12.1 更多增强功能
天津鸿萌科贸发展有限公司是 Stellar 系列软件的授权代理商。 STELLAR Data Recovery For MAC 该数据恢复软件可从任何存储驱动器、清空的回收站以及崩溃或无法启动的 Mac 设备中恢复丢失或删除的文件。 轻松恢复已删除的文档、照片、音频文件和视频。自定义扫描以帮助恢复特…...
初识godot游戏引擎并安装
简介 Godot是一款自由开源、由社区驱动的2D和3D游戏引擎。游戏开发虽复杂,却蕴含一定的通用规律,正是为了简化这些通用化的工作,游戏引擎应运而生。Godot引擎作为一款功能丰富的跨平台游戏引擎,通过统一的界面支持创建2D和3D游戏。…...
Windows配置Qt+VLC
文章目录 前言下载库文件提取文件编写qmakeqtvlc测试代码 总结 前言 在Windows平台上配置Qt和VLC是开发多媒体应用程序的一个重要步骤。Qt作为一个强大的跨平台应用开发框架,为开发人员提供了丰富的GUI工具和库,而VLC则是一个开源的多媒体播放器&#x…...
本地部署 mistralai/Mistral-Nemo-Instruct-2407
本地部署 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟环境2. 安装 fschat3. 安装 transformers4. 安装 flash-attn5. 安装 pytorch6. 启动 controller7. 启动 mistralai/Mistral-Nemo-Instruct-24078. 启动 api9. 访问 mistralai/Mistral-Nemo-Instruct-2407 1. 创建虚拟…...
ARM Cortex-M中断状态寄存器实战:从配置到调试的完整指南
ARM Cortex-M中断状态寄存器实战:从配置到调试的完整指南 在嵌入式开发领域,中断处理是系统实时响应的核心机制。作为ARM Cortex-M系列处理器的开发者,深入理解中断状态寄存器(Interrupt Status Register)的工作原理和操作技巧,能…...
OpenClaw隐私保护:GLM-4.7-Flash本地处理敏感数据的实践方案
OpenClaw隐私保护:GLM-4.7-Flash本地处理敏感数据的实践方案 1. 为什么需要本地化AI处理敏感数据? 去年我在处理公司财务报告自动化时遇到一个棘手问题:使用云端AI服务需要上传包含客户隐私的Excel文件到第三方服务器。尽管服务商承诺数据安…...
FedProto:跨异构客户端的原型联邦学习实践指南
1. 从零理解FedProto的核心思想 第一次听说FedProto时,我正被一个医疗影像分析项目搞得焦头烂额。五家医院的数据就像五个方言区——同样的病症在CT影像上呈现的特征分布天差地别。传统联邦学习就像让这些医院用各自的方言写报告,再强行翻译成标准语&…...
AI教材生成强力工具!低查重保障,让教材编写事半功倍!
梳理教材知识点确实是一项“精细活”,最大的挑战在于平衡和衔接知识之间的关系。如果不小心,很可能会遗漏一些核心知识点,或者在难度的把控上出现问题——小学教材常常写得过于复杂,让学生难以理解;而高中教材又可能显…...
医学影像与卫星图的救星?深入聊聊JPEG-LS算法在边缘计算设备上的应用优势
JPEG-LS算法:边缘计算时代的医学影像与卫星图像压缩利器 当一台CT扫描仪每秒产生数百张16位深度的医学影像,或一颗遥感卫星每天传回数TB的高清地表数据时,传统的图像压缩方案往往面临两难选择——要么牺牲宝贵的诊断细节,要么耗尽…...
17 种 RAG 优化策略
RAG 完整解析 本文适合小白入门,全程用「公司员工手册查病假」为统一实例,清晰讲解 RAG 是什么、工作流程,以及 17 种 RAG 优化策略(含标准英文术语),所有内容可直接复制用于分享,实例均精确到具…...
如何用FCEUX重温经典游戏?全场景部署指南
如何用FCEUX重温经典游戏?全场景部署指南 【免费下载链接】fceux FCEUX, a NES Emulator 项目地址: https://gitcode.com/gh_mirrors/fc/fceux 为什么选择FCEUX模拟器?🎮 在众多NES模拟器中,FCEUX凭借三大核心优势脱颖而出…...
嵌入式串口协议中间件:轻量级SerHelp库设计与应用
1. 项目概述nahs-Bricks-Lib-SerHelp是 NAHS(North American Home System)生态中面向嵌入式砖块化(Brick-based)硬件平台的一套轻量级串行通信辅助库。该库不提供底层驱动实现,而是聚焦于串口协议层的工程化封装与通用…...
语音播报实时
目录 GPT-SoVITS(强烈推荐) Fish Speech-1.5 GPT-SoVITS(强烈推荐) RVC-Boss/GPT-SoVITS: 1 min voice data can also be used to train a good TTS model! (few shot voice cloning) Fish Speech-1.5 追求极致流畅的实时对话&a…...
TscanCode静态代码扫描工具原理与实践
嵌入式静态代码扫描工具TscanCode深度解析1. 静态代码分析技术概述1.1 静态代码扫描原理静态代码扫描是一种在不实际执行程序的情况下,通过词法分析、语法分析、控制流和数据流分析等技术对源代码进行检测的方法。这种技术能够有效识别代码中潜在的错误和缺陷&#…...
