当前位置: 首页 > news >正文

昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nnfrom mindnlp.dataset import load_datasetfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()

加载IMDB数据集。将IMDB数据集分为训练集和测试集。IMDB (Internet Movie Database) 数据集包含来自著名在线电影数据库 IMDB 的电影评论。每条评论都被标注为正面(positive)或负面(negative),因此该数据集是一个二分类问题,也就是情感分类问题。

import numpy as npdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:dataset = dataset.shuffle(batch_size)# map datasetdataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

定义数据预处理函数。这个函数输入参数为数据集、分词器(GPT Tokenizer)以及一些可选参数,如最大序列长度、批量大小和是否打乱数据。预处理包括将文本转换为模型可以理解的输入格式(如input_ids和attention_mask),并将标签转换为整数类型。

from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

加载GPT分词器并增加特殊标记。

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

将训练集划分为训练集和验证集。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

用 process_dataset 函数对训练集、验证集和测试集进行处理,得到相应的数据集对象。

next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)

导入 GPTForSequenceClassification 模型和 Adam 优化器。设置GPT模型的配置信息,包括pad_token_id和词汇表大小。使用Adam优化器对模型的可训练参数进行优化(从这里没有看出是更新部分参数,还是全部参数,有可能是部分参数。通常会改变最后一层分类器的权重和偏置,其他层的权重被冻结不变或者只微小更新些许参数。)。

Accuracy作为评价指标。

定义回调函数用于保存检查点:

   - CheckpointCallback:用于定期保存模型权重,save_path 指定了保存路径,ckpt_name保存文件的前缀,epochs=1 每个epoch保存一次,keep_checkpoint_max=2 表示最多保留2个检查点文件。
   - BestModelCallback:用于保存验证集上表现最好的模型,auto_load=True表示在训练结束后自动加载最优模型的权重。

创建 Trainer 对象,传入以下参数:
      - network:要训练的模型。
      - train_dataset:训练数据集。
      - eval_dataset:验证数据集。
      - metrics:评估指标。
      - epochs:训练轮数。
      - optimizer:优化器。
      - callbacks:回调函数列表,包括检查点保存和最佳模型保存。
      - jit:是否启用JIT编译,这里设置为False。

trainer.run(tgt_columns="labels")

通过 Trainer 的 run 方法启动训练,指定了训练过程中的目标标签列为 "labels"。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

创建 Evaluator 对象,传入以下参数:
      - network:要评估的模型。
      - eval_dataset:测试数据集。
      - metrics:评估指标。

用MindSpore通过GPT实现情感分类(Sentiment Classification)的示例。首先加载了IMDB影评数据集,并将其划分为训练集、验证集和测试集。然后使用GPTTokenizer对文本进行了标记化和转换。接下来,使用GPTForSequenceClassification构建了情感分类模型,并定义了优化器和评估指标。使用Trainer进行模型的训练,并设置了保存检查点的回调函数。训练完成后,通过Evaluator对测试集进行评估,输出分类准确率。通过对IMDB影评数据集进行训练和评估,模型可以自动进行情感分类,识别出正面或负面情感。

相关文章:

昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品&#xff1f;每天30分钟&#xff0c;25天打通AI任督二脉 (qq.com) 基于MindSpore通过GPT实现情感分类 %%capture captured_output # 实验环境已经预装了mindspore2.2.14&#xff0c;如需更换mindspore版本&#xff0c;可更改下面mindspore的版本号 !pip uninsta…...

【Python】变量与基本数据类型

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️Python】 文章目录 前言变量声明变量变量的命名规则 变量赋值多个变量赋值 标准数据类型变量的使用方式存储和访问数据&#xff1a;参与逻辑运算和数学运算在函数间传递数据构建复杂的数据结构 NameE…...

Unity按键表大全

Unity键值对应表# KeyCode是由Event.keyCode返回的。这些直接映射到键盘上的物理键&#xff0c;以下是键值对应列表&#xff1a; 常用键# Backspace 退格键 Delete Delete键 TabTab键 Clear Clear键 Return 回车键 Pause 暂停键 Escape ESC键 Space 空格键 小键盘# …...

第一周java。2

方法的作用 将重复的代码包装起来&#xff0c;写成方法&#xff0c;提高代码的复用性。 方法的语法 方法的语法格式如下 : [修饰符] 方法返回值类型 方法名(形参列表) { //由零条到多条可执行性语句组成的方法体return 返回值; } 定义方法语法格式的详细说明如下&#xf…...

Arduino - Keypad 键盘

Arduino - Keypad Arduino - Keypad The keypad is widely used in many devices such as door lock, ATM, calculator… 键盘广泛应用于门锁、ATM、计算器等多种设备中。 In this tutorial, we will learn: 在本教程中&#xff0c;我们将学习&#xff1a; How to use key…...

国产芯片方案/蓝牙咖啡电子秤方案研发

咖啡电子秤芯片方案精确值可做到分度值0.1g的精准称重,并带有过载提示、自动归零、去皮称重、压低报警等功能&#xff0c;工作电压在2.4V~3.6V之间&#xff0c;满足于咖啡电子秤的电压使用。同时咖啡电子秤PCBA设计可支持四个单位显示&#xff0c;分别为&#xff1a;g、lb、oz、…...

reactjs18 中使用@reduxjs/toolkit同步异步数据的使用

react18 中使用reduxjs/toolkit 1.安装依赖包 yarn add reduxjs/toolkit react-redux2.创建 store 根目录下面创建 store 文件夹&#xff0c;然后创建 index.js 文件。 import { configureStore } from "reduxjs/toolkit"; import { counterReducer } from "…...

剧本杀小程序:助力商家发展,提高游戏体验

近几年&#xff0c;剧本杀游戏已经成为了当下年轻人娱乐的游戏社交方式。与其他游戏相比&#xff0c;剧本杀游戏具有强大的社交性&#xff0c;玩家在游戏中既可以推理玩游戏&#xff0c;也可以与其他玩家交流互动&#xff0c;提高玩家的游戏体验感。 随着互联网的发展&#xf…...

pikachu靶场 利用Rce上传一句话木马案例(工具:中国蚁剑)

目录 一、准备靶场&#xff0c;进入RCE 二、测试写入文件 三、使用中国蚁剑 一、准备靶场&#xff0c;进入RCE 我这里用的是pikachu 打开pikachu靶场&#xff0c;选择 RCE > exec "ping" 测试是否存在 Rce 漏洞 因为我们猜测在这个 ping 功能是直接调用系统…...

CenterOS7安装java

CenterOS7安装java #进入安装目录 cd /usr/local/soft/java#wget下载java8 #直接进入官网选择相应的版本进行下载&#xff0c;然后把下载链接复制下来就可以下载了 #不时间的下载链接不一样 wget http://download.oracle.com/otn-pub/java/jdk/8u181-b13/96a7b8442fe848ef90c9…...

react 重新加载子组件

在React中&#xff0c;要重新加载某个子组件&#xff0c;你可以通过改变该组件的key属性来强制它重新渲染。这是因为React会在key变化时销毁旧的组件实例并创建一个新的实例。 多的不说直接上代码 import React, { useState } from react; import ChildComponent from ../chil…...

从零开始使用WordPress搭建个人网站并一键发布公网详细教程

文章目录 前言1. 搭建网站&#xff1a;安装WordPress2. 搭建网站&#xff1a;创建WordPress数据库3. 搭建网站&#xff1a;安装相对URL插件4. 搭建网站&#xff1a;内网穿透发布网站4.1 命令行方式&#xff1a;4.2. 配置wordpress公网地址 5. 固定WordPress公网地址5.1. 固定地…...

浅谈chrome引擎

Chrome引擎主要包括其浏览器内核Blink、JavaScript引擎V8以及其渲染、网络、安全等子系统。下面我将对这些关键部分进行简要说明分析 1. Blink浏览器内核 Blink是Google开发的浏览器排版引擎&#xff0c;自Chrome 28版本起替代了Webkit作为Chrome的渲染引擎。Blink基于Webkit…...

【常用知识点-Java】创建文件夹

Author&#xff1a;赵志乾 Date&#xff1a;2024-07-04 Declaration&#xff1a;All Right Reserved&#xff01;&#xff01;&#xff01; 1. 简介 java.io.File提供了mkdir()和mkdirs()方法创建文件夹&#xff0c;两者区别&#xff1a;mkdir()仅创建单层文件夹&#xff0c;如…...

【JavaScript脚本宇宙】颜色处理神器大比拼:哪款JavaScript库最适合你?

提升设计与开发效率&#xff1a;深入解析六大颜色处理库 前言 在现代前端开发中&#xff0c;颜色处理是设计和用户体验的重要组成部分。无论是网页设计、数据可视化还是图形设计&#xff0c;都需要强大的颜色处理功能来实现多样化的视觉效果。本文将探讨几种流行的JavaScript…...

怎么录制电脑内部声音?好用的录音软件分享,看这篇就够了!

如何录制电脑内部声音&#xff1f;平时使用电脑工作&#xff0c;难免会遇到需要录音的情况。好用的录音软件有很多&#xff0c;也有部分录屏工具也支持录音功能。 那么如何录制电脑内部声音呢&#xff1f;本文整理了几个录制电脑内部声音的方法&#xff0c;如果你需要在电脑上录…...

ios CCNSDate.m

// // CCNSDate.h // CCFC // // Created by xichen on 11-12-17. // Copyright 2011年 ccteam. All rights reserved. //#import <Foundation/Foundation.h>interface NSDate(cc)// 获取系统时间(yyyy-MM-dd HH:mm:ss.SSS格式)(NSString *)getSystemTimeStr;// prin…...

Windows系统安装SSH服务结合内网穿透配置公网地址远程ssh连接

前言 在当今的数字化转型时代&#xff0c;远程连接和管理计算机已成为日常工作中不可或缺的一部分。对于 Windows 用户而言&#xff0c;SSH&#xff08;Secure Shell&#xff09;协议提供了一种安全、高效的远程访问和命令执行方式。SSH 不仅提供了加密的通信通道&#xff0c;…...

虚拟机与主机的联通

本地光纤分配地址给路由器--》连结路由器是连结局域网--》由路由器分配IP地址 因此在网站上搜索的IP与本机的IP是不一样的 1.windows查看主机IP地址 在终端输入 2.linux虚拟机查看ip 3.主机是否联通虚拟机ping加ip...

2024年中国网络安全市场全景图 -百度下载

是自2018年开始&#xff0c;数说安全发布的第七版全景图。 企业数智化转型加速已经促使网络安全成为全社会关注的焦点&#xff0c;在网络安全边界不断扩大&#xff0c;新理念、新产品、新技术不断融合发展的进程中&#xff0c;数说安全始终秉承科学的方法论&#xff0c;以遵循…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

多场景 OkHttpClient 管理器 - Android 网络通信解决方案

下面是一个完整的 Android 实现&#xff0c;展示如何创建和管理多个 OkHttpClient 实例&#xff0c;分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...

【解密LSTM、GRU如何解决传统RNN梯度消失问题】

解密LSTM与GRU&#xff1a;如何让RNN变得更聪明&#xff1f; 在深度学习的世界里&#xff0c;循环神经网络&#xff08;RNN&#xff09;以其卓越的序列数据处理能力广泛应用于自然语言处理、时间序列预测等领域。然而&#xff0c;传统RNN存在的一个严重问题——梯度消失&#…...

Linux简单的操作

ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...

对WWDC 2025 Keynote 内容的预测

借助我们以往对苹果公司发展路径的深入研究经验&#xff0c;以及大语言模型的分析能力&#xff0c;我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际&#xff0c;我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测&#xff0c;聊作存档。等到明…...

cf2117E

原题链接&#xff1a;https://codeforces.com/contest/2117/problem/E 题目背景&#xff1a; 给定两个数组a,b&#xff0c;可以执行多次以下操作&#xff1a;选择 i (1 < i < n - 1)&#xff0c;并设置 或&#xff0c;也可以在执行上述操作前执行一次删除任意 和 。求…...

【2025年】解决Burpsuite抓不到https包的问题

环境&#xff1a;windows11 burpsuite:2025.5 在抓取https网站时&#xff0c;burpsuite抓取不到https数据包&#xff0c;只显示&#xff1a; 解决该问题只需如下三个步骤&#xff1a; 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1

每日一言 生活的美好&#xff0c;总是藏在那些你咬牙坚持的日子里。 硬件&#xff1a;OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写&#xff0c;"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

全志A40i android7.1 调试信息打印串口由uart0改为uart3

一&#xff0c;概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本&#xff1a;2014.07&#xff1b; Kernel版本&#xff1a;Linux-3.10&#xff1b; 二&#xff0c;Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01)&#xff0c;并让boo…...