当前位置: 首页 > news >正文

Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建

Mindspore框架循环神经网络RNN模型实现情感分类

Mindspore框架循环神经网络RNN模型实现情感分类|(一)IMDB影评数据集准备
Mindspore框架循环神经网络RNN模型实现情感分类|(二)预训练词向量
Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建
Mindspore框架循环神经网络RNN模型实现情感分类|(四)损失函数与优化器
Mindspore框架循环神经网络RNN模型实现情感分类|(五)模型训练
Mindspore框架循环神经网络RNN模型实现情感分类|(六)模型加载和推理(情感分类模型资源下载)
Mindspore框架循环神经网络RNN模型实现情感分类|(七)模型导出ONNX与应用部署

tips:安装依赖库

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install tqdm requests

一、RNN模型构建

数据集准备完成了输入文本通过查字典(序列化)的向量化。并使用nn.Embedding层加载了Glove词向量。下一步将使用RNN循环神经网络做特征提取,最后将RNN连接至全连接网络nn.Dednse,将特征转化为分类。

nn.Embedding -> nn.RNN -> nn.Dense

本项目,采用规避RNN梯度消的变种LSTM(Long short-term memory)代替RNN做特征提取层。

1.1 关于RNN

循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。下图为RNN的一般结构:

RNN-0

图示左侧为一个RNN Cell循环,右侧为RNN的链式连接平铺。实际上不管是单个RNN Cell还是一个RNN网络,都只有一个Cell的参数,在不断进行循环计算中更新。

由于RNN的循环特性,和自然语言文本的序列特性(句子是由单词组成的序列)十分匹配,因此被大量应用于自然语言处理研究中。下图为RNN的结构拆解:

RNN

1.2 关于LSTM(Long short-term memory)

RNN单个Cell的结构简单,因此也造成了梯度消失(Gradient Vanishing)问题,具体表现为RNN网络在序列较长时,在序列尾部已经基本丢失了序列首部的信息。为了克服这一问题,LSTM(Long short-term memory)被提出,通过门控机制(Gating Mechanism)来控制信息流在每个循环步中的留存和丢弃。下图为LSTM的结构拆解:

LSTM

本项目选择LSTM变种而不是经典的RNN做特征提取,可规避梯度消失问题,并获得更好的模型效果。
在MindSpore中nn.LSTM对应的公式:

h 0 : t , ( h t , c t ) = LSTM ( x 0 : t , ( h 0 , c 0 ) ) h_{0:t}, (h_t, c_t) = \text{LSTM}(x_{0:t}, (h_0, c_0)) h0:t,(ht,ct)=LSTM(x0:t,(h0,c0))

这里nn.LSTM隐藏了整个循环神经网络在序列时间步(Time step)上的循环,送入输入序列、初始状态,即可获得每个时间步的隐状态(hidden state`)拼接而成的矩阵,以及最后一个时间步对应的隐状态。我们使用最后的一个时间步的隐状态作为输入句子的编码特征,送入下一层

Time step:在循环神经网络计算的每一次循环,成为一个Time step。在送入文本序列时,一个Time step对应一个单词。因此在本例中,LSTM的输出 h 0 : t h_{0:t} h0:t对应每个单词的隐状态集合, h t h_t ht c t c_t ct对应最后一个单词对应的隐状态。

下一层:全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

1.3 特征提取网络构建

RNN循环神经网络: nn.LSTM()
初始化参数:

 embeddings:输入向量,hidden_dim:隐藏层特征的维度, output_dim:输出维数, n_layers:RNN 层的数量,bidirectional:是否为双向 RNN, pad_idx:padding_idx参数用于标记输入中的填充值(padding value)。在自然语言处理任务中,文本序列的长度不一致是非常常见的。为了能够对不同长度的文本序列进行批处理,我们通常会使用填充值对较短的序列进行填补。

tips:使用nn.embeddings()创建嵌入层时,可以通过padding_idx参数指定一个特定的索引,用于表示填充值。
embedding_layer = nn.Embedding(num_embeddings, embedding_dim, padding_idx=0),将padding_idx设置为0,表示使用索引为0的词汇作为填充值。在文本序列中,我们将使用0来填充较短的序列。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output

实例化模型,打印输出

hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
print(model)

在这里插入图片描述

相关文章:

Mindspore框架循环神经网络RNN模型实现情感分类|(三)RNN模型构建

Mindspore框架循环神经网络RNN模型实现情感分类 Mindspore框架循环神经网络RNN模型实现情感分类|&#xff08;一&#xff09;IMDB影评数据集准备 Mindspore框架循环神经网络RNN模型实现情感分类|&#xff08;二&#xff09;预训练词向量 Mindspore框架循环神经网络RNN模型实现…...

深度解读大语言模型中的Transformer架构

一、Transformer的诞生背景 传统的循环神经网络&#xff08;RNN&#xff09;和长短期记忆网络&#xff08;LSTM&#xff09;在处理自然语言时存在诸多局限性。RNN 由于其递归的结构&#xff0c;在处理长序列时容易出现梯度消失和梯度爆炸的问题。这导致模型难以捕捉长距离的依…...

安装好anaconda,打开jupyter notebook,新建 报500错

解决办法&#xff1a; 打开anaconda prompt 输入 jupyter --version 重新进入jupyter notebook&#xff1a; 可以成功进入进行代码编辑...

C++20之设计模式:状态模式

状态模式 状态模式状态驱动的状态机手工状态机Boost.MSM 中的状态机总结 状态模式 我必须承认:我的行为是由我的状态支配的。如果我没有足够的睡眠&#xff0c;我会有点累。如果我喝了酒&#xff0c;我就不会开车了。所有这些都是状态(states)&#xff0c;它们支配着我的行为:…...

数据库安全综合治理方案(可编辑54页PPT)

引言&#xff1a;数据库安全综合治理方案是一个系统性的工作&#xff0c;需要从多个方面入手&#xff0c;综合运用各种技术和管理手段&#xff0c;确保数据库系统的安全稳定运行。 方案介绍&#xff1a; 数据库安全综合治理方案是一个综合性的策略&#xff0c;旨在确保数据库系…...

人工智能:大语言模型提示注入攻击安全风险分析报告下载

大语言模型提示注入攻击安全风险分析报告下载 今天分享的是人工智能AI研究报告&#xff1a;《大语言模型提示注入攻击安全风险分析报告》。&#xff08;报告出品方&#xff1a;大数据协同安全技术国家工程研究中心安全大脑国家新一代人工智能开放创新平台&#xff09; 研究报告…...

【购买源码时有许多需要注意的坑】

购买源码时有许多需要注意的“坑”&#xff0c;这些坑可能会对项目的后续开发和使用造成严重影响。以下是一些需要特别注意的方面&#xff1a; 源码的完整性 编译测试&#xff1a;确保到手的源码能够从头至尾编译、打包、部署和功能测试无误。这一步非常关键&#xff0c;因为只…...

CAS的三大问题和解决方案

一、ABA问题的解决方案 变量第一次读取的值是1&#xff0c;后来其他线程改成了3&#xff0c;然后又被其他线程修改成了1&#xff0c;原来期望的值是第一个1才会设置新值&#xff0c;第二个1跟期望不符合&#xff0c;但是&#xff0c;可以设置新值。 解决方案&#xff1a; &a…...

EDA和统计分析有什么区别

EDA&#xff08;Electronic Design Automation&#xff09;和统计分析在多个方面存在显著的区别&#xff0c;这些区别主要体现在它们的应用领域、目的、方法以及所使用的工具上。 EDA&#xff08;电子设计自动化&#xff09; 定义与目的&#xff1a; EDA是电子设计自动化&…...

CentOS 7 修改DNS

1、nmcli connection show 命令找到设备名称 # nmcli connection show NAME UUID TYPE DEVICE enp4s0 99559edf-4e0a-4bae-a528-6d75065261e9 ethernet enp4s0 2、nmcli connection modify 命令修改dns nmcli connection modif…...

PHP基础语法-Part2

if-else语句、switch语句 与其他语言相同 循环结构 for循环while循环do-while循环foreach循环&#xff0c;搭配数组使用 foreach ($age as $avlue) //只输出值 {xxx; } foreach ($age as $key > $avlue) //键和值都输出 {xxx; }foreach ($age as $key >…...

数据结构门槛-顺序表

顺序表 1. 线性表2. 顺序表2.1 静态顺序表2.2 动态顺序表2.2.1 动态数据表初始化和销毁2.2.2 动态数据表的尾插尾删2.2.3 动态数据表的头插头删2.2.4 动态数据表的中间部分插入删除2.2.5 动态数据表的查询数据位置 3. 总结 1. 线性表 线性表&#xff08;linear list&#xff0…...

软件测试面试准备工作

1、 什么是数据库? 答&#xff1a;数据库是按照某种数据模型组织起来的并存放二级存储器中的数据集合。 2、 什么是关系型数据库? 答&#xff1a;关系型数据库是建立在关系数据库模型基础上的数据库&#xff0c; 借助集合代数等概念和方法处理数据库中的数据。目前主流的关…...

Java面试八股之后Spring、spring mvc和spring boot的区别

Spring、spring mvc和spring boot的区别 Spring, Spring Boot和Spring MVC都是Spring框架家族的一部分&#xff0c;它们各自有其特定的用途和优势。下面是它们之间的主要区别&#xff1a; Spring: Spring 是一个开源的轻量级Java开发框架&#xff0c;最初由Rod Johnson创建&…...

linux对齐TOF和RGB摄像头画面

问题&#xff1a;TOF和RGB画面不对齐 linux同时接入TOF和RGB&#xff0c;两者出图时间是由驱动层控制&#xff08;RGB硬件触发出图&#xff09;&#xff0c;应用层只负责读取数据。 现在两者画面不对齐&#xff0c;发现是开始的时候两者出图数量不一致导致的。底层解决不了&a…...

配置linux客户端免密登录服务端linux主机的root用户

1、客户端与服务端的ip 客户端IP地址服务端IP地址 2、定位客户端&#xff0c;由客户端制作公私钥对 [rootclient ~]# ssh-keygen -t rsa &#xff08;RSA是非对称加密算法&#xff09; # 一路回车 3、定位客户端&#xff0c;将公钥上传到服务器端root账户 [rootc…...

SpringMVC实现文件上传

导入文件上传相关依赖 <!--文件上传--> <dependency><groupId>commons-fileupload</groupId><artifactId>commons-fileupload</artifactId><version>1.3.1</version> </dependency> <dependency><groupId>…...

计算机实验室排课查询小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;学生管理&#xff0c;教师管理&#xff0c;实验室信息管理&#xff0c;实验室预约管理&#xff0c;取消预约管理&#xff0c;实验课程管理&#xff0c;实验报告管理&#xff0c;报修信息管理&#xff0…...

分享几种电商平台商品数据的批量自动抓取方式

在当今数字化时代&#xff0c;电商平台作为商品交易的重要渠道&#xff0c;其数据对于商家、市场分析师及数据科学家来说具有极高的价值。批量自动抓取电商平台商品数据成为提升业务效率、优化市场策略的重要手段。本文将详细介绍几种主流的电商平台商品数据批量自动抓取方式&a…...

mysql面试(五)

前言 本章节从数据页的具体结构&#xff0c;分析到如何生成索引&#xff0c;如何构成B树的索引结构。 以及什么是聚簇索引&#xff0c;什么是联合索引 InnoDB数据结构 行数据 我看各种文档中有好多记录数据结构的&#xff0c;但是这些都是看完就忘的东西。在这里详细讲也没…...

手把手教你用ESP8266 AT指令连接华为云IoT(附固件烧录与MQTT避坑指南)

从零玩转ESP8266&#xff1a;华为云IoT连接实战与深度排错指南 当你第一次拿到那块拇指大小的ESP8266模块时&#xff0c;可能不会想到这个售价不到20元的Wi-Fi芯片能成为物联网世界的通行证。作为全球使用量最大的IoT连接方案之一&#xff0c;ESP8266配合华为云物联网平台&…...

TEA加密算法实战:用Python和C语言实现QQ同款加密(附完整代码)

TEA加密算法实战&#xff1a;从原理到跨语言实现 在即时通讯和物联网设备中&#xff0c;数据安全传输一直是核心需求。TEA&#xff08;Tiny Encryption Algorithm&#xff09;以其轻量级、高效率的特性&#xff0c;成为资源受限环境下的理想选择。本文将深入探讨TEA算法家族的工…...

别再被ToggleGroup坑了!手把手教你写一个不自动选首项的CustomToggleGroup组件(附完整代码)

深度定制Unity ToggleGroup&#xff1a;打造无默认选中行为的智能组件 引言 在Unity UI开发中&#xff0c;ToggleGroup组件是构建选项卡式界面的常见选择&#xff0c;但许多开发者都遇到过这样的困扰&#xff1a;当ToggleGroup激活时&#xff0c;系统总会自动选中第一个Toggle项…...

零基础入门esp32开发:用快马平台生成第一个led控制程序详解

最近在学ESP32开发&#xff0c;发现对于新手来说&#xff0c;从零开始写代码还是挺有挑战的。不过我发现了一个超好用的工具——InsCode(快马)平台&#xff0c;它可以根据你的需求直接生成可运行的代码&#xff0c;特别适合像我这样的初学者。 项目需求分析 我想实现一个简单的…...

告别盲目搜索!Unity大版本升级时,系统化处理API变更的5个步骤

Unity大版本升级的系统化实践&#xff1a;从API变更管理到团队协作优化 当Unity 2023 LTS发布时&#xff0c;某中型游戏团队在升级过程中发现超过40%的脚本因API变更而报错&#xff0c;导致项目停滞两周。这种场景在技术迭代中并不罕见&#xff0c;但大多数团队仍采用"遇到…...

开发提效新组合:用Cursor生成代码片段,在快马一键集成与部署

最近在做一个数据整理的小工具时&#xff0c;发现了一个特别高效的工作流组合&#xff1a;先用Cursor快速生成核心代码片段&#xff0c;再用InsCode(快马)平台一键整合部署。整个过程就像搭积木一样顺畅&#xff0c;特别适合需要快速实现功能模块的场景。 需求分析 我们经常要处…...

告别手动复制!用ArcGIS字段计算器(VB/Python)批量提取字段值的保姆级教程

ArcGIS字段计算器实战指南&#xff1a;VB与Python高效提取字段值的深度对比 在GIS数据处理工作中&#xff0c;属性表字段值的部分提取是最常见却又最耗时的操作之一。想象一下&#xff0c;当你面对一个包含上万条记录的"BSM"字段&#xff0c;需要提取前6位作为行政区…...

利用快马平台快速生成PyTorch图像分类原型,十分钟验证模型思路

最近在尝试用PyTorch做图像分类的原型验证时&#xff0c;发现从零开始搭建环境、写基础代码特别耗时。后来尝试用InsCode(快马)平台生成项目模板&#xff0c;十分钟就完成了模型验证。这里分享下用PyTorch快速构建MNIST分类器的关键步骤和踩坑经验。 数据准备环节 平台生成的代…...

计算机组成原理实验避坑指南:存储器地址映射常见错误及解决方法

计算机组成原理实验避坑指南&#xff1a;存储器地址映射常见错误及解决方法 第一次在Proteus里搭建存储器系统时&#xff0c;看着密密麻麻的地址线和片选信号&#xff0c;我对着实验指导书发呆了半小时——明明按照图示连接了所有线路&#xff0c;可写入RAM的数据总是莫名其妙出…...

translategemma-27b-it入门必看:Gemma3轻量化设计如何平衡精度与推理速度

translategemma-27b-it入门必看&#xff1a;Gemma3轻量化设计如何平衡精度与推理速度 本文深度解析基于Gemma 3构建的TranslateGemma-27B-IT模型&#xff0c;通过实际部署演示展示其如何在保持翻译精度的同时实现高效推理&#xff0c;为开发者提供完整的入门指南。 1. 认识Tran…...