当前位置: 首页 > news >正文

MXNet中使用双向循环神经网络BiRNN对文本进行情感分类

文本分类类似于图片分类,也是很常见的一种分类任务,将一段不定长的文本序列变换为文本的类别。这节主要就是关注文本的情感分析(sentiment analysis),对电影的评论进行一个正面情绪与负面情绪的分类。

整理数据集

第一步都是将数据集整理好,这里我们使用"大型电影评论数据集"LMDB(Large Movie Review Dataset v1.0),该数据集包含电影评论及其相关二进制情感标签。标签的整体分布是平衡的,一半的正类标签和一半的负类标签,另外有一些未贴标签的用于无监督学习。电影评分满分是10分,将评分>=7分的判定为正面评论,评论得分<= 4分则为负面评论。

下载数据集,可以使用自带的函数

import d2lzh as d2l
d2l.download_imdb(data_dir='data')

或者手动下载:http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz

自动下载虽然只有80M的大小,但是下载特别慢。这里依然推荐迅雷下载,下载下来之后就手动解压(自动下载的函数包括自动解压)

我们先来看下这个数据集里面有一些什么内容,本人地址截图如下:

可以看到有traintest两个数据集,里面都有negpos的评论,分别表示负面和正面的评论:

每个文本是一条影评,文本名称构造:id_评分,比如上面图中的200_8.txt表示id为200的这条影评的评分是8分。

还有一种feat文件,如下图:

这种.feat文件的格式为LIBSVM,是一种用于标记的ascii稀疏向量格式数据,比如图片中红色划线处的第200条评论,8后面的数字表示什么意思呢?

8 0:5 1:2 3:1 4:2 6:4 7:7 8:4 9:2 10:2 11:3 16:1 17:3 ... ...

这里的0:5表示第一个单词出现了5次,1:2就是第二个单词出现了2次,后面依次类推。

接下来使用自带的read_imdb函数来读取训练集和测试集,当然这里使用自带的函数需要注意目录的位置,将aclImdb整个目录剪切到上级目录data里面,比如本人电脑上的地址:D:\data\aclImdb

train_data, test_data = d2l.read_imdb("train"), d2l.read_imdb("test")
print(train_data[1])
'''
(pygpu) D:\DOG-BREED>python test.py
["i went to this movie expecting an artsy scary film. what i got was scare after scare. it's a horror film at it's core. it's not dull like other horror films where a haunted house just has ghosts and gore. this film doesn't even show you the majority of the deaths it shows the fear of the characters. i think one of the best things about the concept where it's not just the house thats haunted its whoever goes into the house. they become haunted no matter where they are. office buildings, police stations, hotel rooms... etc. after reading some of the external reviews i am really surprised that critics didn't like this film. i am going to see it again this week and am excited about it.<br /><br />i gave this film 10 stars because it did what a horror film should. it scared the s**t out of me.", 1]
'''

返回的结果是列表,里面元素是评论加一个正负类标签。这里是赞叹这部恐怖片拍的很不错,后面的1表示正类评价。

上面两个函数的源码附上[../envs/pygpu/Lib/site-packages/d2lzh/utils.py]

def download_imdb(data_dir='../data'):"""Download the IMDB data set for sentiment analysis."""url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')sha1 = '01ada507287d82875905620988597833ad4e0903'fname = gutils.download(url, data_dir, sha1_hash=sha1)with tarfile.open(fname, 'r') as f:f.extractall(data_dir)def read_imdb(folder='train'):"""Read the IMDB data set for sentiment analysis."""data = []for label in ['pos', 'neg']:folder_name = os.path.join('../data/aclImdb/', folder, label)for file in os.listdir(folder_name):with open(os.path.join(folder_name, file), 'rb') as f:review = f.read().decode('utf-8').replace('\n', '').lower()data.append([review, 1 if label == 'pos' else 0])random.shuffle(data)return data

预处理数据集

数据集和测试集读取没有问题之后,我们对评论进行分词,这里基于空格分词,也是自带的函数get_tokenized_imdb进行分词并做了小写处理。

def get_tokenized_imdb(data):"""Get the tokenized IMDB data set for sentiment analysis."""def tokenizer(text):return [tok.lower() for tok in text.split(' ')]return [tokenizer(review) for review, _ in data]

然后将分好词的训练数据集创建Vocabulary词典,我们这里过滤掉出现次数少于5的词,min_freq=5。

def get_vocab_imdb(data):"""Get the vocab for the IMDB data set for sentiment analysis."""tokenized_data = get_tokenized_imdb(data)counter = collections.Counter([tk for st in tokenized_data for tk in st])return text.vocab.Vocabulary(counter, min_freq=5)tokenized_data = d2l.get_tokenized_imdb(train_data)
vocab=d2l.get_vocab_imdb(train_data)
print(len(vocab))#46151

可以看到过滤掉次数少的之后,词汇量从25000降低到了46151,这里返回的变量vocabmxnet.contrib.text.vocab.Vocabulary类型,我们可以查看它里面有哪些属性与方法:

dir(mxnet.contrib.text.vocab.Vocabulary)
'''
['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_index_counter_keys', '_index_unknown_and_reserved_tokens', 'idx_to_token', 'reserved_tokens', 'to_indices', 'to_tokens', 'token_to_idx', 'unknown_token']
'''
print(vocab.idx_to_token[1])#the

由于每条评论的字数或说长度不一样,所以不能直接组合成小批量,我们通过一个辅助函数让它的长度固定在500,超出的进行截断,不足的进行'<pad>'补足。这个函数preprocess_imdb在d2lzh包中也自带有

features, labels = d2l.preprocess_imdb(train_data, vocab)
print(features.shape, labels.shape)#(25000, 500) (25000,)

从形状可以看到每条评论都固定到了长度为500

print(features)
'''
[[5.0000e+00 5.3200e+02 0.0000e+00 ... 0.0000e+00 0.0000e+00 0.0000e+00][2.0100e+02 5.4810e+03 4.2891e+04 ... 1.6000e+01 2.9200e+02 1.1000e+01][0.0000e+00 0.0000e+00 3.6000e+01 ... 0.0000e+00 0.0000e+00 0.0000e+00]...[9.0000e+00 2.2600e+02 3.0000e+00 ... 0.0000e+00 0.0000e+00 0.0000e+00][2.8690e+03 1.2220e+03 1.4000e+01 ... 1.1538e+04 5.2700e+02 2.9000e+01][9.0000e+00 1.9900e+02 1.2108e+04 ... 0.0000e+00 0.0000e+00 0.0000e+00]]
<NDArray 25000x500 @cpu(0)>
'''

附上源码:

def preprocess_imdb(data, vocab):"""Preprocess the IMDB data set for sentiment analysis."""max_l = 500def pad(x):return x[:max_l] if len(x) > max_l else x + [0] * (max_l - len(x))tokenized_data = get_tokenized_imdb(data)features = nd.array([pad(vocab.to_indices(x)) for x in tokenized_data])labels = nd.array([score for _, score in data])return features, labels

当然如果想要查看'<pad>'对应的值,print(vocab.token_to_idx['<pad>'])会报错:

Traceback (most recent call last):
File "test.py", line 19, in <module>
print(vocab.token_to_idx['<pad>'])
KeyError: '<pad>'

所以在创建词典Vocabulary的时候,需指定参数reserved_tokens=['<pad>']保留这个词

def get_vocab_imdb(data):"""Get the vocab for the IMDB data set for sentiment analysis."""tokenized_data = d2l.get_tokenized_imdb(data)counter = collections.Counter([tk for st in tokenized_data for tk in st])return text.vocab.Vocabulary(counter, min_freq=5,reserved_tokens=['<pad>'])

创建数据迭代器

数据集都整理好了之后,就开始做数据迭代器,每次迭代将返回一个小批量的数据

batch_size = 64
#train_set = gdata.ArrayDataset(*d2l.preprocess_imdb(train_data, vocab))
train_set=gdata.ArrayDataset(*[features,labels])
test_set = gdata.ArrayDataset(*d2l.preprocess_imdb(test_data, vocab))
train_iter = gdata.DataLoader(train_set, batch_size, shuffle=True)
test_ieter = gdata.DataLoader(test_set, batch_size)print(len(train_iter))
for X,y in train_iter:print(X.shape,y.shape)break
'''
391
(64, 500) (64,)
'''

创建RNN模型

数据迭代器测试没有问题之后,接下来就是选择循环神经网络模型来试下效果怎么样了。

首先就是将每个词做嵌入,也就是通过嵌入层得到特征向量,然后我们使用双向循环神经网络对特征序列进一步编码得到序列信息,最后将编码的序列信息通过全连接层变换成输出。

具体来说,我们可以将双向长短期记忆在最初时间步和最终时间步的隐藏状态连结,作为特征序列的表征传递给输出层分类。在下面实现BiRNN类中,Embedding实例就是嵌入层,LSTM实例即为序列编码的隐藏层,Dense实例即生成分类结果的输出层。

class BiRNN(nn.Block):def __init__(self, vocab, embed_size, num_hiddens, num_layers, **kwargs):super(BiRNN, self).__init__(**kwargs)# 词嵌入层self.embedding = nn.Embedding(input_dim=len(vocab), output_dim=embed_size)# bidirectional设为True就是双向循环神经网络self.encoder = rnn.LSTM(hidden_size=num_hiddens,num_layers=num_layers,bidirectional=True,input_size=embed_size,)self.decoder = nn.Dense(2)def forward(self, inputs):# LSTM需要序列长度(词数)作为第一维,所以inputs[形状为:(批量大小,词数)]需做转置embeddings = self.embedding(inputs.T)print(embeddings.shape)outputs = self.encoder(embeddings)print(outputs.shape)# 将初始时间步和最终时间步的隐藏状态作为全连接层输入encoding = nd.concat(outputs[0], outputs[-1])print(encoding.shape)outs = self.decoder(encoding)return outs# 创建一个含2个隐藏层的双向循环神经网络
embed_size, num_hiddens, num_layers, ctx = 100, 100, 2, d2l.try_all_gpus()
net = BiRNN(vocab=vocab, embed_size=embed_size, num_hiddens=num_hiddens, num_layers=num_layers
)
net.initialize(init.Xavier(), ctx=ctx)
#print(net)
'''
BiRNN((embedding): Embedding(46152 -> 100, float32)(encoder): LSTM(100 -> 100, TNC, num_layers=2, bidirectional)(decoder): Dense(None -> 2, linear)
)
'''

其中LSTM长短期记忆的公式如下(来自源码):

训练模型

由于情感分类的训练数据集并不大,容易过拟合,所以这里将使用glove.6B.100d.txt的语料库,将这个预训练的词向量作为每个词的特征向量。

需要注意的是,这里选择的预训练词向量维度是100,需要跟创建的模型中的嵌入层输出层大小embed_size一致,以及在训练中就不再需要更新这些词向量。

glove_embedding = text.embedding.create("glove", pretrained_file_name="glove.6B.100d.txt", vocabulary=vocab
)
net.embedding.weight.set_data(glove_embedding.idx_to_vec)
net.embedding.collect_params().setattr('grad_req','null')lr,num_epochs=0.01,5
trainer=gluon.Trainer(net.collect_params(),'adam',{'learning_rate':lr})
loss=gloss.SoftmaxCrossEntropyLoss()
d2l.train(train_iter,test_ieter,net,loss,trainer,ctx,num_epochs)print(d2l.predict_sentiment(net,vocab,['this','movie','is','so','good']))
print(d2l.predict_sentiment(net,vocab,['this','movie','is','so','bad']))
'''
training on [gpu(0)]
epoch 1, loss 0.6553, train acc 0.605, test acc 0.738, time 65.4 sec
epoch 2, loss 0.4273, train acc 0.807, test acc 0.809, time 65.4 sec
epoch 3, loss 0.3514, train acc 0.851, test acc 0.849, time 65.5 sec
epoch 4, loss 0.3054, train acc 0.874, test acc 0.859, time 65.6 sec
epoch 5, loss 0.2765, train acc 0.887, test acc 0.843, time 65.6 sec
positive
negative
'''

其中预测函数的源码如下:

def predict_sentiment(net, vocab, sentence):"""Predict the sentiment of a given sentence."""sentence = nd.array(vocab.to_indices(sentence), ctx=try_gpu())label = nd.argmax(net(sentence.reshape((1, -1))), axis=1)return 'positive' if label.asscalar() == 1 else 'negative'

相关文章:

MXNet中使用双向循环神经网络BiRNN对文本进行情感分类

文本分类类似于图片分类&#xff0c;也是很常见的一种分类任务&#xff0c;将一段不定长的文本序列变换为文本的类别。这节主要就是关注文本的情感分析(sentiment analysis)&#xff0c;对电影的评论进行一个正面情绪与负面情绪的分类。整理数据集第一步都是将数据集整理好&…...

SpringBoot 整合 MongoDB 6 以上版本副本集及配置 SSL / TLS 协议

续上一篇 Linux 中使用 docker-compose 部署 MongoDB 6 以上版本副本集及配置 SSL / TLS 协议 前提&#xff1a;此篇文章是对上一篇文章的实战和项目中相关配置的使用&#xff0c;我这边针对 MongoDB 原有基础上做了增强&#xff0c;简化了 MongoDB 配置 SSL / TLS 协议上的支…...

C语言static关键字

目录static修饰局部变量static修饰全局变量static修饰函数static是C语言的关键字&#xff0c;它有静态的意思static的三种用法&#xff1a;修饰局部变量修饰全局变量修饰函数 static修饰局部变量 我们先看一个程序&#xff1a; void print() {int a 0;a;printf("%d\n&…...

【华为OD机试模拟题】用 C++ 实现 - 单词接龙(2023.Q1)

最近更新的博客 【华为OD机试模拟题】用 C++ 实现 - 货币单位换算(2023.Q1) 【华为OD机试模拟题】用 C++ 实现 - 选座位(2023.Q1) 【华为OD机试模拟题】用 C++ 实现 - 停车场最大距离(2023.Q1) 【华为OD机试模拟题】用 C++ 实现 - 重组字符串(2023.Q1) 【华为OD机试模…...

PHP基础(2)

PHP基础常用函数数组及多维数组数组遍历强制类型转换运算符赋值与基本运算字符串运算逻辑运算符常用函数 substr的用法是&#xff1a;substr&#xff08;目标字符串&#xff0c;从字符串的哪个位置开始&#xff0c;然后返回往后的几个字符&#xff09;strchr的用法是&#xff1…...

Java8(JDK1.8)新特性

一、Java8(JDK1.8)新特性 1、Lamdba表达式 2、函数式接口 3、方法引用和构造引用 4、Stream API 5、接口中的默认方法和静态方法 6、新时间日期API 7、OPtional 8、其他特性 二、java8&#xff08;JDK1.8&#xff09;新特性简介 1、速度快&#xff1b; 2、代码少、简…...

【C语言】指针的定义和使用

指针一、什么是指针二、指针类型三、指针和数组的关系四、空指针五、野指针一、什么是指针 指针&#xff08;Pointer&#xff09;是编程语言中的一个对象&#xff0c;通过地址直接指向内存中该地址的值。由于通过地址能够找到所需的变量存储单元&#xff0c;可以说地址指向该变…...

Parameter ‘zpspid‘ not found

异常&#xff1a;nested exception is org.apache.ibatis.binding.BindingException: Parameter testypid not found. Available parameters are [ztpsXmjcxx, pageable, param1, param2]分析&#xff1a;以为是xml文件中没有对应的字段&#xff0c;一细看了几遍是有这个字段的…...

23、高自由度下的E类波形理论计算(附Matlab代码)

23、高自由度下的E类波形理论计算&#xff08;附Matlab代码&#xff09; 0、代码 任意占空比、电压导数条件下的E类波形与阻抗条件计算Matlab 注意修改路径&#xff0c;我这边是&#xff1a;&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#xff01;&#…...

软件测试:用“bug”来表示“在电脑程序里的错误”

计算机基础知识计算机&#xff08;personal computer&#xff09;俗称电脑&#xff08;pc&#xff09;&#xff0c;是现代一种用于高速计算的电子机器&#xff0c;可以进行数值计算&#xff0c;又可以进行逻辑判断&#xff0c;还具有存储记忆功能&#xff0c;且能够按照程序的运…...

Git命令

git init # 初始化本地git仓库&#xff08;创建新仓库&#xff09;git config --global user.name "xxx" # 配置用户名git config --global user.email "xxxxxx.com" # 配置邮件git config --global color.ui true # git status等命令自动着色git config -…...

Java的异常概念和类型

Java是一种流行的编程语言&#xff0c;拥有强大的异常处理机制&#xff0c;以帮助开发人员在程序出现异常时更好地处理错误情况。本文将介绍Java异常的概念和类型。异常的概念在Java中&#xff0c;异常是指在程序运行时发生的错误或异常情况。例如&#xff0c;当程序试图打开不…...

【Leedcode】环形链表必备的面试题和证明题(附图解)

环形链表必备的面试题和证明题&#xff08;附图解&#xff09; 文章目录环形链表必备的面试题和证明题&#xff08;附图解&#xff09;前言一、第一题1.题目2.思路3.代码4.延伸问题(1)证明题一&#xff1a;(2)证明题二&#xff1a;二、第二题1.题目2.思路延伸的证明题总结前言 …...

Vulnhub靶场----7、DC-7

文章目录一、环境搭建二、渗透流程三、思路总结一、环境搭建 DC-7下载地址&#xff1a;https://download.vulnhub.com/dc/DC-7.zip kali&#xff1a;192.168.144.148 DC-7&#xff1a;192.168.144.155 二、渗透流程 nmap -T5 -A -p- -sV -sT 192.168.144.155思路&#xff1a; …...

【Unity VR开发】结合VRTK4.0:创建滑块

语录&#xff1a; 只有经历地狱般的磨练&#xff0c;才能炼出创造天堂的力量。 前言&#xff1a; 滑块是一个非常简单的控件&#xff0c;它允许通过沿有限的驱动轴滑动 Interactable 来选择不同的值。我们将使用线性驱动器创建一个滑块控件&#xff0c;该控件允许我们根据与滑…...

Latex中的表格(2)

Latex中的表格一、一个加脚注的三线表的例子二、表格中加注释三、并排的表格3.1 使用小页环境并排表格3.2 使用子表格并排表格四、一个复杂的表格五、一个长表格这篇文章主要罗列一些特殊的表格例子。内容来自&#xff1a;一篇北师大学位论文模板&#xff0c;详见https://githu…...

(七)输运定理

本文主要内容包括&#xff1a;1. 物质积分2. 曲线上物质积分的时间变化率3. 曲面上物质积分的时间变化率4. 体积域上物质积分的时间变化率 (Reynolds 输运定理)1. 物质积分 考虑 t0t_0t0​ 时刻参考构型中由物质点 X⃗\vec{X}X所形成的 物质曲线 ct0c_{t_0}ct0​​、物质曲面 …...

ABBYYFineReader15免费电脑pdf文档文字识别软件

ABBYYFineReader是一款OCR文字识别软件&#xff0c;它可以对图片、文档等进行扫描识别&#xff0c;并将其转换为可编辑的格式&#xff0c;比如Word、Excel等&#xff0c;操作也是挺方便的。 我们在官网找到该软件并进行下载&#xff0c;打开软件后&#xff0c;选择转换为“Mic…...

顺序表(超详解哦)

全文目录引言顺序表定义静态顺序表动态顺序表动态顺序表的接口实现顺序表的初始化与销毁顺序表尾插/尾删顺序表头插/头删顺序表在pos位置插入/删除顺序表的打印顺序表中的查找总结引言 在生产中&#xff0c;为了方便管理数据&#xff0c;我们经常会需要将一些数据连续的存储起…...

Compose-Animation高级别动画

目录前言AnimatedVisibilityisScrollingUpFABscaffoldanimateContentSizeCrossfade顶部气泡下弹前言 AnimatedVisibility 驱动可视性相关动画&#xff0c;即布局显隐 animateContentSize 内容变换动画相关 Crossfade 布局&#xff08;或者页面&#xff09;切换过渡动画 Animat…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展&#xff1a;显示创建时间8. 功能扩展&#xff1a;记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

【Linux】C语言执行shell指令

在C语言中执行Shell指令 在C语言中&#xff0c;有几种方法可以执行Shell指令&#xff1a; 1. 使用system()函数 这是最简单的方法&#xff0c;包含在stdlib.h头文件中&#xff1a; #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

系统设计 --- MongoDB亿级数据查询优化策略

系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log&#xff0c;共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题&#xff0c;不能使用ELK只能使用…...

质量体系的重要

质量体系是为确保产品、服务或过程质量满足规定要求&#xff0c;由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面&#xff1a; &#x1f3db;️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限&#xff0c;形成层级清晰的管理网络&#xf…...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践

6月5日&#xff0c;2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席&#xff0c;并作《智能体在安全领域的应用实践》主题演讲&#xff0c;分享了在智能体在安全领域的突破性实践。他指出&#xff0c;百度通过将安全能力…...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

Pinocchio 库详解及其在足式机器人上的应用

Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库&#xff0c;专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性&#xff0c;并提供了一个通用的框架&…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表

##鸿蒙核心技术##运动开发##Sensor Service Kit&#xff08;传感器服务&#xff09;# 前言 在运动类应用中&#xff0c;运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据&#xff0c;如配速、距离、卡路里消耗等&#xff0c;用户可以更清晰…...