当前位置: 首页 > news >正文

momask-codes 部署踩坑笔记

目录

依赖项

t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns

推理代码完善:


代码地址:

https://github.com/EricGuo5513/momask-codes

依赖项

pip install numpy==1.23

matplotlib 必须指定版本:pip install matplotlib==3.3.4

t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns

下载模型:

cd t2m 
echo -e "Downloading pretrained models for HumanML3D dataset"
gdown --fuzzy https://drive.google.com/file/d/1vXS7SHJBgWPt59wupQ5UUzhFObrnGkQ0/view?usp=sharing

echo -e "Unzipping humanml3d_models.zip"
unzip humanml3d_models.zip

推理代码完善:

text_motion/momask-codes/gen_t2m.py

# coding=utf-8
import sys
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
current_dir = os.path.dirname(os.path.abspath(__file__))paths = [os.path.abspath(__file__).split('scripts')[0]]
print('current_dir',current_dir)
paths.append(os.path.abspath(os.path.join(current_dir, 'src')))for path in paths:sys.path.insert(0, path)os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')import os
from os.path import join as pjoinimport torch
import torch.nn.functional as Ffrom models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
from models.vq.model import RVQVAE, LengthEstimator# from options.eval_option import EvalT2MOptionsfrom options.base_option import BaseOptionsclass EvalT2MOptions(BaseOptions):def initialize(self):BaseOptions.initialize(self)self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}')self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size')self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder')self.parser.add_argument("--num_batch", default=2, type=int,help="Number of batch for generation")self.parser.add_argument("--repeat_times", default=1, type=int,help="Number of repetitions, per sample text prompt")self.parser.add_argument("--cond_scale", default=4, type=float,help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")self.parser.add_argument("--temperature", default=1., type=float,help="Sampling Temperature.")self.parser.add_argument("--topkr", default=0.9, type=float,help="Filter out percentil low prop entries.")self.parser.add_argument("--time_steps", default=18, type=int,help="Mask Generate steps.")self.parser.add_argument("--seed", default=10107, type=int)self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.')self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.')# self.parser.add_argument('--est_length', action="store_true", help='Training iterations')self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer')self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file')self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section''type int will specify the token frame, type float will specify the ratio of seq_len')self.parser.add_argument('--text_prompt', default='A person is running on a treadmill.', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.")self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)")self.parser.add_argument("--motion_length", default=0, type=int,help="Motion length for generation, only applicable with single text prompt.")self.is_train = Falsefrom utils.get_opt import get_optfrom utils.fixseed import fixseed
from visualization.joints2bvh import Joint2BVHConvertor
from torch.distributions.categorical import Categoricalfrom utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motionfrom utils.paramUtil import t2m_kinematic_chainimport numpy as np
clip_version = 'ViT-B/32'def load_vq_model(vq_opt):# opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')vq_model = RVQVAE(vq_opt,vq_opt.dim_pose,vq_opt.nb_code,vq_opt.code_dim,vq_opt.output_emb_width,vq_opt.down_t,vq_opt.stride_t,vq_opt.width,vq_opt.depth,vq_opt.dilation_growth_rate,vq_opt.vq_act,vq_opt.vq_norm)ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),map_location='cpu')model_key = 'vq_model' if 'vq_model' in ckpt else 'net'vq_model.load_state_dict(ckpt[model_key])print(f'Loading VQ Model {vq_opt.name} Completed!')return vq_model, vq_optdef load_trans_model(model_opt, opt, which_model):t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,cond_mode='text',latent_dim=model_opt.latent_dim,ff_size=model_opt.ff_size,num_layers=model_opt.n_layers,num_heads=model_opt.n_heads,dropout=model_opt.dropout,clip_dim=512,cond_drop_prob=model_opt.cond_drop_prob,clip_version=clip_version,opt=model_opt)ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),map_location='cpu')model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'# print(ckpt.keys())missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)assert len(unexpected_keys) == 0assert all([k.startswith('clip_model.') for k in missing_keys])print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!')return t2m_transformerdef load_res_model(res_opt, vq_opt, opt):res_opt.num_quantizers = vq_opt.num_quantizersres_opt.num_tokens = vq_opt.nb_coderes_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,cond_mode='text',latent_dim=res_opt.latent_dim,ff_size=res_opt.ff_size,num_layers=res_opt.n_layers,num_heads=res_opt.n_heads,dropout=res_opt.dropout,clip_dim=512,shared_codebook=vq_opt.shared_codebook,cond_drop_prob=res_opt.cond_drop_prob,# codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,share_weight=res_opt.share_weight,clip_version=clip_version,opt=res_opt)ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),map_location=opt.device)missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)assert len(unexpected_keys) == 0assert all([k.startswith('clip_model.') for k in missing_keys])print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')return res_transformerdef load_len_estimator(opt):model = LengthEstimator(512, 50)ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'),map_location=opt.device)model.load_state_dict(ckpt['estimator'])print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!')return modelif __name__ == '__main__':parser = EvalT2MOptions()opt = parser.parse()fixseed(opt.seed)opt.device = torch.device("cuda:1")torch.autograd.set_detect_anomaly(True)dim_pose = 251 if opt.dataset_name == 'kit' else 263# out_dir = pjoin(opt.check)root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)model_dir = pjoin(root_dir, 'model')result_dir = pjoin('./generation', opt.ext)joints_dir = pjoin(result_dir, 'joints')animation_dir = pjoin(result_dir, 'animations')os.makedirs(joints_dir, exist_ok=True)os.makedirs(animation_dir,exist_ok=True)model_opt_path = pjoin(root_dir, 'opt.txt')model_opt = get_opt(model_opt_path, device=opt.device)#############################Loading RVQ#############################vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')vq_opt = get_opt(vq_opt_path, device=opt.device)vq_opt.dim_pose = dim_posevq_model, vq_opt = load_vq_model(vq_opt)model_opt.num_tokens = vq_opt.nb_codemodel_opt.num_quantizers = vq_opt.num_quantizersmodel_opt.code_dim = vq_opt.code_dim#######################################Loading R-Transformer#######################################res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')res_opt = get_opt(res_opt_path, device=opt.device)res_model = load_res_model(res_opt, vq_opt, opt)assert res_opt.vq_name == model_opt.vq_name#######################################Loading M-Transformer#######################################t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')#######################################Loading Length Predictor#######################################length_estimator = load_len_estimator(model_opt)t2m_transformer.eval()vq_model.eval()res_model.eval()length_estimator.eval()res_model.to(opt.device)t2m_transformer.to(opt.device)vq_model.to(opt.device)length_estimator.to(opt.device)##### ---- Dataloader ---- #####opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))def inv_transform(data):return data * std + meanprompt_list = []length_list = []est_length = Falseif opt.text_prompt != "":prompt_list.append(opt.text_prompt)if opt.motion_length == 0:est_length = Trueelse:length_list.append(opt.motion_length)elif opt.text_path != "":with open(opt.text_path, 'r') as f:lines = f.readlines()for line in lines:infos = line.split('#')prompt_list.append(infos[0])if len(infos) == 1 or (not infos[1].isdigit()):est_length = Truelength_list = []else:length_list.append(int(infos[-1]))else:raise "A text prompt, or a file a text prompts are required!!!"# print('loading checkpoint {}'.format(file))if est_length:print("Since no motion length are specified, we will use estimated motion lengthes!!")text_embedding = t2m_transformer.encode_text(prompt_list)pred_dis = length_estimator(text_embedding)probs = F.softmax(pred_dis, dim=-1)  # (b, ntoken)token_lens = Categorical(probs).sample()  # (b, seqlen)# lengths = torch.multinomial()else:token_lens = torch.LongTensor(length_list) // 4token_lens = token_lens.to(opt.device).long()m_length = token_lens * 4captions = prompt_listsample = 0kinematic_chain = t2m_kinematic_chainconverter = Joint2BVHConvertor()for r in range(opt.repeat_times):print("-->Repeat %d"%r)with torch.no_grad():mids = t2m_transformer.generate(captions, token_lens,timesteps=opt.time_steps,cond_scale=opt.cond_scale,temperature=opt.temperature,topk_filter_thres=opt.topkr,gsample=opt.gumbel_sample)# print(mids)# print(mids.shape)mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5)pred_motions = vq_model.forward_decoder(mids)pred_motions = pred_motions.detach().cpu().numpy()data = inv_transform(pred_motions)for k, (caption, joint_data)  in enumerate(zip(captions, data)):print("---->Sample %d: %s %d"%(k, caption, m_length[k]))animation_path = pjoin(animation_dir, str(k))joint_path = pjoin(joints_dir, str(k))print('save_dir',animation_path)os.makedirs(animation_path, exist_ok=True)os.makedirs(joint_path, exist_ok=True)joint_data = joint_data[:m_length[k]]joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))_, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))_, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20)plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)

相关文章:

momask-codes 部署踩坑笔记

目录 依赖项 t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns 推理代码完善: 代码地址: https://github.com/EricGuo5513/momask-codes 依赖项 pip install numpy1.23 matplotlib 必须指定版本:pip install matplotlib3.3.4 t2m_nlayer…...

H3CNE-31-BFD

Bidirectional Forwarding Dection,双向转发检查 作用:毫秒级故障检查,通常结合三层协议(静态路由、vrrp、ospf、BGP等),实现链路故障快速检查。 BFD配置示例 没有中间的SW,接口down&#xff…...

蓝桥备赛指南(5)

queue队列 queue是一种先进先出的数据结构。它提供了一组函数来操作和访问元素,但它的功能相对较简单,queue函数的内部实现了底层容器来存储元素,并且只能通过特定的函数来访问和操作元素。 queue函数的常用函数 1.push()函数:…...

讯飞智作 AI 配音技术浅析(一)

一、核心技术 讯飞智作 AI 配音技术作为科大讯飞在人工智能领域的重要成果,融合了多项前沿技术,为用户提供了高质量的语音合成服务。其核心技术主要涵盖以下几个方面: 1. 深度学习与神经网络 讯飞智作 AI 配音技术以深度学习为核心驱动力&…...

MySQL(高级特性篇) 14 章——MySQL事务日志

事务有4种特性:原子性、一致性、隔离性和持久性 事务的隔离性由锁机制实现事务的原子性、一致性和持久性由事务的redo日志和undo日志来保证(1)REDO LOG称为重做日志,用来保证事务的持久性(2)UNDO LOG称为回…...

openRv1126 AI算法部署实战之——TensorFlow TFLite Pytorch ONNX等模型转换实战

Conda简介 查看当前系统的环境列表 conda env list base为基础环境 py3.6-rknn-1.7.3为模型转换环境,rknn-toolkit版本V1.7.3,python版本3.6 py3.6-tensorflow-2.5.0为tensorflow模型训练环境,tensorflow版本2.5.0,python版本…...

【Redis】常见面试题

什么是Redis? Redis 和 Memcached 有什么区别? 为什么用 Redis 作为 MySQL 的缓存? 主要是因为Redis具备高性能和高并发两种特性。 高性能:MySQL中数据是从磁盘读取的,而Redis是直接操作内存,速度相当快…...

每日 Java 面试题分享【第 17 天】

欢迎来到每日 Java 面试题分享栏目! 订阅专栏,不错过每一天的练习 今日分享 3 道面试题目! 评论区复述一遍印象更深刻噢~ 目录 问题一:Java 中的访问修饰符有哪些?问题二:Java 中静态方法和实例方法的区…...

「全网最细 + 实战源码案例」设计模式——桥接模式

核心思想 桥接模式(Bridge Pattern)是一种结构型设计模式,将抽象部分与其实现部分分离,使它们可以独立变化。降低代码耦合度,避免类爆炸,提高代码的可扩展性。 结构 1. Implementation(实现类…...

JavaScript 进阶(上)

作用域 局部作用域 局部作用域分为函数作用域和块作用域。 函数作用域: 在函数内部声明的变量只能在函数内部被访问,外部无法直接访问。 总结: 函数内部声明的变量,在函数外部无法被访问 函数的参数也是函数内部的局部变量 …...

【编译原理实验二】——自动机实验:NFA转DFA并最小化

本篇适用于ZZU的编译原理课程实验二——自动机实验:NFA转DFA并最小化,包含了实验代码和实验报告的内容,读者可根据需要参考完成自己的程序设计。 如果是ZZU的学弟学妹看到这篇,那么恭喜你,你来对地方啦! 如…...

深入探讨:服务器如何响应前端请求及后端如何查看前端提交的数据

深入探讨:服务器如何响应前端请求及后端如何查看前端提交的数据 一、服务器如何响应前端请求 前端与后端的交互主要通过 HTTP 协议实现。以下是详细步骤: 1. 前端发起 HTTP 请求 GET 请求:用于从服务器获取数据。POST 请求:用…...

如何利用Docker和.NET Core实现环境一致性、简化依赖管理、快速部署与扩展,同时提高资源利用率、确保安全性和生态系统支持

目录 1. 环境一致性 2. 简化依赖管理 3. 快速部署与扩展 4. 提高资源利用率 5. 确保安全性 6. 生态系统支持 总结 使用 Docker 和 .NET Core 结合,可以有效地实现环境一致性、简化依赖管理、快速部署与扩展,同时提高资源利用率、确保安全性和生态…...

@Inject @Qualifier @Named

Inject Qualifier Named 在依赖注入(DI)中,Inject、Qualifier 和 Named 是用于管理对象创建和绑定的关键注解。以下是它们的用途、依赖配置和代码示例的详细说明: 1. 注解的作用 Inject:标记需要注入的构造函数、字段…...

创建 priority_queue - 进阶(内置类型)c++

内置类型就是 C 提供的数据类型,⽐如 int 、 double 、 long long 等。以 int 类型为例,分 别创建⼤根堆和⼩根堆。 这种写法意思是,我要告诉这个优先级队列要建一个什么样的堆,第一个int是要存什么数据类型,vecto…...

2. Java-MarkDown文件解析-工具类

2. Java-MarkDown文件解析-工具类 1. 思路 读取markdown文件的内容&#xff0c;根据markdown的语法进行各个类型语法的解析。引入工具类 commonmark 和 commonmark-ext-gfm-tables进行markdown语法解析。 2. 工具类 pom.xml <!-- commonmark 解析markdown --> <d…...

动态规划DP 最长上升子序列模型 登山(题目分析+C++完整代码)

概览检索 动态规划DP 最长上升子序列模型 登山 原题链接 AcWing 1014. 登山 题目描述 五一到了&#xff0c;ACM队组织大家去登山观光&#xff0c;队员们发现山上一共有N个景点&#xff0c;并且决定按照顺序来浏览这些景点&#xff0c;即每次所浏览景点的编号都要大于前一个…...

css-设置元素的溢出行为为可见overflow: visible;

1.前言 overflow 属性用于设置当元素的内容溢出其框时如何处理。 2. overflow overflow 属性的一些常见值&#xff1a; 1 visible&#xff1a;默认值。内容不会被剪裁&#xff0c;会溢出元素的框。 2 hidden&#xff1a;内容会被剪裁&#xff0c;不会显示溢出的部分。 3 sc…...

家居EDI:Hom Furniture EDI需求分析

HOM Furniture 是一家成立于1977年的美国家具零售商&#xff0c;总部位于明尼苏达州。公司致力于提供高品质、时尚的家具和家居用品&#xff0c;满足各种家庭和办公需求。HOM Furniture 以广泛的产品线和优质的客户服务在市场上赢得了良好的口碑。公司经营的产品包括卧室、客厅…...

1、开始简单使用rag

文章目录 前言数据存放申请api开始代码安装依赖从文件夹中读取文档文档切块将分割嵌入并存储在向量库中检索部分代码构造用户接口演示提示 整体代码 前言 本章只是简单使用rag的一个示例&#xff0c;为了引出以后的学习&#xff0c;将整个rag的流程串起来 数据存放 一个示例…...

Linux Samba 低版本漏洞(远程控制)复现与剖析

目录 前言 漏洞介绍 漏洞原理 产生条件 漏洞影响 防御措施 复现过程 结语 前言 在网络安全的复杂生态中&#xff0c;系统漏洞的探索与防范始终是保障数字世界安全稳定运行的关键所在。Linux Samba 作为一款在网络共享服务领域应用极为广泛的软件&#xff0c;其低版本中…...

安卓(android)实现注册界面【Android移动开发基础案例教程(第2版)黑马程序员】

一、实验目的&#xff08;如果代码有错漏&#xff0c;可查看源码&#xff09; 1.掌握LinearLayout、RelativeLayout、FrameLayout等布局的综合使用。 2.掌握ImageView、TextView、EditText、CheckBox、Button、RadioGroup、RadioButton、ListView、RecyclerView等控件在项目中的…...

【 AI agents】letta:2024年代理堆栈演进(中英文翻译)

The AI agents stack AI 代理堆栈 November 14, 2024 11月 14, 2024原文: The AI agents stack官方教程教程学习笔记: 【memgpt】letta 课程1/2:从头实现一个自我编辑、记忆和多步骤推理的代理Understanding the AI agents landscape 了解 AI 代理环境 Although we see a …...

Java中 instanceof 的用法(详解)

目录 引言 基本语法 基本作用 1. 检查对象是否是指定类的实例 2. 检查对象是否是子类的实例 3. 检查对象是否实现某个接口 4.null 处理 错误分析&#xff1a; 5.综合对比示例 最后总结 注意事项 引言 instanceof 概念在多态中引出&#xff0c;因为在多态发生时&…...

联想拯救者R720笔记本外接显示屏方法,显示屏是2K屏27英寸

晚上23点10分前下单&#xff0c;第二天上午显示屏送到&#xff0c;检查外包装没拆封过。这个屏幕左下方有几个按键&#xff0c;按一按就开屏幕、按一按就关闭屏幕&#xff0c;按一按方便节省时间&#xff0c;也支持阅读等模式。 显示屏是 &#xff1a;AOC 27英寸 2K高清 100Hz…...

【RocketMQ 存储】- 一文总结 RocketMQ 的存储结构-基础

文章目录 1. 前言 本文章基于 RocketMQ 4.9.3 1. 前言 RocketMQ 存储部分系列文章&#xff1a; 【RocketMQ 存储】- RocketMQ存储类 MappedFile 【RocketMQ 存储】- 一文总结 RocketMQ 的存储结构-基础 【RocketMQ 存储】- 一文总结 RocketMQ 的存储结构-基础...

S4 HANA明确税金本币和外币之间转换汇率确定(OBC8)

本文主要介绍在S4 HANA OP中明确明确税金本币和外币之间转换汇率确定(OBC8)相关设置。具体请参照如下内容&#xff1a; 明确税金本币和外币之间转换汇率确定(OBC8) 以上配置&#xff0c;我们可以根据不同公司代码所配置的使用不同的汇率来对税金外币和本币之间进行换算。来针对…...

Cocos Creator 3.8 2D 游戏开发知识点整理

目录 Cocos Creator 3.8 2D 游戏开发知识点整理 1. Cocos Creator 3.8 概述 2. 2D 游戏核心组件 (1) 节点&#xff08;Node&#xff09;与组件&#xff08;Component&#xff09; (2) 渲染组件 (3) UI 组件 3. 动画系统 (1) 传统帧动画 (2) 动画编辑器 (3) Spine 和 …...

梯度提升用于高效的分类与回归

使用 决策树&#xff08;Decision Tree&#xff09; 实现 梯度提升&#xff08;Gradient Boosting&#xff09; 主要是模拟 GBDT&#xff08;Gradient Boosting Decision Trees&#xff09; 的原理&#xff0c;即&#xff1a; 第一棵树拟合原始数据计算残差&#xff08;负梯度…...

【单细胞第二节:单细胞示例数据分析-GSE218208】

GSE218208 1.创建Seurat对象 #untar(“GSE218208_RAW.tar”) rm(list ls()) a data.table::fread("GSM6736629_10x-PBMC-1_ds0.1974_CountMatrix.tsv.gz",data.table F) a[1:4,1:4] library(tidyverse) a$alias:gene str_split(a$alias:gene,":",si…...