当前位置: 首页 > news >正文

【NLP实战】基于Bert和双向LSTM的情感分类【下篇】

文章目录

  • 前言
  • 简介
  • 第一部分
    • 关于pytorch lightning保存模型的机制
    • 关于如何读取保存好的模型
    • 完善测试代码
  • 第二部分
    • 第一次训练出的模型的过拟合问题
    • 如何解决过拟合
  • 后记

前言

本文涉及的代码全由博主自己完成,可以随意拿去做参考。如对代码有不懂的地方请联系博主。

博主page:issey的博客 - 愿无岁月可回首

本系列文章中不会说明环境和包如何安装,这些应该是最基础的东西,可以自己边查边安装。

许多函数用法等在代码里有详细解释,但还是希望各位去看它们的官方文档,我的代码还有很多可以改进的方法,需要的函数等在官方文档都有说明。

简介

本系列将带领大家从数据获取、数据清洗,模型构建、训练,观察loss变化,调整超参数再次训练,并最后进行评估整一个过程。我们将获取一份公开竞赛中文数据,并一步步实验,到最后,我们的评估可以达到排行榜13位的位置。但重要的不是排名,而是我们能在其中学到很多。

本系列共分为三篇文章,分别是:

  • 上篇:数据获取,数据分割与数据清洗
  • 中篇:模型构建,改进pytorch结构,开始第一次训练
  • 下篇:测试与评估,绘图与过拟合,超参数调整

本文为该系列第三篇文章,也是最后一篇。本文共分为两部分,在第一部分,我们将学习如何使用pytorch lightning保存模型的机制、如何读取模型与对测试集做测试。第二部分,我们将探讨前文遇到的过拟合问题,调整我们的超参数,进行第二轮训练,并对比两次训练的区别。我们还将基于pytorch lightning实现回调函数,保存训练过程中val_loss最小的模型。最后,将我们第二轮训练的best model进行评估,这一次,模型在测试集上的表现将达到排行榜第13位。

第一部分

关于pytorch lightning保存模型的机制

官方文档:Saving and loading checkpoints (basic) — PyTorch Lightning 2.0.1 documentation

简单来说,每次用lightning进行训练时,他都会自动保存最近epoch训练出的model参数在checkpoints里。而checkpoints默认在lightning_logs目录下。

你还可以同时保存某次训练的参数,或者写回调函数改变它保存模型的机制(这个我们待会儿会用到)。当然你也可以设置不让它自动保存模型。这一切都在官方文档里。博主就不细讲这些细节了,建议读者自己做实验。

现在我们知道了重要的两件事:

  1. 默认情况下,它会自动保存最近一次epoch训练结束后的模型。
  2. 我们只需要写回调函数,就可以改变它保存模型的机制。

关于如何读取保存好的模型

官方文档:Deploy models into production (basic) — PyTorch Lightning 2.0.1 documentation

根据文档,你还可以不用pytorch lightning,将模型读取到单纯的pytorch中,也可以使用。

感觉这部分讲的有点水?因为都在文档里,感觉没有需要逐一说明的地方。

现在,完善我们进行测试的代码。

完善测试代码

有几点需要说明:我们在测试时还计算了常用的评估标准:acc,recall,pre,f1。这里博主将通常需要用到的评估标准写法逐一列出了。我是根据函数说明一点一点摸索出来的,所以一并写出来方便以后用。

import torch
from datasets import load_dataset  # hugging-face dataset
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import torch.optim as optim
from torch.nn.functional import one_hot
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torchmetrics.functional import accuracy, recall, precision, f1_score  # lightning中的评估
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint# todo:自定义数据集
class MydataSet(Dataset):def __init__(self, path, split):self.dataset = load_dataset('csv', data_files=path, split=split)def __getitem__(self, item):text = self.dataset[item]['text']label = self.dataset[item]['label']return text, labeldef __len__(self):return len(self.dataset)# todo: 定义批处理函数
def collate_fn(data):sents = [i[0] for i in data]labels = [i[1] for i in data]# 分词并编码data = token.batch_encode_plus(batch_text_or_text_pairs=sents,  # 单个句子参与编码truncation=True,  # 当句子长度大于max_length时,截断padding='max_length',  # 一律补pad到max_length长度max_length=200,return_tensors='pt',  # 以pytorch的形式返回,可取值tf,pt,np,默认为返回listreturn_length=True,)# input_ids:编码之后的数字# attention_mask:是补零的位置是0,其他位置是1input_ids = data['input_ids']  # input_ids 就是编码后的词attention_mask = data['attention_mask']  # pad的位置是0,其他位置是1token_type_ids = data['token_type_ids']  # (如果是一对句子)第一个句子和特殊符号的位置是0,第二个句子的位置是1labels = torch.LongTensor(labels)  # 该批次的labels# print(data['length'], data['length'].max())return input_ids, attention_mask, token_type_ids, labels# todo: 定义模型,上游使用bert预训练,下游任务选择双向LSTM模型,最后加一个全连接层
class BiLSTMClassifier(nn.Module):def __init__(self, drop, hidden_dim, output_dim):super(BiLSTMClassifier, self).__init__()self.drop = dropself.hidden_dim = hidden_dimself.output_dim = output_dim# 加载bert中文模型,生成embedding层self.embedding = BertModel.from_pretrained('bert-base-chinese')# 去掉移至gpu# 冻结上游模型参数(不进行预训练模型参数学习)for param in self.embedding.parameters():param.requires_grad_(False)# 生成下游RNN层以及全连接层self.lstm = nn.LSTM(input_size=768, hidden_size=self.hidden_dim, num_layers=2, batch_first=True,bidirectional=True, dropout=self.drop)self.fc = nn.Linear(self.hidden_dim * 2, self.output_dim)# 使用CrossEntropyLoss作为损失函数时,不需要激活。因为实际上CrossEntropyLoss将softmax-log-NLLLoss一并实现的。def forward(self, input_ids, attention_mask, token_type_ids):embedded = self.embedding(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)embedded = embedded.last_hidden_state  # 第0维才是我们需要的embedding,embedding.last_hidden_state = embedding[0]out, (h_n, c_n) = self.lstm(embedded)output = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)output = self.fc(output)return output# todo: 定义pytorch lightning
class BiLSTMLighting(pl.LightningModule):def __init__(self, drop, hidden_dim, output_dim):super(BiLSTMLighting, self).__init__()self.model = BiLSTMClassifier(drop, hidden_dim, output_dim)  # 设置modelself.criterion = nn.CrossEntropyLoss()  # 设置损失函数self.train_dataset = MydataSet('./data/archive/train_clean.csv', 'train')self.val_dataset = MydataSet('./data/archive/val_clean.csv', 'train')self.test_dataset = MydataSet('./data/archive/test_clean.csv', 'train')def configure_optimizers(self):optimizer = optim.AdamW(self.parameters(), lr=lr)return optimizerdef forward(self, input_ids, attention_mask, token_type_ids):  # forward(self,x)return self.model(input_ids, attention_mask, token_type_ids)def train_dataloader(self):train_loader = DataLoader(dataset=self.train_dataset, batch_size=batch_size, collate_fn=collate_fn,shuffle=True)return train_loaderdef training_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batch  # x, y = batchy = one_hot(labels + 1, num_classes=3)# 将one_hot_labels类型转换成floaty = y.to(dtype=torch.float)# forward passy_hat = self.model(input_ids, attention_mask, token_type_ids)y_hat = y_hat.squeeze()  # 将[128, 1, 3]挤压为[128,3]loss = self.criterion(y_hat, y)  # criterion(input, target)self.log('train_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)  # 将loss输出在控制台return loss  # 必须把log返回回去才有用def val_dataloader(self):val_loader = DataLoader(dataset=self.val_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)return val_loaderdef validation_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batchy = one_hot(labels + 1, num_classes=3)y = y.to(dtype=torch.float)# forward passy_hat = self.model(input_ids, attention_mask, token_type_ids)y_hat = y_hat.squeeze()loss = self.criterion(y_hat, y)self.log('val_loss', loss, prog_bar=False, logger=True, on_step=True, on_epoch=True)return lossdef test_dataloader(self):test_loader = DataLoader(dataset=self.test_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=False)return test_loaderdef test_step(self, batch, batch_idx):input_ids, attention_mask, token_type_ids, labels = batchtarget = labels + 1  # 用于待会儿计算acc和f1-scorey = one_hot(target, num_classes=3)y = y.to(dtype=torch.float)# forward passy_hat = self.model(input_ids, attention_mask, token_type_ids)y_hat = y_hat.squeeze()pred = torch.argmax(y_hat, dim=1)acc = (pred == target).float().mean()loss = self.criterion(y_hat, y)self.log('loss', loss)# task: Literal["binary", "multiclass", "multilabel"],对应[二分类,多分类,多标签]#  average=None分别输出各个类别, 不加默认算平均re = recall(pred, target, task="multiclass", num_classes=class_num, average=None)pre = precision(pred, target, task="multiclass", num_classes=class_num, average=None)f1 = f1_score(pred, target, task="multiclass", num_classes=class_num, average=None)def log_score(name, scores):for i, score_class in enumerate(scores):self.log(f"{name}_class{i}", score_class)log_score("recall", re)log_score("precision", pre)log_score("f1", f1)self.log('acc', accuracy(pred, target, task="multiclass", num_classes=class_num))self.log('avg_recall', recall(pred, target, task="multiclass", num_classes=class_num, average="weighted"))self.log('avg_precision', precision(pred, target, task="multiclass", num_classes=class_num, average="weighted"))self.log('avg_f1', f1_score(pred, target, task="multiclass", num_classes=class_num, average="weighted"))def test():# 加载之前训练好的最优模型参数model = BiLSTMLighting.load_from_checkpoint(checkpoint_path=PATH,drop=dropout, hidden_dim=rnn_hidden, output_dim=class_num)trainer = Trainer(fast_dev_run=False)result = trainer.test(model)print(result)

输出:也就是上一篇末尾提前剧透的截图。

第二部分

第一次训练出的模型的过拟合问题

为什么提到之前的模型有过拟合问题呢?让我们打开tensorboard,观察train_lossval_loss

train_loss还没有收敛的趋势,但是val_loss已经出现了反弹的趋势。如果这还不算过拟合的预兆,博主做了第二个实验,我读取了第一次模型训练好的参数,并在次基础上继续训练,于是出现了以下的图像:

红色的线。可以看到,train_loss跟着橙色的线继续下降的,而val_loss直线上升,并且train_loss低于0.3时,val_loss高达0.9+。于是我们可以断定,过拟合了!

如何解决过拟合

最简单的方式是调参,我将batch_size由128调整到了256,将drop从0.4调整到了0.5,再次进行训练。同时,为了防止第二次也过拟合,我加入了回调函数,这个回调函数将保存过拟合之前最好的一组模型。这个回调函数的作用极为重要。下面给出最终版本的train代码:

def train():# 增加过拟合回调函数,提前停止,经过测试发现不太好用,因为可能会停止在局部最优值early_stop_callback = EarlyStopping(monitor='val_loss',  # 监控对象为'val_loss'patience=4,  # 耐心观察4个epochmin_delta=0.0,  # 默认为0.0,指模型性能最小变化量verbose=True,  # 在输出中显示一些关于early stopping的信息,如为何停止等)# 增加回调最优模型,这个比较好用checkpoint_callback = ModelCheckpoint(monitor='val_loss',  # 监控对象为'val_loss'dirpath='checkpoints/',  # 保存模型的路径filename='model-{epoch:02d}-{val_loss:.2f}',  # 最优模型的名称save_top_k=1,  # 只保存最好的那个  mode='min'  # 当监控对象指标最小时)# Trainer可以帮助调试,比如快速运行、只使用一小部分数据进行测试、完整性检查等,# 详情请见官方文档https://lightning.ai/docs/pytorch/latest/debug/debugging_basic.html# auto自适应gpu数量trainer = Trainer(max_epochs=epochs, log_every_n_steps=10, accelerator='gpu', devices="auto", fast_dev_run=False,precision=16, callbacks=[checkpoint_callback])model = BiLSTMLighting(drop=dropout, hidden_dim=rnn_hidden, output_dim=class_num)trainer.fit(model)if __name__ == '__main__':# todo:定义超参数batch_size = 256epochs = 30dropout = 0.5rnn_hidden = 768rnn_layer = 1class_num = 3lr = 0.001PATH = 'PATH'token = BertTokenizer.from_pretrained('bert-base-chinese')train()# test()

把他加入到上面的代码就行了。

关于回调函数的说明在代码里。

在第二天早上,我拿到了这次训练的结果:

对比第一个模型:

好吧,这次还是过拟合了,而且train loss居然低于了0.1,说明模型太复杂了。不过!由于我们的回调函数的存在,我们及时保存了val_loss最小时的模型。现在,将我们的模型路径换成best model,再次对测试集进行评估,我们会得到以下结果:

现在,它在排行榜第13位。

后记

终于写完了,一天肝完三篇文章。前面实验时在边实验边记录,所以写的比较快。

好像也没什么要写成后记的,该说的也都说完了。这三篇文章,其实就是这次实验的后记(笑)。

歇一歇,累~

还有很多不知道和要改进的地方,继续努力吧。

相关文章:

【NLP实战】基于Bert和双向LSTM的情感分类【下篇】

文章目录前言简介第一部分关于pytorch lightning保存模型的机制关于如何读取保存好的模型完善测试代码第二部分第一次训练出的模型的过拟合问题如何解决过拟合后记前言 本文涉及的代码全由博主自己完成,可以随意拿去做参考。如对代码有不懂的地方请联系博主。 博主…...

程序设计方法学

体育竞技分析 问题分析 体育竞技分析 需求:毫厘是多少? 如何科学分析体育竞技比赛? 输入:球员的水平 输出:可预测的比赛成绩 体育竞技分析:模拟N场比赛 计算思维:抽象 自动化 模拟&am…...

Hadoop之Yarn篇

目录 ​编辑 Yarn的工作机制: 全流程作业: Yarn的调度器与调度算法: FIFO调度器(先进先出): 容量调度器(Capacity Scheduler): 容量调度器资源分配算法&#xff1…...

Spring Cloud Nacos使用总结

目录 安装Nacos服务器 服务发现与消费 服务发现与消费-添加依赖 服务发现-配置文件 服务发现-注解 服务发现-Controller 服务消费-配置文件 服务消费-注解与Ribbon消费代码 服务消费-运行 配置管理 配置管理-添加依赖 配置管理-配置文件 配置管理-注解 配置管理-…...

目标检测框架yolov5环境搭建

目前,目标检测框架中,yolov5 是很火的,它基于pytorch框架,集成opencv等框架,项目地址:https://github.com/ultralytics/yolov5,对我来说,机器学习、深度学习才开始接触,本…...

Vulnhub:Digitalworld.local (JOY)靶机

kali:192.168.111.111 靶机:192.168.111.130 信息收集 端口扫描 nmap -A -v -sV -T5 -p- --scripthttp-enum 192.168.111.130 使用enum4linux枚举目标smb服务,发现两个系统用户 enum4linux -a 192.168.111.130 ftp可以匿名登陆&#xff…...

STL源码剖析-六大部件, 部件的关系,复杂度, 区间表示

C标准库-体系结构与内核分析 根据源代码来分析 介绍 自学C侯捷老师的STL源码剖析的个人笔记,方便以后进行学习,查询。 为什么要学STL?按侯捷老师的话来说就是:使用一个东西,却不明白它的道理,不高明&…...

总有一个可用的连接,metaIPC1.2进入智能连接新时代

概述 metaIPC有1.0和2.0两个产品系列,2.0版本是可视对讲IPC,1.0新版本1.2在全面兼容ICE规范基础上进行了扩展,使metaIPC1.2进入智能化连接新时代。 metaIPC1.2在host/stun/turn/srs/zlm/janus/freeswitch等p2p/sfu/mcu进行全方位连通测试&a…...

棋盘问题c

在一个给定形状的棋盘(形状可能是不规则的)上面摆放棋子,棋子没有区别。要求摆放时任意的两个棋子不能放在棋盘中的同一行或者同一列,请编程求解对于给定形状和大小的棋盘,摆放k个棋子的所有可行的摆放方案C。 Input …...

华纳云:Linux系统下怎么创建普通用户并更改用户组

本篇内容主要讲解“Linux系统下怎么创建普通用户并更改用户组”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Linux系统下怎么创建普通用户并更改用户组”吧! 要求 项目做权限管理,不用root部…...

「她时代」背后的欧拉力量

2018年大热电视剧《北京女子图鉴》,讲述了一群在北京打拼的职业女性,她们背井离乡,被现实包裹,被压力、责任困扰,但依旧用倔强的个性、不屈的进取心和深厚的知识技能努力营造、交织出一片励志的天空,既激昂…...

kubespray v2.21.0 在线部署 kubernetes v1.24.0 集群【2】

文章目录创建 虚拟机模板虚拟机名称配置静态地址配置代理yum 配置配置主机名安装 git安装 docker安装 ansible配置内核参数安装 k8s定制安装新增节点配置主机名配置代理配置互信更新 inventory报错kubespray v2.21.0 部署 kubernetes v1.24.0 集群 【1】在 Rocky linux 8.7 使用…...

聚焦运营商信创运维,美信时代监控易四大亮点值得一试!

2021年11月《“十四五”信息通信行业发展规划》提出,到2025年,我国将建立高速泛在、集成互联、智能绿色、安全可靠的新型数字基础设施体系。 此《规划》让我国运营商信创进一步加速,中国移动、中国电信、中国联通等都先后加入信创大军&#x…...

[python刷题模板] 博弈入门-记忆化搜索/dp/打表

[python刷题模板] 博弈入门-记忆化搜索/dp/打表 一、 算法&数据结构1. 描述2. 复杂度分析3. 常见应用4. 常用优化二、 模板代码1. 打表贪心的博弈2. 464. 我能赢吗3. Nim游戏--最最基础版n1。三、其他四、更多例题五、参考链接一、 算法&数据结构 1. 描述 博弈一直没…...

I2C通信

一、理论上了解I2C时序 I2C写数据时序如图: 通过解析器解析I2C通信如上图(SCL和SDA反了)。 1---起始信号 2、3---应答信号ACK 5---停止信号 起始信号:SCL线是高电平时,SDA线从高电平向低电平切换。 停…...

【Linux】man什么都搜不了,No manual entry for xxx的解决方案

本文首发于 慕雪的寒舍 man什么都搜不了,No manual entry for xxx的解决方案 系统 CentOS 7.6 1.问题描述 今天查手册的时候,发现man什么都查不了。不管是系统接口还是函数,都显示没有入口文档(No manual entry for)…...

STM32 库函数 GPIO_SetBits、GPIO_ResetBits、GPIO_WriteBit、GPIO_Write 区别

问题:当我使用STM32库函数对 I/O 口进行赋值时,在头文件中发现有四个相关的函数可以做这个操作,那么它们有什么区别呢? 一、GPIO_SetBits //eg: GPIO_SetBits(GPIOA, GPIO_Pin_1 | GPIO_Pin_2);解释:置位(置1)选择的数…...

在 RISC-V Linux 内核中添加模块

在 RISC-V Linux 内核中添加模块 flyfish 本例以添加helloworld字符设备为例 一 源码配置 1 源码 源码文件helloworld.c拷贝到 drivers/char 目录中 源码主要是输出Hello world init 2 Kconfig 打开drivers/char 目录下的Kconfig文件 在endmenu之前加上 config HELLO…...

利用AOP实现统一功能处理

目录 一、实现用户登录校验 实现自定义拦截器 将自定义的拦截器添加到框架的配置中,并且设置拦截的规则 二、实现统一异常处理 三、实现统一数据格式封装 一、实现用户登录校验 在之前的项目中,在需要验证用户登录的部分,每次都需要利…...

会话技巧---英文单词

目录 前言原文表示同意、答应表示不同意表示建议与忠告鼓励称赞担心与忧虑赞美夸奖-单词前言 加油 原文 表示同意、答应 1.agree[əˈgri]vi. 同意(=approve of); 答应(= consent to) agreement [əˈgrimənt] n. (意见或看法)一致 agree with sb about / on sth…...

华为云AI开发平台ModelArts

华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0:开发环境同步测试 cookie 至 localhost,便于本地请求服务携带 cookie 参考地址:https://juejin.cn/post/7139354571712757767 里面有源码下载下来,加在到扩展即可使用FeHelp…...

Spark 之 入门讲解详细版(1)

1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据

微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列,以便知晓哪些列包含有价值的数据,…...

根目录0xa0属性对应的Ntfs!_SCB中的FileObject是什么时候被建立的----NTFS源代码分析--重要

根目录0xa0属性对应的Ntfs!_SCB中的FileObject是什么时候被建立的 第一部分: 0: kd> g Breakpoint 9 hit Ntfs!ReadIndexBuffer: f7173886 55 push ebp 0: kd> kc # 00 Ntfs!ReadIndexBuffer 01 Ntfs!FindFirstIndexEntry 02 Ntfs!NtfsUpda…...

如何配置一个sql server使得其它用户可以通过excel odbc获取数据

要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据,你需要完成以下配置步骤: ✅ 一、在 SQL Server 端配置(服务器设置) 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到:SQL Server 网络配…...

高考志愿填报管理系统---开发介绍

高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...

[特殊字符] 手撸 Redis 互斥锁那些坑

📖 手撸 Redis 互斥锁那些坑 最近搞业务遇到高并发下同一个 key 的互斥操作,想实现分布式环境下的互斥锁。于是私下顺手手撸了个基于 Redis 的简单互斥锁,也顺便跟 Redisson 的 RLock 机制对比了下,记录一波,别踩我踩过…...