当前位置: 首页 > news >正文

[pai-diffusion]pai的easynlp的clip模型训练

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态理解的重要任务,…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/528476134

initialize_easynlp()->train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/clip_chinese_roberta_base_vit_base"),data_file="MUGE_MR_train_base64_part.tsv",max_seq_length=32,input_schema="text:str:1,image:str:1",first_sequence="text",second_sequence="image",is_training=True)
valid_dataset = CLIPDataset()model = get_application_model(app_name='clip',...)
- easynlp.appzoo.api.ModelMapping->CLIPApp
- easynlp.appzoo.clip.model.py->CLIPApp
- CHINESE_CLIP->
- self.visual = VisualTransformer()
- self.bert = BertModel()trainer = Trainer(model,train_dataset,user_defined_parameters,  evaluator=get_application_evaluator(app_name="clip",valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))trainer.train()
- for _epoch in range(self._first_epoch,int(args.epoch_num)):for _step,batch in enumerate(self._train_loader):    label_ids = batch.pop()forward_outputs = self._model(batch)loss_dict = self.model_module.compute_loss(forward_outputs,label_ids)_loss = loss_dict('loss')_loss.backward()model = get_application_model_evaluation()
evaluator = get_application_evaluator()
evaluator.evaluate(model)

数据处理:

import os
import base64
import multiprocessing
from tqdm import tqdmdef process_image(image_path):# 从图片路径中提取中文描述image_name = os.path.basename(image_path)description = os.path.splitext(image_name)[0]# 将图片转换为 Base64 编码with open(image_path, 'rb') as f:image_data = f.read()base64_data = base64.b64encode(image_data).decode('utf-8')return description, base64_datadef generate_tsv(directory):image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory) iffilename.endswith(('.jpg', '.png'))]with multiprocessing.Pool() as pool, tqdm(total=len(image_paths), desc='Processing Images') as pbar:results = []for result in pool.imap_unordered(process_image, image_paths):results.append(result)pbar.update(1)with open('/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train.tsv','w', encoding='utf-8') as f:for description, base64_data in results:line = f"{description}\t{base64_data}\n"f.write(line)if __name__ == '__main__':target_directory = "/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train/img_download/"# import pdb;pdb.set_trace()generate_tsv(target_directory)

训练代码:

import torch.cuda
from easynlp.appzoo import CLIPDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, \get_application_model_for_evaluation
from easynlp.core import Trainer, PredictorManager
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parametersdef main():# /root/.easynlp/modelzoo中train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),data_file=args.tables.split(",")[0],max_seq_length=args.sequence_length,input_schema=args.input_schema,first_sequence=args.first_sequence,second_sequence=args.second_sequence,is_training=True)valid_dataset = CLIPDataset(# 预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"alibaba-pai/clip_chinese_roberta_base_vit_base"以得到其路径,并自动下载模型pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),data_file=args.tables.split(",")[-1],# "data/pai/MUGE_MR_valid_base64_part.tsv"max_seq_length=args.sequence_length,  # 文本最大长度,超过将截断,不足将paddinginput_schema=args.input_schema,  # 输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等first_sequence=args.first_sequence,  # 用于说明input_schema中哪些字段作为第一/第二列输入数据second_sequence=args.second_sequence,is_training=False)  # 是否为训练过程,train_dataset为True,valid_dataset为Falsemodel = get_application_model(app_name=args.app_name,  # 任务名称,这里选择文本分类"clip"pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),user_defined_parameters=user_defined_parameters# user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters)trainer = Trainer(model=model,train_dataset=train_dataset,user_defined_parameters=user_defined_parameters,evaluator=get_application_evaluator(app_name=args.app_name,valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))trainer.train()# 模型评估model = get_application_model_for_evaluation(app_name=args.app_name,pretrained_model_name_or_path=args.checkpoint_dir,user_defined_parameters=user_defined_parameters)evaluator = get_application_evaluator(app_name=args.app_name,valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32)model.to(torch.cuda.current_device())evaluator.evaluate(model=model)# 模型预测if test:predictor = get_application_predictor(app_name="clip",model_dir="./outputs/clip_model/",first_sequence="text",second_sequence="image",sequence_length=32,user_defined_parameters=user_defined_parameters)predictor_manager = PredictorManager(predictor=predictor,input_file="data/vcg_furnitures_text_image/vcg_furnitures_test.tsv",input_schema="text:str:1",output_file="text_feat.tsv",output_schema="text_feat",append_cols="text",batch_size=2)predictor_manager.run()if __name__ == "__main__":initialize_easynlp()args = get_args()user_defined_parameters = parse_user_defined_parameters('pretrain_model_name_or_path=alibaba-pai/clip_chinese_roberta_base_vit_base')args.checkpoint_dir = "./outputs/clip_model/"args.pretrained_model_name_or_path = "alibaba-pai/clip_chinese_roberta_base_vit_base"# args.n_gpu = 3# args.worker_gpu = "1,2,3"args.app_name = "clip"args.tables = "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"# "data/vcg_furnitures_text_image/vcg_furnitures_train.tsv," \#               "data/vcg_furnitures_text_image/vcg_furnitures_test.tsv"# "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"args.input_schema = "text:str:1,image:str:1"args.first_sequence = "text"args.second_sequence = "image"args.learning_rate = 1e-4args.epoch_num = 1000args.random_seed = 42args.save_checkpoint_steps = 200args.sequence_length = 32# args.train_batch_size = 2args.micro_batch_size = 32test = Falsemain()# python -m torch.distributed.launch --nproc_per_node 4 tools/train_pai_chinese_clip.py

说一点自己的想法,在我自己工作之初,我很喜欢去拆解一些框架,例如openmm系列,但其实大部分在训练过程上都是相似的,大可不必,在改动上,也没有必要对其进行流程上的大改动,兼具百家之长,了解整体pipeline,更加专注在pipeline实现和效果导向型的结果提交更加有效。

相关文章:

[pai-diffusion]pai的easynlp的clip模型训练

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态…...

期权如何交易?期权如何做模拟交易?

买卖期权的第一步就是要有期权账户,国内的期权品种有商品期权和ETF期权以及股指期权,每种的开户方式和要求都不同,下文为大家介绍期权如何交易?期权如何做模拟交易? 一、期权交易需要开立一个期权账户,可以…...

【新书推荐】大模型赛道如何实现华丽的弯道超车 —— 《分布式统一大数据虚拟文件系统 Alluxio原理、技术与实践》

文章目录 大模型赛道如何实现华丽的弯道超车 —— AI/ML训练赋能解决方案01 具备对海量小文件的频繁数据访问的 I/O 效率02 提高 GPU 利用率,降低成本并提高投资回报率03 支持各种存储系统的原生接口04 支持单云、混合云和多云部署01 通过数据抽象化统一数据孤岛02 …...

Calendar对象获取当前周的bug

项目场景: 双周项目管理,需要获取当前周为一年之中的第几周,原先的代码是用Calendar对象,先用setTime()把当前时间传入,再用get(3)获取一年中的第几周 问题描述 实际发…...

嵌入式环境buildroot的espeak配置与编译

1、在buildroot目录下输入make menuconfig 2、选择Target packages 3、选择Audio and video applications 4、选择espeak、选择alsa via portaudio (新版嵌入式linux一般都是用alsa音频驱动) 5、配置portaudio 选择Library 6、选择Audio/Sound 7、选择…...

物理机环境搭建-linux部署nginx

1、安装nginx部署所需依赖 yum install -y gcc-c pcre pcre-devel zlib zlib-devel openssl openssl-devel2、安装nginx包 wget http://nginx.org/download/nginx-1.8.0.tar.gz 如果没有wget可以安装一下 yum install -y wget下载完成后可以在/usr/local/下放置tar包&#xf…...

删除安装Google Chrome浏览器时捆绑安装的Google 文档、表格、幻灯片、Gmail、Google 云端硬盘、YouTube网址链接(Mac)

删除安装Google Chrome浏览器时捆绑安装的Google 文档、表格、幻灯片、Gmail、Google 云端硬盘、YouTube网址链接(Mac) Mac mini操作系统,安装完 Google Chrome 浏览器以后,单击 启动台 桌面左下角的“显示应用程序”,我们发现捆绑安装了 Goo…...

硬件故障诊断:快速定位问题

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…...

IP代理与加速器:理解它们的区别与共同点

在网络使用过程中,我们经常会遇到需要提高访问速度或保护隐私的需求。IP代理和加速器都是常见的应对方案,但它们在工作原理和应用场景上存在一些区别。本文将为您深入探讨IP代理和加速器的异同,帮助您更好地理解它们的作用和适用情况&#xf…...

Java中List转字符串的方法

一、使用String.join方法 在Java 8之后&#xff0c;String类增加了一个静态方法join()&#xff0c;可以方便地将列表中的元素连接成字符串。 // 创建List List<String> list Arrays.asList("Google", "Baidu", "Taobao"); // 以逗号分隔…...

PyTorch实战:实现MNIST手写数字识别

前言 PyTorch可以说是三大主流框架中最适合初学者学习的了&#xff0c;相较于其他主流框架&#xff0c;PyTorch的简单易用性使其成为初学者们的首选。这样我想要强调的一点是&#xff0c;框架可以类比为编程语言&#xff0c;仅为我们实现项目效果的工具&#xff0c;也就是我们…...

【计算机网络】深入理解TCP协议二(连接管理机制、WAIT_TIME、滑动窗口、流量控制、拥塞控制)

TCP协议 1.连接管理机制2.再谈WAIT_TIME状态2.1理解WAIT_TIME状态2.2解决TIME_WAIT状态引起的bind失败的方法2.3监听套接字listen第二个参数介绍 3.滑动窗口3.1介绍3.2丢包情况分析 4.流量控制5.拥塞控制5.1介绍5.2慢启动 6.捎带应答、延时应答 1.连接管理机制 正常情况下&…...

springboot整合sentinel完成限流

1、直入正题&#xff0c;下载sentinel的jar包 1.1 直接到Sentinel官网里的releases下即可下载最新版本&#xff0c;Sentinel官方下载地址&#xff0c;直接下载jar包即可。不过慢&#xff0c;可能下载不下来 1.2 可以去gitee去下载jar包 1.3 下载完成后&#xff0c;进行打包…...

signal(SIGPIPE, SIG_IGN)

linux查看signal常见信号。 [rootplatform:]# kill -l1) HUP2) INT3) QUIT4) ILL5) TRAP6) ABRT7) BUS8) FPE9) KILL 10) USR1 11) SEGV 12) USR2 13) PIPE 14) ALRM 15) TERM 16) STKFLT 17) CHLD 18) CONT 19) STOP 20) TSTP 21) TTIN 22) TTOU 23) URG 24) XCPU 25) XFSZ 2…...

GAN学习笔记

1.原始的GAN 1.1原始的损失函数 1.1.1写法1参考1&#xff0c;参考2 1.1.2 写法2 where, G Generator D Discriminator Pdata(x) distribution of real data P(z) distribution of generator x sample from Pdata(x) z sample from P(z) D(x) Discriminator network G…...

layui框架学习(45: 工具集模块)

layui的工具集模块util支持固定条、倒计时等组件&#xff0c;同时提供辅助函数处理时间数据、字符转义、批量事件处理等操作。   util模块中的fixbar函数支持设置固定条&#xff08;2.7版本的帮助文档中叫固定块&#xff09;&#xff0c;是指固定在页面一侧的工具条元素&…...

车道检测:Decoupling the Curve Modeling and Pavement Regression for Lane Detection

论文作者&#xff1a;Wencheng Han,Jianbing Shen 作者单位&#xff1a;University of Macau 论文链接&#xff1a;http://arxiv.org/abs/2309.10533v1 内容简介&#xff1a; 1&#xff09;方向&#xff1a;车道检测 2&#xff09;应用&#xff1a;车道检测 3&#xff09…...

【扩散生成模型】Diffusion Generative Models

提出扩散模型思想的论文&#xff1a; 《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》理解 扩散模型综述&#xff1a; “扩散模型”首篇综述论文分类汇总&#xff0c;谷歌&北大最新研究 理论推导、代码实现&#xff1a; What are Diffusion Models?…...

美联储加息步伐“暂停”!BTC凌晨力守27000美元!

美东时间9月20日下午&#xff0c;美联储宣布放缓加息步伐&#xff0c;将联邦基金利率目标维持在5.25%至5.50%的区间不变&#xff0c;保持在22年来的最高点&#xff0c;符合市场预期。 在最新的FOMC声明中&#xff0c;美联储表示最近的指标表明&#xff0c;经济活动一直在稳步扩…...

微信小程序与idea后端如何进行数据交互

交互使用的其实就是调用的req.get(url)方法 进行路径访问&#xff0c;你要先保证自己的springboot项目已经成功运行了&#xff1a; 如下&#xff1a; 如何交互的&#xff1f; 微信小程序&#xff1a;如下为index.js页面 在onLoad()事件中调用方法Project.findAllCities() 要…...

接口测试中缓存处理策略

在接口测试中&#xff0c;缓存处理策略是一个关键环节&#xff0c;直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性&#xff0c;避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明&#xff1a; 一、缓存处理的核…...

华为云AI开发平台ModelArts

华为云ModelArts&#xff1a;重塑AI开发流程的“智能引擎”与“创新加速器”&#xff01; 在人工智能浪潮席卷全球的2025年&#xff0c;企业拥抱AI的意愿空前高涨&#xff0c;但技术门槛高、流程复杂、资源投入巨大的现实&#xff0c;却让许多创新构想止步于实验室。数据科学家…...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

相机从app启动流程

一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

AI编程--插件对比分析:CodeRider、GitHub Copilot及其他

AI编程插件对比分析&#xff1a;CodeRider、GitHub Copilot及其他 随着人工智能技术的快速发展&#xff0c;AI编程插件已成为提升开发者生产力的重要工具。CodeRider和GitHub Copilot作为市场上的领先者&#xff0c;分别以其独特的特性和生态系统吸引了大量开发者。本文将从功…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

Map相关知识

数据结构 二叉树 二叉树&#xff0c;顾名思义&#xff0c;每个节点最多有两个“叉”&#xff0c;也就是两个子节点&#xff0c;分别是左子 节点和右子节点。不过&#xff0c;二叉树并不要求每个节点都有两个子节点&#xff0c;有的节点只 有左子节点&#xff0c;有的节点只有…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程&#xff0c;代码下载&#xff1a;这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中&#xff0c;**知识蒸馏&#xff08;Knowledge Distillation&#xff09;**被广泛应用&#xff0c;作为提升模型…...