当前位置: 首页 > news >正文

昇思25天学习打卡营第2天|MindSpore快速入门

打卡

目录

打卡

快速入门案例:minist图像数据识别任务

案例任务说明

流程

1 加载并处理数据集

2 模型网络构建与定义

3 模型约束定义

4 模型训练

5 模型保存

6 模型推理

相关参考文档入门理解

MindSpore数据处理引擎

模型网络参数初始化

模型优化器

损失函数

代码

安装

从模型训练到预测推理

self_main_train_and_save.py

self_dataprocess.py

self_network.py

self_modeltrain.py

self_modeltest.py

self_predict.py


快速入门案例:minist图像数据识别任务

案例任务说明

MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。

目的是训练一个可以高效识别手写阿拉伯数字的模型。

流程

1 加载并处理数据集

涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。

2 模型网络构建与定义

涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来自定义网络结构,其中的construct类函数包含数据(Tensor)的变换过程。。

3 模型约束定义

包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)

4 模型训练

- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

- 定义测试函数,用来评估模型的性能。

5 模型保存

- 两种保存方式:

1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")

2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

6 模型推理

- 两种加载方式:

1)模型参数加载: 

> model = network()

> param_dict = mindspore.load_checkpoint("model.ckpt");  

param_not_load, _ = mindspore.load_param_into_net(model, param_dict)

2)统一的中间表示(Intermediate Representation,IR)的加载:

> mindspore.set_context(mode=mindspore.GRAPH_MODE)
> graph = mindspore.load("model.mindir")
> model = nn.GraphCell(graph)  ## nn.GraphCell 仅支持图模式。
> outputs = model(inputs)

保存与加载 — MindSpore master 文档

相关参考文档入门理解

MindSpore数据处理引擎

MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档

MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。

数据集 Dataset — MindSpore master 文档

数据变换 Transforms — MindSpore master 文档

模型网络参数初始化

Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_initbias_init等入参,可以直接使用实例化的Initializer进行参数初始化。

参数初始化 — MindSpore master 文档

模型优化器

优化器 — MindSpore master 文档

损失函数

损失函数 — MindSpore master 文档

代码

安装

pip/conda均可:

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

从模型训练到预测推理

训练:

python self_main_train_and_save.py

推理:

python self_predict.py

self_main_train_and_save.py

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件 
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)    ## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names())   # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)  
test_dataset = datapipe(test_dataset, 64)  ## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn 
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

self_dataprocess.py

from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset

self_network.py

# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)

self_modeltrain.py

# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()     ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

self_modeltest.py

from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

self_predict.py

## 加载模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)  
print(param_not_load)   ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

相关文章:

昇思25天学习打卡营第2天|MindSpore快速入门

打卡 目录 打卡 快速入门案例:minist图像数据识别任务 案例任务说明 流程 1 加载并处理数据集 2 模型网络构建与定义 3 模型约束定义 4 模型训练 5 模型保存 6 模型推理 相关参考文档入门理解 MindSpore数据处理引擎 模型网络参数初始化 模型优化器 …...

django之url路径

方式一&#xff1a;path 语法&#xff1a;<<转换器类型:自定义>> 作用&#xff1a;若转换器类型匹配到对应类型的数据&#xff0c;则将数据按照关键字传参的方式传递给视图函数 类型&#xff1a; str: 匹配除了”/“之外的非空字符串。 /test/zvxint: 匹配0或任何…...

【OnlyOffice】桌面应用编辑器,插件开发大赛,等你来挑战

OnlyOffice&#xff0c;桌面应用编辑器&#xff0c;最近版本已从8.0升级到了8.1 从PDF、Word、Excel、PPT等全面进行了升级。随着AI应用持续的火热&#xff0c;OnlyOffice也在不断推出AI相关插件。 因此&#xff0c;在此给大家推荐一下OnlyOffice本次的插件开发大赛。 详细信息…...

[学习笔记]SQL学习笔记(连载中。。。)

学习视频&#xff1a;【数据库】SQL 3小时快速入门 #数据库教程 #SQL教程 #MySQL教程 #database#Python连接数据库 目录 1.SQL的基础知识1.1.表(table)和键(key)1.2.外键、联合主键 2.MySQL安装&#xff08;略&#xff0c;请自行参考视频&#xff09;3.基本的MySQL语法3.1.规…...

Buuctf之SimpleRev做法

首先&#xff0c;查个壳&#xff0c;64bit&#xff0c;那就丢进ida64中进行反编译进来之后&#xff0c;我们进入main函数&#xff0c;发现里面没什么东西&#xff0c;那就shiftf12搜索字符串&#xff0c;找到关键字符串&#xff0c;双击进入然后再选中该字符串&#xff0c;ctrl…...

【云原生监控】Prometheus 普罗米修斯从搭建到使用详解

目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…...

【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)

目录 一、前言 二、什么是C模板&#xff1f; &#x1f4a6;泛型编程的思想 &#x1f4a6;C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ &#x1f525;非类型模板参数的定义 &#x1f525;非类型模板参数的两种类型 &#x1f52…...

Cookie与Session

Cookie Set-Cookie: sessionIdabc123; ExpiresWed, 09 Jun 2024 10:18:14 GMT; Path/; Secure; HttpOnlySession session作用域 首先需要了解servlet容器可能包含多个web应用。 在servlet容器中同一应用的servlet 对 session数据是可见的&#xff0c;不同应用之间session是相互…...

Nuxt3 的生命周期和钩子函数(十一)

title: Nuxt3 的生命周期和钩子函数&#xff08;十一&#xff09; date: 2024/7/5 updated: 2024/7/5 author: cmdragon excerpt: 摘要&#xff1a;本文详细介绍了Nuxt3中几个关键的生命周期钩子和它们的使用方法&#xff0c;包括webpack:done用于Webpack编译完成后执行操作…...

Windows ipconfig命令详解,Windows查看IP地址信息

「作者简介」&#xff1a;冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础著作 《网络安全自学教程》&#xff0c;适合基础薄弱的同学系统化的学习网络安全&#xff0c;用最短的时间掌握最核心的技术。 ipconfig 1、基…...

在C#/Net中使用Mqtt

net中MQTT的应用场景 c#常用来开发上位机程序&#xff0c;或者其他一些跟设备打交道比较多的系统&#xff0c;所以会经常作为拥有数据的终端&#xff0c;可以用来采集上传数据&#xff0c;而MQTT也是物联网常用的协议&#xff0c;所以下面介绍在C#开发中使用MQTT。 安装MQTTn…...

VBA提取word表格内容到excel

这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…...

html+css+js图片手动轮播

源代码在界面图片后面 轮播演示用的几张图片是Bing上的&#xff0c;直接用的几张图片的URL&#xff0c;谁加载可能需要等一下&#xff0c;现实中替换成自己的图片即可 关注一下点个赞吧&#x1f604; 谢谢大佬 界面图片 源代码 <!DOCTYPE html> <html lang&quo…...

【十三】图解 Spring 核心数据结构:BeanDefinition 其二

图解 Spring 核心数据结构&#xff1a;BeanDefinition 其二 概述 前面写过一篇相关文章作为开篇介绍了一下BeanDefinition&#xff0c;本篇将深入细节来向读者展示BeanDefinition的设计&#xff0c;让我们一起来揭开日常开发中使用的bean的神秘面纱&#xff0c;深入细节透彻理解…...

数据库作业

命令 登陆数据库 mysql -uroot -p123456 --prompt"\u\h:\d--> " 创建数据库zcr create database zcr&#xff1b; 修改数据库zcr字符集为gbk alter database zcr default character set gbk collate gbk_chinese_ci; 选择数据库zcr use zcr 查看数据库zc…...

12、matlab中for循环,if else判断语句,break和continue用法以及switch case语句使用

1、前言 在MATLAB中&#xff0c;for循环用于迭代一个固定次数的循环。可以使用if else语句在循环中进行条件判断&#xff0c;根据条件的不同执行相应的代码块。break和continue可以用于控制循环的执行流程&#xff0c;break用于提前结束循环&#xff0c;而continue用于跳过当前…...

AcWing 3207:门禁系统 ← 桶排序中“桶”的思想

【题目来源】https://www.acwing.com/problem/content/3210/【题目描述】 涛涛最近要负责图书馆的管理工作&#xff0c;需要记录下每天读者的到访情况。 每位读者有一个唯一编号&#xff0c;每条记录用读者的编号来表示。 给出读者的来访记录&#xff0c;请问每一条记录中的读者…...

开发个人Go-ChatGPT--3 服务拆分

开发个人Go-ChatGPT–3 服务拆分 个人Go-ChatGPT项目可拆分用户服务&#xff08;user&#xff09;&#xff0c;AI模型服务&#xff08;AiModel&#xff09;&#xff0c;… 每个服务都可以再分为 api 服务和 rpc 服务。api 服务对外&#xff0c;可提供给 app 调用。rpc 服务是…...

Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了

新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死&#xff0c;鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048&#xff0c;如果电脑的性能和内存一般的话就可能卡死&#xff0c;解决方案是手动修改安卓模拟器的config文件&…...

从入门到深入,Docker新手学习教程

编译整理&#xff5c;TesterHome社区 作者&#xff5c;Ishaan Gupta 以下为作者观点&#xff1a; Docker 彻底改变了我们开发、交付和运行应用程序的方式。它使开发人员能够将应用程序打包到容器中 - 标准化的可执行组件&#xff0c;将应用程序源代码与在任何环境中运行该代码…...

别再乱用分支了!Flowable四种网关(排他/并行/包容/事件)实战选型指南

Flowable四大网关实战选型&#xff1a;从混乱到精准的决策艺术当你在设计一个请假审批流程时&#xff0c;是否遇到过这样的困惑&#xff1a;部门经理审批后需要同时通知HR和财务&#xff0c;但某些特殊情况下又需要跳过财务直接归档&#xff1f;这种看似简单的业务需求&#xf…...

DIY复刻经典:Texar Audio Prism动态处理器克隆套件全攻略

1. 项目概述&#xff1a;Texar Audio Prism 克隆套件如果你在专业音频圈子里混过一段时间&#xff0c;尤其是对上世纪八九十年代那些经典的、带点“魔法”色彩的外置动态处理器感兴趣&#xff0c;那么“Texar Audio Prism”这个名字你大概率不会陌生。它不是最常见的1176或者LA…...

FeHelper前端助手:30+开发工具集,让你的浏览器变身效率神器

FeHelper前端助手&#xff1a;30开发工具集&#xff0c;让你的浏览器变身效率神器 【免费下载链接】FeHelper &#x1f60d;FeHelper--Web前端助手&#xff08;Awesome&#xff01;Chrome & Firefox & MS-Edge Extension, All in one Toolbox!&#xff09; 项目地址:…...

3步解锁网易云音乐NCM加密:让音乐真正属于你

3步解锁网易云音乐NCM加密&#xff1a;让音乐真正属于你 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 还在为下载的网易云音乐只能在特定客户端播放而烦恼吗&#xff1f;当你精心收藏的歌曲被NCM格式"锁"在单一平台时&a…...

孤舟笔记 互联网常用框架篇二 Dubbo服务请求失败怎么处理?集群容错策略你用过几种

文章目录先说结论Failover&#xff1a;换家店试试Failfast&#xff1a;不行就算了Failsafe&#xff1a;忘了这事Failback&#xff1a;回头再说Forking&#xff1a;同时点几家Broadcast&#xff1a;通知所有人怎么选择回答技巧与点评加分回答面试官点评个人网站分布式系统中&…...

告别元素变动导致的报错:探索自动化测试脚本的 AI“自愈”能力

前言:一个所有测试人都经历过的噩梦 周三晚上十一点,CI/CD流水线再次亮起红灯。 你打开日志,满屏的NoSuchElementException扑面而来。仔细一看——前端团队在昨天的版本中重构了登录页面的DOM结构,原本的#login-btn变成了#signin-button-v2,30个测试用例因此全军覆没。 …...

Keil µVision链接器错误204解决方案

1. 问题现象与背景解析最近在使用Keil Vision进行嵌入式开发时&#xff0c;不少工程师遇到了一个令人头疼的链接器错误。具体表现为编译时出现"FATAL ERROR 204: INVALID KEYWORD"的致命错误&#xff0c;错误位置指向链接器控制文件中的特定行。这个问题在C166和C51两…...

Taotoken的审计日志功能为企业API安全与合规管理提供支持

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 Taotoken的审计日志功能为企业API安全与合规管理提供支持 当企业决定将大模型能力集成到内部业务流程中时&#xff0c;IT管理员和安…...

抖音内容批量下载实战:从零开始构建个人视频资料库

抖音内容批量下载实战&#xff1a;从零开始构建个人视频资料库 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallback support.…...

告别手动预约:i茅台自动预约系统5分钟部署指南

告别手动预约&#xff1a;i茅台自动预约系统5分钟部署指南 【免费下载链接】campus-imaotai i茅台app自动预约&#xff0c;每日自动预约&#xff0c;支持docker一键部署&#xff08;本项目不提供成品&#xff0c;使用的是已淘汰的算法&#xff09; 项目地址: https://gitcode…...