本地部署推理TextDiffuser-2:释放语言模型用于文本渲染的力量
系列文章目录
文章目录
- 系列文章目录
- 一、模型下载和环境配置
- 二、模型训练
- (一)训练布局规划器
- (二)训练扩散模型
- 三、模型推理
- (一)准备训练好的模型checkpoint
- (二)全参数推理
- (三)LoRA微调推理
- 四、遇到的错误
- (一)importerror,缺少某些库
- (二)报错:libGL.so.1: cannot open shared object file: No such file or directory
- (三)各种奇奇怪怪的错误(本质上是diffusers版本不对)
- (四)各种库的版本不兼容
- (五)RuntimeError: expected scalar type float Float bu found Half
一、模型下载和环境配置
- 将textdiffuser-2模型仓库克隆到本地
git clone https://github.com/microsoft/unilm/
cd unilm/textdiffuser-2
- 创建并激活虚拟环境,在textdiffuser-2目录下安装需要的软件包
conda create -n textdiffuser2 python=3.8
conda activate textdiffuser2
pip install -r requirements.txt
- 安装与系统版本和cuda版本相匹配的torch、torchvision、xformers (我的环境下cuda是12.2的,其他版本需要自己去官网查询)
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple xformers
- 如果想用FastChat训练布局规划器,还需要安装flash-attention:
先将flash-attention模型仓库克隆下来
git clone https://github.com/Dao-AILab/flash-attention.git
然后安装对应的软件包
pip install packaging
pip uninstall -y ninja && pip install ninja
conda install -c nvidia cuda
pip install flash-attn --no-build-isolation
- 为了训练文本修复任务,还需要安装 differs 包
pip install https://github.com/JingyeChen/diffusers_td2.git
二、模型训练
(一)训练布局规划器
- 需要先下载lmsys/vicuna-7b-v1.5模型和FastChat模型。
模型下载方式: 采用git远程clone下来,具体方式可以参考之前的内容:huggingface学习 | 云服务器使用git-lfs下载huggingface上的模型文件;
- 进行训练
CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=50008 FastChat-main/fastchat/train/train_mem.py \--model_name_or_path vicuna-7b-v1.5 \--data_path data/layout_planner_data_5k.json \--bf16 True \--output_dir experiment_result \--num_train_epochs 6 \--per_device_train_batch_size 2 \--per_device_eval_batch_size 2 \--gradient_accumulation_steps 16 \--evaluation_strategy "no" \--save_strategy "steps" \--save_steps 500 \--save_total_limit 5 \--learning_rate 2e-5 \--weight_decay 0. \--warmup_ratio 0.03 \--lr_scheduler_type "cosine" \--logging_steps 1 \--fsdp "full_shard auto_wrap" \--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \--tf32 True \--model_max_length 2048 \--gradient_checkpointing True \--lazy_preprocess True
(二)训练扩散模型
- 需要先准备需要训练的扩散模型:stable-diffusion-v1-5模型
- 对于全参数训练:
accelerate launch train_textdiffuser2_t2i_full.py \--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \--train_batch_size=18 \--gradient_accumulation_steps=4 \--gradient_checkpointing \--mixed_precision="fp16" \--num_train_epochs=6 \--learning_rate=1e-5 \--max_grad_norm=1 \--lr_scheduler="constant" \--lr_warmup_steps=0 \--output_dir="diffusion_experiment_result" \--enable_xformers_memory_efficient_attention \--dataloader_num_workers=8 \--index_file_path='/path/to/train_dataset_index.txt' \--dataset_path='/path/to/laion-ocr-select/' \--granularity=128 \--coord_mode="ltrb" \--max_length=77 \--resume_from_checkpoint="latest"
- 对于 LoRA 训练:
accelerate launch train_textdiffuser2_t2i_lora.py \--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \--train_batch_size=18 \--gradient_accumulation_steps=4 \--gradient_checkpointing \--mixed_precision="fp16" \--num_train_epochs=6 \--learning_rate=1e-4 \--text_encoder_learning_rate=1e-5 \--lr_scheduler="constant" \--output_dir="diffusion_experiment_result" \--enable_xformers_memory_efficient_attention \--dataloader_num_workers=8 \--index_file_path='/path/to/train_dataset_index.txt' \--dataset_path='/path/to/laion-ocr-select/' \--granularity=128 \--coord_mode="ltrb" \--max_length=77 \--resume_from_checkpoint="latest"
三、模型推理
(一)准备训练好的模型checkpoint
-
下载官网提供的模型checkpoint:layout planner、diffusion model (full parameter fine-tuning) 和diffusion model (lora fine-tuning)
-
准备stable-diffusion-v1-5模型
(二)全参数推理
CUDA_VISIBLE_DEVICES=4 accelerate launch inference_textdiffuser2_t2i_full.py \--pretrained_model_name_or_path="./stable-diffusion-v1-5" \--mixed_precision="fp16" \--output_dir="inference_results_1" \--enable_xformers_memory_efficient_attention \--resume_from_checkpoint="./textdiffuser2-full-ft" \--granularity=128 \--max_length=77 \--coord_mode="ltrb" \--cfg=7.5 \--sample_steps=20 \--seed=43555 \--m1_model_path="./textdiffuser2_layout_planner" \--input_format='prompt' \--input_prompt='a hotdog with mustard and other toppings on it'
推理结果:
(三)LoRA微调推理
CUDA_VISIBLE_DEVICES=4 accelerate launch inference_textdiffuser2_t2i_lora.py \--pretrained_model_name_or_path="./stable-diffusion-v1-5" \--gradient_accumulation_steps=4 \--gradient_checkpointing \--mixed_precision="fp16" \--output_dir="inference_results_2" \--enable_xformers_memory_efficient_attention \--resume_from_checkpoint="./textdiffuser2-lora-ft" \--granularity=128 \--coord_mode="ltrb" \--cfg=7.5 \--sample_steps=50 \--seed=43555 \--m1_model_path="./textdiffuser2_layout_planner" \--input_format='prompt' \--input_prompt='a stamp of u.s.a'
运行结果:
四、遇到的错误
(一)importerror,缺少某些库
在运行过程中出现了各种各样的importerror,于是就是缺少哪个库就下载那个库:
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python
pip install protobuf
(二)报错:libGL.so.1: cannot open shared object file: No such file or directory
pip uninstall opencv-python
pip install opencv-python-headless
(三)各种奇奇怪怪的错误(本质上是diffusers版本不对)
- RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
- The deprecation tuple (‘LoRAXFormersAttnProcessor’, ‘0.26.0’, 'Make sure use XFormersAttnProcessor instead by settingLoRA layers to `self.
pip install diffusers==0.24.0 -i https://pypi.mirrors.ustc.edu.cn/simple/
(四)各种库的版本不兼容
由于作者在官网上提供了实验中使用的软件包列表可供参考,所以我直接将textdiffuser-2的assets文件夹下的refere_requirements.txt文件中的库一次性安装下来:
cd assets
pip install -r reference_requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/
(五)RuntimeError: expected scalar type float Float bu found Half
这个错误是因为安装的diffusers包里有个文件需要用官网提供的新文件进行替换
可以先根据错误提示找到diffusers库包中attention_processor.py所在的位置,然后用assets文件夹下attention_processor.py进行替换即可解决问题。
参考:libGL.so.1: cannot open shared object file: No such file or directory
相关文章:

本地部署推理TextDiffuser-2:释放语言模型用于文本渲染的力量
系列文章目录 文章目录 系列文章目录一、模型下载和环境配置二、模型训练(一)训练布局规划器(二)训练扩散模型 三、模型推理(一)准备训练好的模型checkpoint(二)全参数推理ÿ…...
设计模式: 模板方法模式
文章目录 一、什么是模板方法模式二、模板方法模式结构三、优点 一、什么是模板方法模式 模板方法模式(Template Method Pattern)是一种行为设计模式,它定义了一个操作中的算法骨架,将一些步骤延迟到子类中实现。这样可以使得子类…...

C++ 打印输出十六进制数 指定占位符前面填充0
C 打印十六进制数据,指定数据长度,前面不够时,补充0. 代码如下: #include <iostream> #include <iomanip> #include <cmath>using namespace std;int main() {unsigned int id 0xc01;unsigned int testCaseId…...

基于 Win Server 2008 复现 IPC$ 漏洞
写在前面 本篇博客演示了使用 winXP(配合部分 win10 的命令)对 win server 2008 的 IPC$ 漏洞进行内网渗透,原本的实验是要求使用 win server 2003,使用 win server 2003 可以规避掉很多下面存在的问题,建议大家使用 …...
HTML笔记2
11,路径 <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>Document</title> <…...

使用Docker实现Jenkins+Python + Pytest +Allure 接口自动化
一、Jenkins搭建 参考《Docker 安装 Jenkins》 进入 jenkins 容器 CLI 界面 docker exec -itu root jenkins /bin/bash二、准备条件 1、替换镜像内源 为了安装wget,默认用yum会安装不上wget命令,参考文章《docker容器内如何更换yum源【只想换成国内…...

论企业安全漏洞扫描的重要性
前言 随着信息技术的迅猛发展和互联网的广泛普及,网络安全问题日益凸显。在这个数字化的世界里,无论是企业还是个人,都面临着前所未有的安全威胁。安全漏洞,作为这些威胁的源头,常常被忽视或无法及时发现。 而安全漏洞…...

23.网络游戏逆向分析与漏洞攻防-网络通信数据包分析工具-实现配置工具数据结构
免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 如果看不懂、不知道现在做的什么,那就跟着做完看效果 内容参考于:易道云信息技术研究院VIP课 上一个内容:22.加载配置文件…...

STM32CubeMX学习笔记20——SD卡FATFS文件系统
1. FATFS文件系统简介 文件系统是操作系统用于明确存储设备或分区上的文件的方法和数据结构(即在存储设备上组织文件的方法)。操作系统中负责管理和存储文件信息的软件机构称为文件管理系统,简称文件系统;不带文件系统的SD卡仅能…...

Facebook商城号为什么被封?如何防封?
由于Facebook商城的高利润空间,越来越多的跨境电商商家注意到它的存在。Facebook作为全球最大、用户量最大的社媒平台,同时也孕育了一个巨大的商业生态,包括广告投放、商城交易等。依托背后的大流量,Facebook商城起号较快…...

【教程】APP备案全攻略:确保你的应用合规上线
【教程】APP备案全攻略:确保你的应用合规上线 摘要 本文详细介绍了中国大陆地区互联网信息服务提供者(AP)进行APP备案的流程、要求和注意事项。包括备案对象、备案方式、备案内容、备案流程等方面的详细说明,帮助开发者顺利完成…...
Vue入门2
v-model 原理:v-model本质上是一个语法糖。例如应用于输入框,就是value属性和input事件的合写。 作用:提供数据的双向绑定 数据变,视图跟着变 :value视图变,数据跟着变 input 注意:$event用于在模板中&…...

简介:CMMI软件能力成熟度集成模型
前言 CMMI是英文Capability Maturity Model Integration的缩写。 CMMI认证简称软件能力成熟度集成模型,是鉴定企业在开发流程化和质量管理上的国际通行标准,全球软件生产标准大都以此为基点,并都努力争取成为CMMI认证队伍中的一分子。 对一个…...
mysql的其他问题
1.MySQL数据库作发布系统的存储,一天五万条以上的增量,预计运维三年,怎么优化? a. 设计良好的数据库结构,允许部分数据冗余,尽量避免join查询,提高效率。 b. 选择合适的表字段数据类型和存储引擎…...

数据结构---复杂度(2)
1.斐波那契数列的时间复杂度问题 每一行分别是2^0---2^1---2^2-----2^3-------------------------------------------2^(n-2) 利用错位相减法,可以得到结果是,2^(n-1)-1,其实还是要减去右下角的灰色部分,我们可以拿简单的数字进行举例子&…...
【设计模式】设计原则和常见的23种经典设计模式
设计模式 1. 设计原则(记忆口诀:SOLID)【记忆口诀:单开里依接迪合(单开礼仪接地和)】 (1)单一职责原则(Single Responsibility Principle, SRP) ÿ…...

Spring Cloud Gateway自定义断言
问题:Spring Cloud Gateway自带的断言(Predicate)不满足业务怎么办?可以自定义断言! 先看Spring Cloud Gateway是如何实现断言的 Gateway中断言的整体架构如下: public abstract class AbstractRoutePred…...

智能测径仪在胶管行业的应用
关键字:胶管外径尺寸测量,胶管检测仪器,胶管外径检测,高温胶管外径检测,软硬胶管检测, 智能测径仪在家胶管行业中的应用主要体现在对胶管外径的精确测量和控制上。在胶管生产过程中,外径的大小直…...

vue自定义主题皮肤方案
方案一:CSS变量换肤(推荐) 利用css定义变量的方法,用var在全局定义颜色变量(需将变量提升到全局即伪类选择器 :root)然后利用js操作css变量,document.getElementsByTagName(‘body’)[0].style…...
iOS中使用schema协议调用APP和使用iframe打开APP的例子
大家好我是咕噜美乐蒂,很高兴又和大家见面了! 当调用自定义 URL scheme 或使用 iframe 打开应用程序时,可以采取以下详细步骤: 使用自定义 URL scheme 协议调用应用程序 1.首先,确认目标应用程序已经注册了自定义 U…...
Linux链表操作全解析
Linux C语言链表深度解析与实战技巧 一、链表基础概念与内核链表优势1.1 为什么使用链表?1.2 Linux 内核链表与用户态链表的区别 二、内核链表结构与宏解析常用宏/函数 三、内核链表的优点四、用户态链表示例五、双向循环链表在内核中的实现优势5.1 插入效率5.2 安全…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
Rust 异步编程
Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...
是否存在路径(FIFOBB算法)
题目描述 一个具有 n 个顶点e条边的无向图,该图顶点的编号依次为0到n-1且不存在顶点与自身相连的边。请使用FIFOBB算法编写程序,确定是否存在从顶点 source到顶点 destination的路径。 输入 第一行两个整数,分别表示n 和 e 的值(1…...

基于IDIG-GAN的小样本电机轴承故障诊断
目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) 梯度归一化(Gradient Normalization) (2) 判别器梯度间隙正则化(Discriminator Gradient Gap Regularization) (3) 自注意力机制(Self-Attention) 3. 完整损失函数 二…...
tomcat入门
1 tomcat 是什么 apache开发的web服务器可以为java web程序提供运行环境tomcat是一款高效,稳定,易于使用的web服务器tomcathttp服务器Servlet服务器 2 tomcat 目录介绍 -bin #存放tomcat的脚本 -conf #存放tomcat的配置文件 ---catalina.policy #to…...
学习一下用鸿蒙DevEco Studio HarmonyOS5实现百度地图
在鸿蒙(HarmonyOS5)中集成百度地图,可以通过以下步骤和技术方案实现。结合鸿蒙的分布式能力和百度地图的API,可以构建跨设备的定位、导航和地图展示功能。 1. 鸿蒙环境准备 开发工具:下载安装 De…...
云原生周刊:k0s 成为 CNCF 沙箱项目
开源项目推荐 HAMi HAMi(原名 k8s‑vGPU‑scheduler)是一款 CNCF Sandbox 级别的开源 K8s 中间件,通过虚拟化 GPU/NPU 等异构设备并支持内存、计算核心时间片隔离及共享调度,为容器提供统一接口,实现细粒度资源配额…...

若依登录用户名和密码加密
/*** 获取公钥:前端用来密码加密* return*/GetMapping("/getPublicKey")public RSAUtil.RSAKeyPair getPublicKey() {return RSAUtil.rsaKeyPair();}新建RSAUti.Java package com.ruoyi.common.utils;import org.apache.commons.codec.binary.Base64; im…...

aardio 自动识别验证码输入
技术尝试 上周在发学习日志时有网友提议“在网页上识别验证码”,于是尝试整合图像识别与网页自动化技术,完成了这套模拟登录流程。核心思路是:截图验证码→OCR识别→自动填充表单→提交并验证结果。 代码在这里 import soImage; import we…...