当前位置: 首页 > news >正文

Diffusion的unet中用到的AttentionBlock详解

AttentionBlock

  • torch.split
  • torch中的permute的用法
    • torch.transpose()
    • view()
  • torch.bmm
  • softmax(x, dim=-1)

Diffusion的unet中用到的AttentionBlock详解

class AttentionBlock(nn.Module):__doc__ = r"""Applies QKV self-attention with a residual connection.Input:x: tensor of shape (N, in_channels, H, W)norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32Output:tensor of shape (N, in_channels, H, W)Args:in_channels (int): number of input channels"""def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)# 为啥这里的QKV并不是一样的???而是把通道数翻了3倍self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w = x.shapeq, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention = torch.softmax(dot_products, dim=-1)out = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)

torch.split

torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''

eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)

torch.split详解

torch中的permute的用法

作用:permute可以对tensor进行转置

import torch
import torch.nn as nnx = torch.randn(1, 2, 3, 4)
print(x.size())    # torch.Size([1, 2, 3, 4])   
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])   

torch.transpose()

因为torch.transpose 一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:

x.transpose(0,2).transpose(1,3)

view()

view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型

import torch
import torch.nn as nn
import numpy as npy = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2)) 
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],[4, 5, 6]]])
tensor([[[1, 4]],[[2, 5]],[[3, 6]]])
tensor([[[1, 2],[3, 4],[5, 6]]])

permute参考
permute详解参考

torch.bmm

作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)

torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D

矩阵相乘

softmax(x, dim=-1)

import torch 
a = torch.randn(2,3)
print(a)
tensor([[-8.2976e-01,  5.8105e-04,  1.2218e+00],[ 1.9745e-01,  1.2727e+00,  5.9587e-01]])
b = torch.softmax(a, dim=-1)
print(b)
tensor([[0.0903, 0.2072, 0.7025],[0.1845, 0.5407, 0.2748]])

softmax(x, dim=-1)

相关文章:

Diffusion的unet中用到的AttentionBlock详解

AttentionBlocktorch.splittorch中的permute的用法torch.transpose()view()torch.bmmsoftmax(x, dim-1)Diffusion的unet中用到的AttentionBlock详解class AttentionBlock(nn.Module):__doc__ r"""Applies QKV self-attention with a residual connection.Input…...

ElasticSearch索引文档写入和近实时搜索

一、基本概念 1.Segments In Lucene 众所周知,ElasticSearch存储的基本单元Shard,ES中一个Index可能分为多个Shard,事实上每个Shard都是一个Lucence的Index,并且每个Lucene Index由多个Segment组成,每个Segment事实上…...

【C语言蓝桥杯每日一题】——等差数列

【C语言蓝桥杯每日一题】——等差数列😎前言🙌等差数列🙌解题思路分析:😍解题源代码分享:😍总结撒花💞😎博客昵称:博客小梦 😊最喜欢的座右铭&…...

EM7电磁铁的技术参数

电磁铁可以通过更换电磁铁极头在一定范围内改善磁场的大小和磁场的均匀度 ,并且可以通过调整极头间距改变磁场的大小。主要用于磁滞现象研究、磁化系数测量、霍尔效应研究、磁光实验、磁场退火、核磁共振、电子顺磁共振、生物学研究、磁性测量、磁性材料取向、磁性产…...

选择很重要,骑友,怎么挑选骑行装备?

骑行装备的重要性,已经不用多说了,大家也都知道。但是如何挑选,如何选择适合自己的骑行装备呢?今天我来和大家聊一聊这个问题。首先我们需要了解一个概念:骑行装备分为两大类:骑行服和骑行鞋。对于公路车来…...

【JUC面试题】Java并发编程面试题

Java并发编程 基础知识 1. 为什么要使用并发编程? 提升多核系统的CPU利用率一般来说一台主机上的会有多个CPU核心,我们可以创建多个线程,理论 上讲操作系统可以将多个线程分配给不同的CPU去执行,每个CPU执行一个线程&#xff0c…...

spark笔记

spark笔记 1. 概述 Spark是一种基于内存的快速、通用、可扩展的大数据分析计算引擎;Spark提供内存计算,将计算结果直接放在内存中,减少了迭代计算的IO开销,有更高效的运算效率。 1.1 Spark核心模块 Spark Core:提供S…...

丢失了packet.dll原因和解决方法全面指南

packet.dll是Windows操作系统中的一个重要文件,它主要用于网络通信,如果丢失了这个文件,可能会导致网络连接问题。本文将探讨packet.dll文件丢失的原因,并提供相应的解决方法。 一、丢失packet.dll文件的原因 1. 病毒感染&#x…...

算法练习随记(三)

1.全排列 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 输入&#x…...

基于Python 进行卫星图像多种指数分析

一、前言本文帮助读者更好地了解卫星数据以及使用 Python 探索和分析哨兵2卫星数号数据在Sundarbans地区的不同方法。二、Sundarbans研究区孙德尔本斯(Sundarbans)是恒河、雅鲁藏布江和梅克纳河在孟加拉湾汇合形成的三角洲中最大的红树林区之一。 孙德尔…...

(Week 15)综合复习(C++,字符串,数学)

文章目录T1 [Daimayuan]删删(C,字符串)输入格式输出格式样例输入样例输出数据规模解题思路T2 [Daimayuan]快快变大(C,区间DP)输入格式输出格式样例输入样例输出数据规模解题思路T3 [Daimayuan]饿饿 饭饭2&a…...

迪赛智慧数——柱状图(正负条形图):“光棍”排行榜TOP10省份

效果图 中国单身男女最多的省份是广东,广东的人口是全国最多的。人口多了,单身的人也会多,单身女性324万,男性498万。全国第二的省份是四川省,单身女性256万,单身男性296万。 数据源:静态数据…...

IDEA集成chatGTP让你编码如虎添翼

第一步,打开您的IDEA, 打开首选项(Preference) -> 插件(Plugin) 在插件市场搜索 chatGPT, 点击安装 安装完毕后会提示您重启IDE, 重启IDEA. 重启后您会发现窗口,右边条上 竖着挂着个chatGPT按钮了。 第二步、配置APIkey或accessToken(二选一,推荐accessToken无费用…...

Python3 os.close() 方法、Python3 File readline() 方法

Python3 os.close() 方法 概述 os.close() 方法用于关闭指定的文件描述符 fd。 语法 close()方法语法格式如下: os.close(fd);参数 fd -- 文件描述符。 返回值 该方法没有返回值。 实例 以下实例演示了 close() 方法的使用: #!/usr/bin/python3…...

Vision Pro 自己写的一些自定义工具(c#)

目录前言一、保存图片工具1、展示2、源码下载地址二、3D图片格式转化1、展示2、源码下载地址三、所有工具汇总下载地址前言 自己用c#写的一些visionPro自定义工具,便于使用的时候直接拿出来,后续会不断添加新的工具。 想看怎么使用c#写visionPro自定义…...

ARM/FPGA/DSP板卡选型大全,总有一款适合您

创龙科技ARM/FPGA/DSP嵌入式板卡选型大全2023.2版本正式发布!接下来,跟着我们一起看看有哪些亮点吧! 6大主流工业处理器原厂 创龙科技现有30多条产品线,覆盖工业自动化、能源电力、仪器仪表、通信、医疗、安防等工业领域,与6大主流工业处理器原厂强强联合,包括德州仪器…...

【C语言蓝桥杯每日一题】—— 既约分数

【C语言蓝桥杯每日一题】—— 既约分数😎前言🙌既约分数🙌递归版解题代码:😍非递归版解题代码:😍总结撒花💞既约分数😎)😎博客昵称:博客小梦 &…...

【机器学习】线性回归

文章目录前言一、单变量线性回归1.导入必要的库2.读取数据3.绘制散点图4.划分数据5.定义模型函数6.定义损失函数7.求权重向量w7.1 梯度下降函数7.2 最小二乘法8.训练模型9.绘制预测曲线10.试试正则化11.绘制预测曲线12.试试sklearn库二、多变量线性回归1.导入库2.读取数据3.划分…...

用ChatGPT学习多传感器融合中的基础知识

困惑与解答: 问题:匈牙利算法中的增广矩阵路径是什么意思 解答: 匈牙利算法是解决二分图最大匹配的经典算法之一。其中的增广矩阵路径指的是在当前匹配下,从一个未匹配节点开始,沿着交替路(交替路是指依次…...

PyCharm2020介绍

PyCharm2020PyCharm2020安装过程PyCharm2020安装包1、PyCharm2020介绍2、PyCharm2020特点3、PyCharm2020特点4、PyCharm2020PyCharm2020安装过程 PyCharm2020安装过程安装步骤点击此链接。 PyCharm2020安装包 链接:https://pan.baidu.com/s/19R3nJx6wMyNBU9oY4N4n…...

ES6从入门到精通:前言

ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var&#xf…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装(Encapsulation) 定义:将数据(属性)和操作数据的方法绑定在一起,通过访问控制符(private、protected、public)隐藏内部实现细节。示例: public …...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个?3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制(过半机制&#xff0…...

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)

笔记整理&#xff1a;刘治强&#xff0c;浙江大学硕士生&#xff0c;研究方向为知识图谱表示学习&#xff0c;大语言模型 论文链接&#xff1a;http://arxiv.org/abs/2407.16127 发表会议&#xff1a;ISWC 2024 1. 动机 传统的知识图谱补全&#xff08;KGC&#xff09;模型通过…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...

pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)

目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 &#xff08;1&#xff09;输入单引号 &#xff08;2&#xff09;万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...

uniapp 实现腾讯云IM群文件上传下载功能

UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中&#xff0c;群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS&#xff0c;在uniapp中实现&#xff1a; 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...

Linux基础开发工具——vim工具

文章目录 vim工具什么是vimvim的多模式和使用vim的基础模式vim的三种基础模式三种模式的初步了解 常用模式的详细讲解插入模式命令模式模式转化光标的移动文本的编辑 底行模式替换模式视图模式总结 使用vim的小技巧vim的配置(了解) vim工具 本文章仍然是继续讲解Linux系统下的…...