当前位置: 首页 > news >正文

昇思MindSpore进阶教程--在ResNet-50网络上应用二阶优化实践(下)

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

文章上半部分请查看
在ResNet-50网络上应用二阶优化实践(上)

训练网络

配置模型保存

MindSpore提供了callback机制,可以在训练过程中执行自定义逻辑,这里使用框架提供的ModelCheckpoint函数。 ModelCheckpoint可以保存网络模型和参数,以便进行后续的fine-tuning操作。 TimeMonitor、LossMonitor是MindSpore官方提供的callback函数,可以分别用于监控训练过程中单步迭代时间和loss值的变化。


import mindspore as ms
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitorif __name__ == "__main__":# define callbackstime_cb = TimeMonitor(data_size=step_size)loss_cb = LossMonitor()cb = [time_cb, loss_cb]if config.save_checkpoint:config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,keep_checkpoint_max=config.keep_checkpoint_max)ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)cb += [ckpt_cb]

配置训练网络

通过MindSpore提供的model.train接口可以方便地进行网络的训练。THOR优化器通过降低二阶矩阵更新频率,来减少计算量,提升计算速度,故重新定义一个ModelThor类,继承MindSpore提供的Model类。在ModelThor类中获取THOR的二阶矩阵更新频率控制参数,用户可以通过调整该参数,优化整体的性能。 MindSpore提供Model类向ModelThor类的一键转换接口。


import mindspore as ms
from mindspore import amp
from mindspore.train import Model, ConvertModelUtilsif __name__ == "__main__":...loss_scale = amp.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network)if cfg.optimizer == "Thor":model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,loss_scale_manager=loss_scale, metrics={'acc'},amp_level="O2", keep_batchnorm_fp32=False)  ...

运行脚本

训练脚本定义完成之后,调scripts目录下的shell脚本,启动分布式训练进程。

Atlas训练系列产品

目前MindSpore分布式在Ascend上执行采用单卡单进程运行方式,即每张卡上运行1个进程,进程数量与使用的卡的数量一致。进程均放在后台执行,每个进程创建1个目录,目录名称为train_parallel+ device_id,用来保存日志信息,算子编译信息以及训练的checkpoint文件。下面以使用8张卡的分布式训练脚本为例,演示如何运行脚本。

使用以下命令运行脚本:

bash run_distribute_train.sh <RANK_TABLE_FILE> <DATASET_PATH> [CONFIG_PATH]

脚本需要传入变量RANK_TABLE_FILE,DATASET_PATH和CONFIG_PATH,其中:

  • RANK_TABLE_FILE:组网信息文件的路径。(rank table文件的生成,参考HCCL_TOOL)

  • DATASET_PATH:训练数据集路径。

  • CONFIG_PATH:配置文件路径。

其余环境变量请参考安装教程中的配置项。

训练过程中loss打印示例如下:


epoch: 1 step: 5004, loss is 4.4182425
epoch: 2 step: 5004, loss is 3.740064
epoch: 3 step: 5004, loss is 4.0546017
epoch: 4 step: 5004, loss is 3.7598825
epoch: 5 step: 5004, loss is 3.3744206epoch: 40 step: 5004, loss is 1.6907625
epoch: 41 step: 5004, loss is 1.8217756
epoch: 42 step: 5004, loss is 1.6453942

训练完后,每张卡训练产生的checkpoint文件保存在各自训练目录下,device_0产生的checkpoint文件示例如下:

└─train_parallel0├─ckpt_0├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-42_5004.ckpt│      ......

其中, *.ckpt:指保存的模型参数文件。checkpoint文件名称具体含义:网络名称-epoch数_step数.ckpt。

GPU

在GPU硬件平台上,MindSpore采用OpenMPI的mpirun进行分布式训练,进程创建1个目录,目录名称为train_parallel,用来保存日志信息和训练的checkpoint文件。下面以使用8张卡的分布式训练脚本为例,演示如何运行脚本。

使用以下命令运行脚本:

bash run_distribute_train_gpu.sh <DATASET_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH和CONFIG_PATH,其中:

  • DATASET_PATH:训练数据集路径。

  • CONFIG_PATH:配置文件路径。

在GPU训练时,无需设置DEVICE_ID环境变量,因此在主训练脚本中不需要调用int(os.getenv(‘DEVICE_ID’))来获取卡的物理序号,同时context中也无需传入device_id。我们需要将device_target设置为GPU,并需要调用init()来使能NCCL。

训练过程中loss打印示例如下:


epoch: 1 step: 5004, loss is 4.2546034
epoch: 2 step: 5004, loss is 4.0819564
epoch: 3 step: 5004, loss is 3.7005644
epoch: 4 step: 5004, loss is 3.2668946
epoch: 5 step: 5004, loss is 3.023509epoch: 36 step: 5004, loss is 1.645802

训练完后,保存的模型文件示例如下:

└─train_parallel├─ckpt_0├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-36_5004.ckpt│      ............├─ckpt_7├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-36_5004.ckpt│      ......

模型推理

使用训练过程中保存的checkpoint文件进行推理,验证模型的泛化能力。首先通过load_checkpoint接口加载模型文件,然后调用Model的eval接口对输入图片类别作出预测,再与输入图片的真实类别做比较,得出最终的预测精度值。

定义推理网络

  1. 使用load_checkpoint接口加载模型文件。

  2. 使用model.eval接口读入测试数据集,进行推理。

  3. 计算得出预测精度值。


import mindspore as ms
from mindspore.train import Modelif __name__ == "__main__":...# define netnet = resnet(class_num=config.class_num)# load checkpointparam_dict = ms.load_checkpoint(args_opt.checkpoint_path)ms.load_param_into_net(net, param_dict)net.set_train(False)# define lossif args_opt.dataset == "imagenet2012":if not config.use_label_smooth:config.label_smooth_factor = 0.0loss = CrossEntropySmooth(sparse=True, reduction='mean',smooth_factor=config.label_smooth_factor, num_classes=config.class_num)else:loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# define modelmodel = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})# eval modelres = model.eval(dataset)print("result:", res, "ckpt=", args_opt.checkpoint_path)...

执行推理

推理网络定义完成之后,调用scripts目录下的shell脚本,进行推理。

Atlas训练系列产品

在Atlas训练系列产品平台上,推理的执行命令如下:

bash run_eval.sh <DATASET_PATH> <CHECKPOINT_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH,CHECKPOINT_PATH和<CONFIG_PATH>,其中:

  • DATASET_PATH:推理数据集路径。

  • CHECKPOINT_PATH:保存的checkpoint路径。

  • CONFIG_PATH:配置文件路径。

目前推理使用的是单卡(默认device 0)进行推理,推理的结果如下:

result: {'top_5_accuracy': 0.9295574583866837, 'top_1_accuracy': 0.761443661971831} ckpt=train_parallel0/resnet-42_5004.ckpt
GPU

在GPU硬件平台上,推理的执行命令如下:

  bash run_eval_gpu.sh <DATASET_PATH> <CHECKPOINT_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH,CHECKPOINT_PATH和CONFIG_PATH,其中:

  • DATASET_PATH:推理数据集路径。

  • CHECKPOINT_PATH:保存的checkpoint路径。

  • CONFIG_PATH:配置文件路径。

推理的结果如下:

result: {'top_5_accuracy': 0.9287972151088348, 'top_1_accuracy': 0.7597031049935979} ckpt=train_parallel/resnet-36_5004.ckpt

相关文章:

昇思MindSpore进阶教程--在ResNet-50网络上应用二阶优化实践(下)

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 文章上半部分请查看 在ResNet-50网络上应…...

基于大数据的Python+Django电影票房数据可视化分析系统设计与实现

目录 1 引言 2 系统需求分析 3 技术选型 4 系统架构设计 5 关键技术实现 6 系统实现 7 总结与展望 1 引言 随着数字媒体技术的发展&#xff0c;电影产业已经成为全球经济文化不可或缺的一部分。电影不仅是艺术表达的形式&#xff0c;更是大众娱乐的重要来源。在这个背景…...

实景三维技术对光伏产业的发展具有哪些优势?

实景三维技术对光伏产业的发展具有显著的优势&#xff0c;主要体现在提高选址准确性、优化用地规划、促进数据融合应用以及赋能文旅服务领域。‌ 提高选址准确性‌&#xff1a;通过构建高精度的三维地形模型&#xff0c;结合卫星遥感、无人机测绘等技术手段&#xff0c;实景三维…...

四非人的保研之路,2024(2025届)四非计算机的保研经验分享(西南交通、苏大nlp、西电、北邮、山软、山计、电科、厦大等)

文章目录 一、个人背景二、夏令营北京邮电大学CS西南交通大学CS深圳大学CS苏州大学NLP南开大学CS 三、预推免北京邮电大学CS华东师范大学 CS和大数据电子科技大学 CS东北大学 CS厦门大学 信息学院山东大学 CS和SE西安电子科技大学 CS 四、个人经验五、上岸 一、个人背景 学校专…...

UE5.4.3 录屏回放系统ReplaySystem蓝图版

这是ReplaySystem的蓝图使用方法版&#xff0c;以第三人称模版为例&#xff0c;需要几个必须步骤 项目config内DefaultEngine.ini的最后添加&#xff1a; [/Script/Engine.GameEngine] NetDriverDefinitions(DefName"DemoNetDriver",DriverClassName"/Script/…...

ECCV 2024 | 融合跨模态先验与扩散模型,快手处理大模型让视频画面更清晰!

计算机视觉领域顶级会议 European Conference on Computer Vision&#xff08;ECCV 2024&#xff09;将于9月29日至10月4日在意大利米兰召开&#xff0c;快手音视频技术部联合清华大学所发表的题为《XPSR: Cross-modal Priors for Diffusion-based Image Super-Resolution》——…...

9--苍穹外卖-SpringBoot项目中Redis的介绍及其使用实例 详解

目录 Redis入门 Redis简介 Redis服务启动与停止 服务启动命令 Redis数据类型 5种常用数据类型介绍 各种数据类型的特点 Redis常用命令 字符串操作命令 哈希操作命令 列表操作命令 集合操作命令 有序集合操作命令 通用命令 在java中操作Redis Redis的Java客户端 …...

【EXCEL数据处理】000014 案例 EXCEL分类汇总、定位和创建组。附多个操作案例。

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 【EXCEL数据处理】000014 案例 EXCEL分类汇总、定位和创建组。附多个操…...

Windows环境Apache httpd 2.4 web服务器加载PHP8:Hello,world!

Windows环境Apache httpd 2.4 web服务器加载PHP8&#xff1a;Hello&#xff0c;world&#xff01; &#xff08;1&#xff09;首先需要安装apache httpd 2.4 web服务器&#xff1a; Windows安装启动apache httpd 2.4 web服务器-CSDN博客文章浏览阅读222次&#xff0c;点赞5次&…...

Spring框架使用Api接口实现AOP的切面编程、两种方式的程序示例以及Java各数据类型及基本数据类型的默认值/最大值/最小值列表

一、Spring框架使用Api接口-继承类实现AOP的切面编程示例 要使用Spring框架AOP&#xff0c;除了要导入spring框架包外&#xff0c;还需要导入一个织入的包org.aspectj&#xff0c;具体maven依赖如下&#xff1a; <dependency><groupId>org.springframework</gr…...

【达梦数据库】尽可能 disql 的使用效果与异构数据库一致

文章目录 前言disql 效果优化参数设置参数说明 mysql参数设置参数说明 db2参数设置参数说明 待补充 前言 让达梦的disql 使用起来更跟手&#xff0c;与其他优质数据库的命令行工具通过配置参数的方式尽可能一致&#xff0c;提高使用体验&#xff0c;长期整理中~~~ 测试版本&…...

【研1深度学习】《神经网络和深度学习》阅读笔记(记录中......

9.27 语义鸿沟&#xff1a; 是指输入数据的底层特征和高层语义信息之间的不一致性和查一下。如果可以有一个好的表示在某种程度上能够反映出数据的高层语义特征&#xff0c;那么我们就能相对容易的构建后续的机器学习模型。嵌入&#xff08;Embedding&#xff09;&#xff1a;…...

十一不停歇-学习ROS2第一天 (10.2 10:45)

话题通信 1.1 发布第一个节点&#xff1a; import rclpy #导入此类模块 rcl类型 from rclpy.node import Node #从这个子模块中导入这类函数 def main(): #定义这个函数 rclpy.init() #使用初始化函数 node Node(hello_python) 将类函数里面的内容调给…...

Java高效编程(14):考虑实现 `Comparable

解锁Python编程的无限可能&#xff1a;《奇妙的Python》带你漫游代码世界 与其他方法不同&#xff0c;compareTo 并非 Object 类中声明的&#xff0c;而是 Comparable 接口的唯一方法。compareTo 方法与 equals 类似&#xff0c;但它不仅支持相等性比较&#xff0c;还允许顺序…...

华为昇腾CANN训练营2024第二季--Ascend C算子开发能力认证(中级)题目和经验分享

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 正文开始 华为昇腾CANN训练营2024第二季…...

实战OpenCV之形态学操作

基础入门 形态学操作是一种基于图像形状的处理方法,主要用于结构分析,比如:边缘检测、轮廓提取、噪声去除等。这些操作通常使用一个称为“结构元素”(Structuring Element)的核来进行,结构元素可以是任何形状,但最常见的有矩形和圆形。形态学操作的核心在于通过结构元素…...

矩阵的特征值和特征向量

矩阵的特征值和特征向量是线性代数中非常重要的概念&#xff0c;用于描述矩阵对向量的作用&#xff0c;特别是在矩阵对向量的线性变换中的表现。它们帮助我们理解矩阵在某些方向上的缩放或旋转效果。 1. 特征值和特征向量的定义&#xff1a; 给定一个 n n n \times n nn 的…...

(11)MATLAB莱斯(Rician)衰落信道仿真2

文章目录 前言一、莱斯衰落信道仿真模型二、仿真代码与结果1.仿真代码2.仿真结果画图 三、后续&#xff1a;四、参考文献&#xff1a; 前言 首先给出莱斯衰落信道仿真模型&#xff0c;该模型由直射路径分量和反射路径分量组成&#xff0c;其中反射路径分量由瑞利衰落信道模型构…...

ComfyUI局部重绘换衣讲解

一、下载插件 ComfyUI-Impact-Pack 下载地址 https://github.com/ltdrdata/ComfyUI-Impact-Pack 主要用到sam Detector去绘制衣服蒙版和高斯模糊蒙版&#xff0c;高斯模糊让蒙版边缘更加柔和 sams模型 放在E:\Comfyui\ComfyUI\models\sams二、换衣思路 文生图或直接上传…...

Android——添加联系人

概述 方式一&#xff1a;使用ContentResolver多次写入&#xff0c;每次写入一个字段 第一步 往手机联系人应用中的raw_contacts表添加一条记录 raw_contacts表 ContentValues values new ContentValues();// 往 raw_contacts 添加联系人记录&#xff0c;并获取添加后的联…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

Spark 之 入门讲解详细版(1)

1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室&#xff08;Algorithms, Machines, and People Lab&#xff09;开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目&#xff0c;8个月后成为Apache顶级项目&#xff0c;速度之快足见过人之处&…...

关于nvm与node.js

1 安装nvm 安装过程中手动修改 nvm的安装路径&#xff0c; 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解&#xff0c;但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后&#xff0c;通常在该文件中会出现以下配置&…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

云原生安全实战:API网关Kong的鉴权与限流详解

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关&#xff08;API Gateway&#xff09; API网关是微服务架构中的核心组件&#xff0c;负责统一管理所有API的流量入口。它像一座…...

Windows安装Miniconda

一、下载 https://www.anaconda.com/download/success 二、安装 三、配置镜像源 Anaconda/Miniconda pip 配置清华镜像源_anaconda配置清华源-CSDN博客 四、常用操作命令 Anaconda/Miniconda 基本操作命令_miniconda创建环境命令-CSDN博客...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...