BERT训练环节(代码实现)
1.代码实现
#导包
import torch
from torch import nn
import dltools
#加载数据需要用到的声明变量
batch_size, max_len = 1, 64
#获取训练数据迭代器、词汇表
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)
#其余都是二维数组
#tokens, segments, valid_lens(一维), pred_position, mlm_weights, mlm, nsp(一维)对应每条数据i中包含的数据
for i in train_iter: #遍历迭代器break #只遍历一条数据
[tensor([[ 3, 25, 0, 4993, 0, 24, 4, 26, 13, 2,158, 20, 5, 73, 1399, 2, 9, 813, 9, 987,45, 26, 52, 46, 53, 158, 2, 5, 3140, 5880,9, 543, 6, 6974, 2, 2, 315, 6, 8, 5,8698, 8, 17229, 9, 308, 2, 4, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1]]),tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),tensor([47.]),tensor([[ 9, 15, 26, 32, 34, 35, 45, 0, 0, 0]]),tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]]),tensor([[ 484, 1288, 20, 6, 2808, 9, 18, 0, 0, 0]]),tensor([0])]
#创建BERT网络模型
net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128], ffn_num_input=128, ffn_num_hiddens=256, num_heads=2, num_layers=2, dropout=0.2, key_size=128, query_size=128, value_size=128, hid_in_features=128, mlm_in_features=128, nsp_in_features=128)
#调用设备上的GPU
devices = dltools.try_all_gpus()
#损失函数对象
loss = nn.CrossEntropyLoss() #多分类问题,使用交叉熵
#@save #表示用于指示某些代码应该被保存或导出,以便于管理和重用
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):#前向传播#获取遮蔽词元的预测结果、下一个句子的预测结果_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1), pred_positions_X)#计算遮蔽语言模型的损失mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1,1)mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8) #MLM损失函数的归一化版本 #加一个很小的数1e-8,防止分母为0,抵消上一行代码乘以的数值#计算下一个句子预测任务的损失nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_lreturn mlm_l, nsp_l, l
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps): #文本词元样本量太多,全跑完花费的时间太多,若num_steps=1在BERT中表示,跑了1个batch_sizenet = nn.DataParallel(net, device_ids=devices).to(devices[0]) #调用设备的GPUtrainer = torch.optim.Adam(net.parameters(), lr=0.01) #梯度下降的优化算法Adamstep, timer = 0, dltools.Timer() #设置计时器#调用画图工具animator = dltools.Animator(xlabel='step', ylabel='loss', xlim=[1, num_steps], legend=['mlm', 'nsp'])#遮蔽语言模型损失的和, 下一句预测任务损失的和, 句子对的数量, 计数metric = dltools.Accumulator(4) #Accumulator类被设计用来收集和累加各种指标(metric)num_steps_reached = False #设置一个判断标志, 训练步数是否达到预设的步数while step < num_steps and not num_steps_reached:for tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y in train_iter:#将遍历的数据发送到设备上tokens_X = tokens_X.to(devices[0])segments_X = segments_X.to(devices[0])valid_lens_x = valid_lens_x.to(devices[0])pred_positions_X = pred_positions_X.to(devices[0])mlm_weights_X = mlm_weights_X.to(devices[0])mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])#梯度清零trainer.zero_grad()timer.start() #开始计时mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)l.backward() #反向传播trainer.step() #梯度更新metric.add(mlm_l, nsp_l, tokens_X.shape[0], l) #累积的参数指标timer.stop() #计时停止animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3])) #画图的step += 1 #训练完一个batch_size,就+1if step == num_steps: #若步数与预设的训练步数相等num_steps_reached = True #判断标志改为Truebreak #退出while循环print(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')print(f'{metric[2]/ timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')
train_bert(train_iter, net, loss, len(vocab), devices, 500)
def get_bert_encoding(net, tokens_a, tokens_b=None):tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0) #unsqueeze(0)增加一个维度segments = torch.tensor(segments, device=devices[0]).unsqueeze(0) valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)endoced_X, _, _ = net(token_ids, segments, valid_len)return endoced_X
tokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]),torch.Size([1, 128]),tensor([-0.5872, -0.0510, -0.7376], device='cuda:0', grad_fn=<SliceBackward0>))
encoded_text_crane
tensor([[-5.8725e-01, -5.0994e-02, -7.3764e-01, -4.3832e-02, 9.2467e-02,1.2745e+00, 2.7062e-01, 6.0271e-01, -5.5055e-02, 7.5122e-02,4.4872e-01, 7.5821e-01, -6.1558e-02, -1.2549e+00, 2.4479e-01,1.3132e+00, -1.0382e+00, -4.7851e-03, -6.3590e-01, -1.3180e+00,5.2245e-02, 5.0982e-01, 7.4168e-02, -2.2352e+00, 7.4425e-02,5.0371e-01, 7.2120e-02, -4.6384e-01, -1.6588e+00, 6.3987e-01,-6.4567e-01, 1.7187e+00, -6.9696e-01, 5.6788e-01, 3.2628e-01,-1.0486e+00, -7.2610e-01, 5.7909e-02, -1.6380e-01, -1.2834e+00,1.6431e+00, -1.5972e+00, -4.5678e-03, 8.8022e-02, 5.5931e-02,-7.2332e-02, -4.9313e-01, -4.2971e+00, 6.9757e-01, 7.0690e-02,-1.8613e+00, 2.0366e-01, 8.9868e-01, -3.4565e-01, 9.6776e-02,1.3699e-02, 7.1410e-01, 5.4820e-01, 9.7358e-01, -8.1038e-01,2.6216e-01, -5.7850e-01, -1.1969e-01, -2.5277e-01, -2.0046e-01,-1.6718e-01, 5.5540e-01, -1.8172e-01, -2.5639e-02, -6.0961e-01,-1.1521e-03, -9.2973e-02, 9.5226e-01, -2.4453e-01, 9.7340e-01,-1.7908e+00, -2.9840e-02, 2.3087e+00, 2.4889e-01, -7.2734e-01,2.1827e+00, -1.1172e+00, -7.0915e-02, 2.5138e+00, -1.0356e+00,-3.7332e-02, -5.6668e-01, 5.2251e-01, -5.0058e-01, 1.7354e+00,4.0760e-01, -1.2982e-01, -7.0230e-01, 3.1563e+00, 1.8754e-01,2.0220e-01, 1.4500e-01, 2.3296e+00, 4.5522e-02, 1.1762e-01,1.0662e+00, -4.0858e+00, 1.6024e-01, 1.7885e+00, -2.7034e-01,-1.6869e-01, -8.7018e-02, -4.2451e-01, 1.1446e-01, -1.5761e+00,7.6947e-02, 2.4336e+00, 4.5346e-02, -6.5078e-02, 1.4203e+00,3.7165e-01, -7.9571e-01, -1.3515e+00, 4.1511e-02, 1.3561e-01,-3.3006e+00, 1.4821e-01, 1.3024e-01, 1.9966e-01, -8.5910e-01,1.4505e+00, 7.6774e-02, 9.3771e-01]], device='cuda:0',grad_fn=<SliceBackward0>)
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just', 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
(torch.Size([1, 10, 128]),torch.Size([1, 128]),tensor([-0.4637, -0.0569, -0.6119], device='cuda:0', grad_fn=<SliceBackward0>))
相关文章:

BERT训练环节(代码实现)
1.代码实现 #导包 import torch from torch import nn import dltools #加载数据需要用到的声明变量 batch_size, max_len 1, 64 #获取训练数据迭代器、词汇表 train_iter, vocab dltools.load_data_wiki(batch_size, max_len) #其余都是二维数组 #tokens, segments, vali…...

必须执行该语句才能获得结果
UncategorizedSQLException: Error getting generated key or setting result to parameter object. Cause: com.microsoft.sqlserver.jdbc.SQLServerException: 必须执行该语句才能获得结果。 ; uncategorized SQLException; SQL state [null]; error code [0]; 必须执行该语句…...

AI论文写作可靠吗?分享5款论文写作助手ai免费网站
AI论文写作的可靠性是一个备受关注的话题。在当前的技术背景下,AI写作工具能够显著提高论文写作的效率和质量,但其可靠性和安全性仍需谨慎评估。 AI论文写作的可靠性 技术能力与限制 AI论文写作的质量很大程度上取决于用户提供的输入指令或素材的质量…...

AJAX 入门 day3 XMLHttpRequest、Promise对象、自己封装简单版的axios
目录 1.XMLHttpRequest 1.1 XMLHttpRequest认识 1.2 用ajax发送请求 1.3 案例 1.4 XMLHttpRequest - 查询参数 1.5 XMLHttpRequest - 数据提交 2.Promise 2.1 Promise认识 2.2 Promise - 三种状态 2.3 案例 3.封装简易版 axios 3.1 封装_简易axios_获取省份列表 3…...

oracle avg、count、max、min、sum、having、any、all、nvl的用法
组函数 having的使用 any的使用 all的使用 nvl 从执行结果来看,nvl(列名,默认值),nvl的作用就是如果列名所在的这一行出现空则用默认值替换...

Python一分钟:装饰器
一、装饰器基础 函数即对象 在python中函数可以作为参数传递,和任何其它对象一样如:str、int、float、list等等 def say_hello(name):return f"Hello {name}"def be_awesome(name):return f"Yo {name}, together were the awesomest!"def gr…...

Docker部署ddns-go教程(包含完整的配置过程)
本章教程教程,主要介绍如何用Docker部署ddns-go。 一、拉取容器 docker pull jeessy/ddns-go:v6.7.0二、运行容器 docker run -d \--name ddns-go \--restart unless-stopped \...

简单多状态dp第三弹 leetcode -买卖股票的最佳时机问题
309. 买卖股票的最佳时机含冷冻期 买卖股票的最佳时机含冷冻期 分析: 使用动态规划解决 状态表示: 由于有「买入」「可交易」「冷冻期」三个状态,因此我们可以选择用三个数组,其中: ▪ dp[i][0] 表示:第 i 天结束后,…...

游戏化在电子课程中的作用:提高参与度和学习成果
游戏化,即游戏设计元素在非游戏环境中的应用,已成为电子学习领域的强大工具。通过将积分、徽章、排行榜和挑战等游戏机制整合到教育内容中,电子课程可以变得更具吸引力、激励性和有效性。以下是游戏化如何在转变电子学习中发挥重要作用&#…...

php+mysql安装
1.卸载mysql 没启动不停止 2.下载 3.解压 4.点击安装 5.出现成功 端口占用修改 修改端口89或者87 可视化扩展 修改后重启 开启扩展...

音视频入门基础:FLV专题(5)——FFmpeg源码中,判断某文件是否为FLV文件的实现
一、引言 通过FFmpeg命令: ./ffmpeg -i XXX.flv 可以判断出某个文件是否为FLV文件: 所以FFmpeg是怎样判断出某个文件是否为FLV文件呢?它内部其实是通过flv_probe函数来判断的。从《FFmpeg源码:av_probe_input_format3函数和AVI…...

Tomcat 乱码问题彻底解决
1. 终端乱码问题 找到 tomcat 安装目录下的 conf ---> logging.properties .修改ConsoleHandler.endcoding GBK (如果在idea中设置了UTF-8字符集,这里就不需要修改) 2. CMD命令窗口设置编码 参考:WIN10的cmd查看编码方式&am…...

RGB颜色模型
RGB颜色模型是一种广泛应用于数字图像和计算机图形领域的颜色表示方法 一、基本概念 RGB 代表红色(Red)、绿色(Green)和蓝色(Blue)三种基本颜色。这些颜色被视为加色模型中的原色,意味着它们可…...

智能工厂的软件设计 创新型原始制造商(“创新工厂“)的Creator原型(统一行为理论)之2
Q8、今天我们继续昨天开始的 “智能工厂的软件设计”以“统一行为理论”为指导原则的 创新型原始制造商的Creator伪代码--创新工厂“原型”。这是在前述将“程序program”问题的三个体现“方面”(逻辑/语言/数学) 视为符号学的三分支(语用语义…...

【个人博客hexo版】hexo安装时会出现的一些问题
项目场景: 项目场景:在完成了GitHub仓库和git的连接之后,就要新建一个文件夹(例如hexo blog)进行下一步hexo的使用 问题描述 例如:如图所示 原因分析: 这些error不用看它到底是什么…...

道路裂缝,坑洼,病害数据集-包括无人机视角,摩托车视角,车辆视角覆盖道路
道路裂缝,坑洼,病害数据集 包括无人机视角,摩托车视角,车辆视角 覆盖道路所有问题 一共有八类16000张 1到7依次为: [横向裂缝, 纵向裂缝, 块状裂缝, 龟裂, 坑槽, 修补网状裂缝, 修补裂缝, 修补坑槽] 道路病害(如裂缝、…...

java接口文档配置
接口文档配置 一. swagger与knife4j 配置 1. 导入依赖 <!--swagger接口文档说明--> <dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2</artifactId> </dependency> <dependency><groupId>…...

【服务器第二期】mobaxterm软件下载及连接
【服务器第二期】mobaxterm软件下载及连接 前言什么是SSH什么是FTP/SFTP mobaxterm软件介绍mobaxterm软件下载SSH登录使用方法1-新建ssh连接方法2-打开已有的ssh连接方法3-通过ssh命令建立连接 SFTP数据传输方法1-建立ssh连接后直接拖拽方法2-建立sftp连接再拖拽方法3-直接使用…...

排序-----计数排序(非比较排序)
原理: 存在的问题:数组空间浪费 所以要相对映射,不要绝对映射 calloc()函数的功能是:为num个大小为size的元素开辟一块空间,并且把空间的每个字节初始化为0. // 时间复杂度:O(Nrange) // 只适合整数/适合范围集中 // 空间范围度:…...

[Python]案例驱动最佳入门:Python数据可视化在气候研究中的应用
在全球气候问题日益受到关注的今天,气温变化成为了科学家、政府、公众讨论的热门话题。然而,全球气温究竟是如何变化的?我们能通过数据洞察到哪些趋势?本文将通过真实模拟的气温数据,结合Python数据分析和可视化技术&a…...

PyQt5 导入ui文件报错 AttributeError: type object ‘Qt‘ has no attribute
问题描述: 利用 PyQt5 编写可视化界面是较为普遍的做法,但是使用全新UI版本的 Pycharm 修改之前正常的UI文件时,在没有动其他代码的情况下发现出现以下报错 AttributeError: type object Qt has no attribute Qt::ContextMenuPolicy::Defaul…...

Unity中Rigidbody 刚体组件和Rigidbody类是什么?
Rigidbody 刚体组件 Rigidbody 是 Unity 中的一个组件,它可以让你的游戏对象像真实世界中的物体一样移动和碰撞。想象一下,你有一个小球,你希望它像真实世界中的球一样滚动、弹跳和碰撞,那么你就可以给这个小球添加一个 Rigidbod…...

MySQL学习笔记(持续更新中)
1、Mysql概述 1.1 数据库相关概念 三个概念:数据库、数据库管理系统、SQL 名称全称简称数据库存储数据的仓库,数据是有组织的进行存储DataBase(DB)数据库管理系统操纵和管理数据库的大型软件DataBase Mangement System…...

sqlserver插入数据删除数据
1、插入数据 1.1 直接插入 1.1.1 方式一 insert into test values(001,黎明,1),(002,冯绍峰,1),(003,菲菲,2);1.1.2 方式二 insert into test(ID,Name,Sex) values(004,丽丽,2),(005,凌晨,2),(006,虾米,1);1.2 插入部分行 insert into test(ID,Name) values(007,红)2、删除…...

[51单片机] 简单介绍 (一)
文章目录 1.单片机介绍2.单片机内部三大资源3.单片机最小系统4.STC89C52单片机 1.单片机介绍 兼容Intel的MCS-51体系架构的一系列单片机。 STC89C52:8K FLASH、512字节RAM、32个IO口、3个定时器、1个UART、8个中断源。 单片机简称MCU单片机内部集成了CPU、RAM、RO…...

6个岗位抢1个人,百万年薪抢毕业生?大厂打响AI人才战
前言 “24岁毕业时年薪50万元,到了30岁大概能升到P7(注:职级名称),那时就能年薪百万了。” 从上海交大硕士毕业后,出生于2000年的赵宏在今年入职腾讯,担任AI算法工程师,成为AI风口下第一批就业…...

erlang学习:Linux命令学习3
shell基本输出 创建一个test.sh文件,并开放他的权限,之后向其中编辑以下内容 touch test.sh chmod 777 test.sh vim test.shecho "hello linux"之后运行相应shell程序得到输出 ./test.sh变量 单引号特点: 单引号里的任何字符都…...

力扣41 缺失的第一个正数 Java版本
文章目录 题目描述代码 题目描述 给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除…...

第五篇:Linux进程的相关知识总结(1)
目录 第四章:进程 4.1进程管理 4.1.1进程管理需要的学习目标 4.1.1.1了解进程的相关信息 4.1.1.2僵尸进程的概念和处理方法: 4.1.1.3PID、PPID的概念以及特性: 4.1.1.4进程状态 4.1.2进程管理PS 4.1.2.1静态查看进程 4.1.2.1.1自定义…...

企业级Windows server服务器技术(1)
windows server服务器安装 准备工作: 1.准备安装的镜像 2.安装好虚拟机VMware或者virtual box 3.准备安装的位置(选择你的电脑的磁盘上比较空闲的位置,新建一个文件夹并命名) 4.开始安装(按步骤)----…...