昇思25天学习打卡营第2天|MindSpore快速入门
打卡

目录
打卡
快速入门案例:minist图像数据识别任务
案例任务说明
流程
1 加载并处理数据集
2 模型网络构建与定义
3 模型约束定义
4 模型训练
5 模型保存
6 模型推理
相关参考文档入门理解
MindSpore数据处理引擎
模型网络参数初始化
模型优化器
损失函数
代码
安装
从模型训练到预测推理
self_main_train_and_save.py
self_dataprocess.py
self_network.py
self_modeltrain.py
self_modeltest.py
self_predict.py
快速入门案例:minist图像数据识别任务
案例任务说明
MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。
目的是训练一个可以高效识别手写阿拉伯数字的模型。
流程
1 加载并处理数据集
涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。
2 模型网络构建与定义
涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来自定义网络结构,其中的construct类函数包含数据(Tensor)的变换过程。。
3 模型约束定义
包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)
4 模型训练
- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
- 定义测试函数,用来评估模型的性能。
5 模型保存
- 两种保存方式:
1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")
2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
6 模型推理
- 两种加载方式:
1)模型参数加载:
> model = network()
> param_dict = mindspore.load_checkpoint("model.ckpt");
> param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
2)统一的中间表示(Intermediate Representation,IR)的加载:
> mindspore.set_context(mode=mindspore.GRAPH_MODE) > graph = mindspore.load("model.mindir") > model = nn.GraphCell(graph) ## nn.GraphCell 仅支持图模式。 > outputs = model(inputs)
保存与加载 — MindSpore master 文档
相关参考文档入门理解
MindSpore数据处理引擎
MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档

MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。
数据集 Dataset — MindSpore master 文档
数据变换 Transforms — MindSpore master 文档
模型网络参数初始化
Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_init、bias_init等入参,可以直接使用实例化的Initializer进行参数初始化。
参数初始化 — MindSpore master 文档
模型优化器
优化器 — MindSpore master 文档
损失函数
损失函数 — MindSpore master 文档
代码
安装
pip/conda均可:
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1
从模型训练到预测推理
训练:
python self_main_train_and_save.py
推理:
python self_predict.py
self_main_train_and_save.py
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True) ## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names()) # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64) ## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
self_dataprocess.py
from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
self_network.py
# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)
self_modeltrain.py
# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train() ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
self_modeltest.py
from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
self_predict.py
## 加载模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load) ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break
相关文章:
昇思25天学习打卡营第2天|MindSpore快速入门
打卡 目录 打卡 快速入门案例:minist图像数据识别任务 案例任务说明 流程 1 加载并处理数据集 2 模型网络构建与定义 3 模型约束定义 4 模型训练 5 模型保存 6 模型推理 相关参考文档入门理解 MindSpore数据处理引擎 模型网络参数初始化 模型优化器 …...
django之url路径
方式一:path 语法:<<转换器类型:自定义>> 作用:若转换器类型匹配到对应类型的数据,则将数据按照关键字传参的方式传递给视图函数 类型: str: 匹配除了”/“之外的非空字符串。 /test/zvxint: 匹配0或任何…...
【OnlyOffice】桌面应用编辑器,插件开发大赛,等你来挑战
OnlyOffice,桌面应用编辑器,最近版本已从8.0升级到了8.1 从PDF、Word、Excel、PPT等全面进行了升级。随着AI应用持续的火热,OnlyOffice也在不断推出AI相关插件。 因此,在此给大家推荐一下OnlyOffice本次的插件开发大赛。 详细信息…...
[学习笔记]SQL学习笔记(连载中。。。)
学习视频:【数据库】SQL 3小时快速入门 #数据库教程 #SQL教程 #MySQL教程 #database#Python连接数据库 目录 1.SQL的基础知识1.1.表(table)和键(key)1.2.外键、联合主键 2.MySQL安装(略,请自行参考视频)3.基本的MySQL语法3.1.规…...
Buuctf之SimpleRev做法
首先,查个壳,64bit,那就丢进ida64中进行反编译进来之后,我们进入main函数,发现里面没什么东西,那就shiftf12搜索字符串,找到关键字符串,双击进入然后再选中该字符串,ctrl…...
【云原生监控】Prometheus 普罗米修斯从搭建到使用详解
目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…...
【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)
目录 一、前言 二、什么是C模板? 💦泛型编程的思想 💦C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ 🔥非类型模板参数的定义 🔥非类型模板参数的两种类型 ὒ…...
Cookie与Session
Cookie Set-Cookie: sessionIdabc123; ExpiresWed, 09 Jun 2024 10:18:14 GMT; Path/; Secure; HttpOnlySession session作用域 首先需要了解servlet容器可能包含多个web应用。 在servlet容器中同一应用的servlet 对 session数据是可见的,不同应用之间session是相互…...
Nuxt3 的生命周期和钩子函数(十一)
title: Nuxt3 的生命周期和钩子函数(十一) date: 2024/7/5 updated: 2024/7/5 author: cmdragon excerpt: 摘要:本文详细介绍了Nuxt3中几个关键的生命周期钩子和它们的使用方法,包括webpack:done用于Webpack编译完成后执行操作…...
Windows ipconfig命令详解,Windows查看IP地址信息
「作者简介」:冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础著作 《网络安全自学教程》,适合基础薄弱的同学系统化的学习网络安全,用最短的时间掌握最核心的技术。 ipconfig 1、基…...
在C#/Net中使用Mqtt
net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…...
VBA提取word表格内容到excel
这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…...
html+css+js图片手动轮播
源代码在界面图片后面 轮播演示用的几张图片是Bing上的,直接用的几张图片的URL,谁加载可能需要等一下,现实中替换成自己的图片即可 关注一下点个赞吧😄 谢谢大佬 界面图片 源代码 <!DOCTYPE html> <html lang&quo…...
【十三】图解 Spring 核心数据结构:BeanDefinition 其二
图解 Spring 核心数据结构:BeanDefinition 其二 概述 前面写过一篇相关文章作为开篇介绍了一下BeanDefinition,本篇将深入细节来向读者展示BeanDefinition的设计,让我们一起来揭开日常开发中使用的bean的神秘面纱,深入细节透彻理解…...
数据库作业
命令 登陆数据库 mysql -uroot -p123456 --prompt"\u\h:\d--> " 创建数据库zcr create database zcr; 修改数据库zcr字符集为gbk alter database zcr default character set gbk collate gbk_chinese_ci; 选择数据库zcr use zcr 查看数据库zc…...
12、matlab中for循环,if else判断语句,break和continue用法以及switch case语句使用
1、前言 在MATLAB中,for循环用于迭代一个固定次数的循环。可以使用if else语句在循环中进行条件判断,根据条件的不同执行相应的代码块。break和continue可以用于控制循环的执行流程,break用于提前结束循环,而continue用于跳过当前…...
AcWing 3207:门禁系统 ← 桶排序中“桶”的思想
【题目来源】https://www.acwing.com/problem/content/3210/【题目描述】 涛涛最近要负责图书馆的管理工作,需要记录下每天读者的到访情况。 每位读者有一个唯一编号,每条记录用读者的编号来表示。 给出读者的来访记录,请问每一条记录中的读者…...
开发个人Go-ChatGPT--3 服务拆分
开发个人Go-ChatGPT–3 服务拆分 个人Go-ChatGPT项目可拆分用户服务(user),AI模型服务(AiModel),… 每个服务都可以再分为 api 服务和 rpc 服务。api 服务对外,可提供给 app 调用。rpc 服务是…...
Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了
新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048,如果电脑的性能和内存一般的话就可能卡死,解决方案是手动修改安卓模拟器的config文件&…...
从入门到深入,Docker新手学习教程
编译整理|TesterHome社区 作者|Ishaan Gupta 以下为作者观点: Docker 彻底改变了我们开发、交付和运行应用程序的方式。它使开发人员能够将应用程序打包到容器中 - 标准化的可执行组件,将应用程序源代码与在任何环境中运行该代码…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院挂号小程序
一、开发准备 环境搭建: 安装DevEco Studio 3.0或更高版本配置HarmonyOS SDK申请开发者账号 项目创建: File > New > Create Project > Application (选择"Empty Ability") 二、核心功能实现 1. 医院科室展示 /…...
dedecms 织梦自定义表单留言增加ajax验证码功能
增加ajax功能模块,用户不点击提交按钮,只要输入框失去焦点,就会提前提示验证码是否正确。 一,模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...
c++ 面试题(1)-----深度优先搜索(DFS)实现
操作系统:ubuntu22.04 IDE:Visual Studio Code 编程语言:C11 题目描述 地上有一个 m 行 n 列的方格,从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子,但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
Module Federation 和 Native Federation 的比较
前言 Module Federation 是 Webpack 5 引入的微前端架构方案,允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
在Ubuntu24上采用Wine打开SourceInsight
1. 安装wine sudo apt install wine 2. 安装32位库支持,SourceInsight是32位程序 sudo dpkg --add-architecture i386 sudo apt update sudo apt install wine32:i386 3. 验证安装 wine --version 4. 安装必要的字体和库(解决显示问题) sudo apt install fonts-wqy…...
在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...
django blank 与 null的区别
1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是,要注意以下几点: Django的表单验证与null无关:null参数控制的是数据库层面字段是否可以为NULL,而blank参数控制的是Django表单验证时字…...
