昇思25天学习打卡营第2天|MindSpore快速入门
打卡
目录
打卡
快速入门案例:minist图像数据识别任务
案例任务说明
流程
1 加载并处理数据集
2 模型网络构建与定义
3 模型约束定义
4 模型训练
5 模型保存
6 模型推理
相关参考文档入门理解
MindSpore数据处理引擎
模型网络参数初始化
模型优化器
损失函数
代码
安装
从模型训练到预测推理
self_main_train_and_save.py
self_dataprocess.py
self_network.py
self_modeltrain.py
self_modeltest.py
self_predict.py
快速入门案例:minist图像数据识别任务
案例任务说明
MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。
目的是训练一个可以高效识别手写阿拉伯数字的模型。
流程
1 加载并处理数据集
涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。
2 模型网络构建与定义
涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来
自定义网络结构,其中的construct类函数
包含数据(Tensor)的变换过程。。
3 模型约束定义
包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)
4 模型训练
- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。
- 定义测试函数,用来评估模型的性能。
5 模型保存
- 两种保存方式:
1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")
2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
6 模型推理
- 两种加载方式:
1)模型参数加载:
> model = network()
> param_dict = mindspore.load_checkpoint("model.ckpt");
> param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
2)统一的中间表示(Intermediate Representation,IR)的加载:
> mindspore.set_context(mode=mindspore.GRAPH_MODE) > graph = mindspore.load("model.mindir") > model = nn.GraphCell(graph) ## nn.GraphCell 仅支持图模式。 > outputs = model(inputs)
保存与加载 — MindSpore master 文档
相关参考文档入门理解
MindSpore数据处理引擎
MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档
MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。
数据集 Dataset — MindSpore master 文档
数据变换 Transforms — MindSpore master 文档
模型网络参数初始化
Initializer
是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn
中提供的神经网络层封装均提供weight_init
、bias_init
等入参,可以直接使用实例化的Initializer进行参数初始化。
参数初始化 — MindSpore master 文档
模型优化器
优化器 — MindSpore master 文档
损失函数
损失函数 — MindSpore master 文档
代码
安装
pip/conda均可:
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1
从模型训练到预测推理
训练:
python self_main_train_and_save.py
推理:
python self_predict.py
self_main_train_and_save.py
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True) ## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names()) # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64) ## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
self_dataprocess.py
from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
self_network.py
# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)
self_modeltrain.py
# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train() ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f} [{current:>3d}/{size:>3d}]")
self_modeltest.py
from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
self_predict.py
## 加载模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load) ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break
相关文章:

昇思25天学习打卡营第2天|MindSpore快速入门
打卡 目录 打卡 快速入门案例:minist图像数据识别任务 案例任务说明 流程 1 加载并处理数据集 2 模型网络构建与定义 3 模型约束定义 4 模型训练 5 模型保存 6 模型推理 相关参考文档入门理解 MindSpore数据处理引擎 模型网络参数初始化 模型优化器 …...

django之url路径
方式一:path 语法:<<转换器类型:自定义>> 作用:若转换器类型匹配到对应类型的数据,则将数据按照关键字传参的方式传递给视图函数 类型: str: 匹配除了”/“之外的非空字符串。 /test/zvxint: 匹配0或任何…...

【OnlyOffice】桌面应用编辑器,插件开发大赛,等你来挑战
OnlyOffice,桌面应用编辑器,最近版本已从8.0升级到了8.1 从PDF、Word、Excel、PPT等全面进行了升级。随着AI应用持续的火热,OnlyOffice也在不断推出AI相关插件。 因此,在此给大家推荐一下OnlyOffice本次的插件开发大赛。 详细信息…...

[学习笔记]SQL学习笔记(连载中。。。)
学习视频:【数据库】SQL 3小时快速入门 #数据库教程 #SQL教程 #MySQL教程 #database#Python连接数据库 目录 1.SQL的基础知识1.1.表(table)和键(key)1.2.外键、联合主键 2.MySQL安装(略,请自行参考视频)3.基本的MySQL语法3.1.规…...

Buuctf之SimpleRev做法
首先,查个壳,64bit,那就丢进ida64中进行反编译进来之后,我们进入main函数,发现里面没什么东西,那就shiftf12搜索字符串,找到关键字符串,双击进入然后再选中该字符串,ctrl…...

【云原生监控】Prometheus 普罗米修斯从搭建到使用详解
目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…...

【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)
目录 一、前言 二、什么是C模板? 💦泛型编程的思想 💦C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ 🔥非类型模板参数的定义 🔥非类型模板参数的两种类型 ὒ…...

Cookie与Session
Cookie Set-Cookie: sessionIdabc123; ExpiresWed, 09 Jun 2024 10:18:14 GMT; Path/; Secure; HttpOnlySession session作用域 首先需要了解servlet容器可能包含多个web应用。 在servlet容器中同一应用的servlet 对 session数据是可见的,不同应用之间session是相互…...

Nuxt3 的生命周期和钩子函数(十一)
title: Nuxt3 的生命周期和钩子函数(十一) date: 2024/7/5 updated: 2024/7/5 author: cmdragon excerpt: 摘要:本文详细介绍了Nuxt3中几个关键的生命周期钩子和它们的使用方法,包括webpack:done用于Webpack编译完成后执行操作…...

Windows ipconfig命令详解,Windows查看IP地址信息
「作者简介」:冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础著作 《网络安全自学教程》,适合基础薄弱的同学系统化的学习网络安全,用最短的时间掌握最核心的技术。 ipconfig 1、基…...

在C#/Net中使用Mqtt
net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…...

VBA提取word表格内容到excel
这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…...

html+css+js图片手动轮播
源代码在界面图片后面 轮播演示用的几张图片是Bing上的,直接用的几张图片的URL,谁加载可能需要等一下,现实中替换成自己的图片即可 关注一下点个赞吧😄 谢谢大佬 界面图片 源代码 <!DOCTYPE html> <html lang&quo…...

【十三】图解 Spring 核心数据结构:BeanDefinition 其二
图解 Spring 核心数据结构:BeanDefinition 其二 概述 前面写过一篇相关文章作为开篇介绍了一下BeanDefinition,本篇将深入细节来向读者展示BeanDefinition的设计,让我们一起来揭开日常开发中使用的bean的神秘面纱,深入细节透彻理解…...

数据库作业
命令 登陆数据库 mysql -uroot -p123456 --prompt"\u\h:\d--> " 创建数据库zcr create database zcr; 修改数据库zcr字符集为gbk alter database zcr default character set gbk collate gbk_chinese_ci; 选择数据库zcr use zcr 查看数据库zc…...

12、matlab中for循环,if else判断语句,break和continue用法以及switch case语句使用
1、前言 在MATLAB中,for循环用于迭代一个固定次数的循环。可以使用if else语句在循环中进行条件判断,根据条件的不同执行相应的代码块。break和continue可以用于控制循环的执行流程,break用于提前结束循环,而continue用于跳过当前…...

AcWing 3207:门禁系统 ← 桶排序中“桶”的思想
【题目来源】https://www.acwing.com/problem/content/3210/【题目描述】 涛涛最近要负责图书馆的管理工作,需要记录下每天读者的到访情况。 每位读者有一个唯一编号,每条记录用读者的编号来表示。 给出读者的来访记录,请问每一条记录中的读者…...

开发个人Go-ChatGPT--3 服务拆分
开发个人Go-ChatGPT–3 服务拆分 个人Go-ChatGPT项目可拆分用户服务(user),AI模型服务(AiModel),… 每个服务都可以再分为 api 服务和 rpc 服务。api 服务对外,可提供给 app 调用。rpc 服务是…...

Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了
新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048,如果电脑的性能和内存一般的话就可能卡死,解决方案是手动修改安卓模拟器的config文件&…...

从入门到深入,Docker新手学习教程
编译整理|TesterHome社区 作者|Ishaan Gupta 以下为作者观点: Docker 彻底改变了我们开发、交付和运行应用程序的方式。它使开发人员能够将应用程序打包到容器中 - 标准化的可执行组件,将应用程序源代码与在任何环境中运行该代码…...

Postman编写测试脚本
在 Postman 中,编写测试脚本通常使用 JavaScript,这些脚本可以在请求发送前后执行。以下是一些示例代码,展示了如何在 Postman 中使用测试脚本。 1. 测试脚本示例:检查响应状态码 // 测试脚本在请求发送后执行 pm.test("Re…...

代码随想录算法训练Day57|LeetCode200-岛屿数量、LeetCode695-岛屿的最大面积
岛屿数量 题目描述 力扣200-岛屿数量 给你一个由 1(陆地)和 0(水)组成的的二维网格,请你计算网格中岛屿的数量。 岛屿总是被水包围,并且每座岛屿只能由水平方向和/或竖直方向上相邻的陆地连接形成。 此…...

StopWatch的使用
org.springframework.util.StopWatch 是 Spring 框架提供的一个轻量级的计时工具,用于测量代码执行时间。它比 Apache Commons Lang 的 StopWatch 提供了更多的功能,例如累计多个时间段、打印详细报告等。 以下是如何使用 Spring 的 StopWatchÿ…...

MySQL基础篇(三)数据库的修改 删除 备份恢复 查看连接情况
对数据库的修改主要指的是修改数据库的字符集,校验规则。 将test1数据库字符集改为gbk。 数据库的删除: 执行完该数据库就不存在了,对应数据库文件夹被删除,级联删除,里面的数据表全部被删除。 注意:不要随…...

android手机电视相框项目-学员做出个bug版本邀请大家review提意见
背景 前几天给我的vip学员布置了一个android手机/电视相框的项目,具体详情看如下链接: https://mp.weixin.qq.com/s/l2roDoco-o59SLlORENZlA 这个项目我说过不给提供答案哈,让各位学员朋友自己独立思考完成哈,因为尽量想让大家慢…...

web零碎知识2
不知道我的这个axios的包导进去没。 找一下关键词: http请求协议:就是进行交互式的格式 需要定义好 这个式一发一收短连接 而且没有记忆 这个分为三个部分 第一个式请求行,第二个就是请求头 第三个就是请求体 以get方式进行请求的失手请求…...

Android项目框架
Android项目基于Android Studio开发,Android Studio使用Gradle作为项目构建工具。新建工程后可以看到如图所示目录结构,将Android切成Project可以看到完整的Android工程目录结构,如图所示。 图1-2 Android项目目录结构 app目录是一个典型的…...

vue 模糊查询加个禁止属性
vue 模糊查询加个禁止属性 父组件通过属性传,是否禁止输入-------默认可以输入...

MySQL 主从复制中 MHA 工具的研究与实践
MySQL 主从复制中 MHA 工具的研究与实践 一、MHA 工具简介二、MHA 的工作原理三、MHA 配置步骤环境准备1. 在主服务器上配置主从复制2. 在从服务器上配置复制 安装 MHA 工具1. 安装必要的依赖包2. 下载并安装 MHA 配置 MHA1. 创建 MHA 配置文件2. 配置 SSH 免密登录 测试 MHA1.…...

Hi3861 OpenHarmony嵌入式应用入门--TCP Server
本篇使用的是lwip编写tcp服务端。需要提前准备好一个PARAM_HOTSPOT_SSID宏定义的热点,并且密码为PARAM_HOTSPOT_PSK LwIP简介 LwIP是什么? A Lightweight TCP/IP stack 一个轻量级的TCP/IP协议栈 详细介绍请参考LwIP项目官网:lwIP - A Li…...