当前位置: 首页 > news >正文

昇思MindSpore进阶教程--在ResNet-50网络上应用二阶优化实践(下)

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

文章上半部分请查看
在ResNet-50网络上应用二阶优化实践(上)

训练网络

配置模型保存

MindSpore提供了callback机制,可以在训练过程中执行自定义逻辑,这里使用框架提供的ModelCheckpoint函数。 ModelCheckpoint可以保存网络模型和参数,以便进行后续的fine-tuning操作。 TimeMonitor、LossMonitor是MindSpore官方提供的callback函数,可以分别用于监控训练过程中单步迭代时间和loss值的变化。


import mindspore as ms
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitorif __name__ == "__main__":# define callbackstime_cb = TimeMonitor(data_size=step_size)loss_cb = LossMonitor()cb = [time_cb, loss_cb]if config.save_checkpoint:config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,keep_checkpoint_max=config.keep_checkpoint_max)ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)cb += [ckpt_cb]

配置训练网络

通过MindSpore提供的model.train接口可以方便地进行网络的训练。THOR优化器通过降低二阶矩阵更新频率,来减少计算量,提升计算速度,故重新定义一个ModelThor类,继承MindSpore提供的Model类。在ModelThor类中获取THOR的二阶矩阵更新频率控制参数,用户可以通过调整该参数,优化整体的性能。 MindSpore提供Model类向ModelThor类的一键转换接口。


import mindspore as ms
from mindspore import amp
from mindspore.train import Model, ConvertModelUtilsif __name__ == "__main__":...loss_scale = amp.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network)if cfg.optimizer == "Thor":model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,loss_scale_manager=loss_scale, metrics={'acc'},amp_level="O2", keep_batchnorm_fp32=False)  ...

运行脚本

训练脚本定义完成之后,调scripts目录下的shell脚本,启动分布式训练进程。

Atlas训练系列产品

目前MindSpore分布式在Ascend上执行采用单卡单进程运行方式,即每张卡上运行1个进程,进程数量与使用的卡的数量一致。进程均放在后台执行,每个进程创建1个目录,目录名称为train_parallel+ device_id,用来保存日志信息,算子编译信息以及训练的checkpoint文件。下面以使用8张卡的分布式训练脚本为例,演示如何运行脚本。

使用以下命令运行脚本:

bash run_distribute_train.sh <RANK_TABLE_FILE> <DATASET_PATH> [CONFIG_PATH]

脚本需要传入变量RANK_TABLE_FILE,DATASET_PATH和CONFIG_PATH,其中:

  • RANK_TABLE_FILE:组网信息文件的路径。(rank table文件的生成,参考HCCL_TOOL)

  • DATASET_PATH:训练数据集路径。

  • CONFIG_PATH:配置文件路径。

其余环境变量请参考安装教程中的配置项。

训练过程中loss打印示例如下:


epoch: 1 step: 5004, loss is 4.4182425
epoch: 2 step: 5004, loss is 3.740064
epoch: 3 step: 5004, loss is 4.0546017
epoch: 4 step: 5004, loss is 3.7598825
epoch: 5 step: 5004, loss is 3.3744206epoch: 40 step: 5004, loss is 1.6907625
epoch: 41 step: 5004, loss is 1.8217756
epoch: 42 step: 5004, loss is 1.6453942

训练完后,每张卡训练产生的checkpoint文件保存在各自训练目录下,device_0产生的checkpoint文件示例如下:

└─train_parallel0├─ckpt_0├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-42_5004.ckpt│      ......

其中, *.ckpt:指保存的模型参数文件。checkpoint文件名称具体含义:网络名称-epoch数_step数.ckpt。

GPU

在GPU硬件平台上,MindSpore采用OpenMPI的mpirun进行分布式训练,进程创建1个目录,目录名称为train_parallel,用来保存日志信息和训练的checkpoint文件。下面以使用8张卡的分布式训练脚本为例,演示如何运行脚本。

使用以下命令运行脚本:

bash run_distribute_train_gpu.sh <DATASET_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH和CONFIG_PATH,其中:

  • DATASET_PATH:训练数据集路径。

  • CONFIG_PATH:配置文件路径。

在GPU训练时,无需设置DEVICE_ID环境变量,因此在主训练脚本中不需要调用int(os.getenv(‘DEVICE_ID’))来获取卡的物理序号,同时context中也无需传入device_id。我们需要将device_target设置为GPU,并需要调用init()来使能NCCL。

训练过程中loss打印示例如下:


epoch: 1 step: 5004, loss is 4.2546034
epoch: 2 step: 5004, loss is 4.0819564
epoch: 3 step: 5004, loss is 3.7005644
epoch: 4 step: 5004, loss is 3.2668946
epoch: 5 step: 5004, loss is 3.023509epoch: 36 step: 5004, loss is 1.645802

训练完后,保存的模型文件示例如下:

└─train_parallel├─ckpt_0├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-36_5004.ckpt│      ............├─ckpt_7├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-36_5004.ckpt│      ......

模型推理

使用训练过程中保存的checkpoint文件进行推理,验证模型的泛化能力。首先通过load_checkpoint接口加载模型文件,然后调用Model的eval接口对输入图片类别作出预测,再与输入图片的真实类别做比较,得出最终的预测精度值。

定义推理网络

  1. 使用load_checkpoint接口加载模型文件。

  2. 使用model.eval接口读入测试数据集,进行推理。

  3. 计算得出预测精度值。


import mindspore as ms
from mindspore.train import Modelif __name__ == "__main__":...# define netnet = resnet(class_num=config.class_num)# load checkpointparam_dict = ms.load_checkpoint(args_opt.checkpoint_path)ms.load_param_into_net(net, param_dict)net.set_train(False)# define lossif args_opt.dataset == "imagenet2012":if not config.use_label_smooth:config.label_smooth_factor = 0.0loss = CrossEntropySmooth(sparse=True, reduction='mean',smooth_factor=config.label_smooth_factor, num_classes=config.class_num)else:loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# define modelmodel = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})# eval modelres = model.eval(dataset)print("result:", res, "ckpt=", args_opt.checkpoint_path)...

执行推理

推理网络定义完成之后,调用scripts目录下的shell脚本,进行推理。

Atlas训练系列产品

在Atlas训练系列产品平台上,推理的执行命令如下:

bash run_eval.sh <DATASET_PATH> <CHECKPOINT_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH,CHECKPOINT_PATH和<CONFIG_PATH>,其中:

  • DATASET_PATH:推理数据集路径。

  • CHECKPOINT_PATH:保存的checkpoint路径。

  • CONFIG_PATH:配置文件路径。

目前推理使用的是单卡(默认device 0)进行推理,推理的结果如下:

result: {'top_5_accuracy': 0.9295574583866837, 'top_1_accuracy': 0.761443661971831} ckpt=train_parallel0/resnet-42_5004.ckpt
GPU

在GPU硬件平台上,推理的执行命令如下:

  bash run_eval_gpu.sh <DATASET_PATH> <CHECKPOINT_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH,CHECKPOINT_PATH和CONFIG_PATH,其中:

  • DATASET_PATH:推理数据集路径。

  • CHECKPOINT_PATH:保存的checkpoint路径。

  • CONFIG_PATH:配置文件路径。

推理的结果如下:

result: {'top_5_accuracy': 0.9287972151088348, 'top_1_accuracy': 0.7597031049935979} ckpt=train_parallel/resnet-36_5004.ckpt

相关文章:

昇思MindSpore进阶教程--在ResNet-50网络上应用二阶优化实践(下)

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 文章上半部分请查看 在ResNet-50网络上应…...

基于大数据的Python+Django电影票房数据可视化分析系统设计与实现

目录 1 引言 2 系统需求分析 3 技术选型 4 系统架构设计 5 关键技术实现 6 系统实现 7 总结与展望 1 引言 随着数字媒体技术的发展&#xff0c;电影产业已经成为全球经济文化不可或缺的一部分。电影不仅是艺术表达的形式&#xff0c;更是大众娱乐的重要来源。在这个背景…...

实景三维技术对光伏产业的发展具有哪些优势?

实景三维技术对光伏产业的发展具有显著的优势&#xff0c;主要体现在提高选址准确性、优化用地规划、促进数据融合应用以及赋能文旅服务领域。‌ 提高选址准确性‌&#xff1a;通过构建高精度的三维地形模型&#xff0c;结合卫星遥感、无人机测绘等技术手段&#xff0c;实景三维…...

四非人的保研之路,2024(2025届)四非计算机的保研经验分享(西南交通、苏大nlp、西电、北邮、山软、山计、电科、厦大等)

文章目录 一、个人背景二、夏令营北京邮电大学CS西南交通大学CS深圳大学CS苏州大学NLP南开大学CS 三、预推免北京邮电大学CS华东师范大学 CS和大数据电子科技大学 CS东北大学 CS厦门大学 信息学院山东大学 CS和SE西安电子科技大学 CS 四、个人经验五、上岸 一、个人背景 学校专…...

UE5.4.3 录屏回放系统ReplaySystem蓝图版

这是ReplaySystem的蓝图使用方法版&#xff0c;以第三人称模版为例&#xff0c;需要几个必须步骤 项目config内DefaultEngine.ini的最后添加&#xff1a; [/Script/Engine.GameEngine] NetDriverDefinitions(DefName"DemoNetDriver",DriverClassName"/Script/…...

ECCV 2024 | 融合跨模态先验与扩散模型,快手处理大模型让视频画面更清晰!

计算机视觉领域顶级会议 European Conference on Computer Vision&#xff08;ECCV 2024&#xff09;将于9月29日至10月4日在意大利米兰召开&#xff0c;快手音视频技术部联合清华大学所发表的题为《XPSR: Cross-modal Priors for Diffusion-based Image Super-Resolution》——…...

9--苍穹外卖-SpringBoot项目中Redis的介绍及其使用实例 详解

目录 Redis入门 Redis简介 Redis服务启动与停止 服务启动命令 Redis数据类型 5种常用数据类型介绍 各种数据类型的特点 Redis常用命令 字符串操作命令 哈希操作命令 列表操作命令 集合操作命令 有序集合操作命令 通用命令 在java中操作Redis Redis的Java客户端 …...

【EXCEL数据处理】000014 案例 EXCEL分类汇总、定位和创建组。附多个操作案例。

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 【EXCEL数据处理】000014 案例 EXCEL分类汇总、定位和创建组。附多个操…...

Windows环境Apache httpd 2.4 web服务器加载PHP8:Hello,world!

Windows环境Apache httpd 2.4 web服务器加载PHP8&#xff1a;Hello&#xff0c;world&#xff01; &#xff08;1&#xff09;首先需要安装apache httpd 2.4 web服务器&#xff1a; Windows安装启动apache httpd 2.4 web服务器-CSDN博客文章浏览阅读222次&#xff0c;点赞5次&…...

Spring框架使用Api接口实现AOP的切面编程、两种方式的程序示例以及Java各数据类型及基本数据类型的默认值/最大值/最小值列表

一、Spring框架使用Api接口-继承类实现AOP的切面编程示例 要使用Spring框架AOP&#xff0c;除了要导入spring框架包外&#xff0c;还需要导入一个织入的包org.aspectj&#xff0c;具体maven依赖如下&#xff1a; <dependency><groupId>org.springframework</gr…...

【达梦数据库】尽可能 disql 的使用效果与异构数据库一致

文章目录 前言disql 效果优化参数设置参数说明 mysql参数设置参数说明 db2参数设置参数说明 待补充 前言 让达梦的disql 使用起来更跟手&#xff0c;与其他优质数据库的命令行工具通过配置参数的方式尽可能一致&#xff0c;提高使用体验&#xff0c;长期整理中~~~ 测试版本&…...

【研1深度学习】《神经网络和深度学习》阅读笔记(记录中......

9.27 语义鸿沟&#xff1a; 是指输入数据的底层特征和高层语义信息之间的不一致性和查一下。如果可以有一个好的表示在某种程度上能够反映出数据的高层语义特征&#xff0c;那么我们就能相对容易的构建后续的机器学习模型。嵌入&#xff08;Embedding&#xff09;&#xff1a;…...

十一不停歇-学习ROS2第一天 (10.2 10:45)

话题通信 1.1 发布第一个节点&#xff1a; import rclpy #导入此类模块 rcl类型 from rclpy.node import Node #从这个子模块中导入这类函数 def main(): #定义这个函数 rclpy.init() #使用初始化函数 node Node(hello_python) 将类函数里面的内容调给…...

Java高效编程(14):考虑实现 `Comparable

解锁Python编程的无限可能&#xff1a;《奇妙的Python》带你漫游代码世界 与其他方法不同&#xff0c;compareTo 并非 Object 类中声明的&#xff0c;而是 Comparable 接口的唯一方法。compareTo 方法与 equals 类似&#xff0c;但它不仅支持相等性比较&#xff0c;还允许顺序…...

华为昇腾CANN训练营2024第二季--Ascend C算子开发能力认证(中级)题目和经验分享

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 正文开始 华为昇腾CANN训练营2024第二季…...

实战OpenCV之形态学操作

基础入门 形态学操作是一种基于图像形状的处理方法,主要用于结构分析,比如:边缘检测、轮廓提取、噪声去除等。这些操作通常使用一个称为“结构元素”(Structuring Element)的核来进行,结构元素可以是任何形状,但最常见的有矩形和圆形。形态学操作的核心在于通过结构元素…...

矩阵的特征值和特征向量

矩阵的特征值和特征向量是线性代数中非常重要的概念&#xff0c;用于描述矩阵对向量的作用&#xff0c;特别是在矩阵对向量的线性变换中的表现。它们帮助我们理解矩阵在某些方向上的缩放或旋转效果。 1. 特征值和特征向量的定义&#xff1a; 给定一个 n n n \times n nn 的…...

(11)MATLAB莱斯(Rician)衰落信道仿真2

文章目录 前言一、莱斯衰落信道仿真模型二、仿真代码与结果1.仿真代码2.仿真结果画图 三、后续&#xff1a;四、参考文献&#xff1a; 前言 首先给出莱斯衰落信道仿真模型&#xff0c;该模型由直射路径分量和反射路径分量组成&#xff0c;其中反射路径分量由瑞利衰落信道模型构…...

ComfyUI局部重绘换衣讲解

一、下载插件 ComfyUI-Impact-Pack 下载地址 https://github.com/ltdrdata/ComfyUI-Impact-Pack 主要用到sam Detector去绘制衣服蒙版和高斯模糊蒙版&#xff0c;高斯模糊让蒙版边缘更加柔和 sams模型 放在E:\Comfyui\ComfyUI\models\sams二、换衣思路 文生图或直接上传…...

Android——添加联系人

概述 方式一&#xff1a;使用ContentResolver多次写入&#xff0c;每次写入一个字段 第一步 往手机联系人应用中的raw_contacts表添加一条记录 raw_contacts表 ContentValues values new ContentValues();// 往 raw_contacts 添加联系人记录&#xff0c;并获取添加后的联…...

龙虎榜——20250610

上证指数放量收阴线&#xff0c;个股多数下跌&#xff0c;盘中受消息影响大幅波动。 深证指数放量收阴线形成顶分型&#xff0c;指数短线有调整的需求&#xff0c;大概需要一两天。 2025年6月10日龙虎榜行业方向分析 1. 金融科技 代表标的&#xff1a;御银股份、雄帝科技 驱动…...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外&#xff0c;K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案&#xff0c;全安装在K8S群集中。 具体可参…...

Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误

HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误&#xff0c;它们的含义、原因和解决方法都有显著区别。以下是详细对比&#xff1a; 1. HTTP 406 (Not Acceptable) 含义&#xff1a; 客户端请求的内容类型与服务器支持的内容类型不匹…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例

使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件&#xff0c;常用于在两个集合之间进行数据转移&#xff0c;如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model&#xff1a;绑定右侧列表的值&…...

STM32+rt-thread判断是否联网

一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)

目录 1.TCP的连接管理机制&#xff08;1&#xff09;三次握手①握手过程②对握手过程的理解 &#xff08;2&#xff09;四次挥手&#xff08;3&#xff09;握手和挥手的触发&#xff08;4&#xff09;状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

系统设计 --- MongoDB亿级数据查询优化策略

系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log&#xff0c;共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题&#xff0c;不能使用ELK只能使用…...