Diffusion的unet中用到的AttentionBlock详解
AttentionBlock
- torch.split
- torch中的permute的用法
- torch.transpose()
- view()
- torch.bmm
- softmax(x, dim=-1)
Diffusion的unet中用到的AttentionBlock详解
class AttentionBlock(nn.Module):__doc__ = r"""Applies QKV self-attention with a residual connection.Input:x: tensor of shape (N, in_channels, H, W)norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32Output:tensor of shape (N, in_channels, H, W)Args:in_channels (int): number of input channels"""def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)# 为啥这里的QKV并不是一样的???而是把通道数翻了3倍self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w = x.shapeq, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention = torch.softmax(dot_products, dim=-1)out = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x
x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)
torch.split
torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''
eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)
torch.split详解
torch中的permute的用法
作用:permute可以对tensor进行转置
import torch
import torch.nn as nnx = torch.randn(1, 2, 3, 4)
print(x.size()) # torch.Size([1, 2, 3, 4])
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])
torch.transpose()
因为torch.transpose
一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:
x.transpose(0,2).transpose(1,3)
view()
view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型
import torch
import torch.nn as nn
import numpy as npy = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2))
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],[4, 5, 6]]])
tensor([[[1, 4]],[[2, 5]],[[3, 6]]])
tensor([[[1, 2],[3, 4],[5, 6]]])
permute参考
permute详解参考
torch.bmm
作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)
torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D
矩阵相乘
softmax(x, dim=-1)
import torch
a = torch.randn(2,3)
print(a)
tensor([[-8.2976e-01, 5.8105e-04, 1.2218e+00],[ 1.9745e-01, 1.2727e+00, 5.9587e-01]])
b = torch.softmax(a, dim=-1)
print(b)
tensor([[0.0903, 0.2072, 0.7025],[0.1845, 0.5407, 0.2748]])
softmax(x, dim=-1)
相关文章:

Diffusion的unet中用到的AttentionBlock详解
AttentionBlocktorch.splittorch中的permute的用法torch.transpose()view()torch.bmmsoftmax(x, dim-1)Diffusion的unet中用到的AttentionBlock详解class AttentionBlock(nn.Module):__doc__ r"""Applies QKV self-attention with a residual connection.Input…...

ElasticSearch索引文档写入和近实时搜索
一、基本概念 1.Segments In Lucene 众所周知,ElasticSearch存储的基本单元Shard,ES中一个Index可能分为多个Shard,事实上每个Shard都是一个Lucence的Index,并且每个Lucene Index由多个Segment组成,每个Segment事实上…...

【C语言蓝桥杯每日一题】——等差数列
【C语言蓝桥杯每日一题】——等差数列😎前言🙌等差数列🙌解题思路分析:😍解题源代码分享:😍总结撒花💞😎博客昵称:博客小梦 😊最喜欢的座右铭&…...

EM7电磁铁的技术参数
电磁铁可以通过更换电磁铁极头在一定范围内改善磁场的大小和磁场的均匀度 ,并且可以通过调整极头间距改变磁场的大小。主要用于磁滞现象研究、磁化系数测量、霍尔效应研究、磁光实验、磁场退火、核磁共振、电子顺磁共振、生物学研究、磁性测量、磁性材料取向、磁性产…...

选择很重要,骑友,怎么挑选骑行装备?
骑行装备的重要性,已经不用多说了,大家也都知道。但是如何挑选,如何选择适合自己的骑行装备呢?今天我来和大家聊一聊这个问题。首先我们需要了解一个概念:骑行装备分为两大类:骑行服和骑行鞋。对于公路车来…...

【JUC面试题】Java并发编程面试题
Java并发编程 基础知识 1. 为什么要使用并发编程? 提升多核系统的CPU利用率一般来说一台主机上的会有多个CPU核心,我们可以创建多个线程,理论 上讲操作系统可以将多个线程分配给不同的CPU去执行,每个CPU执行一个线程,…...

spark笔记
spark笔记 1. 概述 Spark是一种基于内存的快速、通用、可扩展的大数据分析计算引擎;Spark提供内存计算,将计算结果直接放在内存中,减少了迭代计算的IO开销,有更高效的运算效率。 1.1 Spark核心模块 Spark Core:提供S…...

丢失了packet.dll原因和解决方法全面指南
packet.dll是Windows操作系统中的一个重要文件,它主要用于网络通信,如果丢失了这个文件,可能会导致网络连接问题。本文将探讨packet.dll文件丢失的原因,并提供相应的解决方法。 一、丢失packet.dll文件的原因 1. 病毒感染&#x…...

算法练习随记(三)
1.全排列 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 输入&#x…...

基于Python 进行卫星图像多种指数分析
一、前言本文帮助读者更好地了解卫星数据以及使用 Python 探索和分析哨兵2卫星数号数据在Sundarbans地区的不同方法。二、Sundarbans研究区孙德尔本斯(Sundarbans)是恒河、雅鲁藏布江和梅克纳河在孟加拉湾汇合形成的三角洲中最大的红树林区之一。 孙德尔…...

(Week 15)综合复习(C++,字符串,数学)
文章目录T1 [Daimayuan]删删(C,字符串)输入格式输出格式样例输入样例输出数据规模解题思路T2 [Daimayuan]快快变大(C,区间DP)输入格式输出格式样例输入样例输出数据规模解题思路T3 [Daimayuan]饿饿 饭饭2&a…...

迪赛智慧数——柱状图(正负条形图):“光棍”排行榜TOP10省份
效果图 中国单身男女最多的省份是广东,广东的人口是全国最多的。人口多了,单身的人也会多,单身女性324万,男性498万。全国第二的省份是四川省,单身女性256万,单身男性296万。 数据源:静态数据…...

IDEA集成chatGTP让你编码如虎添翼
第一步,打开您的IDEA, 打开首选项(Preference) -> 插件(Plugin) 在插件市场搜索 chatGPT, 点击安装 安装完毕后会提示您重启IDE, 重启IDEA. 重启后您会发现窗口,右边条上 竖着挂着个chatGPT按钮了。 第二步、配置APIkey或accessToken(二选一,推荐accessToken无费用…...

Python3 os.close() 方法、Python3 File readline() 方法
Python3 os.close() 方法 概述 os.close() 方法用于关闭指定的文件描述符 fd。 语法 close()方法语法格式如下: os.close(fd);参数 fd -- 文件描述符。 返回值 该方法没有返回值。 实例 以下实例演示了 close() 方法的使用: #!/usr/bin/python3…...

Vision Pro 自己写的一些自定义工具(c#)
目录前言一、保存图片工具1、展示2、源码下载地址二、3D图片格式转化1、展示2、源码下载地址三、所有工具汇总下载地址前言 自己用c#写的一些visionPro自定义工具,便于使用的时候直接拿出来,后续会不断添加新的工具。 想看怎么使用c#写visionPro自定义…...

ARM/FPGA/DSP板卡选型大全,总有一款适合您
创龙科技ARM/FPGA/DSP嵌入式板卡选型大全2023.2版本正式发布!接下来,跟着我们一起看看有哪些亮点吧! 6大主流工业处理器原厂 创龙科技现有30多条产品线,覆盖工业自动化、能源电力、仪器仪表、通信、医疗、安防等工业领域,与6大主流工业处理器原厂强强联合,包括德州仪器…...

【C语言蓝桥杯每日一题】—— 既约分数
【C语言蓝桥杯每日一题】—— 既约分数😎前言🙌既约分数🙌递归版解题代码:😍非递归版解题代码:😍总结撒花💞既约分数😎)😎博客昵称:博客小梦 &…...

【机器学习】线性回归
文章目录前言一、单变量线性回归1.导入必要的库2.读取数据3.绘制散点图4.划分数据5.定义模型函数6.定义损失函数7.求权重向量w7.1 梯度下降函数7.2 最小二乘法8.训练模型9.绘制预测曲线10.试试正则化11.绘制预测曲线12.试试sklearn库二、多变量线性回归1.导入库2.读取数据3.划分…...

用ChatGPT学习多传感器融合中的基础知识
困惑与解答: 问题:匈牙利算法中的增广矩阵路径是什么意思 解答: 匈牙利算法是解决二分图最大匹配的经典算法之一。其中的增广矩阵路径指的是在当前匹配下,从一个未匹配节点开始,沿着交替路(交替路是指依次…...

PyCharm2020介绍
PyCharm2020PyCharm2020安装过程PyCharm2020安装包1、PyCharm2020介绍2、PyCharm2020特点3、PyCharm2020特点4、PyCharm2020PyCharm2020安装过程 PyCharm2020安装过程安装步骤点击此链接。 PyCharm2020安装包 链接:https://pan.baidu.com/s/19R3nJx6wMyNBU9oY4N4n…...

Le Potato + Jumbospot MMDVM热点盒子
最近才留意到,树莓派受到编程圈一定瞩目之后,智慧的同胞早已悄咪咪的搞了一堆xx派出来,本来对于香橙派,苹果派,土豆派和香蕉派是不感冒的,但是因为最近树莓派夸张的二级市场价格和断供,终于还是…...

蓝桥杯第19天(Python)(疯狂刷题第2天)
题型: 1.思维题/杂题:数学公式,分析题意,找规律 2.BFS/DFS:广搜(递归实现),深搜(deque实现) 3.简单数论:模,素数(只需要…...

(五)手把手带你搭建精美简洁的个人时间管理网站—基于Axure的首页原型设计
🌟所属专栏:献给榕榕🐔作者简介:rchjr——五带信管菜只因一枚 😮前言:该专栏系为女友准备的,里面会不定时发一些讨好她的技术作品,感兴趣的小伙伴可以关注一下~👉文章简介…...

阿里面试:为什么MySQL不建议使用delete删除数据?
MySQL是一种关系型数据库管理系统,它的数据存储是基于磁盘上的文件系统实现的。MySQL将数据存储在表中,每个表由一系列的行和列组成。每一行表示一个记录,每一列表示一个字段。表的结构由其列名、数据类型、索引等信息组成。 MySQL的数据存储…...

低代码开发公司:用科技强力开启产业分工新时代!
实现办公自动化,是不少企业的共同追求。低代码开发公司会遵循时代发展规律,注入强劲的科技新生力量,在低代码开发市场厚积爆发、努力奋斗,推动企业数字化转型升级,为每一个企业的办公自动化升级创新贡献应有的力量。 一…...

参考mfa官方文档实践笔记(亲测)
按顺序执行以下指令: conda create -n aligner -c conda-forge montreal-forced-alignerconda config --add channels conda-forgeconda activate alignerconda install pytorch torchvision torchaudio pytorch-cuda11.7 -c pytorch -c nvidia 如果报错࿱…...

【 第六章 拦截器,注解配置springMVC,springMVC执行流程】
第六章 拦截器,注解配置springMVC,springMVC执行流程 1.拦截器: ①springMVC中的拦截器用于拦截控制器方法的执行。 ②springMVC的拦截器需要实现HandlerInterceptor或者继承HandlerInterceptorAdapter类。 ③springMVC的拦截器必须在spring…...

一种编译器视角下的python性能优化
“Life is short,You need python”!老码农很喜欢python的优雅,然而,在生产环境中,Python这样的没有优先考虑性能构建优化的动态语言特性可能是危险的,因此,流行的高性能库如TensorFlow 或PyTor…...

太逼真!这个韩国虚拟女团你追不追?
“她们看上去太像真人了”, 韩国虚拟女团MAVE的首支MV和打歌舞台引发网友阵阵惊呼。现在,她们的舞蹈已经有真人在挑战了。 这一组虚拟人的“逼真”倒不在脸,主要是MAVE女团的舞台动作接近自然,不放近景看,基本可以达到…...

安全与道路测试:自动驾驶系统安全性探究
随着自动驾驶技术的迅速发展,如何确保自动驾驶系统的安全性已成为业界关注的焦点。本文将探讨自动驾驶系统的潜在风险、安全设计原则和道路测试要求。 潜在风险 自动驾驶系统在改善交通安全和提高出行效率方面具有巨大潜力,但其安全性仍面临许多挑战&a…...