当前位置: 首页 > news >正文

多头自注意力机制的代码实现

文章目录

  • 1、自注意力机制
  • 2、多头注意力机制

  • transformer的整体结构:
    在这里插入图片描述

1、自注意力机制

  • 自注意力机制如下:
    在这里插入图片描述
  • 计算过程:
    在这里插入图片描述
  • 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1)	# dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)

2、多头注意力机制

  • 结构如下:
    在这里插入图片描述
  • 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias)        self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output

相关文章:

多头自注意力机制的代码实现

文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...

抽象工厂模式

目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...

登录校验-Filter-详解

目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...

堆栈方法区笔记记录

成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...

新版微信小程序获取用户手机号

小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...

CSS实践 —— 悬浮盒子阴影加上移效果

悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...

安全测试基础知识

软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语&#xff0c;但出于我们的目的&#xff0c;我们将评估定义为分析和发现漏洞&#xff0c;而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...

列表首屏毫秒级加载与自动滚动定位方案

引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...

小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic

摘 要 随着互联网的发展&#xff0c;网络技术的发展变得极其重要&#xff0c;所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变&#xff0c;所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...

Fortran 微分方程求解 --ODEPACK

最近涉及到使用Fortran对微分方程求解&#xff0c;我们知道MATLAB已有内置的函数&#xff0c;比如ode家族&#xff0c;ode15s&#xff0c;对应着不同的求解办法。通过查看odepack的官方文档&#xff0c;我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...

8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45

特点&#xff1a; ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数&#xff0c;频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数&#xff0c;可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...

【Python】函数

None类型 思考&#xff1a;若函数没有使用return语句返回数据&#xff0c;那么函数有返回值吗&#xff1f; 答&#xff1a;实际上是有的&#xff0c;Python中有一个特殊的字面量None&#xff0c;其类型是<class ‘NoneType’>&#xff0c;无返回值的函数&#xff0c;实…...

centos安装MySQL 解压版完整教程(按步骤傻瓜式安装

一、卸载系统自带的 Mariadb 查看&#xff1a; rpm -qa|grep mariadb 卸载&#xff1a; rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...

【后端速成 Vue】第一个 Vue 程序

1、为什么要学习 Vue&#xff1f; 为什么使用 Vue? 回想之前&#xff0c;前后端交互的时候&#xff0c;前端收到后端响应的数据&#xff0c;接着将数据渲染到页面上&#xff0c;之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery&#xff0c;但是这两个用起来还是不太…...

Macbook pro M1 安装Ubuntu教程

先讲下心路历程 由于版主最近刚切换到Mac&#xff0c;所以在安装的时候一上手就选择了virutalbox&#xff0c;结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看&#xff0c;才知…...

前端console.log打印内容与后端请求返回数据不一致

后端传值num0 前端打印num1 ,如图&#xff0c;console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...

SQL入门:多表查询

SQL&#xff0c;或者说结构化查询语言(Structured Query Language)&#xff0c;是用于管理和操作关系型数据库的标准语言。在本篇文章中&#xff0c;我们将重点介绍SQL中的多表查询&#xff0c;这是一种强大的工具&#xff0c;可以帮助我们从多个相关的表格中获取数据。 数据库…...

【C++】进一步认识模板

&#x1f3d6;️作者&#xff1a;malloc不出对象 ⛺专栏&#xff1a;C的学习之路 &#x1f466;个人简介&#xff1a;一名双非本科院校大二在读的科班编程菜鸟&#xff0c;努力编程只为赶上各位大佬的步伐&#x1f648;&#x1f648; 目录 前言一、非类型模板参数二、模板的特…...

Mysql Oracle 区别

1. oracle select *, id需要在星号前加别名&#xff0c;mysql则不需要 mysql语法&#xff1a; select *, id from xin_student_t;oracle语法&#xff1a; select st.*, st.id from xin_student_t st;2. oracle表定义了别名&#xff0c;在查询时可以不用别名指定字段&#xf…...

华为OD-第K长的连续字母字符串长度

题目描述 给定一个字符串&#xff0c;只包含大写字母&#xff0c;求在包含同一字母的子串中&#xff0c;长度第 k 长的子串的长度&#xff0c;相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...

7.4.分块查找

一.分块查找的算法思想&#xff1a; 1.实例&#xff1a; 以上述图片的顺序表为例&#xff0c; 该顺序表的数据元素从整体来看是乱序的&#xff0c;但如果把这些数据元素分成一块一块的小区间&#xff0c; 第一个区间[0,1]索引上的数据元素都是小于等于10的&#xff0c; 第二…...

RocketMQ延迟消息机制

两种延迟消息 RocketMQ中提供了两种延迟消息机制 指定固定的延迟级别 通过在Message中设定一个MessageDelayLevel参数&#xff0c;对应18个预设的延迟级别指定时间点的延迟级别 通过在Message中设定一个DeliverTimeMS指定一个Long类型表示的具体时间点。到了时间点后&#xf…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

浅谈不同二分算法的查找情况

二分算法原理比较简单&#xff0c;但是实际的算法模板却有很多&#xff0c;这一切都源于二分查找问题中的复杂情况和二分算法的边界处理&#xff0c;以下是博主对一些二分算法查找的情况分析。 需要说明的是&#xff0c;以下二分算法都是基于有序序列为升序有序的情况&#xf…...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中&#xff0c;CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时&#xff0c;通常会导致应用响应缓慢&#xff0c;甚至服务不可用&#xff0c;严重影响用户体验和业务运行。因此&#xff0c;掌握一套科学有效的CPU飙高问题排查方法&…...

云原生安全实战:API网关Kong的鉴权与限流详解

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关&#xff08;API Gateway&#xff09; API网关是微服务架构中的核心组件&#xff0c;负责统一管理所有API的流量入口。它像一座…...

华为OD机考-机房布局

import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...