当前位置: 首页 > news >正文

【动手学深度学习-Pytorch版】BERT预测系列——BERTModel

本小节主要实现了以下几部分内容:

  • 从一个句子中提取BERT输入序列以及相对的segments段落索引(因为BERT支持输入两个句子)
  • BERT使用的是Transformer的Encoder部分,所以需要需要使用Encoder进行前向传播:输出的特征等于词嵌入+位置编码+Encoder块
  • 用于BERT预训练时预测的掩蔽语言模型任务中的掩蔽标记< mask >
  • 用于预训练任务的下一个句子的预测——在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
  • 通过BERTModel整合代码

"""可学习的位置编码也需要进行初始化"""
import torch
import d2l.torch
from torch import nn
import transformers
"""将一个句子或者两个句子作为输入,然后返回BERT输入序列及其相应的序列对的片段索引segments"""
def get_tokens_segments(tokens_a,tokens_b=None):"""获取输入序列的词元及其片段索引"""tokens = ['<cls>'] + tokens_a + ['<sep>']# 利用0和1分别标记片段A和片段Bsegments = [0] * (len(tokens_a)+2)  #加上<cls>和sepif tokens_b is not None:# 如果是句子对tokens += tokens_b+['<sep>']segments += [1]*(len(tokens_b)+1)  # 加上<sep>return tokens,segments"""在原始的Transformer架构中,编码器的位置嵌入信息是直接加到了输入序列的每个位置,但是BERT使用的是可学习的位置嵌入"""
"""bert-input = tokens_embedding + position_embedding + segment_embedding"""
class BERTEncoder(nn.Module):"""BERT编码器"""def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,max_len=1000,key_size=768,query_size=768,value_size=768,use_bias=True):super(BERTEncoder, self).__init__()self.token_embedding = nn.Embedding(vocab_size,num_hiddens)self.segment_embedding = nn.Embedding(2,num_hiddens)# 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入的参数self.pos_embedding = nn.Parameter(torch.randn(size=(1,max_len,num_hiddens)))# print('self.pos_embedding:',self.pos_embedding)"""self.pos_embedding.data : [1,1000,768]在下面与X相加时利用的是广播机制"""self.blks = nn.Sequential()for i in range(num_layers):self.blks.add_module(f'{i}',d2l.torch.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias))def forward(self,tokens,segments,valid_lens):# 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)X = self.token_embedding(tokens)+self.segment_embedding(segments)print('X.shape:',X.shape)   # [2,8,768]X += self.pos_embedding.data[:,:X.shape[1],:]  #[2,8,768]for blk in self.blks:X = blk(X,valid_lens)return X
"""演示BERTEncoder的前向传播--->词表大小:10000"""
vocab_size,num_hiddens,ffn_num_input,ffn_num_hiddens,num_heads,num_layers = 1000,768,768,1024,4,2
norm_shape,dropout = [768],0.2
encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout)
"""将tokens定义为长度为8的2个输入序列"""
tokens = torch.randint(0,vocab_size,(2,8))
print('tokens:',tokens)
print('tokens_shape:',tokens.shape)
"""其中每个词元由向量表示,其长度由超参数num_hiddens定义,此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)"""
segments = torch.tensor([[0,0,0,0,1,1,1,1],[0,0,0,1,1,1,1,1]])
print('segments:',segments)
enc_outputs = encoder(tokens,segments,None)
print('enc_outputs.shape',enc_outputs.shape)# 预训练任务---》双向编码上下文:掩蔽语言模型
"""预测BERT预训练的掩蔽语言模型任务中的掩蔽标记"""
#@save
class MaskLM(nn.Module):"""BERT的掩蔽语言模型任务"""def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):super(MaskLM, self).__init__(**kwargs)# 两层的MLP,同时使用激活函数ReLU  和 层归一化self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),nn.ReLU(),nn.LayerNorm(num_hiddens),nn.Linear(num_hiddens, vocab_size))# 前向传播时的输入信息包括:# 1 BERTEncoder编码结果# 2 用于预测词元的位置def forward(self, X, pred_positions):num_pred_positions = pred_positions.shape[1]# 将预测的位置压缩成一维向量空间pred_positions = pred_positions.reshape(-1)# BERTEncoder的输出特征形状:[batch_size,...]batch_size = X.shape[0]batch_idx = torch.arange(0, batch_size)# 假设batch_size=2,num_pred_positions=3# 那么batch_idx是np.array([0,0,0,1,1,1])# torch.repeat_interleave用于重复张量元素batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)print('输入的X形状:',X.shape)# batch_idx# pred_positions# 都是两个list其中batch_idx选择的是屏蔽的行# pred_positions选择的是屏蔽的列masked_X = X[batch_idx, pred_positions]print('masked后X的内容:',masked_X)# 最后把所有要屏蔽的数据拉成一个一维的向量masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))mlm_Y_hat = self.mlp(masked_X)# 最后返回的是利用MLP预测这些位置的结果return mlm_Y_hat"""将mlm_positions定义为在encoded_X的任一输如系列中预测3个值"""
"""而且对于每一个预测的结果都等于词表的大小"""
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(enc_outputs, mlm_positions)
mlm_Y_hat_shape = mlm_Y_hat.shape
print('mlm_Y_hat_shape:',mlm_Y_hat_shape)# 通过掩码下的预测词元mlm_Y的真实标签mlm_Y_hat,我们可以计算在BERT预训练中的遮蔽语言模型任务的交叉熵损失
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l_shape = mlm_l.shape
print('mlm_l_shape:',mlm_l_shape)# 预训练任务---》下一个句子的预测
"""在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
"""
#@save
class NextSentencePred(nn.Module):"""BERT的下一句预测任务"""def __init__(self, num_inputs, **kwargs):super(NextSentencePred, self).__init__(**kwargs)self.output = nn.Linear(num_inputs, 2)def forward(self, X):# X的形状:(batchsize,num_hiddens)return self.output(X)
"""NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测"""
enc_outputs = torch.flatten(enc_outputs, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(enc_outputs.shape[-1])
nsp_Y_hat = nsp(enc_outputs)
print('nsp_Y_hat.shape',nsp_Y_hat.shape)
# 计算两个二元分类的交叉熵损失
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l_shape = nsp_l.shape
print('nsp_l_shape:',nsp_l_shape)#@save
class BERTModel(nn.Module):"""BERT模型"""def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,ffn_num_hiddens, num_heads, num_layers, dropout,max_len=1000, key_size=768, query_size=768, value_size=768,hid_in_features=768, mlm_in_features=768,nsp_in_features=768):super(BERTModel, self).__init__()self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,ffn_num_input, ffn_num_hiddens, num_heads, num_layers,dropout, max_len=max_len, key_size=key_size,query_size=query_size, value_size=value_size)self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),nn.Tanh())self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)self.nsp = NextSentencePred(nsp_in_features)def forward(self, tokens, segments, valid_lens=None,pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

相关文章:

【动手学深度学习-Pytorch版】BERT预测系列——BERTModel

本小节主要实现了以下几部分内容&#xff1a; 从一个句子中提取BERT输入序列以及相对的segments段落索引&#xff08;因为BERT支持输入两个句子&#xff09;BERT使用的是Transformer的Encoder部分&#xff0c;所以需要需要使用Encoder进行前向传播&#xff1a;输出的特征等于词…...

Python之元组、字典和集合练习

1、餐厅下午茶 &#xff08;列表与元组 crr66&#xff09; 某餐厅推出了优惠下午茶套餐活动。顾客可以以优惠的价格从给定的糕点和给定的饮 料中各选一款组成套餐。已知&#xff0c;指定的糕点包括松饼(Muffins)、提拉米苏(Tiramisu)、芝士蛋 糕(Cheese Cake)和三明治(Sandwic…...

【数据结构】归并排序和计数排序(排序的总结)

目录 一&#xff0c;归并排序的递归 二&#xff0c;归并排序的非递归 三&#xff0c;计数排序 四&#xff0c;排序算法的综合分析 一&#xff0c;归并排序的递归 基本思想&#xff1a; 归并采用的是分治思想&#xff0c;是分治法的一个经典的运用。该算法先将原数据进行拆…...

某医疗机构:建立S-SDLC安全开发流程,保障医疗前沿科技应用高质量发展

某医疗机构是头部资本集团旗下专注大健康领域战略性投资与运营的实业公司&#xff0c;市场规模超300亿。该医疗机构已完成数字赋能&#xff0c;形成了标准化、专业化、数字化的疾病和健康管理体系&#xff0c;将进一步规划战略方向&#xff0c;为人工智能纳米技术、高温超导、生…...

验证二叉搜索树的后序遍历序列

LCR 152. 验证二叉搜索树的后序遍历序列 class VerifyTreeOrder:"""LCR 152. 验证二叉搜索树的后序遍历序列https://leetcode.cn/problems/er-cha-sou-suo-shu-de-hou-xu-bian-li-xu-lie-lcof/description/"""def solution(self, postorder: Lis…...

第三章 内存管理 一、内存的基础知识

目录 一、什么是内存 二、有何作用 三、常用数量单位 四、指令的工作原理 五、装入方式 1、绝对装入 2、可重定位装入&#xff08;静态重定位&#xff09; 3、动态运行时装入&#xff08;动态重定位&#xff09; 六、从写程序到程序运行 七、链接的三种方式 1、静态…...

【Java学习之道】Java常用集合框架

引言 在Java中&#xff0c;集合框架是一个非常重要的概念。它提供了一种方式&#xff0c;让你可以方便地存储和操作数据。Java中的集合框架包括各种集合类和接口&#xff0c;这些类和接口提供了不同的功能和特性。通过学习和掌握Java的集合框架&#xff0c;你可以更好地管理和…...

logicFlow 流程图编辑工具使用及开源地址

一、工具介绍 LogicFlow 是一款流程图编辑框架&#xff0c;提供了一系列流程图交互、编辑所必需的功能和灵活的节点自定义、插件等拓展机制。LogicFlow 支持前端研发自定义开发各种逻辑编排场景&#xff0c;如流程图、ER 图、BPMN 流程等。在工作审批配置、机器人逻辑编排、无…...

ATF(TF-A)/OPTEE之动态代码分析汇总

安全之安全(security)博客目录导读 1、ASAN(AddressSanitizer)地址消毒动态代码分析 2、ATF(TF-A)之UBSAN动态代码分析 3、OPTEE之KASAN地址消毒动态代码分析...

10-11 周三 shell xargs tr curl 做大事情

最近发现&#xff0c;shell的小工具非常的强大&#xff0c;简单记录下 tr命令 -d 删除字符串1中所有输入字符。-s 删除所有重复出现字符序列&#xff0c;只保留第一个&#xff1b;即将重复出现字符串压缩为一个字符串 -d 用于删除查询到的字符串中的空格。 [test3NH-DC-NM1…...

1.1 向量与线性组合

一、向量的基础知识 两个独立的数字 v 1 v_1 v1​ 和 v 2 v_2 v2​&#xff0c;将它们配对可以产生一个二维向量 v \boldsymbol{v} v&#xff1a; 列向量 v v [ v 1 v 2 ] v 1 v 的第一个分量 v 2 v 的第二个分量 \textbf{列向量}\,\boldsymbol v\kern 10pt\boldsymbol …...

django: You may need to add ‘localhost‘ to ALLOWED_HOSTS

参考:https://blog.csdn.net/qq_21744873/article/details/87857279 python manage.py runserver后页面访问失败&#xff0c;提示&#xff1a; DisallowedHost at /admin/ Invalid HTTP_HOST header: ‘localhost:8000’. You may need to add ‘localhost’ to ALLOWED_HOSTS…...

网络安全(黑客技术)—自学手册

1.网络安全是什么 网络安全可以基于攻击和防御视角来分类&#xff0c;我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术&#xff0c;而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 2.网络安全市场 一、是市场需求量高&#xff1b; 二、则是发展相对成熟…...

【Vue】之Vuex的入门使用,取值,修改值,同异步请求处理---保姆级别教学

一&#xff0c;Vuex入门 1.1 什么是Vuex Vuex是一个专门为Vue.js应用程序开发的状态管理库。它用于管理应用程序中的共享状态&#xff0c;它采用集中式存储管理应用的所有组件的状态&#xff0c;使得状态的管理变得简单和可预测 官方解释&#xff1a;Vuex 是一个专为 Vue.js 应…...

ubuntu20.04 nerf Instant-ngp (下) 复现,自建数据集,导出mesh

参考链接 Ubuntu20.04复现instant-ngp&#xff0c;自建数据集&#xff0c;导出mesh_XINYU W的博客-CSDN博客 GitHub - NVlabs/instant-ngp: Instant neural graphics primitives: lightning fast NeRF and more youtube上的一个博主自建数据集 https://www.youtube.com/watch…...

【常见错误】SVN提交项目时,出现了这样的提示:“XXX“ is scheduled for addition, but is missing。

SVN提交项目时&#xff0c;出现了这样的提示&#xff1a;“XXX“ is scheduled for addition, but is missing。 原因是&#xff1a;之前用SVN提交过的文件/文件夹&#xff0c;被标记为"addition"状态&#xff0c;等待被加入到仓库。虽然你把这个文件删除了&#xf…...

深度学习基础知识 给模型的不同层 设置不同学习率

深度学习基础知识 给模型的不同层 设置不同学习率 1、使用预训练模型时&#xff0c;可能需要将2、学习率设置方式&#xff1a; 1、使用预训练模型时&#xff0c;可能需要将 &#xff08;1&#xff09;预训练好的 backbone 的 参数学习率设置为较小值&#xff0c; &#xff08;2…...

【Python 零基础入门】 Numpy

【Python 零基础入门】第六课 Numpy 概述什么是 Numpy?Numpy 与 Python 数组的区别并发 vs 并行单线程 vs 多线程GILNumpy 在数据科学中的重要性 Numpy 安装Anaconda导包 ndarraynp.array 创建数组属性np.zeros 创建np.ones 创建 数组的切片和索引基本索引切片操作数组运算 常…...

1600*C. Circle of Monsters(贪心)

Problem - 1334C - Codeforces 解析&#xff1a; 对于某个怪兽&#xff0c;他的耗费为两种情况&#xff0c;要么直接用子弹打&#xff0c;要么被前面的怪兽炸&#xff0c;显然第二种情况耗费更少。 统计出所有怪兽的 max&#xff08;0&#xff0c;a[ i ] - b[ i - 1 ]&#xff…...

国外互联网巨头常用的项目管理工具揭秘

大型互联网公司有涉及多个团队和利益相关者的复杂项目。为了保持项目的组织性和效率&#xff0c;他们中的许多人依赖于项目管理工具。这些工具有助于跟踪任务&#xff0c;与团队成员沟通&#xff0c;并监控进度。让我们来看看一些大型互联网公司正在使用的项目管理工具。 1、Zo…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

【Zephyr 系列 10】实战项目:打造一个蓝牙传感器终端 + 网关系统(完整架构与全栈实现)

🧠关键词:Zephyr、BLE、终端、网关、广播、连接、传感器、数据采集、低功耗、系统集成 📌目标读者:希望基于 Zephyr 构建 BLE 系统架构、实现终端与网关协作、具备产品交付能力的开发者 📊篇幅字数:约 5200 字 ✨ 项目总览 在物联网实际项目中,**“终端 + 网关”**是…...

12.找到字符串中所有字母异位词

&#x1f9e0; 题目解析 题目描述&#xff1a; 给定两个字符串 s 和 p&#xff0c;找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义&#xff1a; 若两个字符串包含的字符种类和出现次数完全相同&#xff0c;顺序无所谓&#xff0c;则互为…...

基于SpringBoot在线拍卖系统的设计和实现

摘 要 随着社会的发展&#xff0c;社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统&#xff0c;主要的模块包括管理员&#xff1b;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准

城市路内停车管理常因行道树遮挡、高位设备盲区等问题&#xff0c;导致车牌识别率低、逃费率高&#xff0c;传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法&#xff0c;正成为破局关键。该设备安装于车位侧方0.5-0.7米高度&#xff0c;直接规避树枝遮…...

深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏

一、引言 在深度学习中&#xff0c;我们训练出的神经网络往往非常庞大&#xff08;比如像 ResNet、YOLOv8、Vision Transformer&#xff09;&#xff0c;虽然精度很高&#xff0c;但“太重”了&#xff0c;运行起来很慢&#xff0c;占用内存大&#xff0c;不适合部署到手机、摄…...

华为OD最新机试真题-数组组成的最小数字-OD统一考试(B卷)

题目描述 给定一个整型数组,请从该数组中选择3个元素 组成最小数字并输出 (如果数组长度小于3,则选择数组中所有元素来组成最小数字)。 输入描述 行用半角逗号分割的字符串记录的整型数组,0<数组长度<= 100,0<整数的取值范围<= 10000。 输出描述 由3个元素组成…...