当前位置: 首页 > news >正文

华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区

在这里插入图片描述

在这里插入图片描述

一、环境准备

1.进入ModelArts官网

云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpore2.0.0-alpha版本,可以在昇思教程中进入ModelArts官网

在这里插入图片描述

选择下方CodeLab立即体验

在这里插入图片描述

等待环境搭建完成

在这里插入图片描述

2.使用CodeLab体验Notebook实例

下载NoteBook样例代码,Lite-HRNet实现人体关键点检测 ,.ipynb为样例代码

在这里插入图片描述

选择ModelArts Upload Files上传.ipynb文件

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

选择Kernel环境

在这里插入图片描述

切换至GPU环境,切换成第一个限时免费

在这里插入图片描述

进入昇思MindSpore官网,点击上方的安装

在这里插入图片描述

获取安装命令

在这里插入图片描述

回到Notebook中,在第一块代码前加入命令
在这里插入图片描述

conda update -n base -c defaults conda

在这里插入图片描述

安装MindSpore 2.0 GPU版本

conda install mindspore=2.0.0a0 -c mindspore -c conda-forge

在这里插入图片描述

安装mindvision

pip install mindvision

在这里插入图片描述

安装下载download

pip install download

人体关键点检测模型Lite-HRNet

人体关键点检测是计算机视觉的基本任务之一,在许多应用场景诸如自动驾驶、安防等有着重要的地位。可以发现,在这些应用场景下,深度学习模型可能需要部署在IoT设备上,这些设备算力较低,存储空间有限,无法支撑太大的模型,因此轻量但不失高性能的人体关键点检测级模型将极大降低模型部署难度。Lite-HRNet便提供了一轻量级神经网络骨干,通过接上不同的后续模型可以完成不同的任务,其中便包括人体关键点检测,在配置合理的情况下,Lite-HRNet可以以大型神经网络数十分之一的参数量及计算量达到相近的性能。

模型简介

Lite-HRNet由HRNet(High-Resolution Network)改进而来,HRNet的主要思路是在前向传播过程中通过维持不同分辨率的特征,使得最后生成的高阶特征既可以保留低分辨率高阶特征中的图像语义信息,也可以保留高分辨率高阶特征中的物体位置信息,进而提高在分辨率敏感的任务如语义分割、姿态检测中的表现。Lite-HRNet是HRNet的轻量化改进,改进了HRNet中的卷积模块,将HRNet中的参数量从28.5M降低至1.1M,计算量从7.1GFLOPS降低至0.2GFLOPS,但AP75仅下降了7%。
综上,Lite-HRNet具有计算量、参数量低,精度可观的优点,有利于部署在物联网低算力设备上服务于各个应用场景。

数据准备

本案例使用COCO2017数据集作为训练、验证数据集,请首先安装Mindspore Vision套件,并确保安装的Mindspore是GPU版本,随后请在https://cocodataset.org/ 上下载好2017 Train Images、2017 Val Images以及对应的标记2017 Train/Val Annotations,并解压至当前文件夹,文件夹结构下表所示

Lite-HRNet/├── imgs├── src├── annotations├──person_keypoints_train2017.json└──person_keypoints_train2017.json├── train2017└── val2017

训练、测试原始图片如下所示,图片中可能包含多个人体,且包含的人体不一定包含COCO2017中定义的17个关键点,标注中有每个人体的边框、关键点信息,以便处理图像后供模型训练。

数据预处理

src/mindspore_coco.py中定义了供mindspore模型训练、测试的COCO数据集接口,在加载训练数据集时只需指定所用数据集文件夹位置、输入图像的尺寸、目标热力图的尺寸、以及手动设置对训练图像采用的变换即可


import mindspore as ms
import mindspore.dataset as dataset
import mindspore.dataset.vision.py_transforms as py_vision
import mindspore.nn as nn
from mindspore.dataset.transforms.py_transforms import Composefrom src.configs.dataset_config import COCOConfig
from src.dataset.mindspore_coco import COCODatasetcfg = COCOConfig(root="./", output_dir="outputs/", image_size=[192, 256], heatmap_size=[48, 64])
trans = Compose([py_vision.ToTensor(),py_vision.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
train_ds = COCODataset(cfg, "../", "train2017", True, transform=trans)
train_loader = dataset.GeneratorDataset(train_ds, ["data", "target", "weight"])

在这里插入图片描述

构建网络

Lite-HRNet网络骨干大体结构如下图所示:
在这里插入图片描述

网络中存在不同的分辨率分支,网络主干上维持着较高分辨率、较少通道数的输出特征,网络分支上延展出较低分辨率、较多通道数的输出特征,且这些不同分辨率的特征之间通过上采样、下采样的卷积层进行交互、融合。Stage内的Cross Channel Weighting(CCW)则是网络实现轻量化的精髓,它将原HRNet中复杂度较高的1*1卷积以更低复杂度的Spatial Weighting等方法替代,从而实现降低网络参数、计算量的效果。CCW的结构如下图所示

在这里插入图片描述

值得注意的是,除了骨干网络,作者在论文中同时也给出了所使用的检测头即SimpleBaseline,为了简洁起见,在本次的Lite-HRNet的Mindspore实现中,检测头(代码中包括IterativeHeads和LiteTopDownSimpleHeatMap)已集成至骨干网络之后,作为整体模型的一部分,直接调用模型即可得到热力图预测输出。

损失函数

此处使用损失函数为JointMSELoss,即关节点的均方差误差损失函数,其源码如下所示,总体流程即计算每个关节点预测热力图与实际热力图的均方差,其中target是根据关节点的人工标注坐标,通过二维高斯分布生成的热力图,target_weight用于指定参与计算的关节点,若某关节点对应target_weight取值为0,则表明该关节点在输入图像中未出现,不参与计算。

"""JointMSELoss"""
import mindspore.nn as nn
import mindspore.ops as opsclass JointsMSELoss(nn.Cell):"""Joint MSELoss"""def __init__(self, use_target_weight):"""JointMSELoss"""super(JointsMSELoss, self).__init__()self.criterion = nn.MSELoss(reduction='mean')self.use_target_weight = use_target_weightdef construct(self, output, target, weight):"""construct"""target = targettarget_weight = weightbatch_size = output.shape[0]num_joints = output.shape[1]spliter = ops.Split(axis=1, output_num=num_joints)mul = ops.Mul()heatmaps_pred = spliter(output.reshape((batch_size, num_joints, -1)))heatmaps_gt = spliter(target.reshape((batch_size, num_joints, -1)))loss = 0for idx in range(num_joints):heatmap_pred = heatmaps_pred[idx].squeeze()heatmap_gt = heatmaps_gt[idx].squeeze()if self.use_target_weight:heatmap_pred = mul(heatmap_pred, target_weight[:, idx])heatmap_gt = mul(heatmap_gt, target_weight[:, idx])loss += 0.5 * self.criterion(heatmap_pred,heatmap_gt)else:loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)return loss/num_joints

模型实现与训练

在实现模型时,需指定模型内部结构,在src/net_configs中已指定原论文中10种结构配置,在训练样例种取Lite_18_coco作为模型结构,此处作为案例,仅设置epoch数量为1,在实际训练中可以设置为200,并且可以加入warmup。由于mindspore的训练接口默认数据集中每条数据只有两列(即训练数据和标签),所以这里需自定义Loss Cell。值得注意的是loss在训练前后变化并不会十分大,训练好的模型的loss为0.0004左右

class CustomWithLossCell(nn.Cell):def __init__(self,net: nn.Cell,loss_fn: nn.Cell):super(CustomWithLossCell, self).__init__()self.net = netself._loss_fn = loss_fndef construct(self, img, target, weight):""" build network """heatmap_pred = self.net(img)return self._loss_fn(heatmap_pred,target,weight)
from src.configs.net_configs import get_netconfig
from mindspore.train.callback import  LossMonitor
from src.backbone import LiteHRNetext = get_netconfig("extra_lite_18_coco")
net = LiteHRNet(ext)
criterion = JointsMSELoss(use_target_weight=True)train_loader = train_loader.batch(64)
optim = nn.Adam(net.trainable_params(), learning_rate=2e-3)
loss = JointsMSELoss(use_target_weight=True)
net_with_loss = CustomWithLossCell(net, loss)model = ms.Model(network=net_with_loss, optimizer=optim)
epochs = 1
#Start Training
model.train(epochs, train_loader, callbacks=[LossMonitor(100)], dataset_sink_mode=False)

在这里插入图片描述

模型评估

模型评估过程中使用AP、AP50、AP75以及AR50、AR75作为评价指标,val2017作为评价数据集,pycocotool包中已实现根据评价函数,且src/mindspore_coco.py中的evaluate函数也实现了调用该评价函数的接口,只需提供预测关键点坐标等信息即可获得评价指标。此处载入Lite_18_coco的预训练模型进行评价。

from mindspore import load_checkpoint
from mindspore import load_param_into_netfrom src.utils.utils import get_final_preds
import numpy as npdef evaluate_model(model, dataset, output_path):"""Evaluate"""num_samples = len(dataset)all_preds = np.zeros((num_samples, 17, 3),dtype=np.float32)all_boxes = np.zeros((num_samples, 6))image_path = []for i, data in enumerate(dataset):input_data, target, meta = data[0], data[1], data[3]input_data = ms.Tensor(input_data[0], ms.float32).reshape(1, 3, 256, 192)shit = model(input_data).asnumpy()target = target.reshape(shit.shape)c = meta['center'].reshape(1, 2)s = meta['scale'].reshape(1, 2)score = meta['score']preds, maxvals = get_final_preds(shit, c, s)all_preds[i:i + 1, :, 0:2] = preds[:, :, 0:2]all_preds[i:i + 1, :, 2:3] = maxvals# double check this all_boxes partsall_boxes[i:i + 1, 0:2] = c[:, 0:2]all_boxes[i:i + 1, 2:4] = s[:, 0:2]all_boxes[i:i + 1, 4] = np.prod(s*200, 1)all_boxes[i:i + 1, 5] = scoreimage_path.append(meta['image'])dataset.evaluate(0, all_preds, output_path, all_boxes, image_path)net_dict = load_checkpoint("./ckpt/litehrnet_18_coco_256x192.ckpt")
load_param_into_net(net, net_dict)eval_ds = COCODataset(cfg, "./", "val2017", False, transform=trans)
evaluate_model(net, eval_ds, "./result")

在这里插入图片描述

模型推理

  1. Lite-HRNet是关键点检测模型,所以输入待推理图像应为包含单个人体的图像,作者在论文中提及在coco test 2017测试前已使用SimpleBaseline生成的目标检测Bounding Box处理图像,所以待推理图像应仅包含单个人体。
  2. 网络的输入为(1,3,256,192),所以在输入图像前应先将其变换成网络可处理的形式。
import cv2
from src.utils.utils import get_max_preds
origin_img = cv2.imread("./imgs/man.jpg")
origin_h, origin_w, _ = origin_img.shape
scale_factor = [origin_w/192, origin_h/256]# resize to (112 112 3) and convert to tensor
img = cv2.resize(origin_img, (192, 256))
print(img.shape)
img = trans(img)
# img = np.expand_dims(img, axis=0)
img = ms.Tensor(img)
print(img.shape)# Infer
heatmap_pred = net(img).asnumpy()
pred, _ = get_max_preds(heatmap_pred)# Postprocess
pred = pred.reshape(pred.shape[0], -1, 2)
print(pred[0])
pre_landmark = pred[0] * 4 * scale_factor
# Draw points
for (x, y) in pre_landmark.astype(np.int32):cv2.circle(origin_img, (x, y), 3, (255, 255, 255), -1)# Save image
cv2.imwrite("./imgs/man_infer.jpg", origin_img)

在这里插入图片描述

可以看到模型基本正确标注出了关键点的位置\

在这里插入图片描述

算法基本流程

  1. 获取原始数据
  2. 从数据集的标注json文件中得到各个图像bbox以及关键点坐标信息
  3. 根据bbox裁剪图像,并放缩至指定尺寸,如果是训练还可以作适当数据增强,生成指定尺寸的目标热力图
  4. 指定尺寸的输入经过网络前向传播后得到预测的关键点热力图
  5. 经过处理后取热力图中的最大值坐标作为关键点的预测坐标

相关文章:

华为开源自研AI框架昇思MindSpore应用案例:人体关键点检测模型Lite-HRNet

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区 一、环境准备 1.进入ModelArts官网 云平台帮助用户快速创建和部署模型,管理全周期AI工作流,选择下面的云平台以开始使用昇思MindSpore,获取安装命令,安装MindSpo…...

每日OJ题_牛客_天使果冻_递推_C++_Java

目录 牛客_天使果冻_递推 题目解析 C代码 Java代码 牛客_天使果冻_递推 天使果冻 描述: 有 n 个果冻排成一排。第 i 个果冻的美味度是 ai。 天使非常喜欢吃果冻,但她想把最好吃的果冻留到最后收藏。天使想知道前 x个果冻中,美味…...

独立站干货:WordPress主机推荐

WordPress作为全球最受欢迎的独立站建设平台,提供了灵活性和强大的功能,使得建站变得简单而高效。本文将为您详细介绍WordPress建站的流程,并推荐几款实测后觉得好用的主机商。 WordPress建站流程 域名注册 首先需要注册一个域名&#xff0c…...

支持多种快充协议和支持多种功能的诱骗取电协议芯片

汇铭达XSP15是一款应用于手持电动工具、智能家居、显示器、音箱等充电方案的大功率快充协议芯片,支持最大功率100W给设备快速充电,大大缩短了充电时间。芯片支持通过UART串口发送电压/电流消息供其它芯片读取。支持自动识别连接的是电脑或是充电器。支持…...

Android中常见内存泄漏的场景和解决方案

本文讲解Android 开发中常见内存泄漏场景及其解决方案,内容包括代码示例、原因分析以及最佳实践建议。 1. 静态变量导致的内存泄漏 静态变量的生命周期与应用进程一致,如果静态变量持有了对 Activity 或其他大对象的引用,就可能导致内存泄漏…...

MyBatis Plus中的@TableId注解

TableId 注解用于将某个成员变量指定为数据表主键,以下为使用示例: import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import lo…...

java基础概念33:常见API-Objects工具类

一、使用场景 二、成员方法 2-1、equals方法 源码: 2-2、isNull方法、nonNull方法 三、小结...

脚手架vue-cli,webpack模板

先安装node.js,它是服务器端,用于给页面提供服务。前端学习不需要会node.js,只需要学会node.js衍生出来的npm命令即可。 npm 是node.js的一个工具,作用是进行包管理,npm是node.js的包管理器。 接着安装脚手架&#xff…...

什么是React Native?

写在前面 React Native (RN) 是一个由 Facebook 开发的开源框架,用于构建跨平台的移动应用程序。它允许开发者使用 JavaScript 和 React 来创建原生 iOS 和 Android 应用。RN 的出现极大地简化了移动应用的开发过程,使得开发者可以更快速、更高效地构建…...

Three.js LOD(Level of Detail)通过根据视距调整渲染细节的技术

在 Three.js 中,LOD(Level of Detail)技术是一种通过根据视距调整渲染细节的技术,旨在提高渲染性能并优化用户体验。LOD 技术尤其在处理复杂场景或高多边形模型时显得尤为重要。在这篇博客中,我们将详细介绍 LOD 的概念…...

Vulnhub靶场案例渗透[12]-Grotesque: 1.0.1

文章目录 一、靶场搭建1. 靶场描述2. 下载靶机环境3. 靶场搭建 二、渗透靶场1. 确定靶机IP2. 探测靶场开放端口及对应服务3. 目录扫描4. 敏感信息获取5. 反弹shell6. 权限提升 一、靶场搭建 1. 靶场描述 get flags difficulty: medium about vm: tested and exported from vi…...

招聘和面试

本篇内容是根据2019年4月份#82 Hiring and job interviews音频录制内容的整理与翻译 小组成员 Mat Ryer、Ashley McNamara、Johnny Boursiquot 和 Carmen Andoh 讨论了受聘、雇用和工作面试的过程。如果人是团队中最重要的部分,我们如何选择与谁一起工作&#xff1…...

Gin 框架入门(GO)-1

解决安装包失败问题(*) go env -w GO111MODULE=on go env -w GOPROXY=https://goproxy.cn,direct 1 介绍 Gin 是一个 Go (Golang) 编写的轻量级 http web 框架,运行速度非常快,Gin 最擅长的就是 Api 接口的高并发。 2 Gin 环境搭建 1.下载并安装 gin go get -u github.…...

LeetCode:700. 二叉搜索树中的搜索

目录 题目描述: 代码: 题目描述: 给定二叉搜索树(BST)的根节点 root 和一个整数值 val。 你需要在 BST 中找到节点值等于 val 的节点。 返回以该节点为根的子树。 如果节点不存在,则返回 null 。 示例 1: 输入:root [4,2,7,1,3…...

用邻接矩阵实现图的深度优先遍历

问题描述 给定一个无向图,用邻接矩阵作为图的存储结构,输出指定顶点出发的深度优先遍历序列。在深度优先遍历的过程中,如果同时出现多个待访问的顶点,则优先选择编号最小的一个进行访问。 输入描述 第一行输入三个正整数&#…...

vue2中实现token的无感刷新

后端配置 设置Token过期时间:在后端(如服务器或网关)配置access_token和refresh_token的过期时间。通常,access_token的过期时间较短,而refresh_token的过期时间较长。提供刷新Token接口:后端需要提供一个…...

无需Photoshop即可在线裁剪和调整图像大小的工具

Bitmind是一个灵活且易于使用的批量图像本地化处理器,经过抓包看,这个工具在浏览器本地运行,不会上传图片到服务器,所以安全性完全有保证。 它可以将图像调整到任何特定尺寸,并在必要时按比例裁剪。 这是一个在线工具…...

云安全之法律和合规

0x00 前言 本文主要内容是从法律,合同,电子举证,以及合规和审计这五个部分来记录一下相关的云安全内容 0x01 法律 受法律约束的影响因素 云服务所在的地区云用户所在的区域数据主体所在的区域 GDPR:通用数据保护法案&#xf…...

倒计时功能分享

今天想要分享的是一个面试题,也是一个我们在项目中常用的功能:倒计时。 首先我们在写倒计时的时候必须要考虑到是:准确性、性能。接下来我们一步一步实现这个完美地倒计时功能。 setInterval 先来简单实现一个倒计时的函数: func…...

【论文分享】使用多源数据识别建筑功能:以中国三大城市群为例

建筑功能对城市规划至关重要,而利用多源数据进行建筑功能分类有助于支持城市规划政策。本研究通过分析建筑特征和POI密度,识别了中国三个城市群的建筑功能,并使用XGBoost模型验证了其在大规模映射中的高准确性和有效性。研究强调了建筑环境对…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

多场景 OkHttpClient 管理器 - Android 网络通信解决方案

下面是一个完整的 Android 实现&#xff0c;展示如何创建和管理多个 OkHttpClient 实例&#xff0c;分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中&#xff0c;各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过&#xff0c;在涉及到多个子类派生于基类进行多态模拟的场景下&#xff0c;…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时&#xff0c;与数据库的交互无疑是核心环节。虽然传统的数据库操作方式&#xff08;如直接编写SQL语句与psycopg2交互&#xff09;赋予了我们精细的控制权&#xff0c;但在面对日益复杂的业务逻辑和快速迭代的需求时&#xff0c;这种方式的开发效率和可…...

Objective-C常用命名规范总结

【OC】常用命名规范总结 文章目录 【OC】常用命名规范总结1.类名&#xff08;Class Name)2.协议名&#xff08;Protocol Name)3.方法名&#xff08;Method Name)4.属性名&#xff08;Property Name&#xff09;5.局部变量/实例变量&#xff08;Local / Instance Variables&…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

对WWDC 2025 Keynote 内容的预测

借助我们以往对苹果公司发展路径的深入研究经验&#xff0c;以及大语言模型的分析能力&#xff0c;我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际&#xff0c;我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测&#xff0c;聊作存档。等到明…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

uniapp中使用aixos 报错

问题&#xff1a; 在uniapp中使用aixos&#xff0c;运行后报如下错误&#xff1a; AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...