当前位置: 首页 > news >正文

多头自注意力机制的代码实现

文章目录

  • 1、自注意力机制
  • 2、多头注意力机制

  • transformer的整体结构:
    在这里插入图片描述

1、自注意力机制

  • 自注意力机制如下:
    在这里插入图片描述
  • 计算过程:
    在这里插入图片描述
  • 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1)	# dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)

2、多头注意力机制

  • 结构如下:
    在这里插入图片描述
  • 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias)        self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output

相关文章:

多头自注意力机制的代码实现

文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...

抽象工厂模式

目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...

登录校验-Filter-详解

目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...

堆栈方法区笔记记录

成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...

新版微信小程序获取用户手机号

小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...

CSS实践 —— 悬浮盒子阴影加上移效果

悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...

安全测试基础知识

软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语&#xff0c;但出于我们的目的&#xff0c;我们将评估定义为分析和发现漏洞&#xff0c;而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...

列表首屏毫秒级加载与自动滚动定位方案

引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...

小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic

摘 要 随着互联网的发展&#xff0c;网络技术的发展变得极其重要&#xff0c;所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变&#xff0c;所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...

Fortran 微分方程求解 --ODEPACK

最近涉及到使用Fortran对微分方程求解&#xff0c;我们知道MATLAB已有内置的函数&#xff0c;比如ode家族&#xff0c;ode15s&#xff0c;对应着不同的求解办法。通过查看odepack的官方文档&#xff0c;我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...

8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45

特点&#xff1a; ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数&#xff0c;频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数&#xff0c;可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...

【Python】函数

None类型 思考&#xff1a;若函数没有使用return语句返回数据&#xff0c;那么函数有返回值吗&#xff1f; 答&#xff1a;实际上是有的&#xff0c;Python中有一个特殊的字面量None&#xff0c;其类型是<class ‘NoneType’>&#xff0c;无返回值的函数&#xff0c;实…...

centos安装MySQL 解压版完整教程(按步骤傻瓜式安装

一、卸载系统自带的 Mariadb 查看&#xff1a; rpm -qa|grep mariadb 卸载&#xff1a; rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...

【后端速成 Vue】第一个 Vue 程序

1、为什么要学习 Vue&#xff1f; 为什么使用 Vue? 回想之前&#xff0c;前后端交互的时候&#xff0c;前端收到后端响应的数据&#xff0c;接着将数据渲染到页面上&#xff0c;之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery&#xff0c;但是这两个用起来还是不太…...

Macbook pro M1 安装Ubuntu教程

先讲下心路历程 由于版主最近刚切换到Mac&#xff0c;所以在安装的时候一上手就选择了virutalbox&#xff0c;结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看&#xff0c;才知…...

前端console.log打印内容与后端请求返回数据不一致

后端传值num0 前端打印num1 ,如图&#xff0c;console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...

SQL入门:多表查询

SQL&#xff0c;或者说结构化查询语言(Structured Query Language)&#xff0c;是用于管理和操作关系型数据库的标准语言。在本篇文章中&#xff0c;我们将重点介绍SQL中的多表查询&#xff0c;这是一种强大的工具&#xff0c;可以帮助我们从多个相关的表格中获取数据。 数据库…...

【C++】进一步认识模板

&#x1f3d6;️作者&#xff1a;malloc不出对象 ⛺专栏&#xff1a;C的学习之路 &#x1f466;个人简介&#xff1a;一名双非本科院校大二在读的科班编程菜鸟&#xff0c;努力编程只为赶上各位大佬的步伐&#x1f648;&#x1f648; 目录 前言一、非类型模板参数二、模板的特…...

Mysql Oracle 区别

1. oracle select *, id需要在星号前加别名&#xff0c;mysql则不需要 mysql语法&#xff1a; select *, id from xin_student_t;oracle语法&#xff1a; select st.*, st.id from xin_student_t st;2. oracle表定义了别名&#xff0c;在查询时可以不用别名指定字段&#xf…...

华为OD-第K长的连续字母字符串长度

题目描述 给定一个字符串&#xff0c;只包含大写字母&#xff0c;求在包含同一字母的子串中&#xff0c;长度第 k 长的子串的长度&#xff0c;相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...

idea大量爆红问题解决

问题描述 在学习和工作中&#xff0c;idea是程序员不可缺少的一个工具&#xff0c;但是突然在有些时候就会出现大量爆红的问题&#xff0c;发现无法跳转&#xff0c;无论是关机重启或者是替换root都无法解决 就是如上所展示的问题&#xff0c;但是程序依然可以启动。 问题解决…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

Qwen3-Embedding-0.6B深度解析:多语言语义检索的轻量级利器

第一章 引言&#xff1a;语义表示的新时代挑战与Qwen3的破局之路 1.1 文本嵌入的核心价值与技术演进 在人工智能领域&#xff0c;文本嵌入技术如同连接自然语言与机器理解的“神经突触”——它将人类语言转化为计算机可计算的语义向量&#xff0c;支撑着搜索引擎、推荐系统、…...

什么是EULA和DPA

文章目录 EULA&#xff08;End User License Agreement&#xff09;DPA&#xff08;Data Protection Agreement&#xff09;一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA&#xff08;End User License Agreement&#xff09; 定义&#xff1a; EULA即…...

多模态大语言模型arxiv论文略读(108)

CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题&#xff1a;CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者&#xff1a;Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

select、poll、epoll 与 Reactor 模式

在高并发网络编程领域&#xff0c;高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表&#xff0c;以及基于它们实现的 Reactor 模式&#xff0c;为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。​ 一、I…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...

腾讯云V3签名

想要接入腾讯云的Api&#xff0c;必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口&#xff0c;但总是卡在签名这一步&#xff0c;最后放弃选择SDK&#xff0c;这次终于自己代码实现。 可能腾讯云翻新了接口文档&#xff0c;现在阅读起来&#xff0c;清晰了很多&…...

SQL Server 触发器调用存储过程实现发送 HTTP 请求

文章目录 需求分析解决第 1 步:前置条件,启用 OLE 自动化方式 1:使用 SQL 实现启用 OLE 自动化方式 2:Sql Server 2005启动OLE自动化方式 3:Sql Server 2008启动OLE自动化第 2 步:创建存储过程第 3 步:创建触发器扩展 - 如何调试?第 1 步:登录 SQL Server 2008第 2 步…...

Android屏幕刷新率与FPS(Frames Per Second) 120hz

Android屏幕刷新率与FPS(Frames Per Second) 120hz 屏幕刷新率是屏幕每秒钟刷新显示内容的次数&#xff0c;单位是赫兹&#xff08;Hz&#xff09;。 60Hz 屏幕&#xff1a;每秒刷新 60 次&#xff0c;每次刷新间隔约 16.67ms 90Hz 屏幕&#xff1a;每秒刷新 90 次&#xff0c;…...