当前位置: 首页 > news >正文

昇思训练营打卡第二十五天(RNN实现情感分类)

RNN,即循环神经网络(Recurrent Neural Network),是一种深度学习模型,特别适用于处理序列数据。以下是对RNN的简要介绍:

RNN的特点:

  1. 记忆性:与传统的前馈神经网络不同,RNN具有内部状态(记忆),可以捕获到目前为止观察到的序列信息。
  2. 参数共享:在处理序列的不同时间步时,RNN使用相同的权重,这意味着模型的参数数量不会随着输入序列长度的增加而增加。
  3. 灵活性:RNN能够处理任意长度的输入序列。

RNN的结构:

  • 输入层:接收序列中的单个元素。
  • 隐藏层:包含循环单元,这些单元具有记忆功能,能够存储之前的信息。
  • 输出层:根据当前输入和隐藏层的状态输出结果。

RNN的类型:

  1. 简单RNN:基础模型,但容易受到梯度消失和梯度爆炸问题的影响。
  2. LSTM(长短期记忆网络):通过引入门控机制,解决了简单RNN的长期依赖问题。
  3. GRU(门控循环单元):LSTM的变体,结构更简单,但性能相似。

应用场景:

  • 自然语言处理:如语言模型、机器翻译、文本生成等。
  • 语音识别:将语音信号转换为文本。
  • 时间序列预测:如股票价格预测、天气预报等。

数据下载模块

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):"""使用requests库下载数据,并使用tqdm库进行流程可视化"""req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):"""下载数据并存为指定名称"""if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_pathimdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

加载IMDB数据集

import re
import six
import string
import tarfileclass IMDBData():"""IMDB数据集加载器加载IMDB数据集并处理为一个Python迭代对象。"""label_map = {"pos": 1,"neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))# 将数据加载至内存with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):# 对文本进行分词、去除标点和特殊字符、小写处理self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)
imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)
import mindspore.dataset as dsdef load_imdb(imdb_path):imdb_train = ds.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"], shuffle=True, num_samples=10000)imdb_test = ds.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"], shuffle=False)return imdb_train, imdb_test
imdb_train, imdb_test = load_imdb(imdb_path)
imdb_train

加载预训练词向量

预训练词向量是对输入单词的数值化表示,通过nn.Embedding层,采用查表的方式,输入单词对应词表中的index,获得对应的表达向量。 因此进行模型构造前,需要将Embedding层所需的词向量和词表进行构造。

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))# 添加 <unk>, <pad> 两个特殊占位符对应的embeddingembeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddings
glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
len(vocab.vocab())
idx = vocab.tokens_to_ids('the')
embedding = embeddings[idx]
idx, embedding

数据集预处理

通过加载器加载的IMDB数据集进行了分词处理,但不满足构造训练数据的需要,因此要对其进行额外的预处理。其中包含的预处理如下:

  • 通过Vocab将所有的Token处理为index id。
  • 将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。
  • import mindspore as mslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
    pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
    type_cast_op = ds.transforms.TypeCast(ms.float32)
    imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])
    imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
    imdb_train = imdb_train.batch(64, drop_remainder=True)
    imdb_valid = imdb_valid.batch(64, drop_remainder=True)

    Embedding

    Embedding层又可称为EmbeddingLookup层,其作用是使用index id对权重矩阵对应id的向量进行查找,当输入为一个由index id组成的序列时,则查找并返回一个相同长度的矩阵

  • RNN(循环神经网络)

    循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。

Dense

在经过LSTM编码获取句子特征后,将其送入一个全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

相关文章:

昇思训练营打卡第二十五天(RNN实现情感分类)

RNN&#xff0c;即循环神经网络&#xff08;Recurrent Neural Network&#xff09;&#xff0c;是一种深度学习模型&#xff0c;特别适用于处理序列数据。以下是对RNN的简要介绍&#xff1a; RNN的特点&#xff1a; 记忆性&#xff1a;与传统的前馈神经网络不同&#xff0c;R…...

昇思25天学习打卡营第02天|张量 Tensor

一、什么是张量 Tensor 张量是一种特殊的数据结构&#xff0c;与数组和矩阵非常相似。张量&#xff08;Tensor&#xff09;是MindSpore网络运算中的基本数据结构。 张量可以被看作是一个多维数组&#xff0c;但它比普通的数组更加灵活和强大&#xff0c;因为它支持在GPU等加速…...

权威认可 | 海云安开发者安全助手系统通过信通院支撑产品功能认证并荣获信通院2024年数据安全体系建设优秀案例

近日&#xff0c;2024全球数字经济大会——数字安全生态建设专题论坛&#xff08;以下简称“论坛”&#xff09;在京成功举办。由全球数字经济大会组委会主办&#xff0c;中国信息通信研究院及公安部第三研究所共同承办&#xff0c;论坛邀请多位专家和企业共同参与。 会上颁发…...

24.7.10|暑假-数组题目:实现整数的数字反转【学习记录】

1、题目&#xff1a; 32位有符号整数&#xff0c;将整数每位上的数字进行反转 输入&#xff1a;123 输出&#xff1a;321 输入&#xff1a;-123 输出&#xff1a;-321 输入&#xff1a;120 输出&#xff1a;21 &#xff01;&#xff09; 问题 怎么把整数转换成字符串&#xff…...

【ceph】ceph集群-添加/删除mon

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…...

Django ORM中的Q对象

Q 对象在 Django ORM 中用于构建复杂的查询条件&#xff0c;特别是当你需要使用逻辑运算符&#xff08;如 AND、OR、NOT&#xff09;时。以下是一些使用 Q 对象进行复杂查询的实际例子。 Q对象使用 模型 假设我们有一个包含员工信息的模型 Employee&#xff1a; from djang…...

相控阵雷达原理详解

相控阵&#xff0c;即相位控制阵列&#xff0c;通过控制阵列各个单元的馈电相位来改变波束指向。 相控阵雷达的原理可以清晰地归纳为以下几点&#xff1a; 1. 基本构成&#xff1a; - 相控阵雷达&#xff0c;即相位控制电子扫描阵列雷达&#xff08;Phased Array Radar, PAR&a…...

算法项目报告:物流中的最短路径问题

问题描述 物流问题 有一个物流公司需要从起点A到终点B进行货物运输&#xff0c;在运输过程中&#xff0c;该公司需要途径多个不同的城市&#xff0c;并且在每个城市中都有一个配送站点。为了最大程度地降低运输成本和时间&#xff0c;该公司需要确定经过哪些配送站点&#xff…...

linux中 crontab 定时器用法

*/10 * * * * python3 /home/code/haha2.py Crontab 当然&#xff0c;以下是一个简短的博客&#xff0c;介绍了 Cron 和 Crontab 的用法&#xff1a; --- # 简介&#xff1a;使用 Cron 和 Crontab 在 Linux 中进行定时任务调度 在 Linux 系统中&#xff0c;Cron 是一个用于…...

java算法day16

java算法day16 112 路径总和404 左叶子之和513 找树左下角的值 112 路径总和 题型判定为自顶向下类型&#xff0c;并且为路径和类型。 那就套模板。 自顶向下就是从上到下处理&#xff0c;那么就是前序遍历的思想。 class Solution {boolean res false;public boolean hasP…...

华为HCIP Datacom H12-821 卷41

1.多选题 以下关于BGP Atomic_Aggregate和Aggregator的描述&#xff0c;正确的是哪些项? A、Aggregator属性属于可选过渡属性 B、Atomic_Aggregate属于公认任意属性 C、收到携带Atomic_Aggregate属性的路由表示这条路由不能再度明细化 D、 Agregator表示某条路由可能出现…...

【React Hooks原理 - forwardRef、useImperativeHandle】

概述 上文我们聊了useRef的使用和实现&#xff0c;主要两个用途&#xff1a;1、用于持久化保存 2、用于绑定dom。 但是有时候我们需要在父组件中访问子组件的dom或者属性/方法&#xff0c;而React中默认是不允许父组件直接访问子组件的dom的&#xff0c;这时候就可以通过forwa…...

用于可穿戴传感器的人类活动识别、健康监测和行为建模的大型语言模型

这篇论文题为《用于可穿戴传感器的人类活动识别、健康监测和行为建模的大型语言模型&#xff1a;早期趋势、数据集和挑战的综述》&#xff0c;由埃米利奥费拉拉&#xff08;Emilio Ferrara&#xff09;撰写。论文主要内容如下&#xff1a; 摘要 可穿戴技术的普及使得传感器数…...

react事件绑定

react基础事件绑定 function passwordChange(e){console.log(e.target.value); } function usernameChange(e){console.log(e.target.value); }function App() {return (<div><input type"text" placeholder请输入用户名onChange{usernameChange}/><i…...

spring框架之AOP注解方式(java代码实例)

目录 半注解形式&#xff1a; 业务层接口实现类&#xff1a; 编写切面类&#xff1a; 在配置文件里面唯一需要加的&#xff1a; 测试类&#xff1a; 全注解形式&#xff1a; 不要配置文件&#xff0c;改为配置类&#xff1a; 同样的业务层接口实现类&#xff1a; 同样的…...

windows下gcc编译C、C++程序 MinGW编译器

文章目录 1、概要2、MinGW安装2.1 编译器下载2.2 编译器安装2.3 设置环境变量2.4 查看gcc版本信息 3、编译C、C程序3.1 编写Hello World.c3.2 编译C程序3.3 运行程序3.4 编译C程序 1、概要 GCC原名为GNU C语言编译器&#xff08;GNU C Compiler&#xff09;&#xff0c;只能处…...

uniapp启动图延时效果,启动图的配置

今天阐述uniapp开发中给启动图做延迟效果&#xff0c;不然启动图太快了&#xff0c;一闪就过去了&#xff1b; 一&#xff1a;修改配置文件&#xff1a;manifest.json "app-plus" : {"splashscreen" : {"alwaysShowBeforeRender" : false,"…...

SQL,python,knime将数据混合的文字数字拆出来,合并计算(学习笔记)

将下面将数据混合的文字数字拆出来&#xff0c;合并计算 一、SQL解决&#xff1a; ---创建表插入数据 CREATE TABLE original_data (id INT AUTO_INCREMENT PRIMARY KEY,city VARCHAR(255),value DECIMAL(10, 2) );INSERT INTO original_data (city, value) VALUES (上海0.5…...

【算法】LRU缓存

难度&#xff1a;中等 题目&#xff1a; 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;…...

解决elementUI列表的疑难杂症,排序显示错乱的问题

大家好&#xff0c;在使用elementUI表格时&#xff0c;有时会出现一些意料之外的问题&#xff0c;比如数据排序正常但表格显示、排序错乱等。在网上搜索后一般有2种解决方法&#xff1a;1.给表格每一项的el-table-column添加唯一的id用于区分。2.给表格每一项的el-table-column…...

椭圆曲线密码学(ECC)

一、ECC算法概述 椭圆曲线密码学&#xff08;Elliptic Curve Cryptography&#xff09;是基于椭圆曲线数学理论的公钥密码系统&#xff0c;由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA&#xff0c;ECC在相同安全强度下密钥更短&#xff08;256位ECC ≈ 3072位RSA…...

golang循环变量捕获问题​​

在 Go 语言中&#xff0c;当在循环中启动协程&#xff08;goroutine&#xff09;时&#xff0c;如果在协程闭包中直接引用循环变量&#xff0c;可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下&#xff1a; 问题背景 看这个代码片段&#xff1a; fo…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

从零实现富文本编辑器#5-编辑器选区模型的状态结构表达

先前我们总结了浏览器选区模型的交互策略&#xff0c;并且实现了基本的选区操作&#xff0c;还调研了自绘选区的实现。那么相对的&#xff0c;我们还需要设计编辑器的选区表达&#xff0c;也可以称为模型选区。编辑器中应用变更时的操作范围&#xff0c;就是以模型选区为基准来…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件

在选煤厂、化工厂、钢铁厂等过程生产型企业&#xff0c;其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进&#xff0c;需提前预防假检、错检、漏检&#xff0c;推动智慧生产运维系统数据的流动和现场赋能应用。同时&#xff0c;…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集&#xff0c;包含8种湿地亚类&#xff0c;该数据以0.5X0.5的瓦片存储&#xff0c;我们整理了所有属于中国的瓦片名称与其对应省份&#xff0c;方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

2021-03-15 iview一些问题

1.iview 在使用tree组件时&#xff0c;发现没有set类的方法&#xff0c;只有get&#xff0c;那么要改变tree值&#xff0c;只能遍历treeData&#xff0c;递归修改treeData的checked&#xff0c;发现无法更改&#xff0c;原因在于check模式下&#xff0c;子元素的勾选状态跟父节…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...