多头自注意力机制的代码实现
文章目录
- 1、自注意力机制
- 2、多头注意力机制
- transformer的整体结构:

1、自注意力机制
- 自注意力机制如下:

- 计算过程:

- 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x) # (N, L, key_size)key = self.W_k(x) # (N, L, key_size)value = self.W_v(x) # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1) # dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)
2、多头注意力机制
- 结构如下:

- 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias) self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x) # (N, L, key_size)key = self.W_k(x) # (N, L, key_size)value = self.W_v(x) # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output
相关文章:
多头自注意力机制的代码实现
文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...
抽象工厂模式
目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...
登录校验-Filter-详解
目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...
堆栈方法区笔记记录
成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...
新版微信小程序获取用户手机号
小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...
CSS实践 —— 悬浮盒子阴影加上移效果
悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...
安全测试基础知识
软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语,但出于我们的目的,我们将评估定义为分析和发现漏洞,而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...
列表首屏毫秒级加载与自动滚动定位方案
引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...
小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic
摘 要 随着互联网的发展,网络技术的发展变得极其重要,所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变,所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...
Fortran 微分方程求解 --ODEPACK
最近涉及到使用Fortran对微分方程求解,我们知道MATLAB已有内置的函数,比如ode家族,ode15s,对应着不同的求解办法。通过查看odepack的官方文档,我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...
8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45
特点: ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数,频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数,可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...
【Python】函数
None类型 思考:若函数没有使用return语句返回数据,那么函数有返回值吗? 答:实际上是有的,Python中有一个特殊的字面量None,其类型是<class ‘NoneType’>,无返回值的函数,实…...
centos安装MySQL 解压版完整教程(按步骤傻瓜式安装
一、卸载系统自带的 Mariadb 查看: rpm -qa|grep mariadb 卸载: rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...
【后端速成 Vue】第一个 Vue 程序
1、为什么要学习 Vue? 为什么使用 Vue? 回想之前,前后端交互的时候,前端收到后端响应的数据,接着将数据渲染到页面上,之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery,但是这两个用起来还是不太…...
Macbook pro M1 安装Ubuntu教程
先讲下心路历程 由于版主最近刚切换到Mac,所以在安装的时候一上手就选择了virutalbox,结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看,才知…...
前端console.log打印内容与后端请求返回数据不一致
后端传值num0 前端打印num1 ,如图,console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...
SQL入门:多表查询
SQL,或者说结构化查询语言(Structured Query Language),是用于管理和操作关系型数据库的标准语言。在本篇文章中,我们将重点介绍SQL中的多表查询,这是一种强大的工具,可以帮助我们从多个相关的表格中获取数据。 数据库…...
【C++】进一步认识模板
🏖️作者:malloc不出对象 ⛺专栏:C的学习之路 👦个人简介:一名双非本科院校大二在读的科班编程菜鸟,努力编程只为赶上各位大佬的步伐🙈🙈 目录 前言一、非类型模板参数二、模板的特…...
Mysql Oracle 区别
1. oracle select *, id需要在星号前加别名,mysql则不需要 mysql语法: select *, id from xin_student_t;oracle语法: select st.*, st.id from xin_student_t st;2. oracle表定义了别名,在查询时可以不用别名指定字段…...
华为OD-第K长的连续字母字符串长度
题目描述 给定一个字符串,只包含大写字母,求在包含同一字母的子串中,长度第 k 长的子串的长度,相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...
C++ ONNX Runtime推理踩坑记:为什么我的全局Session一Run就报ORT_RUNTIME_EXCEPTION?
C ONNX Runtime推理异常解析:全局Session与Env生命周期的陷阱 在C项目中使用ONNX Runtime进行模型推理时,许多开发者都遇到过这样一个令人困惑的场景:明明代码逻辑看起来完全正确,却在调用Session.Run()时突然抛出ORT_RUNTIME_EXC…...
RestTemplate遇到非RESTful接口怎么办?3种表单参数处理方案对比
RestTemplate应对非RESTful接口的实战指南 在现实开发中,我们常常会遇到各种不符合RESTful规范的接口设计。这些接口可能采用传统的表单传参方式,或是混合了路径参数与查询参数的"四不像"设计。本文将深入探讨三种高效处理这类非标准接口的方案…...
Cursor Pro免费激活指南:3步解锁AI编程工具的完整功能
Cursor Pro免费激活指南:3步解锁AI编程工具的完整功能 【免费下载链接】cursor-free-vip [Support 0.45](Multi Language 多语言)自动注册 Cursor Ai ,自动重置机器ID , 免费升级使用Pro 功能: Youve reached your tri…...
Mid-70激光雷达与相机无目标标定:从环境搭建到实战避坑
1. 为什么选择Ubuntu 16.04进行Mid-70标定 最近在给Livox Mid-70激光雷达做相机标定时,我踩了个大坑——在Ubuntu 22.04上折腾了整整两天都没搞定环境配置。后来才发现问题出在版本兼容性上:ROS Kinetic、Ceres 1.14.x和Eigen 3.2.92这几个关键组件在新系…...
文墨共鸣大模型入门指南:Ubuntu 20.04系统下的保姆级部署教程
文墨共鸣大模型入门指南:Ubuntu 20.04系统下的保姆级部署教程 想试试最近挺火的文墨共鸣大模型,但被复杂的部署步骤劝退了?别担心,这篇教程就是为你准备的。咱们今天不谈复杂的原理,就手把手教你,如何在Ub…...
5分钟掌握Vue工作流设计器:workflow-bpmn-modeler终极指南
5分钟掌握Vue工作流设计器:workflow-bpmn-modeler终极指南 【免费下载链接】workflow-bpmn-modeler 🔥 flowable workflow designer based on vue and bpmn.io7.0 项目地址: https://gitcode.com/gh_mirrors/wo/workflow-bpmn-modeler 还在为复杂…...
解锁Claude无限潜能:技能生态系统的构建艺术
解锁Claude无限潜能:技能生态系统的构建艺术 【免费下载链接】awesome-claude-skills A curated list of awesome Claude Skills, resources, and tools for customizing Claude AI workflows 项目地址: https://gitcode.com/GitHub_Trending/aw/awesome-claude-s…...
GLM-4.1V-9B-Base效果展示:艺术画作风格+主题+文化元素三重解析
GLM-4.1V-9B-Base效果展示:艺术画作风格主题文化元素三重解析 1. 视觉理解新标杆:GLM-4.1V-9B-Base简介 GLM-4.1V-9B-Base是智谱开源的一款视觉多模态理解模型,专为图像内容识别、场景描述和目标问答任务而设计。不同于普通的图像识别工具&…...
ssm+java2026年毕设体育赛事管理系统App【源码+论文】
本系统(程序源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、选题背景关于赛事管理问题的研究,现有研究主要以大型综合性体育赛事(如奥运会、亚运会)的信息化管理…...
解锁智能OCR新范式:Pix2Text多模态内容识别技术全解析
解锁智能OCR新范式:Pix2Text多模态内容识别技术全解析 【免费下载链接】Pix2Text Pix In, Latex & Text Out. Recognize Chinese, English Texts, and Math Formulas from Images. 项目地址: https://gitcode.com/gh_mirrors/pi/Pix2Text Pix2Text是一款…...
