当前位置: 首页 > news >正文

momask-codes 部署踩坑笔记

目录

依赖项

t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns

推理代码完善:


代码地址:

https://github.com/EricGuo5513/momask-codes

依赖项

pip install numpy==1.23

matplotlib 必须指定版本:pip install matplotlib==3.3.4

t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns

下载模型:

cd t2m 
echo -e "Downloading pretrained models for HumanML3D dataset"
gdown --fuzzy https://drive.google.com/file/d/1vXS7SHJBgWPt59wupQ5UUzhFObrnGkQ0/view?usp=sharing

echo -e "Unzipping humanml3d_models.zip"
unzip humanml3d_models.zip

推理代码完善:

text_motion/momask-codes/gen_t2m.py

# coding=utf-8
import sys
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
current_dir = os.path.dirname(os.path.abspath(__file__))paths = [os.path.abspath(__file__).split('scripts')[0]]
print('current_dir',current_dir)
paths.append(os.path.abspath(os.path.join(current_dir, 'src')))for path in paths:sys.path.insert(0, path)os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')import os
from os.path import join as pjoinimport torch
import torch.nn.functional as Ffrom models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
from models.vq.model import RVQVAE, LengthEstimator# from options.eval_option import EvalT2MOptionsfrom options.base_option import BaseOptionsclass EvalT2MOptions(BaseOptions):def initialize(self):BaseOptions.initialize(self)self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}')self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size')self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder')self.parser.add_argument("--num_batch", default=2, type=int,help="Number of batch for generation")self.parser.add_argument("--repeat_times", default=1, type=int,help="Number of repetitions, per sample text prompt")self.parser.add_argument("--cond_scale", default=4, type=float,help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")self.parser.add_argument("--temperature", default=1., type=float,help="Sampling Temperature.")self.parser.add_argument("--topkr", default=0.9, type=float,help="Filter out percentil low prop entries.")self.parser.add_argument("--time_steps", default=18, type=int,help="Mask Generate steps.")self.parser.add_argument("--seed", default=10107, type=int)self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.')self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.')# self.parser.add_argument('--est_length', action="store_true", help='Training iterations')self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer')self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file')self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section''type int will specify the token frame, type float will specify the ratio of seq_len')self.parser.add_argument('--text_prompt', default='A person is running on a treadmill.', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.")self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)")self.parser.add_argument("--motion_length", default=0, type=int,help="Motion length for generation, only applicable with single text prompt.")self.is_train = Falsefrom utils.get_opt import get_optfrom utils.fixseed import fixseed
from visualization.joints2bvh import Joint2BVHConvertor
from torch.distributions.categorical import Categoricalfrom utils.motion_process import recover_from_ric
from utils.plot_script import plot_3d_motionfrom utils.paramUtil import t2m_kinematic_chainimport numpy as np
clip_version = 'ViT-B/32'def load_vq_model(vq_opt):# opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')vq_model = RVQVAE(vq_opt,vq_opt.dim_pose,vq_opt.nb_code,vq_opt.code_dim,vq_opt.output_emb_width,vq_opt.down_t,vq_opt.stride_t,vq_opt.width,vq_opt.depth,vq_opt.dilation_growth_rate,vq_opt.vq_act,vq_opt.vq_norm)ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),map_location='cpu')model_key = 'vq_model' if 'vq_model' in ckpt else 'net'vq_model.load_state_dict(ckpt[model_key])print(f'Loading VQ Model {vq_opt.name} Completed!')return vq_model, vq_optdef load_trans_model(model_opt, opt, which_model):t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,cond_mode='text',latent_dim=model_opt.latent_dim,ff_size=model_opt.ff_size,num_layers=model_opt.n_layers,num_heads=model_opt.n_heads,dropout=model_opt.dropout,clip_dim=512,cond_drop_prob=model_opt.cond_drop_prob,clip_version=clip_version,opt=model_opt)ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),map_location='cpu')model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'# print(ckpt.keys())missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)assert len(unexpected_keys) == 0assert all([k.startswith('clip_model.') for k in missing_keys])print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!')return t2m_transformerdef load_res_model(res_opt, vq_opt, opt):res_opt.num_quantizers = vq_opt.num_quantizersres_opt.num_tokens = vq_opt.nb_coderes_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,cond_mode='text',latent_dim=res_opt.latent_dim,ff_size=res_opt.ff_size,num_layers=res_opt.n_layers,num_heads=res_opt.n_heads,dropout=res_opt.dropout,clip_dim=512,shared_codebook=vq_opt.shared_codebook,cond_drop_prob=res_opt.cond_drop_prob,# codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,share_weight=res_opt.share_weight,clip_version=clip_version,opt=res_opt)ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),map_location=opt.device)missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)assert len(unexpected_keys) == 0assert all([k.startswith('clip_model.') for k in missing_keys])print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')return res_transformerdef load_len_estimator(opt):model = LengthEstimator(512, 50)ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'),map_location=opt.device)model.load_state_dict(ckpt['estimator'])print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!')return modelif __name__ == '__main__':parser = EvalT2MOptions()opt = parser.parse()fixseed(opt.seed)opt.device = torch.device("cuda:1")torch.autograd.set_detect_anomaly(True)dim_pose = 251 if opt.dataset_name == 'kit' else 263# out_dir = pjoin(opt.check)root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)model_dir = pjoin(root_dir, 'model')result_dir = pjoin('./generation', opt.ext)joints_dir = pjoin(result_dir, 'joints')animation_dir = pjoin(result_dir, 'animations')os.makedirs(joints_dir, exist_ok=True)os.makedirs(animation_dir,exist_ok=True)model_opt_path = pjoin(root_dir, 'opt.txt')model_opt = get_opt(model_opt_path, device=opt.device)#############################Loading RVQ#############################vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')vq_opt = get_opt(vq_opt_path, device=opt.device)vq_opt.dim_pose = dim_posevq_model, vq_opt = load_vq_model(vq_opt)model_opt.num_tokens = vq_opt.nb_codemodel_opt.num_quantizers = vq_opt.num_quantizersmodel_opt.code_dim = vq_opt.code_dim#######################################Loading R-Transformer#######################################res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')res_opt = get_opt(res_opt_path, device=opt.device)res_model = load_res_model(res_opt, vq_opt, opt)assert res_opt.vq_name == model_opt.vq_name#######################################Loading M-Transformer#######################################t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')#######################################Loading Length Predictor#######################################length_estimator = load_len_estimator(model_opt)t2m_transformer.eval()vq_model.eval()res_model.eval()length_estimator.eval()res_model.to(opt.device)t2m_transformer.to(opt.device)vq_model.to(opt.device)length_estimator.to(opt.device)##### ---- Dataloader ---- #####opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))def inv_transform(data):return data * std + meanprompt_list = []length_list = []est_length = Falseif opt.text_prompt != "":prompt_list.append(opt.text_prompt)if opt.motion_length == 0:est_length = Trueelse:length_list.append(opt.motion_length)elif opt.text_path != "":with open(opt.text_path, 'r') as f:lines = f.readlines()for line in lines:infos = line.split('#')prompt_list.append(infos[0])if len(infos) == 1 or (not infos[1].isdigit()):est_length = Truelength_list = []else:length_list.append(int(infos[-1]))else:raise "A text prompt, or a file a text prompts are required!!!"# print('loading checkpoint {}'.format(file))if est_length:print("Since no motion length are specified, we will use estimated motion lengthes!!")text_embedding = t2m_transformer.encode_text(prompt_list)pred_dis = length_estimator(text_embedding)probs = F.softmax(pred_dis, dim=-1)  # (b, ntoken)token_lens = Categorical(probs).sample()  # (b, seqlen)# lengths = torch.multinomial()else:token_lens = torch.LongTensor(length_list) // 4token_lens = token_lens.to(opt.device).long()m_length = token_lens * 4captions = prompt_listsample = 0kinematic_chain = t2m_kinematic_chainconverter = Joint2BVHConvertor()for r in range(opt.repeat_times):print("-->Repeat %d"%r)with torch.no_grad():mids = t2m_transformer.generate(captions, token_lens,timesteps=opt.time_steps,cond_scale=opt.cond_scale,temperature=opt.temperature,topk_filter_thres=opt.topkr,gsample=opt.gumbel_sample)# print(mids)# print(mids.shape)mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5)pred_motions = vq_model.forward_decoder(mids)pred_motions = pred_motions.detach().cpu().numpy()data = inv_transform(pred_motions)for k, (caption, joint_data)  in enumerate(zip(captions, data)):print("---->Sample %d: %s %d"%(k, caption, m_length[k]))animation_path = pjoin(animation_dir, str(k))joint_path = pjoin(joints_dir, str(k))print('save_dir',animation_path)os.makedirs(animation_path, exist_ok=True)os.makedirs(joint_path, exist_ok=True)joint_data = joint_data[:m_length[k]]joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))_, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))_, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20)plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)

相关文章:

momask-codes 部署踩坑笔记

目录 依赖项 t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns 推理代码完善: 代码地址: https://github.com/EricGuo5513/momask-codes 依赖项 pip install numpy1.23 matplotlib 必须指定版本:pip install matplotlib3.3.4 t2m_nlayer…...

H3CNE-31-BFD

Bidirectional Forwarding Dection,双向转发检查 作用:毫秒级故障检查,通常结合三层协议(静态路由、vrrp、ospf、BGP等),实现链路故障快速检查。 BFD配置示例 没有中间的SW,接口down&#xff…...

蓝桥备赛指南(5)

queue队列 queue是一种先进先出的数据结构。它提供了一组函数来操作和访问元素,但它的功能相对较简单,queue函数的内部实现了底层容器来存储元素,并且只能通过特定的函数来访问和操作元素。 queue函数的常用函数 1.push()函数:…...

讯飞智作 AI 配音技术浅析(一)

一、核心技术 讯飞智作 AI 配音技术作为科大讯飞在人工智能领域的重要成果,融合了多项前沿技术,为用户提供了高质量的语音合成服务。其核心技术主要涵盖以下几个方面: 1. 深度学习与神经网络 讯飞智作 AI 配音技术以深度学习为核心驱动力&…...

MySQL(高级特性篇) 14 章——MySQL事务日志

事务有4种特性:原子性、一致性、隔离性和持久性 事务的隔离性由锁机制实现事务的原子性、一致性和持久性由事务的redo日志和undo日志来保证(1)REDO LOG称为重做日志,用来保证事务的持久性(2)UNDO LOG称为回…...

openRv1126 AI算法部署实战之——TensorFlow TFLite Pytorch ONNX等模型转换实战

Conda简介 查看当前系统的环境列表 conda env list base为基础环境 py3.6-rknn-1.7.3为模型转换环境,rknn-toolkit版本V1.7.3,python版本3.6 py3.6-tensorflow-2.5.0为tensorflow模型训练环境,tensorflow版本2.5.0,python版本…...

【Redis】常见面试题

什么是Redis? Redis 和 Memcached 有什么区别? 为什么用 Redis 作为 MySQL 的缓存? 主要是因为Redis具备高性能和高并发两种特性。 高性能:MySQL中数据是从磁盘读取的,而Redis是直接操作内存,速度相当快…...

每日 Java 面试题分享【第 17 天】

欢迎来到每日 Java 面试题分享栏目! 订阅专栏,不错过每一天的练习 今日分享 3 道面试题目! 评论区复述一遍印象更深刻噢~ 目录 问题一:Java 中的访问修饰符有哪些?问题二:Java 中静态方法和实例方法的区…...

「全网最细 + 实战源码案例」设计模式——桥接模式

核心思想 桥接模式(Bridge Pattern)是一种结构型设计模式,将抽象部分与其实现部分分离,使它们可以独立变化。降低代码耦合度,避免类爆炸,提高代码的可扩展性。 结构 1. Implementation(实现类…...

JavaScript 进阶(上)

作用域 局部作用域 局部作用域分为函数作用域和块作用域。 函数作用域: 在函数内部声明的变量只能在函数内部被访问,外部无法直接访问。 总结: 函数内部声明的变量,在函数外部无法被访问 函数的参数也是函数内部的局部变量 …...

【编译原理实验二】——自动机实验:NFA转DFA并最小化

本篇适用于ZZU的编译原理课程实验二——自动机实验:NFA转DFA并最小化,包含了实验代码和实验报告的内容,读者可根据需要参考完成自己的程序设计。 如果是ZZU的学弟学妹看到这篇,那么恭喜你,你来对地方啦! 如…...

深入探讨:服务器如何响应前端请求及后端如何查看前端提交的数据

深入探讨:服务器如何响应前端请求及后端如何查看前端提交的数据 一、服务器如何响应前端请求 前端与后端的交互主要通过 HTTP 协议实现。以下是详细步骤: 1. 前端发起 HTTP 请求 GET 请求:用于从服务器获取数据。POST 请求:用…...

如何利用Docker和.NET Core实现环境一致性、简化依赖管理、快速部署与扩展,同时提高资源利用率、确保安全性和生态系统支持

目录 1. 环境一致性 2. 简化依赖管理 3. 快速部署与扩展 4. 提高资源利用率 5. 确保安全性 6. 生态系统支持 总结 使用 Docker 和 .NET Core 结合,可以有效地实现环境一致性、简化依赖管理、快速部署与扩展,同时提高资源利用率、确保安全性和生态…...

@Inject @Qualifier @Named

Inject Qualifier Named 在依赖注入(DI)中,Inject、Qualifier 和 Named 是用于管理对象创建和绑定的关键注解。以下是它们的用途、依赖配置和代码示例的详细说明: 1. 注解的作用 Inject:标记需要注入的构造函数、字段…...

创建 priority_queue - 进阶(内置类型)c++

内置类型就是 C 提供的数据类型,⽐如 int 、 double 、 long long 等。以 int 类型为例,分 别创建⼤根堆和⼩根堆。 这种写法意思是,我要告诉这个优先级队列要建一个什么样的堆,第一个int是要存什么数据类型,vecto…...

2. Java-MarkDown文件解析-工具类

2. Java-MarkDown文件解析-工具类 1. 思路 读取markdown文件的内容&#xff0c;根据markdown的语法进行各个类型语法的解析。引入工具类 commonmark 和 commonmark-ext-gfm-tables进行markdown语法解析。 2. 工具类 pom.xml <!-- commonmark 解析markdown --> <d…...

动态规划DP 最长上升子序列模型 登山(题目分析+C++完整代码)

概览检索 动态规划DP 最长上升子序列模型 登山 原题链接 AcWing 1014. 登山 题目描述 五一到了&#xff0c;ACM队组织大家去登山观光&#xff0c;队员们发现山上一共有N个景点&#xff0c;并且决定按照顺序来浏览这些景点&#xff0c;即每次所浏览景点的编号都要大于前一个…...

css-设置元素的溢出行为为可见overflow: visible;

1.前言 overflow 属性用于设置当元素的内容溢出其框时如何处理。 2. overflow overflow 属性的一些常见值&#xff1a; 1 visible&#xff1a;默认值。内容不会被剪裁&#xff0c;会溢出元素的框。 2 hidden&#xff1a;内容会被剪裁&#xff0c;不会显示溢出的部分。 3 sc…...

家居EDI:Hom Furniture EDI需求分析

HOM Furniture 是一家成立于1977年的美国家具零售商&#xff0c;总部位于明尼苏达州。公司致力于提供高品质、时尚的家具和家居用品&#xff0c;满足各种家庭和办公需求。HOM Furniture 以广泛的产品线和优质的客户服务在市场上赢得了良好的口碑。公司经营的产品包括卧室、客厅…...

1、开始简单使用rag

文章目录 前言数据存放申请api开始代码安装依赖从文件夹中读取文档文档切块将分割嵌入并存储在向量库中检索部分代码构造用户接口演示提示 整体代码 前言 本章只是简单使用rag的一个示例&#xff0c;为了引出以后的学习&#xff0c;将整个rag的流程串起来 数据存放 一个示例…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

django filter 统计数量 按属性去重

在Django中&#xff0c;如果你想要根据某个属性对查询集进行去重并统计数量&#xff0c;你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求&#xff1a; 方法1&#xff1a;使用annotate()和Count 假设你有一个模型Item&#xff0c;并且你想…...

Spring数据访问模块设计

前面我们已经完成了IoC和web模块的设计&#xff0c;聪明的码友立马就知道了&#xff0c;该到数据访问模块了&#xff0c;要不就这俩玩个6啊&#xff0c;查库势在必行&#xff0c;至此&#xff0c;它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据&#xff08;数据库、No…...

精益数据分析(97/126):邮件营销与用户参与度的关键指标优化指南

精益数据分析&#xff08;97/126&#xff09;&#xff1a;邮件营销与用户参与度的关键指标优化指南 在数字化营销时代&#xff0c;邮件列表效度、用户参与度和网站性能等指标往往决定着创业公司的增长成败。今天&#xff0c;我们将深入解析邮件打开率、网站可用性、页面参与时…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

力扣-35.搜索插入位置

题目描述 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...

基于PHP的连锁酒店管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

9-Oracle 23 ai Vector Search 特性 知识准备

很多小伙伴是不是参加了 免费认证课程&#xff08;限时至2025/5/15&#xff09; Oracle AI Vector Search 1Z0-184-25考试&#xff0c;都顺利拿到certified了没。 各行各业的AI 大模型的到来&#xff0c;传统的数据库中的SQL还能不能打&#xff0c;结构化和非结构的话数据如何和…...

Java求职者面试指南:Spring、Spring Boot、Spring MVC与MyBatis技术解析

Java求职者面试指南&#xff1a;Spring、Spring Boot、Spring MVC与MyBatis技术解析 一、第一轮基础概念问题 1. Spring框架的核心容器是什么&#xff1f;它的作用是什么&#xff1f; Spring框架的核心容器是IoC&#xff08;控制反转&#xff09;容器。它的主要作用是管理对…...