昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要
1、mindnlp 版本要求
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
2、数据集加载与处理
2.1 数据集加载
本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。
from mindspore.dataset import TextFileDataset # 从mindspore.dataset模块中导入TextFileDataset类# load dataset # 加载数据集
dataset = TextFileDataset(str(path), shuffle=False) # 创建一个TextFileDataset实例,参数是文件路径(path)转换成字符串格式,shuffle=False表示不打乱数据顺序
dataset.get_dataset_size() # 获取数据集的大小,即数据集中样本的数量
![]()
# split into training and testing dataset # 将数据集分割为训练集和测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False) # 将数据集按比例[0.9, 0.1]分割为训练集和测试集,randomize=False表示不随机打乱数据
2.2 数据预处理

import json # 导入json模块,用于处理JSON数据
import numpy as np # 导入numpy模块,并简写为np,用于处理数组和矩阵# preprocess dataset # 预处理数据集
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes()) # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization']) # 返回文章和摘要的numpy数组# 定义一个嵌套函数merge_and_pad,用于合并并填充数据def merge_and_pad(article, summary):# tokenization # 进行分词操作# pad to max_seq_length, only truncate the article # 填充到最大序列长度,仅截断文章部分tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len) # 使用tokenizer对文章和摘要进行分词,填充到最大长度,仅截断文章部分return tokenized['input_ids'], tokenized['input_ids'] # 返回分词后的输入ID(注意:这里的input_ids和labels是相同的)dataset = dataset.map(read_map, 'text', ['article', 'summary']) # 使用read_map函数对数据集进行映射,提取文章和摘要# change column names to input_ids and labels for the following training # 更改列名为input_ids和labels,以便后续训练dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels']) # 使用merge_and_pad函数对数据进行映射,生成input_ids和labelsdataset = dataset.batch(batch_size) # 将数据集按批次大小进行分批处理if shuffle:dataset = dataset.shuffle(batch_size) # 如果shuffle为True,则对批次进行随机打乱return dataset # 返回预处理后的数据集
因GPT2无中文的tokenizer,我们使用BertTokenizer替代。
from mindnlp.transformers import BertTokenizer # 从mindnlp.transformers模块中导入BertTokenizer类# We use BertTokenizer for tokenizing Chinese context. # 我们使用BertTokenizer对中文内容进行分词
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') # 使用预训练的'bert-base-chinese'模型初始化BertTokenizer
len(tokenizer) # 获取tokenizer的词汇表大小

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4) # 使用process_dataset函数对训练数据集进行预处理,传入参数包括训练数据集、分词器和批次大小为4
next(train_dataset.create_tuple_iterator()) # 创建一个tuple迭代器并获取其第一个元素

3、模型构建
3.1 构建GPT2ForSummarization模型,注意shift right的操作。
from mindspore import ops # 从mindspore模块导入ops操作
from mindnlp.transformers import GPT2LMHeadModel # 从mindnlp.transformers模块中导入GPT2LMHeadModel类# 定义一个用于摘要生成的GPT2模型类,继承自GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):# 定义模型的构造函数def construct(self,input_ids=None, # 输入IDattention_mask=None, # 注意力掩码labels=None, # 标签):# 调用父类的construct方法,获取模型输出outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :] # 移动logits,使其与shift_labels对齐shift_labels = labels[..., 1:] # 移动标签,使其与shift_logits对齐# Flatten the tokens # 将tokens展平loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id) # 计算交叉熵损失,忽略填充的tokenreturn loss # 返回计算的损失
3.2 动态学习率
from mindspore import ops # 从mindspore模块导入ops操作
from mindspore.nn.learning_rate_schedule import LearningRateSchedule # 从mindspore.nn.learning_rate_schedule模块导入LearningRateSchedule类# 定义一个线性学习率衰减与热身相结合的学习率调度类,继承自LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate. # 热身-衰减学习率。"""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__() # 调用父类的构造函数self.learning_rate = learning_rate # 初始化学习率self.num_warmup_steps = num_warmup_steps # 初始化热身步数self.num_training_steps = num_training_steps # 初始化训练步数# 定义构造函数def construct(self, global_step):# 如果当前步数小于热身步数if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate # 线性增加学习率# 否则,学习率进行线性衰减return ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate # 计算并返回衰减后的学习率
4、模型训练
num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn # 从mindspore模块导入nn(神经网络)模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel # 从mindnlp.transformers模块导入GPT2Config和GPT2LMHeadModel类# 配置GPT2模型的配置
config = GPT2Config(vocab_size=len(tokenizer)) # 创建GPT2配置实例,设置词汇表大小为tokenizer的长度
model = GPT2ForSummarization(config) # 使用配置实例创建一个GPT2ForSummarization模型# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps) # 创建线性热身-衰减学习率调度器# 创建优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler) # 使用AdamWeightDecay优化器,并传入模型的可训练参数和学习率调度器
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
![]()
from mindnlp._legacy.engine import Trainer # 从mindnlp._legacy.engine模块导入Trainer类
from mindnlp._legacy.engine.callbacks import CheckpointCallback # 从mindnlp._legacy.engine.callbacks模块导入CheckpointCallback类# 创建一个CheckpointCallback实例,用于保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', # 检查点保存路径ckpt_name='gpt2_summarization', # 检查点文件名epochs=1, # 每个epoch保存一次检查点keep_checkpoint_max=2 # 最多保留两个检查点
)# 创建一个Trainer实例,用于训练模型
trainer = Trainer(network=model, # 要训练的模型train_dataset=train_dataset, # 训练数据集epochs=1, # 训练的epoch数optimizer=optimizer, # 优化器callbacks=ckpoint_cb # 回调函数,包括检查点回调
)trainer.set_amp(level='O1') # 开启混合精度训练,级别设置为'O1'
下面这段代码,运行时间较长,最好选择较高算力。
trainer.run(tgt_columns="labels") # 运行训练器,指定目标列为“labels”

配置不够,训练时间太长。
5、模型推理
数据处理,将向量数据变为中文数据
def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes()) # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization']) # 返回文章和摘要的numpy数组# 定义一个嵌套函数pad,用于对文章进行分词和填充def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len) # 对文章进行分词,截断至最大长度减去摘要长度return tokenized['input_ids'] # 返回分词后的输入IDdataset = dataset.map(read_map, 'text', ['article', 'summary']) # 使用read_map函数对数据集进行映射,提取文章和摘要dataset = dataset.map(pad, 'article', ['input_ids']) # 使用pad函数对文章进行分词和填充,生成input_idsdataset = dataset.batch(batch_size) # 将数据集按批次大小进行分批处理return dataset # 返回预处理后的数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 创建一个tuple迭代器并获取其第一个元素,以NumPy数组的形式输出,并打印出来
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config) # 从预训练的检查点加载模型
model.set_train(False) # 设置模型为评估模式(非训练模式)
model.config.eos_token_id = model.config.sep_token_id # 设置模型的eos_token_id为sep_token_id
i = 0 # 初始化计数器为0# 遍历测试数据集的迭代器,获取输入ID和原始摘要
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():# 使用模型生成新的摘要,参数包括最大新生成的token数量、束搜索的束数、不重复的ngram大小output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)# 将生成的ID转换为文本output_text = tokenizer.decode(output_ids[0].tolist())print(output_text) # 打印生成的摘要文本i += 1 # 计数器加1if i == 1: # 如果计数器达到1break # 跳出循环,仅生成并打印一个摘要

相关文章:
昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要
1、mindnlp 版本要求 !pip install tokenizers0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple # 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行!pip install mindnlp0.3.1 !pip install mindnlp …...
React Router 6笔记
一个路由就是一个映射关系 key为路径,value可能是function或component 路由分类 后端路由(node) value是function,用来处理客户端提交的请求注册路由:router.get(path, function(req, res))工作过程:当…...
Android init 中的wait_for_property指令
Android开机优化系列文档-CSDN博客 Android 14 开机时间优化措施汇总-CSDN博客Android 14 开机时间优化措施-CSDN博客根据systrace报告优化系统时需要关注的指标和优化策略-CSDN博客Android系统上常见的性能优化工具-CSDN博客Android上如何使用perfetto分析systrace-CSDN博客A…...
智能合约语言(eDSL)—— 并行化方案——调度算法
3、调度算法 处理区块的时候,我们会同时启动多个线程去执行多个交易,这个时候我们需要一个良好的调度策略,来决定当前的线程是应该执行交易还是验证交易、提前结束还是立刻重新执行交易等,只有有一个良好调度策略才能保证所有交易都稳定有序的执行; 线程数量 这是一个不…...
vue2.0中如何实现数据监听
vue2中实现数据监听的原理 在Vue 2中,数据监听是通过ES5的Object.defineProperty实现的。Vue在初始化数据对象时,会遍历data对象,并使用Object.defineProperty为每个属性设置getter和setter。当你尝试读取或修改数据属性时,这些g…...
kafka开启kerberos和ACL
作者:恩慈 一、部署kafka-KB包 1.上传软件包 依次点击 部署中心----部署组件----上传软件包 选择需要升级的kafka版本并点击确定 2.部署kafka 依次点击部署中心----部署组件----物理/虚拟机部署----选择集群----下一步 选择手动部署-…...
QT+winodow 代码适配调试总结(三)
问题描述: 1、开发测试环境为: A: window10 64位 B: QT版本为4.8.6 C:采用VS2017 C++ Compiler 9.0 (x86)编译器版本 根据总结(二)经验,开发环境的可执行程序显示正常; 2、新的环境运行的时候显示乱码; 经过查阅资料,还是代码环境编码配置的问题,下面为解…...
Linux之旅:常用的指令,热键和权限管理
目录 前言 1. Linux指令 (1) ls (2) pwd 和 cd (3)touch 和 mkdir (4) rmdir 和 rm (5)cp (6)mv (7)…...
简单实用的企业舆情安全解决方案
前言:企业舆情安全重要吗?其实很重要,尤其面对负面新闻,主动处理和应对,可以掌握主动权,避免股价下跌等,那么如何做使用简单实用的企业舆情解决方案呢? 背景 好了,提取词…...
【中项】系统集成项目管理工程师-第2章 信息技术发展-2.1信息技术及其发展-2.1.1计算机软硬件与2.1.2计算机网络
前言:系统集成项目管理工程师专业,现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试,全称为“全国计算机与软件专业技术资格(水平)考试”&…...
SpringBoot集成Sharding-JDBC-5.3.0实现按月动态建表分表
Sharding-JDBC系列 1、Sharding-JDBC分库分表的基本使用 2、Sharding-JDBC分库分表之SpringBoot分片策略 3、Sharding-JDBC分库分表之SpringBoot主从配置 4、SpringBoot集成Sharding-JDBC-5.3.0分库分表 5、SpringBoot集成Sharding-JDBC-5.3.0实现按月动态建表分表 前言 …...
ubuntu 上安装中文输入法
在Ubuntu上安装中文输入法,通常有以下几种方法: 方法一:使用Fcitx输入法框架和搜狗输入法 安装Fcitx: sudo apt update sudo apt install fcitx fcitx-bin fcitx-table-all 安装搜狗输入法: 首先,从搜狗…...
Postman导出excel文件
0 写在前面 在我们后端写接口的时候,前端页面还没有出来,我们就得先接口测试,在此记录下如何使用postman测试导出excel接口。 如果不会使用接口传参可以看我这篇博客如何使用Postman 1 方法一 2 方法二 3 写在末尾 虽然在代码中写入文件名…...
你还在手动构建Python项目吗?PyBuilder让一切自动化!
在 Python 项目开发中,构建和管理项目是一项繁琐但必不可少的工作。你可能需要处理依赖项、运行测试、生成文档等。这时候,PyBuilder 出场了。它是一个强大的构建自动化工具,可以帮助你简化项目管理,让你更专注于编写代码。 什么…...
WebRTC音视频-前言介绍
目录 效果预期 1:WebRTC相关简介 1.1:WebRTC和RTC 1.2:WebRTC前景和应用 2:WebRTC通话原理 2.1:媒体协商 2.2:网络协商 2.3:信令服务器 效果预期 1:WebRTC相关简介 1.1&…...
centos/rocky容器中安装xfce、xrdp记录
最近需要一台机器来测试rdp连接,使用容器linuxxfcexrdp来实现,在此记录下主要步骤 启动rockylinux容器(其他linux发行版步骤应该相似) docker run -it -p 33891:3389 rockylinux:9.3 bash容器内操作 # 省略替换软件源步骤 ...# …...
实战:Eureka的概念作用以及用法详解
概叙 什么是Eureka? Netflix Eureka 是一款由 Netflix 开源的基于 REST 服务的注册中心,用于提供服务发现功能。Spring Cloud Eureka 是 Spring Cloud Netflix 微服务套件的一部分,基于 Netflix Eureka 进行了二次封装,主要负责…...
jupyter_contrib_nbextensions安装失败问题
目录 1.文件路径长度问题 2.jupyter不出现Nbextensions选项 1.文件路径长度问题 问题: could not create build\bdist.win-amd64\wheel\.\jupyter_contrib_nbextensions\nbextensions\contrib_nbextensions_help_item\contrib_nbextensions_help_item.yaml: No su…...
设计模式-Git-其他
目录 设计模式? 创建型模式 单例模式? 啥情况需要单例模式 实现单例模式的关键点? 常见的单例模式实现? 01、饿汉式如何实现单例? 02、懒汉式如何实现单例? 03、双重检查锁定如何实现单例ÿ…...
【C#】计算两条直线的交点坐标
问题描述 计算两条直线的交点坐标,可以理解为给定坐标P1、P2、P3、P4,形成两条线,返回这两条直线的交点坐标? 注意区分:这两条线是否垂直、是否平行。 代码实现 斜率解释 斜率是数学中的一个概念,特别是…...
k8s从入门到放弃之Ingress七层负载
k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...
CMake控制VS2022项目文件分组
我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)
Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习) 一、Aspose.PDF 简介二、说明(⚠️仅供学习与研究使用)三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...
LLMs 系列实操科普(1)
写在前面: 本期内容我们继续 Andrej Karpathy 的《How I use LLMs》讲座内容,原视频时长 ~130 分钟,以实操演示主流的一些 LLMs 的使用,由于涉及到实操,实际上并不适合以文字整理,但还是决定尽量整理一份笔…...
Linux系统部署KES
1、安装准备 1.版本说明V008R006C009B0014 V008:是version产品的大版本。 R006:是release产品特性版本。 C009:是通用版 B0014:是build开发过程中的构建版本2.硬件要求 #安全版和企业版 内存:1GB 以上 硬盘…...
Python常用模块:time、os、shutil与flask初探
一、Flask初探 & PyCharm终端配置 目的: 快速搭建小型Web服务器以提供数据。 工具: 第三方Web框架 Flask (需 pip install flask 安装)。 安装 Flask: 建议: 使用 PyCharm 内置的 Terminal (模拟命令行) 进行安装,避免频繁切换。 PyCharm Terminal 配置建议: 打开 Py…...
python打卡第47天
昨天代码中注意力热图的部分顺移至今天 知识点回顾: 热力图 作业:对比不同卷积层热图可视化的结果 def visualize_attention_map(model, test_loader, device, class_names, num_samples3):"""可视化模型的注意力热力图,展示模…...
云原生时代的系统设计:架构转型的战略支点
📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 一、云原生的崛起:技术趋势与现实需求的交汇 随着企业业务的互联网化、全球化、智能化持续加深,传统的 I…...
