当前位置: 首页 > news >正文

transformer实现词性标注

1、self-attention

1.1、self-attention结构图

上图是 Self-Attention 的结构,在计算的时候需要用到矩阵 Q(查询), K(键值), V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量 x组成的矩阵 X) 或者上一个 Encoder block 的输出。而 QK正是通过 Self-Attention 的输入进行线性变换得到的。

1.2 Q,K,V的计算

Self-Attention 的输入用矩阵 X进行表示,则可以使用线性变阵矩阵 WQWKWV 计算得到 QKV。计算如下图所示,注意 X, Q, K, V每一行都表示一个单词

 3.3 Self-Attention 的输出

得到矩阵 QKV之后就可以计算出 Self-Attention 的输出了,计算的公式如下: 

公式中计算矩阵 Q和 K 每一行向量的内积,为了防止内积过大,因此除以 dk 的平方根。乘以 K 的转置后,得到的矩阵行列数都为 n,n 为句子单词数,这个矩阵可以表示单词之间的 attention 强度。下图为 乘以 的转置,1234 表示的是句子中的单词。

得到 QK^{T} 之后,使用 Softmax 计算每一个单词对于其他单词的 attention 系数,公式中的 Softmax 是对矩阵的每一行进行 Softmax,即每一行的和都变为 1。

对矩阵每一行进行softmax
​​​​​

 

得到 Softmax 矩阵之后可以和 V相乘,得到最终的输出 Z

self-attention输出

 上图中 Softmax 矩阵的第 1 行表示单词 1 与其他所有单词的 attention 系数,最终单词 1 的输出 Z1 等于所有单词 i 的值 Vi 根据 attention 系数的比例加在一起得到,如下图所示:

Zi的计算方法

class Attention(nn.Module):def __init__(self, input_n:int,hidden_n:int):super().__init__()self.hidden_n = hidden_nself.input_n=input_nself.W_q = torch.nn.Linear(input_n, hidden_n)self.W_k = torch.nn.Linear(input_n, hidden_n)self.W_v = torch.nn.Linear(input_n, hidden_n)def forward(self, Q, K, V, mask=None):Q = self.W_q(Q)K = self.W_k(K)V = self.W_v(V)attention_scores = torch.matmul(Q, K.transpose(-2, -1))attention_weights = softmax(attention_scores)output = torch.matmul(attention_weights, V)return output

2、multi-head attention

       

从上图可以看到 Multi-Head Attention 包含多个 Self-Attention 层,首先将输入 X分别传递到 h 个不同的 Self-Attention 中,计算得到 h 个输出矩阵 Z。下图是 h=8 时候的情况,此时会得到 8 个输出矩阵 Z

多个self-attention

 得到 8 个输出矩阵 Z1 到 Z8 之后,Multi-Head Attention 将它们拼接在一起 (Concat),然后传入一个 Linear层,得到 Multi-Head Attention 最终的输出 Z

Multi-Head Attention的输出

 可以看到 Multi-Head Attention 输出的矩阵 Z与其输入的矩阵 X 的维度是一样的。

class MultiHeadAttention(nn.Module):def __init__(self,hidden_n:int, h:int = 2):"""hidden_n: hidden dimensionh: number of heads"""super().__init__()embed_size=hidden_nheads=hself.embed_size = embed_sizeself.heads = heads# 每个head的处理的特征个数self.head_dim = embed_size // heads# 如果不能整除就报错assert (self.head_dim * self.heads == self.embed_size), 'embed_size should be divided by heads'# 三个全连接分别计算qkvself.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)# 输出层self.fc_out = nn.Linear(self.head_dim * self.heads, embed_size)def forward(self, Q, K, V, mask=None):query,values,keys=Q,K,VN = query.shape[0]  # batch# 获取每个句子有多少个单词value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# 维度调整 [b,seq_len,embed_size] ==> [b,seq_len,heads,head_dim]values = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)# 对原始输入数据计算q、k、vvalues = self.values(values)keys = self.keys(keys)queries = self.queries(queries)# 爱因斯坦简记法,用于张量矩阵运算,q和k的转置矩阵相乘# queries.shape = [N, query_len, self.heads, self.head_dim]# keys.shape = [N, keys_len, self.heads, self.head_dim]# energy.shape = [N, heads, query_len, keys_len]energy = torch.einsum('nqhd, nkhd -> nhqk', [queries, keys])# 是否使用mask遮挡t时刻以后的所有q、kif mask is not None:# 将mask中所有为0的位置的元素,在energy中对应位置都置为 -1*10^10energy = energy.masked_fill(mask==0, torch.tensor(-1e10))# 根据公式计算attention, 在最后一个维度上计算softmaxattention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3)# 爱因斯坦简记法矩阵元素,其中query_len == keys_len == value_len# attention.shape = [N, heads, query_len, keys_len]# values.shape = [N, value_len, heads, head_dim]# out.shape = [N, query_len, heads, head_dim]out = torch.einsum('nhql, nlhd -> nqhd', [attention, values])# 维度调整 [N, query_len, heads, head_dim] ==> [N, query_len, heads*head_dim]out = out.reshape(N, query_len, self.heads*self.head_dim)# 全连接,shape不变output = self.fc_out(out)return output

3、transformer block

3.1 encoder blockg构架图

 上图红色部分是 Transformer 的 Encoder block 结构,可以看到是由 Multi-Head Attention, Add & Norm, Feed Forward, Add & Norm 组成的。刚刚已经了解了 Multi-Head Attention 的计算过程,现在了解一下 Add & Norm 和 Feed Forward 部分。

3.2 Add & Norm

Add & Norm 层由 Add 和 Norm 两部分组成,其计算公式如下:

 其中 X表示 Multi-Head Attention 或者 Feed Forward 的输入,MultiHeadAttention(X) 和 FeedForward(X) 表示输出 (输出与输入 X 维度是一样的,所以可以相加)。

Add指 X+MultiHeadAttention(X),是一种残差连接,通常用于解决多层网络训练的问题,可以让网络只关注当前差异的部分,在 ResNet 中经常用到。

残差连接

 Norm指 Layer Normalization,通常用于 RNN 结构,Layer Normalization 会将每一层神经元的输入都转成均值方差都一样的,这样可以加快收敛。

3.3 Feed Forward

Feed Forward 层比较简单,是一个两层的全连接层,第一层的激活函数为 Relu,第二层不使用激活函数,对应的公式如下。

Feed Forward

 X是输入,Feed Forward 最终得到的输出矩阵的维度与 X 一致。

class TransformerBlock(nn.Module):def __init__(self, hidden_n:int, h:int = 2):"""hidden_n: hidden dimensionh: number of heads"""super().__init__()embed_size=hidden_nheads=h# 实例化自注意力模块self.attention =MultiHeadAttention (embed_size, heads)# muti_head之后的layernormself.norm1 = nn.LayerNorm(embed_size)# FFN之后的layernormself.norm2 = nn.LayerNorm(embed_size)forward_expansion=1dropout=0.2# 构建FFN前馈型神经网络self.feed_forward = nn.Sequential(# 第一个全连接层上升特征个数nn.Linear(embed_size, embed_size * forward_expansion),# relu激活nn.ReLU(),# 第二个全连接下降特征个数nn.Linear(embed_size * forward_expansion, embed_size))# dropout层随机杀死神经元self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask=None):attention = self.attention(value, key, query, mask)# 输入和输出做残差连接x = query + attention# layernorm标准化x = self.norm1(x)# dropoutx = self.dropout(x)# FFNffn = self.feed_forward(x)# 残差连接输入和输出forward = ffn + x# layernorm + dropoutout = self.dropout(self.norm2(forward))return out

transformer

import torch.nn as nn
class Transformer(nn.Module):def __init__(self,vocab_size, emb_n: int, hidden_n: int, n:int =3, h:int =2):"""emb_n: number of token embeddingshidden_n: hidden dimensionn: number of layersh: number of heads per layer"""embedding_dim=emb_nsuper().__init__()self.embedding_dim = embedding_dimself.embeddings = nn.Embedding(vocab_size,embedding_dim)self.layers=nn.ModuleList([TransformerBlock(hidden_n,h) for _ in range(n)    ])def forward(self,x):N,seq_len=x.shapeout=self.embeddings(x)for layer in self.layers:out=layer(out,out,out)return out

相关文章:

transformer实现词性标注

1、self-attention 1.1、self-attention结构图 上图是 Self-Attention 的结构,在计算的时候需要用到矩阵 Q(查询), K(键值), V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量 x组成的矩阵 X) 或者上一个 Encoder block 的输出。而 Q, K, V…...

Java中异或操作和OTP算法

最近在研究加密算法,发现异或操作在加密算法中用途特别广,也特别好用。下面以Java语言为例,简单记录一下异或操作,以及在算法中的使用,包括常用的OTP算法。 一,异或操作特征 1, 相同出0&#…...

K8S最新版本集群部署(v1.28) + 容器引擎Docker部署(下)

温故知新 📚第三章 Kubernetes各组件部署📗安装kubectl(可直接跳转到安装kubeadm章节,直接全部安装了)📕下载kubectl安装包📕执行kubectl安装📕验证kubectl 📗安装kubead…...

女子垒球运动的发展·垒球1号位

女子垒球运动的发展 1. 女子垒球运动的起源和发展概述 女子垒球运动,诞生于19世纪末的美国,作为棒球运动的衍生品,经过百年的积淀,已在全球范围内广泛传播,形成了丰富的赛事文化。她的起源,可以追溯到19世…...

Debian 30 周年,生日快乐!

导读近日是 Debian 日,也是由伊恩-默多克(Ian Murdock)创立的 Debian GNU/Linux 通用操作系统和社区支持的 Debian 项目 30 周年纪念日。 不管你信不信,从已故的伊恩-默多克于 1993 年 8 月 16 日宣布成立 Debian 项目&#xff0c…...

字符串匹配的Rabin–Karp算法

leetcode-28 实现strStr() 更熟悉的字符串匹配算法可能是KMP算法, 但在Golang中,使用的是Rabin–Karp算法 一般中文译作 拉宾-卡普算法,由迈克尔拉宾与理查德卡普于1987年提出 “ 要在一段文本中找出单个模式串的一个匹配,此算法具有线性时间的平均复杂度&#xff0…...

傅里叶变换(FFT)笔记存档

参考博客:https://www.luogu.com.cn/blog/command-block/fft-xue-xi-bi-ji 目录: FFT引入复数相关知识单位根及其相关性质DFT过程(难点)DFT结论(重要)IDFT结论(重要)IDFT结论证明&…...

ELK安装、部署、调试 (二) ES的安装部署

ElasticSearch是一个基于Lucene的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎,基于RESTful web接口操作ES,也可以利用Java API。Elasticsearch是用Java开发的,并作为Apache许可条款下的开放源码发布,是当前流行的企业…...

Android 13 - Media框架(8)- MediaExtractor

上一篇我们了解了 GenericSource 需要依赖 IMediaExtractor 完成 demux 工作,这一篇我们就来学习 android media 框架中的第二个服务 media.extractor,看看 IMediaExtractor 是如何创建与工作的。 1、MediaExtractorService media.extractor 和 media.p…...

Flutter 混合开发调试

针对Flutter开发的同学来说,大部分的应用还是Native Flutter的混合开发,所以每次改完Flutter代码,运行整个项目无疑是很费时间的。所以Flutter官方也给我们提供了混合调试的方案【在混合开发模式下进行调试】,这里以Android Stud…...

C语言每日一练------(Day3)

本专栏为c语言练习专栏,适合刚刚学完c语言的初学者。本专栏每天会不定时更新,通过每天练习,进一步对c语言的重难点知识进行更深入的学习。 今天练习题的关键字: 尼科彻斯定理 等差数列 💓博主csdn个人主页&#xff1a…...

14、监测数据采集物联网应用开发步骤(10)

监测数据采集物联网应用开发步骤(9.2) Modbus rtu协议开发 本章节在《监测数据采集物联网应用开发步骤(7)》基础上实现可参考《...开发步骤(7)》调试工具,本章节代码需要调用modbus_tk组件,阅读本章节前建议baidu熟悉modbus rtu协议内容 组件安装modb…...

Linux禅道上修改Apache 和 MySQL 默认端口号

1. 修改Apache默认端口号 80 cd /opt/zbox/etc/apachevim httpd.conf :wq 保存 2. 修改MySQL默认端口号 3306 cd /opt/zbox/etc/mysql vim my.cnf :wq 保存 3. 重启服务 ./zbox restart...

操作教程|通过1Panel开源Linux面板快速安装DataEase

DataEase开源数据可视化分析工具(dataease.io)的在线安装是通过在服务器命令行执行Linux命令来进行的。但是在实际的安装部署过程中,很多数据分析师或者业务人员经常会因为不熟悉Linux操作系统及命令行操作方式,在安装DataEase的过…...

机器学习策略——优化深度学习系统

正交化(Orthogonalization) 老式电视机,有很多旋钮可以用来调整图像的各种性质,对于这些旧式电视,可能有一个旋钮用来调图像垂直方向的高度,另外有一个旋钮用来调图像宽度,也许还有一个旋钮用来…...

ES6中Proxy和Proxy实例

1.Proxy Proxy 这个词的原意是代理,用在这里表示由它来“代理”某些操作,可以译为“代理器” 使用方法 let p new Proxy(target, handler);其中,target 为被代理对象。handler 是一个对象,其声明了代理 target 的一些操作。p 是…...

UDP协议的重要知识点

UDP,即用户数据报协议(User Datagram Protocol),是一个简单的无连接的传输层协议。与TCP相比,UDP提供了更少的错误检查机制,并允许数据包在网络上更快地传输。在这篇博客中,我们将深入探讨UDP的…...

QT6为工程添加资源文件,并在ui界面引用

以添加图片资源为例 右键工程名字(不是最上面的名字),点击添加现有文件 这种方式虽然添加到了工程中,但不能在UI设计界面完成引用。主要原因可能是未把文件放入到项目资源文件中,以下面一种方式可以看出区别。 点击添…...

Python小知识 - 如何使用Python的Flask框架快速开发Web应用

如何使用Python的Flask框架快速开发Web应用 现在越来越多的人把Python作为自己的第一语言来学习,Python的简洁易学的语法以及丰富的第三方库让人们越来越喜欢上了这门语言。本文将介绍如何使用Python的Flask框架快速开发Web应用。 Flask是一个使用Python编写的轻量级…...

Flutter 项目结构文件

1、Flutter项目的文件结构 先helloworld项目,看看它都包含哪些组成部分。首先,来看一下项目的文件结构,如下图所示。 2、介绍上图的内容。 -litb/main.dart文件:整个应用的入口文件,其中的main函数是整个Flutter应…...

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互

物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)

可以使用Sqliteviz这个网站免费编写sql语句,它能够让用户直接在浏览器内练习SQL的语法,不需要安装任何软件。 链接如下: sqliteviz 注意: 在转写SQL语法时,关键字之间有一个特定的顺序,这个顺序会影响到…...

HTML前端开发:JavaScript 常用事件详解

作为前端开发的核心,JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例: 1. onclick - 点击事件 当元素被单击时触发(左键点击) button.onclick function() {alert("按钮被点击了!&…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用

1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...

Webpack性能优化:构建速度与体积优化策略

一、构建速度优化 1、​​升级Webpack和Node.js​​ ​​优化效果​​:Webpack 4比Webpack 3构建时间降低60%-98%。​​原因​​: V8引擎优化(for of替代forEach、Map/Set替代Object)。默认使用更快的md4哈希算法。AST直接从Loa…...

Golang——6、指针和结构体

指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...