使用 DDPO 在 TRL 中微调 Stable Diffusion 模型
引言
扩散模型 (如 DALL-E 2、Stable Diffusion) 是一类文生图模型,在生成图像 (尤其是有照片级真实感的图像) 方面取得了广泛成功。然而,这些模型生成的图像可能并不总是符合人类偏好或人类意图。因此出现了对齐问题,即如何确保模型的输出与人类偏好 (如“质感”) 一致,或者与那种难以通过提示来表达的意图一致?这里就有强化学习的用武之地了。
在大语言模型 (LLM) 领域,强化学习 (RL) 已被证明是能让目标模型符合人类偏好的非常有效的工具。这是 ChatGPT 等系统卓越性能背后的主要秘诀之一。更准确地说,强化学习是人类反馈强化学习 (RLHF) 的关键要素,它使 ChatGPT 能像人类一样聊天。
在 Training Diffusion Models with Reinforcement Learning 一文中,Black 等人展示了如何利用 RL 来对扩散模型进行强化,他们通过名为去噪扩散策略优化 (Denoising Diffusion Policy Optimization,DDPO) 的方法针对模型的目标函数实施微调。
在本文中,我们讨论了 DDPO 的诞生、简要描述了其工作原理,并介绍了如何将 DDPO 加入 RLHF 工作流中以实现更符合人类审美的模型输出。然后,我们切换到实战,讨论如何使用 trl 库中新集成的 DDPOTrainer 将 DDPO 应用到模型中,并讨论我们在 Stable Diffusion 上运行 DDPO 的发现。
DDPO 的优势
DDPO 并非解决 如何使用 RL 微调扩散模型 这一问题的唯一有效答案。
在进一步深入讨论之前,我们强调一下在对 RL 解决方案进行横评时需要掌握的两个关键点:
计算效率是关键。数据分布越复杂,计算成本就越高。
近似法很好,但由于近似值不是真实值,因此相关的错误会累积。
在 DDPO 之前,奖励加权回归 (Reward-Weighted Regression,RWR) 是使用强化学习微调扩散模型的主要方法。RWR 重用了扩散模型的去噪损失函数、从模型本身采样得的训练数据以及取决于最终生成样本的奖励的逐样本损失权重。该算法忽略中间的去噪步骤/样本。虽然有效,但应该注意两件事:
通过对逐样本损失进行加权来进行优化,这是一个最大似然目标,因此这是一种近似优化。
加权后的损失甚至不是精确的最大似然目标,而是从重新加权的变分界中得出的近似值。
所以,根本上来讲,这是一个两阶近似法,其对性能和处理复杂目标的能力都有比较大的影响。
DDPO 始于此方法,但 DDPO 没有将去噪过程视为仅关注最终样本的单个步骤,而是将整个去噪过程构建为多步马尔可夫决策过程 (MDP),只是在最后收到奖励而已。这样做的好处除了可以使用固定的采样器之外,还为让代理策略成为各向同性高斯分布 (而不是任意复杂的分布) 铺平了道路。因此,该方法不使用最终样本的近似似然 (即 RWR 的做法),而是使用易于计算的每个去噪步骤的确切似然 ( )。
如果你有兴趣了解有关 DDPO 的更多详细信息,我们鼓励你阅读 原论文 及其 附带的博文。
DDPO 算法简述
考虑到我们用 MDP 对去噪过程进行建模以及其他因素,求解该优化问题的首选工具是策略梯度方法。特别是近端策略优化 (PPO)。整个 DDPO 算法与近端策略优化 (PPO) 几乎相同,仅对 PPO 的轨迹收集部分进行了比较大的修改。
下图总结了整个算法流程:
DDPO 和 RLHF: 合力增强美观性
RLHF 的一般训练步骤如下:
有监督微调“基础”模型,以学习新数据的分布。
收集偏好数据并用它训练奖励模型。
使用奖励模型作为信号,通过强化学习对模型进行微调。
需要指出的是,在 RLHF 中偏好数据是获取人类反馈的主要来源。
DDPO 加进来后,整个工作流就变成了:
从预训练的扩散模型开始。
收集偏好数据并用它训练奖励模型。
使用奖励模型作为信号,通过 DDPO 微调模型
请注意,DDPO 工作流把原始 RLHF 工作流中的第 3 步省略了,这是因为经验表明 (后面你也会亲眼见证) 这是不需要的。
下面我们实战一下,训练一个扩散模型来输出更符合人类审美的图像,我们分以下几步来走:
从预训练的 Stable Diffusion (SD) 模型开始。
在 美学视觉分析 (Aesthetic Visual Analysis,AVA) 数据集上训练一个带有可训回归头的冻结 CLIP 模型,用于预测人们对输入图像的平均喜爱程度。
使用美学预测模型作为奖励信号,通过 DDPO 微调 SD 模型。
记住这些步骤,下面开始干活:
使用 DDPO 训练 Stable Diffusion
环境设置
首先,要成功使用 DDPO 训练模型,你至少需要一个英伟达 A100 GPU,低于此规格的 GPU 很容易遇到内存不足问题。
使用 pip 安装 trl 库
pip install trl[diffusers] 主库安装好后,再安装所需的训练过程跟踪和图像处理相关的依赖库。注意,安装完 wandb 后,请务必登录以将结果保存到个人帐户。
pip install wandb torchvision 注意: 如果不想用 wandb ,你也可以用 pip 安装 tensorboard 。
演练一遍
trl 库中负责 DDPO 训练的主要是 DDPOTrainer 和 DDPOConfig 这两个类。有关 DDPOTrainer 和 DDPOConfig 的更多信息,请参阅 相应文档。trl 代码库中有一个 示例训练脚本。它默认使用这两个类,并有一套默认的输入和参数用于微调 RunwayML 中的预训练 Stable Diffusion 模型。
此示例脚本使用 wandb 记录训练日志,并使用美学奖励模型,其权重是从公开的 Hugging Face 存储库读取的 (因此数据收集和美学奖励模型训练均已经帮你做完了)。默认提示数据是一系列动物名。
用户只需要一个命令行参数即可启动脚本。此外,用户需要有一个 Hugging Face 用户访问令牌,用于将微调后的模型上传到 Hugging Face Hub。
运行以下 bash 命令启动程序:
python stable_diffusion_tuning.py --hf_user_access_token <token> 下表列出了影响微调结果的关键超参数:
| 参数 | 描述 | 单 GPU 训练推荐值(迄今为止) |
|---|---|---|
num_epochs | 训练 epoch 数 | 200 |
train_batch_size | 训练 batch size | 3 |
sample_batch_size | 采样 batch size | 6 |
gradient_accumulation_steps | 梯度累积步数 | 1 |
sample_num_steps | 采样步数 | 50 |
sample_num_batches_per_epoch | 每个 epoch 的采样 batch 数 | 4 |
per_prompt_stat_tracking | 是否跟踪每个提示的统计信息。如果为 False,将使用整个 batch 的平均值和标准差来计算优势,而不是对每个提示进行跟踪 | True |
per_prompt_stat_tracking_buffer_size | 用于跟踪每个提示的统计数据的缓冲区大小 | 32 |
mixed_precision | 混合精度训练 | True |
train_learning_rate | 学习率 | 3e-4 |
这个脚本仅仅是一个起点。你可以随意调整超参数,甚至彻底修改脚本以适应不同的目标函数。例如,可以集成一个测量 JPEG 压缩度的函数或 使用多模态模型评估视觉文本对齐度的函数 等。
经验与教训
尽管训练提示很少,但其结果似乎已经足够泛化。对于美学奖励函数而言,该方法已经得到了彻底的验证。
尝试通过增加训练提示数以及改变提示来进一步泛化美学奖励函数,似乎反而会减慢收敛速度,但对模型的泛化能力收效甚微。
虽然推荐使用久经考验 LoRA,但非 LoRA 也值得考虑,一个经验证据就是,非 LoRA 似乎确实比 LoRA 能产生相对更复杂的图像。但同时,非 LoRA 训练的收敛稳定性不太好,对超参选择的要求也高很多。
对于非 LoRA 的超参建议是: 将学习率设低点,经验值是大约
1e-5,同时将mixed_ precision设置为None。
结果
以下是提示 bear 、 heaven 和 dune 微调前 (左) 、后 (右) 的输出 (每行都是一个提示的输出):
| 微调前 | 微调后 |
|---|---|
![]() | ![]() |
![]() | ![]() |
![]() | ![]() |
限制
目前
trl的DDPOTrainer仅限于微调原始 SD 模型;在我们的实验中,主要关注的是效果较好的 LoRA。我们也做了一些全模型训练的实验,其生成的质量会更好,但超参寻优更具挑战性。
总结
像 Stable Diffusion 这样的扩散模型,当使用 DDPO 进行微调时,可以显著提高图像的主观质感或其对应的指标,只要其可以表示成一个目标函数的形式。
DDPO 的计算效率及其不依赖近似优化的能力,在扩散模型微调方面远超之前的方法,因而成为微调扩散模型 (如 Stable Diffusion) 的有力候选。
trl 库的 DDPOTrainer 实现了 DDPO 以微调 SD 模型。
我们的实验表明 DDPO 对很多提示具有相当好的泛化能力,尽管进一步增加提示数以增强泛化似乎效果不大。为非 LoRA 微调找到正确超参的难度比较大,这也是我们得到的重要经验之一。
DDPO 是一种很有前途的技术,可以将扩散模型与任何奖励函数结合起来,我们希望通过其在 TRL 中的发布,社区可以更容易地使用它!
致谢
感谢 Chunte Lee 提供本博文的缩略图。
🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!
英文原文: https://hf.co/blog/trl-ddpo
原文作者: Luke Meyers,Sayak Paul,Kashif Rasul,Leandro von Werra
译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。
审校/排版: zhongdongy (阿东)
相关文章:
使用 DDPO 在 TRL 中微调 Stable Diffusion 模型
引言 扩散模型 (如 DALL-E 2、Stable Diffusion) 是一类文生图模型,在生成图像 (尤其是有照片级真实感的图像) 方面取得了广泛成功。然而,这些模型生成的图像可能并不总是符合人类偏好或人类意图。因此出现了对齐问题,即如何确保模型的输出与…...
cocosCreator 之 crypto-es数据加密
版本: 3.8.0 语言: TypeScript 环境: Mac 简介 项目开发中,针对于一些明文数据,比如本地存储和Http数据请求等,进行加密保护,是有必要的。 关于加密手段主要有: 对称加密 使用相…...
Leetcode---368周赛
题目列表 2908. 元素和最小的山形三元组 I 2909. 元素和最小的山形三元组 II 2910. 合法分组的最少组数 2911. 得到 K 个半回文串的最少修改次数 一、元素和最小的山形三元组I 没什么好说的,不会其他方法就直接暴力,时间复杂度O(n^3),代…...
矢量图形编辑软件Illustrator 2023 mac中文版软件特点(ai2023) v27.9
illustrator 2023 mac是一款矢量图形编辑软件,用于创建和编辑排版、图标、标志、插图和其他类型的矢量图形。 illustrator 2023 mac软件特点 矢量图形:illustrator创建的图形是矢量图形,可以无限放大而不失真,这与像素图形编辑软…...
一、Docker Compose——什么是 Docker Compose
Docker Compose 是一个用来定义和运行多容器 Docker 应用程序的工具,他的方便之处就是可以使用 YAML 文件来配置将要运行的 Docker 容器,然后使用一条命令即可创建并启动配置好的 Docker 容器了;相比手动输入命令的繁琐,Docker Co…...
Java提升技术,进阶为高级开发和架构师的路线
原文网址:Java提升技术,进阶为高级开发和架构师的路线-CSDN博客 简介 Java怎样提升技术?怎样进阶为高级开发和架构师?本文介绍靠谱的成长路线。 首先点明,只写业务代码是无法成长技术的。提升技术的两个方法是&…...
记一次 .Net+SqlSugar 查询超时的问题排查过程
环境和版本:.Net 6 SqlSuger 5.1.4.* ,数据库是mysql 5.7 ,数据量在2000多条左右 业务是一个非常简单的查询,代码如下: var list _dbClient.Queryable<tb_name>().ToList(); tb_name 下配置了一对多的关系…...
PHP危险函数
PHP危险函数 文章目录 PHP危险函数PHP 代码执行函数eval 语句assert()语句preg_replace()函数正则表达式里修饰符 回调函数call_user_func()函数array_map()函数 OS命令执行函数system()函数exec()函数shell_exec()函数passthru() 函数popen 函数反引号 实列 通过构造函数可以执…...
【ARM Cortex-M 系列 4 番外篇 -- 常用 benchmark 介绍】
文章目录 1.1 CPU 性能测试 MIPS 计算1.1.1 Cortex-M7 CPI 1.2 benchmark 小节1.3.1 Geekbenck 介绍 1.3 编译参数配置 1.1 CPU 性能测试 MIPS 计算 每秒百万指令数 (MIPS):在数据压缩测试中,MIPS 每秒测量一次 CPU 执行的低级指令的数量。越高越好&…...
web安全-原发抗抵赖
原发抗抵赖 原发抗抵赖也称不可否认性,主要表现以下两种形式: 数据发送者无法否认其发送数据的事实。例如,A向B发信,事后,A不能否认该信是其发送的。数据接收者事后无法否认其收到过这些数据。例如,A向B发…...
强化学习------PPO算法
目录 简介一、PPO原理1、由On-policy 转化为Off-policy2、Importance Sampling(重要性采样)3、off-policy下的梯度公式推导 二、PPO算法两种形式1、PPO-Penalty2、PPO-Clip 三、PPO算法实战四、参考 简介 PPO 算法之所以被提出,根本原因在于…...
node(三)express框架
文章目录 1.express介绍2.express初体验3.express路由3.1什么是路由?3.2路由的使用 1.express介绍 是一个基于Node平台的极简、灵活的WEB应用开发框架,官网地址:https://www.expressjs.com.cn/ 简单来说,express是一个封装好的工…...
linux find命令搜索日志内容
linux find命令搜索日志内容 查询服务器log日志 find /opt/logs/ -name "filename.log" | xargs grep -a "这里是要查询的字符"加上-a 是为了不报查出 binary 的错 服务器会返回 包含所查字符的整行日志信息...
CentOS 编译安装TinyXml2
安装 TinyXml2 Git 源码下载地址:https://github.com/leethomason/tinyxml2 步骤1:首先,你需要下载tinyxml2的源代码。你可以从Github或者源代码官方网站下载。并上传至/usr/local/source_code/ 步骤2:下载完成后,需要将源代码解…...
竞赛选题 深度学习人体跌倒检测 -yolo 机器视觉 opencv python
0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 **基于深度学习的人体跌倒检测算法研究与实现 ** 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🥇学长这里给一个题目综合评分(每项满…...
使用gson将复杂的树型结构转Json遇到的问题,写入文件为空
某个项目需要用到一个较为复杂的数据结构。定义成一个树型链表。 public class TreeNode { private String name; public String getName() { return name; } public void setName(String name) { this.name name; } public String getPartType() { retur…...
JavaScript异步编程:提升性能与用户体验
目录 什么是异步编程? 回调函数 Promise Async/Await 总结 在Web开发中,处理耗时操作是一项重要的任务。如果我们在执行这些操作时阻塞了主线程,会导致页面失去响应,用户体验下降。JavaScript异步编程则可以解决这个问题&…...
lossBN
still tips for learning classification and regression关于softmax的引入和作用分类问题损失函数 - MSE & Cross-entropy⭐Batch Normalization(BN)⭐想法:直接改error surface的landscape,把山铲平feature normalization那…...
【微信小程序】数字化会议OA系统之投票模块(附源码)
🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的专栏《微信小程序开发实战》。🎯Ἲ…...
clang-前端插件-给各种无花括号的“块”加花括号-基于llvm15--clang-plugin-add-brace
处理的语句 case 术语约定或备忘 case起止范围: 从冒号到下一个’case’开头, 简称有: case内 、case内容Ast: Abstract syntax tree: 抽象语法树没插入花括号的case 若case内, 以下任一条成立,则 跳过该case 即 不会对该case内容用花括号包裹. 有#define、有#include、有…...
万象视界灵坛快速部署:GitLab CI流水线自动触发镜像构建与K8s滚动更新
万象视界灵坛快速部署:GitLab CI流水线自动触发镜像构建与K8s滚动更新 1. 项目概述 万象视界灵坛(Omni-Vision Sanctuary)是一款基于OpenAI CLIP技术的高级多模态智能感知平台。该平台通过创新的像素风格界面,将复杂的语义对齐过…...
TVBoxOSC:电视盒子全能播放解决方案终极指南
TVBoxOSC:电视盒子全能播放解决方案终极指南 【免费下载链接】TVBoxOSC TVBoxOSC - 一个基于第三方项目的代码库,用于电视盒子的控制和管理。 项目地址: https://gitcode.com/GitHub_Trending/tv/TVBoxOSC 你是否曾经为电视盒子播放视频时遇到格式…...
告别答辩 PPT 熬夜局!PaperXie AI 一键生成,3 分钟拿捏学术范答辩神器
paperxie-免费查重复率aigc检测/开题报告/毕业论文/智能排版/文献综述/AIPPThttps://www.paperxie.cn/ppt/createhttps://www.paperxie.cn/ppt/create 一、开题答辩人破防瞬间:PPT 做得好,答辩分数高一半 “论文写完了,PPT 才是真正的修罗场…...
OpCore-Simplify:突破性黑苹果EFI配置革命,15分钟完成专业级系统搭建 [特殊字符]
OpCore-Simplify:突破性黑苹果EFI配置革命,15分钟完成专业级系统搭建 🚀 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-Simplify…...
为什么Python社区推荐用pipx替代pip?以virtualenv安装为例演示工作流
为什么Python开发者应该用pipx替代pip?以virtualenv为例的完整隔离方案 当你在Ubuntu终端输入pip install virtualenv时,那个刺眼的externally-managed-environment错误提示就像一堵墙——这不是技术故障,而是Python生态进化的重要路标。传统…...
如何用网盘直链下载助手突破限制提升效率:5个实用技巧
如何用网盘直链下载助手突破限制提升效率:5个实用技巧 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天翼…...
实战指南:在快马平台用trae构建电商购物车状态管理系统
今天想和大家分享一个实战项目:用trae在电商场景下构建购物车状态管理系统。这个方案特别适合需要清晰数据流的中小型项目,比如电商平台、管理后台等。下面我会详细拆解整个实现过程,希望能给有类似需求的同学一些参考。 项目结构设计 首先…...
RTX 3090环境下的BEVFusion实战部署:从源码编译到多模态训练调优
1. RTX 3090环境准备与BEVFusion适配 在RTX 3090上部署BEVFusion最大的挑战就是硬件与软件版本的兼容性问题。官方推荐的环境是CUDA 9.2和PyTorch 1.3.1,但这对于RTX 3090来说完全不适用——30系显卡需要CUDA 11才能发挥全部性能。我刚开始尝试直接按照官方文档安装…...
DOL-CHS-MODS:开源工具助力游戏体验一键优化
DOL-CHS-MODS:开源工具助力游戏体验一键优化 【免费下载链接】DOL-CHS-MODS Degrees of Lewdity 整合 项目地址: https://gitcode.com/gh_mirrors/do/DOL-CHS-MODS 您是否在为游戏汉化过程中的繁琐配置而头疼?是否曾因美化补丁安装不当导致游戏崩…...
从零搭建无人船:两年实战后,我总结的ArduPilot+Pixhawk避坑全流程
从零搭建无人船:两年实战后,我总结的ArduPilotPixhawk避坑全流程 第一次把无人船放进水里时,GPS信号突然丢失,船体在河中央失控打转——这个惊心动魄的瞬间让我意识到,开源飞控的实战应用远不是下载代码、连接硬件那么…...





