昇思大模型平台打卡体验活动:项目4基于MindSpore实现Roberta模型Prompt Tuning
基于MindNLP的Roberta模型Prompt Tuning
本文档介绍了如何基于MindNLP进行Roberta模型的Prompt Tuning,主要用于GLUE基准数据集的微调。本文提供了完整的代码示例以及详细的步骤说明,便于理解和复现实验。
环境配置
在运行此代码前,请确保MindNLP库已经安装。本文档基于大模型平台运行,因此需要进行适当的环境配置,确保代码可以在相应的平台上运行。
模型与数据集加载
在本案例中,我们使用 roberta-large 模型并基于GLUE基准数据集进行Prompt Tuning。GLUE (General Language Understanding Evaluation) 是自然语言处理中的标准评估基准,包括多个子任务,如句子相似性匹配、自然语言推理等。Prompt Tuning是一种新的微调技术,通过插入虚拟的“提示”Token在模型的输入中,以微调较少的参数达到较好的性能。
import mindspore
from tqdm import tqdm
from mindnlp import evaluate
from mindnlp.dataset import load_dataset
from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from mindnlp.core.optim import AdamW
from mindnlp.transformers.optimization import get_linear_schedule_with_warmup
from mindnlp.peft import (get_peft_model,PeftType,PromptTuningConfig,
)
1. 定义训练参数
首先,定义模型名称、数据集任务名称、Prompt Tuning类型、训练轮数等基本参数。
batch_size = 32
model_name_or_path = "roberta-large"
task = "mrpc"
peft_type = PeftType.PROMPT_TUNING
num_epochs = 20
2. 配置Prompt Tuning
在Prompt Tuning的配置中,选择任务类型为"SEQ_CLS"(序列分类任务),并定义虚拟Token的数量。虚拟Token即为插入模型输入中的“提示”Token,通过这些Token的微调,使得模型能够更好地完成下游任务。
peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
lr = 1e-3
3. 加载Tokenizer
根据模型类型选择padding的侧边,如果模型为GPT、OPT或BLOOM类模型,则从序列左侧填充(padding),否则从序列右侧填充。
if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):padding_side = "left"
else:padding_side = "right"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:tokenizer.pad_token_id = tokenizer.eos_token_id
4. 加载数据集
通过MindNLP加载GLUE数据集,并打印样本以便确认数据格式。在此示例中,我们使用GLUE的MRPC(Microsoft Research Paraphrase Corpus)任务,该任务用于句子匹配,即判断两个句子是否表达相同的意思。
datasets = load_dataset("glue", task)
print(next(datasets['train'].create_dict_iterator()))
5. 数据预处理
为了适配MindNLP的数据处理流程,我们定义了一个映射函数 MapFunc,用于将句子转换为 input_ids 和 attention_mask,并对数据进行padding处理。
from mindnlp.dataset import BaseMapFunctionclass MapFunc(BaseMapFunction):def __call__(self, sentence1, sentence2, label, idx):outputs = tokenizer(sentence1, sentence2, truncation=True, max_length=None)return outputs['input_ids'], outputs['attention_mask'], labeldef get_dataset(dataset, tokenizer):input_colums=['sentence1', 'sentence2', 'label', 'idx']output_columns=['input_ids', 'attention_mask', 'labels']dataset = dataset.map(MapFunc(input_colums, output_columns),input_colums, output_columns)dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasettrain_dataset = get_dataset(datasets['train'], tokenizer)
eval_dataset = get_dataset(datasets['validation'], tokenizer)
6. 设置评估指标
我们使用 evaluate 模块加载评估指标(accuracy 和 F1-score)来评估模型的性能。
metric = evaluate.load("./glue.py", task)
7. 加载模型并配置Prompt Tuning
加载 roberta-large 模型,并根据配置进行Prompt Tuning。可以看到,微调的参数量仅为总参数量的0.3%左右,节省了大量计算资源。
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
模型微调(Prompt Tuning)
在Prompt Tuning中,训练过程中仅微调部分参数(主要是虚拟Token相关的参数),相比于传统微调而言,大大减少了需要调整的参数量,使得模型能够高效适应下游任务。
1. 优化器与学习率调整
使用 AdamW 优化器,并设置线性学习率调整策略。
optimizer = AdamW(params=model.parameters(), lr=lr)# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=0.06 * (len(train_dataset) * num_epochs),num_training_steps=(len(train_dataset) * num_epochs),
)
2. 训练逻辑定义
训练步骤如下:
- 构建正向计算函数
forward_fn。 - 定义梯度计算函数
grad_fn。 - 定义每一步的训练逻辑
train_step。 - 遍历数据集进行训练和评估,在每个 epoch 结束时,计算评估指标。
def forward_fn(**batch):outputs = model(**batch)loss = outputs.lossreturn lossgrad_fn = mindspore.value_and_grad(forward_fn, None, tuple(model.parameters()))def train_step(**batch):loss, grads = grad_fn(**batch)optimizer.step(grads)return lossfor epoch in range(num_epochs):model.set_train()train_total_size = train_dataset.get_dataset_size()for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):loss = train_step(**batch)lr_scheduler.step()model.set_train(False)eval_total_size = eval_dataset.get_dataset_size()for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):outputs = model(**batch)predictions = outputs.logits.argmax(axis=-1)predictions, references = predictions, batch["labels"]metric.add_batch(predictions=predictions,references=references,)eval_metric = metric.compute()print(f"epoch {epoch}:", eval_metric)
在每个 epoch 后,程序输出当前模型的评估指标(accuracy 和 F1-score)。从结果中可以看到,模型的准确率和 F1-score 会随着训练的进展逐渐提升。


总结
本案例通过Prompt Tuning技术,在Roberta模型上进行了微调以适应GLUE数据集任务。通过控制微调参数量,Prompt Tuning展示了较强的高效性。
相关文章:
昇思大模型平台打卡体验活动:项目4基于MindSpore实现Roberta模型Prompt Tuning
基于MindNLP的Roberta模型Prompt Tuning 本文档介绍了如何基于MindNLP进行Roberta模型的Prompt Tuning,主要用于GLUE基准数据集的微调。本文提供了完整的代码示例以及详细的步骤说明,便于理解和复现实验。 环境配置 在运行此代码前,请确保…...
hadoop 3.x 伪分布式搭建
hadoop 伪分布式搭建 环境 CentOS 7jdk 1.8hadoop 3.3.6 1. 准备 准备环境所需包上传所有压缩包到服务器 2. 安装jdk # 解压jdk到/usr/local目录下 tar -xvf jdk-8u431-linux-x64.tar.gz -C /usr/local先不着急配置java环境变量,后面和hadoop一起配置 3. 安装had…...
springboot 整合mybatis
一,引入MyBatis起步依赖 <!--mybatis依赖--><dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><version>3.0.0</version></dependency> 二&a…...
餐饮门店收银系统源码、php收银系统源码
1. 系统开发语言 核心开发语言: PHP、HTML5、Dart后台接口: PHP7.3后台管理网站: HTML5vue2.0element-uicssjs线下收银台(安卓/PC收银、安卓自助收银): Dart3框架:Flutter 3.19.6移动店务助手: uniapp线上商城: uniapp 2.系统概况及适用行业…...
canal1.1.7使用canal-adapter进行mysql同步数据
重要的事情说前面,canal1.1.8需要jdk11以上,大家自行选择,我这由于项目原因只能使用1.1.7兼容版的 文章参考地址: canal 使用详解_canal使用-CSDN博客 使用canal.deployer-1.1.7和canal.adapter-1.1.7实现mysql数据同步_mysql更…...
揭秘文心一言,智能助手新体验
一、产品描述 文心一言是一款集先进人工智能技术与自然语言处理能力于一体的智能助手软件。它采用了深度学习算法和大规模语料库训练,具备强大的语义理解和生成能力。通过简洁直观的用户界面,文心一言能够与用户进行流畅的对话交流,理解用户…...
良心无广,这5款才是你电脑上该装的神仙软件,很多人都不知道
图吧工具箱 这是一款完全纯净的硬件检测工具包,体积小巧不足0.5MB,却全面整合了CPU、硬盘、内存、显卡等电脑大神常用的检测工具与压力测试软件。 还特别为游戏爱好者们准备了直达平台官网的链接以及Directx修复工具,而且全部免费哦…...
Scala图书馆创建图书信息
图书馆书籍管理系统相关的练习。内容要求: 1.创建一个可变 Set,用于存储图书馆中的书籍信息(假设书籍信息用字符串表示,如 “Java 编程思想”“Scala 实战” 等),初始化为包含几本你喜欢的书籍。 2.添加两本…...
【Python】深入理解Python中的单例模式:用元类、装饰器和模块实现高效的单例设计
解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 单例模式是一种重要的设计模式,旨在确保一个类的实例在整个应用程序中仅存在一个。Python作为一种动态语言,为实现单例模式提供了多种方式…...
Flutter 小技巧之 Shader 实现酷炫的粒子动画
在之前的《不一样的思路实现炫酷 3D 翻页折叠动画》我们其实介绍过:如何使用 Shader 去实现一个 3D 的翻页效果,具体就是使用 Flutter 在 3.7 开始提供 Fragment Shader API ,因为每个像素都会过 Fragment Shader ,所以我们可以通…...
【LeetCode】【算法】42. 接雨水
LeetCode 42. 接雨水 题目描述 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例: 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数…...
深⼊理解指针(5)[回调函数、qsort相关知识(qsort可用于各种类型变量的排序)】
目录 1. 回调函数 2. qsort相关知识(qsort可用于各种类型变量的排序) 一 回调函数 1定义/作用:把函数的指针(地址)作为参数传递给另⼀个函数,当这个指针被⽤来调⽤其所指向的函数 时,被调⽤的函数就…...
qt QRunnable 与 QThreadPool详解
1. 概述 QRunnable是所有runnable对象的基类,它表示一个任务或要执行的代码。开发者需要子类化QRunnable并重写其run()函数来实现具体的任务逻辑。而QThreadPool则是一个管理QThread集合的类,它帮助减少创建线程的成本,通过管理和循环使用单…...
博客摘录「 java三年工作经验面试题整理《精华》」2023年6月12日
JDK 和 JRE 有什么区别?JDK:java 开发工具包,提供了 java 的开发环境和运行环境。JRE:java 运行环境,为 java 的运行提供了所需环境。JDK 其实包含了 JRE,同时还包含了编译 java 源码的编译器 javac&#x…...
福禄克FLUKE5500A与fluke5520a校准仪的区别功能
FLUKE5500A是美国福禄克公司的一款高性能的多功能校准仪,能够对手持式和台式多用表、示波器、示波表、功率计、电子温度表、数据采集器、功率谐波分析仪、进程校准器等多种仪器进行校准。 FLUKE5500A多功能校准仪供给了GPIB(IEEE-488)、RS-2…...
量化交易系统开发-实时行情自动化交易-2.技术栈
2019年创业做过一年的量化交易但没有成功,作为交易系统的开发人员积累了一些经验,最近想重新研究交易系统,一边整理一边写出来一些思考供大家参考,也希望跟做量化的朋友有更多的交流和合作。 本篇谈谈系统主要可以选择的技术栈&a…...
【逆向爬虫实战】--全方位分析+某某学堂登录(DES加密)
🤵♂️ 个人主页:rain雨雨编程 😄微信公众号:rain雨雨编程 ✍🏻作者简介:持续分享机器学习,爬虫,数据分析 🐋 希望大家多多支持,我们一起进步! …...
第2关:装载问题 (最优队列法)
问题描述 任务描述 相关知识 编程要求 测试说明 问题描述 有一批共个集装箱要装上 2 艘载重量分别为 C1 和 C2 的轮船,其中集 装箱i的重量为 Wi ,且 装载问题要求确定是否有一个合理的装载方案可将这个集装箱装上这 2 艘轮船。如果有,找出一种…...
萤石设备视频接入平台EasyCVR海康私有化视频平台监控硬盘和普通硬盘有何区别?
在现代安防监控领域,对于数据存储和视频处理的需求日益增长,特别是在需要长时间、高稳定性监控的环境中,选择合适的存储设备和监控系统显得尤为重要。本文将深入探讨监控硬盘与普通硬盘的区别,并详细介绍海康私有化视频平台EasyCV…...
【Webpack配置全解析】打造你的专属构建流程️(4)
webpack 提供的 CLI 支持很多参数,例如 --mode,但更多的时候,我们会使用更加灵活的配置文件来控制 webpack 的行为。默认情况下,webpack 会读取 webpack.config.js 文件作为配置文件,但也可以通过 CLI 参数 --config 来…...
MPNet:旋转机械轻量化故障诊断模型详解python代码复现
目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...
深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录
ASP.NET Core 是一个跨平台的开源框架,用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录,以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...
为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?
在建筑行业,项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升,传统的管理模式已经难以满足现代工程的需求。过去,许多企业依赖手工记录、口头沟通和分散的信息管理,导致效率低下、成本失控、风险频发。例如&#…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
华硕a豆14 Air香氛版,美学与科技的馨香融合
在快节奏的现代生活中,我们渴望一个能激发创想、愉悦感官的工作与生活伙伴,它不仅是冰冷的科技工具,更能触动我们内心深处的细腻情感。正是在这样的期许下,华硕a豆14 Air香氛版翩然而至,它以一种前所未有的方式&#x…...
【JVM】Java虚拟机(二)——垃圾回收
目录 一、如何判断对象可以回收 (一)引用计数法 (二)可达性分析算法 二、垃圾回收算法 (一)标记清除 (二)标记整理 (三)复制 (四ÿ…...
MinIO Docker 部署:仅开放一个端口
MinIO Docker 部署:仅开放一个端口 在实际的服务器部署中,出于安全和管理的考虑,我们可能只能开放一个端口。MinIO 是一个高性能的对象存储服务,支持 Docker 部署,但默认情况下它需要两个端口:一个是 API 端口(用于存储和访问数据),另一个是控制台端口(用于管理界面…...
nnUNet V2修改网络——暴力替换网络为UNet++
更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...
【Linux】Linux安装并配置RabbitMQ
目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的,需要先安…...
