当前位置: 首页 > news >正文

昇思MindSpore进阶教程--在ResNet-50网络上应用二阶优化实践(下)

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

文章上半部分请查看
在ResNet-50网络上应用二阶优化实践(上)

训练网络

配置模型保存

MindSpore提供了callback机制,可以在训练过程中执行自定义逻辑,这里使用框架提供的ModelCheckpoint函数。 ModelCheckpoint可以保存网络模型和参数,以便进行后续的fine-tuning操作。 TimeMonitor、LossMonitor是MindSpore官方提供的callback函数,可以分别用于监控训练过程中单步迭代时间和loss值的变化。


import mindspore as ms
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitorif __name__ == "__main__":# define callbackstime_cb = TimeMonitor(data_size=step_size)loss_cb = LossMonitor()cb = [time_cb, loss_cb]if config.save_checkpoint:config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,keep_checkpoint_max=config.keep_checkpoint_max)ckpt_cb = ModelCheckpoint(prefix="resnet", directory=ckpt_save_dir, config=config_ck)cb += [ckpt_cb]

配置训练网络

通过MindSpore提供的model.train接口可以方便地进行网络的训练。THOR优化器通过降低二阶矩阵更新频率,来减少计算量,提升计算速度,故重新定义一个ModelThor类,继承MindSpore提供的Model类。在ModelThor类中获取THOR的二阶矩阵更新频率控制参数,用户可以通过调整该参数,优化整体的性能。 MindSpore提供Model类向ModelThor类的一键转换接口。


import mindspore as ms
from mindspore import amp
from mindspore.train import Model, ConvertModelUtilsif __name__ == "__main__":...loss_scale = amp.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics=metrics,amp_level="O2", keep_batchnorm_fp32=False, eval_network=dist_eval_network)if cfg.optimizer == "Thor":model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,loss_scale_manager=loss_scale, metrics={'acc'},amp_level="O2", keep_batchnorm_fp32=False)  ...

运行脚本

训练脚本定义完成之后,调scripts目录下的shell脚本,启动分布式训练进程。

Atlas训练系列产品

目前MindSpore分布式在Ascend上执行采用单卡单进程运行方式,即每张卡上运行1个进程,进程数量与使用的卡的数量一致。进程均放在后台执行,每个进程创建1个目录,目录名称为train_parallel+ device_id,用来保存日志信息,算子编译信息以及训练的checkpoint文件。下面以使用8张卡的分布式训练脚本为例,演示如何运行脚本。

使用以下命令运行脚本:

bash run_distribute_train.sh <RANK_TABLE_FILE> <DATASET_PATH> [CONFIG_PATH]

脚本需要传入变量RANK_TABLE_FILE,DATASET_PATH和CONFIG_PATH,其中:

  • RANK_TABLE_FILE:组网信息文件的路径。(rank table文件的生成,参考HCCL_TOOL)

  • DATASET_PATH:训练数据集路径。

  • CONFIG_PATH:配置文件路径。

其余环境变量请参考安装教程中的配置项。

训练过程中loss打印示例如下:


epoch: 1 step: 5004, loss is 4.4182425
epoch: 2 step: 5004, loss is 3.740064
epoch: 3 step: 5004, loss is 4.0546017
epoch: 4 step: 5004, loss is 3.7598825
epoch: 5 step: 5004, loss is 3.3744206epoch: 40 step: 5004, loss is 1.6907625
epoch: 41 step: 5004, loss is 1.8217756
epoch: 42 step: 5004, loss is 1.6453942

训练完后,每张卡训练产生的checkpoint文件保存在各自训练目录下,device_0产生的checkpoint文件示例如下:

└─train_parallel0├─ckpt_0├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-42_5004.ckpt│      ......

其中, *.ckpt:指保存的模型参数文件。checkpoint文件名称具体含义:网络名称-epoch数_step数.ckpt。

GPU

在GPU硬件平台上,MindSpore采用OpenMPI的mpirun进行分布式训练,进程创建1个目录,目录名称为train_parallel,用来保存日志信息和训练的checkpoint文件。下面以使用8张卡的分布式训练脚本为例,演示如何运行脚本。

使用以下命令运行脚本:

bash run_distribute_train_gpu.sh <DATASET_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH和CONFIG_PATH,其中:

  • DATASET_PATH:训练数据集路径。

  • CONFIG_PATH:配置文件路径。

在GPU训练时,无需设置DEVICE_ID环境变量,因此在主训练脚本中不需要调用int(os.getenv(‘DEVICE_ID’))来获取卡的物理序号,同时context中也无需传入device_id。我们需要将device_target设置为GPU,并需要调用init()来使能NCCL。

训练过程中loss打印示例如下:


epoch: 1 step: 5004, loss is 4.2546034
epoch: 2 step: 5004, loss is 4.0819564
epoch: 3 step: 5004, loss is 3.7005644
epoch: 4 step: 5004, loss is 3.2668946
epoch: 5 step: 5004, loss is 3.023509epoch: 36 step: 5004, loss is 1.645802

训练完后,保存的模型文件示例如下:

└─train_parallel├─ckpt_0├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-36_5004.ckpt│      ............├─ckpt_7├─resnet-1_5004.ckpt├─resnet-2_5004.ckpt│      ......├─resnet-36_5004.ckpt│      ......

模型推理

使用训练过程中保存的checkpoint文件进行推理,验证模型的泛化能力。首先通过load_checkpoint接口加载模型文件,然后调用Model的eval接口对输入图片类别作出预测,再与输入图片的真实类别做比较,得出最终的预测精度值。

定义推理网络

  1. 使用load_checkpoint接口加载模型文件。

  2. 使用model.eval接口读入测试数据集,进行推理。

  3. 计算得出预测精度值。


import mindspore as ms
from mindspore.train import Modelif __name__ == "__main__":...# define netnet = resnet(class_num=config.class_num)# load checkpointparam_dict = ms.load_checkpoint(args_opt.checkpoint_path)ms.load_param_into_net(net, param_dict)net.set_train(False)# define lossif args_opt.dataset == "imagenet2012":if not config.use_label_smooth:config.label_smooth_factor = 0.0loss = CrossEntropySmooth(sparse=True, reduction='mean',smooth_factor=config.label_smooth_factor, num_classes=config.class_num)else:loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# define modelmodel = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})# eval modelres = model.eval(dataset)print("result:", res, "ckpt=", args_opt.checkpoint_path)...

执行推理

推理网络定义完成之后,调用scripts目录下的shell脚本,进行推理。

Atlas训练系列产品

在Atlas训练系列产品平台上,推理的执行命令如下:

bash run_eval.sh <DATASET_PATH> <CHECKPOINT_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH,CHECKPOINT_PATH和<CONFIG_PATH>,其中:

  • DATASET_PATH:推理数据集路径。

  • CHECKPOINT_PATH:保存的checkpoint路径。

  • CONFIG_PATH:配置文件路径。

目前推理使用的是单卡(默认device 0)进行推理,推理的结果如下:

result: {'top_5_accuracy': 0.9295574583866837, 'top_1_accuracy': 0.761443661971831} ckpt=train_parallel0/resnet-42_5004.ckpt
GPU

在GPU硬件平台上,推理的执行命令如下:

  bash run_eval_gpu.sh <DATASET_PATH> <CHECKPOINT_PATH> <CONFIG_PATH>

脚本需要传入变量DATASET_PATH,CHECKPOINT_PATH和CONFIG_PATH,其中:

  • DATASET_PATH:推理数据集路径。

  • CHECKPOINT_PATH:保存的checkpoint路径。

  • CONFIG_PATH:配置文件路径。

推理的结果如下:

result: {'top_5_accuracy': 0.9287972151088348, 'top_1_accuracy': 0.7597031049935979} ckpt=train_parallel/resnet-36_5004.ckpt

相关文章:

昇思MindSpore进阶教程--在ResNet-50网络上应用二阶优化实践(下)

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 文章上半部分请查看 在ResNet-50网络上应…...

基于大数据的Python+Django电影票房数据可视化分析系统设计与实现

目录 1 引言 2 系统需求分析 3 技术选型 4 系统架构设计 5 关键技术实现 6 系统实现 7 总结与展望 1 引言 随着数字媒体技术的发展&#xff0c;电影产业已经成为全球经济文化不可或缺的一部分。电影不仅是艺术表达的形式&#xff0c;更是大众娱乐的重要来源。在这个背景…...

实景三维技术对光伏产业的发展具有哪些优势?

实景三维技术对光伏产业的发展具有显著的优势&#xff0c;主要体现在提高选址准确性、优化用地规划、促进数据融合应用以及赋能文旅服务领域。‌ 提高选址准确性‌&#xff1a;通过构建高精度的三维地形模型&#xff0c;结合卫星遥感、无人机测绘等技术手段&#xff0c;实景三维…...

四非人的保研之路,2024(2025届)四非计算机的保研经验分享(西南交通、苏大nlp、西电、北邮、山软、山计、电科、厦大等)

文章目录 一、个人背景二、夏令营北京邮电大学CS西南交通大学CS深圳大学CS苏州大学NLP南开大学CS 三、预推免北京邮电大学CS华东师范大学 CS和大数据电子科技大学 CS东北大学 CS厦门大学 信息学院山东大学 CS和SE西安电子科技大学 CS 四、个人经验五、上岸 一、个人背景 学校专…...

UE5.4.3 录屏回放系统ReplaySystem蓝图版

这是ReplaySystem的蓝图使用方法版&#xff0c;以第三人称模版为例&#xff0c;需要几个必须步骤 项目config内DefaultEngine.ini的最后添加&#xff1a; [/Script/Engine.GameEngine] NetDriverDefinitions(DefName"DemoNetDriver",DriverClassName"/Script/…...

ECCV 2024 | 融合跨模态先验与扩散模型,快手处理大模型让视频画面更清晰!

计算机视觉领域顶级会议 European Conference on Computer Vision&#xff08;ECCV 2024&#xff09;将于9月29日至10月4日在意大利米兰召开&#xff0c;快手音视频技术部联合清华大学所发表的题为《XPSR: Cross-modal Priors for Diffusion-based Image Super-Resolution》——…...

9--苍穹外卖-SpringBoot项目中Redis的介绍及其使用实例 详解

目录 Redis入门 Redis简介 Redis服务启动与停止 服务启动命令 Redis数据类型 5种常用数据类型介绍 各种数据类型的特点 Redis常用命令 字符串操作命令 哈希操作命令 列表操作命令 集合操作命令 有序集合操作命令 通用命令 在java中操作Redis Redis的Java客户端 …...

【EXCEL数据处理】000014 案例 EXCEL分类汇总、定位和创建组。附多个操作案例。

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 【EXCEL数据处理】000014 案例 EXCEL分类汇总、定位和创建组。附多个操…...

Windows环境Apache httpd 2.4 web服务器加载PHP8:Hello,world!

Windows环境Apache httpd 2.4 web服务器加载PHP8&#xff1a;Hello&#xff0c;world&#xff01; &#xff08;1&#xff09;首先需要安装apache httpd 2.4 web服务器&#xff1a; Windows安装启动apache httpd 2.4 web服务器-CSDN博客文章浏览阅读222次&#xff0c;点赞5次&…...

Spring框架使用Api接口实现AOP的切面编程、两种方式的程序示例以及Java各数据类型及基本数据类型的默认值/最大值/最小值列表

一、Spring框架使用Api接口-继承类实现AOP的切面编程示例 要使用Spring框架AOP&#xff0c;除了要导入spring框架包外&#xff0c;还需要导入一个织入的包org.aspectj&#xff0c;具体maven依赖如下&#xff1a; <dependency><groupId>org.springframework</gr…...

【达梦数据库】尽可能 disql 的使用效果与异构数据库一致

文章目录 前言disql 效果优化参数设置参数说明 mysql参数设置参数说明 db2参数设置参数说明 待补充 前言 让达梦的disql 使用起来更跟手&#xff0c;与其他优质数据库的命令行工具通过配置参数的方式尽可能一致&#xff0c;提高使用体验&#xff0c;长期整理中~~~ 测试版本&…...

【研1深度学习】《神经网络和深度学习》阅读笔记(记录中......

9.27 语义鸿沟&#xff1a; 是指输入数据的底层特征和高层语义信息之间的不一致性和查一下。如果可以有一个好的表示在某种程度上能够反映出数据的高层语义特征&#xff0c;那么我们就能相对容易的构建后续的机器学习模型。嵌入&#xff08;Embedding&#xff09;&#xff1a;…...

十一不停歇-学习ROS2第一天 (10.2 10:45)

话题通信 1.1 发布第一个节点&#xff1a; import rclpy #导入此类模块 rcl类型 from rclpy.node import Node #从这个子模块中导入这类函数 def main(): #定义这个函数 rclpy.init() #使用初始化函数 node Node(hello_python) 将类函数里面的内容调给…...

Java高效编程(14):考虑实现 `Comparable

解锁Python编程的无限可能&#xff1a;《奇妙的Python》带你漫游代码世界 与其他方法不同&#xff0c;compareTo 并非 Object 类中声明的&#xff0c;而是 Comparable 接口的唯一方法。compareTo 方法与 equals 类似&#xff0c;但它不仅支持相等性比较&#xff0c;还允许顺序…...

华为昇腾CANN训练营2024第二季--Ascend C算子开发能力认证(中级)题目和经验分享

大家好&#xff0c;我是刘明&#xff0c;明志科技创始人&#xff0c;华为昇思MindSpore布道师。 技术上主攻前端开发、鸿蒙开发和AI算法研究。 努力为大家带来持续的技术分享&#xff0c;如果你也喜欢我的文章&#xff0c;就点个关注吧 正文开始 华为昇腾CANN训练营2024第二季…...

实战OpenCV之形态学操作

基础入门 形态学操作是一种基于图像形状的处理方法,主要用于结构分析,比如:边缘检测、轮廓提取、噪声去除等。这些操作通常使用一个称为“结构元素”(Structuring Element)的核来进行,结构元素可以是任何形状,但最常见的有矩形和圆形。形态学操作的核心在于通过结构元素…...

矩阵的特征值和特征向量

矩阵的特征值和特征向量是线性代数中非常重要的概念&#xff0c;用于描述矩阵对向量的作用&#xff0c;特别是在矩阵对向量的线性变换中的表现。它们帮助我们理解矩阵在某些方向上的缩放或旋转效果。 1. 特征值和特征向量的定义&#xff1a; 给定一个 n n n \times n nn 的…...

(11)MATLAB莱斯(Rician)衰落信道仿真2

文章目录 前言一、莱斯衰落信道仿真模型二、仿真代码与结果1.仿真代码2.仿真结果画图 三、后续&#xff1a;四、参考文献&#xff1a; 前言 首先给出莱斯衰落信道仿真模型&#xff0c;该模型由直射路径分量和反射路径分量组成&#xff0c;其中反射路径分量由瑞利衰落信道模型构…...

ComfyUI局部重绘换衣讲解

一、下载插件 ComfyUI-Impact-Pack 下载地址 https://github.com/ltdrdata/ComfyUI-Impact-Pack 主要用到sam Detector去绘制衣服蒙版和高斯模糊蒙版&#xff0c;高斯模糊让蒙版边缘更加柔和 sams模型 放在E:\Comfyui\ComfyUI\models\sams二、换衣思路 文生图或直接上传…...

Android——添加联系人

概述 方式一&#xff1a;使用ContentResolver多次写入&#xff0c;每次写入一个字段 第一步 往手机联系人应用中的raw_contacts表添加一条记录 raw_contacts表 ContentValues values new ContentValues();// 往 raw_contacts 添加联系人记录&#xff0c;并获取添加后的联…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

跨链模式:多链互操作架构与性能扩展方案

跨链模式&#xff1a;多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈&#xff1a;模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展&#xff08;H2Cross架构&#xff09;&#xff1a; 适配层&#xf…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

Java入门学习详细版(一)

大家好&#xff0c;Java 学习是一个系统学习的过程&#xff0c;核心原则就是“理论 实践 坚持”&#xff0c;并且需循序渐进&#xff0c;不可过于着急&#xff0c;本篇文章推出的这份详细入门学习资料将带大家从零基础开始&#xff0c;逐步掌握 Java 的核心概念和编程技能。 …...

用docker来安装部署freeswitch记录

今天刚才测试一个callcenter的项目&#xff0c;所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

【生成模型】视频生成论文调研

工作清单 上游应用方向&#xff1a;控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...

JVM 内存结构 详解

内存结构 运行时数据区&#xff1a; Java虚拟机在运行Java程序过程中管理的内存区域。 程序计数器&#xff1a; ​ 线程私有&#xff0c;程序控制流的指示器&#xff0c;分支、循环、跳转、异常处理、线程恢复等基础功能都依赖这个计数器完成。 ​ 每个线程都有一个程序计数…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

基于SpringBoot在线拍卖系统的设计和实现

摘 要 随着社会的发展&#xff0c;社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统&#xff0c;主要的模块包括管理员&#xff1b;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...