多头自注意力机制的代码实现
文章目录
- 1、自注意力机制
- 2、多头注意力机制
- transformer的整体结构:

1、自注意力机制
- 自注意力机制如下:

- 计算过程:

- 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x) # (N, L, key_size)key = self.W_k(x) # (N, L, key_size)value = self.W_v(x) # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1) # dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)
2、多头注意力机制
- 结构如下:

- 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias) self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x) # (N, L, key_size)key = self.W_k(x) # (N, L, key_size)value = self.W_v(x) # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output
相关文章:
多头自注意力机制的代码实现
文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...
抽象工厂模式
目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...
登录校验-Filter-详解
目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...
堆栈方法区笔记记录
成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...
新版微信小程序获取用户手机号
小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...
CSS实践 —— 悬浮盒子阴影加上移效果
悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...
安全测试基础知识
软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语,但出于我们的目的,我们将评估定义为分析和发现漏洞,而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...
列表首屏毫秒级加载与自动滚动定位方案
引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...
小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic
摘 要 随着互联网的发展,网络技术的发展变得极其重要,所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变,所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...
Fortran 微分方程求解 --ODEPACK
最近涉及到使用Fortran对微分方程求解,我们知道MATLAB已有内置的函数,比如ode家族,ode15s,对应着不同的求解办法。通过查看odepack的官方文档,我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...
8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45
特点: ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数,频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数,可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...
【Python】函数
None类型 思考:若函数没有使用return语句返回数据,那么函数有返回值吗? 答:实际上是有的,Python中有一个特殊的字面量None,其类型是<class ‘NoneType’>,无返回值的函数,实…...
centos安装MySQL 解压版完整教程(按步骤傻瓜式安装
一、卸载系统自带的 Mariadb 查看: rpm -qa|grep mariadb 卸载: rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...
【后端速成 Vue】第一个 Vue 程序
1、为什么要学习 Vue? 为什么使用 Vue? 回想之前,前后端交互的时候,前端收到后端响应的数据,接着将数据渲染到页面上,之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery,但是这两个用起来还是不太…...
Macbook pro M1 安装Ubuntu教程
先讲下心路历程 由于版主最近刚切换到Mac,所以在安装的时候一上手就选择了virutalbox,结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看,才知…...
前端console.log打印内容与后端请求返回数据不一致
后端传值num0 前端打印num1 ,如图,console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...
SQL入门:多表查询
SQL,或者说结构化查询语言(Structured Query Language),是用于管理和操作关系型数据库的标准语言。在本篇文章中,我们将重点介绍SQL中的多表查询,这是一种强大的工具,可以帮助我们从多个相关的表格中获取数据。 数据库…...
【C++】进一步认识模板
🏖️作者:malloc不出对象 ⛺专栏:C的学习之路 👦个人简介:一名双非本科院校大二在读的科班编程菜鸟,努力编程只为赶上各位大佬的步伐🙈🙈 目录 前言一、非类型模板参数二、模板的特…...
Mysql Oracle 区别
1. oracle select *, id需要在星号前加别名,mysql则不需要 mysql语法: select *, id from xin_student_t;oracle语法: select st.*, st.id from xin_student_t st;2. oracle表定义了别名,在查询时可以不用别名指定字段…...
华为OD-第K长的连续字母字符串长度
题目描述 给定一个字符串,只包含大写字母,求在包含同一字母的子串中,长度第 k 长的子串的长度,相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...
C++实现分布式网络通信框架RPC(3)--rpc调用端
目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...
循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...
【位运算】消失的两个数字(hard)
消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...
【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
vue3 字体颜色设置的多种方式
在Vue 3中设置字体颜色可以通过多种方式实现,这取决于你是想在组件内部直接设置,还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法: 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...
华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...
