momask-codes 部署踩坑笔记
目录
依赖项
t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns
推理代码完善:
代码地址:
https://github.com/EricGuo5513/momask-codes
依赖项
pip install numpy==1.23
matplotlib 必须指定版本:pip install matplotlib==3.3.4
t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns
下载模型:
cd t2m
echo -e "Downloading pretrained models for HumanML3D dataset"
gdown --fuzzy https://drive.google.com/file/d/1vXS7SHJBgWPt59wupQ5UUzhFObrnGkQ0/view?usp=sharing
echo -e "Unzipping humanml3d_models.zip"
unzip humanml3d_models.zip
推理代码完善:
text_motion/momask-codes/gen_t2m.py
# coding=utf-8
import sys
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
current_dir = os.path.dirname(os.path.abspath(__file__))paths = [os.path.abspath(__file__).split('scripts')[0]]
print('current_dir',current_dir)
paths.append(os.path.abspath(os.path.join(current_dir, 'src')))for path in paths:sys.path.insert(0, path)os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')import os
from os.path import join as pjoinimport torch
import torch.nn.functional as Ffrom models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
from models.vq.model import RVQVAE, LengthEstimator# from options.eval_option import EvalT2MOptionsfrom options.base_option import BaseOptionsclass EvalT2MOptions(BaseOptions):def initialize(self):BaseOptions.initialize(self)self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}')self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size')self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder')self.parser.add_argument("--num_batch", default=2, type=int,help="Number of batch for generation")self.parser.add_argument("--repeat_times", default=1, type=int,help="Number of repetitions, per sample text prompt")self.parser.add_argument("--cond_scale", default=4, type=float,help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")self.parser.add_argument("--temperature", default=1., type=float,help="Sampling Temperature.")self.parser.add_argument("--topkr", default=0.9, type=float,help="Filter out percentil low prop entries.")self.parser.add_argument("--time_steps", default=18, type=int,help="Mask Generate steps.")self.parser.add_argument("--seed", default=10107, type=int)self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.')self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.')# self.parser.add_argument('--est_length', action="store_true", help='Training iterations')self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer')self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file')self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section''type int will specify the token frame, type float will specify the ratio of seq_len')self.parser.add_argument('--text_prompt', default='A person is running on a treadmill.', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.")self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)")self.parser.add_argument("--motion_length", default=0, type=int,help="Motion length for generation, only applicable with single text prompt.")self.is_train = Falsefrom utils.get_opt import get_optfrom utils.fixseed import fixseed
from visualization.joints2bvh import Joint2BVHConvertor
from torch.distributions.categorical import Categoricalfrom utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motionfrom utils.paramUtil import t2m_kinematic_chainimport numpy as np
clip_version = 'ViT-B/32'def load_vq_model(vq_opt):# opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')vq_model = RVQVAE(vq_opt,vq_opt.dim_pose,vq_opt.nb_code,vq_opt.code_dim,vq_opt.output_emb_width,vq_opt.down_t,vq_opt.stride_t,vq_opt.width,vq_opt.depth,vq_opt.dilation_growth_rate,vq_opt.vq_act,vq_opt.vq_norm)ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),map_location='cpu')model_key = 'vq_model' if 'vq_model' in ckpt else 'net'vq_model.load_state_dict(ckpt[model_key])print(f'Loading VQ Model {vq_opt.name} Completed!')return vq_model, vq_optdef load_trans_model(model_opt, opt, which_model):t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,cond_mode='text',latent_dim=model_opt.latent_dim,ff_size=model_opt.ff_size,num_layers=model_opt.n_layers,num_heads=model_opt.n_heads,dropout=model_opt.dropout,clip_dim=512,cond_drop_prob=model_opt.cond_drop_prob,clip_version=clip_version,opt=model_opt)ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),map_location='cpu')model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'# print(ckpt.keys())missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)assert len(unexpected_keys) == 0assert all([k.startswith('clip_model.') for k in missing_keys])print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!')return t2m_transformerdef load_res_model(res_opt, vq_opt, opt):res_opt.num_quantizers = vq_opt.num_quantizersres_opt.num_tokens = vq_opt.nb_coderes_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,cond_mode='text',latent_dim=res_opt.latent_dim,ff_size=res_opt.ff_size,num_layers=res_opt.n_layers,num_heads=res_opt.n_heads,dropout=res_opt.dropout,clip_dim=512,shared_codebook=vq_opt.shared_codebook,cond_drop_prob=res_opt.cond_drop_prob,# codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,share_weight=res_opt.share_weight,clip_version=clip_version,opt=res_opt)ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),map_location=opt.device)missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)assert len(unexpected_keys) == 0assert all([k.startswith('clip_model.') for k in missing_keys])print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')return res_transformerdef load_len_estimator(opt):model = LengthEstimator(512, 50)ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'),map_location=opt.device)model.load_state_dict(ckpt['estimator'])print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!')return modelif __name__ == '__main__':parser = EvalT2MOptions()opt = parser.parse()fixseed(opt.seed)opt.device = torch.device("cuda:1")torch.autograd.set_detect_anomaly(True)dim_pose = 251 if opt.dataset_name == 'kit' else 263# out_dir = pjoin(opt.check)root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)model_dir = pjoin(root_dir, 'model')result_dir = pjoin('./generation', opt.ext)joints_dir = pjoin(result_dir, 'joints')animation_dir = pjoin(result_dir, 'animations')os.makedirs(joints_dir, exist_ok=True)os.makedirs(animation_dir,exist_ok=True)model_opt_path = pjoin(root_dir, 'opt.txt')model_opt = get_opt(model_opt_path, device=opt.device)#############################Loading RVQ#############################vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')vq_opt = get_opt(vq_opt_path, device=opt.device)vq_opt.dim_pose = dim_posevq_model, vq_opt = load_vq_model(vq_opt)model_opt.num_tokens = vq_opt.nb_codemodel_opt.num_quantizers = vq_opt.num_quantizersmodel_opt.code_dim = vq_opt.code_dim#######################################Loading R-Transformer#######################################res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')res_opt = get_opt(res_opt_path, device=opt.device)res_model = load_res_model(res_opt, vq_opt, opt)assert res_opt.vq_name == model_opt.vq_name#######################################Loading M-Transformer#######################################t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')#######################################Loading Length Predictor#######################################length_estimator = load_len_estimator(model_opt)t2m_transformer.eval()vq_model.eval()res_model.eval()length_estimator.eval()res_model.to(opt.device)t2m_transformer.to(opt.device)vq_model.to(opt.device)length_estimator.to(opt.device)##### ---- Dataloader ---- #####opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))def inv_transform(data):return data * std + meanprompt_list = []length_list = []est_length = Falseif opt.text_prompt != "":prompt_list.append(opt.text_prompt)if opt.motion_length == 0:est_length = Trueelse:length_list.append(opt.motion_length)elif opt.text_path != "":with open(opt.text_path, 'r') as f:lines = f.readlines()for line in lines:infos = line.split('#')prompt_list.append(infos[0])if len(infos) == 1 or (not infos[1].isdigit()):est_length = Truelength_list = []else:length_list.append(int(infos[-1]))else:raise "A text prompt, or a file a text prompts are required!!!"# print('loading checkpoint {}'.format(file))if est_length:print("Since no motion length are specified, we will use estimated motion lengthes!!")text_embedding = t2m_transformer.encode_text(prompt_list)pred_dis = length_estimator(text_embedding)probs = F.softmax(pred_dis, dim=-1) # (b, ntoken)token_lens = Categorical(probs).sample() # (b, seqlen)# lengths = torch.multinomial()else:token_lens = torch.LongTensor(length_list) // 4token_lens = token_lens.to(opt.device).long()m_length = token_lens * 4captions = prompt_listsample = 0kinematic_chain = t2m_kinematic_chainconverter = Joint2BVHConvertor()for r in range(opt.repeat_times):print("-->Repeat %d"%r)with torch.no_grad():mids = t2m_transformer.generate(captions, token_lens,timesteps=opt.time_steps,cond_scale=opt.cond_scale,temperature=opt.temperature,topk_filter_thres=opt.topkr,gsample=opt.gumbel_sample)# print(mids)# print(mids.shape)mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5)pred_motions = vq_model.forward_decoder(mids)pred_motions = pred_motions.detach().cpu().numpy()data = inv_transform(pred_motions)for k, (caption, joint_data) in enumerate(zip(captions, data)):print("---->Sample %d: %s %d"%(k, caption, m_length[k]))animation_path = pjoin(animation_dir, str(k))joint_path = pjoin(joints_dir, str(k))print('save_dir',animation_path)os.makedirs(animation_path, exist_ok=True)os.makedirs(joint_path, exist_ok=True)joint_data = joint_data[:m_length[k]]joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))_, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))_, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20)plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)
相关文章:
momask-codes 部署踩坑笔记
目录 依赖项 t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns 推理代码完善: 代码地址: https://github.com/EricGuo5513/momask-codes 依赖项 pip install numpy1.23 matplotlib 必须指定版本:pip install matplotlib3.3.4 t2m_nlayer…...

H3CNE-31-BFD
Bidirectional Forwarding Dection,双向转发检查 作用:毫秒级故障检查,通常结合三层协议(静态路由、vrrp、ospf、BGP等),实现链路故障快速检查。 BFD配置示例 没有中间的SW,接口downÿ…...
蓝桥备赛指南(5)
queue队列 queue是一种先进先出的数据结构。它提供了一组函数来操作和访问元素,但它的功能相对较简单,queue函数的内部实现了底层容器来存储元素,并且只能通过特定的函数来访问和操作元素。 queue函数的常用函数 1.push()函数:…...
讯飞智作 AI 配音技术浅析(一)
一、核心技术 讯飞智作 AI 配音技术作为科大讯飞在人工智能领域的重要成果,融合了多项前沿技术,为用户提供了高质量的语音合成服务。其核心技术主要涵盖以下几个方面: 1. 深度学习与神经网络 讯飞智作 AI 配音技术以深度学习为核心驱动力&…...

MySQL(高级特性篇) 14 章——MySQL事务日志
事务有4种特性:原子性、一致性、隔离性和持久性 事务的隔离性由锁机制实现事务的原子性、一致性和持久性由事务的redo日志和undo日志来保证(1)REDO LOG称为重做日志,用来保证事务的持久性(2)UNDO LOG称为回…...

openRv1126 AI算法部署实战之——TensorFlow TFLite Pytorch ONNX等模型转换实战
Conda简介 查看当前系统的环境列表 conda env list base为基础环境 py3.6-rknn-1.7.3为模型转换环境,rknn-toolkit版本V1.7.3,python版本3.6 py3.6-tensorflow-2.5.0为tensorflow模型训练环境,tensorflow版本2.5.0,python版本…...

【Redis】常见面试题
什么是Redis? Redis 和 Memcached 有什么区别? 为什么用 Redis 作为 MySQL 的缓存? 主要是因为Redis具备高性能和高并发两种特性。 高性能:MySQL中数据是从磁盘读取的,而Redis是直接操作内存,速度相当快…...
每日 Java 面试题分享【第 17 天】
欢迎来到每日 Java 面试题分享栏目! 订阅专栏,不错过每一天的练习 今日分享 3 道面试题目! 评论区复述一遍印象更深刻噢~ 目录 问题一:Java 中的访问修饰符有哪些?问题二:Java 中静态方法和实例方法的区…...

「全网最细 + 实战源码案例」设计模式——桥接模式
核心思想 桥接模式(Bridge Pattern)是一种结构型设计模式,将抽象部分与其实现部分分离,使它们可以独立变化。降低代码耦合度,避免类爆炸,提高代码的可扩展性。 结构 1. Implementation(实现类…...

JavaScript 进阶(上)
作用域 局部作用域 局部作用域分为函数作用域和块作用域。 函数作用域: 在函数内部声明的变量只能在函数内部被访问,外部无法直接访问。 总结: 函数内部声明的变量,在函数外部无法被访问 函数的参数也是函数内部的局部变量 …...

【编译原理实验二】——自动机实验:NFA转DFA并最小化
本篇适用于ZZU的编译原理课程实验二——自动机实验:NFA转DFA并最小化,包含了实验代码和实验报告的内容,读者可根据需要参考完成自己的程序设计。 如果是ZZU的学弟学妹看到这篇,那么恭喜你,你来对地方啦! 如…...
深入探讨:服务器如何响应前端请求及后端如何查看前端提交的数据
深入探讨:服务器如何响应前端请求及后端如何查看前端提交的数据 一、服务器如何响应前端请求 前端与后端的交互主要通过 HTTP 协议实现。以下是详细步骤: 1. 前端发起 HTTP 请求 GET 请求:用于从服务器获取数据。POST 请求:用…...
如何利用Docker和.NET Core实现环境一致性、简化依赖管理、快速部署与扩展,同时提高资源利用率、确保安全性和生态系统支持
目录 1. 环境一致性 2. 简化依赖管理 3. 快速部署与扩展 4. 提高资源利用率 5. 确保安全性 6. 生态系统支持 总结 使用 Docker 和 .NET Core 结合,可以有效地实现环境一致性、简化依赖管理、快速部署与扩展,同时提高资源利用率、确保安全性和生态…...
@Inject @Qualifier @Named
Inject Qualifier Named 在依赖注入(DI)中,Inject、Qualifier 和 Named 是用于管理对象创建和绑定的关键注解。以下是它们的用途、依赖配置和代码示例的详细说明: 1. 注解的作用 Inject:标记需要注入的构造函数、字段…...
创建 priority_queue - 进阶(内置类型)c++
内置类型就是 C 提供的数据类型,⽐如 int 、 double 、 long long 等。以 int 类型为例,分 别创建⼤根堆和⼩根堆。 这种写法意思是,我要告诉这个优先级队列要建一个什么样的堆,第一个int是要存什么数据类型,vecto…...

2. Java-MarkDown文件解析-工具类
2. Java-MarkDown文件解析-工具类 1. 思路 读取markdown文件的内容,根据markdown的语法进行各个类型语法的解析。引入工具类 commonmark 和 commonmark-ext-gfm-tables进行markdown语法解析。 2. 工具类 pom.xml <!-- commonmark 解析markdown --> <d…...

动态规划DP 最长上升子序列模型 登山(题目分析+C++完整代码)
概览检索 动态规划DP 最长上升子序列模型 登山 原题链接 AcWing 1014. 登山 题目描述 五一到了,ACM队组织大家去登山观光,队员们发现山上一共有N个景点,并且决定按照顺序来浏览这些景点,即每次所浏览景点的编号都要大于前一个…...
css-设置元素的溢出行为为可见overflow: visible;
1.前言 overflow 属性用于设置当元素的内容溢出其框时如何处理。 2. overflow overflow 属性的一些常见值: 1 visible:默认值。内容不会被剪裁,会溢出元素的框。 2 hidden:内容会被剪裁,不会显示溢出的部分。 3 sc…...

家居EDI:Hom Furniture EDI需求分析
HOM Furniture 是一家成立于1977年的美国家具零售商,总部位于明尼苏达州。公司致力于提供高品质、时尚的家具和家居用品,满足各种家庭和办公需求。HOM Furniture 以广泛的产品线和优质的客户服务在市场上赢得了良好的口碑。公司经营的产品包括卧室、客厅…...

1、开始简单使用rag
文章目录 前言数据存放申请api开始代码安装依赖从文件夹中读取文档文档切块将分割嵌入并存储在向量库中检索部分代码构造用户接口演示提示 整体代码 前言 本章只是简单使用rag的一个示例,为了引出以后的学习,将整个rag的流程串起来 数据存放 一个示例…...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...
Leetcode 3576. Transform Array to All Equal Elements
Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接:3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到…...

ElasticSearch搜索引擎之倒排索引及其底层算法
文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

如何在最短时间内提升打ctf(web)的水平?
刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...

C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
智能AI电话机器人系统的识别能力现状与发展水平
一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

基于Java+VUE+MariaDB实现(Web)仿小米商城
仿小米商城 环境安装 nodejs maven JDK11 运行 mvn clean install -DskipTestscd adminmvn spring-boot:runcd ../webmvn spring-boot:runcd ../xiaomi-store-admin-vuenpm installnpm run servecd ../xiaomi-store-vuenpm installnpm run serve 注意:运行前…...

恶补电源:1.电桥
一、元器件的选择 搜索并选择电桥,再multisim中选择FWB,就有各种型号的电桥: 电桥是用来干嘛的呢? 它是一个由四个二极管搭成的“桥梁”形状的电路,用来把交流电(AC)变成直流电(DC)。…...