基于Stable Diffusion XL模型进行文本生成图像的训练
基于Stable Diffusion XL模型进行文本生成图像的训练
flyfish
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--report_to="wandb" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model" \--push_to_hub
环境变量部分
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
MODEL_NAME
:指定预训练模型的名称或路径。这里使用的是stabilityai/stable-diffusion-xl-base-1.0
,也就是Stable Diffusion XL的基础版本1.0。VAE_NAME
:指定变分自编码器(VAE)的名称或路径。madebyollin/sdxl-vae-fp16-fix
是针对Stable Diffusion XL的一个经过修复的VAE模型,适用于半精度(FP16)计算。DATASET_NAME
:指定训练所使用的数据集名称或路径。这里使用的是lambdalabs/naruto-blip-captions
,是一个包含火影忍者相关图像及其描述的数据集。
accelerate launch
命令参数部分
accelerate launch train_text_to_image_sdxl.py \
这行代码使用 accelerate
工具来启动 train_text_to_image_sdxl.py
脚本,accelerate
可以帮助我们在多GPU、TPU等环境下进行分布式训练。
脚本参数部分
--pretrained_model_name_or_path=$MODEL_NAME
:指定预训练模型的名称或路径,这里使用前面定义的MODEL_NAME
环境变量。--pretrained_vae_model_name_or_path=$VAE_NAME
:指定预训练VAE模型的名称或路径,使用前面定义的VAE_NAME
环境变量。--dataset_name=$DATASET_NAME
:指定训练数据集的名称或路径,使用前面定义的DATASET_NAME
环境变量。--enable_xformers_memory_efficient_attention
:启用xformers
库的内存高效注意力机制,能减少训练过程中的内存占用。--resolution=512 --center_crop --random_flip
:--resolution=512
:将输入图像的分辨率统一调整为512x512像素。--center_crop
:对图像进行中心裁剪,使其达到指定的分辨率。--random_flip
:在训练过程中随机对图像进行水平翻转,以增加数据的多样性。
--proportion_empty_prompts=0.2
:设置空提示(没有文本描述)的样本在训练数据中的比例为20%。--train_batch_size=1
:每个训练批次包含的样本数量为1。--gradient_accumulation_steps=4 --gradient_checkpointing
:--gradient_accumulation_steps=4
:梯度累积步数为4,即每4个批次的梯度进行一次更新,这样可以在有限的内存下模拟更大的批次大小。--gradient_checkpointing
:启用梯度检查点机制,通过减少内存使用来支持更大的模型和批次大小。
--max_train_steps=10000
:最大训练步数为10000步。--use_8bit_adam
:使用8位Adam优化器,能减少内存占用。--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
:--learning_rate=1e-06
:学习率设置为1e-6。--lr_scheduler="constant"
:学习率调度器设置为常数,即训练过程中学习率保持不变。--lr_warmup_steps=0
:学习率预热步数为0,即不进行学习率预热。
--mixed_precision="fp16"
:使用半精度(FP16)混合精度训练,能减少内存使用并加快训练速度。--report_to="wandb"
:将训练过程中的指标报告到Weights & Biases(WandB)平台,方便进行可视化和监控。--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
:--validation_prompt="a cute Sundar Pichai creature"
:指定验证时使用的文本提示,这里是“一个可爱的桑达尔·皮查伊形象”。--validation_epochs 5
:每5个训练轮次进行一次验证。
--checkpointing_steps=5000
:每5000步保存一次模型的检查点。--output_dir="sdxl-naruto-model"
:指定训练好的模型的输出目录为sdxl-naruto-model
。--push_to_hub
:将训练好的模型推送到Hugging Face模型库。
离线环境运行
# 假设已经把模型、VAE和数据集下载到本地了
# 这里假设模型在当前目录下的 sdxl-base-1.0 文件夹
# VAE 在 sdxl-vae-fp16-fix 文件夹
# 数据集在 naruto-blip-captions 文件夹# 定义本地路径
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"# 移除需要外网连接的参数
accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model"
移除需要外网连接的参数:去掉 --report_to="wandb"
和 --push_to_hub
参数,因为 wandb
需要外网连接来上传训练指标,--push_to_hub
则需要外网连接把模型推送到Hugging Face模型库。
推理
from diffusers import DiffusionPipeline
import torchmodel_path = "you-model-id-goes-here" # <-- 替换为你的模型路径
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")
训练后的文件夹结构
.
├── checkpoint-10000
│ ├── optimizer.bin
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ ├── scheduler.bin
│ └── unet
│ ├── config.json
│ ├── diffusion_pytorch_model-00001-of-00002.safetensors
│ ├── diffusion_pytorch_model-00002-of-00002.safetensors
│ └── diffusion_pytorch_model.safetensors.index.json
├── checkpoint-5000
│ ├── optimizer.bin
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ ├── scheduler.bin
│ └── unet
│ ├── config.json
│ ├── diffusion_pytorch_model-00001-of-00002.safetensors
│ ├── diffusion_pytorch_model-00002-of-00002.safetensors
│ └── diffusion_pytorch_model.safetensors.index.json
├── model_index.json
├── scheduler
│ └── scheduler_config.json
├── text_encoder
│ ├── config.json
│ └── model.safetensors
├── text_encoder_2
│ ├── config.json
│ └── model.safetensors
├── tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── tokenizer_2
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── unet
│ ├── config.json
│ ├── diffusion_pytorch_model-00001-of-00002.safetensors
│ ├── diffusion_pytorch_model-00002-of-00002.safetensors
│ └── diffusion_pytorch_model.safetensors.index.json
└── vae├── config.json└── diffusion_pytorch_model.safetensors
LoRA训练
accelerate launch train_text_to_image_lora_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME --caption_column="text" \--resolution=1024 --random_flip \--train_batch_size=1 \--num_train_epochs=2 --checkpointing_steps=500 \--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--seed=42 \--output_dir="sd-naruto-model-lora-sdxl" \--validation_prompt="cute dragon creature"
推理
from diffusers import DiffusionPipeline
import torchsdxl_model_path="/media/models/AI-ModelScope/stable-diffusion-xl-base-1___0/"
lora_model_path = "/media/text_to_image/sd-naruto-model-lora-sdxl/"pipe = DiffusionPipeline.from_pretrained(sdxl_model_path, torch_dtype=torch.float16)
pipe.to("cuda")
pipe.load_lora_weights(lora_model_path)prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")
LoRA训练后的文件夹结构
├── checkpoint-1000
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
├── checkpoint-1500
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
├── checkpoint-2000
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
├── checkpoint-500
│ ├── optimizer.bin
│ ├── pytorch_lora_weights.safetensors
│ ├── random_states_0.pkl
│ ├── scaler.pt
│ └── scheduler.bin
└── pytorch_lora_weights.safetensors
相关文章:
基于Stable Diffusion XL模型进行文本生成图像的训练
基于Stable Diffusion XL模型进行文本生成图像的训练 flyfish export MODEL_NAME"stabilityai/stable-diffusion-xl-base-1.0" export VAE_NAME"madebyollin/sdxl-vae-fp16-fix" export DATASET_NAME"lambdalabs/naruto-blip-captions"acceler…...

Facebook的元宇宙新次元:社交互动如何改变?
科技的浪潮正将我们推向一个全新的时代——元宇宙时代。Facebook,这个全球最大的社交网络平台,已经宣布将公司名称更改为 Meta,全面拥抱元宇宙概念。那么,元宇宙究竟是什么?它将如何改变我们的社交互动方式呢ÿ…...

概统期末复习--速成
随机事件及其概率 加法公式 推三个的时候ABC,夹逼准则 减法准则 除法公式 相互独立定义 两种分析 两个解法 古典概型求概率(排列组合) 分步相乘、分类相加 全概率公式和贝叶斯公式 两阶段问题 第一个小概率*A在小概率的概率。。。累计 …...

n8n系列(1)初识n8n:工作流自动化平台概述
1. 引言 随着各类自动化工具的涌现,n8n作为一款开源的工作流自动化平台,凭借其灵活性、可扩展性和强大的集成能力,正在获得越来越多技术团队的青睐。 本文作为n8n系列的开篇,将带您全面了解这个强大的自动化平台,探索其起源、特性以及与其他工具的差异,帮助您判断n8n是否…...
Java中Comparator排序原理详解
引言 在Java编程中,集合排序是一个常见需求。很多开发者对于为什么o2-o1实现降序排列而o1-o2实现升序排列感到困惑。本文将从数学角度解析这个问题,帮助读者彻底理解Comparator的排序原理。 问题引入 看看以下排序代码: List<Student&…...
PyQt5基础:QWidget类的全面解析与应用实践
在Python的GUI编程领域,PyQt5是一个强大且广泛应用的库。其中,QWidget类作为所有用户界面对象的基类,是构建丰富多样用户界面的基础。今天,我们就来深入了解QWidget类及其相关应用。 QWidget类概述 QWidget类是PyQt中所有窗口和…...
Python-77:古生物DNA序列血缘分析
问题描述 小U是一位古生物学家,正在研究不同物种之间的血缘关系。为了分析两种古生物的血缘远近,她需要比较它们的DNA序列。DNA由四种核苷酸A、C、G、T组成,并且可能通过三种方式发生变异:添加一个核苷酸、删除一个核苷酸或替换一…...

QT6 源(82):阅读与注释日历类型 QCalendar,本类并未完结,儒略历,格里高利历原来就是公历,
(1)本代码来自于头文件 qcalendar . h : #ifndef QCALENDAR_H #define QCALENDAR_H#include <limits>#include <QtCore/qglobal.h> #include <QtCore/qlocale.h> #include <QtCore/qstring.h> #include <QtCore/…...

CVE体系若消亡将如何影响网络安全防御格局
CVE体系的核心价值与当前危机 由MITRE运营的通用漏洞披露(CVE)项目的重要性不容低估。25年来,它始终是网络安全专业人员理解和缓解安全漏洞的基准参照系。通过提供标准化的漏洞命名与分类方法,这套体系为防御者建立了理解、优先级…...

OpenKylin安装Elastic Search8
一、环境准备 Java安装 安装过程此处不做赘述,使用以下命令检查是否安装成功。 java -version 注意:Elasticsearch 自 7.0 版本起内置了 OpenJDK,无需单独安装。但如需自定义 JDK,可设置 JAVA_HOME。 二、安装Elasticsearch …...

【ARM AMBA AHB 入门 3 -- AHB 总线介绍】
请阅读【ARM AMBA 总线 文章专栏导读】 文章目录 AHB Bus 简介AHB Bus 构成AHB BUS 工作机制AHB 传输阶段 AHB InterfacesAHB仲裁信号 AHB 数据访问零等待传输(no waitstatetransfer)等待传输(transfers with wait states)多重传送(multipletransfer)--Pipeline AHB 控制信号 A…...

多模态大模型中的视觉分词器(Tokenizer)前沿研究介绍
文章目录 引言MAETok背景方法介绍高斯混合模型(GMM)分析模型架构 实验分析总结 FlexTok背景方法介绍模型架构 实验分析总结 Emu3背景方法介绍模型架构训练细节 实验分析总结 InternVL2.5背景方法介绍模型架构 实验分析总结 LLAVA-MINI背景方法介绍出发点…...

sqli-labs靶场第二关——数字型
一:查找注入类型: 输入 ?id1--与第一关的差别:报错; 说明不是字符型 渐进测试:?id1--,结果正常,说明是数字型 二:判断列数和回显位 ?id1 order by 3-- 正常, 说明有三列&am…...
使用FastAPI微服务在AWS EKS上实现AI会话历史的管理
架构概述 本文介绍如何使用FastAPI构建微服务架构,在AWS EKS上部署两个微服务: 服务A:接收用户提示服务B:处理对话逻辑,与Redis缓存和MongoDB数据库交互 该架构利用AWS ElastiCache(Redis)实现快速响应,…...

[模型选择与调优]机器学习-part4
七 模型选择与调优 1 交叉验证 (1) 保留交叉验证HoldOut HoldOut Cross-validation(Train-Test Split) 在这种交叉验证技术中,整个数据集被随机地划分为训练集和验证集。根据经验法则,整个数据集的近70%被用作训练集ÿ…...

【计算机网络-数据链路层】以太网、MAC地址、MTU与ARP协议
📚 博主的专栏 🐧 Linux | 🖥️ C | 📊 数据结构 | 💡C 算法 | 🅒 C 语言 | 🌐 计算机网络 上篇文章:传输层-TCP协议TCP核心机制与可靠性保障 下篇文章: 网络…...
学习适应对智能软件对对象的属性进行表征、计算的影响
下面的链接是我新发表的文章。这篇文章是关于智能软件对对象进行标志、表征的问题,这是所有智能实体都无法回避的基本问题。 我最近写了一篇关于奖惩系统的文章。并开始写智能是如何在基础编程的基础上涌现出来的文章。 https://www.oalib.com/articles/6857382 …...
vue 组件函数式调用实战:以身份验证弹窗为例
通常我们在 Vue 中使用组件,是像这样在模板中写标签: <MyComponent :prop"value" event"handleEvent" />而函数式调用,则是让我们像调用一个普通 JavaScript 函数一样来使用这个组件,例如:…...
多线程面试题总结
基础概念 进程与线程的区别 进程:操作系统资源分配的基本单位,有独立内存空间线程:CPU调度的基本单位,共享进程资源对比: 创建开销:进程 > 线程通信方式:进程(IPC)、线程(共享内存)安全性:进程更安全(隔离),线程需要同步线程的生命周期与状态转换 NEW → RUNNABLE …...

Kafka 与 RabbitMQ、RocketMQ 有何不同?
一、不同的诞生背景,塑造了不同的“性格” 名称 背景与目标 产品定位 Kafka 为了解决 LinkedIn 的日志收集瓶颈,强调吞吐与持久化 更像一个“可持久化的分布式日志系统” RabbitMQ 出自金融通信协议 AMQP 的实现,强调协议标准与广泛适…...
【比赛真题解析】篮球迷
本次给大家分享一道比赛的题目:篮球迷。 洛谷链接:U561543 篮球迷 题目如下: 【题目描述】 众所周知,jimmy是个篮球迷。众所周知,Jimmy非常爱看NBA。 众所周知,Jimmy对NBA冠军球队的获奖年份和队名了如指掌。 所以,Jimmy要告诉你n个冠军球队的名字和获奖年份,并要求你…...

【MATLAB源码-第277期】基于matlab的AF中继系统仿真,AF和直传误码率对比、不同中继位置误码率对比、信道容量、中继功率分配以及终端概率。
操作环境: MATLAB 2022a 1、算法描述 在AF(放大转发)中继通信系统中,信号的传输质量和效率受到多个因素的影响,理解这些因素对于系统的优化至关重要。AF中继通信的基本架构由发射端、中继节点和接收端组成。发射端负…...

webRtc之指定摄像头设备绿屏问题
摘要:最近发现,在使用navigator.mediaDevices.getUserMedia({ deviceId: ‘xxx’}),指定设备的时候,video播放总是绿屏,发现关闭浏览器硬件加速不会出现,但显然这不是一个最好的方案; 播放后张这样 修复后 上代码 指定…...

2023年03月青少年软件编程(图形化)等级考试四级编程题
求和 1.准备工作 (1)保留舞台中的小猫角色和白色背景。 2.功能实现 (1)计算1~100中,可以被3整除的数之和; (2)说出被3整除的数之和。 标准答案: 参考程序&…...

ensp的华为小实验
1.先进行子网划分 2.进行接口的IP地址配置和ospf的简易配置,先做到全网小通 3.进行ospf优化 对区域所有区域域间路由器进行一个汇总 对区域1进行优化 对区域2.3进行nssa设置 4.对ISP的路由进行协议配置 最后ping通5.5.5.5...

ragflow报错:KeyError: ‘\n “序号“‘
环境: ragflowv 0.17.2 问题描述: ragflow报错:KeyError: ‘\n “序号”’ **1. 推荐表(输出json格式)** [{"},{},{"},{} ]raceback (most recent call last): May 08 20:06:09 VM-0-2-ubuntu ragflow-s…...
Java大数据可视化在城市空气质量监测与污染溯源中的应用:GIS与实时数据流的技术融合
随着城市化进程加速,空气质量监测与污染溯源成为智慧城市建设的核心议题。传统监测手段受限于数据离散性、分析滞后性及可视化能力不足,难以支撑实时决策。2025年4月27日发布的《Java大数据可视化在城市空气质量监测与污染溯源中的应用》一文,…...

FHE与后量子密码学
1. 引言 近年来,关于 后量子密码学(PQC, Post-Quantum Cryptography) 的讨论愈发热烈。这是因为安全专家担心,一旦有人成功研发出量子计算机,会发生什么可怕的事情。由于 Shor 算法的存在,量子计算机将能够…...
Flask 调试的时候进入main函数两次
在 Flask 开启 Debug 模式时,程序会因为自动重载(reloader)的机制而启动两个进程,导致if __name__ __main__底层的程序代码被执行两次。以下说明其原理与常见解法。 Flask Debug 模式下自动重载机制 Flask 使用的底层服务器 Wer…...
cv_area_center()
主题 用opencv实现了halcon中area_center算子的功能, 返回region的面积,中心的行坐标和中心的列坐标 代码很简单 def cv_area_center(region):area[]row []col []for re in region:retval cv2.moments(re)area.append(retval[m00])row.append(int(r…...