昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要
1、mindnlp 版本要求
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
2、数据集加载与处理
2.1 数据集加载
本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。
from mindspore.dataset import TextFileDataset # 从mindspore.dataset模块中导入TextFileDataset类# load dataset # 加载数据集
dataset = TextFileDataset(str(path), shuffle=False) # 创建一个TextFileDataset实例,参数是文件路径(path)转换成字符串格式,shuffle=False表示不打乱数据顺序
dataset.get_dataset_size() # 获取数据集的大小,即数据集中样本的数量
![]()
# split into training and testing dataset # 将数据集分割为训练集和测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False) # 将数据集按比例[0.9, 0.1]分割为训练集和测试集,randomize=False表示不随机打乱数据
2.2 数据预处理

import json # 导入json模块,用于处理JSON数据
import numpy as np # 导入numpy模块,并简写为np,用于处理数组和矩阵# preprocess dataset # 预处理数据集
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes()) # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization']) # 返回文章和摘要的numpy数组# 定义一个嵌套函数merge_and_pad,用于合并并填充数据def merge_and_pad(article, summary):# tokenization # 进行分词操作# pad to max_seq_length, only truncate the article # 填充到最大序列长度,仅截断文章部分tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len) # 使用tokenizer对文章和摘要进行分词,填充到最大长度,仅截断文章部分return tokenized['input_ids'], tokenized['input_ids'] # 返回分词后的输入ID(注意:这里的input_ids和labels是相同的)dataset = dataset.map(read_map, 'text', ['article', 'summary']) # 使用read_map函数对数据集进行映射,提取文章和摘要# change column names to input_ids and labels for the following training # 更改列名为input_ids和labels,以便后续训练dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels']) # 使用merge_and_pad函数对数据进行映射,生成input_ids和labelsdataset = dataset.batch(batch_size) # 将数据集按批次大小进行分批处理if shuffle:dataset = dataset.shuffle(batch_size) # 如果shuffle为True,则对批次进行随机打乱return dataset # 返回预处理后的数据集
因GPT2无中文的tokenizer,我们使用BertTokenizer替代。
from mindnlp.transformers import BertTokenizer # 从mindnlp.transformers模块中导入BertTokenizer类# We use BertTokenizer for tokenizing Chinese context. # 我们使用BertTokenizer对中文内容进行分词
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') # 使用预训练的'bert-base-chinese'模型初始化BertTokenizer
len(tokenizer) # 获取tokenizer的词汇表大小

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4) # 使用process_dataset函数对训练数据集进行预处理,传入参数包括训练数据集、分词器和批次大小为4
next(train_dataset.create_tuple_iterator()) # 创建一个tuple迭代器并获取其第一个元素

3、模型构建
3.1 构建GPT2ForSummarization模型,注意shift right的操作。
from mindspore import ops # 从mindspore模块导入ops操作
from mindnlp.transformers import GPT2LMHeadModel # 从mindnlp.transformers模块中导入GPT2LMHeadModel类# 定义一个用于摘要生成的GPT2模型类,继承自GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):# 定义模型的构造函数def construct(self,input_ids=None, # 输入IDattention_mask=None, # 注意力掩码labels=None, # 标签):# 调用父类的construct方法,获取模型输出outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :] # 移动logits,使其与shift_labels对齐shift_labels = labels[..., 1:] # 移动标签,使其与shift_logits对齐# Flatten the tokens # 将tokens展平loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id) # 计算交叉熵损失,忽略填充的tokenreturn loss # 返回计算的损失
3.2 动态学习率
from mindspore import ops # 从mindspore模块导入ops操作
from mindspore.nn.learning_rate_schedule import LearningRateSchedule # 从mindspore.nn.learning_rate_schedule模块导入LearningRateSchedule类# 定义一个线性学习率衰减与热身相结合的学习率调度类,继承自LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate. # 热身-衰减学习率。"""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__() # 调用父类的构造函数self.learning_rate = learning_rate # 初始化学习率self.num_warmup_steps = num_warmup_steps # 初始化热身步数self.num_training_steps = num_training_steps # 初始化训练步数# 定义构造函数def construct(self, global_step):# 如果当前步数小于热身步数if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate # 线性增加学习率# 否则,学习率进行线性衰减return ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate # 计算并返回衰减后的学习率
4、模型训练
num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn # 从mindspore模块导入nn(神经网络)模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel # 从mindnlp.transformers模块导入GPT2Config和GPT2LMHeadModel类# 配置GPT2模型的配置
config = GPT2Config(vocab_size=len(tokenizer)) # 创建GPT2配置实例,设置词汇表大小为tokenizer的长度
model = GPT2ForSummarization(config) # 使用配置实例创建一个GPT2ForSummarization模型# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps) # 创建线性热身-衰减学习率调度器# 创建优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler) # 使用AdamWeightDecay优化器,并传入模型的可训练参数和学习率调度器
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
![]()
from mindnlp._legacy.engine import Trainer # 从mindnlp._legacy.engine模块导入Trainer类
from mindnlp._legacy.engine.callbacks import CheckpointCallback # 从mindnlp._legacy.engine.callbacks模块导入CheckpointCallback类# 创建一个CheckpointCallback实例,用于保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', # 检查点保存路径ckpt_name='gpt2_summarization', # 检查点文件名epochs=1, # 每个epoch保存一次检查点keep_checkpoint_max=2 # 最多保留两个检查点
)# 创建一个Trainer实例,用于训练模型
trainer = Trainer(network=model, # 要训练的模型train_dataset=train_dataset, # 训练数据集epochs=1, # 训练的epoch数optimizer=optimizer, # 优化器callbacks=ckpoint_cb # 回调函数,包括检查点回调
)trainer.set_amp(level='O1') # 开启混合精度训练,级别设置为'O1'
下面这段代码,运行时间较长,最好选择较高算力。
trainer.run(tgt_columns="labels") # 运行训练器,指定目标列为“labels”

配置不够,训练时间太长。
5、模型推理
数据处理,将向量数据变为中文数据
def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes()) # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization']) # 返回文章和摘要的numpy数组# 定义一个嵌套函数pad,用于对文章进行分词和填充def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len) # 对文章进行分词,截断至最大长度减去摘要长度return tokenized['input_ids'] # 返回分词后的输入IDdataset = dataset.map(read_map, 'text', ['article', 'summary']) # 使用read_map函数对数据集进行映射,提取文章和摘要dataset = dataset.map(pad, 'article', ['input_ids']) # 使用pad函数对文章进行分词和填充,生成input_idsdataset = dataset.batch(batch_size) # 将数据集按批次大小进行分批处理return dataset # 返回预处理后的数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 创建一个tuple迭代器并获取其第一个元素,以NumPy数组的形式输出,并打印出来
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config) # 从预训练的检查点加载模型
model.set_train(False) # 设置模型为评估模式(非训练模式)
model.config.eos_token_id = model.config.sep_token_id # 设置模型的eos_token_id为sep_token_id
i = 0 # 初始化计数器为0# 遍历测试数据集的迭代器,获取输入ID和原始摘要
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():# 使用模型生成新的摘要,参数包括最大新生成的token数量、束搜索的束数、不重复的ngram大小output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)# 将生成的ID转换为文本output_text = tokenizer.decode(output_ids[0].tolist())print(output_text) # 打印生成的摘要文本i += 1 # 计数器加1if i == 1: # 如果计数器达到1break # 跳出循环,仅生成并打印一个摘要

相关文章:
昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要
1、mindnlp 版本要求 !pip install tokenizers0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple # 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行!pip install mindnlp0.3.1 !pip install mindnlp …...
React Router 6笔记
一个路由就是一个映射关系 key为路径,value可能是function或component 路由分类 后端路由(node) value是function,用来处理客户端提交的请求注册路由:router.get(path, function(req, res))工作过程:当…...
Android init 中的wait_for_property指令
Android开机优化系列文档-CSDN博客 Android 14 开机时间优化措施汇总-CSDN博客Android 14 开机时间优化措施-CSDN博客根据systrace报告优化系统时需要关注的指标和优化策略-CSDN博客Android系统上常见的性能优化工具-CSDN博客Android上如何使用perfetto分析systrace-CSDN博客A…...
智能合约语言(eDSL)—— 并行化方案——调度算法
3、调度算法 处理区块的时候,我们会同时启动多个线程去执行多个交易,这个时候我们需要一个良好的调度策略,来决定当前的线程是应该执行交易还是验证交易、提前结束还是立刻重新执行交易等,只有有一个良好调度策略才能保证所有交易都稳定有序的执行; 线程数量 这是一个不…...
vue2.0中如何实现数据监听
vue2中实现数据监听的原理 在Vue 2中,数据监听是通过ES5的Object.defineProperty实现的。Vue在初始化数据对象时,会遍历data对象,并使用Object.defineProperty为每个属性设置getter和setter。当你尝试读取或修改数据属性时,这些g…...
kafka开启kerberos和ACL
作者:恩慈 一、部署kafka-KB包 1.上传软件包 依次点击 部署中心----部署组件----上传软件包 选择需要升级的kafka版本并点击确定 2.部署kafka 依次点击部署中心----部署组件----物理/虚拟机部署----选择集群----下一步 选择手动部署-…...
QT+winodow 代码适配调试总结(三)
问题描述: 1、开发测试环境为: A: window10 64位 B: QT版本为4.8.6 C:采用VS2017 C++ Compiler 9.0 (x86)编译器版本 根据总结(二)经验,开发环境的可执行程序显示正常; 2、新的环境运行的时候显示乱码; 经过查阅资料,还是代码环境编码配置的问题,下面为解…...
Linux之旅:常用的指令,热键和权限管理
目录 前言 1. Linux指令 (1) ls (2) pwd 和 cd (3)touch 和 mkdir (4) rmdir 和 rm (5)cp (6)mv (7)…...
简单实用的企业舆情安全解决方案
前言:企业舆情安全重要吗?其实很重要,尤其面对负面新闻,主动处理和应对,可以掌握主动权,避免股价下跌等,那么如何做使用简单实用的企业舆情解决方案呢? 背景 好了,提取词…...
【中项】系统集成项目管理工程师-第2章 信息技术发展-2.1信息技术及其发展-2.1.1计算机软硬件与2.1.2计算机网络
前言:系统集成项目管理工程师专业,现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试,全称为“全国计算机与软件专业技术资格(水平)考试”&…...
SpringBoot集成Sharding-JDBC-5.3.0实现按月动态建表分表
Sharding-JDBC系列 1、Sharding-JDBC分库分表的基本使用 2、Sharding-JDBC分库分表之SpringBoot分片策略 3、Sharding-JDBC分库分表之SpringBoot主从配置 4、SpringBoot集成Sharding-JDBC-5.3.0分库分表 5、SpringBoot集成Sharding-JDBC-5.3.0实现按月动态建表分表 前言 …...
ubuntu 上安装中文输入法
在Ubuntu上安装中文输入法,通常有以下几种方法: 方法一:使用Fcitx输入法框架和搜狗输入法 安装Fcitx: sudo apt update sudo apt install fcitx fcitx-bin fcitx-table-all 安装搜狗输入法: 首先,从搜狗…...
Postman导出excel文件
0 写在前面 在我们后端写接口的时候,前端页面还没有出来,我们就得先接口测试,在此记录下如何使用postman测试导出excel接口。 如果不会使用接口传参可以看我这篇博客如何使用Postman 1 方法一 2 方法二 3 写在末尾 虽然在代码中写入文件名…...
你还在手动构建Python项目吗?PyBuilder让一切自动化!
在 Python 项目开发中,构建和管理项目是一项繁琐但必不可少的工作。你可能需要处理依赖项、运行测试、生成文档等。这时候,PyBuilder 出场了。它是一个强大的构建自动化工具,可以帮助你简化项目管理,让你更专注于编写代码。 什么…...
WebRTC音视频-前言介绍
目录 效果预期 1:WebRTC相关简介 1.1:WebRTC和RTC 1.2:WebRTC前景和应用 2:WebRTC通话原理 2.1:媒体协商 2.2:网络协商 2.3:信令服务器 效果预期 1:WebRTC相关简介 1.1&…...
centos/rocky容器中安装xfce、xrdp记录
最近需要一台机器来测试rdp连接,使用容器linuxxfcexrdp来实现,在此记录下主要步骤 启动rockylinux容器(其他linux发行版步骤应该相似) docker run -it -p 33891:3389 rockylinux:9.3 bash容器内操作 # 省略替换软件源步骤 ...# …...
实战:Eureka的概念作用以及用法详解
概叙 什么是Eureka? Netflix Eureka 是一款由 Netflix 开源的基于 REST 服务的注册中心,用于提供服务发现功能。Spring Cloud Eureka 是 Spring Cloud Netflix 微服务套件的一部分,基于 Netflix Eureka 进行了二次封装,主要负责…...
jupyter_contrib_nbextensions安装失败问题
目录 1.文件路径长度问题 2.jupyter不出现Nbextensions选项 1.文件路径长度问题 问题: could not create build\bdist.win-amd64\wheel\.\jupyter_contrib_nbextensions\nbextensions\contrib_nbextensions_help_item\contrib_nbextensions_help_item.yaml: No su…...
设计模式-Git-其他
目录 设计模式? 创建型模式 单例模式? 啥情况需要单例模式 实现单例模式的关键点? 常见的单例模式实现? 01、饿汉式如何实现单例? 02、懒汉式如何实现单例? 03、双重检查锁定如何实现单例ÿ…...
【C#】计算两条直线的交点坐标
问题描述 计算两条直线的交点坐标,可以理解为给定坐标P1、P2、P3、P4,形成两条线,返回这两条直线的交点坐标? 注意区分:这两条线是否垂直、是否平行。 代码实现 斜率解释 斜率是数学中的一个概念,特别是…...
XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql
智慧工地管理云平台系统,智慧工地全套源码,java版智慧工地源码,支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求,提供“平台网络终端”的整体解决方案,提供劳务管理、视频管理、智能监测、绿色施工、安全管…...
【快手拥抱开源】通过快手团队开源的 KwaiCoder-AutoThink-preview 解锁大语言模型的潜力
引言: 在人工智能快速发展的浪潮中,快手Kwaipilot团队推出的 KwaiCoder-AutoThink-preview 具有里程碑意义——这是首个公开的AutoThink大语言模型(LLM)。该模型代表着该领域的重大突破,通过独特方式融合思考与非思考…...
屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!
5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...
Angular微前端架构:Module Federation + ngx-build-plus (Webpack)
以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...
网站指纹识别
网站指纹识别 网站的最基本组成:服务器(操作系统)、中间件(web容器)、脚本语言、数据厍 为什么要了解这些?举个例子:发现了一个文件读取漏洞,我们需要读/etc/passwd,如…...
(一)单例模式
一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...
解决:Android studio 编译后报错\app\src\main\cpp\CMakeLists.txt‘ to exist
现象: android studio报错: [CXX1409] D:\GitLab\xxxxx\app.cxx\Debug\3f3w4y1i\arm64-v8a\android_gradle_build.json : expected buildFiles file ‘D:\GitLab\xxxxx\app\src\main\cpp\CMakeLists.txt’ to exist 解决: 不要动CMakeLists.…...
OD 算法题 B卷【正整数到Excel编号之间的转换】
文章目录 正整数到Excel编号之间的转换 正整数到Excel编号之间的转换 excel的列编号是这样的:a b c … z aa ab ac… az ba bb bc…yz za zb zc …zz aaa aab aac…; 分别代表以下的编号1 2 3 … 26 27 28 29… 52 53 54 55… 676 677 678 679 … 702 703 704 705;…...
