利用MMPose进行姿态估计(训练、测试全流程)
前言
MMPose
是一款基于PyTorch
的姿态分析开源工具箱,是OpenMMLab
项目成员之一,主要特性:- 支持多种人体姿态分析相关任务:2D多人姿态估计、2D手部姿态估计、动物关键点检测等等
- 更高的精度和更快的速度:包括“自顶向下”和“自底向上”两大类算法
- 支持多样的数据集:支持了很多主流数据集的准备和构建,如 COCO、 MPII等
- 模块化设计:将统一的人体姿态分析框架解耦成不同的模块组件,通过组合不同的模块组件,可以便捷地构建自定义人体姿态分析模型
- 本文主要对动物关键点检测模型进行微调与测试,从数据集构造开始,详细解释各模块作用。对一些新手可能会犯的错误做一些说明
- 本文使用的数据集为
kaggle
平台中Cat Dataset
数据集,数据说明,环境为kaggle
平台提供的P100 GPU
,完整的Jupyter Notebook
,放在这里,欢迎大家Copy & Edit
环境配置
mmcv
的安装方式在我前面的mmdetection
和mmsegmentation
教程中都有写到。这里不再提MMPose
安装方法最好是使用git
,如果没有git
工具,可以使用mim install mmpose
- 最后在项目文件夹下新建
checkpoint
、outputs
、data
文件夹,分别用来存放模型预训练权重、模型输出结果、训练数据
from IPython import display
!pip install openmim
!pip install -q /kaggle/input/frozen-packages-mmdetection/mmcv-2.0.1-cp310-cp310-linux_x86_64.whl!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection
!pip install -e .%cd ..
!git clone https://github.com/open-mmlab/mmpose.git
%cd mmpose
!pip install -e .!mkdir checkpoint
!mkdir outputs
!mkdir data
display.clear_output()
- 在上面的安装工作完成后,我们检查一下环境,以及核对一下安装版本
from IPython import display
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
print('MMCV版本', mmcv.__version__)
%cd /kaggle/working/mmdetection
import mmdet
print('mmdetection版本', mmdet.__version__)
%cd /kaggle/working/mmpose
import mmpose
print('mmpose版本', mmpose.__version__)
print('CUDA版本', get_compiling_cuda_version())
print('编译器版本', get_compiler_version())
输出:
MMCV版本 2.0.1
/kaggle/working/mmdetection
mmdetection版本 3.1.0
/kaggle/working/mmpose
mmpose版本 1.1.0
CUDA版本 11.8
编译器版本 GCC 11.3
- 为方便后续进行文件操作,导入一些常用库
import os
import io
import json
import shutil
import random
import numpy as np
from pathlib import Pathfrom PIL import Image
from tqdm import tqdm
from mmengine import Configfrom pycocotools.coco import COCO
预训练模型推理
- 在进行姿态估计前需要目标检测将不同的目标检测出来,然后再分别对不同的目标进行姿态估计。所以我们要选择一个目标检测模型。
- 这里选择的是
mmdetection
工具箱中的RTMDet
模型,型号选择RTMDet-l
。配置文件位于mmdetection/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
,我们复制模型权重地址并进行下载。
- 姿态估计模型选择
RTMPose
模型,打开mmpose
项目文件夹projects/rtmpose/README.md
文档,发现RTMPose
模型动物姿态估计(Animal 2d (17 Keypoints)
)仅提供了一个预训练模型。配置文件位于projects/rtmpose/rtmpose/animal_2d_keypoint/rtmpose-m_8xb64-210e_ap10k-256x256.py
,我们复制模型权重地址并进行下载。
- 将预训练权重模型全部放入
mmpose
项目文件夹的checkpoint
文件夹下。
# 下载RTMDet-L模型,用于目标检测
!wget https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth -P checkpoint
# 下载RTMPose模型,用于姿态估计
!wget https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth -P checkpoint
display.clear_output()
MMPose
提供了一个被称为MMPoseInferencer
的、全面的推理API
。这个API
使得用户得以使用所有MMPose
支持的模型来对图像和视频进行模型推理。此外,该API
可以完成推理结果自动化,并方便用户保存预测结果。- 我们使用
Cat Dataset
数据集中的一张图片作为示例,进行模型推理。推理参数说明:det_model
:mmdetection
工具箱中目标检测模型配置文件det_weights
:mmdetection
工具箱中目标检测模型对应预训练权重文件pose2d
:mmpose
工具箱中姿态估计模型配置文件pose2d_weights
:mmpose
工具箱中姿态估计对应预训练权重文件out_dir
:图片生成的文件夹
from mmpose.apis import MMPoseInferencerimg_path = '/kaggle/input/cat-dataset/CAT_00/00000001_012.jpg'
# 使用模型别名创建推断器
inferencer = MMPoseInferencer(det_model = '/kaggle/working/mmdetection/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py',det_weights = 'checkpoint/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth',pose2d = 'projects/rtmpose/rtmpose/animal_2d_keypoint/rtmpose-m_8xb64-210e_ap10k-256x256.py',pose2d_weights = 'checkpoint/rtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth')# MMPoseInferencer采用了惰性推断方法,在给定输入时创建一个预测生成器
result_generator = inferencer(img_path, out_dir = 'outputs', show=False)
result = next(result_generator)
display.clear_output()
- 可视化推理结果
import matplotlib.pyplot as pltimg_og = mmcv.imread(img_path)
img_fuse = mmcv.imread('outputs/visualizations/00000001_012.jpg')fig, axes = plt.subplots(1, 2, figsize=(15, 10))
axes[0].imshow(mmcv.bgr2rgb(img_og))
axes[0].set_title('Original Image')
axes[0].axis('off')axes[1].imshow(mmcv.bgr2rgb(img_fuse))
axes[1].set_title('Keypoint Image')
axes[1].axis('off')
plt.show()
数据处理
数据内容详解
- Cat Dataset包含9000多张猫图像。对于每张图像,都有猫头部的注释,有9个点,2个用于眼睛,1个用于嘴巴,6个用于耳朵。
- 注释数据存储在1个文件中,文件名是相应的图像名称,末尾加上“cat”。每张猫图像都有1个注释文件。对于每个注释文件,注释数据按以下顺序存储:
- Number of points (关键点数目)
- Left Eye(左眼)
- Right Eye(右眼)
- Mouth(嘴)
- Left Ear-1(左耳-1)
- Left Ear-2(左耳-2)
- Left Ear-3(左耳-3)
- Right Ear-1(右耳-1)
- Right Ear-2(右耳-2)
- Right Ear-3(左耳-3)
- 数据集最初在互联网档案馆中找到,网站
- 数据层级目录如下所示:
- CAT_00- 00000001_000.jpg- 00000001_000.jpg.cat- 00000001_005.jpg- 00000001_005.jpg.cat- ...- CAT_01- 00000100_002.jpg- 00000100_002.jpg.cat- 00000100_003.jpg- 00000100_003.jpg.cat- CAT_02- CAT_03- CAT_04- CAT_05- CAT_06
- 总的来说,一共有7个文件夹,每个文件夹里面有若干
.jpg
格式的图片文件,且对应有.cat
格式的注释文件,.cat
文件可以看做是文本文件,内容示例:
9 435 322 593 315 524 446 318 285 283 118 430 195 568 186 701 81 703 267
- 除第1个数字9表示有9个关键点,后面每2个点表示1个部位的坐标
(x,y)
,所以一共有1 + 2 * 9 = 19
个点
文件夹规整
- 我们将数据集中的7个文件夹中的图片与注释文件分开,分别存储在
mmpose
项目文件夹data
文件夹中,并分别命名为images
、ann
def separate_files(og_folder, trans_folder):image_folder = os.path.join(trans_folder, 'images')ann_folder = os.path.join(trans_folder, 'ann')os.makedirs(image_folder, exist_ok=True)os.makedirs(ann_folder, exist_ok=True)for folder in os.listdir(data_folder):folder_path = os.path.join(data_folder, folder)if os.path.isdir(folder_path):for file in os.listdir(folder_path):if file.endswith('.jpg'):source_path = os.path.join(folder_path, file)target_path = os.path.join(image_folder, file)shutil.copy(source_path, target_path)elif file.endswith('.cat'):source_path = os.path.join(folder_path, file)target_path = os.path.join(ann_folder, file)shutil.copy(source_path, target_path)data_folder = '/kaggle/input/cat-dataset'
trans_folder = './data'separate_files(data_folder, trans_folder)
构造COCO注释文件
- 本质上来说COCO就是1个字典文件,第1级键包含
images
、annotations
、categories
。- 其中
images
包含id
(图片的唯一标识,必须要是数值型,不能有字符) 、file_name
(图片名字)、height
(图片高度),width
(图片宽度)这些信息 - 其中
annotations
包含category_id
(图片所属种类)、segmentation
(实例分割掩码)、iscrowd
(决定是RLE
格式还是polygon
格式)、image_id
(图片id
,对应images
键中的id
)、id
(注释信息id)、bbox
(目标检测框,[x, y, width, height]
)、area
(目标检测框面积)、num_keypoints
(关键点数量),keypoints
(关键点坐标) - 其中
categories
包含supercategory
、id
(类别id
)、name
(类别名)、keypoints
(各部位名称)、skeleton
(部位连接信息)
- 其中
- 更详细的COCO注释文件解析推荐博客COCO数据集的标注格式、如何将VOC XML文件转化成COCO数据格式
- 构造
read_file_as_list
函数,将注释文件中的坐标变成[x,y,v]
,v
为0时表示这个关键点没有标注,v
为1时表示这个关键点标注了但是不可见(被遮挡了),v
为2时表示这个关键点标注了同时可见。因为数据集中部位坐标均标注且可见,所以在x,y
坐标后均插入2。
def read_file_as_list(file_path):with open(file_path, 'r') as file:content = file.read()key_point = [int(num) for num in content.split()]key_num = key_point[0]key_point.pop(0)for i in range(2, len(key_point) + len(key_point)//2, 2 + 1):key_point.insert(i, 2)return key_num,key_point
- 构造
get_image_size
函数,用于获取图片宽度和高度。
def get_image_size(image_path):with Image.open(image_path) as img:width, height = img.sizereturn width, height
- 因为数据集没有提供目标检测框信息,且图片中基本无干扰元素,所以将目标检测框信息置为
[0, 0, width, height]
即整张图片。相应的目标检测框面积area = width * height
。
# 转换为coco数据格式
def coco_structure(ann_dir,image_dir):coco = dict()coco['images'] = []coco['annotations'] = []coco['categories'] = []coco['categories'].append(dict(supercategory = 'cat',id = 1,name = 'cat',keypoints = ['Left Eye','Right Eye','Mouth','Left Ear-1','Left Ear-2','Left Ear-3','Right Ear-1','Right Ear-2','Right Ear-3'],skeleton = [[0,1],[0,2],[1,2],[3,4],[4,5],[5,6],[6,7],[7,8],[3,8]]))ann_list = os.listdir(ann_dir)id = 0for file_name in tqdm(ann_list):key_num,key_point = read_file_as_list(os.path.join(ann_dir, file_name))if key_num == 9:image_name = os.path.splitext(file_name)[0]image_id = os.path.splitext(image_name)[0]height, width = get_image_size(os.path.join(image_dir, image_name))image = {"id": id, "file_name": image_name, "height": height, "width": width}coco['images'].append(image)key_dict = dict(category_id = 1, segmentation = [], iscrowd = 0, image_id = id, id = id, bbox = [0, 0, width, height], area = width * height, num_keypoints = key_num, keypoints = key_point)coco['annotations'].append(key_dict)id = id + 1return coco
- 写入注释信息,并将其保存为
mmpose
项目文件夹data/annotations_all.json
文件
ann_file = coco_structure('./data/ann','./data/images')
output_file_path = './data/annotations_all.json'
with open(output_file_path, "w", encoding="utf-8") as output_file:json.dump(ann_file, output_file, ensure_ascii=True, indent=4)
拆分训练、测试数据
- 按0.85、0.15的比例将注释文件拆分为训练、测试文件
def split_coco_dataset(coco_json_path: str, save_dir: str, ratios: list,shuffle: bool, seed: int):if not Path(coco_json_path).exists():raise FileNotFoundError(f'Can not not found {coco_json_path}')if not Path(save_dir).exists():Path(save_dir).mkdir(parents=True)ratios = np.array(ratios) / np.array(ratios).sum()if len(ratios) == 2:ratio_train, ratio_test = ratiosratio_val = 0train_type = 'trainval'elif len(ratios) == 3:ratio_train, ratio_val, ratio_test = ratiostrain_type = 'train'else:raise ValueError('ratios must set 2 or 3 group!')coco = COCO(coco_json_path)coco_image_ids = coco.getImgIds()val_image_num = int(len(coco_image_ids) * ratio_val)test_image_num = int(len(coco_image_ids) * ratio_test)train_image_num = len(coco_image_ids) - val_image_num - test_image_numprint('Split info: ====== \n'f'Train ratio = {ratio_train}, number = {train_image_num}\n'f'Val ratio = {ratio_val}, number = {val_image_num}\n'f'Test ratio = {ratio_test}, number = {test_image_num}')seed = int(seed)if seed != -1:print(f'Set the global seed: {seed}')np.random.seed(seed)if shuffle:print('shuffle dataset.')random.shuffle(coco_image_ids)train_image_ids = coco_image_ids[:train_image_num]if val_image_num != 0:val_image_ids = coco_image_ids[train_image_num:train_image_num +val_image_num]else:val_image_ids = Nonetest_image_ids = coco_image_ids[train_image_num + val_image_num:]categories = coco.loadCats(coco.getCatIds())for img_id_list in [train_image_ids, val_image_ids, test_image_ids]:if img_id_list is None:continueimg_dict = {'images': coco.loadImgs(ids=img_id_list),'categories': categories,'annotations': coco.loadAnns(coco.getAnnIds(imgIds=img_id_list))}if img_id_list == train_image_ids:json_file_path = Path(save_dir, f'{train_type}.json')elif img_id_list == val_image_ids:json_file_path = Path(save_dir, 'val.json')elif img_id_list == test_image_ids:json_file_path = Path(save_dir, 'test.json')else:raise ValueError('img_id_list ERROR!')print(f'Saving json to {json_file_path}')with open(json_file_path, 'w') as f_json:json.dump(img_dict, f_json, ensure_ascii=False, indent=2)print('All done!')
split_coco_dataset('./data/annotations_all.json', './data', [0.85,0.15], True, 2023)
输出:
loading annotations into memory...
Done (t=0.13s)
creating index...
index created!
Split info: ======
Train ratio = 0.85, number = 8495
Val ratio = 0, number = 0
Test ratio = 0.15, number = 1498
Set the global seed: 2023
shuffle dataset.
Saving json to data/trainval.json
Saving json to data/test.json
All done!
- 可以看到训练集有8495张图片,测试集有1498张图片
模型配置文件
- 打开项目文件夹下的
projects/rtmpose/rtmpose/animal_2d_keypoint/rtmpose-m_8xb64-210e_ap10k-256x256.py
文件,发现模型配置文件仅继承_base_/default_runtime.py
文件 - 需要修改主要有
dataset_type
、data_mode
、dataset_info
、codec
、train_dataloader
、test_dataloader
、val_evaluator
、base_lr
、max_epochs
、default_hooks
。还有一些细节我在代码中有标注,可以参照着修改 - 修改完成后将文件写入
./configs/animal_2d_keypoint/cat_keypoint.py
中
custom_config = """
_base_ = ['mmpose::_base_/default_runtime.py']# 数据集类型及路径
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = './data/'
work_dir = './work_dir'# cat dataset关键点检测数据集-元数据
dataset_info = {'dataset_name':'Keypoint_cat','classes':'cat','paper_info':{'author':'Luck','title':'Cat Keypoints Detection',},'keypoint_info':{0:{'name':'Left Eye','id':0,'color':[255,0,0],'type': '','swap': ''},1:{'name':'Right Eye','id':1,'color':[255,127,0],'type': '','swap': ''},2:{'name':'Mouth','id':2,'color':[255,255,0],'type': '','swap': ''},3:{'name':'Left Ear-1','id':3,'color':[0,255,0],'type': '','swap': ''},4:{'name':'Left Ear-2','id':4,'color':[0,255,255],'type': '','swap': ''},5:{'name':'Left Ear-3','id':5,'color':[0,0,255],'type': '','swap': ''},6:{'name':'Right Ear-1','id':6,'color':[139,0,255],'type': '','swap': ''},7:{'name':'Right Ear-2','id':7,'color':[255,0,255],'type': '','swap': ''},8:{'name':'Right Ear-3','id':8,'color':[160,82,45],'type': '','swap': ''}},'skeleton_info': {0: {'link':('Left Eye','Right Eye'),'id': 0,'color': [255,0,0]},1: {'link':('Left Eye','Mouth'),'id': 1,'color': [255,0,0]},2: {'link':('Right Eye','Mouth'),'id': 2,'color': [255,0,0]},3: {'link':('Left Ear-1','Left Ear-2'),'id': 3,'color': [255,0,0]},4: {'link':('Left Ear-2','Left Ear-3'),'id': 4,'color': [255,0,0]},5: {'link':('Left Ear-3','Right Ear-1'),'id': 5,'color': [255,0,0]},6: {'link':('Right Ear-1','Right Ear-2'),'id': 6,'color': [255,0,0]},7: {'link':('Right Ear-2','Right Ear-3'),'id': 7,'color': [255,0,0]},8: {'link':('Left Ear-1','Right Ear-3'),'id': 8,'color': [255,0,0]},}
}# 获取关键点个数
NUM_KEYPOINTS = len(dataset_info['keypoint_info'])
dataset_info['joint_weights'] = [1.0] * NUM_KEYPOINTS
dataset_info['sigmas'] = [0.025] * NUM_KEYPOINTS# 训练超参数
max_epochs = 100
val_interval = 5
train_cfg = {'max_epochs': max_epochs, 'val_begin':20, 'val_interval': val_interval}
train_batch_size = 32
val_batch_size = 32
stage2_num_epochs = 10
base_lr = 4e-3 / 16
randomness = dict(seed=2023)# 优化器
optim_wrapper = dict(type='OptimWrapper',optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),paramwise_cfg=dict(norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))# 学习率
param_scheduler = [dict(type='LinearLR', start_factor=1.0e-5, by_epoch=False, begin=0, end=600),dict(type='CosineAnnealingLR',eta_min=base_lr * 0.05,begin=max_epochs // 2,end=max_epochs,T_max=max_epochs // 2,by_epoch=True,convert_to_iter_based=True),
]# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=1024)# codec settings
# input_size可以换成128的倍数
# sigma高斯分布标准差,越大越易学,但进度低。高精度场景,可以调小,RTMPose 原始论文中为 5.66
# input_size、sigma和下面model中的in_featuremap_size参数需要成比例缩放
codec = dict(type='SimCCLabel',input_size=(512, 512),sigma=(24, 24),simcc_split_ratio=2.0,normalize=False,use_dark=False)# 模型:RTMPose-M
model = dict(type='TopdownPoseEstimator',data_preprocessor=dict(type='PoseDataPreprocessor',mean=[123.675, 116.28, 103.53],std=[58.395, 57.12, 57.375],bgr_to_rgb=True),backbone=dict(_scope_='mmdet',type='CSPNeXt',arch='P5',expand_ratio=0.5,deepen_factor=0.67,widen_factor=0.75,out_indices=(4, ),channel_attention=True,norm_cfg=dict(type='SyncBN'),act_cfg=dict(type='SiLU'),init_cfg=dict(type='Pretrained',prefix='backbone.',checkpoint='https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth')),head=dict(type='RTMCCHead',in_channels=768,out_channels=NUM_KEYPOINTS,input_size=codec['input_size'],in_featuremap_size=(16, 16),simcc_split_ratio=codec['simcc_split_ratio'],final_layer_kernel_size=7,gau_cfg=dict(hidden_dims=256,s=128,expansion_factor=2,dropout_rate=0.,drop_path=0.,act_fn='SiLU',use_rel_bias=False,pos_enc=False),loss=dict(type='KLDiscretLoss',use_target_weight=True,beta=10.,label_softmax=True),decoder=codec),test_cfg=dict(flip_test=True))backend_args = dict(backend='local')# pipelines
train_pipeline = [dict(type='LoadImage', backend_args=backend_args),dict(type='GetBBoxCenterScale'),dict(type='RandomFlip', direction='horizontal'),# dict(type='RandomHalfBody'),dict(type='RandomBBoxTransform', scale_factor=[0.8, 1.2], rotate_factor=30),dict(type='TopdownAffine', input_size=codec['input_size']),dict(type='mmdet.YOLOXHSVRandomAug'),dict(type='Albumentation',transforms=[dict(type='ChannelShuffle', p=0.5),dict(type='CLAHE', p=0.5),# dict(type='Downscale', scale_min=0.7, scale_max=0.9, p=0.2),dict(type='ColorJitter', p=0.5),dict(type='CoarseDropout',max_holes=4,max_height=0.3,max_width=0.3,min_holes=1,min_height=0.2,min_width=0.2,p=0.5),]),dict(type='GenerateTarget', encoder=codec),dict(type='PackPoseInputs')
]val_pipeline = [dict(type='LoadImage', backend_args=backend_args),dict(type='GetBBoxCenterScale'),dict(type='TopdownAffine', input_size=codec['input_size']),dict(type='PackPoseInputs')
]train_pipeline_stage2 = [dict(type='LoadImage', backend_args=backend_args),dict(type='GetBBoxCenterScale'),dict(type='RandomFlip', direction='horizontal'),dict(type='RandomHalfBody'),dict(type='RandomBBoxTransform',shift_factor=0.,scale_factor=[0.75, 1.25],rotate_factor=60),dict(type='TopdownAffine', input_size=codec['input_size']),dict(type='mmdet.YOLOXHSVRandomAug'),dict(type='Albumentation',transforms=[dict(type='Blur', p=0.1),dict(type='MedianBlur', p=0.1),dict(type='CoarseDropout',max_holes=1,max_height=0.4,max_width=0.4,min_holes=1,min_height=0.2,min_width=0.2,p=0.5),]),dict(type='GenerateTarget', encoder=codec),dict(type='PackPoseInputs')
]# data loaders
train_dataloader = dict(batch_size=train_batch_size,num_workers=2,persistent_workers=True,sampler=dict(type='DefaultSampler', shuffle=True),dataset=dict(type=dataset_type,data_root=data_root,metainfo=dataset_info,data_mode=data_mode,ann_file='trainval.json',data_prefix=dict(img='images/'),pipeline=train_pipeline,))
val_dataloader = dict(batch_size=val_batch_size,num_workers=2,persistent_workers=True,drop_last=False,sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),dataset=dict(type=dataset_type,data_root=data_root,metainfo=dataset_info,data_mode=data_mode,ann_file='test.json',data_prefix=dict(img='images/'),pipeline=val_pipeline,))
test_dataloader = val_dataloaderdefault_hooks = {'checkpoint': {'save_best': 'PCK','rule': 'greater','max_keep_ckpts': 2},'logger': {'interval': 50}
}custom_hooks = [dict(type='EMAHook',ema_type='ExpMomentumEMA',momentum=0.0002,update_buffers=True,priority=49),dict(type='mmdet.PipelineSwitchHook',switch_epoch=max_epochs - stage2_num_epochs,switch_pipeline=train_pipeline_stage2)
]# evaluators
val_evaluator = [dict(type='CocoMetric', ann_file=data_root + 'test.json'),dict(type='PCKAccuracy'),dict(type='AUC'),dict(type='NME', norm_mode='keypoint_distance', keypoint_indices=[0, 1])
]test_evaluator = val_evaluator
"""
config = './configs/animal_2d_keypoint/cat_keypoint.py'
with io.open(config, 'w', encoding='utf-8') as f:f.write(custom_config)
模型训练
- 使用训练脚本启动训练
!python tools/train.py {config}
- 因为训练输出太长,这里截取一段模型在测试集上最佳精度:
08/06 19:15:56 - mmengine - INFO - Evaluating CocoMetric...
Loading and preparing results...
DONE (t=0.07s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *keypoints*
DONE (t=0.57s).
Accumulating evaluation results...
DONE (t=0.03s).Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.943Average Precision (AP) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.979Average Precision (AP) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.969Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = -1.000Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.944Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.953Average Recall (AR) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.987Average Recall (AR) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.977Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = -1.000Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.953
08/06 19:15:56 - mmengine - INFO - Evaluating PCKAccuracy (normalized by ``"bbox_size"``)...
08/06 19:15:56 - mmengine - INFO - Evaluating AUC...
08/06 19:15:56 - mmengine - INFO - Evaluating NME...
08/06 19:15:57 - mmengine - INFO - Epoch(val) [60][47/47] coco/AP: 0.943453 coco/AP .5: 0.979424 coco/AP .75: 0.969202 coco/AP (M): -1.000000 coco/AP (L): 0.944082 coco/AR: 0.953471 coco/AR .5: 0.987316 coco/AR .75: 0.977303 coco/AR (M): -1.000000 coco/AR (L): 0.953471 PCK: 0.978045 AUC: 0.801710 NME: 0.121770 data_time: 0.101005 time: 0.435133
08/06 19:15:57 - mmengine - INFO - The previous best checkpoint /kaggle/working/mmpose/work_dir/best_PCK_epoch_55.pth is removed
08/06 19:16:01 - mmengine - INFO - The best checkpoint with 0.9780 PCK at 60 epoch is saved to best_PCK_epoch_60.pth.
- 可以看到模型PCK达到了0.978,AUC达到了0.8017,mAP也都挺高,说明模型效果非常不错!
模型精简
mmpose
提供模型精简脚本,模型训练权重文件大小减少一半,但不影响精度和推理- 将在验证集上表现最好的模型权重进行精简
import glob
ckpt_path = glob.glob('./work_dir/best_PCK_*.pth')[0]
ckpt_sim = './work_dir/cat_pose_sim.pth'
# 模型精简
!python tools/misc/publish_model.py \{ckpt_path} \{ckpt_sim}
模型推理
- 这里和上面的模型推理使用相同的思路,使用
RTMDet
模型进行目标检测,使用我们自己训练的RTMPose
模型进行姿态估计。不过pose2d
参数是我们上面保存的配置文件./configs/animal_2d_keypoint/cat_keypoint.py
,pose2d_weights
为最佳精度模型精简后的权重文件glob.glob('./work_dir/cat_pose_sim*.pth')[0]
。
img_path = '/kaggle/input/cat-dataset/CAT_00/00000001_012.jpg'inferencer = MMPoseInferencer(det_model = '/kaggle/working/mmdetection/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py',det_weights = 'checkpoint/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth',pose2d = './configs/animal_2d_keypoint/cat_keypoint.py',pose2d_weights = glob.glob('./work_dir/cat_pose_sim*.pth')[0])result_generator = inferencer(img_path, out_dir = 'outputs', show=False)
result = next(result_generator)
display.clear_output()
- 可视化训练结果
img_og = mmcv.imread(img_path)
img_fuse = mmcv.imread('outputs/visualizations/00000001_012.jpg')fig, axes = plt.subplots(1, 2, figsize=(15, 10))
axes[0].imshow(mmcv.bgr2rgb(img_og))
axes[0].set_title('Original Image')
axes[0].axis('off')axes[1].imshow(mmcv.bgr2rgb(img_fuse))
axes[1].set_title('Keypoint Image')
axes[1].axis('off')
plt.show()
相关文章:

利用MMPose进行姿态估计(训练、测试全流程)
前言 MMPose是一款基于PyTorch的姿态分析开源工具箱,是OpenMMLab项目成员之一,主要特性: 支持多种人体姿态分析相关任务:2D多人姿态估计、2D手部姿态估计、动物关键点检测等等更高的精度和更快的速度:包括“自顶向下”…...

ROS2 编译含有自定义消息项目报错:msg/detail/header__struct.h: 没有那个文件或目录
项目场景: 当迁移ROS 1 项目到 ROS 2 时,有时候会遇到消息类型的变化和更新,消息类型可能需要进行一些调整以适应新的ROS 2要求。本文将介绍如何处理自定义消息中的Header字段,以确保项目能够顺利地适应ROS 2的消息类型定义。 问…...

线段树思想拆解(下篇)
线段树思想拆解(下篇) 上篇回顾 到这里我们已经处理好了初始化以及添加方法,接下来实现范围的 query 方法 public int query(int queryL, int queryR) {return query(queryL, queryR, 1, orgLength - 1, 1);}到此为止通过借助 sum 数组&…...

Containerd容器镜像管理
1. 轻量级容器管理工具 Containerd 2. Containerd的两种安装方式 3. Containerd容器镜像管理 4. Containerd数据持久化和网络管理 1、Containerd镜像管理 1.1 Containerd容器镜像管理命令 docker使用docker images命令管理镜像单机containerd使用ctr images命令管理镜像,con…...

Baumer工业相机堡盟工业相机如何通过BGAPI SDK获取相机当前数据吞吐量(C#)
Baumer工业相机堡盟工业相机如何通过BGAPISDK里函数来获取相机当前数据吞吐量(C#) Baumer工业相机Baumer工业相机的数据吞吐量的技术背景CameraExplorer如何查看相机吞吐量信息在BGAPI SDK里通过函数获取相机接口吞吐量 Baumer工业相机通过BGAPI SDK获取…...

Ubuntu服务器版配置wifi
最近把曾经不用的上网本安装了一个Ubuntu-Server版,当成服务器来用,因为家庭网络布线问题,只好用自带的WIFI来连接网络,Server版也没有什么图形化的管理工具,之后手动编辑配置文件了。 Server下面配置起来还是很方便的…...

Windows 主机的VMware 虚拟机访问 wsl-ubuntu 的 API 服务
Windows 主机的VMware 虚拟机访问 wsl-ubuntu 的 API 服务 0. 背景1. 设置2. 删除 0. 背景 需要从Windows 主机的VMware 虚拟机访问 wsl-ubuntu 的 API 服务。 1. 设置 Windows 主机的IP:192.168.31.20 wsl-ubuntu Ubuntu-22.04 的IP:172.29.211.52 &…...

【Spring】(一)Spring设计核心思想
文章目录 一、初识 Spring1.1 什么是 Spring1.2 什么是 容器1.3 什么是 IoC 二、对 IoC 的深入理解2.1 传统程序开发方式存在的问题2.2 控制反转式程序的开发2.3 对比总结 三、对 Spring IoC 的理解四、DI 的概念4.1 什么是 DI4.2 DI 与 IoC的关系 一、初识 Spring 1.1 什么是…...

chrome插件开发实例04-智能收藏夹
目录 功能说明 演示 源码下载 源代码 manifest.json popup.html popup.js background.js 功能说明 基于chrome插件...

iOS技术之 手机系统15.0之后 的 UITableView section header多22像素问题
iOS 15 的 UITableView又新增了一个新属性:sectionHeaderTopPadding 会给每一个section header 增加一个默认高度,当我们 使用 UITableViewStylePlain 初始化 UITableView的时候,就会发现,系统给section header增高了22像素。 解…...

Windows下安装Kafka(图文记录详细步骤)
Windows下安装Kafka Kafka简介一、Kafka安装前提安装Kafka之前,需要安装JDK、Zookeeper、Scala。1.1、JDK安装(version:1.8)1.1.1、JDK官网下载1.1.2、JDK网盘下载1.1.3、JDK安装 1.2、Zookeeper安装1.2.1、Zookeeper官网下载1.2.…...

linuxARM裸机学习笔记(3)----主频和时钟配置实验
引言:本文主要学习当前linux该如何去配置时钟频率,这也是重中之重。 系统时钟来源: 32.768KHz 晶振是 I.MX6U 的 RTC 时钟源, 24MHz 晶振是 I.MX6U 内核 和其它外设的时钟源 1. 7路PLL时钟源【都是从24MHZ的晶振PLL而来…...

防勒索病毒
随着勒索软件攻击在2023年的激增,网络安全已成为当今最重要的议题之一。根据区块链分析公司Chainaanalysis的最新报告,勒索软件攻击已成为唯一呈增长趋势的基于加密货币的犯罪行为,勒索金额更是比一年前增加了近1.758亿美元,达到4…...

剑指 Offer 53 - II. 0~n-1 中缺失的数字
力扣 一个长度为n-1的递增排序数组中的所有数字都是唯一的,并且每个数字都在范围0~n-1之内。在范围0~n-1内的n个数字中有且只有一个数字不在该数组中,请找出这个数字。 示例 1: 输入: [0,1,3] 输出: 2 示例 2: 输入: [0,1,2,3,4,5…...

vue2和vue3区别
vue2和vue3的区别有以下8点: 1、双向数据绑定原理不同; 2、是否支持碎片; 3、API类型不同; 4、定义数据变量和方法不同; 5、生命周期钩子函数不同; 6、父子传参不同; 7、指令与插槽不同&#x…...

IMV3.0
经历了两个版本,基础内容在前面,可以使用之前的基础环境: v1: https://blog.csdn.net/wtt234/article/details/132139454 v2: https://blog.csdn.net/wtt234/article/details/132144907 一、代码组织结构 二、代码 2.…...

怎么在树莓派环境上搭建web网站,并发布到外网可访问,今天教给大家
怎么在树莓派上搭建web网站,并发布到外网可访问? 文章目录 怎么在树莓派上搭建web网站,并发布到外网可访问?概述使用 Raspberry Pi Imager 安装 Raspberry Pi OS测试 web 站点安装静态样例站点 将web站点发布到公网安装 Cpolarcpo…...

大文件传输软件| 生命科学中的关键因素
在2023年,生命科学领域以及其先进的科学技术吸引了人们的目光。这些研究背后,很少有人知道的是,其中涉及了大量的研究数据需要实时进行文件传输,以便于研究,合作,分享,分析,临床试验…...

varint编码实现原理
简言 1. varint即 variable int,也就是变长整型,在mysql,levelDB,protobuf中都有使用 2. varint编码的优点是对数值较小的数进行编码后占用字节较少,比如[0-127]只占用1个字节,[128~16383]只占用2个字节。…...

如果新电脑是刚安装的mysql,但是旧电脑迁移过来的文件里面有相关的rails文件,运行rake db:migrate一直报错
$ bundle exec rake db:migrate#运行完命令报错 rake aborted! LoadError: libmysqlclient.so.21: cannot open shared object file: No such file or directory - /home/meiyi/.asdf/installs/ruby/2.6.9/lib/ruby/gems/2.6.0/gems/mysql2-0.5.5/lib/mysql2/mysql2.so /home/m…...

ChatGPT已闯入学术界,Elsevier推出AI工具
2022年11月,OpenAI公司发布了ChatGPT,这是迄今为止人工智能在现实世界中最重要的应用之一。 当前,互联网搜索引擎中出现了越来越多的人工智能(AI)聊天机器人,例如谷歌的Bard和微软的Bing,看起来…...

深度学习论文: RepViT: Revisiting Mobile CNN From ViT Perspective及其PyTorch实现
深度学习论文: RepViT: Revisiting Mobile CNN From ViT Perspective及其PyTorch实现 RepViT: Revisiting Mobile CNN From ViT Perspective PDF: https://arxiv.org/pdf/2307.09283.pdf PyTorch代码: https://github.com/shanglianlm0525/CvPytorch PyTorch代码: https://gith…...

R语言3_安装SeurateData
环境Ubuntu22/20, R4.1 在命令行中键入, apt-get update apt install libcurl4-openssl-dev libssl-dev libxml2-dev libcairo2-dev libgtk-3-dev # libcairo2-dev :: systemfonts # libgtk :: textshaping进入r语言交互环境,键入, instal…...

详解Gillespie算法:使用Python构建分子化学模拟及其在随机多智能体动力学中的应用
第一部分:Gillespie算法简介 Gillespie算法是一种利用蒙特卡洛抽样模拟化学体系随机动力学行为的方法[3]。它是由Joseph L. Doob提出的,用于生成具有已知反应速率的随机方程组的统计上正确的轨迹(可能的解)[5]。在本文中,我们将详细介绍Gillespie算法的原理,并使用Pytho…...

Unity数字可视化学校_昼夜(三)
1、删除不需要的 UI using System.Collections; using System.Collections.Generic; using UnityEngine; using UnityEngine.UI;public class EnvControl : MonoBehaviour {//UIprivate Button btnTime;private Text txtTime; //材质public List<Material> matListnew Li…...

使用罗技鼠标后 弹出当前页面的脚本发生错误AppData/Local/Temp/LogiUI/Pak/js/jquery-1.3.2.min.js解决
使用的台式机,没有蓝牙驱动,在用logi无线鼠标时,把鼠标连接插入台式机后弹出的如上图所示这个提示,无论是点是/否,还是X掉上图提示,电脑右下角的图依然存在。不习惯这丫的存在。 我重启还是有,然…...

Kubernetes(K8s)从入门到精通系列之十四:安装工具
Kubernetes K8s从入门到精通系列之十四:安装工具 一、kubectl二、kind三、minikube四、kubeadm 一、kubectl Kubernetes 命令行工具 kubectl, 让你可以对 Kubernetes 集群运行命令。 你可以使用 kubectl 来部署应用、监测和管理集群资源以及查看日志。 …...

【Python】Python元组学习
Python之元组学习记录 一、元组的特点 可以容纳多个数据可以容纳不同类型的数据(混装)数据是有序存储的(下标索引)允许重复数据存在不可以修改(增加或删除元素等)但内部list元素可以被修改支持while&…...

HTML 元素的属性有哪些?
聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ class⭐ id⭐ style⭐ src⭐ href⭐ alt⭐ width和height⭐ disabled⭐ value⭐ required⭐ placeholder⭐ checked⭐ selected⭐ target⭐ colspan和rowspan⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得…...

Spring之事务实现方式及原理
目录 Spring事务简介 Spring支持事务管理的两种方式 编程式事务控制 声明式事务管理 Spring事务角色 未开启事务之前 开启Spring的事务管理后 事务配置 事务传播行为 事务传播行为的可选值 Spring事务简介 事务作用:在数据层保障一系列的数据库操作同成功…...