当前位置: 首页 > news >正文

注意力机制(代码实现案例)

学习目标

  • 了解什么是注意力计算规则以及常见的计算规则.
  • 了解什么是注意力机制及其作用.
  • 掌握注意力机制的实现步骤.

1 注意力机制介绍

1.1 注意力概念

  • 我们观察事物时,之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物最具有辨识度的部分从而作出判断,而并非是从头到尾的观察一遍事物后,才能有判断结果. 正是基于这样的理论,就产生了注意力机制.

1.2 注意力计算规则

  • 它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示. 当输入的Q=K=V时, 称作自注意力计算规则.

1.3 常见的注意力计算规则

  • bmm运算演示:

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

2 什么是注意力机制

  • 注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
  • 说明: NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型.

3 注意力机制的作用

  • 在解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果, 当其作为解码器的输入时提升效果. 改善以往编码器输出是单一定长张量, 无法存储过多信息的情况.
  • 在编码器端的注意力机制: 主要解决表征问题, 相当于特征提取过程, 得到输入的注意力表示. 一般使用自注意力(self-attention).

注意力机制在网络中实现的图形表示:

4 注意力机制实现步骤

4.1 步骤

  • 第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.
  • 第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.
  • 第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.

4.2 代码实现

  • 常见注意力机制的代码分析:
    import torch
    import torch.nn as nn
    import torch.nn.functional as Fclass Attn(nn.Module):def __init__(self, query_size, key_size, value_size1, value_size2, output_size):"""初始化函数中的参数有5个, query_size代表query的最后一维大小key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, value = (1, value_size1, value_size2)value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""super(Attn, self).__init__()# 将以下参数传入类中self.query_size = query_sizeself.key_size = key_sizeself.value_size1 = value_size1self.value_size2 = value_size2self.output_size = output_size# 初始化注意力机制实现第一步中需要的线性层.self.attn = nn.Linear(self.query_size + self.key_size, value_size1)# 初始化注意力机制实现第三步中需要的线性层.self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)def forward(self, Q, K, V):"""forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""# 第一步, 按照计算规则进行计算, # 我们采用常见的第一种计算规则# 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)# 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, # 需要将Q与第一步的计算结果再进行拼接output = torch.cat((Q[0], attn_applied[0]), 1)# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出# 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度output = self.attn_combine(output).unsqueeze(0)return output, attn_weights
    

  • 调用:
  • query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    output_size = 64
    attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

  • 输出效果:
    tensor([[[ 0.4477, -0.0500, -0.2277, -0.3168, -0.4096, -0.5982,  0.1548,-0.0771, -0.0951,  0.1833,  0.3128,  0.1260,  0.4420,  0.0495,-0.7774, -0.0995,  0.2629,  0.4957,  1.0922,  0.1428,  0.3024,-0.2646, -0.0265,  0.0632,  0.3951,  0.1583,  0.1130,  0.5500,-0.1887, -0.2816, -0.3800, -0.5741,  0.1342,  0.0244, -0.2217,0.1544,  0.1865, -0.2019,  0.4090, -0.4762,  0.3677, -0.2553,-0.5199,  0.2290, -0.4407,  0.0663, -0.0182, -0.2168,  0.0913,-0.2340,  0.1924, -0.3687,  0.1508,  0.3618, -0.0113,  0.2864,-0.1929, -0.6821,  0.0951,  0.1335,  0.3560, -0.3215,  0.6461,0.1532]]], grad_fn=<UnsqueezeBackward0>)tensor([[0.0395, 0.0342, 0.0200, 0.0471, 0.0177, 0.0209, 0.0244, 0.0465, 0.0346,0.0378, 0.0282, 0.0214, 0.0135, 0.0419, 0.0926, 0.0123, 0.0177, 0.0187,0.0166, 0.0225, 0.0234, 0.0284, 0.0151, 0.0239, 0.0132, 0.0439, 0.0507,0.0419, 0.0352, 0.0392, 0.0546, 0.0224]], grad_fn=<SoftmaxBackward>)
    

相关文章:

注意力机制(代码实现案例)

学习目标 了解什么是注意力计算规则以及常见的计算规则.了解什么是注意力机制及其作用.掌握注意力机制的实现步骤. 1 注意力机制介绍 1.1 注意力概念 我们观察事物时&#xff0c;之所以能够快速判断一种事物(当然允许判断是错误的), 是因为我们大脑能够很快把注意力放在事物…...

全量知识系统问题及SmartChat给出的答复 之8 三套工具之3语法解析器 之1

Q19. 问题 : 解释单词解释单词occupied 的字典条目 (word-def occupiedinterest 5type EBsubclass SEBtemplate (script $Demonstrateactor nilobject nildemands nilmethod (scene $Occupyactor nillocation nil))fill (((actor) (top-of *actor-s…...

软考59-上午题-【数据库】-小结+杂题

一、杂题 真题1&#xff1a; 真题2&#xff1a; 真题3&#xff1a; 真题4&#xff1a; 真题5&#xff1a; 真题6&#xff1a; 真题7&#xff1a; 真题8&#xff1a; 二、数据库总结 考试题型&#xff1a; 1、选择题&#xff08;6题&#xff0c;6分&#xff09; 2、综合分析题…...

【ARM Trace32(劳特巴赫) 高级篇 21 -- SystemTrace ITM 使用介绍】

文章目录 SystemTrace ITMSystemTrace ITM 常用命令Trace Data AnalysisSystemTrace ITM CoreSight ITM (Instrumentation Trace Macrocell) provides the following information: Address, data value and instruction address for selected data cyclesInterrupt event info…...

Python系列(20)—— 循环语句

Python中的循环控制语句 一、引言 在Python编程中&#xff0c;循环是重复执行一段代码直到满足特定条件的基本结构。Python提供了多种循环控制语句&#xff0c;如For 和While &#xff0c;以及用于控制循环流程的辅助语句&#xff0c;如Break、Continue和Pass。这些语句的组合…...

MYSQL的sql性能优化技巧

在编写 SQL 查询时&#xff0c;有一些技巧可以帮助你提高性能、简化查询并避免常见错误。以下是一些 MySQL 的写 SQL 技巧&#xff1a; 1. 使用索引 确保经常用于搜索、排序和连接的列上有索引。避免在索引列上使用函数或表达式&#xff0c;这会导致索引失效。使用 EXPLAIN 关…...

C#(C Sharp)学习笔记_数组的遍历【十】

输出数组内容 一般而言&#xff0c;我们会使用索引来输出指定的内容。 int[] arrayInt new int[] {4, 5, 2, 7, 9}; Console.WriteLine(arrayInt[3]);但这样只能输出指定的索引指向的内容&#xff0c;无法一下子查看数组全部的值。所以我们需要用到遍历方法输出所有元素。 …...

掌握未来技术:一站式深度学习学习平台体验!

介绍&#xff1a;深度学习是机器学习的一个子领域&#xff0c;它模仿人脑的分析和学习能力&#xff0c;通过构建和训练多层神经网络来学习数据的内在规律和表示层次。 深度学习的核心在于能够自动学习数据中的高层次特征&#xff0c;而无需人工进行复杂的特征工程。这种方法在图…...

Doris实战——特步集团零售数据仓库项目实践

目录 一、背景 二、总体架构 三、ETL实践 3.1 批量数据的导入 3.2 实时数据接入 3.3 数据加工 3.4 BI 查询 四、实时需求响应 五、其他经验 5.1 Doris BE内存溢出 5.2 SQL任务超时 5.3 删除语句不支持表达式 5.4 Drop 表闪回 六、未来展望 原文大佬的这篇Doris数…...

【python】(4)条件和循环

条件语句(Conditional Statements) 条件语句允许程序根据条件的不同执行不同的代码段。这是实现决策逻辑、分支和循环的基础。 if 语句 if 语句是最基本的条件语句,它用于执行仅当特定条件为真时才需要执行的代码块。 x = 10 if x > 5:print("x is greater than…...

Docker 的基本概念

Docker是一种开源的容器化平台&#xff0c;可以用于将应用程序和它们的依赖项打包到一个可移植的容器中。Docker容器可以在任何支持Docker的操作系统上运行&#xff0c;提供了隔离、可移植性和易于部署的优势。 Docker的基本概念包括以下几点&#xff1a; 镜像&#xff08;Im…...

5.44 BCC工具之killsnoop.py解读

一,工具简介 工具用于追踪通过 kill() 系统调用发送的信号,并实时报告相关信息。 二,代码示例 #!/usr/bin/env pythonfrom __future__ import print_function from bcc import BPF from bcc.utils import ArgString, printb import argparse from time import strftime# …...

2023人机交互期末复习

考试题型及分值分布 1、选择题&#xff08;10题、20分&#xff09; 2、填空题&#xff08;10题、20分&#xff09; 3、判断题&#xff08;可选、5题、10分&#xff09; 4、解答题&#xff08;5~6题、30分&#xff09; 5、分析计算题&#xff08;1~2题、20分&#xff09; 注意&…...

Linux使用bcache 将SSD加速硬盘

前言 在Linux下&#xff0c;使用SSD为HDD加速&#xff0c;目前较为成熟的方案有&#xff1a;flashcache&#xff0c;enhanceIO&#xff0c;dm-cache&#xff0c;bcache等&#xff0c;多方面比较以后最终选择了bcache。 bcache 是一个 Linux 内核块层超速缓存。它允许使用一个或…...

大厂报价查询系统性能优化之道!

0 前言 机票查询系统&#xff0c;日均亿级流量&#xff0c;要求高吞吐&#xff0c;低延迟架构设计。提升缓存的效率以及实时计算模块长尾延迟&#xff0c;成为制约机票查询系统性能关键。本文介绍机票查询系统在缓存和实时计算两个领域的架构提升。 1 机票搜索服务概述 1.1 …...

Carbondata编译适配Spark3

背景 当前carbondata版本2.3.1-rc1中项目源码适配的spark版本最高为3.1,我们需要进行spark3.3版本的编译适配。 原始编译 linux系统下载源码后&#xff0c;安装maven3.6.3&#xff0c;然后执行&#xff1a; mvn -DskipTests -Pspark-3.1 clean package会遇到一些网络问题&a…...

数学建模【灰色关联分析】

一、灰色关联分析简介 一般的抽象系统,如社会系统、经济系统、农业系统、生态系统、教育系统等都包含有许多种因素&#xff0c;多种因素共同作用的结果决定了该系统的发展态势。人们常常希望知道在众多的因素中&#xff0c;哪些是主要因素&#xff0c;哪些是次要因素;哪些因素…...

Vue.js的单向数据流:让你的应用更清晰、更可控

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…...

IntelliJ IDEA社区版传统web开发环境搭建

前言 现在主流的开发框架是SpringBoot,使用maven配置的开发环境&#xff0c;网上有很多教程&#xff0c;这里记录一下传统Web开发项目&#xff08;mvc架构的框架&#xff0c;如SSH&#xff09;使用idea社区版的开发环境搭建。防止被人说都2024年了还用eclipse。 一、下载文件…...

arm-linux-gnueabi、arm-linux-gnueabihf 交叉编译器区别

1、arm-linux-gnueabi&#xff1a; 使用软件浮点&#xff08;软浮点&#xff09;。这意味着所有的浮点运算都将由软件库来处理&#xff0c;而不会利用硬件中的浮点运算单元。因此&#xff0c;生成的目标代码包含了对软件浮点库的调用。 2、arm-linux-gnueabihf&#xff1a; 使…...

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…...

vscode里如何用git

打开vs终端执行如下&#xff1a; 1 初始化 Git 仓库&#xff08;如果尚未初始化&#xff09; git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

反向工程与模型迁移:打造未来商品详情API的可持续创新体系

在电商行业蓬勃发展的当下&#xff0c;商品详情API作为连接电商平台与开发者、商家及用户的关键纽带&#xff0c;其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息&#xff08;如名称、价格、库存等&#xff09;的获取与展示&#xff0c;已难以满足市场对个性化、智能…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

CSS设置元素的宽度根据其内容自动调整

width: fit-content 是 CSS 中的一个属性值&#xff0c;用于设置元素的宽度根据其内容自动调整&#xff0c;确保宽度刚好容纳内容而不会超出。 效果对比 默认情况&#xff08;width: auto&#xff09;&#xff1a; 块级元素&#xff08;如 <div>&#xff09;会占满父容器…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

mac 安装homebrew (nvm 及git)

mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用&#xff1a; 方法一&#xff1a;使用 Homebrew 安装 Git&#xff08;推荐&#xff09; 步骤如下&#xff1a;打开终端&#xff08;Terminal.app&#xff09; 1.安装 Homebrew…...

论文阅读笔记——Muffin: Testing Deep Learning Libraries via Neural Architecture Fuzzing

Muffin 论文 现有方法 CRADLE 和 LEMON&#xff0c;依赖模型推理阶段输出进行差分测试&#xff0c;但在训练阶段是不可行的&#xff0c;因为训练阶段直到最后才有固定输出&#xff0c;中间过程是不断变化的。API 库覆盖低&#xff0c;因为各个 API 都是在各种具体场景下使用。…...

【p2p、分布式,区块链笔记 MESH】Bluetooth蓝牙通信 BLE Mesh协议的拓扑结构 定向转发机制

目录 节点的功能承载层&#xff08;GATT/Adv&#xff09;局限性&#xff1a; 拓扑关系定向转发机制定向转发意义 CG 节点的功能 节点的功能由节点支持的特性和功能决定。所有节点都能够发送和接收网格消息。节点还可以选择支持一个或多个附加功能&#xff0c;如 Configuration …...

Python 高效图像帧提取与视频编码:实战指南

Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...