当前位置: 首页 > news >正文

昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()

加载IMDB数据集。将IMDB数据集分为训练集和测试集。IMDB (Internet Movie Database) 数据集包含来自著名在线电影数据库 IMDB 的电影评论。每条评论都被标注为正面(positive)或负面(negative),因此该数据集是一个二分类问题,也就是情感分类问题。

import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

定义数据预处理函数。这个函数输入参数为数据集、分词器(GPT Tokenizer)以及一些可选参数,如最大序列长度、批量大小和是否打乱数据。预处理包括将文本转换为模型可以理解的输入格式(如input_ids和attention_mask),并将标签转换为整数类型。

from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

加载GPT分词器并增加特殊标记。

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

将训练集划分为训练集和验证集。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

用 process_dataset 函数对训练集、验证集和测试集进行处理,得到相应的数据集对象。

next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)

导入 GPTForSequenceClassification 模型和 Adam 优化器。设置GPT模型的配置信息,包括pad_token_id和词汇表大小。使用Adam优化器对模型的可训练参数进行优化(从这里没有看出是更新部分参数,还是全部参数,有可能是部分参数。通常会改变最后一层分类器的权重和偏置,其他层的权重被冻结不变或者只微小更新些许参数。)。

Accuracy作为评价指标。

定义回调函数用于保存检查点:

   - CheckpointCallback:用于定期保存模型权重,save_path 指定了保存路径,ckpt_name保存文件的前缀,epochs=1 每个epoch保存一次,keep_checkpoint_max=2 表示最多保留2个检查点文件。
   - BestModelCallback:用于保存验证集上表现最好的模型,auto_load=True表示在训练结束后自动加载最优模型的权重。

创建 Trainer 对象,传入以下参数:
      - network:要训练的模型。
      - train_dataset:训练数据集。
      - eval_dataset:验证数据集。
      - metrics:评估指标。
      - epochs:训练轮数。
      - optimizer:优化器。
      - callbacks:回调函数列表,包括检查点保存和最佳模型保存。
      - jit:是否启用JIT编译,这里设置为False。

trainer.run(tgt_columns="labels")

通过 Trainer 的 run 方法启动训练,指定了训练过程中的目标标签列为 "labels"。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

创建 Evaluator 对象,传入以下参数:
      - network:要评估的模型。
      - eval_dataset:测试数据集。
      - metrics:评估指标。

用MindSpore通过GPT实现情感分类(Sentiment Classification)的示例。首先加载了IMDB影评数据集,并将其划分为训练集、验证集和测试集。然后使用GPTTokenizer对文本进行了标记化和转换。接下来,使用GPTForSequenceClassification构建了情感分类模型,并定义了优化器和评估指标。使用Trainer进行模型的训练,并设置了保存检查点的回调函数。训练完成后,通过Evaluator对测试集进行评估,输出分类准确率。通过对IMDB影评数据集进行训练和评估,模型可以自动进行情感分类,识别出正面或负面情感。

相关文章:

昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品&#xff1f;每天30分钟&#xff0c;25天打通AI任督二脉 (qq.com) 基于MindSpore通过GPT实现情感分类 %%capture captured_output # 实验环境已经预装了mindspore2.2.14&#xff0c;如需更换mindspore版本&#xff0c;可更改下面mindspore的版本号 !pip uninsta…...

【Python】变量与基本数据类型

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️Python】 文章目录 前言变量声明变量变量的命名规则 变量赋值多个变量赋值 标准数据类型变量的使用方式存储和访问数据&#xff1a;参与逻辑运算和数学运算在函数间传递数据构建复杂的数据结构 NameE…...

Unity按键表大全

Unity键值对应表# KeyCode是由Event.keyCode返回的。这些直接映射到键盘上的物理键&#xff0c;以下是键值对应列表&#xff1a; 常用键# Backspace 退格键 Delete Delete键 TabTab键 Clear Clear键 Return 回车键 Pause 暂停键 Escape ESC键 Space 空格键 小键盘# …...

第一周java。2

方法的作用 将重复的代码包装起来&#xff0c;写成方法&#xff0c;提高代码的复用性。 方法的语法 方法的语法格式如下 : [修饰符] 方法返回值类型 方法名(形参列表) { //由零条到多条可执行性语句组成的方法体return 返回值; } 定义方法语法格式的详细说明如下&#xf…...

Arduino - Keypad 键盘

Arduino - Keypad Arduino - Keypad The keypad is widely used in many devices such as door lock, ATM, calculator… 键盘广泛应用于门锁、ATM、计算器等多种设备中。 In this tutorial, we will learn: 在本教程中&#xff0c;我们将学习&#xff1a; How to use key…...

国产芯片方案/蓝牙咖啡电子秤方案研发

咖啡电子秤芯片方案精确值可做到分度值0.1g的精准称重,并带有过载提示、自动归零、去皮称重、压低报警等功能&#xff0c;工作电压在2.4V~3.6V之间&#xff0c;满足于咖啡电子秤的电压使用。同时咖啡电子秤PCBA设计可支持四个单位显示&#xff0c;分别为&#xff1a;g、lb、oz、…...

reactjs18 中使用@reduxjs/toolkit同步异步数据的使用

react18 中使用reduxjs/toolkit 1.安装依赖包 yarn add reduxjs/toolkit react-redux2.创建 store 根目录下面创建 store 文件夹&#xff0c;然后创建 index.js 文件。 import { configureStore } from "reduxjs/toolkit"; import { counterReducer } from "…...

剧本杀小程序:助力商家发展,提高游戏体验

近几年&#xff0c;剧本杀游戏已经成为了当下年轻人娱乐的游戏社交方式。与其他游戏相比&#xff0c;剧本杀游戏具有强大的社交性&#xff0c;玩家在游戏中既可以推理玩游戏&#xff0c;也可以与其他玩家交流互动&#xff0c;提高玩家的游戏体验感。 随着互联网的发展&#xf…...

pikachu靶场 利用Rce上传一句话木马案例(工具:中国蚁剑)

目录 一、准备靶场&#xff0c;进入RCE 二、测试写入文件 三、使用中国蚁剑 一、准备靶场&#xff0c;进入RCE 我这里用的是pikachu 打开pikachu靶场&#xff0c;选择 RCE > exec "ping" 测试是否存在 Rce 漏洞 因为我们猜测在这个 ping 功能是直接调用系统…...

CenterOS7安装java

CenterOS7安装java #进入安装目录 cd /usr/local/soft/java#wget下载java8 #直接进入官网选择相应的版本进行下载&#xff0c;然后把下载链接复制下来就可以下载了 #不时间的下载链接不一样 wget http://download.oracle.com/otn-pub/java/jdk/8u181-b13/96a7b8442fe848ef90c9…...

react 重新加载子组件

在React中&#xff0c;要重新加载某个子组件&#xff0c;你可以通过改变该组件的key属性来强制它重新渲染。这是因为React会在key变化时销毁旧的组件实例并创建一个新的实例。 多的不说直接上代码 import React, { useState } from react; import ChildComponent from ../chil…...

从零开始使用WordPress搭建个人网站并一键发布公网详细教程

文章目录 前言1. 搭建网站&#xff1a;安装WordPress2. 搭建网站&#xff1a;创建WordPress数据库3. 搭建网站&#xff1a;安装相对URL插件4. 搭建网站&#xff1a;内网穿透发布网站4.1 命令行方式&#xff1a;4.2. 配置wordpress公网地址 5. 固定WordPress公网地址5.1. 固定地…...

浅谈chrome引擎

Chrome引擎主要包括其浏览器内核Blink、JavaScript引擎V8以及其渲染、网络、安全等子系统。下面我将对这些关键部分进行简要说明分析 1. Blink浏览器内核 Blink是Google开发的浏览器排版引擎&#xff0c;自Chrome 28版本起替代了Webkit作为Chrome的渲染引擎。Blink基于Webkit…...

【常用知识点-Java】创建文件夹

Author&#xff1a;赵志乾 Date&#xff1a;2024-07-04 Declaration&#xff1a;All Right Reserved&#xff01;&#xff01;&#xff01; 1. 简介 java.io.File提供了mkdir()和mkdirs()方法创建文件夹&#xff0c;两者区别&#xff1a;mkdir()仅创建单层文件夹&#xff0c;如…...

【JavaScript脚本宇宙】颜色处理神器大比拼:哪款JavaScript库最适合你?

提升设计与开发效率&#xff1a;深入解析六大颜色处理库 前言 在现代前端开发中&#xff0c;颜色处理是设计和用户体验的重要组成部分。无论是网页设计、数据可视化还是图形设计&#xff0c;都需要强大的颜色处理功能来实现多样化的视觉效果。本文将探讨几种流行的JavaScript…...

怎么录制电脑内部声音?好用的录音软件分享,看这篇就够了!

如何录制电脑内部声音&#xff1f;平时使用电脑工作&#xff0c;难免会遇到需要录音的情况。好用的录音软件有很多&#xff0c;也有部分录屏工具也支持录音功能。 那么如何录制电脑内部声音呢&#xff1f;本文整理了几个录制电脑内部声音的方法&#xff0c;如果你需要在电脑上录…...

ios CCNSDate.m

// // CCNSDate.h // CCFC // // Created by xichen on 11-12-17. // Copyright 2011年 ccteam. All rights reserved. //#import <Foundation/Foundation.h>interface NSDate(cc)// 获取系统时间(yyyy-MM-dd HH:mm:ss.SSS格式)(NSString *)getSystemTimeStr;// prin…...

Windows系统安装SSH服务结合内网穿透配置公网地址远程ssh连接

前言 在当今的数字化转型时代&#xff0c;远程连接和管理计算机已成为日常工作中不可或缺的一部分。对于 Windows 用户而言&#xff0c;SSH&#xff08;Secure Shell&#xff09;协议提供了一种安全、高效的远程访问和命令执行方式。SSH 不仅提供了加密的通信通道&#xff0c;…...

虚拟机与主机的联通

本地光纤分配地址给路由器--》连结路由器是连结局域网--》由路由器分配IP地址 因此在网站上搜索的IP与本机的IP是不一样的 1.windows查看主机IP地址 在终端输入 2.linux虚拟机查看ip 3.主机是否联通虚拟机ping加ip...

2024年中国网络安全市场全景图 -百度下载

是自2018年开始&#xff0c;数说安全发布的第七版全景图。 企业数智化转型加速已经促使网络安全成为全社会关注的焦点&#xff0c;在网络安全边界不断扩大&#xff0c;新理念、新产品、新技术不断融合发展的进程中&#xff0c;数说安全始终秉承科学的方法论&#xff0c;以遵循…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上&#xff0c;你可以使用apt包管理器来安装NFS服务器。打开终端并运行&#xff1a; sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享&#xff0c;例如/shared&#xff1a; sudo mkdir /shared sud…...

k8s从入门到放弃之Ingress七层负载

k8s从入门到放弃之Ingress七层负载 在Kubernetes&#xff08;简称K8s&#xff09;中&#xff0c;Ingress是一个API对象&#xff0c;它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress&#xff0c;你可…...

MySQL 隔离级别:脏读、幻读及不可重复读的原理与示例

一、MySQL 隔离级别 MySQL 提供了四种隔离级别,用于控制事务之间的并发访问以及数据的可见性,不同隔离级别对脏读、幻读、不可重复读这几种并发数据问题有着不同的处理方式,具体如下: 隔离级别脏读不可重复读幻读性能特点及锁机制读未提交(READ UNCOMMITTED)允许出现允许…...

2.Vue编写一个app

1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA&#xff08;Direct Memory Access&#xff09;直接存储器存取 DMA可以提供外设…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

图表类系列各种样式PPT模版分享

图标图表系列PPT模版&#xff0c;柱状图PPT模版&#xff0c;线状图PPT模版&#xff0c;折线图PPT模版&#xff0c;饼状图PPT模版&#xff0c;雷达图PPT模版&#xff0c;树状图PPT模版 图表类系列各种样式PPT模版分享&#xff1a;图表系列PPT模板https://pan.quark.cn/s/20d40aa…...

Web 架构之 CDN 加速原理与落地实践

文章目录 一、思维导图二、正文内容&#xff08;一&#xff09;CDN 基础概念1. 定义2. 组成部分 &#xff08;二&#xff09;CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 &#xff08;三&#xff09;CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 &#xf…...

JAVA后端开发——多租户

数据隔离是多租户系统中的核心概念&#xff0c;确保一个租户&#xff08;在这个系统中可能是一个公司或一个独立的客户&#xff09;的数据对其他租户是不可见的。在 RuoYi 框架&#xff08;您当前项目所使用的基础框架&#xff09;中&#xff0c;这通常是通过在数据表中增加一个…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...