当前位置: 首页 > news >正文

昇思25天学习打卡营第2天|MindSpore快速入门

打卡

目录

打卡

快速入门案例:minist图像数据识别任务

案例任务说明

流程

1 加载并处理数据集

2 模型网络构建与定义

3 模型约束定义

4 模型训练

5 模型保存

6 模型推理

相关参考文档入门理解

MindSpore数据处理引擎

模型网络参数初始化

模型优化器

损失函数

代码

安装

从模型训练到预测推理

self_main_train_and_save.py

self_dataprocess.py

self_network.py

self_modeltrain.py

self_modeltest.py

self_predict.py


快速入门案例:minist图像数据识别任务

案例任务说明

MINIST数据集是有标签的图像数据,图像数据是0-9的手写阿拉伯数字。其中,训练集有6W个,测试集1W个。

目的是训练一个可以高效识别手写阿拉伯数字的模型。

流程

1 加载并处理数据集

涉及到的mindspore接口 mindspore.dataset。例如对数据集的map、batch、shuffle等操作,数据列名获取,对数据集进行迭代访问、查看数据和标签的shape和datatype等。

2 模型网络构建与定义

涉及到 mindspore.nn 类。例如用户可继承nn.Cell类来自定义网络结构,其中的construct类函数包含数据(Tensor)的变换过程。。

3 模型约束定义

包括损失函数、优化器等。如 nn.CrossEntropyLoss() 、nn.SGD(model.trainable_params(), 1e-2)

4 模型训练

- 定义训练函数,用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

- 定义测试函数,用来评估模型的性能。

5 模型保存

- 两种保存方式:

1)模型参数保存:mindspore.save_checkpoint(model, "model.ckpt")

2)统一的中间表示(Intermediate Representation,IR)的保存,MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

6 模型推理

- 两种加载方式:

1)模型参数加载: 

> model = network()

> param_dict = mindspore.load_checkpoint("model.ckpt");  

param_not_load, _ = mindspore.load_param_into_net(model, param_dict)

2)统一的中间表示(Intermediate Representation,IR)的加载:

> mindspore.set_context(mode=mindspore.GRAPH_MODE)
> graph = mindspore.load("model.mindir")
> model = nn.GraphCell(graph)  ## nn.GraphCell 仅支持图模式。
> outputs = model(inputs)

保存与加载 — MindSpore master 文档

相关参考文档入门理解

MindSpore数据处理引擎

MindSpore 通过对外暴露API层来构建数据图;内部的Data Processing Pipeline 层用来进行数据加载和预处理多步并行流水线。
高性能数据处理引擎 — MindSpore master 文档

MindSpore 通过数据集(Dataset)和数据变换(Transforms)实现高效的数据预处理。

数据集 Dataset — MindSpore master 文档

数据变换 Transforms — MindSpore master 文档

模型网络参数初始化

Initializer是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。mindspore.nn中提供的神经网络层封装均提供weight_initbias_init等入参,可以直接使用实例化的Initializer进行参数初始化。

参数初始化 — MindSpore master 文档

模型优化器

优化器 — MindSpore master 文档

损失函数

损失函数 — MindSpore master 文档

代码

安装

pip/conda均可:

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

从模型训练到预测推理

训练:

python self_main_train_and_save.py

推理:

python self_predict.py

self_main_train_and_save.py

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# 用download库从公开华为云obs桶下载 MINIST 数据集并解压。因为mindspore.dataset 提供的接口仅支持解压后的数据文件 
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)    ## 1 加载数据集
train_dataset = MnistDataset('MNIST_Data/train', shuffle=False)
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names())   # 打印数据集中包含的数据列名,用于dataset的预处理。输出['image', 'label']## 2 MindSpore的dataset使用数据处理流水线,这里将处理好的数据集打包为大小为64的batch。
from self_dataprocess import datapipe
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)  
test_dataset = datapipe(test_dataset, 64)  ## 3 数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。可使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问,查看数据和标签的shape和datatype。
for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")break“”“Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32Shape of label: (64,) Int32”“”
for data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break## 4 模型训练
from self_network import Network
from self_modeltrain import train, loss_fn 
from self_modelteset import test
model = Network()
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")## 5 保存模型
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

self_dataprocess.py

from mindspore.dataset import vision, transforms
def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset

self_network.py

# Define model
from mindspore import nnclass Network(nn.Cell): def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsdef check_network():model = Network()print(model)

self_modeltrain.py

# Instantiate loss function and optimizer
from mindspore import nnloss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()     ## 设置当前Cell和所有子Cell的训练模式。对于训练和预测具有不同结构的网络层(如 BatchNorm),将通过这个属性区分分支。如果设置为True,则执行训练分支,否则执行另一个分支。默认Truefor batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

self_modeltest.py

from mindspore import nn def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

self_predict.py

## 加载模型
from self_network import Network# Instantiate a random initialized model
model = Network()# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)  
print(param_not_load)   ## param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。## 加载后的模型可以直接用于预测推理。
model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

相关文章:

昇思25天学习打卡营第2天|MindSpore快速入门

打卡 目录 打卡 快速入门案例:minist图像数据识别任务 案例任务说明 流程 1 加载并处理数据集 2 模型网络构建与定义 3 模型约束定义 4 模型训练 5 模型保存 6 模型推理 相关参考文档入门理解 MindSpore数据处理引擎 模型网络参数初始化 模型优化器 …...

django之url路径

方式一&#xff1a;path 语法&#xff1a;<<转换器类型:自定义>> 作用&#xff1a;若转换器类型匹配到对应类型的数据&#xff0c;则将数据按照关键字传参的方式传递给视图函数 类型&#xff1a; str: 匹配除了”/“之外的非空字符串。 /test/zvxint: 匹配0或任何…...

【OnlyOffice】桌面应用编辑器,插件开发大赛,等你来挑战

OnlyOffice&#xff0c;桌面应用编辑器&#xff0c;最近版本已从8.0升级到了8.1 从PDF、Word、Excel、PPT等全面进行了升级。随着AI应用持续的火热&#xff0c;OnlyOffice也在不断推出AI相关插件。 因此&#xff0c;在此给大家推荐一下OnlyOffice本次的插件开发大赛。 详细信息…...

[学习笔记]SQL学习笔记(连载中。。。)

学习视频&#xff1a;【数据库】SQL 3小时快速入门 #数据库教程 #SQL教程 #MySQL教程 #database#Python连接数据库 目录 1.SQL的基础知识1.1.表(table)和键(key)1.2.外键、联合主键 2.MySQL安装&#xff08;略&#xff0c;请自行参考视频&#xff09;3.基本的MySQL语法3.1.规…...

Buuctf之SimpleRev做法

首先&#xff0c;查个壳&#xff0c;64bit&#xff0c;那就丢进ida64中进行反编译进来之后&#xff0c;我们进入main函数&#xff0c;发现里面没什么东西&#xff0c;那就shiftf12搜索字符串&#xff0c;找到关键字符串&#xff0c;双击进入然后再选中该字符串&#xff0c;ctrl…...

【云原生监控】Prometheus 普罗米修斯从搭建到使用详解

目录 一、前言 二、服务监控概述 2.1 什么是微服务监控 2.2 微服务监控指标 2.3 微服务监控工具 三、Prometheus概述 3.1 Prometheus是什么 3.2 Prometheus 特点 3.3 Prometheus 架构图 3.3.1 Prometheus核心组件 3.3.2 Prometheus 工作流程 3.4 Prometheus 应用场景…...

【C++】模板进阶--保姆级解析(什么是非类型模板参数?什么是模板的特化?模板的特化如何应用?)

目录 一、前言 二、什么是C模板&#xff1f; &#x1f4a6;泛型编程的思想 &#x1f4a6;C模板的分类 三、非类型模板参数 ⚡问题引入⚡ ⚡非类型模板参数的使用⚡ &#x1f525;非类型模板参数的定义 &#x1f525;非类型模板参数的两种类型 &#x1f52…...

Cookie与Session

Cookie Set-Cookie: sessionIdabc123; ExpiresWed, 09 Jun 2024 10:18:14 GMT; Path/; Secure; HttpOnlySession session作用域 首先需要了解servlet容器可能包含多个web应用。 在servlet容器中同一应用的servlet 对 session数据是可见的&#xff0c;不同应用之间session是相互…...

Nuxt3 的生命周期和钩子函数(十一)

title: Nuxt3 的生命周期和钩子函数&#xff08;十一&#xff09; date: 2024/7/5 updated: 2024/7/5 author: cmdragon excerpt: 摘要&#xff1a;本文详细介绍了Nuxt3中几个关键的生命周期钩子和它们的使用方法&#xff0c;包括webpack:done用于Webpack编译完成后执行操作…...

Windows ipconfig命令详解,Windows查看IP地址信息

「作者简介」&#xff1a;冬奥会网络安全中国代表队&#xff0c;CSDN Top100&#xff0c;就职奇安信多年&#xff0c;以实战工作为基础著作 《网络安全自学教程》&#xff0c;适合基础薄弱的同学系统化的学习网络安全&#xff0c;用最短的时间掌握最核心的技术。 ipconfig 1、基…...

在C#/Net中使用Mqtt

net中MQTT的应用场景 c#常用来开发上位机程序&#xff0c;或者其他一些跟设备打交道比较多的系统&#xff0c;所以会经常作为拥有数据的终端&#xff0c;可以用来采集上传数据&#xff0c;而MQTT也是物联网常用的协议&#xff0c;所以下面介绍在C#开发中使用MQTT。 安装MQTTn…...

VBA提取word表格内容到excel

这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…...

html+css+js图片手动轮播

源代码在界面图片后面 轮播演示用的几张图片是Bing上的&#xff0c;直接用的几张图片的URL&#xff0c;谁加载可能需要等一下&#xff0c;现实中替换成自己的图片即可 关注一下点个赞吧&#x1f604; 谢谢大佬 界面图片 源代码 <!DOCTYPE html> <html lang&quo…...

【十三】图解 Spring 核心数据结构:BeanDefinition 其二

图解 Spring 核心数据结构&#xff1a;BeanDefinition 其二 概述 前面写过一篇相关文章作为开篇介绍了一下BeanDefinition&#xff0c;本篇将深入细节来向读者展示BeanDefinition的设计&#xff0c;让我们一起来揭开日常开发中使用的bean的神秘面纱&#xff0c;深入细节透彻理解…...

数据库作业

命令 登陆数据库 mysql -uroot -p123456 --prompt"\u\h:\d--> " 创建数据库zcr create database zcr&#xff1b; 修改数据库zcr字符集为gbk alter database zcr default character set gbk collate gbk_chinese_ci; 选择数据库zcr use zcr 查看数据库zc…...

12、matlab中for循环,if else判断语句,break和continue用法以及switch case语句使用

1、前言 在MATLAB中&#xff0c;for循环用于迭代一个固定次数的循环。可以使用if else语句在循环中进行条件判断&#xff0c;根据条件的不同执行相应的代码块。break和continue可以用于控制循环的执行流程&#xff0c;break用于提前结束循环&#xff0c;而continue用于跳过当前…...

AcWing 3207:门禁系统 ← 桶排序中“桶”的思想

【题目来源】https://www.acwing.com/problem/content/3210/【题目描述】 涛涛最近要负责图书馆的管理工作&#xff0c;需要记录下每天读者的到访情况。 每位读者有一个唯一编号&#xff0c;每条记录用读者的编号来表示。 给出读者的来访记录&#xff0c;请问每一条记录中的读者…...

开发个人Go-ChatGPT--3 服务拆分

开发个人Go-ChatGPT–3 服务拆分 个人Go-ChatGPT项目可拆分用户服务&#xff08;user&#xff09;&#xff0c;AI模型服务&#xff08;AiModel&#xff09;&#xff0c;… 每个服务都可以再分为 api 服务和 rpc 服务。api 服务对外&#xff0c;可提供给 app 调用。rpc 服务是…...

Android --- 新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死,鼠标和键盘都操作不了

新电脑安装Android Studio 使用 Android 内置模拟器电脑直接卡死&#xff0c;鼠标和键盘都操作不了 大概原因就是,初始化默认Google的安卓模拟器占用的RAM内存是2048&#xff0c;如果电脑的性能和内存一般的话就可能卡死&#xff0c;解决方案是手动修改安卓模拟器的config文件&…...

从入门到深入,Docker新手学习教程

编译整理&#xff5c;TesterHome社区 作者&#xff5c;Ishaan Gupta 以下为作者观点&#xff1a; Docker 彻底改变了我们开发、交付和运行应用程序的方式。它使开发人员能够将应用程序打包到容器中 - 标准化的可执行组件&#xff0c;将应用程序源代码与在任何环境中运行该代码…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

通过Wrangler CLI在worker中创建数据库和表

官方使用文档&#xff1a;Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后&#xff0c;会在本地和远程创建数据库&#xff1a; npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库&#xff1a; 现在&#xff0c;您的Cloudfla…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术&#xff0c;它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton)&#xff1a;由层级结构的骨头组成&#xff0c;类似于人体骨骼蒙皮 (Mesh Skinning)&#xff1a;将模型网格顶点绑定到骨骼上&#xff0c;使骨骼移动…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习视频及资料&#xff0c;尽在聚客AI学院。 本文全面剖析RNN核心原理&#xff0c;深入讲解梯度消失/爆炸问题&#xff0c;并通过LSTM/GRU结构实现解决方案&#xff0c;提供时间序列预测和文本生成…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

NPOI Excel用OLE对象的形式插入文件附件以及插入图片

static void Main(string[] args) {XlsWithObjData();Console.WriteLine("输出完成"); }static void XlsWithObjData() {// 创建工作簿和单元格,只有HSSFWorkbook,XSSFWorkbook不可以HSSFWorkbook workbook new HSSFWorkbook();HSSFSheet sheet (HSSFSheet)workboo…...