当前位置: 首页 > news >正文

Diffusion中的Unet (DIMP)

针对UNet2DConditionModel模型

查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是:

down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: str = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")

查看一下down  下采样的get_down_block方法:

def get_down_block(down_block_type,num_layers,in_channels,out_channels,temb_channels,add_downsample,resnet_eps,resnet_act_fn,attn_num_head_channels,resnet_groups=None,cross_attention_dim=None,downsample_padding=None,dual_cross_attention=False,use_linear_projection=False,only_cross_attention=False,upcast_attention=False,resnet_time_scale_shift="default",
):down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_typeif down_block_type == "DownBlock2D":return DownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "ResnetDownsampleBlock2D":return ResnetDownsampleBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownBlock2D":return AttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "CrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")return CrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention,upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SimpleCrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")return SimpleCrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SkipDownBlock2D":return SkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnSkipDownBlock2D":return AttnSkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "DownEncoderBlock2D":return DownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownEncoderBlock2D":return AttnDownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)raise ValueError(f"{down_block_type} does not exist.")

 我们看一下该Unet的forward函数:

def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""Args:sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensortimestep (`torch.FloatTensor` or `float` or `int`): (batch) timestepsencoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden statesreturn_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. Whenreturning a tuple, the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = True# prepare attention_maskif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# timesteps does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=self.dtype)emb = self.time_embedding(t_emb)if self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)emb = emb + class_emb# 2. pre-processsample = self.conv_in(sample)# 3. downdown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)down_block_res_samples += res_samples# 4. midsample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,)else:sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)# 6. post-processsample = self.conv_norm_out(sample)sample = self.conv_act(sample)sample = self.conv_out(sample)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)

也就是说在:down,mid和up Block时候都有传入text_embedding的信息encoder_hidden_states和cross attention的控制:cross_attention_kwargs.

具体每一个Block的实现看源码

相关文章:

Diffusion中的Unet (DIMP)

针对UNet2DConditionModel模型 查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是: down_block_types: Tuple[str] ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2…...

编译以前项目更改在x64下面时报错:函数“PVOID GetCurrentFiber(void)”已有主体

win32下面编译成功,但是x64报错 1>GetWord.c 1>md5.c 这两个文件无法编译 1>C:\Program Files (x86)\Windows Kits\10\Include\10.0.22000.0\um\winnt.h(24125,1): error C2084: 函数“PVOID GetCurrentFiber(void)”已有主体 1>C:\Program Files (x…...

【AIGC】大模型面试高频考点-数据清洗篇

【AIGC】大模型面试高频考点-数据清洗篇 (一)常用文本清洗方法1.去除无用的符号2.去除表情符号3.文本只保留汉字4.中文繁体、简体转换5.删除 HTML 标签和特殊字符6.标记化7.小写8.停用词删除9.词干提取和词形还原10.处理缺失数据11.删除重复文本12.处理嘈…...

当测试时间与测试资源有限时,你会如何优化测试策略?

1.优先级排序:根据项目的需求和紧急程度进行优先级排序,将测试用例用例划分优先级,合理安排测试资源 和时间。这样能够保障在有限的时间内测试到最关键的功能 2.提前介入测试:在开发过程中提前进行测试,可以迅速发现问…...

基于R语言森林生态系统结构、功能与稳定性分析与可视化

在生态学研究中,森林生态系统的结构、功能与稳定性是核心研究内容之一。这些方面不仅关系到森林动态变化和物种多样性,还直接影响森林提供的生态服务功能及其应对环境变化的能力。森林生态系统的结构主要包括物种组成、树种多样性、树木的空间分布与密度…...

如何使用 Python 实现插件式架构

使用 Python 实现插件式架构可以通过动态加载和调用模块或类,构建一个易于扩展和维护的系统。以下是实现插件式架构的步骤和核心思想。 1. 插件式架构核心概念 主程序:负责加载、管理插件,并调用插件的功能。插件:独立的模块或类…...

【北京迅为】iTOP-4412全能版使用手册-第二十章 搭建和测试NFS服务器

iTOP-4412全能版采用四核Cortex-A9,主频为1.4GHz-1.6GHz,配备S5M8767 电源管理,集成USB HUB,选用高品质板对板连接器稳定可靠,大厂生产,做工精良。接口一应俱全,开发更简单,搭载全网通4G、支持WIFI、蓝牙、…...

【纯原生js】原生实现h5落地页面中的单选组件按钮及功能

h5端的按钮系统自带的一般都很丑&#xff0c;需要我们进行二次美化&#xff0c;比如单选按钮复选框之类的&#xff0c;那怎么对其进行html和css的改造&#xff1f; 实现效果 实现代码 <section id"tags"><h2>给景区添加标题</h2><label><…...

深入浅出:开发者如何快速上手Web3生态系统

Web3作为互联网的未来发展方向&#xff0c;正在逐步改变传统互联网架构&#xff0c;推动去中心化技术的发展。对于开发者而言&#xff0c;Web3代表着一个充满机遇与挑战的新领域&#xff0c;学习和掌握Web3的基本技术和工具&#xff0c;将为未来的项目开发提供强大的支持。那么…...

通过深度点图表示的隐式场实现肺树结构的高效解剖标注文献速递-生成式模型与transformer在医学影像中的应用

Title 题目 Efficient anatomical labeling of pulmonary tree structures via deeppoint-graph representation-based implicit fields 通过深度点图表示的隐式场实现肺树结构的高效解剖标注 01 文献速递介绍 近年来&#xff0c;肺部疾病&#xff08;Decramer等&#xff…...

数据结构 (17)广义表

前言 数据结构中的广义表&#xff08;Generalized List&#xff0c;又称列表Lists&#xff09;是一种重要的数据结构&#xff0c;它是对线性表的一种推广&#xff0c;放松了对表元素的原子限制&#xff0c;容许它们具有其自身的结构。 一、定义与表示 定义&#xff1a;广义表是…...

论文笔记 SliceGPT: Compress Large Language Models By Deleting Rows And Columns

欲买桂花同载酒&#xff0c;终不似&#xff0c;少年游。 数学知识 秩&#xff1a; 矩阵中最大线性无关的行/列向量数。行秩与列秩相等。 线性无关&#xff1a;对于N个向量而言&#xff0c;如果任取一个向量 v \textbf{v} v&#xff0c;不能被剩下的N-1个向量通过线性组合的方式…...

前端工具的选择和安装

选择和安装前端工具是前端开发过程中的重要步骤。现代前端开发需要一些工具来提高效率和协作能力。以下是一些常用的前端工具及其选择和安装指南。 1. 代码编辑器 选择一个好的代码编辑器可以显著提高开发效率。以下是几款流行的代码编辑器&#xff1a; Visual Studio Code (…...

Fantasy中定时器得驱动原理

一、服务器框架启动 public static async FTask Start(){// 启动ProcessStartProcess().Coroutine();await FTask.CompletedTask;while (true){ThreadScheduler.Update();Thread.Sleep(1);}} 二、主线程 Fantasy.ThreadScheduler.Update internal static void Update(){MainS…...

【反转链表】力扣 445. 两数相加 II

一、题目 二、思路 加法运算是从低位开始&#xff0c;向高位进位&#xff0c;因此需要将两个链表进行反转&#xff0c;再进行对齐后的相加操作。力扣 2. 两数相加 三、题解 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode …...

SpringBoot 项目中使用 spring-boot-starter-amqp 依赖实现 RabbitMQ

文章目录 前言1、application.yml2、RabbitMqConfig3、MqMessage4、MqMessageItem5、DirectMode6、StateConsumer&#xff1a;消费者7、InfoConsumer&#xff1a;消费者 前言 本文是工作之余的随手记&#xff0c;记录在工作期间使用 RabbitMQ 的笔记。 1、application.yml 使…...

Uniapp 安装安卓、IOS模拟器并调试

一、安装Android模拟器并调试 1. 下载并安装 Android Studio 首先下载 Mac 环境下的 Android Studio 的安装包&#xff0c;为dmg 格式。 下载完将Android Studio 向右拖拽到Applications中&#xff0c;接下来等待安装完成就OK啦&#xff01; 打开过程界面如下图所示&#xf…...

JavaScript 中的原型和原型链

JavaScript 中的原型和原型链也是一个相对较难理解透彻的知识点&#xff0c;下面结合详细例子来进行说明&#xff1a; 一、原型的概念 在 JavaScript 中&#xff0c;每个函数都有一个 prototype 属性&#xff0c;这个属性指向一个对象&#xff0c;这个对象就是所谓的 “原型对…...

数组变换(两倍)

数组变换 以最大元素为基准元素&#xff0c;判读其他元素能否通过 x 2 成为最大值&#xff01; 那么怎么判断呢&#xff1a; max % arr[i] 0arr[i] * 2 ^n max int x 2 ^ n max / arr[i] 3.只需判断 这个 x 是不是 2 的 n 次放就可以了&#xff01; 判断 是否为 2 的 n 次 …...

GBN协议、SR协议

1、回退N步&#xff08;Go-Back-N,GBN&#xff09;协议&#xff1a; 总结&#xff1a; GBN协议的特点&#xff1a; &#xff08;1&#xff09;累计确认机制&#xff1a;当发送方收到ACKn时&#xff0c;表明接收方已正确接收序号为n以及序号小于n的所有分组&#xff0c;发送窗…...

在软件开发中正确使用MySQL日期时间类型的深度解析

在日常软件开发场景中&#xff0c;时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志&#xff0c;到供应链系统的物流节点时间戳&#xff0c;时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库&#xff0c;其日期时间类型的…...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题&#xff1a; 下面创建一个简单的Flask RESTful API示例。首先&#xff0c;我们需要创建环境&#xff0c;安装必要的依赖&#xff0c;然后…...

高等数学(下)题型笔记(八)空间解析几何与向量代数

目录 0 前言 1 向量的点乘 1.1 基本公式 1.2 例题 2 向量的叉乘 2.1 基础知识 2.2 例题 3 空间平面方程 3.1 基础知识 3.2 例题 4 空间直线方程 4.1 基础知识 4.2 例题 5 旋转曲面及其方程 5.1 基础知识 5.2 例题 6 空间曲面的法线与切平面 6.1 基础知识 6.2…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

【git】把本地更改提交远程新分支feature_g

创建并切换新分支 git checkout -b feature_g 添加并提交更改 git add . git commit -m “实现图片上传功能” 推送到远程 git push -u origin feature_g...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

Java 二维码

Java 二维码 **技术&#xff1a;**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...

Python 包管理器 uv 介绍

Python 包管理器 uv 全面介绍 uv 是由 Astral&#xff08;热门工具 Ruff 的开发者&#xff09;推出的下一代高性能 Python 包管理器和构建工具&#xff0c;用 Rust 编写。它旨在解决传统工具&#xff08;如 pip、virtualenv、pip-tools&#xff09;的性能瓶颈&#xff0c;同时…...

【Linux】Linux 系统默认的目录及作用说明

博主介绍&#xff1a;✌全网粉丝23W&#xff0c;CSDN博客专家、Java领域优质创作者&#xff0c;掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围&#xff1a;SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...

掌握 HTTP 请求:理解 cURL GET 语法

cURL 是一个强大的命令行工具&#xff0c;用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中&#xff0c;cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...