当前位置: 首页 > news >正文

NLP_情感分类_预训练加微调方案

文章目录

  • 项目背景
  • 代码
    • 导包
    • 一些模型以及训练的参数设置
    • 定义dataset
    • 定义模型
    • 读取数据
    • 声明训练及测试数据集
    • 将定义模型实例化
    • 打印模型结构
    • 模型训练
    • 测试集效果
  • 同类型项目


项目背景

项目的目的,是为了对情感评论数据集进行预测打标。在训练之前,需要对数据进行数据清洗环节,前面已对数据进行清洗,详情可移步至NLP_情感分类_数据清洗
前面用机器学习方案解决,详情可移步至NLP_情感分类_机器学习方案

下面对已清洗的数据集,用预训练加微调方案进行处理

代码

导包

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import numpy as np
import gc
import os
from sklearn.metrics import accuracy_score,f1_score,recall_score,precision_score
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data as D
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import warningswarnings.filterwarnings('ignore')

一些模型以及训练的参数设置

batch_size = 128
max_seq = 128
Epoch = 2
lr = 2e-5
debug_mode = False                #若开启此模式,则只读入很小的一部分数据,可以用来快速调试整个流程
num_workers = 0                    #多线程读取数据的worker的个数,由于win对多线程支持有bug,这里只能设置为0
seed = 4399
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH =  'pre_model/' #'juliensimon/reviews-sentiment-analysis'  #预训练权重的目录 or 远程地址
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)   

定义dataset

class TextDataSet(Dataset):def __init__(self, df, tokenizer, max_seq=128, debug_mode = False):self.max_seq = max_seqself.df = dfself.tokenizer = tokenizerself.debug_mode = debug_modeif self.debug_mode:self.df = self.df[:100]def __len__(self):return len(self.df)def __getitem__(self,item):sent = self.df['text'].iloc[item]enc_code = self.tokenizer.encode_plus(sent,max_length=self.max_seq,pad_to_max_length=True,truncation=True)input_ids = enc_code['input_ids']input_mask = enc_code['attention_mask']label = self.df['label'].iloc[item]return (torch.LongTensor(input_ids), torch.LongTensor(input_mask), int(label))

定义模型

class Model(nn.Module):def __init__(self,MODEL_PATH =None):super(Model, self).__init__()self.model = AutoModel.from_pretrained(MODEL_PATH)self.fc = nn.Linear(768, 2)def forward(self, input_ids, input_mask):sentence_emb = self.model(input_ids, attention_mask = input_mask).last_hidden_state # [batch,seq_len,emb_dim]sentence_emb = torch.mean(sentence_emb,dim=1)out = self.fc(sentence_emb)return out

读取数据

df = pd.read_csv('data/sentiment_analysis_clean.csv')
df = df.dropna()

声明训练及测试数据集

train_df, test_df = train_test_split(df,test_size=0.2,random_state=2024)#声明数据集
train_dataset = TextDataSet(train_df,tokenizer,debug_mode=debug_mode)
test_dataset = TextDataSet(test_df,tokenizer,debug_mode=debug_mode)#注意这里的num_workers参数,由于在win环境下对多线程支持不到位,所以这里只能设置为0,若使用服务器则可以设置其他的
train_loader = D.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_loader = D.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

将定义模型实例化

model = Model(MODEL_PATH)

打印模型结构

model

在这里插入图片描述

模型训练

criterion = nn.CrossEntropyLoss()
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.5, 0.999))
scaler = torch.cuda.amp.GradScaler()
#*****************************************************train*********************************************
for epoch in range(Epoch):model.train()correct = 0total = 0for i, batch_data in enumerate(train_loader):(input_ids, input_mask, label) = batch_datainput_ids = input_ids.to(device)input_mask = input_mask.to(device)label = label.to(device)optimizer.zero_grad()with torch.cuda.amp.autocast():logit = model(input_ids, input_mask)loss = criterion(logit, label)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()_, predicted = torch.max(logit.data, 1)total += label.size(0)correct += (predicted == label).sum().item()if i % 10 == 0:acc = 100 * correct / totalprint(f'Epoch [{epoch+1}/{Epoch}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}, Accuracy: {acc:.2f}%')

在这里插入图片描述

测试集效果

#*****************************************************Test*********************************************
correct = 0
total = 0
with torch.no_grad():model.eval()for batch_data in test_loader:(input_ids, input_mask, label) = batch_datainput_ids = input_ids.to(device)input_mask = input_mask.to(device)label = label.to(device)outputs = model(input_ids, input_mask)_, predicted = torch.max(outputs.data, 1)total += label.size(0)correct += (predicted == label).sum().item()print('#'*30+'Test Accuracy:{:7.3f} '.format(100 * correct / total)+'#'*30)print()

在这里插入图片描述



同类型项目

阿里云-零基础入门NLP【基于机器学习的文本分类】

阿里云-零基础入门NLP【基于深度学习的文本分类3-BERT】
也可以参考进行学习


学习的参考资料:
深度之眼

相关文章:

NLP_情感分类_预训练加微调方案

文章目录 项目背景代码导包一些模型以及训练的参数设置定义dataset定义模型读取数据声明训练及测试数据集将定义模型实例化打印模型结构模型训练测试集效果 同类型项目 项目背景 项目的目的,是为了对情感评论数据集进行预测打标。在训练之前,需要对数据…...

全网最适合入门的面向对象编程教程:36 Python的内置数据类型-字典

全网最适合入门的面向对象编程教程:36 Python 的内置数据类型-字典 摘要: 字典是非常好用的容器,它可以用来直接将一个对象映射到另一个对象。一个拥有属性的空对象在某种程度上说就是一个字典,属性名映射到属性值。在内部&#…...

DataWind看板绘制案例

摘要​: 1. 在不清楚DataWind看板怎么画的情况,可以先把表格给实现了,然后找几个有价值的数据进行看板实现 2. 还是不知道怎么画的情况,就去模仿其他人的案例; 3. 多看看DataWind提供的函数用法,就可以把表达式的使用运用起来了;​ 飞书官方文档:https://www.volcen…...

Golang | Leetcode Golang题解之第335题路径交叉

题目&#xff1a; 题解&#xff1a; func isSelfCrossing(distance []int) bool {n : len(distance)// 处理第 1 种情况i : 0for i < n && (i < 2 || distance[i] > distance[i-2]) {i}if i n {return false}// 处理第 j 次移动的情况if i 3 && di…...

C# 在Word中插入或删除分节符

在Word中&#xff0c;分节符是一种强大的工具&#xff0c;用于将文档分成不同的部分&#xff0c;每个部分可以有独立的页面设置&#xff0c;如页边距、纸张方向、页眉和页脚等。正确使用分节符可以极大地提升文档的组织性和专业性&#xff0c;特别是在长文档中&#xff0c;需要…...

基于STM32+Qt设计的无人超市收银系统(206)

文章目录 一、前言1.1 项目介绍【1】项目功能介绍【2】设计实现的功能【3】项目硬件模块组成1.2 设计思路【1】整体设计思路【2】上位机设计思路1.3 项目开发背景【1】选题的意义【2】可行性分析【3】参考文献【4】摘要【5】国内外技术发展现状1.4 开发工具的选择【1】设备端开…...

开源免费的表单收集系统TDuck

TDuck&#xff08;填鸭表单&#xff09;是一款开源免费的表单收集系统&#xff0c;它基于Apache 2.0协议开源&#xff0c;用户可以随时下载源码&#xff0c;自由修改和定制&#xff0c;也可以参与到项目的贡献和反馈中。TDuck表单系统不仅支持私有化部署&#xff0c;还提供了丰…...

Python 生成器、迭代器、可迭代对象 以及应用场景

Python 生成器&#xff08;Generators&#xff09; 生成器是一种特殊的迭代器&#xff0c;它使用 yield 语句来逐次产生数据&#xff0c;而不是一次性在内存中生成数据。这意呀着生成器提供了一种懒加载&#xff08;lazy evaluation&#xff09;的方式&#xff0c;非常适合处理…...

马斯克对欧盟的反应

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…...

uniapp + 安卓APP + H5 + 微信小程序实现PDF文件的预览和下载

文章目录 uniapp 安卓APP H5 微信小程序实现PDF文件的预览和下载1、用到的技术及插件2、简述操作&#xff1a;下载预览 3、上代码&#xff1a;(主要是写后端&#xff0c;前端不大熟&#xff0c;我感觉写的还凑活&#xff0c;不对的请指正嘻嘻)4、注意的问题 uniapp 安卓APP…...

Elasticsearch 8 RAG 技术分享

作者&#xff1a;来自 Elastic 中国区首席架构师 Jerry 本文由 Elastic 中国区首席架构师 Jerry Zhu 在【AI 搜索 TechDay】上的分享整理而成。【AI 搜索 TechDay】 是 Elastic 和阿里云联合主办的 AI 技术 Meetup 系列&#xff0c;聚焦企业级 AI 搜索应用和开发者动手实践&am…...

根据字典值回显,有颜色的

背景 本项目以若依前端vue2版本为例&#xff0c;项目中有根据字典值回显文本的函数selectDictLabel&#xff0c;但是有时候我们需要带颜色的回显&#xff0c;大概这样的 用法 <template v-slotscope><dict-label :options"dangerLevelOptions" :value&qu…...

多台PC网络ADB连接同一台RK3399 Android7.1.2设备

在RK3399 Android7.1.2上面&#xff0c;进行网络ADB调试时&#xff0c;如果多台电脑连接同一台Android设备&#xff0c;第一台连接上的能正常操作&#xff0c;之后连接的看到设备状态为OFFLINE&#xff0c;分析了下ADBD相关代码&#xff0c;发现在ACCEPT Client的时候没有区分别…...

前端黑科技:使用 JavaScript 实现网页扫码功能

在数字化时代&#xff0c;二维码已经渗透到我们生活的方方面面。从移动支付到产品溯源&#xff0c;二维码凭借其便捷性和高效性&#xff0c;成为了信息传递的重要载体。而随着前端技术的不断发展&#xff0c;我们甚至可以使用 JavaScript 在网页端实现二维码扫描功能&#xff0…...

【人工智能】全景解析:【机器学习】【深度学习】从基础理论到应用前景的【深度探索】

目录 1. 人工智能的基本概念 1.1 人工智能的定义与发展 1.1.1 人工智能的定义 1.1.2 人工智能的发展历史 1.2 人工智能的分类 1.2.1 弱人工智能 1.2.2 强人工智能 1.2.3 超人工智能 1.3 人工智能的关键组成部分 1.3.1 数据 1.3.2 算法 1.3.3 计算能力 2. 机器学习…...

MySQL与PostgreSQL语法区别

1. 数据类型差异 a. 整型 ● MySQL中的text数据类型最大存储容量为64KB&#xff0c;PostgreSQL中的text类型没有此限制。 ● MySQL中使用tinyint、mediumint和int表示不同大小的整数&#xff0c;PostgreSQL使用smallint、int和bigint。 b. 浮点数类型 ● MySQL提供了float和…...

vue2+OpenLayers 天地图上凸显出当前地理位置区域(4)

凸显出当前区域 需要当前地方的json数据 这个可以在阿里的这个阿里 看下效果图 遮盖层的逃命都是可以调的 引入 下面一段代码 import sx from "/views/json/sx1.json"; // 下载的json import GeoJSON from "ol/format/GeoJSON"; // ol的一些方法 imp…...

基于Python、Django开发Web计算器

1、创建项目 创建Django项目参照https://blog.csdn.net/qq_42148307/article/details/140798249&#xff0c;其中项目名为compute&#xff0c;并在该项目下创建一个名为app的应用&#xff0c;并且进行基本的配置。 2、导入Bootstrap前端框架 Bootstrap的使用参照https://blo…...

高性能并行计算面试-核心概念-问题理解

目录 1.什么是并行计算&#xff1f;高性能从哪些方面体现&#xff1f; 2.CPU常见的并行技术 3.GPU并行 4.并发与并行 5.常见的并行计算模型 6.如何评估并行程序的性能&#xff1f; 7.描述Am达尔定律和Gustafson定律&#xff0c;并解释它们对并行计算性能的影响 8.并行计…...

java-activiti笔记

版本&#xff1a;activiti7 <dependency><groupId>org.activiti</groupId><artifactId>activiti-json-converter</artifactId><version>7.0.0.Beta2</version><exclusions><exclusion><groupId>org.mybatis</g…...

**Modbus协议深度解析:基于Python的TCP通信实战与发散创新应用**在工业自动化领域,**Modbus协议

Modbus协议深度解析&#xff1a;基于Python的TCP通信实战与发散创新应用 在工业自动化领域&#xff0c;Modbus协议因其简单、稳定和开放性成为最广泛使用的串行通信标准之一。本文将从底层原理出发&#xff0c;深入剖析 Modbus TCP 的数据帧结构&#xff0c;并结合 Python 实现…...

树莓派C语言工程建立

从原来例子程序中拷贝一个例子例如blink目录到myPrj目录下&#xff0c;再拷贝其他几个文件&#xff0c;最终示意如下&#xff1a;修改CMakeLists.txt 文件&#xff0c;去除add_subdirectory(…)语句和add_subdirectory_exclude_platforms(…)语句&#xff0c;在最后增加 add_su…...

百度网盘提取码智能获取工具:提升资源访问效率的技术方案

百度网盘提取码智能获取工具&#xff1a;提升资源访问效率的技术方案 【免费下载链接】baidupankey 项目地址: https://gitcode.com/gh_mirrors/ba/baidupankey 核心价值&#xff1a;重新定义资源访问效率 &#x1f680; 在信息快速流转的今天&#xff0c;获取网络资源…...

说说你对spring的IOC的理解

面试 IOC指的就是控制反转&#xff0c;指的就是创建对象的控制权的转移&#xff0c;简单来说&#xff0c;由之前的手动new对象&#xff0c;转换成了由spring自动生产&#xff0c;spring利用java的反射机制&#xff0c;根据配置文件或注解在运行时动态创建并管理对象。...

网络安全这个技能学会了,不考研也能迅速找到高薪工作

网络安全这个技能学会了&#xff0c;不考研也能迅速找到高薪工作 近几年“考研热”持续升温&#xff0c;报名人数和报录比屡创新高。据数据显示&#xff1a;2003年全国考研人数仅仅才70万&#xff0c;直至2017年考研人数才刚刚突破200万。而今年考研人数居高达457万&#xff0…...

利用快马平台快速构建mcporter数据转换工具原型,十分钟验证数据管道设计

最近在做一个数据迁移项目时&#xff0c;遇到了需要频繁转换数据格式的需求。传统方式下&#xff0c;光是搭建开发环境、编写基础代码就要花上大半天时间。这次尝试用InsCode(快马)平台快速构建了一个mcporter数据转换工具原型&#xff0c;整个过程出乎意料地顺畅。 明确核心需…...

深度学习道路提取代码更换数据集后 PyCharm 闪退问题全面解决指南

深度学习道路提取代码更换数据集后 PyCharm 闪退问题全面解决指南 摘要 在基于深度学习的道路提取任务中,更换数据集后常出现 PyCharm 闪退现象。这类问题涉及环境配置、数据加载、内存管理、模型适配等多个层面,往往难以快速定位。本文从 Ubuntu 操作系统、PyCharm IDE、C…...

企业级React UI组件库实战指南:Element React深度解析与最佳实践

企业级React UI组件库实战指南&#xff1a;Element React深度解析与最佳实践 【免费下载链接】element-react Element UI 项目地址: https://gitcode.com/gh_mirrors/el/element-react Element React作为一款专业的企业级React UI组件库&#xff0c;为现代前端开发提供了…...

基于RAG的智能客服系统实战:从架构设计到生产环境优化

最近在做一个智能客服系统的升级项目&#xff0c;之前用规则引擎维护起来太痛苦了&#xff0c;纯用大模型又贵又不准。经过一番折腾&#xff0c;最终用RAG&#xff08;检索增强生成&#xff09;技术搞定了&#xff0c;效果提升非常明显。今天就来分享一下从架构设计到上线优化的…...

避坑指南:三自由度机械臂DH参数建模与逆解求解的那些‘坑’(从理论到Matlab/Python验证)

三自由度机械臂运动学建模实战&#xff1a;从DH参数陷阱到逆解验证 机械臂运动学建模是机器人学中最基础却最容易踩坑的领域之一。很多工程师和学生在理论学习阶段看似掌握了DH参数法和正逆运动学推导&#xff0c;但一旦动手实践&#xff0c;总会遇到各种"诡异"的问题…...