当前位置: 首页 > news >正文

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance
Github

摘要

近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导(CFG)。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中,我们提出了一种新颖的采样引导方法,称为Perturbed-Attention Guidance(PAG),它能在无条件和条件设置下提高扩散样本的质量,并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本,这是考虑到自注意力机制捕捉结构信息的能力,并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中,PAG 在条件甚至无条件场景下都显著提高了样本质量。此外,在诸如空提示的 ControlNet 以及图像修复(如修补和去模糊)等现有引导(如 CG 或 CFG)无法充分利用的各种下游任务中,PAG 也显著提高了基线性能。
在这里插入图片描述
研究表明,在diffusion U-Net的self-attention 模块中,query-key 主要影响structure ,values主要影响appearance。
在这里插入图片描述
如果直接扰动Vt 的话,会导致 out-of-distribution (OOD),因此选择使用单位矩阵替换query-key 部分。
在这里插入图片描述
在这里插入图片描述
那么具体扰动Unet的哪一部分呢?作者使用了5k个样本,在PAG guidance scale s = 2.5 and DDIM 25 step的条件下,表现最好的是mid-block “m0”
在这里插入图片描述

代码

Diffusers 已经支持PAG用在多种任务中,并且可以和ControlNet、 IP-Adapter 一起使用。

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torchpipeline = AutoPipelineForText2Image.from_pretrained("~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0",enable_pag=True,  ##addpag_applied_layers=["mid"], ##addtorch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()prompt = "an insect robot preparing a delicious meal, anime style"
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(prompt=prompt,num_inference_steps=25,guidance_scale=7.0,generator=generator,pag_scale=2.5,
).imagesimages[0].save("pag.jpg")

PAG代码细节

如果同时使用PAG和CFG,那么输入到Unet中prompt_embeds定义如下,也就是[uncond,cond,cond]

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):cond = torch.cat([cond] * 2, dim=0)if do_classifier_free_guidance:cond = torch.cat([uncond, cond], dim=0)return cond

PAGCFGIdentitySelfAttnProcessor2_0计算,其中[uncond,cond]正常计算SA,第二个cond则计算PSA。

class PAGCFGIdentitySelfAttnProcessor2_0:r"""Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).PAG reference: https://arxiv.org/abs/2403.17377"""def __init__(self):if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: Optional[torch.FloatTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,temb: Optional[torch.FloatTensor] = None,) -> torch.Tensor:residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# chunkhidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])# original pathbatch_size, sequence_length, _ = hidden_states_org.shapeif attention_mask is not None:attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# scaled_dot_product_attention expects attention_mask shape to be# (batch, heads, source_length, target_length)attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])if attn.group_norm is not None:hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states_org)key = attn.to_k(hidden_states_org)value = attn.to_v(hidden_states_org)inner_dim = key.shape[-1]head_dim = inner_dim // attn.headsquery = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# the output of sdp = (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states_org = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)hidden_states_org = hidden_states_org.to(query.dtype)# linear projhidden_states_org = attn.to_out[0](hidden_states_org)# dropouthidden_states_org = attn.to_out[1](hidden_states_org)if input_ndim == 4:hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)# perturbed path (identity attention)batch_size, sequence_length, _ = hidden_states_ptb.shapeif attn.group_norm is not None:hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)value = attn.to_v(hidden_states_ptb)hidden_states_ptb = valuehidden_states_ptb = hidden_states_ptb.to(query.dtype)# linear projhidden_states_ptb = attn.to_out[0](hidden_states_ptb)# dropouthidden_states_ptb = attn.to_out[1](hidden_states_ptb)if input_ndim == 4:hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)# cathidden_states = torch.cat([hidden_states_org, hidden_states_ptb])if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

经过Unet后,noise_pred的计算方法。

    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False):r"""Apply perturbed attention guidance to the noise prediction.Args:noise_pred (torch.Tensor): The noise prediction tensor.do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.guidance_scale (float): The scale factor for the guidance term.t (int): The current time step.return_pred_text (bool): Whether to return the text noise prediction.Returns:Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applyingperturbed attention guidance and the text noise prediction."""pag_scale = self._get_pag_scale(t)if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)noise_pred = (noise_pred_uncond+ guidance_scale * (noise_pred_text - noise_pred_uncond)+ pag_scale * (noise_pred_text - noise_pred_perturb))else:noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)if return_pred_text:return noise_pred, noise_pred_textreturn noise_pred

相关文章:

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance Github 摘要 近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导&#xff…...

自动驾驶控制与规划——Project 6: A* Route Planning

目录 零、任务介绍一、算法原理1.1 A* Algorithm1.2 启发函数 二、代码实现三、结果分析四、效果展示4.1 Dijkstra距离4.2 Manhatten距离4.3 欧几里德距离4.4 对角距离 五、后记 零、任务介绍 carla-ros-bridge/src/ros-bridge/carla_shenlan_projects/carla_shenlan_a_star_p…...

通俗易懂之线性回归时序预测PyTorch实践

线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入…...

[离线数仓] 总结二、Hive数仓分层开发

接 [离线数仓] 总结一、数据采集 5.8 数仓开发之ODS层 ODS层的设计要点如下: (1)ODS层的表结构设计依托于从业务系统同步过来的数据结构。 (2)ODS层要保存全部历史数据,故其压缩格式应选择压缩比率,较高的,此处选择gzip。 CompressedStorage - Apache Hive - Apac…...

页面顶部导航栏(Navbar)的功能(Navbar/index.vue)

这段代码是一个 Vue.js 组件&#xff0c;实现了页面顶部导航栏&#xff08;Navbar&#xff09;的功能。我将分块分析它的各个部分&#xff1a; 模板 (Template): <!-- spid-admin/src/layout/components/Navbar/index.vue --> <template><div class"navb…...

thinnkphp5.1和 thinkphp6以及nginx,apache 解决跨域问题

ThinkPHP 5.1 使用中间件设置响应头 ThinkPHP 5.1 及以上版本支持中间件&#xff0c;可以通过中间件统一设置跨域响应头。 步骤&#xff1a; 创建一个中间件文件&#xff0c;例如 CorsMiddleware.php&#xff1a; namespace app\middleware;class CorsMiddleware {public fu…...

vue2新增删除

&#xff08;只是页面实现&#xff0c;不涉及数据库&#xff09; list组件&#xff1a; <button click"onAdd">新增</button><el-table:header-cell-style"{ textAlign: center }" :cell-style"{ textAlign: center }":data&quo…...

测试ip端口-telnet开启与使用

前言 开发过程中我们总会要去测试ip通不通&#xff0c;或者ip下某个端口是否可以联通&#xff0c;为此我们可以使用telnet 命令来实现。 一、telnet 开启 可能有些人使用telnet报错&#xff0c;不是内部命令&#xff0c;可以如下开启&#xff1a; 1、打开控制面板&#xff…...

Python爬虫基础——XPath表达式

首先说一下这节内容在学习过程中存在的问题吧&#xff0c;在爬取百度网页文字时&#xff0c;出现了问题&#xff0c;就是通过表达式在网页搜索中可以定位&#xff0c;但是通过代码无法定位&#xff0c;请教了一位老师&#xff0c;他说是动态链接&#xff0c;目前这部分内容比较…...

ansible-性能优化

一. 简述&#xff1a; 搞过运维自动化工具的人&#xff0c;肯定会发现很多运维伙伴们经常用saltstack和ansible做比较&#xff0c;单从执行效率上来说&#xff0c;ansible确实比不上saltstack(ansible使用的是ssh,salt使用的是zeromq消息队列[暂没深入了解])&#xff0c;但其实…...

高等数学学习笔记 ☞ 一元函数微分的基础知识

1. 微分的定义 &#xff08;1&#xff09;定义&#xff1a;设函数在点的某领域内有定义&#xff0c;取附近的点&#xff0c;对应的函数值分别为和&#xff0c; 令&#xff0c;若可以表示成&#xff0c;则称函数在点是可微的。 【 若函数在点是可微的&#xff0c;则可以表达为】…...

前后端实现防抖节流实现

在前端和 Java 后端中实现防抖&#xff08;Debounce&#xff09;和节流&#xff08;Throttle&#xff09;主要用于减少频繁请求或事件触发对系统的压力。前端和后端的实现方式有些不同&#xff0c;以下是两种方法的具体实现&#xff1a; 1. 前端实现防抖和节流 在前端中&…...

【笔记】算法记录

1、求一个数的素因子&#xff08;试除法&#xff09; // 获取一个数的所有素因子 set<int> getPrimeFactors(int num) {set<int> primeFactors;for (int i 2; i * i < num; i) {while (num % i 0) {primeFactors.insert(i);num / i;}}if (num > 1) {prime…...

【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析

文章目录 一、选择题二、理论题三、实操题 【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析 一、选择题 生成树协议的主要作用是 B. 防止网络环路解释&#xff1a;生成树协议&#xff08;STP&#xff09;的主要目的是防止网络中…...

git push -f 指定分支

要将本地代码推送到指定的远程分支&#xff0c;你可以使用以下步骤和命令&#xff1a; 确认远程仓库&#xff1a; 确保你的本地仓库已经与远程仓库关联。你可以使用以下命令查看当前的远程仓库状态&#xff1a; git remote -v查看本地分支&#xff1a; 使用命令查看当前存在的本…...

CTF知识点总结(二)

异或注入&#xff1a;两个条件相同&#xff08;同真或同假&#xff09;即为假。 http://120.24.86.145:9004/1ndex.php?id1^(length(union)!0)-- 如上&#xff0c;如果union被过滤&#xff0c;则 length(union)!0 为假&#xff0c;那么返回页面正常。 2|0updatexml() 函数报…...

解决Edge打开PDF总是没有焦点

【问题描述】 使用Edge浏览器作为默认PDF阅读器打开本地PDF文件&#xff0c;Edge窗口总是不获得焦点&#xff0c;而是在任务栏以橙色显示&#xff0c;需要再手动点击一次才能查看文件内容。 本强迫症来治一治这个问题&#xff01; 【解决方法】 GPT老师指出问题出在Edge的启动…...

69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)

项目介绍 在Internet高速发展的今天&#xff0c;我们生活的各个领域都涉及到计算机的应用&#xff0c;其中包括家乡特色推荐的网络应用&#xff0c;在外国家乡特色推荐系统已经是很普遍的方式&#xff0c;不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…...

计算机视觉目标检测-DETR网络

目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR&#xff08;DEtection TRansformer&#xff09;是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题&#xff0c;摒弃了锚框设计和非…...

《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶

目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知&#xff0c;乃至可以使用感知结果直接构建鸟瞰图&#xff08;bird eye view, BEV&#xff09;&#xff0c;而 L4 则依赖离线地图。 高精地…...

【清华代码熊】图解 Gemma 4 架构设计细节

&#x1f4cc; 本期图解 Google 开源Gemma 4 架构设计细节&#xff0c;其中端侧模型的架构上有很多值得一看的设计。...

[具身智能-298]:深度神经网络实现语音识别的库、模型、方案

在深度神经网络时代&#xff0c;实现语音识别&#xff08;ASR&#xff09;已经不再需要从零开始编写底层算法&#xff0c;而是更多地依赖于成熟的开源库、预训练模型以及高效的工程化方案。基于最新的行业实践&#xff08;截至2026年4月&#xff09;&#xff0c;我为你梳理了目…...

ESP居然能当 DNS 服务器用?内含NCSI欺骗和DNS劫持实现妆

前言 Kubernetes 本身并不复杂&#xff0c;是我们把它搞复杂的。无论是刻意为之还是那种虽然出于好意却将优雅的原语堆砌成 鲁布戈德堡机械 的狂热。平台最初提供的 ReplicaSets、Services、ConfigMaps&#xff0c;这些基础组件简单直接&#xff0c;甚至显得有些枯燥。但后来我…...

第6章 黎曼流形优化与几何方法

第6章 黎曼流形优化与几何方法 6.1 黎曼几何基础 6.1.1 复Stiefel流形与单位模流形&#xff08;Unit-Modulus Manifold&#xff09;度量 6.1.2 指数映射&#xff08;Exponential Mapping&#xff09;与平行移动&#xff08;Parallel Transport&#xff09; 6.1.3 测…...

告别ArcGIS!用GEE+QGIS搞定流域DEM下载与地形分析(附完整代码)

告别ArcGIS&#xff01;用GEEQGIS搞定流域DEM下载与地形分析&#xff08;附完整代码&#xff09; 在GIS领域&#xff0c;数字高程模型&#xff08;DEM&#xff09;是地形分析的基础数据。传统上&#xff0c;ArcGIS凭借其完善的功能和稳定的性能&#xff0c;成为DEM处理的首选工…...

营销自动化数据驱动 - 多源数据 OLAP 架构演进嘉

1. 流图&#xff1a;数据的河流 如果把传统的堆叠面积图想象成一块块整齐堆叠的积木&#xff0c;那么流图就像一条蜿蜒流淌的河流&#xff0c;河道的宽窄变化自然流畅&#xff0c;波峰波谷过渡平滑。 它特别适合展示多个类别数据随时间的变化趋势&#xff0c;尤其是当你想强调整…...

SPI扩展CAN方案:从寄存器配置到多路通信实战

1. SPI扩展CAN方案的核心价值 在工业控制领域&#xff0c;CAN总线因其高可靠性和实时性被广泛使用。但随着设备节点增加&#xff0c;主控芯片原生CAN接口往往不够用。这时通过SPI接口扩展CAN通道就成了性价比极高的解决方案。我曾在多个工业现场实测&#xff0c;用10元级的MCP2…...

第十五届题目

握手问题 #include <stdio.h> #include <stdlib.h>int main(int argc, char *argv[]) {int sum0;for(int i49;i>7;i--){sumi;}printf("%d",sum);return 0; } 小球反弹 #include <stdio.h> #include <math.h>int main(int argc, char *ar…...

AI赋能智能制造:预测性维护在工业4.0中的落地实践

1. 预测性维护&#xff1a;从被动维修到智能预防的革命 想象一下&#xff0c;你家的空调突然在炎热的夏天罢工了&#xff0c;维修师傅告诉你&#xff1a;"这个零件本来三个月前就该换了"。这种场景在工业生产中放大1000倍&#xff0c;就是传统维护方式带来的痛点。预…...

AI Agent 赋能智能客服:Vue3 + LangChain + 千问落地实战

前言 &#x1f44b; 本文适合有前端基础的开发者阅读。我会从整体架构出发&#xff0c;详解如何用Vue3 TypeScript做前端交互、**Python3 LangChain 千问&#xff08;Qwen&#xff09;**做后端推理&#xff0c;构建一个真正能落地的智能客服 Agent。代码干货较多&#xff0…...