当前位置: 首页 > news >正文

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance
Github

摘要

近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导(CFG)。这些技术在无条件生成或诸如图像恢复等各种下游任务中往往并不适用。在本文中,我们提出了一种新颖的采样引导方法,称为Perturbed-Attention Guidance(PAG),它能在无条件和条件设置下提高扩散样本的质量,并且无需额外的训练或集成外部模块。PAG 旨在通过去噪过程逐步增强样本的结构。它通过用单位矩阵替换 UNet 中的self-attention map来生成结构退化的中间样本,这是考虑到自注意力机制捕捉结构信息的能力,并引导去噪过程远离这些退化样本。在 ADM 和 Stable Diffusion 中,PAG 在条件甚至无条件场景下都显著提高了样本质量。此外,在诸如空提示的 ControlNet 以及图像修复(如修补和去模糊)等现有引导(如 CG 或 CFG)无法充分利用的各种下游任务中,PAG 也显著提高了基线性能。
在这里插入图片描述
研究表明,在diffusion U-Net的self-attention 模块中,query-key 主要影响structure ,values主要影响appearance。
在这里插入图片描述
如果直接扰动Vt 的话,会导致 out-of-distribution (OOD),因此选择使用单位矩阵替换query-key 部分。
在这里插入图片描述
在这里插入图片描述
那么具体扰动Unet的哪一部分呢?作者使用了5k个样本,在PAG guidance scale s = 2.5 and DDIM 25 step的条件下,表现最好的是mid-block “m0”
在这里插入图片描述

代码

Diffusers 已经支持PAG用在多种任务中,并且可以和ControlNet、 IP-Adapter 一起使用。

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torchpipeline = AutoPipelineForText2Image.from_pretrained("~/.cache/modelscope/hub/AI-ModelScope/stable-diffusion-xl-base-1___0",enable_pag=True,  ##addpag_applied_layers=["mid"], ##addtorch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()prompt = "an insect robot preparing a delicious meal, anime style"
generator = torch.Generator(device="cpu").manual_seed(0)
images = pipeline(prompt=prompt,num_inference_steps=25,guidance_scale=7.0,generator=generator,pag_scale=2.5,
).imagesimages[0].save("pag.jpg")

PAG代码细节

如果同时使用PAG和CFG,那么输入到Unet中prompt_embeds定义如下,也就是[uncond,cond,cond]

    def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):cond = torch.cat([cond] * 2, dim=0)if do_classifier_free_guidance:cond = torch.cat([uncond, cond], dim=0)return cond

PAGCFGIdentitySelfAttnProcessor2_0计算,其中[uncond,cond]正常计算SA,第二个cond则计算PSA。

class PAGCFGIdentitySelfAttnProcessor2_0:r"""Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).PAG reference: https://arxiv.org/abs/2403.17377"""def __init__(self):if not hasattr(F, "scaled_dot_product_attention"):raise ImportError("PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")def __call__(self,attn: Attention,hidden_states: torch.FloatTensor,encoder_hidden_states: Optional[torch.FloatTensor] = None,attention_mask: Optional[torch.FloatTensor] = None,temb: Optional[torch.FloatTensor] = None,) -> torch.Tensor:residual = hidden_statesif attn.spatial_norm is not None:hidden_states = attn.spatial_norm(hidden_states, temb)input_ndim = hidden_states.ndimif input_ndim == 4:batch_size, channel, height, width = hidden_states.shapehidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)# chunkhidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])# original pathbatch_size, sequence_length, _ = hidden_states_org.shapeif attention_mask is not None:attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)# scaled_dot_product_attention expects attention_mask shape to be# (batch, heads, source_length, target_length)attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])if attn.group_norm is not None:hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)query = attn.to_q(hidden_states_org)key = attn.to_k(hidden_states_org)value = attn.to_v(hidden_states_org)inner_dim = key.shape[-1]head_dim = inner_dim // attn.headsquery = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)# the output of sdp = (batch, num_heads, seq_len, head_dim)# TODO: add support for attn.scale when we move to Torch 2.1hidden_states_org = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)hidden_states_org = hidden_states_org.to(query.dtype)# linear projhidden_states_org = attn.to_out[0](hidden_states_org)# dropouthidden_states_org = attn.to_out[1](hidden_states_org)if input_ndim == 4:hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)# perturbed path (identity attention)batch_size, sequence_length, _ = hidden_states_ptb.shapeif attn.group_norm is not None:hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)value = attn.to_v(hidden_states_ptb)hidden_states_ptb = valuehidden_states_ptb = hidden_states_ptb.to(query.dtype)# linear projhidden_states_ptb = attn.to_out[0](hidden_states_ptb)# dropouthidden_states_ptb = attn.to_out[1](hidden_states_ptb)if input_ndim == 4:hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)# cathidden_states = torch.cat([hidden_states_org, hidden_states_ptb])if attn.residual_connection:hidden_states = hidden_states + residualhidden_states = hidden_states / attn.rescale_output_factorreturn hidden_states

经过Unet后,noise_pred的计算方法。

    def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False):r"""Apply perturbed attention guidance to the noise prediction.Args:noise_pred (torch.Tensor): The noise prediction tensor.do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.guidance_scale (float): The scale factor for the guidance term.t (int): The current time step.return_pred_text (bool): Whether to return the text noise prediction.Returns:Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applyingperturbed attention guidance and the text noise prediction."""pag_scale = self._get_pag_scale(t)if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)noise_pred = (noise_pred_uncond+ guidance_scale * (noise_pred_text - noise_pred_uncond)+ pag_scale * (noise_pred_text - noise_pred_perturb))else:noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)if return_pred_text:return noise_pred, noise_pred_textreturn noise_pred

相关文章:

Perturbed-Attention Guidance(PAG) 笔记

Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance Github 摘要 近期研究表明,扩散模型能够生成高质量样本,但其质量在很大程度上依赖于采样引导技术,如分类器引导(CG)和无分类器引导&#xff…...

自动驾驶控制与规划——Project 6: A* Route Planning

目录 零、任务介绍一、算法原理1.1 A* Algorithm1.2 启发函数 二、代码实现三、结果分析四、效果展示4.1 Dijkstra距离4.2 Manhatten距离4.3 欧几里德距离4.4 对角距离 五、后记 零、任务介绍 carla-ros-bridge/src/ros-bridge/carla_shenlan_projects/carla_shenlan_a_star_p…...

通俗易懂之线性回归时序预测PyTorch实践

线性回归(Linear Regression)是机器学习中最基本且广泛应用的算法之一。它不仅作为入门学习的经典案例,也是许多复杂模型的基础。本文将全面介绍线性回归的原理、应用,并通过一段PyTorch代码进行实践演示,帮助读者深入…...

[离线数仓] 总结二、Hive数仓分层开发

接 [离线数仓] 总结一、数据采集 5.8 数仓开发之ODS层 ODS层的设计要点如下: (1)ODS层的表结构设计依托于从业务系统同步过来的数据结构。 (2)ODS层要保存全部历史数据,故其压缩格式应选择压缩比率,较高的,此处选择gzip。 CompressedStorage - Apache Hive - Apac…...

页面顶部导航栏(Navbar)的功能(Navbar/index.vue)

这段代码是一个 Vue.js 组件&#xff0c;实现了页面顶部导航栏&#xff08;Navbar&#xff09;的功能。我将分块分析它的各个部分&#xff1a; 模板 (Template): <!-- spid-admin/src/layout/components/Navbar/index.vue --> <template><div class"navb…...

thinnkphp5.1和 thinkphp6以及nginx,apache 解决跨域问题

ThinkPHP 5.1 使用中间件设置响应头 ThinkPHP 5.1 及以上版本支持中间件&#xff0c;可以通过中间件统一设置跨域响应头。 步骤&#xff1a; 创建一个中间件文件&#xff0c;例如 CorsMiddleware.php&#xff1a; namespace app\middleware;class CorsMiddleware {public fu…...

vue2新增删除

&#xff08;只是页面实现&#xff0c;不涉及数据库&#xff09; list组件&#xff1a; <button click"onAdd">新增</button><el-table:header-cell-style"{ textAlign: center }" :cell-style"{ textAlign: center }":data&quo…...

测试ip端口-telnet开启与使用

前言 开发过程中我们总会要去测试ip通不通&#xff0c;或者ip下某个端口是否可以联通&#xff0c;为此我们可以使用telnet 命令来实现。 一、telnet 开启 可能有些人使用telnet报错&#xff0c;不是内部命令&#xff0c;可以如下开启&#xff1a; 1、打开控制面板&#xff…...

Python爬虫基础——XPath表达式

首先说一下这节内容在学习过程中存在的问题吧&#xff0c;在爬取百度网页文字时&#xff0c;出现了问题&#xff0c;就是通过表达式在网页搜索中可以定位&#xff0c;但是通过代码无法定位&#xff0c;请教了一位老师&#xff0c;他说是动态链接&#xff0c;目前这部分内容比较…...

ansible-性能优化

一. 简述&#xff1a; 搞过运维自动化工具的人&#xff0c;肯定会发现很多运维伙伴们经常用saltstack和ansible做比较&#xff0c;单从执行效率上来说&#xff0c;ansible确实比不上saltstack(ansible使用的是ssh,salt使用的是zeromq消息队列[暂没深入了解])&#xff0c;但其实…...

高等数学学习笔记 ☞ 一元函数微分的基础知识

1. 微分的定义 &#xff08;1&#xff09;定义&#xff1a;设函数在点的某领域内有定义&#xff0c;取附近的点&#xff0c;对应的函数值分别为和&#xff0c; 令&#xff0c;若可以表示成&#xff0c;则称函数在点是可微的。 【 若函数在点是可微的&#xff0c;则可以表达为】…...

前后端实现防抖节流实现

在前端和 Java 后端中实现防抖&#xff08;Debounce&#xff09;和节流&#xff08;Throttle&#xff09;主要用于减少频繁请求或事件触发对系统的压力。前端和后端的实现方式有些不同&#xff0c;以下是两种方法的具体实现&#xff1a; 1. 前端实现防抖和节流 在前端中&…...

【笔记】算法记录

1、求一个数的素因子&#xff08;试除法&#xff09; // 获取一个数的所有素因子 set<int> getPrimeFactors(int num) {set<int> primeFactors;for (int i 2; i * i < num; i) {while (num % i 0) {primeFactors.insert(i);num / i;}}if (num > 1) {prime…...

【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析

文章目录 一、选择题二、理论题三、实操题 【网络云SRE运维开发】2025第2周-每日【2025/01/08】小测-【第8章 STP生成树协议】理论和实操解析 一、选择题 生成树协议的主要作用是 B. 防止网络环路解释&#xff1a;生成树协议&#xff08;STP&#xff09;的主要目的是防止网络中…...

git push -f 指定分支

要将本地代码推送到指定的远程分支&#xff0c;你可以使用以下步骤和命令&#xff1a; 确认远程仓库&#xff1a; 确保你的本地仓库已经与远程仓库关联。你可以使用以下命令查看当前的远程仓库状态&#xff1a; git remote -v查看本地分支&#xff1a; 使用命令查看当前存在的本…...

CTF知识点总结(二)

异或注入&#xff1a;两个条件相同&#xff08;同真或同假&#xff09;即为假。 http://120.24.86.145:9004/1ndex.php?id1^(length(union)!0)-- 如上&#xff0c;如果union被过滤&#xff0c;则 length(union)!0 为假&#xff0c;那么返回页面正常。 2|0updatexml() 函数报…...

解决Edge打开PDF总是没有焦点

【问题描述】 使用Edge浏览器作为默认PDF阅读器打开本地PDF文件&#xff0c;Edge窗口总是不获得焦点&#xff0c;而是在任务栏以橙色显示&#xff0c;需要再手动点击一次才能查看文件内容。 本强迫症来治一治这个问题&#xff01; 【解决方法】 GPT老师指出问题出在Edge的启动…...

69.基于SpringBoot + Vue实现的前后端分离-家乡特色推荐系统(项目 + 论文PPT)

项目介绍 在Internet高速发展的今天&#xff0c;我们生活的各个领域都涉及到计算机的应用&#xff0c;其中包括家乡特色推荐的网络应用&#xff0c;在外国家乡特色推荐系统已经是很普遍的方式&#xff0c;不过国内的管理网站可能还处于起步阶段。家乡特色推荐系统采用java技术&…...

计算机视觉目标检测-DETR网络

目录 摘要abstractDETR目标检测网络详解二分图匹配和损失函数 DETR总结总结 摘要 DETR&#xff08;DEtection TRansformer&#xff09;是由Facebook AI提出的一种基于Transformer架构的端到端目标检测方法。它通过将目标检测建模为集合预测问题&#xff0c;摒弃了锚框设计和非…...

《自动驾驶与机器人中的SLAM技术》ch1:自动驾驶

目录 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 1.1 自动驾驶技术 1.2 自动驾驶中的定位与地图 L2 在技术实现上会更倾向于实时感知&#xff0c;乃至可以使用感知结果直接构建鸟瞰图&#xff08;bird eye view, BEV&#xff09;&#xff0c;而 L4 则依赖离线地图。 高精地…...

python打卡day49

知识点回顾&#xff1a; 通道注意力模块复习空间注意力模块CBAM的定义 作业&#xff1a;尝试对今天的模型检查参数数目&#xff0c;并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时&#xff0c;与数据库的交互无疑是核心环节。虽然传统的数据库操作方式&#xff08;如直接编写SQL语句与psycopg2交互&#xff09;赋予了我们精细的控制权&#xff0c;但在面对日益复杂的业务逻辑和快速迭代的需求时&#xff0c;这种方式的开发效率和可…...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具&#xff0c;可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下&#xff1a; ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜&#xff1a; ffmpeg…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说&#xff0c;传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度&#xff0c;通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块&#xff0c;它提供了一个轻量级的 HTTP 服务器实现&#xff0c;主要用于构建基于 HTTP 的应用程序和服务。 功能介绍&#xff1a; 主要功能 HTTP服务器功能&#xff1a; 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接&#xff1a;3403. 从盒子中找出字典序最大的字符串 I 代码如下&#xff1a; class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

20个超级好用的 CSS 动画库

分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码&#xff0c;而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库&#xff0c;可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画&#xff0c;可以包含在你的网页或应用项目中。 3.An…...