当前位置: 首页 > news >正文

深度学习:GPT-1的MindSpore实践

GPT-1简介

GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点:

  • NLP领域的迁移学习:通过最少的任务专项数据,利用预训练模型出色地完成具体的下游任务。
  • 语言建模作为预训练任务:使用无监督学习和大规模的文本语料库来训练模型
  • 为具体任务微调:采用预训练模型来适应监督任务

和BERT类似,GPT-1同样采取pre-train + fine-tune的思路:先基于大量未标注语料数据进行预训练, 后基于少量标注数据进行微调。但GPT-1在预训练任务思路和模型结构上与BERT有所差别。

GPT-1的目标是在预训练的过程中根据现有的所有词元,预测下一个词元。这个任务被称为“自回归语言建模”。

一个简单的例子:

输入序列为:“The sun rises in the”

训练数据的原句子为:“The sun rises in the east”

所以我们的目标输出为:“east”

将输入序列输入GPT模型,GPT根据输入预测下一个词元(“east”)在语料库中的概率分布

正确词元“east”作为一个“伪标签”来帮助模型训练

模型架构

GPT主要使用Transformer Decoder架构,但因为没有Encoder,所以在Transformer Decoder的基础上移除了计算Encoder与Decoder间注意力分数的Multi-Head Attention Layer。

Masked Multi-HeadSelf-Attention

Masked Multi-Head Self-Attention 是Multi-Head Attetion的变种。 最大的不同来自于MMSA的掩码机制,掩码机制防止模型通过观测未来的词元以进行“作弊”。

一个掩码词元<mask>被用于注意力分数矩阵,所以当前词元只能注意到序列中自己和自己之前的词元。未来的次元的注意力分数将被设为0以确保其在Softmax步骤后的实际贡献为0。

为什么掩码机制非常重要?

对于自回归任务,模型必须线性地生成词元,不能基于未来的信息预测下一个词元。

损失函数

GPT使用Cross-Entropy Loss作为损失函数:\mathcal{L} = - \sum_{t=1}^N \log P(w_t | w_1, w_2, \dots, w_{t-1})

交叉熵损失是这项任务的理想选择,因为它通过测量预测的概率分布与真实分布的距离来惩罚不正确的预测。它自然适于处理多类分类任务,其中模型从大量词汇表中选择一个标记。

模型输入

GPT-1的输入同样为句子或句子对,并添加Special Tokens。

  • [BOS]:表示句子的开始,(论文中给出的token表示为[START]),添加到序列最前;
  • [EOS]:表示序列的结束,(论文中的给出的[EXTRACT]),添加到序列最后,在进行分类任务时,会将 该special token对应的输出接入输出层;我们也可以理解为该token可以学习到整个句子的语义信息;
  • [SEP]:用于间隔句子对中的两个句子;
GPT Embedding 同样分为三类:token Embedding、Position Embedding、Segment Embedding

 

GPT-1模型具体参数

模型架构

  • 12个Transformer Decoder Block
  • hidden_size为768(模型输入和输出的向量纬度)
  • 注意力头数为12
  • FFN维度为3072
  • 词表(Vocab)大小为40000
  • 序列长度为512(上下文窗口长度)

训练过程

  • Adam优化器,超参数为:0.9, 0.99
  • 学习率:最大学习率:2.5x10e-4 使用2000步作为热身,随后线性衰退
  • 批大小:64
  • 梯度剪裁:1.0
  • Dropout率:0.1

训练过程

100000步,大约花费8张NVIDIA V100 GPU训练30天,共有117M参数。使用Xavier初始化,权重衰退为0.01。 

下游任务 

GPT按照生成式的逻辑统一了下游任务的应用模板,使用最后一个token([EOS]or[EXTRACT])对应的hidden state,输出到额外的输出层中,进行分类标签预测。
任务包括:文本分类(情感分类、新闻分类)、文本蕴含(根据前提推出假设)、文本语义相似度、多类选择(在多个next token中进行选择)

基于MindSpore微调GPT-1进行情感分类

# #安装mindnlp 0.4.0套件
# !pip install mindnlp
# !pip uninstall soundfile -y
# !pip install download
# !pip install jieba
# !pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simpleimport osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp.engine import Trainer# loading dataset
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']imdb_train.get_dataset_size()import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasetfrom mindnlp.transformers import OpenAIGPTTokenizer
# tokenizer
gpt_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)#为方便体验流程,把原本数据集的十分之一拿出来体验训练和评估,
imdb_train, _ = imdb_train.split([0.1, 0.9], randomize=False)# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)# load GPT sequence classification model and set class=2
from mindnlp.transformers import OpenAIGPTForSequenceClassification  # Import the GPT model for sequence classification
from mindnlp import evaluate  # Import the evaluation module from MindNLP
import numpy as np  # Import NumPy for numerical operations# Set up the GPT model for sequence classification with 2 output labels (binary classification).
model = OpenAIGPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)# Set the padding token ID in the model configuration to match the tokenizer's padding token ID.
model.config.pad_token_id = gpt_tokenizer.pad_token_id# Resize the token embedding layer to account for any added tokens (e.g., special tokens).
model.resize_token_embeddings(model.config.vocab_size + 3)from mindnlp.engine import TrainingArguments  # Import training arguments for model training configuration.# Define training arguments.
training_args = TrainingArguments(output_dir="gpt_imdb_finetune",  # Directory to save model checkpoints and outputs.evaluation_strategy="epoch",  # Evaluate the model at the end of each epoch.save_strategy="epoch",  # Save model checkpoints at the end of each epoch.logging_strategy="epoch",  # Log metrics and progress at the end of each epoch.load_best_model_at_end=True,  # Automatically load the best model (based on evaluation metrics) at the end of training.num_train_epochs=1.0,  # Number of training epochs (default is 1 for quick experimentation).learning_rate=2e-5  # Learning rate for the optimizer.
)# Load the accuracy metric for evaluation.
metric = evaluate.load("accuracy")# Define a function to compute metrics during evaluation.
def compute_metrics(eval_pred):logits, labels = eval_pred  # Unpack predictions (logits) and true labels.predictions = np.argmax(logits, axis=-1)  # Convert logits to class predictions using argmax.return metric.compute(predictions=predictions, references=labels)  # Compute accuracy metric.# Initialize the Trainer class with the model, training arguments, datasets, and metric computation function.
trainer = Trainer(model=model,  # The GPT model to be fine-tuned.args=training_args,  # Training configuration arguments.train_dataset=dataset_train,  # Training dataset (must be preprocessed and tokenized).eval_dataset=dataset_val,  # Validation dataset for evaluation.compute_metrics=compute_metrics  # Metric computation function for evaluation.
)# start training
trainer.train()trainer.evaluate(dataset_test)

相关文章:

深度学习:GPT-1的MindSpore实践

GPT-1简介 GPT-1&#xff08;Generative Pre-trained Transformer&#xff09;是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构&#xff0c;具有如下创新点&#xff1a; NLP领域的迁移学习&#xff1a;通过最…...

前端图像处理(一)

目录 一、上传 1.1、图片转base64 二、图片样式 2.1、图片边框【border-image】 三、Canvas 3.1、把canvas图片上传到服务器 3.2、在canvas中绘制和拖动矩形 3.3、图片(同色区域)点击变色 一、上传 1.1、图片转base64 传统上传&#xff1a; 客户端选择图片&#xf…...

unity中:超低入门级显卡、集显(功耗30W以下)运行unity URP管线输出的webgl程序有那些地方可以大幅优化帧率

删除Global Volume&#xff1a; 删除Global Volume是一项简单且高效的优化措施。实测表明&#xff0c;这一改动可以显著提升帧率&#xff0c;甚至能够将原本无法流畅运行的场景变得可用。 更改前的效果&#xff1a; 更改后的效果&#xff1a; 优化阴影和材质&#xff1a; …...

ftdi_sio应用学习笔记 4 - I2C

目录 1. 查找设备 2. 打开设备 3. 写数据 4. 读数据 5. 设置频率 6 验证 6.1 遍历设备 6.2 开关设备 6.3 读写测试 I2C设备最多有6个&#xff08;FT232H&#xff09;&#xff0c;其他为2个。和之前的设备一样&#xff0c;定义个I2C结构体记录找到的设备。 #define FT…...

如何更好的把控软件测试质量

如何更好的把控软件测试质量 在软件开发过程中&#xff0c;测试是确保软件质量、稳定性和用户体验的重要环节。随着需求的不断变化以及技术的不断进步&#xff0c;如何更好的把控软件测试质量已成为一个不可忽视的话题。本文将从几个维度探讨确保软件质量的方法和方案&#xf…...

“漫步北京”小程序及“气象景观数字化服务平台”上线啦

随着科技的飞速发展&#xff0c;智慧旅游已成为现代旅游业的重要趋势。近日&#xff0c;北京万云科技有限公司联合北京市气象服务中心&#xff0c;打造的“气象景观数字化服务平台“和“漫步北京“小程序已经上线&#xff0c;作为智慧旅游的典型代表&#xff0c;以其丰富的功能…...

SOL链上的 Meme 生态发展:从文化到创新的融合#dapp开发#

一、引言 随着区块链技术的不断发展&#xff0c;Meme 文化在去中心化领域逐渐崭露头角。从 Dogecoin 到 Shiba Inu&#xff0c;再到更多细分的 Meme 项目&#xff0c;这类基于网络文化的加密货币因其幽默和社区驱动力吸引了广泛关注。作为近年来备受瞩目的区块链平台之一&…...

身份证实名认证API接口助力电商购物安全

亲爱的网购达人们&#xff0c;你们是否曾经因为网络上的虚假信息和诈骗而感到困扰&#xff1f;在享受便捷的网购乐趣时&#xff0c;如何确保交易安全成为了我们共同关注的话题。今天&#xff0c;一起来了解一下翔云身份证实名认证接口如何为电子商务保驾护航&#xff0c;让您的…...

【过程控制系统】第6章 串级控制系统

目录 6. l 串级控制系统的概念 6.1.2 串级控制系统的组成 6.l.3 串级控制系统的工作过程 6.2 串级控制系统的分析 6.2.1 增强系统的抗干扰能力 6.2.2 改善对象的动态特性 6.2.3 对负荷变化有一定的自适应能力 6.3 串级控制系统的设计 6.3.1 副回路的选择 2.串级系…...

YOLOv11融合针对小目标FFCA-YOPLO中的FEM模块及相关改进思路

YOLOv11v10v8使用教程&#xff1a; YOLOv11入门到入土使用教程 YOLOv11改进汇总贴&#xff1a;YOLOv11及自研模型更新汇总 《FFCA-YOLO for Small Object Detection in Remote Sensing Images》 一、 模块介绍 论文链接&#xff1a;https://ieeexplore.ieee.org/document/10…...

qt+opengl 三维物体加入摄像机

1 在前几期的文章中&#xff0c;我们已经实现了三维正方体的显示了&#xff0c;那我们来实现让物体的由远及近&#xff0c;和由近及远。这里我们需要了解一个概念摄像机。 1.1 摄像机定义&#xff1a;在世界空间中位置、观察方向、指向右侧向量、指向上方的向量。如下图所示: …...

day05(单片机高级)PCB基础

目录 PCB基础 什么是PCB&#xff1f;PCB的作用&#xff1f; PCB的制作过程 PCB板的层数 PCB设计软件 安装立创EDA PCB基础 什么是PCB&#xff1f;PCB的作用&#xff1f; PCB&#xff08;Printed Circuit Board&#xff09;&#xff0c;中文名称为印制电路板&#xff0c;又称印刷…...

全球天气预报5天-经纬度版免费API接口教程

接口简介&#xff1a; 获取全球任意地区未来5天天气预报&#xff0c;必须传经纬度参数。可先调用【位置坐标】分类下相关接口获取地区经纬度坐标。 请求地址&#xff1a; https://cn.apihz.cn/api/tianqi/tqybjw5.php 请求方式&#xff1a; POST或GET。 请求参数&#xff1a…...

Shell编程8

声明&#xff01; 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下&#xff0c;如涉及侵权马上删除文章&#xff0c;笔记只是方便各位师傅的学习和探讨&#xff0c;文章所提到的网站以及内容&#xff0c;只做学习交流&#xff0c;其他均与本人以及泷羽sec团队无关&a…...

python语言基础-5 进阶语法-5.5 上下文管理协议(with语句)

声明&#xff1a;本内容非盈利性质&#xff0c;也不支持任何组织或个人将其用作盈利用途。本内容来源于参考书或网站&#xff0c;会尽量附上原文链接&#xff0c;并鼓励大家看原文。侵删。 5.5 上下文管理协议&#xff08;with语句&#xff09;&#xff08;参考链接&#xff1…...

自动驾驶3D目标检测综述(三)

前两篇综述阅读理解放在这啦&#xff0c;有需要自行前往观看&#xff1a; 第一篇&#xff1a;自动驾驶3D目标检测综述&#xff08;一&#xff09;_3d 目标检测-CSDN博客 第二篇&#xff1a;自动驾驶3D目标检测综述&#xff08;二&#xff09;_子流行稀疏卷积 gpu实现-CSDN博客…...

【GESP】C++三级练习 luogu-B3661, [语言月赛202209] 排排

三级知识点一维数组练习&#xff0c;除了应用了数组以外&#xff0c;其余逻辑比较简单&#xff0c;适合初学者。 题目题解详见&#xff1a;https://www.coderli.com/gesp-3-luogu-b3661/ 【GESP】C三级练习 luogu-B3661, [语言月赛202209] 排排队 | OneCoder三级知识点一维数…...

【PPTist】添加PPT模版

前言&#xff1a;这篇文章来探索一下如何应用其他的PPT模版&#xff0c;给一个下拉菜单&#xff0c;列出几个项目中内置的模版 PPT模版数据 &#xff08;一&#xff09;增加菜单项 首先在下面这个菜单中增加一个“切换模版”的菜单项&#xff0c;点击之后在弹出框中显示所有的…...

大疆上云api开发

目前很多公司希望使用上云api开发自己的无人机平台,但是官网资料不是特别全,下面浅谈一下本人开发过程中遇到的一系列问题。 本人使用机场为大疆机场2&#xff0c;飞机为M3TD&#xff0c;纯内网使用 部署 链接: 上云api代码. 首先从github上面拉去代码 上云api代码github. 后…...

IDEA2023 SpringBoot整合MyBatis(三)

一、数据库表 CREATE TABLE students (id INT AUTO_INCREMENT PRIMARY KEY,name VARCHAR(100) NOT NULL,age INT,gender ENUM(Male, Female, Other),email VARCHAR(100) UNIQUE,phone_number VARCHAR(20),address VARCHAR(255),date_of_birth DATE,enrollment_date DATE,cours…...

【Apache Paimon】-- 6 -- 清理过期数据

目录 1、简要介绍 2、操作方式和步骤 2.1、调整快照文件过期时间 2.2、设置分区过期时间 2.2.1、举例1 2.2.2、举例2 2.3、清理废弃文件 3、参考 1、简要介绍 清理 paimon (表)过期数据可以释放存储空间,优化资源利用并提升系统运行效率等。本文将介绍如何清理 Paim…...

C语言数据结构——详细讲解 双链表

从单链表到双链表&#xff1a;数据结构的演进与优化 前言一、单链表回顾二、单链表的局限性三、什么是双链表四、双链表的优势1.双向遍历2.不带头双链表的用途3.带头双链表的用途 五、双链表的操作双链表的插入操作&#xff08;一&#xff09;双链表的尾插操作&#xff08;二&a…...

Shell脚本基础(4):条件判断

内容预览 ≧∀≦ゞ Shell脚本基础&#xff08;4&#xff09;&#xff1a;条件判断声明导语基本的if语句结构数值比较运算符文件测试运算符扩展&#xff1a;使用elif和else使用&&和||结合条件判断小结 Shell脚本基础&#xff08;4&#xff09;&#xff1a;条件判断 声明…...

在 Swift 中实现字符串分割问题:以字典中的单词构造句子

文章目录 前言摘要描述题解答案题解代码题解代码分析示例测试及结果时间复杂度空间复杂度总结 前言 本题由于没有合适答案为以往遗留问题&#xff0c;最近有时间将以往遗留问题一一完善。 LeetCode - #140 单词拆分 II 不积跬步&#xff0c;无以至千里&#xff1b;不积小流&…...

win10中使用ffmpeg和MediaMTX 推流rtsp视频

在win10上测试下ffmpeg推流rtsp视频&#xff0c;需要同时用到流媒体服务器MediaMTX 。ffmpeg推流到流媒体服务器MediaMTX &#xff0c;其他客户端从流媒体服务器拉流。 步骤如下&#xff1a; 1 下载MediaMTX github: Release v1.9.3 bluenviron/mediamtx GitHub​​​​​…...

16. 【.NET 8 实战--孢子记账--从单体到微服务】--汇率获取定时器

这篇文章我们将一起编写这个系列专栏中第一个和外部系统交互的功能&#xff1a;获取每日汇率。下面我们一起来编写代码吧。 一、需求 根据文章标题可知&#xff0c;在这片文章中我们只进行汇率的获取和写入数据库。 编号需求说明1获取每日汇率1. 从第三方汇率API中获取汇率信…...

C#元组详解:创建、访问与解构

在C#中&#xff0c;元组&#xff08;Tuple&#xff09;是一种数据结构&#xff0c;用于将多个元素组合成一个单一的对象。元组可以包含不同类型的元素&#xff0c;并且每个元素都有一个指定的位置&#xff08;索引&#xff09;。元组在需要临时组合多个值而不想创建自定义类时非…...

wsl2安装

Windows Subsystem for Linux 2 (WSL2) 是 Windows 10 和 Windows 11 中用于运行 Linux 二进制可执行文件的兼容层。WSL2 是 WSL 的最新版本&#xff0c;提供了更快的文件系统性能和完整的系统调用兼容性。本教程将指导你如何在 Windows 系统上安装 WSL2。 前提条件 操作系统要…...

android studio无法下载,Could not GET xxx, Received status code 400

-- 1. 使用下面的地址代替 原地址: distributionUrlhttps\://services.gradle.org/distributions/gradle-6.5-all.zip 镜像地址: distributionUrlhttps\://downloads.gradle-dn.com/distributions/gradle-6.5-all.zips 上面的已经不好用了 https\://mirrors.cloud.tencent.c…...

RUST学习教程-安装教程

文章目录 参考文档安装教程更新卸载 参考文档 https://course.rs/first-try/installation.html 安装教程 Linux或者mac安装教程 curl --proto https --tlsv1.2 https://sh.rustup.rs -sSf | sh安装完成&#xff0c;当出现command not found的时候&#xff0c;需要source一下…...