当前位置: 首页 > news >正文

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance
Github

摘要

近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导(CFG)。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中,我们提出了一种新颖的采样引导方法,称为Perturbed-Attention Guidance(PAG),它能在无条件和条件设置下提高扩散样本的质量,并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本,这是考虑到自注意力机制捕捉结构信息的能力,并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中,PAG 在条件甚至无条件场景下都显著提高了样本质量。此外,在诸如空提示的 ControlNet 以及图像修复(如修补和去模糊)等现有引导(如 CG 或 CFG)无法充分利用的各种下游任务中,PAG 也显著提高了基线性能。
在这里插入图片描述
研究表明,在diffusion U-Net的self-attention 模块中,query-key 主要影响structure ,values主要影响appearance。
在这里插入图片描述
如果直接扰动Vt 的话,会导致 out-of-distribution (OOD),因此选择使用单位矩阵替换query-key 部分。
在这里插入图片描述
在这里插入图片描述
那么具体扰动Unet的哪一部分呢?作者使用了5k个样本,在PAG guidance scale s = 2.5 and DDIM 25 step的条件下,表现最好的是mid-block “m0”
在这里插入图片描述

代码

Diffusers 已经支持PAG用在多种任务中,并且可以和ControlNet、 IP-Adapter 一起使用。

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torchpipeline = AutoPipelineForText2Image.from_pretrained("~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0",enable_pag=True,  ##addpag_applied_layers=["mid"], ##addtorch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()prompt = "an insect robot preparing a delicious meal, anime style"
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(prompt=prompt,num_inference_steps=25,guidance_scale=7.0,generator=generator,pag_scale=2.5,
).imagesimages[0].save("pag.jpg")

PAG代码细节

如果同时使用PAG和CFG,那么输入到Unet中prompt_embeds定义如下,也就是[uncond,cond,cond]

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):cond = torch.cat([cond] * 2, dim=0)if do_classifier_free_guidance:cond = torch.cat([uncond, cond], dim=0)return cond

PAGCFGIdentitySelfAttnProcessor2_0计算,其中[uncond,cond]正常计算SA,第二个cond则计算PSA。

class PAGCFGIdentitySelfAttnProcessor2_0:r"""Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).PAG reference: https://arxiv.org/abs/2403.17377"""def __init__(self):if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: Optional[torch.FloatTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,temb: Optional[torch.FloatTensor] = None,) -> torch.Tensor:residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# chunkhidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])# original pathbatch_size, sequence_length, _ = hidden_states_org.shapeif attention_mask is not None:attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# scaled_dot_product_attention expects attention_mask shape to be# (batch, heads, source_length, target_length)attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])if attn.group_norm is not None:hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states_org)key = attn.to_k(hidden_states_org)value = attn.to_v(hidden_states_org)inner_dim = key.shape[-1]head_dim = inner_dim // attn.headsquery = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# the output of sdp = (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states_org = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)hidden_states_org = hidden_states_org.to(query.dtype)# linear projhidden_states_org = attn.to_out[0](hidden_states_org)# dropouthidden_states_org = attn.to_out[1](hidden_states_org)if input_ndim == 4:hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)# perturbed path (identity attention)batch_size, sequence_length, _ = hidden_states_ptb.shapeif attn.group_norm is not None:hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)value = attn.to_v(hidden_states_ptb)hidden_states_ptb = valuehidden_states_ptb = hidden_states_ptb.to(query.dtype)# linear projhidden_states_ptb = attn.to_out[0](hidden_states_ptb)# dropouthidden_states_ptb = attn.to_out[1](hidden_states_ptb)if input_ndim == 4:hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)# cathidden_states = torch.cat([hidden_states_org, hidden_states_ptb])if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

经过Unet后,noise_pred的计算方法。

    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False):r"""Apply perturbed attention guidance to the noise prediction.Args:noise_pred (torch.Tensor): The noise prediction tensor.do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.guidance_scale (float): The scale factor for the guidance term.t (int): The current time step.return_pred_text (bool): Whether to return the text noise prediction.Returns:Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applyingperturbed attention guidance and the text noise prediction."""pag_scale = self._get_pag_scale(t)if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)noise_pred = (noise_pred_uncond+ guidance_scale * (noise_pred_text - noise_pred_uncond)+ pag_scale * (noise_pred_text - noise_pred_perturb))else:noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)if return_pred_text:return noise_pred, noise_pred_textreturn noise_pred

相关文章:

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance Github 摘要 近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导&#xff…...

自动驾驶控制与规划——Project 6: A* Route Planning

目录 零、任务介绍一、算法原理1.1 A* Algorithm1.2 启发函数 二、代码实现三、结果分析四、效果展示4.1 Dijkstra距离4.2 Manhatten距离4.3 欧几里德距离4.4 对角距离 五、后记 零、任务介绍 carla-ros-bridge/src/ros-bridge/carla_shenlan_projects/carla_shenlan_a_star_p…...

通俗易懂之线性回归时序预测PyTorch实践

线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入…...

[离线数仓] 总结二、Hive数仓分层开发

接 [离线数仓] 总结一、数据采集 5.8 数仓开发之ODS层 ODS层的设计要点如下: (1)ODS层的表结构设计依托于从业务系统同步过来的数据结构。 (2)ODS层要保存全部历史数据,故其压缩格式应选择压缩比率,较高的,此处选择gzip。 CompressedStorage - Apache Hive - Apac…...

页面顶部导航栏(Navbar)的功能(Navbar/index.vue)

这段代码是一个 Vue.js 组件&#xff0c;实现了页面顶部导航栏&#xff08;Navbar&#xff09;的功能。我将分块分析它的各个部分&#xff1a; 模板 (Template): <!-- spid-admin/src/layout/components/Navbar/index.vue --> <template><div class"navb…...

thinnkphp5.1和 thinkphp6以及nginx,apache 解决跨域问题

ThinkPHP 5.1 使用中间件设置响应头 ThinkPHP 5.1 及以上版本支持中间件&#xff0c;可以通过中间件统一设置跨域响应头。 步骤&#xff1a; 创建一个中间件文件&#xff0c;例如 CorsMiddleware.php&#xff1a; namespace app\middleware;class CorsMiddleware {public fu…...

vue2新增删除

&#xff08;只是页面实现&#xff0c;不涉及数据库&#xff09; list组件&#xff1a; <button click"onAdd">新增</button><el-table:header-cell-style"{ textAlign: center }" :cell-style"{ textAlign: center }":data&quo…...

测试ip端口-telnet开启与使用

前言 开发过程中我们总会要去测试ip通不通&#xff0c;或者ip下某个端口是否可以联通&#xff0c;为此我们可以使用telnet 命令来实现。 一、telnet 开启 可能有些人使用telnet报错&#xff0c;不是内部命令&#xff0c;可以如下开启&#xff1a; 1、打开控制面板&#xff…...

Python爬虫基础——XPath表达式

首先说一下这节内容在学习过程中存在的问题吧&#xff0c;在爬取百度网页文字时&#xff0c;出现了问题&#xff0c;就是通过表达式在网页搜索中可以定位&#xff0c;但是通过代码无法定位&#xff0c;请教了一位老师&#xff0c;他说是动态链接&#xff0c;目前这部分内容比较…...

ansible-性能优化

一. 简述&#xff1a; 搞过运维自动化工具的人&#xff0c;肯定会发现很多运维伙伴们经常用saltstack和ansible做比较&#xff0c;单从执行效率上来说&#xff0c;ansible确实比不上saltstack(ansible使用的是ssh,salt使用的是zeromq消息队列[暂没深入了解])&#xff0c;但其实…...

高等数学学习笔记 ☞ 一元函数微分的基础知识

1. 微分的定义 &#xff08;1&#xff09;定义&#xff1a;设函数在点的某领域内有定义&#xff0c;取附近的点&#xff0c;对应的函数值分别为和&#xff0c; 令&#xff0c;若可以表示成&#xff0c;则称函数在点是可微的。 【 若函数在点是可微的&#xff0c;则可以表达为】…...

前后端实现防抖节流实现

在前端和 Java 后端中实现防抖&#xff08;Debounce&#xff09;和节流&#xff08;Throttle&#xff09;主要用于减少频繁请求或事件触发对系统的压力。前端和后端的实现方式有些不同&#xff0c;以下是两种方法的具体实现&#xff1a; 1. 前端实现防抖和节流 在前端中&…...

【笔记】算法记录

1、求一个数的素因子&#xff08;试除法&#xff09; // 获取一个数的所有素因子 set<int> getPrimeFactors(int num) {set<int> primeFactors;for (int i 2; i * i < num; i) {while (num % i 0) {primeFactors.insert(i);num / i;}}if (num > 1) {prime…...

【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析

文章目录 一、选择题二、理论题三、实操题 【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析 一、选择题 生成树协议的主要作用是 B. 防止网络环路解释&#xff1a;生成树协议&#xff08;STP&#xff09;的主要目的是防止网络中…...

git push -f 指定分支

要将本地代码推送到指定的远程分支&#xff0c;你可以使用以下步骤和命令&#xff1a; 确认远程仓库&#xff1a; 确保你的本地仓库已经与远程仓库关联。你可以使用以下命令查看当前的远程仓库状态&#xff1a; git remote -v查看本地分支&#xff1a; 使用命令查看当前存在的本…...

CTF知识点总结(二)

异或注入&#xff1a;两个条件相同&#xff08;同真或同假&#xff09;即为假。 http://120.24.86.145:9004/1ndex.php?id1^(length(union)!0)-- 如上&#xff0c;如果union被过滤&#xff0c;则 length(union)!0 为假&#xff0c;那么返回页面正常。 2|0updatexml() 函数报…...

解决Edge打开PDF总是没有焦点

【问题描述】 使用Edge浏览器作为默认PDF阅读器打开本地PDF文件&#xff0c;Edge窗口总是不获得焦点&#xff0c;而是在任务栏以橙色显示&#xff0c;需要再手动点击一次才能查看文件内容。 本强迫症来治一治这个问题&#xff01; 【解决方法】 GPT老师指出问题出在Edge的启动…...

69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)

项目介绍 在Internet高速发展的今天&#xff0c;我们生活的各个领域都涉及到计算机的应用&#xff0c;其中包括家乡特色推荐的网络应用&#xff0c;在外国家乡特色推荐系统已经是很普遍的方式&#xff0c;不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…...

计算机视觉目标检测-DETR网络

目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR&#xff08;DEtection TRansformer&#xff09;是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题&#xff0c;摒弃了锚框设计和非…...

《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶

目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知&#xff0c;乃至可以使用感知结果直接构建鸟瞰图&#xff08;bird eye view, BEV&#xff09;&#xff0c;而 L4 则依赖离线地图。 高精地…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

在鸿蒙HarmonyOS 5中实现抖音风格的点赞功能

下面我将详细介绍如何使用HarmonyOS SDK在HarmonyOS 5中实现类似抖音的点赞功能&#xff0c;包括动画效果、数据同步和交互优化。 1. 基础点赞功能实现 1.1 创建数据模型 // VideoModel.ets export class VideoModel {id: string "";title: string ""…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集&#xff0c;包含8种湿地亚类&#xff0c;该数据以0.5X0.5的瓦片存储&#xff0c;我们整理了所有属于中国的瓦片名称与其对应省份&#xff0c;方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

Java 加密常用的各种算法及其选择

在数字化时代&#xff0c;数据安全至关重要&#xff0c;Java 作为广泛应用的编程语言&#xff0c;提供了丰富的加密算法来保障数据的保密性、完整性和真实性。了解这些常用加密算法及其适用场景&#xff0c;有助于开发者在不同的业务需求中做出正确的选择。​ 一、对称加密算法…...

QT: `long long` 类型转换为 `QString` 2025.6.5

在 Qt 中&#xff0c;将 long long 类型转换为 QString 可以通过以下两种常用方法实现&#xff1a; 方法 1&#xff1a;使用 QString::number() 直接调用 QString 的静态方法 number()&#xff0c;将数值转换为字符串&#xff1a; long long value 1234567890123456789LL; …...

Java面试专项一-准备篇

一、企业简历筛选规则 一般企业的简历筛选流程&#xff1a;首先由HR先筛选一部分简历后&#xff0c;在将简历给到对应的项目负责人后再进行下一步的操作。 HR如何筛选简历 例如&#xff1a;Boss直聘&#xff08;招聘方平台&#xff09; 直接按照条件进行筛选 例如&#xff1a…...

代码随想录刷题day30

1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币&#xff0c;另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额&#xff0c;返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...

[大语言模型]在个人电脑上部署ollama 并进行管理,最后配置AI程序开发助手.

ollama官网: 下载 https://ollama.com/ 安装 查看可以使用的模型 https://ollama.com/search 例如 https://ollama.com/library/deepseek-r1/tags # deepseek-r1:7bollama pull deepseek-r1:7b改token数量为409622 16384 ollama命令说明 ollama serve #&#xff1a…...