当前位置: 首页 > news >正文

多头自注意力机制的代码实现

文章目录

  • 1、自注意力机制
  • 2、多头注意力机制

  • transformer的整体结构:
    在这里插入图片描述

1、自注意力机制

  • 自注意力机制如下:
    在这里插入图片描述
  • 计算过程:
    在这里插入图片描述
  • 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1)	# dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)

2、多头注意力机制

  • 结构如下:
    在这里插入图片描述
  • 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias)        self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output

相关文章:

多头自注意力机制的代码实现

文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...

抽象工厂模式

目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...

登录校验-Filter-详解

目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...

堆栈方法区笔记记录

成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...

新版微信小程序获取用户手机号

小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...

CSS实践 —— 悬浮盒子阴影加上移效果

悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...

安全测试基础知识

软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语&#xff0c;但出于我们的目的&#xff0c;我们将评估定义为分析和发现漏洞&#xff0c;而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...

列表首屏毫秒级加载与自动滚动定位方案

引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...

小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic

摘 要 随着互联网的发展&#xff0c;网络技术的发展变得极其重要&#xff0c;所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变&#xff0c;所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...

Fortran 微分方程求解 --ODEPACK

最近涉及到使用Fortran对微分方程求解&#xff0c;我们知道MATLAB已有内置的函数&#xff0c;比如ode家族&#xff0c;ode15s&#xff0c;对应着不同的求解办法。通过查看odepack的官方文档&#xff0c;我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...

8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45

特点&#xff1a; ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数&#xff0c;频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数&#xff0c;可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...

【Python】函数

None类型 思考&#xff1a;若函数没有使用return语句返回数据&#xff0c;那么函数有返回值吗&#xff1f; 答&#xff1a;实际上是有的&#xff0c;Python中有一个特殊的字面量None&#xff0c;其类型是<class ‘NoneType’>&#xff0c;无返回值的函数&#xff0c;实…...

centos安装MySQL 解压版完整教程(按步骤傻瓜式安装

一、卸载系统自带的 Mariadb 查看&#xff1a; rpm -qa|grep mariadb 卸载&#xff1a; rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...

【后端速成 Vue】第一个 Vue 程序

1、为什么要学习 Vue&#xff1f; 为什么使用 Vue? 回想之前&#xff0c;前后端交互的时候&#xff0c;前端收到后端响应的数据&#xff0c;接着将数据渲染到页面上&#xff0c;之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery&#xff0c;但是这两个用起来还是不太…...

Macbook pro M1 安装Ubuntu教程

先讲下心路历程 由于版主最近刚切换到Mac&#xff0c;所以在安装的时候一上手就选择了virutalbox&#xff0c;结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看&#xff0c;才知…...

前端console.log打印内容与后端请求返回数据不一致

后端传值num0 前端打印num1 ,如图&#xff0c;console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...

SQL入门:多表查询

SQL&#xff0c;或者说结构化查询语言(Structured Query Language)&#xff0c;是用于管理和操作关系型数据库的标准语言。在本篇文章中&#xff0c;我们将重点介绍SQL中的多表查询&#xff0c;这是一种强大的工具&#xff0c;可以帮助我们从多个相关的表格中获取数据。 数据库…...

【C++】进一步认识模板

&#x1f3d6;️作者&#xff1a;malloc不出对象 ⛺专栏&#xff1a;C的学习之路 &#x1f466;个人简介&#xff1a;一名双非本科院校大二在读的科班编程菜鸟&#xff0c;努力编程只为赶上各位大佬的步伐&#x1f648;&#x1f648; 目录 前言一、非类型模板参数二、模板的特…...

Mysql Oracle 区别

1. oracle select *, id需要在星号前加别名&#xff0c;mysql则不需要 mysql语法&#xff1a; select *, id from xin_student_t;oracle语法&#xff1a; select st.*, st.id from xin_student_t st;2. oracle表定义了别名&#xff0c;在查询时可以不用别名指定字段&#xf…...

华为OD-第K长的连续字母字符串长度

题目描述 给定一个字符串&#xff0c;只包含大写字母&#xff0c;求在包含同一字母的子串中&#xff0c;长度第 k 长的子串的长度&#xff0c;相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

pikachu靶场通关笔记22-1 SQL注入05-1-insert注入(报错法)

目录 一、SQL注入 二、insert注入 三、报错型注入 四、updatexml函数 五、源码审计 六、insert渗透实战 1、渗透准备 2、获取数据库名database 3、获取表名table 4、获取列名column 5、获取字段 本系列为通过《pikachu靶场通关笔记》的SQL注入关卡(共10关&#xff0…...

从 GreenPlum 到镜舟数据库:杭银消费金融湖仓一体转型实践

作者&#xff1a;吴岐诗&#xff0c;杭银消费金融大数据应用开发工程师 本文整理自杭银消费金融大数据应用开发工程师在StarRocks Summit Asia 2024的分享 引言&#xff1a;融合数据湖与数仓的创新之路 在数字金融时代&#xff0c;数据已成为金融机构的核心竞争力。杭银消费金…...

论文阅读:LLM4Drive: A Survey of Large Language Models for Autonomous Driving

地址&#xff1a;LLM4Drive: A Survey of Large Language Models for Autonomous Driving 摘要翻译 自动驾驶技术作为推动交通和城市出行变革的催化剂&#xff0c;正从基于规则的系统向数据驱动策略转变。传统的模块化系统受限于级联模块间的累积误差和缺乏灵活性的预设规则。…...

Vue 模板语句的数据来源

&#x1f9e9; Vue 模板语句的数据来源&#xff1a;全方位解析 Vue 模板&#xff08;<template> 部分&#xff09;中的表达式、指令绑定&#xff08;如 v-bind, v-on&#xff09;和插值&#xff08;{{ }}&#xff09;都在一个特定的作用域内求值。这个作用域由当前 组件…...

第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10+pip3.10)

第一篇&#xff1a;Liunx环境下搭建PaddlePaddle 3.0基础环境&#xff08;Liunx Centos8.5安装Python3.10pip3.10&#xff09; 一&#xff1a;前言二&#xff1a;安装编译依赖二&#xff1a;安装Python3.10三&#xff1a;安装PIP3.10四&#xff1a;安装Paddlepaddle基础框架4.1…...

C++_哈希表

本篇文章是对C学习的哈希表部分的学习分享 相信一定会对你有所帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、基础概念 1. 哈希核心思想&#xff1a; 哈希函数的作用&#xff1a;通过此函数建立一个Key与存储位置之间的映射关系。理想目标&#xff1a;实现…...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案

引言 在分布式系统的事务处理中&#xff0c;如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议&#xff08;2PC&#xff09;通过准备阶段与提交阶段的协调机制&#xff0c;以同步决策模式确保事务原子性。其改进版本三阶段提交协议&#xff08;3PC&#xf…...