【深度学习】基于华为MindSpore的手写体图像识别实验
1 实验介绍
1.1 简介
Mnist手写体图像识别实验是深度学习入门经典实验。Mnist数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到255。为简单起见,每个图像都被平展并转换为784(28*28)个特征的一维numpy数组。
1.2 实验目的
- 学会如何搭建全连接神经网络。
- 掌握搭建网络过程中的关键点。
- 掌握分类任务的整体流程。
2.2 实验环境要求
推荐在华为云ModelArts实验平台完成实验,也可在本地搭建python3.7.5和MindSpore1.0.0环境完成实验。
2.3 实验总体设计

创建实验环境:在本地搭建MindSpore环境。
导入实验所需模块:该步骤通常都是程序编辑的第一步,将实验代码所需要用到的模块包用import命令进行导入。
导入数据集并预处理:神经网络的训练离不开数据,这里对数据进行导入。同时,因为全连接网络只能接收固定维度的输入数据,所以,要对数据集进行预处理,以符合网络的输入维度要求。同时,设定好每一次训练的Batch的大小,以Batch Size为单位进行输入。
模型搭建:利用mindspore.nn的cell模块搭建全连接网络,包含输入层,隐藏层,输出层。同时,配置好网络需要的优化器,损失函数和评价指标。传入数据,并开始训练模型。
模型评估:利用测试集进行模型的评估。
2.4 实验过程
2.4.1 搭建实验环境
Windows下MindSpore实验环境搭建并配置Pycharm请参考【机器学习】Windows下MindSpore实验环境搭建并配置Pycharm_在pycharm上安装mindspore_弓长纟隹为的博客-CSDN博客
官网下载MNIST数据集 MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
在MNIST文件夹下建立train和test两个文件夹,train中存放train-labels-idx1-ubyte和train-images-idx3-ubyte文件,test中存放t10k-labels-idx1-ubyte和t10k-images-idx3-ubyte文件。
2.4.2 模型训练、测试及评估
#导入相关依赖库
import os
import numpy as np
from matplotlib import pyplot as plt
import mindspore as ms
#context模块用于设置实验环境和实验设备
import mindspore.context as context
#dataset模块用于处理数据形成数据集
import mindspore.dataset as ds
#c_transforms模块用于转换数据类型
import mindspore.dataset.transforms as C
#vision.c_transforms模块用于转换图像,这是一个基于opencv的高级API
import mindspore.dataset.vision as CV
#导入Accuracy作为评价指标
from mindspore.nn.metrics import Accuracy
#nn中有各种神经网络层如:Dense,ReLu
from mindspore import nn
#Model用于创建模型对象,完成网络搭建和编译,并用于训练和评估
from mindspore.train import Model
#LossMonitor可以在训练过程中返回LOSS值作为监控指标
from mindspore.train.callback import LossMonitor
#设定运行模式为动态图模式,并且运行设备为昇腾芯片
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
#MindSpore内置方法读取MNIST数据集
ds_train = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "train"))
ds_test = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "test"))print('训练数据集数量:',ds_train.get_dataset_size())
print('测试数据集数量:',ds_test.get_dataset_size())
#该数据集可以通过create_dict_iterator()转换为迭代器形式,然后通过get_next()一个个输出样本
image=ds_train.create_dict_iterator().get_next()
#print(type(image))
print('图像长/宽/通道数:',image['image'].shape)
#一共10类,用0-9的数字表达类别。
print('一张图像的标签样式:',image['label'])
DATA_DIR_TRAIN = "D:/Dataset/MNIST/train" # 训练集信息
DATA_DIR_TEST = "D:/Dataset/MNIST/test" # 测试集信息def create_dataset(training=True, batch_size=128, resize=(28, 28), rescale=1 / 255, shift=-0.5, buffer_size=64):ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)# 定义改变形状、归一化和更改图片维度的操作。# 改为(28,28)的形状resize_op = CV.Resize(resize)# rescale方法可以对数据集进行归一化和标准化操作,这里就是将像素值归一到0和1之间,shift参数可以让值域偏移至-0.5和0.5之间rescale_op = CV.Rescale(rescale, shift)# 由高度、宽度、深度改为深度、高度、宽度hwc2chw_op = CV.HWC2CHW()# 利用map操作对原数据集进行调整ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))# 设定洗牌缓冲区的大小,从一定程度上控制打乱操作的混乱程度ds = ds.shuffle(buffer_size=buffer_size)# 设定数据集的batch_size大小,并丢弃剩余的样本ds = ds.batch(batch_size, drop_remainder=True)return ds
#显示前10张图片以及对应标签,检查图片是否是正确的数据集
dataset_show = create_dataset(training=False)
data = dataset_show.create_dict_iterator().get_next()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()for i in range(1,11):plt.subplot(2, 5, i)#利用squeeze方法去掉多余的一个维度plt.imshow(np.squeeze(images[i]))plt.title('Number: %s' % labels[i])plt.xticks([])
plt.show()# 利用定义类的方式生成网络,Mindspore中定义网络需要继承nn.cell。在init方法中定义该网络需要的神经网络层
# 在construct方法中梳理神经网络层与层之间的关系。
class ForwardNN(nn.Cell):def __init__(self):super(ForwardNN, self).__init__()self.flatten = nn.Flatten()self.relu = nn.ReLU()self.fc1 = nn.Dense(784, 512, activation='relu')self.fc2 = nn.Dense(512, 256, activation='relu')self.fc3 = nn.Dense(256, 128, activation='relu')self.fc4 = nn.Dense(128, 64, activation='relu')self.fc5 = nn.Dense(64, 32, activation='relu')self.fc6 = nn.Dense(32, 10, activation='softmax')def construct(self, input_x):output = self.flatten(input_x)output = self.fc1(output)output = self.fc2(output)output = self.fc3(output)output = self.fc4(output)output = self.fc5(output)output = self.fc6(output)return outputlr = 0.001
num_epoch = 10
momentum = 0.9net = ForwardNN()
#定义loss函数,改函数不需要求导,可以给离散的标签值,且loss值为均值
loss = nn.loss.SoftmaxCrossEntropyWithLogits( sparse=True, reduction='mean')
#定义准确率为评价指标,用于评价模型
metrics={"Accuracy": Accuracy()}
#定义优化器为Adam优化器,并设定学习率
opt = nn.Adam(net.trainable_params(), lr)#生成验证集,验证机不需要训练,所以不需要repeat
ds_eval = create_dataset(False, batch_size=32)
#模型编译过程,将定义好的网络、loss函数、评价指标、优化器编译
model = Model(net, loss, opt, metrics)#生成训练集
ds_train = create_dataset(True, batch_size=32)
print("============== Starting Training ==============")
#训练模型,用loss作为监控指标,并利用昇腾芯片的数据下沉特性进行训练
model.train(num_epoch, ds_train,callbacks=[LossMonitor()],dataset_sink_mode=True)#使用测试集评估模型,打印总体准确率
metrics_result=model.eval(ds_eval)
print(metrics_result)


备注:
若报错 AttributeError: ‘DictIterator’ object has no attribute ‘get_next’ ,这是说MindSpore数据类中缺少 “get_next”这个方法,但是在MNIST图像识别的官方代码中却使用了这个方法,这就说明MindSpore官方把这个变成私密方法。
只需要在源码iterators.py中找到DictIterator这个类,将私有方法变成公有方法就行了(即去掉最前面的下划线)。
参考mindspore 报错 AttributeError: ‘DictIterator‘ object has no attribute ‘get_next‘_create_dict_iterator_TNiuB的博客-CSDN博客
MindSpore:前馈神经网络时报错‘DictIterator‘ has no attribute ‘get_next‘_skytier的博客-CSDN博客

更多问题请参考Window10 上MindSpore(CPU)用LeNet网络训练MNIST - 知乎
相关文章:
【深度学习】基于华为MindSpore的手写体图像识别实验
1 实验介绍 1.1 简介 Mnist手写体图像识别实验是深度学习入门经典实验。Mnist数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到255。为简单起见,每…...
Linux:内核调试之内核魔术键sysrq
在linux系统下,我们可能会遇到系统某个命令hang住的情况,通常情况下,我们会查看/proc/pid/wchan文件,看看进程处于什么状况,然后进一步查看系统日志或者使用strace跟踪命令执行时的系统调用等等方法来分析问题。我们知…...
Python import导包快速入门
import 和 from import 在 Python 中,使用 import 语句可以将其他 Python 模块或包中的代码引入到当前模块中,以供使用。通常情况下,我们可以使用以下语法将整个模块导入到当前命名空间中: import module_name其中,m…...
ChatGPT这么火,我们能怎么办?
今天打开百度,看到这样一条热搜高居榜二:B站UP主发起停更潮,然后点进去了解一看,大体是因为最近AI创作太火,对高质量原创形成了巨大冲击!记得之前看过一位UP主的分享,说B站UP主的年收入大体约等…...
HashMap底层原理
文章目录1. 基本概念2. HashMap 的底层数据结构3. HashMap 的 put 方法流程4. 怎么计算节点存储的下标5. Hash 冲突1)概念2)解决 hash 冲突的办法开放地址法再哈希法链地址法建立公共溢出区6. HashMap 的扩容机制1)扩容时涉及到的几个属性2&a…...
卡顿优化小结
卡顿的本质 卡顿的本质是因为一次垂直同步信号来的时候,当前帧要显示的图像数据还没准备好,只能等待16ms下一次垂直同步信号来时才能更新画面,在这段时间里显示器只能一直停留在上一帧的画面,如果跳过的帧数过多,就会…...
springboot前端ajax 04 关于后台传的时间和状态在前端的转换
修改状态及时间格式 在jsp中,时间显式: 只需要把json的时间部分改为用Date对象来显示就好了。 <td>new Date(jsonObj[i].startTime).toLocaleString()</td> <td>new Date(jsonObj[i].endTime).toLocaleString()</td> 状态对象…...
解决Windows微信和 PowerToys 的键盘管理器冲突
Windows开机之后PowerToys能正常使用, 但是打开微信之后设置好的快捷键映射就全部失效了 打开微信 -> 左下角三条杠 -> 设置 -> 快捷键 首先我把微信的快捷键全部清空了,发现还是没用 然后发现了设置里默认勾选了检测快捷键,我在想程序肯定是一直在后台检测,而powerTo…...
组会时间的工作
1. 党支部活动 看看组织生活记录本写完了没有 2. 论文翻译...
linux udp bind 返回值-1分析
在linux socket通信中,我们通常用到open/bind/read/write等内部函数,那么当这些函数返回值为-1的时候,我们怎么进一步定位呢! (1)怎么打印出返回值出错的原因呢!系统调用的错误都会存放在errno中 errno需要的头文件: #include<errno.h> strerror头文件,将错误信…...
Hexo搭建博客
文章目录开始安装使用配置主题配置gitee配置域名之前配置过hexo但是后来hexo文件夹莫名其妙崩了,我也懒得修理,就没管了,现在又想重拾回来。然后hexo可以搭建静态博客网站,放在github或者gitee都行,有免费的网页空间&a…...
Lesson11:http协议
前言 应用层:就是程序员基于socket接口之上编写的具体逻辑,做了很多工作,都是和文本处理有关的--- 协议分析与处理http协议,一定会具有大量的文体分析和协议处理如果用户想再url中包含url本身用来作为特殊字符的字符,url形式的时候,浏览器会自动给我们进行编码encode一般服务端…...
计算机信息安全有哪些SCI期刊推荐? - 易智编译EaseEditing
以下是计算机信息安全方向的SCI期刊推荐: IEEE Transactions on Information Forensics and Security 该期刊主要发表信息安全和数字取证方面的原创性研究,包括数据安全、网络安全、身份认证、加密、信息隐藏等领域的研究成果。该期刊的影响因子为8.134…...
CNVD-2023-12632 泛微e-cology9 sql注入 附poc
目录 漏洞描述影响版本漏洞复现漏洞修复 漏洞描述 泛微 E-Cology9 协同办公系统是一套基于 JSP 及 SQL Server 数据库的 OA 系统,包括知识文档管理、人力资源管理、客户关系管理、项目管理、财务管理、工作流程管理、数据中心等打造协同高效的企业管理环境&#…...
赛宁网安合作伙伴大会成功举办,重磅发布SCBaaS服务!
3月29日,“赛宁网安合作伙伴大会”在江苏南京隆重举办。大会现场汇集网络安全数字化领域的专业人才、技术专家,共同研讨数字安全发展趋势,分享智能安全解决方案和技术创新产品。 会上,赛宁网安产品专家针对数字化靶场、网络安…...
R语言 4.2.2安装包下载及安装教程
[软件名称]:R语言 4.2.2 [软件大小]: 75.6 MB [安装环境]: Win11/Win10/Win7 [软件安装包下载]: https://pan.quark.cn/s/b6f604930d04 R语言软件的GUI界面比较的简陋,只有一个命令行窗口,且每次创建图片都会跳出一个新的窗口,比较的繁琐,我们可以安装RStudio,来更方便的操作R(…...
快速玩转 CNStack 2.0 流量防护
作者:冠钰 云原生下的服务治理 在云原生技术的演进过程中,依托云原生技术能力,形成一个可以向下管理基础设施,向上管理业务应用的技术中台,越来越成为企业期望的云原生技术落地趋势。随着云原生技术中台 CNStack 发布…...
你还在用原生 poi 处理 excel?太麻烦了来瞧瞧这个
1、easypoi 前言 Excel 在日常工作中经常被用来存储用例信息,是一种非常便捷的数据存储工具有着众多的优点,我们就不一一介绍了。 今天来讲讲 Java 操作 Excel,总所周知 Java 是世界上最好的语言(不容反驳)ÿ…...
No.027<软考>《(高项)备考大全》【第11章】项目风险管理
【第11章】项目风险管理1 章节相关1.1 考试相关1.2 ITO口诀2 章节概述2.1 风险的含义2.2 风险定义的三个必要条件2.3 项目风险2.4 风险的随机性和相对性2.5 风险的分类2.6 风险成本2.7.1 风险损失有形成本2.7.2 风险损失无形成本2.8 项目风险管理过程3 规划风险管理4 识别风险4…...
mit6.824 lab2c-数据持久化
目录2c简介2b、2a问题测试时间2c简介 简单的说,raft需要将currentTerm、voteFor、entries(当前的所有日志)保存到硬盘进行持久化存储。 保存的方法:在变量改变时,利用persist()中的gob将变量序列化,存储在persister结构体中。&a…...
React 第五十五节 Router 中 useAsyncError的使用详解
前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...
springboot 百货中心供应链管理系统小程序
一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...
Java - Mysql数据类型对应
Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...
什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...
cf2117E
原题链接:https://codeforces.com/contest/2117/problem/E 题目背景: 给定两个数组a,b,可以执行多次以下操作:选择 i (1 < i < n - 1),并设置 或,也可以在执行上述操作前执行一次删除任意 和 。求…...
苍穹外卖--缓存菜品
1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...
Java入门学习详细版(一)
大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...
零基础在实践中学习网络安全-皮卡丘靶场(第九期-Unsafe Fileupload模块)(yakit方式)
本期内容并不是很难,相信大家会学的很愉快,当然对于有后端基础的朋友来说,本期内容更加容易了解,当然没有基础的也别担心,本期内容会详细解释有关内容 本期用到的软件:yakit(因为经过之前好多期…...
如何在网页里填写 PDF 表格?
有时候,你可能希望用户能在你的网站上填写 PDF 表单。然而,这件事并不简单,因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件,但原生并不支持编辑或填写它们。更糟的是,如果你想收集表单数据ÿ…...
初学 pytest 记录
安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...
