当前位置: 首页 > news >正文

注意力机制(代码实现案例)

学习目标

  • 了解什么是注意力计算规则以及常见的计算规则.
  • 了解什么是注意力机制及其作用.
  • 掌握注意力机制的实现步骤.

1 注意力机制介绍

1.1 注意力概念

  • 我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果. 正是基于这样的理论,就产生了注意力机制.

1.2 注意力计算规则

  • 它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示. 当输入的Q=K=V时, 称作自注意力计算规则.

1.3 常见的注意力计算规则

  • bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

2 什么是注意力机制

  • 注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
  • 说明: NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型.

3 注意力机制的作用

  • 在解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果, 当其作为解码器的输入时提升效果. 改善以往编码器输出是单一定长张量, 无法存储过多信息的情况.
  • 在编码器端的注意力机制: 主要解决表征问题, 相当于特征提取过程, 得到输入的注意力表示. 一般使用自注意力(self-attention).

注意力机制在网络中实现的图形表示:

4 注意力机制实现步骤

4.1 步骤

  • 第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.
  • 第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.
  • 第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.

4.2 代码实现

  • 常见注意力机制的代码分析:
    import torch
    import torch.nn as nn
    import torch.nn.functional as Fclass Attn(nn.Module):def __init__(self, query_size, key_size, value_size1, value_size2, output_size):"""初始化函数中的参数有5个, query_size代表query的最后一维大小key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, value = (1, value_size1, value_size2)value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""super(Attn, self).__init__()# 将以下参数传入类中self.query_size = query_sizeself.key_size = key_sizeself.value_size1 = value_size1self.value_size2 = value_size2self.output_size = output_size# 初始化注意力机制实现第一步中需要的线性层.self.attn = nn.Linear(self.query_size + self.key_size, value_size1)# 初始化注意力机制实现第三步中需要的线性层.self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)def forward(self, Q, K, V):"""forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""# 第一步, 按照计算规则进行计算, # 我们采用常见的第一种计算规则# 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)# 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, # 需要将Q与第一步的计算结果再进行拼接output = torch.cat((Q[0], attn_applied[0]), 1)# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出# 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度output = self.attn_combine(output).unsqueeze(0)return output, attn_weights
    

  • 调用:
  • query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64
    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

  • 输出效果:
    tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -0.4096, -0.5982,  0.1548,-0.0771, -0.0951,  0.1833,  0.3128,  0.1260,  0.4420,  0.0495,-0.7774, -0.0995,  0.2629,  0.4957,  1.0922,  0.1428,  0.3024,-0.2646, -0.0265,  0.0632,  0.3951,  0.1583,  0.1130,  0.5500,-0.1887, -0.2816, -0.3800, -0.5741,  0.1342,  0.0244, -0.2217,0.1544,  0.1865, -0.2019,  0.4090, -0.4762,  0.3677, -0.2553,-0.5199,  0.2290, -0.4407,  0.0663, -0.0182, -0.2168,  0.0913,-0.2340,  0.1924, -0.3687,  0.1508,  0.3618, -0.0113,  0.2864,-0.1929, -0.6821,  0.0951,  0.1335,  0.3560, -0.3215,  0.6461,0.1532]]], grad_fn=<UnsqueezeBackward0>)tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,0.0166, 0.0225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,0.0419, 0.0352, 0.0392, 0.0546, 0.0224]], grad_fn=<SoftmaxBackward>)
    

相关文章:

注意力机制(代码实现案例)

学习目标 了解什么是注意力计算规则以及常见的计算规则.了解什么是注意力机制及其作用.掌握注意力机制的实现步骤. 1 注意力机制介绍 1.1 注意力概念 我们观察事物时&#xff0c;之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物…...

全量知识系统问题及SmartChat给出的答复 之8 三套工具之3语法解析器 之1

Q19. 问题 : 解释单词解释单词occupied 的字典条目 (word-def occupiedinterest 5type EBsubclass SEBtemplate (script $Demonstrateactor nilobject nildemands nilmethod (scene $Occupyactor nillocation nil))fill (((actor) (top-of *actor-s…...

软考59-上午题-【数据库】-小结+杂题

一、杂题 真题1&#xff1a; 真题2&#xff1a; 真题3&#xff1a; 真题4&#xff1a; 真题5&#xff1a; 真题6&#xff1a; 真题7&#xff1a; 真题8&#xff1a; 二、数据库总结 考试题型&#xff1a; 1、选择题&#xff08;6题&#xff0c;6分&#xff09; 2、综合分析题…...

【ARM Trace32(劳特巴赫) 高级篇 21 -- SystemTrace ITM 使用介绍】

文章目录 SystemTrace ITMSystemTrace ITM 常用命令Trace Data AnalysisSystemTrace ITM CoreSight ITM (Instrumentation Trace Macrocell) provides the following information: Address, data value and instruction address for selected data cyclesInterrupt event info…...

Python系列(20)—— 循环语句

Python中的循环控制语句 一、引言 在Python编程中&#xff0c;循环是重复执行一段代码直到满足特定条件的基本结构。Python提供了多种循环控制语句&#xff0c;如For 和While &#xff0c;以及用于控制循环流程的辅助语句&#xff0c;如Break、Continue和Pass。这些语句的组合…...

MYSQL的sql性能优化技巧

在编写 SQL 查询时&#xff0c;有一些技巧可以帮助你提高性能、简化查询并避免常见错误。以下是一些 MySQL 的写 SQL 技巧&#xff1a; 1. 使用索引 确保经常用于搜索、排序和连接的列上有索引。避免在索引列上使用函数或表达式&#xff0c;这会导致索引失效。使用 EXPLAIN 关…...

C#(C Sharp)学习笔记_数组的遍历【十】

输出数组内容 一般而言&#xff0c;我们会使用索引来输出指定的内容。 int[] arrayInt new int[] {4, 5, 2, 7, 9}; Console.WriteLine(arrayInt[3]);但这样只能输出指定的索引指向的内容&#xff0c;无法一下子查看数组全部的值。所以我们需要用到遍历方法输出所有元素。 …...

掌握未来技术:一站式深度学习学习平台体验!

介绍&#xff1a;深度学习是机器学习的一个子领域&#xff0c;它模仿人脑的分析和学习能力&#xff0c;通过构建和训练多层神经网络来学习数据的内在规律和表示层次。 深度学习的核心在于能够自动学习数据中的高层次特征&#xff0c;而无需人工进行复杂的特征工程。这种方法在图…...

Doris实战——特步集团零售数据仓库项目实践

目录 一、背景 二、总体架构 三、ETL实践 3.1 批量数据的导入 3.2 实时数据接入 3.3 数据加工 3.4 BI 查询 四、实时需求响应 五、其他经验 5.1 Doris BE内存溢出 5.2 SQL任务超时 5.3 删除语句不支持表达式 5.4 Drop 表闪回 六、未来展望 原文大佬的这篇Doris数…...

【python】(4)条件和循环

条件语句(Conditional Statements) 条件语句允许程序根据条件的不同执行不同的代码段。这是实现决策逻辑、分支和循环的基础。 if 语句 if 语句是最基本的条件语句,它用于执行仅当特定条件为真时才需要执行的代码块。 x = 10 if x > 5:print("x is greater than…...

Docker 的基本概念

Docker是一种开源的容器化平台&#xff0c;可以用于将应用程序和它们的依赖项打包到一个可移植的容器中。Docker容器可以在任何支持Docker的操作系统上运行&#xff0c;提供了隔离、可移植性和易于部署的优势。 Docker的基本概念包括以下几点&#xff1a; 镜像&#xff08;Im…...

5.44 BCC工具之killsnoop.py解读

一,工具简介 工具用于追踪通过 kill() 系统调用发送的信号,并实时报告相关信息。 二,代码示例 #!/usr/bin/env pythonfrom __future__ import print_function from bcc import BPF from bcc.utils import ArgString, printb import argparse from time import strftime# …...

2023人机交互期末复习

考试题型及分值分布 1、选择题&#xff08;10题、20分&#xff09; 2、填空题&#xff08;10题、20分&#xff09; 3、判断题&#xff08;可选、5题、10分&#xff09; 4、解答题&#xff08;5~6题、30分&#xff09; 5、分析计算题&#xff08;1~2题、20分&#xff09; 注意&…...

Linux使用bcache 将SSD加速硬盘

前言 在Linux下&#xff0c;使用SSD为HDD加速&#xff0c;目前较为成熟的方案有&#xff1a;flashcache&#xff0c;enhanceIO&#xff0c;dm-cache&#xff0c;bcache等&#xff0c;多方面比较以后最终选择了bcache。 bcache 是一个 Linux 内核块层超速缓存。它允许使用一个或…...

大厂报价查询系统性能优化之道!

0 前言 机票查询系统&#xff0c;日均亿级流量&#xff0c;要求高吞吐&#xff0c;低延迟架构设计。提升缓存的效率以及实时计算模块长尾延迟&#xff0c;成为制约机票查询系统性能关键。本文介绍机票查询系统在缓存和实时计算两个领域的架构提升。 1 机票搜索服务概述 1.1 …...

Carbondata编译适配Spark3

背景 当前carbondata版本2.3.1-rc1中项目源码适配的spark版本最高为3.1,我们需要进行spark3.3版本的编译适配。 原始编译 linux系统下载源码后&#xff0c;安装maven3.6.3&#xff0c;然后执行&#xff1a; mvn -DskipTests -Pspark-3.1 clean package会遇到一些网络问题&a…...

数学建模【灰色关联分析】

一、灰色关联分析简介 一般的抽象系统,如社会系统、经济系统、农业系统、生态系统、教育系统等都包含有许多种因素&#xff0c;多种因素共同作用的结果决定了该系统的发展态势。人们常常希望知道在众多的因素中&#xff0c;哪些是主要因素&#xff0c;哪些是次要因素;哪些因素…...

Vue.js的单向数据流:让你的应用更清晰、更可控

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…...

IntelliJ IDEA社区版传统web开发环境搭建

前言 现在主流的开发框架是SpringBoot,使用maven配置的开发环境&#xff0c;网上有很多教程&#xff0c;这里记录一下传统Web开发项目&#xff08;mvc架构的框架&#xff0c;如SSH&#xff09;使用idea社区版的开发环境搭建。防止被人说都2024年了还用eclipse。 一、下载文件…...

arm-linux-gnueabi、arm-linux-gnueabihf 交叉编译器区别

1、arm-linux-gnueabi&#xff1a; 使用软件浮点&#xff08;软浮点&#xff09;。这意味着所有的浮点运算都将由软件库来处理&#xff0c;而不会利用硬件中的浮点运算单元。因此&#xff0c;生成的目标代码包含了对软件浮点库的调用。 2、arm-linux-gnueabihf&#xff1a; 使…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

ES6从入门到精通:前言

ES6简介 ES6&#xff08;ECMAScript 2015&#xff09;是JavaScript语言的重大更新&#xff0c;引入了许多新特性&#xff0c;包括语法糖、新数据类型、模块化支持等&#xff0c;显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var&#xf…...

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module

1、为什么要修改 CONNECT 报文&#xff1f; 多租户隔离&#xff1a;自动为接入设备追加租户前缀&#xff0c;后端按 ClientID 拆分队列。零代码鉴权&#xff1a;将入站用户名替换为 OAuth Access-Token&#xff0c;后端 Broker 统一校验。灰度发布&#xff1a;根据 IP/地理位写…...

python如何将word的doc另存为docx

将 DOCX 文件另存为 DOCX 格式&#xff08;Python 实现&#xff09; 在 Python 中&#xff0c;你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是&#xff0c;.doc 是旧的 Word 格式&#xff0c;而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事&#xff0c;必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后&#xff0c;我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集&#xff0c;就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准

城市路内停车管理常因行道树遮挡、高位设备盲区等问题&#xff0c;导致车牌识别率低、逃费率高&#xff0c;传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法&#xff0c;正成为破局关键。该设备安装于车位侧方0.5-0.7米高度&#xff0c;直接规避树枝遮…...

C++_哈希表

本篇文章是对C学习的哈希表部分的学习分享 相信一定会对你有所帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、基础概念 1. 哈希核心思想&#xff1a; 哈希函数的作用&#xff1a;通过此函数建立一个Key与存储位置之间的映射关系。理想目标&#xff1a;实现…...

针对药品仓库的效期管理问题,如何利用WMS系统“破局”

案例&#xff1a; 某医药分销企业&#xff0c;主要经营各类药品的批发与零售。由于药品的特殊性&#xff0c;效期管理至关重要&#xff0c;但该企业一直面临效期问题的困扰。在未使用WMS系统之前&#xff0c;其药品入库、存储、出库等环节的效期管理主要依赖人工记录与检查。库…...

海云安高敏捷信创白盒SCAP入选《中国网络安全细分领域产品名录》

近日&#xff0c;嘶吼安全产业研究院发布《中国网络安全细分领域产品名录》&#xff0c;海云安高敏捷信创白盒&#xff08;SCAP&#xff09;成功入选软件供应链安全领域产品名录。 在数字化转型加速的今天&#xff0c;网络安全已成为企业生存与发展的核心基石&#xff0c;为了解…...