当前位置: 首页 > news >正文

注意力机制(代码实现案例)

学习目标

  • 了解什么是注意力计算规则以及常见的计算规则.
  • 了解什么是注意力机制及其作用.
  • 掌握注意力机制的实现步骤.

1 注意力机制介绍

1.1 注意力概念

  • 我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果. 正是基于这样的理论,就产生了注意力机制.

1.2 注意力计算规则

  • 它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示. 当输入的Q=K=V时, 称作自注意力计算规则.

1.3 常见的注意力计算规则

  • bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

2 什么是注意力机制

  • 注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
  • 说明: NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型.

3 注意力机制的作用

  • 在解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果, 当其作为解码器的输入时提升效果. 改善以往编码器输出是单一定长张量, 无法存储过多信息的情况.
  • 在编码器端的注意力机制: 主要解决表征问题, 相当于特征提取过程, 得到输入的注意力表示. 一般使用自注意力(self-attention).

注意力机制在网络中实现的图形表示:

4 注意力机制实现步骤

4.1 步骤

  • 第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.
  • 第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.
  • 第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.

4.2 代码实现

  • 常见注意力机制的代码分析:
    import torch
    import torch.nn as nn
    import torch.nn.functional as Fclass Attn(nn.Module):def __init__(self, query_size, key_size, value_size1, value_size2, output_size):"""初始化函数中的参数有5个, query_size代表query的最后一维大小key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, value = (1, value_size1, value_size2)value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""super(Attn, self).__init__()# 将以下参数传入类中self.query_size = query_sizeself.key_size = key_sizeself.value_size1 = value_size1self.value_size2 = value_size2self.output_size = output_size# 初始化注意力机制实现第一步中需要的线性层.self.attn = nn.Linear(self.query_size + self.key_size, value_size1)# 初始化注意力机制实现第三步中需要的线性层.self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)def forward(self, Q, K, V):"""forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""# 第一步, 按照计算规则进行计算, # 我们采用常见的第一种计算规则# 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)# 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, # 需要将Q与第一步的计算结果再进行拼接output = torch.cat((Q[0], attn_applied[0]), 1)# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出# 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度output = self.attn_combine(output).unsqueeze(0)return output, attn_weights
    

  • 调用:
  • query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64
    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

  • 输出效果:
    tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -0.4096, -0.5982,  0.1548,-0.0771, -0.0951,  0.1833,  0.3128,  0.1260,  0.4420,  0.0495,-0.7774, -0.0995,  0.2629,  0.4957,  1.0922,  0.1428,  0.3024,-0.2646, -0.0265,  0.0632,  0.3951,  0.1583,  0.1130,  0.5500,-0.1887, -0.2816, -0.3800, -0.5741,  0.1342,  0.0244, -0.2217,0.1544,  0.1865, -0.2019,  0.4090, -0.4762,  0.3677, -0.2553,-0.5199,  0.2290, -0.4407,  0.0663, -0.0182, -0.2168,  0.0913,-0.2340,  0.1924, -0.3687,  0.1508,  0.3618, -0.0113,  0.2864,-0.1929, -0.6821,  0.0951,  0.1335,  0.3560, -0.3215,  0.6461,0.1532]]], grad_fn=<UnsqueezeBackward0>)tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,0.0166, 0.0225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,0.0419, 0.0352, 0.0392, 0.0546, 0.0224]], grad_fn=<SoftmaxBackward>)
    

相关文章:

注意力机制(代码实现案例)

学习目标 了解什么是注意力计算规则以及常见的计算规则.了解什么是注意力机制及其作用.掌握注意力机制的实现步骤. 1 注意力机制介绍 1.1 注意力概念 我们观察事物时&#xff0c;之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物…...

全量知识系统问题及SmartChat给出的答复 之8 三套工具之3语法解析器 之1

Q19. 问题 : 解释单词解释单词occupied 的字典条目 (word-def occupiedinterest 5type EBsubclass SEBtemplate (script $Demonstrateactor nilobject nildemands nilmethod (scene $Occupyactor nillocation nil))fill (((actor) (top-of *actor-s…...

软考59-上午题-【数据库】-小结+杂题

一、杂题 真题1&#xff1a; 真题2&#xff1a; 真题3&#xff1a; 真题4&#xff1a; 真题5&#xff1a; 真题6&#xff1a; 真题7&#xff1a; 真题8&#xff1a; 二、数据库总结 考试题型&#xff1a; 1、选择题&#xff08;6题&#xff0c;6分&#xff09; 2、综合分析题…...

【ARM Trace32(劳特巴赫) 高级篇 21 -- SystemTrace ITM 使用介绍】

文章目录 SystemTrace ITMSystemTrace ITM 常用命令Trace Data AnalysisSystemTrace ITM CoreSight ITM (Instrumentation Trace Macrocell) provides the following information: Address, data value and instruction address for selected data cyclesInterrupt event info…...

Python系列(20)—— 循环语句

Python中的循环控制语句 一、引言 在Python编程中&#xff0c;循环是重复执行一段代码直到满足特定条件的基本结构。Python提供了多种循环控制语句&#xff0c;如For 和While &#xff0c;以及用于控制循环流程的辅助语句&#xff0c;如Break、Continue和Pass。这些语句的组合…...

MYSQL的sql性能优化技巧

在编写 SQL 查询时&#xff0c;有一些技巧可以帮助你提高性能、简化查询并避免常见错误。以下是一些 MySQL 的写 SQL 技巧&#xff1a; 1. 使用索引 确保经常用于搜索、排序和连接的列上有索引。避免在索引列上使用函数或表达式&#xff0c;这会导致索引失效。使用 EXPLAIN 关…...

C#(C Sharp)学习笔记_数组的遍历【十】

输出数组内容 一般而言&#xff0c;我们会使用索引来输出指定的内容。 int[] arrayInt new int[] {4, 5, 2, 7, 9}; Console.WriteLine(arrayInt[3]);但这样只能输出指定的索引指向的内容&#xff0c;无法一下子查看数组全部的值。所以我们需要用到遍历方法输出所有元素。 …...

掌握未来技术:一站式深度学习学习平台体验!

介绍&#xff1a;深度学习是机器学习的一个子领域&#xff0c;它模仿人脑的分析和学习能力&#xff0c;通过构建和训练多层神经网络来学习数据的内在规律和表示层次。 深度学习的核心在于能够自动学习数据中的高层次特征&#xff0c;而无需人工进行复杂的特征工程。这种方法在图…...

Doris实战——特步集团零售数据仓库项目实践

目录 一、背景 二、总体架构 三、ETL实践 3.1 批量数据的导入 3.2 实时数据接入 3.3 数据加工 3.4 BI 查询 四、实时需求响应 五、其他经验 5.1 Doris BE内存溢出 5.2 SQL任务超时 5.3 删除语句不支持表达式 5.4 Drop 表闪回 六、未来展望 原文大佬的这篇Doris数…...

【python】(4)条件和循环

条件语句(Conditional Statements) 条件语句允许程序根据条件的不同执行不同的代码段。这是实现决策逻辑、分支和循环的基础。 if 语句 if 语句是最基本的条件语句,它用于执行仅当特定条件为真时才需要执行的代码块。 x = 10 if x > 5:print("x is greater than…...

Docker 的基本概念

Docker是一种开源的容器化平台&#xff0c;可以用于将应用程序和它们的依赖项打包到一个可移植的容器中。Docker容器可以在任何支持Docker的操作系统上运行&#xff0c;提供了隔离、可移植性和易于部署的优势。 Docker的基本概念包括以下几点&#xff1a; 镜像&#xff08;Im…...

5.44 BCC工具之killsnoop.py解读

一,工具简介 工具用于追踪通过 kill() 系统调用发送的信号,并实时报告相关信息。 二,代码示例 #!/usr/bin/env pythonfrom __future__ import print_function from bcc import BPF from bcc.utils import ArgString, printb import argparse from time import strftime# …...

2023人机交互期末复习

考试题型及分值分布 1、选择题&#xff08;10题、20分&#xff09; 2、填空题&#xff08;10题、20分&#xff09; 3、判断题&#xff08;可选、5题、10分&#xff09; 4、解答题&#xff08;5~6题、30分&#xff09; 5、分析计算题&#xff08;1~2题、20分&#xff09; 注意&…...

Linux使用bcache 将SSD加速硬盘

前言 在Linux下&#xff0c;使用SSD为HDD加速&#xff0c;目前较为成熟的方案有&#xff1a;flashcache&#xff0c;enhanceIO&#xff0c;dm-cache&#xff0c;bcache等&#xff0c;多方面比较以后最终选择了bcache。 bcache 是一个 Linux 内核块层超速缓存。它允许使用一个或…...

大厂报价查询系统性能优化之道!

0 前言 机票查询系统&#xff0c;日均亿级流量&#xff0c;要求高吞吐&#xff0c;低延迟架构设计。提升缓存的效率以及实时计算模块长尾延迟&#xff0c;成为制约机票查询系统性能关键。本文介绍机票查询系统在缓存和实时计算两个领域的架构提升。 1 机票搜索服务概述 1.1 …...

Carbondata编译适配Spark3

背景 当前carbondata版本2.3.1-rc1中项目源码适配的spark版本最高为3.1,我们需要进行spark3.3版本的编译适配。 原始编译 linux系统下载源码后&#xff0c;安装maven3.6.3&#xff0c;然后执行&#xff1a; mvn -DskipTests -Pspark-3.1 clean package会遇到一些网络问题&a…...

数学建模【灰色关联分析】

一、灰色关联分析简介 一般的抽象系统,如社会系统、经济系统、农业系统、生态系统、教育系统等都包含有许多种因素&#xff0c;多种因素共同作用的结果决定了该系统的发展态势。人们常常希望知道在众多的因素中&#xff0c;哪些是主要因素&#xff0c;哪些是次要因素;哪些因素…...

Vue.js的单向数据流:让你的应用更清晰、更可控

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…...

IntelliJ IDEA社区版传统web开发环境搭建

前言 现在主流的开发框架是SpringBoot,使用maven配置的开发环境&#xff0c;网上有很多教程&#xff0c;这里记录一下传统Web开发项目&#xff08;mvc架构的框架&#xff0c;如SSH&#xff09;使用idea社区版的开发环境搭建。防止被人说都2024年了还用eclipse。 一、下载文件…...

arm-linux-gnueabi、arm-linux-gnueabihf 交叉编译器区别

1、arm-linux-gnueabi&#xff1a; 使用软件浮点&#xff08;软浮点&#xff09;。这意味着所有的浮点运算都将由软件库来处理&#xff0c;而不会利用硬件中的浮点运算单元。因此&#xff0c;生成的目标代码包含了对软件浮点库的调用。 2、arm-linux-gnueabihf&#xff1a; 使…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义&#xff08;Task Definition&…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制&#xff0c;因此这个了16进制的数据既可以翻译成为这个机器码&#xff0c;也可以翻译成为这个国标码&#xff0c;所以这个时候很容易会出现这个歧义的情况&#xff1b; 因此&#xff0c;我们的这个国…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录

ASP.NET Core 是一个跨平台的开源框架&#xff0c;用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录&#xff0c;以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...

(十)学生端搭建

本次旨在将之前的已完成的部分功能进行拼装到学生端&#xff0c;同时完善学生端的构建。本次工作主要包括&#xff1a; 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中&#xff0c;手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力&#xff0c;既支持点击、长按、拖拽等基础单一手势的精细控制&#xff0c;也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

在rocky linux 9.5上在线安装 docker

前面是指南&#xff0c;后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程&#xff0c;并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令&#xff0c;把数据流转换成Message&#xff0c;状态转变流程是&#xff1a;State::Created 》 St…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

反射获取方法和属性

Java反射获取方法 在Java中&#xff0c;反射&#xff08;Reflection&#xff09;是一种强大的机制&#xff0c;允许程序在运行时访问和操作类的内部属性和方法。通过反射&#xff0c;可以动态地创建对象、调用方法、改变属性值&#xff0c;这在很多Java框架中如Spring和Hiberna…...