当前位置: 首页 > news >正文

Diffusion的unet中用到的AttentionBlock详解

AttentionBlock

  • torch.split
  • torch中的permute的用法
    • torch.transpose()
    • view()
  • torch.bmm
  • softmax(x, dim=-1)

Diffusion的unet中用到的AttentionBlock详解

class AttentionBlock(nn.Module):__doc__ = r"""Applies QKV self-attention with a residual connection.Input:x: tensor of shape (N, in_channels, H, W)norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32Output:tensor of shape (N, in_channels, H, W)Args:in_channels (int): number of input channels"""def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)# 为啥这里的QKV并不是一样的???而是把通道数翻了3倍self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w = x.shapeq, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention = torch.softmax(dot_products, dim=-1)out = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)

torch.split

torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''

eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)

torch.split详解

torch中的permute的用法

作用:permute可以对tensor进行转置

import torch
import torch.nn as nnx = torch.randn(1, 2, 3, 4)
print(x.size())    # torch.Size([1, 2, 3, 4])   
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])   

torch.transpose()

因为torch.transpose 一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:

x.transpose(0,2).transpose(1,3)

view()

view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型

import torch
import torch.nn as nn
import numpy as npy = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2)) 
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],[4, 5, 6]]])
tensor([[[1, 4]],[[2, 5]],[[3, 6]]])
tensor([[[1, 2],[3, 4],[5, 6]]])

permute参考
permute详解参考

torch.bmm

作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)

torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D

矩阵相乘

softmax(x, dim=-1)

import torch 
a = torch.randn(2,3)
print(a)
tensor([[-8.2976e-01,  5.8105e-04,  1.2218e+00],[ 1.9745e-01,  1.2727e+00,  5.9587e-01]])
b = torch.softmax(a, dim=-1)
print(b)
tensor([[0.0903, 0.2072, 0.7025],[0.1845, 0.5407, 0.2748]])

softmax(x, dim=-1)

相关文章:

Diffusion的unet中用到的AttentionBlock详解

AttentionBlocktorch.splittorch中的permute的用法torch.transpose()view()torch.bmmsoftmax(x, dim-1)Diffusion的unet中用到的AttentionBlock详解class AttentionBlock(nn.Module):__doc__ r"""Applies QKV self-attention with a residual connection.Input…...

ElasticSearch索引文档写入和近实时搜索

一、基本概念 1.Segments In Lucene 众所周知,ElasticSearch存储的基本单元Shard,ES中一个Index可能分为多个Shard,事实上每个Shard都是一个Lucence的Index,并且每个Lucene Index由多个Segment组成,每个Segment事实上…...

【C语言蓝桥杯每日一题】——等差数列

【C语言蓝桥杯每日一题】——等差数列😎前言🙌等差数列🙌解题思路分析:😍解题源代码分享:😍总结撒花💞😎博客昵称:博客小梦 😊最喜欢的座右铭&…...

EM7电磁铁的技术参数

电磁铁可以通过更换电磁铁极头在一定范围内改善磁场的大小和磁场的均匀度 ,并且可以通过调整极头间距改变磁场的大小。主要用于磁滞现象研究、磁化系数测量、霍尔效应研究、磁光实验、磁场退火、核磁共振、电子顺磁共振、生物学研究、磁性测量、磁性材料取向、磁性产…...

选择很重要,骑友,怎么挑选骑行装备?

骑行装备的重要性,已经不用多说了,大家也都知道。但是如何挑选,如何选择适合自己的骑行装备呢?今天我来和大家聊一聊这个问题。首先我们需要了解一个概念:骑行装备分为两大类:骑行服和骑行鞋。对于公路车来…...

【JUC面试题】Java并发编程面试题

Java并发编程 基础知识 1. 为什么要使用并发编程? 提升多核系统的CPU利用率一般来说一台主机上的会有多个CPU核心,我们可以创建多个线程,理论 上讲操作系统可以将多个线程分配给不同的CPU去执行,每个CPU执行一个线程&#xff0c…...

spark笔记

spark笔记 1. 概述 Spark是一种基于内存的快速、通用、可扩展的大数据分析计算引擎;Spark提供内存计算,将计算结果直接放在内存中,减少了迭代计算的IO开销,有更高效的运算效率。 1.1 Spark核心模块 Spark Core:提供S…...

丢失了packet.dll原因和解决方法全面指南

packet.dll是Windows操作系统中的一个重要文件,它主要用于网络通信,如果丢失了这个文件,可能会导致网络连接问题。本文将探讨packet.dll文件丢失的原因,并提供相应的解决方法。 一、丢失packet.dll文件的原因 1. 病毒感染&#x…...

算法练习随记(三)

1.全排列 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 输入&#x…...

基于Python 进行卫星图像多种指数分析

一、前言本文帮助读者更好地了解卫星数据以及使用 Python 探索和分析哨兵2卫星数号数据在Sundarbans地区的不同方法。二、Sundarbans研究区孙德尔本斯(Sundarbans)是恒河、雅鲁藏布江和梅克纳河在孟加拉湾汇合形成的三角洲中最大的红树林区之一。 孙德尔…...

(Week 15)综合复习(C++,字符串,数学)

文章目录T1 [Daimayuan]删删(C,字符串)输入格式输出格式样例输入样例输出数据规模解题思路T2 [Daimayuan]快快变大(C,区间DP)输入格式输出格式样例输入样例输出数据规模解题思路T3 [Daimayuan]饿饿 饭饭2&a…...

迪赛智慧数——柱状图(正负条形图):“光棍”排行榜TOP10省份

效果图 中国单身男女最多的省份是广东,广东的人口是全国最多的。人口多了,单身的人也会多,单身女性324万,男性498万。全国第二的省份是四川省,单身女性256万,单身男性296万。 数据源:静态数据…...

IDEA集成chatGTP让你编码如虎添翼

第一步,打开您的IDEA, 打开首选项(Preference) -> 插件(Plugin) 在插件市场搜索 chatGPT, 点击安装 安装完毕后会提示您重启IDE, 重启IDEA. 重启后您会发现窗口,右边条上 竖着挂着个chatGPT按钮了。 第二步、配置APIkey或accessToken(二选一,推荐accessToken无费用…...

Python3 os.close() 方法、Python3 File readline() 方法

Python3 os.close() 方法 概述 os.close() 方法用于关闭指定的文件描述符 fd。 语法 close()方法语法格式如下: os.close(fd);参数 fd -- 文件描述符。 返回值 该方法没有返回值。 实例 以下实例演示了 close() 方法的使用: #!/usr/bin/python3…...

Vision Pro 自己写的一些自定义工具(c#)

目录前言一、保存图片工具1、展示2、源码下载地址二、3D图片格式转化1、展示2、源码下载地址三、所有工具汇总下载地址前言 自己用c#写的一些visionPro自定义工具,便于使用的时候直接拿出来,后续会不断添加新的工具。 想看怎么使用c#写visionPro自定义…...

ARM/FPGA/DSP板卡选型大全,总有一款适合您

创龙科技ARM/FPGA/DSP嵌入式板卡选型大全2023.2版本正式发布!接下来,跟着我们一起看看有哪些亮点吧! 6大主流工业处理器原厂 创龙科技现有30多条产品线,覆盖工业自动化、能源电力、仪器仪表、通信、医疗、安防等工业领域,与6大主流工业处理器原厂强强联合,包括德州仪器…...

【C语言蓝桥杯每日一题】—— 既约分数

【C语言蓝桥杯每日一题】—— 既约分数😎前言🙌既约分数🙌递归版解题代码:😍非递归版解题代码:😍总结撒花💞既约分数😎)😎博客昵称:博客小梦 &…...

【机器学习】线性回归

文章目录前言一、单变量线性回归1.导入必要的库2.读取数据3.绘制散点图4.划分数据5.定义模型函数6.定义损失函数7.求权重向量w7.1 梯度下降函数7.2 最小二乘法8.训练模型9.绘制预测曲线10.试试正则化11.绘制预测曲线12.试试sklearn库二、多变量线性回归1.导入库2.读取数据3.划分…...

用ChatGPT学习多传感器融合中的基础知识

困惑与解答: 问题:匈牙利算法中的增广矩阵路径是什么意思 解答: 匈牙利算法是解决二分图最大匹配的经典算法之一。其中的增广矩阵路径指的是在当前匹配下,从一个未匹配节点开始,沿着交替路(交替路是指依次…...

PyCharm2020介绍

PyCharm2020PyCharm2020安装过程PyCharm2020安装包1、PyCharm2020介绍2、PyCharm2020特点3、PyCharm2020特点4、PyCharm2020PyCharm2020安装过程 PyCharm2020安装过程安装步骤点击此链接。 PyCharm2020安装包 链接:https://pan.baidu.com/s/19R3nJx6wMyNBU9oY4N4n…...

内存分配函数malloc kmalloc vmalloc

内存分配函数malloc kmalloc vmalloc malloc实现步骤: 1)请求大小调整:首先,malloc 需要调整用户请求的大小,以适应内部数据结构(例如,可能需要存储额外的元数据)。通常,这包括对齐调整,确保分配的内存地址满足特定硬件要求(如对齐到8字节或16字节边界)。 2)空闲…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

循环冗余码校验CRC码 算法步骤+详细实例计算

通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)&#xff0…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 (1)设置网关 打开VMware虚拟机,点击编辑…...

Leetcode 3577. Count the Number of Computer Unlocking Permutations

Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...

【机器视觉】单目测距——运动结构恢复

ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛&#xf…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别

OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...