当前位置: 首页 > article >正文

基于Stable Diffusion XL模型进行文本生成图像的训练

基于Stable Diffusion XL模型进行文本生成图像的训练

flyfish

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--report_to="wandb" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model" \--push_to_hub

环境变量部分

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
  • MODEL_NAME:指定预训练模型的名称或路径。这里使用的是 stabilityai/stable-diffusion-xl-base-1.0,也就是Stable Diffusion XL的基础版本1.0。
  • VAE_NAME:指定变分自编码器(VAE)的名称或路径。madebyollin/sdxl-vae-fp16-fix 是针对Stable Diffusion XL的一个经过修复的VAE模型,适用于半精度(FP16)计算。
  • DATASET_NAME:指定训练所使用的数据集名称或路径。这里使用的是 lambdalabs/naruto-blip-captions,是一个包含火影忍者相关图像及其描述的数据集。

accelerate launch 命令参数部分

accelerate launch train_text_to_image_sdxl.py \

这行代码使用 accelerate 工具来启动 train_text_to_image_sdxl.py 脚本,accelerate 可以帮助我们在多GPU、TPU等环境下进行分布式训练。

脚本参数部分

  • --pretrained_model_name_or_path=$MODEL_NAME:指定预训练模型的名称或路径,这里使用前面定义的 MODEL_NAME 环境变量。
  • --pretrained_vae_model_name_or_path=$VAE_NAME:指定预训练VAE模型的名称或路径,使用前面定义的 VAE_NAME 环境变量。
  • --dataset_name=$DATASET_NAME:指定训练数据集的名称或路径,使用前面定义的 DATASET_NAME 环境变量。
  • --enable_xformers_memory_efficient_attention:启用 xformers 库的内存高效注意力机制,能减少训练过程中的内存占用。
  • --resolution=512 --center_crop --random_flip
    • --resolution=512:将输入图像的分辨率统一调整为512x512像素。
    • --center_crop:对图像进行中心裁剪,使其达到指定的分辨率。
    • --random_flip:在训练过程中随机对图像进行水平翻转,以增加数据的多样性。
  • --proportion_empty_prompts=0.2:设置空提示(没有文本描述)的样本在训练数据中的比例为20%。
  • --train_batch_size=1:每个训练批次包含的样本数量为1。
  • --gradient_accumulation_steps=4 --gradient_checkpointing
    • --gradient_accumulation_steps=4:梯度累积步数为4,即每4个批次的梯度进行一次更新,这样可以在有限的内存下模拟更大的批次大小。
    • --gradient_checkpointing:启用梯度检查点机制,通过减少内存使用来支持更大的模型和批次大小。
  • --max_train_steps=10000:最大训练步数为10000步。
  • --use_8bit_adam:使用8位Adam优化器,能减少内存占用。
  • --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
    • --learning_rate=1e-06:学习率设置为1e-6。
    • --lr_scheduler="constant":学习率调度器设置为常数,即训练过程中学习率保持不变。
    • --lr_warmup_steps=0:学习率预热步数为0,即不进行学习率预热。
  • --mixed_precision="fp16":使用半精度(FP16)混合精度训练,能减少内存使用并加快训练速度。
  • --report_to="wandb":将训练过程中的指标报告到Weights & Biases(WandB)平台,方便进行可视化和监控。
  • --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
    • --validation_prompt="a cute Sundar Pichai creature":指定验证时使用的文本提示,这里是“一个可爱的桑达尔·皮查伊形象”。
    • --validation_epochs 5:每5个训练轮次进行一次验证。
  • --checkpointing_steps=5000:每5000步保存一次模型的检查点。
  • --output_dir="sdxl-naruto-model":指定训练好的模型的输出目录为 sdxl-naruto-model
  • --push_to_hub:将训练好的模型推送到Hugging Face模型库。

离线环境运行

# 假设已经把模型、VAE和数据集下载到本地了
# 这里假设模型在当前目录下的 sdxl-base-1.0 文件夹
# VAE 在 sdxl-vae-fp16-fix 文件夹
# 数据集在 naruto-blip-captions 文件夹# 定义本地路径
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"# 移除需要外网连接的参数
accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model"

移除需要外网连接的参数:去掉 --report_to="wandb"--push_to_hub 参数,因为 wandb 需要外网连接来上传训练指标,--push_to_hub 则需要外网连接把模型推送到Hugging Face模型库。

推理

from diffusers import DiffusionPipeline
import torchmodel_path = "you-model-id-goes-here" # <-- 替换为你的模型路径
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

训练后的文件夹结构

.
├── checkpoint-10000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── checkpoint-5000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── model_index.json
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── model.safetensors
├── text_encoder_2
│   ├── config.json
│   └── model.safetensors
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── tokenizer_2
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   ├── diffusion_pytorch_model-00001-of-00002.safetensors
│   ├── diffusion_pytorch_model-00002-of-00002.safetensors
│   └── diffusion_pytorch_model.safetensors.index.json
└── vae├── config.json└── diffusion_pytorch_model.safetensors

LoRA训练

accelerate launch train_text_to_image_lora_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME --caption_column="text" \--resolution=1024 --random_flip \--train_batch_size=1 \--num_train_epochs=2 --checkpointing_steps=500 \--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--seed=42 \--output_dir="sd-naruto-model-lora-sdxl" \--validation_prompt="cute dragon creature"

推理

from diffusers import DiffusionPipeline
import torchsdxl_model_path="/media/models/AI-ModelScope/stable-diffusion-xl-base-1___0/"
lora_model_path = "/media/text_to_image/sd-naruto-model-lora-sdxl/"pipe = DiffusionPipeline.from_pretrained(sdxl_model_path, torch_dtype=torch.float16)
pipe.to("cuda")
pipe.load_lora_weights(lora_model_path)prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

LoRA训练后的文件夹结构

├── checkpoint-1000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-1500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-2000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
└── pytorch_lora_weights.safetensors

相关文章:

基于Stable Diffusion XL模型进行文本生成图像的训练

基于Stable Diffusion XL模型进行文本生成图像的训练 flyfish export MODEL_NAME"stabilityai/stable-diffusion-xl-base-1.0" export VAE_NAME"madebyollin/sdxl-vae-fp16-fix" export DATASET_NAME"lambdalabs/naruto-blip-captions"acceler…...

Facebook的元宇宙新次元:社交互动如何改变?

科技的浪潮正将我们推向一个全新的时代——元宇宙时代。Facebook&#xff0c;这个全球最大的社交网络平台&#xff0c;已经宣布将公司名称更改为 Meta&#xff0c;全面拥抱元宇宙概念。那么&#xff0c;元宇宙究竟是什么&#xff1f;它将如何改变我们的社交互动方式呢&#xff…...

概统期末复习--速成

随机事件及其概率 加法公式 推三个的时候ABC&#xff0c;夹逼准则 减法准则 除法公式 相互独立定义 两种分析 两个解法 古典概型求概率&#xff08;排列组合&#xff09; 分步相乘、分类相加 全概率公式和贝叶斯公式 两阶段问题 第一个小概率*A在小概率的概率。。。累计 …...

n8n系列(1)初识n8n:工作流自动化平台概述

1. 引言 随着各类自动化工具的涌现,n8n作为一款开源的工作流自动化平台,凭借其灵活性、可扩展性和强大的集成能力,正在获得越来越多技术团队的青睐。 本文作为n8n系列的开篇,将带您全面了解这个强大的自动化平台,探索其起源、特性以及与其他工具的差异,帮助您判断n8n是否…...

Java中Comparator排序原理详解

引言 在Java编程中&#xff0c;集合排序是一个常见需求。很多开发者对于为什么o2-o1实现降序排列而o1-o2实现升序排列感到困惑。本文将从数学角度解析这个问题&#xff0c;帮助读者彻底理解Comparator的排序原理。 问题引入 看看以下排序代码&#xff1a; List<Student&…...

PyQt5基础:QWidget类的全面解析与应用实践

在Python的GUI编程领域&#xff0c;PyQt5是一个强大且广泛应用的库。其中&#xff0c;QWidget类作为所有用户界面对象的基类&#xff0c;是构建丰富多样用户界面的基础。今天&#xff0c;我们就来深入了解QWidget类及其相关应用。 QWidget类概述 QWidget类是PyQt中所有窗口和…...

Python-77:古生物DNA序列血缘分析

问题描述 小U是一位古生物学家&#xff0c;正在研究不同物种之间的血缘关系。为了分析两种古生物的血缘远近&#xff0c;她需要比较它们的DNA序列。DNA由四种核苷酸A、C、G、T组成&#xff0c;并且可能通过三种方式发生变异&#xff1a;添加一个核苷酸、删除一个核苷酸或替换一…...

QT6 源(82):阅读与注释日历类型 QCalendar,本类并未完结,儒略历,格里高利历原来就是公历,

&#xff08;1&#xff09;本代码来自于头文件 qcalendar . h &#xff1a; #ifndef QCALENDAR_H #define QCALENDAR_H#include <limits>#include <QtCore/qglobal.h> #include <QtCore/qlocale.h> #include <QtCore/qstring.h> #include <QtCore/…...

CVE体系若消亡将如何影响网络安全防御格局

CVE体系的核心价值与当前危机 由MITRE运营的通用漏洞披露&#xff08;CVE&#xff09;项目的重要性不容低估。25年来&#xff0c;它始终是网络安全专业人员理解和缓解安全漏洞的基准参照系。通过提供标准化的漏洞命名与分类方法&#xff0c;这套体系为防御者建立了理解、优先级…...

OpenKylin安装Elastic Search8

一、环境准备 Java安装 安装过程此处不做赘述&#xff0c;使用以下命令检查是否安装成功。 java -version 注意&#xff1a;Elasticsearch 自 7.0 版本起内置了 OpenJDK&#xff0c;无需单独安装。但如需自定义 JDK&#xff0c;可设置 JAVA_HOME。 二、安装Elasticsearch …...

【ARM AMBA AHB 入门 3 -- AHB 总线介绍】

请阅读【ARM AMBA 总线 文章专栏导读】 文章目录 AHB Bus 简介AHB Bus 构成AHB BUS 工作机制AHB 传输阶段 AHB InterfacesAHB仲裁信号 AHB 数据访问零等待传输(no waitstatetransfer)等待传输(transfers with wait states)多重传送(multipletransfer)--Pipeline AHB 控制信号 A…...

多模态大模型中的视觉分词器(Tokenizer)前沿研究介绍

文章目录 引言MAETok背景方法介绍高斯混合模型&#xff08;GMM&#xff09;分析模型架构 实验分析总结 FlexTok背景方法介绍模型架构 实验分析总结 Emu3背景方法介绍模型架构训练细节 实验分析总结 InternVL2.5背景方法介绍模型架构 实验分析总结 LLAVA-MINI背景方法介绍出发点…...

sqli-labs靶场第二关——数字型

一&#xff1a;查找注入类型&#xff1a; 输入 ?id1--与第一关的差别&#xff1a;报错; 说明不是字符型 渐进测试&#xff1a;?id1--&#xff0c;结果正常&#xff0c;说明是数字型 二&#xff1a;判断列数和回显位 ?id1 order by 3-- 正常&#xff0c; 说明有三列&am…...

使用FastAPI微服务在AWS EKS上实现AI会话历史的管理

架构概述 本文介绍如何使用FastAPI构建微服务架构&#xff0c;在AWS EKS上部署两个微服务&#xff1a; 服务A&#xff1a;接收用户提示服务B&#xff1a;处理对话逻辑&#xff0c;与Redis缓存和MongoDB数据库交互 该架构利用AWS ElastiCache(Redis)实现快速响应&#xff0c;…...

[模型选择与调优]机器学习-part4

七 模型选择与调优 1 交叉验证 (1) 保留交叉验证HoldOut HoldOut Cross-validation&#xff08;Train-Test Split&#xff09; 在这种交叉验证技术中&#xff0c;整个数据集被随机地划分为训练集和验证集。根据经验法则&#xff0c;整个数据集的近70%被用作训练集&#xff…...

【计算机网络-数据链路层】以太网、MAC地址、MTU与ARP协议

&#x1f4da; 博主的专栏 &#x1f427; Linux | &#x1f5a5;️ C | &#x1f4ca; 数据结构 | &#x1f4a1;C 算法 | &#x1f152; C 语言 | &#x1f310; 计算机网络 上篇文章&#xff1a;传输层-TCP协议TCP核心机制与可靠性保障 下篇文章&#xff1a; 网络…...

学习适应对智能软件对对象的属性进行表征、计算的影响

下面的链接是我新发表的文章。这篇文章是关于智能软件对对象进行标志、表征的问题&#xff0c;这是所有智能实体都无法回避的基本问题。 我最近写了一篇关于奖惩系统的文章。并开始写智能是如何在基础编程的基础上涌现出来的文章。 https://www.oalib.com/articles/6857382 …...

vue 组件函数式调用实战:以身份验证弹窗为例

通常我们在 Vue 中使用组件&#xff0c;是像这样在模板中写标签&#xff1a; <MyComponent :prop"value" event"handleEvent" />而函数式调用&#xff0c;则是让我们像调用一个普通 JavaScript 函数一样来使用这个组件&#xff0c;例如&#xff1a;…...

多线程面试题总结

基础概念 进程与线程的区别 进程:操作系统资源分配的基本单位,有独立内存空间线程:CPU调度的基本单位,共享进程资源对比: 创建开销:进程 > 线程通信方式:进程(IPC)、线程(共享内存)安全性:进程更安全(隔离),线程需要同步线程的生命周期与状态转换 NEW → RUNNABLE …...

Kafka 与 RabbitMQ、RocketMQ 有何不同?

一、不同的诞生背景&#xff0c;塑造了不同的“性格” 名称 背景与目标 产品定位 Kafka 为了解决 LinkedIn 的日志收集瓶颈&#xff0c;强调吞吐与持久化 更像一个“可持久化的分布式日志系统” RabbitMQ 出自金融通信协议 AMQP 的实现&#xff0c;强调协议标准与广泛适…...

【比赛真题解析】篮球迷

本次给大家分享一道比赛的题目:篮球迷。 洛谷链接:U561543 篮球迷 题目如下: 【题目描述】 众所周知,jimmy是个篮球迷。众所周知,Jimmy非常爱看NBA。 众所周知,Jimmy对NBA冠军球队的获奖年份和队名了如指掌。 所以,Jimmy要告诉你n个冠军球队的名字和获奖年份,并要求你…...

【MATLAB源码-第277期】基于matlab的AF中继系统仿真,AF和直传误码率对比、不同中继位置误码率对比、信道容量、中继功率分配以及终端概率。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 在AF&#xff08;放大转发&#xff09;中继通信系统中&#xff0c;信号的传输质量和效率受到多个因素的影响&#xff0c;理解这些因素对于系统的优化至关重要。AF中继通信的基本架构由发射端、中继节点和接收端组成。发射端负…...

webRtc之指定摄像头设备绿屏问题

摘要&#xff1a;最近发现&#xff0c;在使用navigator.mediaDevices.getUserMedia({ deviceId: ‘xxx’}),指定设备的时候&#xff0c;video播放总是绿屏&#xff0c;发现关闭浏览器硬件加速不会出现&#xff0c;但显然这不是一个最好的方案; 播放后张这样 修复后 上代码 指定…...

2023年03月青少年软件编程(图形化)等级考试四级编程题

求和 1.准备工作 &#xff08;1&#xff09;保留舞台中的小猫角色和白色背景。 2.功能实现 &#xff08;1&#xff09;计算1&#xff5e;100中&#xff0c;可以被3整除的数之和&#xff1b; &#xff08;2&#xff09;说出被3整除的数之和。 标准答案&#xff1a; 参考程序&…...

ensp的华为小实验

1.先进行子网划分 2.进行接口的IP地址配置和ospf的简易配置&#xff0c;先做到全网小通 3.进行ospf优化 对区域所有区域域间路由器进行一个汇总 对区域1进行优化 对区域2.3进行nssa设置 4.对ISP的路由进行协议配置 最后ping通5.5.5.5...

ragflow报错:KeyError: ‘\n “序号“‘

环境&#xff1a; ragflowv 0.17.2 问题描述&#xff1a; ragflow报错&#xff1a;KeyError: ‘\n “序号”’ **1. 推荐表&#xff08;输出json格式&#xff09;** [{"},{},{"},{} ]raceback (most recent call last): May 08 20:06:09 VM-0-2-ubuntu ragflow-s…...

Java大数据可视化在城市空气质量监测与污染溯源中的应用:GIS与实时数据流的技术融合

随着城市化进程加速&#xff0c;空气质量监测与污染溯源成为智慧城市建设的核心议题。传统监测手段受限于数据离散性、分析滞后性及可视化能力不足&#xff0c;难以支撑实时决策。2025年4月27日发布的《Java大数据可视化在城市空气质量监测与污染溯源中的应用》一文&#xff0c…...

FHE与后量子密码学

1. 引言 近年来&#xff0c;关于 后量子密码学&#xff08;PQC, Post-Quantum Cryptography&#xff09; 的讨论愈发热烈。这是因为安全专家担心&#xff0c;一旦有人成功研发出量子计算机&#xff0c;会发生什么可怕的事情。由于 Shor 算法的存在&#xff0c;量子计算机将能够…...

Flask 调试的时候进入main函数两次

在 Flask 开启 Debug 模式时&#xff0c;程序会因为自动重载&#xff08;reloader&#xff09;的机制而启动两个进程&#xff0c;导致if __name__ __main__底层的程序代码被执行两次。以下说明其原理与常见解法。 Flask Debug 模式下自动重载机制 Flask 使用的底层服务器 Wer…...

cv_area_center()

主题 用opencv实现了halcon中area_center算子的功能&#xff0c; 返回region的面积&#xff0c;中心的行坐标和中心的列坐标 代码很简单 def cv_area_center(region):area[]row []col []for re in region:retval cv2.moments(re)area.append(retval[m00])row.append(int(r…...