当前位置: 首页 > news >正文

【动手学深度学习-Pytorch版】BERT预测系列——BERTModel

本小节主要实现了以下几部分内容:

  • 从一个句子中提取BERT输入序列以及相对的segments段落索引(因为BERT支持输入两个句子)
  • BERT使用的是Transformer的Encoder部分,所以需要需要使用Encoder进行前向传播:输出的特征等于词嵌入+位置编码+Encoder块
  • 用于BERT预训练时预测的掩蔽语言模型任务中的掩蔽标记< mask >
  • 用于预训练任务的下一个句子的预测——在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
  • 通过BERTModel整合代码

"""可学习的位置编码也需要进行初始化"""
import torch
import d2l.torch
from torch import nn
import transformers
"""将一个句子或者两个句子作为输入,然后返回BERT输入序列及其相应的序列对的片段索引segments"""
def get_tokens_segments(tokens_a,tokens_b=None):"""获取输入序列的词元及其片段索引"""tokens = ['<cls>'] + tokens_a + ['<sep>']# 利用0和1分别标记片段A和片段Bsegments = [0] * (len(tokens_a)+2)  #加上<cls>和sepif tokens_b is not None:# 如果是句子对tokens += tokens_b+['<sep>']segments += [1]*(len(tokens_b)+1)  # 加上<sep>return tokens,segments"""在原始的Transformer架构中,编码器的位置嵌入信息是直接加到了输入序列的每个位置,但是BERT使用的是可学习的位置嵌入"""
"""bert-input = tokens_embedding + position_embedding + segment_embedding"""
class BERTEncoder(nn.Module):"""BERT编码器"""def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=1000,key_size=768,query_size=768,value_size=768,use_bias=True):super(BERTEncoder, self).__init__()self.token_embedding = nn.Embedding(vocab_size,num_hiddens)self.segment_embedding = nn.Embedding(2,num_hiddens)# 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入的参数self.pos_embedding = nn.Parameter(torch.randn(size=(1,max_len,num_hiddens)))# print('self.pos_embedding:',self.pos_embedding)"""self.pos_embedding.data : [1,1000,768]在下面与X相加时利用的是广播机制"""self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module(f'{i}',d2l.torch.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias))def forward(self,tokens,segments,valid_lens):# 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)X = self.token_embedding(tokens)+self.segment_embedding(segments)print('X.shape:',X.shape)   # [2,8,768]X += self.pos_embedding.data[:,:X.shape[1],:]  #[2,8,768]for blk in self.blks:X = blk(X,valid_lens)return X
"""演示BERTEncoder的前向传播--->词表大小:10000"""
vocab_size,num_hiddens,ffn_num_input,ffn_num_hiddens,num_heads,num_layers = 1000,768,768,1024,4,2
norm_shape,dropout = [768],0.2
encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout)
"""将tokens定义为长度为8的2个输入序列"""
tokens = torch.randint(0,vocab_size,(2,8))
print('tokens:',tokens)
print('tokens_shape:',tokens.shape)
"""其中每个词元由向量表示,其长度由超参数num_hiddens定义,此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)"""
segments = torch.tensor([[0,0,0,0,1,1,1,1],[0,0,0,1,1,1,1,1]])
print('segments:',segments)
enc_outputs = encoder(tokens,segments,None)
print('enc_outputs.shape',enc_outputs.shape)# 预训练任务---》双向编码上下文:掩蔽语言模型
"""预测BERT预训练的掩蔽语言模型任务中的掩蔽标记"""
#@save
class MaskLM(nn.Module):"""BERT的掩蔽语言模型任务"""def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):super(MaskLM, self).__init__(**kwargs)# 两层的MLP,同时使用激活函数ReLU  和 层归一化self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),nn.ReLU(),nn.LayerNorm(num_hiddens),nn.Linear(num_hiddens, vocab_size))# 前向传播时的输入信息包括:# 1 BERTEncoder编码结果# 2 用于预测词元的位置def forward(self, X, pred_positions):num_pred_positions = pred_positions.shape[1]# 将预测的位置压缩成一维向量空间pred_positions = pred_positions.reshape(-1)# BERTEncoder的输出特征形状:[batch_size,...]batch_size = X.shape[0]batch_idx = torch.arange(0, batch_size)# 假设batch_size=2,num_pred_positions=3# 那么batch_idx是np.array([0,0,0,1,1,1])# torch.repeat_interleave用于重复张量元素batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)print('输入的X形状:',X.shape)# batch_idx# pred_positions# 都是两个list其中batch_idx选择的是屏蔽的行# pred_positions选择的是屏蔽的列masked_X = X[batch_idx, pred_positions]print('masked后X的内容:',masked_X)# 最后把所有要屏蔽的数据拉成一个一维的向量masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))mlm_Y_hat = self.mlp(masked_X)# 最后返回的是利用MLP预测这些位置的结果return mlm_Y_hat"""将mlm_positions定义为在encoded_X的任一输如系列中预测3个值"""
"""而且对于每一个预测的结果都等于词表的大小"""
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(enc_outputs, mlm_positions)
mlm_Y_hat_shape = mlm_Y_hat.shape
print('mlm_Y_hat_shape:',mlm_Y_hat_shape)# 通过掩码下的预测词元mlm_Y的真实标签mlm_Y_hat,我们可以计算在BERT预训练中的遮蔽语言模型任务的交叉熵损失
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l_shape = mlm_l.shape
print('mlm_l_shape:',mlm_l_shape)# 预训练任务---》下一个句子的预测
"""在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
"""
#@save
class NextSentencePred(nn.Module):"""BERT的下一句预测任务"""def __init__(self, num_inputs, **kwargs):super(NextSentencePred, self).__init__(**kwargs)self.output = nn.Linear(num_inputs, 2)def forward(self, X):# X的形状:(batchsize,num_hiddens)return self.output(X)
"""NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测"""
enc_outputs = torch.flatten(enc_outputs, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(enc_outputs.shape[-1])
nsp_Y_hat = nsp(enc_outputs)
print('nsp_Y_hat.shape',nsp_Y_hat.shape)
# 计算两个二元分类的交叉熵损失
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l_shape = nsp_l.shape
print('nsp_l_shape:',nsp_l_shape)#@save
class BERTModel(nn.Module):"""BERT模型"""def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768,hid_in_features=768, mlm_in_features=768,nsp_in_features=768):super(BERTModel, self).__init__()self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,ffn_num_input, ffn_num_hiddens, num_heads, num_layers,dropout, max_len=max_len, key_size=key_size,query_size=query_size, value_size=value_size)self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),nn.Tanh())self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)self.nsp = NextSentencePred(nsp_in_features)def forward(self, tokens, segments, valid_lens=None,pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

相关文章:

【动手学深度学习-Pytorch版】BERT预测系列——BERTModel

本小节主要实现了以下几部分内容&#xff1a; 从一个句子中提取BERT输入序列以及相对的segments段落索引&#xff08;因为BERT支持输入两个句子&#xff09;BERT使用的是Transformer的Encoder部分&#xff0c;所以需要需要使用Encoder进行前向传播&#xff1a;输出的特征等于词…...

Python之元组、字典和集合练习

1、餐厅下午茶 &#xff08;列表与元组 crr66&#xff09; 某餐厅推出了优惠下午茶套餐活动。顾客可以以优惠的价格从给定的糕点和给定的饮 料中各选一款组成套餐。已知&#xff0c;指定的糕点包括松饼(Muffins)、提拉米苏(Tiramisu)、芝士蛋 糕(Cheese Cake)和三明治(Sandwic…...

【数据结构】归并排序和计数排序(排序的总结)

目录 一&#xff0c;归并排序的递归 二&#xff0c;归并排序的非递归 三&#xff0c;计数排序 四&#xff0c;排序算法的综合分析 一&#xff0c;归并排序的递归 基本思想&#xff1a; 归并采用的是分治思想&#xff0c;是分治法的一个经典的运用。该算法先将原数据进行拆…...

某医疗机构:建立S-SDLC安全开发流程,保障医疗前沿科技应用高质量发展

某医疗机构是头部资本集团旗下专注大健康领域战略性投资与运营的实业公司&#xff0c;市场规模超300亿。该医疗机构已完成数字赋能&#xff0c;形成了标准化、专业化、数字化的疾病和健康管理体系&#xff0c;将进一步规划战略方向&#xff0c;为人工智能纳米技术、高温超导、生…...

验证二叉搜索树的后序遍历序列

LCR 152. 验证二叉搜索树的后序遍历序列 class VerifyTreeOrder:"""LCR 152. 验证二叉搜索树的后序遍历序列https://leetcode.cn/problems/er-cha-sou-suo-shu-de-hou-xu-bian-li-xu-lie-lcof/description/"""def solution(self, postorder: Lis…...

第三章 内存管理 一、内存的基础知识

目录 一、什么是内存 二、有何作用 三、常用数量单位 四、指令的工作原理 五、装入方式 1、绝对装入 2、可重定位装入&#xff08;静态重定位&#xff09; 3、动态运行时装入&#xff08;动态重定位&#xff09; 六、从写程序到程序运行 七、链接的三种方式 1、静态…...

【Java学习之道】Java常用集合框架

引言 在Java中&#xff0c;集合框架是一个非常重要的概念。它提供了一种方式&#xff0c;让你可以方便地存储和操作数据。Java中的集合框架包括各种集合类和接口&#xff0c;这些类和接口提供了不同的功能和特性。通过学习和掌握Java的集合框架&#xff0c;你可以更好地管理和…...

logicFlow 流程图编辑工具使用及开源地址

一、工具介绍 LogicFlow 是一款流程图编辑框架&#xff0c;提供了一系列流程图交互、编辑所必需的功能和灵活的节点自定义、插件等拓展机制。LogicFlow 支持前端研发自定义开发各种逻辑编排场景&#xff0c;如流程图、ER 图、BPMN 流程等。在工作审批配置、机器人逻辑编排、无…...

ATF(TF-A)/OPTEE之动态代码分析汇总

安全之安全(security)博客目录导读 1、ASAN(AddressSanitizer)地址消毒动态代码分析 2、ATF(TF-A)之UBSAN动态代码分析 3、OPTEE之KASAN地址消毒动态代码分析...

10-11 周三 shell xargs tr curl 做大事情

最近发现&#xff0c;shell的小工具非常的强大&#xff0c;简单记录下 tr命令 -d 删除字符串1中所有输入字符。-s 删除所有重复出现字符序列&#xff0c;只保留第一个&#xff1b;即将重复出现字符串压缩为一个字符串 -d 用于删除查询到的字符串中的空格。 [test3NH-DC-NM1…...

1.1 向量与线性组合

一、向量的基础知识 两个独立的数字 v 1 v_1 v1​ 和 v 2 v_2 v2​&#xff0c;将它们配对可以产生一个二维向量 v \boldsymbol{v} v&#xff1a; 列向量 v v [ v 1 v 2 ] v 1 v 的第一个分量 v 2 v 的第二个分量 \textbf{列向量}\,\boldsymbol v\kern 10pt\boldsymbol …...

django: You may need to add ‘localhost‘ to ALLOWED_HOSTS

参考:https://blog.csdn.net/qq_21744873/article/details/87857279 python manage.py runserver后页面访问失败&#xff0c;提示&#xff1a; DisallowedHost at /admin/ Invalid HTTP_HOST header: ‘localhost:8000’. You may need to add ‘localhost’ to ALLOWED_HOSTS…...

网络安全(黑客技术)—自学手册

1.网络安全是什么 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 2.网络安全市场 一、是市场需求量高&#xff1b; 二、则是发展相对成熟…...

【Vue】之Vuex的入门使用,取值,修改值,同异步请求处理---保姆级别教学

一&#xff0c;Vuex入门 1.1 什么是Vuex Vuex是一个专门为Vue.js应用程序开发的状态管理库。它用于管理应用程序中的共享状态&#xff0c;它采用集中式存储管理应用的所有组件的状态&#xff0c;使得状态的管理变得简单和可预测 官方解释&#xff1a;Vuex 是一个专为 Vue.js 应…...

ubuntu20.04 nerf Instant-ngp (下) 复现,自建数据集,导出mesh

参考链接 Ubuntu20.04复现instant-ngp&#xff0c;自建数据集&#xff0c;导出mesh_XINYU W的博客-CSDN博客 GitHub - NVlabs/instant-ngp: Instant neural graphics primitives: lightning fast NeRF and more youtube上的一个博主自建数据集 https://www.youtube.com/watch…...

【常见错误】SVN提交项目时,出现了这样的提示:“XXX“ is scheduled for addition, but is missing。

SVN提交项目时&#xff0c;出现了这样的提示&#xff1a;“XXX“ is scheduled for addition, but is missing。 原因是&#xff1a;之前用SVN提交过的文件/文件夹&#xff0c;被标记为"addition"状态&#xff0c;等待被加入到仓库。虽然你把这个文件删除了&#xf…...

深度学习基础知识 给模型的不同层 设置不同学习率

深度学习基础知识 给模型的不同层 设置不同学习率 1、使用预训练模型时&#xff0c;可能需要将2、学习率设置方式&#xff1a; 1、使用预训练模型时&#xff0c;可能需要将 &#xff08;1&#xff09;预训练好的 backbone 的 参数学习率设置为较小值&#xff0c; &#xff08;2…...

【Python 零基础入门】 Numpy

【Python 零基础入门】第六课 Numpy 概述什么是 Numpy?Numpy 与 Python 数组的区别并发 vs 并行单线程 vs 多线程GILNumpy 在数据科学中的重要性 Numpy 安装Anaconda导包 ndarraynp.array 创建数组属性np.zeros 创建np.ones 创建 数组的切片和索引基本索引切片操作数组运算 常…...

1600*C. Circle of Monsters(贪心)

Problem - 1334C - Codeforces 解析&#xff1a; 对于某个怪兽&#xff0c;他的耗费为两种情况&#xff0c;要么直接用子弹打&#xff0c;要么被前面的怪兽炸&#xff0c;显然第二种情况耗费更少。 统计出所有怪兽的 max&#xff08;0&#xff0c;a[ i ] - b[ i - 1 ]&#xff…...

国外互联网巨头常用的项目管理工具揭秘

大型互联网公司有涉及多个团队和利益相关者的复杂项目。为了保持项目的组织性和效率&#xff0c;他们中的许多人依赖于项目管理工具。这些工具有助于跟踪任务&#xff0c;与团队成员沟通&#xff0c;并监控进度。让我们来看看一些大型互联网公司正在使用的项目管理工具。 1、Zo…...

sql 注入(4), 盲注

sql 注入, 盲注 盲注适合在页面没有任何回显时使用. 测试页面有变化, 但是没有显示任何异常错误等信息. 情景: url: http://192.168.112.200/security/read.php?id1 服务器数据库名: learn一, boolean盲注 # 盲注可能需要一个一个字符去试探, 字符串处理函数经常会用到. 比…...

【string题解 C++】字符串相乘 | 翻转字符串III:翻转单词

字符串相乘 题面 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 给定两个以字符串形式表示的非负整数 num1 和 num2&#xff0c;返回 num1 和 num2 的乘积&#xff0c;它们的乘积也表示为字符串形式。 注意&#xff1a;不能使用任何内置的 BigIn…...

CentOS 7下JumpServer安装及配置(超详细版)

前言 Jumpserver是一种用于访问和管理远程设备的Web应用程序&#xff0c;通常用于对服务器进行安全访问。它基于SSH协议&#xff0c;提供了一个安全和可管理的环境来管理SSH访问。Jumpserver是基于Python开发的一款开源工具&#xff0c;其提供了强大的访问控制功能&#xff0c;…...

基于 ACK Fluid 的混合云优化数据访问(五):自动化跨区域中心数据分发

作者&#xff1a;车漾 前文回顾&#xff1a; 本系列将介绍如何基于 ACK Fluid 支持和优化混合云的数据访问场景&#xff0c;相关文章请参考&#xff1a; -基于 ACK Fluid 的混合云优化数据访问&#xff08;一&#xff09;&#xff1a;场景与架构 -基于 ACK Fluid 的混合云优…...

sentinel的启动与运行

首先我们github下载sentinel Releases alibaba/Sentinel (github.com) 下载好了后输入命令让它运行即可&#xff0c;使用cmd窗口输入一下命令即可 java -Dserver.port8089 -jar sentinel-dashboard-1.8.6.jar 账号密码默认都是sentinel 启动成功后登录进去效果如下...

模拟量采集无线WiFi网络接口TCP Server, UDP, MQTT

● 4-20mA信号转换成标准Modbus TCP协议 ● 支持TCP Server, UDP, MQTT等通讯协议 ● 内置网页功能&#xff0c;可以通过网页查询数据 ● 宽电源供电范围&#xff1a;8 ~ 32VDC ● 可靠性高&#xff0c;编程方便&#xff0c;易于应用 ● 标准DIN35导轨安装&#xff0c;方便…...

五、OSPF动态路由实验

拓扑图&#xff1a; 基本ip的配置已经配置好了&#xff0c;接下来对两台路由器配置ospf协议&#xff0c;两台PC进行跨网段通讯 R1与R2构成单区域OSPF区域0&#xff0c;首先对R1进行配置 首先进入ospf 默认进程1&#xff0c;router id省略空缺&#xff0c;之后进入area 0区域&…...

系统架构设计:16 论软件开发过程RUP及其应用

目录 一 统一过程RUP 1 典型特点 2 四个阶段 (1)构思阶段(初始阶段/初启阶段)...

Gralloc ION DMABUF in Camera Display

目录 Background knowledge Introduction ia pa va and memory addressing Memory Addressing Page Frame Management Memory area management DMA IOVA and IOMMU Introduce DMABUF What is DMABUF DMABUF 关键概念 DMABUF APIS –The Exporter DMABUF APIS –The…...

【LVS】lvs的四种模式的区别是什么?

LVS中的DR模式、NAT模式、TUN模式和FANT模式是四种不同的负载均衡模式&#xff0c;它们之间的主要区别在于数据包转发方式和网络地址转换。 DR模式&#xff08;Direct Routing&#xff09;&#xff1a;此模式通过改写请求报文的目标MAC地址&#xff0c;将请求发给真实服务器&a…...