【深度学习】基于华为MindSpore的手写体图像识别实验
1 实验介绍
1.1 简介
Mnist手写体图像识别实验是深度学习入门经典实验。Mnist数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到255。为简单起见,每个图像都被平展并转换为784(28*28)个特征的一维numpy数组。
1.2 实验目的
- 学会如何搭建全连接神经网络。
- 掌握搭建网络过程中的关键点。
- 掌握分类任务的整体流程。
2.2 实验环境要求
推荐在华为云ModelArts实验平台完成实验,也可在本地搭建python3.7.5和MindSpore1.0.0环境完成实验。
2.3 实验总体设计
创建实验环境:在本地搭建MindSpore环境。
导入实验所需模块:该步骤通常都是程序编辑的第一步,将实验代码所需要用到的模块包用import命令进行导入。
导入数据集并预处理:神经网络的训练离不开数据,这里对数据进行导入。同时,因为全连接网络只能接收固定维度的输入数据,所以,要对数据集进行预处理,以符合网络的输入维度要求。同时,设定好每一次训练的Batch的大小,以Batch Size为单位进行输入。
模型搭建:利用mindspore.nn的cell模块搭建全连接网络,包含输入层,隐藏层,输出层。同时,配置好网络需要的优化器,损失函数和评价指标。传入数据,并开始训练模型。
模型评估:利用测试集进行模型的评估。
2.4 实验过程
2.4.1 搭建实验环境
Windows下MindSpore实验环境搭建并配置Pycharm请参考【机器学习】Windows下MindSpore实验环境搭建并配置Pycharm_在pycharm上安装mindspore_弓长纟隹为的博客-CSDN博客
官网下载MNIST数据集 MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
在MNIST文件夹下建立train和test两个文件夹,train中存放train-labels-idx1-ubyte和train-images-idx3-ubyte文件,test中存放t10k-labels-idx1-ubyte和t10k-images-idx3-ubyte文件。
2.4.2 模型训练、测试及评估
#导入相关依赖库
import os
import numpy as np
from matplotlib import pyplot as plt
import mindspore as ms
#context模块用于设置实验环境和实验设备
import mindspore.context as context
#dataset模块用于处理数据形成数据集
import mindspore.dataset as ds
#c_transforms模块用于转换数据类型
import mindspore.dataset.transforms as C
#vision.c_transforms模块用于转换图像,这是一个基于opencv的高级API
import mindspore.dataset.vision as CV
#导入Accuracy作为评价指标
from mindspore.nn.metrics import Accuracy
#nn中有各种神经网络层如:Dense,ReLu
from mindspore import nn
#Model用于创建模型对象,完成网络搭建和编译,并用于训练和评估
from mindspore.train import Model
#LossMonitor可以在训练过程中返回LOSS值作为监控指标
from mindspore.train.callback import LossMonitor
#设定运行模式为动态图模式,并且运行设备为昇腾芯片
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
#MindSpore内置方法读取MNIST数据集
ds_train = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "train"))
ds_test = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "test"))print('训练数据集数量:',ds_train.get_dataset_size())
print('测试数据集数量:',ds_test.get_dataset_size())
#该数据集可以通过create_dict_iterator()转换为迭代器形式,然后通过get_next()一个个输出样本
image=ds_train.create_dict_iterator().get_next()
#print(type(image))
print('图像长/宽/通道数:',image['image'].shape)
#一共10类,用0-9的数字表达类别。
print('一张图像的标签样式:',image['label'])
DATA_DIR_TRAIN = "D:/Dataset/MNIST/train" # 训练集信息
DATA_DIR_TEST = "D:/Dataset/MNIST/test" # 测试集信息def create_dataset(training=True, batch_size=128, resize=(28, 28), rescale=1 / 255, shift=-0.5, buffer_size=64):ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)# 定义改变形状、归一化和更改图片维度的操作。# 改为(28,28)的形状resize_op = CV.Resize(resize)# rescale方法可以对数据集进行归一化和标准化操作,这里就是将像素值归一到0和1之间,shift参数可以让值域偏移至-0.5和0.5之间rescale_op = CV.Rescale(rescale, shift)# 由高度、宽度、深度改为深度、高度、宽度hwc2chw_op = CV.HWC2CHW()# 利用map操作对原数据集进行调整ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))# 设定洗牌缓冲区的大小,从一定程度上控制打乱操作的混乱程度ds = ds.shuffle(buffer_size=buffer_size)# 设定数据集的batch_size大小,并丢弃剩余的样本ds = ds.batch(batch_size, drop_remainder=True)return ds
#显示前10张图片以及对应标签,检查图片是否是正确的数据集
dataset_show = create_dataset(training=False)
data = dataset_show.create_dict_iterator().get_next()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()for i in range(1,11):plt.subplot(2, 5, i)#利用squeeze方法去掉多余的一个维度plt.imshow(np.squeeze(images[i]))plt.title('Number: %s' % labels[i])plt.xticks([])
plt.show()# 利用定义类的方式生成网络,Mindspore中定义网络需要继承nn.cell。在init方法中定义该网络需要的神经网络层
# 在construct方法中梳理神经网络层与层之间的关系。
class ForwardNN(nn.Cell):def __init__(self):super(ForwardNN, self).__init__()self.flatten = nn.Flatten()self.relu = nn.ReLU()self.fc1 = nn.Dense(784, 512, activation='relu')self.fc2 = nn.Dense(512, 256, activation='relu')self.fc3 = nn.Dense(256, 128, activation='relu')self.fc4 = nn.Dense(128, 64, activation='relu')self.fc5 = nn.Dense(64, 32, activation='relu')self.fc6 = nn.Dense(32, 10, activation='softmax')def construct(self, input_x):output = self.flatten(input_x)output = self.fc1(output)output = self.fc2(output)output = self.fc3(output)output = self.fc4(output)output = self.fc5(output)output = self.fc6(output)return outputlr = 0.001
num_epoch = 10
momentum = 0.9net = ForwardNN()
#定义loss函数,改函数不需要求导,可以给离散的标签值,且loss值为均值
loss = nn.loss.SoftmaxCrossEntropyWithLogits( sparse=True, reduction='mean')
#定义准确率为评价指标,用于评价模型
metrics={"Accuracy": Accuracy()}
#定义优化器为Adam优化器,并设定学习率
opt = nn.Adam(net.trainable_params(), lr)#生成验证集,验证机不需要训练,所以不需要repeat
ds_eval = create_dataset(False, batch_size=32)
#模型编译过程,将定义好的网络、loss函数、评价指标、优化器编译
model = Model(net, loss, opt, metrics)#生成训练集
ds_train = create_dataset(True, batch_size=32)
print("============== Starting Training ==============")
#训练模型,用loss作为监控指标,并利用昇腾芯片的数据下沉特性进行训练
model.train(num_epoch, ds_train,callbacks=[LossMonitor()],dataset_sink_mode=True)#使用测试集评估模型,打印总体准确率
metrics_result=model.eval(ds_eval)
print(metrics_result)
备注:
若报错 AttributeError: ‘DictIterator’ object has no attribute ‘get_next’ ,这是说MindSpore数据类中缺少 “get_next”这个方法,但是在MNIST图像识别的官方代码中却使用了这个方法,这就说明MindSpore官方把这个变成私密方法。
只需要在源码iterators.py中找到DictIterator这个类,将私有方法变成公有方法就行了(即去掉最前面的下划线)。
参考mindspore 报错 AttributeError: ‘DictIterator‘ object has no attribute ‘get_next‘_create_dict_iterator_TNiuB的博客-CSDN博客
MindSpore:前馈神经网络时报错‘DictIterator‘ has no attribute ‘get_next‘_skytier的博客-CSDN博客
更多问题请参考Window10 上MindSpore(CPU)用LeNet网络训练MNIST - 知乎
相关文章:

【深度学习】基于华为MindSpore的手写体图像识别实验
1 实验介绍 1.1 简介 Mnist手写体图像识别实验是深度学习入门经典实验。Mnist数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到255。为简单起见,每…...

Linux:内核调试之内核魔术键sysrq
在linux系统下,我们可能会遇到系统某个命令hang住的情况,通常情况下,我们会查看/proc/pid/wchan文件,看看进程处于什么状况,然后进一步查看系统日志或者使用strace跟踪命令执行时的系统调用等等方法来分析问题。我们知…...

Python import导包快速入门
import 和 from import 在 Python 中,使用 import 语句可以将其他 Python 模块或包中的代码引入到当前模块中,以供使用。通常情况下,我们可以使用以下语法将整个模块导入到当前命名空间中: import module_name其中,m…...

ChatGPT这么火,我们能怎么办?
今天打开百度,看到这样一条热搜高居榜二:B站UP主发起停更潮,然后点进去了解一看,大体是因为最近AI创作太火,对高质量原创形成了巨大冲击!记得之前看过一位UP主的分享,说B站UP主的年收入大体约等…...

HashMap底层原理
文章目录1. 基本概念2. HashMap 的底层数据结构3. HashMap 的 put 方法流程4. 怎么计算节点存储的下标5. Hash 冲突1)概念2)解决 hash 冲突的办法开放地址法再哈希法链地址法建立公共溢出区6. HashMap 的扩容机制1)扩容时涉及到的几个属性2&a…...

卡顿优化小结
卡顿的本质 卡顿的本质是因为一次垂直同步信号来的时候,当前帧要显示的图像数据还没准备好,只能等待16ms下一次垂直同步信号来时才能更新画面,在这段时间里显示器只能一直停留在上一帧的画面,如果跳过的帧数过多,就会…...

springboot前端ajax 04 关于后台传的时间和状态在前端的转换
修改状态及时间格式 在jsp中,时间显式: 只需要把json的时间部分改为用Date对象来显示就好了。 <td>new Date(jsonObj[i].startTime).toLocaleString()</td> <td>new Date(jsonObj[i].endTime).toLocaleString()</td> 状态对象…...

解决Windows微信和 PowerToys 的键盘管理器冲突
Windows开机之后PowerToys能正常使用, 但是打开微信之后设置好的快捷键映射就全部失效了 打开微信 -> 左下角三条杠 -> 设置 -> 快捷键 首先我把微信的快捷键全部清空了,发现还是没用 然后发现了设置里默认勾选了检测快捷键,我在想程序肯定是一直在后台检测,而powerTo…...

组会时间的工作
1. 党支部活动 看看组织生活记录本写完了没有 2. 论文翻译...

linux udp bind 返回值-1分析
在linux socket通信中,我们通常用到open/bind/read/write等内部函数,那么当这些函数返回值为-1的时候,我们怎么进一步定位呢! (1)怎么打印出返回值出错的原因呢!系统调用的错误都会存放在errno中 errno需要的头文件: #include<errno.h> strerror头文件,将错误信…...

Hexo搭建博客
文章目录开始安装使用配置主题配置gitee配置域名之前配置过hexo但是后来hexo文件夹莫名其妙崩了,我也懒得修理,就没管了,现在又想重拾回来。然后hexo可以搭建静态博客网站,放在github或者gitee都行,有免费的网页空间&a…...

Lesson11:http协议
前言 应用层:就是程序员基于socket接口之上编写的具体逻辑,做了很多工作,都是和文本处理有关的--- 协议分析与处理http协议,一定会具有大量的文体分析和协议处理如果用户想再url中包含url本身用来作为特殊字符的字符,url形式的时候,浏览器会自动给我们进行编码encode一般服务端…...

计算机信息安全有哪些SCI期刊推荐? - 易智编译EaseEditing
以下是计算机信息安全方向的SCI期刊推荐: IEEE Transactions on Information Forensics and Security 该期刊主要发表信息安全和数字取证方面的原创性研究,包括数据安全、网络安全、身份认证、加密、信息隐藏等领域的研究成果。该期刊的影响因子为8.134…...

CNVD-2023-12632 泛微e-cology9 sql注入 附poc
目录 漏洞描述影响版本漏洞复现漏洞修复 漏洞描述 泛微 E-Cology9 协同办公系统是一套基于 JSP 及 SQL Server 数据库的 OA 系统,包括知识文档管理、人力资源管理、客户关系管理、项目管理、财务管理、工作流程管理、数据中心等打造协同高效的企业管理环境&#…...

赛宁网安合作伙伴大会成功举办,重磅发布SCBaaS服务!
3月29日,“赛宁网安合作伙伴大会”在江苏南京隆重举办。大会现场汇集网络安全数字化领域的专业人才、技术专家,共同研讨数字安全发展趋势,分享智能安全解决方案和技术创新产品。 会上,赛宁网安产品专家针对数字化靶场、网络安…...

R语言 4.2.2安装包下载及安装教程
[软件名称]:R语言 4.2.2 [软件大小]: 75.6 MB [安装环境]: Win11/Win10/Win7 [软件安装包下载]: https://pan.quark.cn/s/b6f604930d04 R语言软件的GUI界面比较的简陋,只有一个命令行窗口,且每次创建图片都会跳出一个新的窗口,比较的繁琐,我们可以安装RStudio,来更方便的操作R(…...

快速玩转 CNStack 2.0 流量防护
作者:冠钰 云原生下的服务治理 在云原生技术的演进过程中,依托云原生技术能力,形成一个可以向下管理基础设施,向上管理业务应用的技术中台,越来越成为企业期望的云原生技术落地趋势。随着云原生技术中台 CNStack 发布…...

你还在用原生 poi 处理 excel?太麻烦了来瞧瞧这个
1、easypoi 前言 Excel 在日常工作中经常被用来存储用例信息,是一种非常便捷的数据存储工具有着众多的优点,我们就不一一介绍了。 今天来讲讲 Java 操作 Excel,总所周知 Java 是世界上最好的语言(不容反驳)ÿ…...

No.027<软考>《(高项)备考大全》【第11章】项目风险管理
【第11章】项目风险管理1 章节相关1.1 考试相关1.2 ITO口诀2 章节概述2.1 风险的含义2.2 风险定义的三个必要条件2.3 项目风险2.4 风险的随机性和相对性2.5 风险的分类2.6 风险成本2.7.1 风险损失有形成本2.7.2 风险损失无形成本2.8 项目风险管理过程3 规划风险管理4 识别风险4…...

mit6.824 lab2c-数据持久化
目录2c简介2b、2a问题测试时间2c简介 简单的说,raft需要将currentTerm、voteFor、entries(当前的所有日志)保存到硬盘进行持久化存储。 保存的方法:在变量改变时,利用persist()中的gob将变量序列化,存储在persister结构体中。&a…...

leaflet使用L.geoJSON加载文件,参数filter的使用方法(127)
第127个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+leaflet中加载geojson文件,这里介绍filter的使用方法。filter将用于决定是否包含某个功能的函数。 默认是包括所有特征。 直接复制下面的 vue+leaflet源代码,操作2分钟即可运行实现效果 文章目录 示例效果配置方…...

23年5月高项学习笔记7—— 质量管理
质量通常指产品质量,也包括工作质量(即过程),产品质量是指产品的使用价值,工作质量是产品质量的保证,反映了产品质量直接相关的工作的对产品质量的保证程度。 公差:结果的可接受范围 项目合同…...

学编程需要哪些基础呢?一起来看看吧
众所周知程序员薪酬高、工作环境好,是很多人向往的职业,那么学编程需要什么基础?0基础能学编程吗? 学编程需要什么基础? 1、数学基础 从计算机发展和应用的历史来看计算机的数学模型和体系结构等都是有数学家提出的&…...

PECS In Java泛型类型通配符限定之<? extends T>与<? super T>
泛型类型通配符限定 🚆PECS | 类型通配符限定如何使用“<? extends T>”和“<? super T>”通配符java源码示例PECS | 类型通配符限定 PECS原则是指在使用泛型时,当我们需要传递一个泛型集合时,如何选择适当的泛型类型通配符来…...

电子招投标系统源码之了解电子招标投标全流程
随着各级政府部门的大力推进,以及国内互联网的建设,电子招投标已经逐渐成为国内主流的招标投标方式,但是依然有很多人对电子招投标的流程不够了解,在具体操作上存在困难。虽然各个交易平台的招标投标在线操作会略有不同࿰…...

admin Tips
1 获取 当前浏览器 url new URL(window.location.href)...

ToBeWritten之Radare2 使用教程
也许每个人出生的时候都以为这世界都是为他一个人而存在的,当他发现自己错的时候,他便开始长大 少走了弯路,也就错过了风景,无论如何,感谢经历 转移发布平台通知:将不再在CSDN博客发布新文章,敬…...

实时翻译屏幕插件
程序插件的功能是:点击按钮,将获取屏幕截图,然后翻译输出图片。(目前只支持翻译英语) 要实现这个功能,我们可以使用Python编程语言,结合一些库来完成。以下是一个简单的实现方案: …...

代码随想录算法训练营第二天| 977,209,59
977.有序数组的平方 * 数组平方后,最大值一定是在两侧 因为可以采用双指针 package algor.trainingcamp;import java.util.Arrays;/*** author lizhe* version 1.0* description: https://leetcode.cn/problems/squares-of-a-sorted-array/** 有序数组的平方* 给…...

echarts 地图板块点击着色,移除着色
//选择省份变色 showProvince(name) { this.oldName name; this.mapChart && this.mapChart.dispatchAction({ type: geoSelect, name }) }, //移除上次点击变色 hideProvince() { this.mapChart && this.mapChart.dispatchAction({ type: geoUnSelect, name:…...