当前位置: 首页 > news >正文

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance
Github

摘要

近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导(CFG)。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中,我们提出了一种新颖的采样引导方法,称为Perturbed-Attention Guidance(PAG),它能在无条件和条件设置下提高扩散样本的质量,并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本,这是考虑到自注意力机制捕捉结构信息的能力,并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中,PAG 在条件甚至无条件场景下都显著提高了样本质量。此外,在诸如空提示的 ControlNet 以及图像修复(如修补和去模糊)等现有引导(如 CG 或 CFG)无法充分利用的各种下游任务中,PAG 也显著提高了基线性能。
在这里插入图片描述
研究表明,在diffusion U-Net的self-attention 模块中,query-key 主要影响structure ,values主要影响appearance。
在这里插入图片描述
如果直接扰动Vt 的话,会导致 out-of-distribution (OOD),因此选择使用单位矩阵替换query-key 部分。
在这里插入图片描述
在这里插入图片描述
那么具体扰动Unet的哪一部分呢?作者使用了5k个样本,在PAG guidance scale s = 2.5 and DDIM 25 step的条件下,表现最好的是mid-block “m0”
在这里插入图片描述

代码

Diffusers 已经支持PAG用在多种任务中,并且可以和ControlNet、 IP-Adapter 一起使用。

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torchpipeline = AutoPipelineForText2Image.from_pretrained("~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0",enable_pag=True,  ##addpag_applied_layers=["mid"], ##addtorch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()prompt = "an insect robot preparing a delicious meal, anime style"
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(prompt=prompt,num_inference_steps=25,guidance_scale=7.0,generator=generator,pag_scale=2.5,
).imagesimages[0].save("pag.jpg")

PAG代码细节

如果同时使用PAG和CFG,那么输入到Unet中prompt_embeds定义如下,也就是[uncond,cond,cond]

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):cond = torch.cat([cond] * 2, dim=0)if do_classifier_free_guidance:cond = torch.cat([uncond, cond], dim=0)return cond

PAGCFGIdentitySelfAttnProcessor2_0计算,其中[uncond,cond]正常计算SA,第二个cond则计算PSA。

class PAGCFGIdentitySelfAttnProcessor2_0:r"""Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).PAG reference: https://arxiv.org/abs/2403.17377"""def __init__(self):if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: Optional[torch.FloatTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,temb: Optional[torch.FloatTensor] = None,) -> torch.Tensor:residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# chunkhidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])# original pathbatch_size, sequence_length, _ = hidden_states_org.shapeif attention_mask is not None:attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# scaled_dot_product_attention expects attention_mask shape to be# (batch, heads, source_length, target_length)attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])if attn.group_norm is not None:hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states_org)key = attn.to_k(hidden_states_org)value = attn.to_v(hidden_states_org)inner_dim = key.shape[-1]head_dim = inner_dim // attn.headsquery = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# the output of sdp = (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states_org = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)hidden_states_org = hidden_states_org.to(query.dtype)# linear projhidden_states_org = attn.to_out[0](hidden_states_org)# dropouthidden_states_org = attn.to_out[1](hidden_states_org)if input_ndim == 4:hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)# perturbed path (identity attention)batch_size, sequence_length, _ = hidden_states_ptb.shapeif attn.group_norm is not None:hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)value = attn.to_v(hidden_states_ptb)hidden_states_ptb = valuehidden_states_ptb = hidden_states_ptb.to(query.dtype)# linear projhidden_states_ptb = attn.to_out[0](hidden_states_ptb)# dropouthidden_states_ptb = attn.to_out[1](hidden_states_ptb)if input_ndim == 4:hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)# cathidden_states = torch.cat([hidden_states_org, hidden_states_ptb])if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

经过Unet后,noise_pred的计算方法。

    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False):r"""Apply perturbed attention guidance to the noise prediction.Args:noise_pred (torch.Tensor): The noise prediction tensor.do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.guidance_scale (float): The scale factor for the guidance term.t (int): The current time step.return_pred_text (bool): Whether to return the text noise prediction.Returns:Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applyingperturbed attention guidance and the text noise prediction."""pag_scale = self._get_pag_scale(t)if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)noise_pred = (noise_pred_uncond+ guidance_scale * (noise_pred_text - noise_pred_uncond)+ pag_scale * (noise_pred_text - noise_pred_perturb))else:noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)if return_pred_text:return noise_pred, noise_pred_textreturn noise_pred

相关文章:

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance Github 摘要 近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导&#xff…...

自动驾驶控制与规划——Project 6: A* Route Planning

目录 零、任务介绍一、算法原理1.1 A* Algorithm1.2 启发函数 二、代码实现三、结果分析四、效果展示4.1 Dijkstra距离4.2 Manhatten距离4.3 欧几里德距离4.4 对角距离 五、后记 零、任务介绍 carla-ros-bridge/src/ros-bridge/carla_shenlan_projects/carla_shenlan_a_star_p…...

通俗易懂之线性回归时序预测PyTorch实践

线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入…...

[离线数仓] 总结二、Hive数仓分层开发

接 [离线数仓] 总结一、数据采集 5.8 数仓开发之ODS层 ODS层的设计要点如下: (1)ODS层的表结构设计依托于从业务系统同步过来的数据结构。 (2)ODS层要保存全部历史数据,故其压缩格式应选择压缩比率,较高的,此处选择gzip。 CompressedStorage - Apache Hive - Apac…...

页面顶部导航栏(Navbar)的功能(Navbar/index.vue)

这段代码是一个 Vue.js 组件&#xff0c;实现了页面顶部导航栏&#xff08;Navbar&#xff09;的功能。我将分块分析它的各个部分&#xff1a; 模板 (Template): <!-- spid-admin/src/layout/components/Navbar/index.vue --> <template><div class"navb…...

thinnkphp5.1和 thinkphp6以及nginx,apache 解决跨域问题

ThinkPHP 5.1 使用中间件设置响应头 ThinkPHP 5.1 及以上版本支持中间件&#xff0c;可以通过中间件统一设置跨域响应头。 步骤&#xff1a; 创建一个中间件文件&#xff0c;例如 CorsMiddleware.php&#xff1a; namespace app\middleware;class CorsMiddleware {public fu…...

vue2新增删除

&#xff08;只是页面实现&#xff0c;不涉及数据库&#xff09; list组件&#xff1a; <button click"onAdd">新增</button><el-table:header-cell-style"{ textAlign: center }" :cell-style"{ textAlign: center }":data&quo…...

测试ip端口-telnet开启与使用

前言 开发过程中我们总会要去测试ip通不通&#xff0c;或者ip下某个端口是否可以联通&#xff0c;为此我们可以使用telnet 命令来实现。 一、telnet 开启 可能有些人使用telnet报错&#xff0c;不是内部命令&#xff0c;可以如下开启&#xff1a; 1、打开控制面板&#xff…...

Python爬虫基础——XPath表达式

首先说一下这节内容在学习过程中存在的问题吧&#xff0c;在爬取百度网页文字时&#xff0c;出现了问题&#xff0c;就是通过表达式在网页搜索中可以定位&#xff0c;但是通过代码无法定位&#xff0c;请教了一位老师&#xff0c;他说是动态链接&#xff0c;目前这部分内容比较…...

ansible-性能优化

一. 简述&#xff1a; 搞过运维自动化工具的人&#xff0c;肯定会发现很多运维伙伴们经常用saltstack和ansible做比较&#xff0c;单从执行效率上来说&#xff0c;ansible确实比不上saltstack(ansible使用的是ssh,salt使用的是zeromq消息队列[暂没深入了解])&#xff0c;但其实…...

高等数学学习笔记 ☞ 一元函数微分的基础知识

1. 微分的定义 &#xff08;1&#xff09;定义&#xff1a;设函数在点的某领域内有定义&#xff0c;取附近的点&#xff0c;对应的函数值分别为和&#xff0c; 令&#xff0c;若可以表示成&#xff0c;则称函数在点是可微的。 【 若函数在点是可微的&#xff0c;则可以表达为】…...

前后端实现防抖节流实现

在前端和 Java 后端中实现防抖&#xff08;Debounce&#xff09;和节流&#xff08;Throttle&#xff09;主要用于减少频繁请求或事件触发对系统的压力。前端和后端的实现方式有些不同&#xff0c;以下是两种方法的具体实现&#xff1a; 1. 前端实现防抖和节流 在前端中&…...

【笔记】算法记录

1、求一个数的素因子&#xff08;试除法&#xff09; // 获取一个数的所有素因子 set<int> getPrimeFactors(int num) {set<int> primeFactors;for (int i 2; i * i < num; i) {while (num % i 0) {primeFactors.insert(i);num / i;}}if (num > 1) {prime…...

【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析

文章目录 一、选择题二、理论题三、实操题 【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析 一、选择题 生成树协议的主要作用是 B. 防止网络环路解释&#xff1a;生成树协议&#xff08;STP&#xff09;的主要目的是防止网络中…...

git push -f 指定分支

要将本地代码推送到指定的远程分支&#xff0c;你可以使用以下步骤和命令&#xff1a; 确认远程仓库&#xff1a; 确保你的本地仓库已经与远程仓库关联。你可以使用以下命令查看当前的远程仓库状态&#xff1a; git remote -v查看本地分支&#xff1a; 使用命令查看当前存在的本…...

CTF知识点总结(二)

异或注入&#xff1a;两个条件相同&#xff08;同真或同假&#xff09;即为假。 http://120.24.86.145:9004/1ndex.php?id1^(length(union)!0)-- 如上&#xff0c;如果union被过滤&#xff0c;则 length(union)!0 为假&#xff0c;那么返回页面正常。 2|0updatexml() 函数报…...

解决Edge打开PDF总是没有焦点

【问题描述】 使用Edge浏览器作为默认PDF阅读器打开本地PDF文件&#xff0c;Edge窗口总是不获得焦点&#xff0c;而是在任务栏以橙色显示&#xff0c;需要再手动点击一次才能查看文件内容。 本强迫症来治一治这个问题&#xff01; 【解决方法】 GPT老师指出问题出在Edge的启动…...

69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)

项目介绍 在Internet高速发展的今天&#xff0c;我们生活的各个领域都涉及到计算机的应用&#xff0c;其中包括家乡特色推荐的网络应用&#xff0c;在外国家乡特色推荐系统已经是很普遍的方式&#xff0c;不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…...

计算机视觉目标检测-DETR网络

目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR&#xff08;DEtection TRansformer&#xff09;是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题&#xff0c;摒弃了锚框设计和非…...

《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶

目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知&#xff0c;乃至可以使用感知结果直接构建鸟瞰图&#xff08;bird eye view, BEV&#xff09;&#xff0c;而 L4 则依赖离线地图。 高精地…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK&#xff0c;开始写第二篇的内容了。这篇博客主要能写一下&#xff1a; 如何给一些三方库按照xmake方式进行封装&#xff0c;供调用如何按…...

【项目实战】通过多模态+LangGraph实现PPT生成助手

PPT自动生成系统 基于LangGraph的PPT自动生成系统&#xff0c;可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析&#xff1a;自动解析Markdown文档结构PPT模板分析&#xff1a;分析PPT模板的布局和风格智能布局决策&#xff1a;匹配内容与合适的PPT布局自动…...

Java多线程实现之Callable接口深度解析

Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...

Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storms…...

给网站添加live2d看板娘

给网站添加live2d看板娘 参考文献&#xff1a; stevenjoezhang/live2d-widget: 把萌萌哒的看板娘抱回家 (ノ≧∇≦)ノ | Live2D widget for web platformEikanya/Live2d-model: Live2d model collectionzenghongtu/live2d-model-assets 前言 网站环境如下&#xff0c;文章也主…...

《Offer来了:Java面试核心知识点精讲》大纲

文章目录 一、《Offer来了:Java面试核心知识点精讲》的典型大纲框架Java基础并发编程JVM原理数据库与缓存分布式架构系统设计二、《Offer来了:Java面试核心知识点精讲(原理篇)》技术文章大纲核心主题:Java基础原理与面试高频考点Java虚拟机(JVM)原理Java并发编程原理Jav…...

EEG-fNIRS联合成像在跨频率耦合研究中的创新应用

摘要 神经影像技术对医学科学产生了深远的影响&#xff0c;推动了许多神经系统疾病研究的进展并改善了其诊断方法。在此背景下&#xff0c;基于神经血管耦合现象的多模态神经影像方法&#xff0c;通过融合各自优势来提供有关大脑皮层神经活动的互补信息。在这里&#xff0c;本研…...