当前位置: 首页 > news >正文

昇思训练营打卡第二十五天(RNN实现情感分类)

RNN,即循环神经网络(Recurrent Neural Network),是一种深度学习模型,特别适用于处理序列数据。以下是对RNN的简要介绍:

RNN的特点:

  1. 记忆性:与传统的前馈神经网络不同,RNN具有内部状态(记忆),可以捕获到目前为止观察到的序列信息。
  2. 参数共享:在处理序列的不同时间步时,RNN使用相同的权重,这意味着模型的参数数量不会随着输入序列长度的增加而增加。
  3. 灵活性:RNN能够处理任意长度的输入序列。

RNN的结构:

  • 输入层:接收序列中的单个元素。
  • 隐藏层:包含循环单元,这些单元具有记忆功能,能够存储之前的信息。
  • 输出层:根据当前输入和隐藏层的状态输出结果。

RNN的类型:

  1. 简单RNN:基础模型,但容易受到梯度消失和梯度爆炸问题的影响。
  2. LSTM(长短期记忆网络):通过引入门控机制,解决了简单RNN的长期依赖问题。
  3. GRU(门控循环单元):LSTM的变体,结构更简单,但性能相似。

应用场景:

  • 自然语言处理:如语言模型、机器翻译、文本生成等。
  • 语音识别:将语音信号转换为文本。
  • 时间序列预测:如股票价格预测、天气预报等。

数据下载模块

import os
import shutil
import requests
import tempfile
from tqdm import tqdm
from typing import IO
from pathlib import Path# 指定保存路径为 `home_path/.mindspore_examples`
cache_dir = Path.home() / '.mindspore_examples'def http_get(url: str, temp_file: IO):"""使用requests库下载数据,并使用tqdm库进行流程可视化"""req = requests.get(url, stream=True)content_length = req.headers.get('Content-Length')total = int(content_length) if content_length is not None else Noneprogress = tqdm(unit='B', total=total)for chunk in req.iter_content(chunk_size=1024):if chunk:progress.update(len(chunk))temp_file.write(chunk)progress.close()def download(file_name: str, url: str):"""下载数据并存为指定名称"""if not os.path.exists(cache_dir):os.makedirs(cache_dir)cache_path = os.path.join(cache_dir, file_name)cache_exist = os.path.exists(cache_path)if not cache_exist:with tempfile.NamedTemporaryFile() as temp_file:http_get(url, temp_file)temp_file.flush()temp_file.seek(0)with open(cache_path, 'wb') as cache_file:shutil.copyfileobj(temp_file, cache_file)return cache_pathimdb_path = download('aclImdb_v1.tar.gz', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz')
imdb_path

加载IMDB数据集

import re
import six
import string
import tarfileclass IMDBData():"""IMDB数据集加载器加载IMDB数据集并处理为一个Python迭代对象。"""label_map = {"pos": 1,"neg": 0}def __init__(self, path, mode="train"):self.mode = modeself.path = pathself.docs, self.labels = [], []self._load("pos")self._load("neg")def _load(self, label):pattern = re.compile(r"aclImdb/{}/{}/.*\.txt$".format(self.mode, label))# 将数据加载至内存with tarfile.open(self.path) as tarf:tf = tarf.next()while tf is not None:if bool(pattern.match(tf.name)):# 对文本进行分词、去除标点和特殊字符、小写处理self.docs.append(str(tarf.extractfile(tf).read().rstrip(six.b("\n\r")).translate(None, six.b(string.punctuation)).lower()).split())self.labels.append([self.label_map[label]])tf = tarf.next()def __getitem__(self, idx):return self.docs[idx], self.labels[idx]def __len__(self):return len(self.docs)
imdb_train = IMDBData(imdb_path, 'train')
len(imdb_train)
import mindspore.dataset as dsdef load_imdb(imdb_path):imdb_train = ds.GeneratorDataset(IMDBData(imdb_path, "train"), column_names=["text", "label"], shuffle=True, num_samples=10000)imdb_test = ds.GeneratorDataset(IMDBData(imdb_path, "test"), column_names=["text", "label"], shuffle=False)return imdb_train, imdb_test
imdb_train, imdb_test = load_imdb(imdb_path)
imdb_train

加载预训练词向量

预训练词向量是对输入单词的数值化表示,通过nn.Embedding层,采用查表的方式,输入单词对应词表中的index,获得对应的表达向量。 因此进行模型构造前,需要将Embedding层所需的词向量和词表进行构造。

import zipfile
import numpy as npdef load_glove(glove_path):glove_100d_path = os.path.join(cache_dir, 'glove.6B.100d.txt')if not os.path.exists(glove_100d_path):glove_zip = zipfile.ZipFile(glove_path)glove_zip.extractall(cache_dir)embeddings = []tokens = []with open(glove_100d_path, encoding='utf-8') as gf:for glove in gf:word, embedding = glove.split(maxsplit=1)tokens.append(word)embeddings.append(np.fromstring(embedding, dtype=np.float32, sep=' '))# 添加 <unk>, <pad> 两个特殊占位符对应的embeddingembeddings.append(np.random.rand(100))embeddings.append(np.zeros((100,), np.float32))vocab = ds.text.Vocab.from_list(tokens, special_tokens=["<unk>", "<pad>"], special_first=False)embeddings = np.array(embeddings).astype(np.float32)return vocab, embeddings
glove_path = download('glove.6B.zip', 'https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/glove.6B.zip')
vocab, embeddings = load_glove(glove_path)
len(vocab.vocab())
idx = vocab.tokens_to_ids('the')
embedding = embeddings[idx]
idx, embedding

数据集预处理

通过加载器加载的IMDB数据集进行了分词处理,但不满足构造训练数据的需要,因此要对其进行额外的预处理。其中包含的预处理如下:

  • 通过Vocab将所有的Token处理为index id。
  • 将文本序列统一长度,不足的使用<pad>补齐,超出的进行截断。
  • import mindspore as mslookup_op = ds.text.Lookup(vocab, unknown_token='<unk>')
    pad_op = ds.transforms.PadEnd([500], pad_value=vocab.tokens_to_ids('<pad>'))
    type_cast_op = ds.transforms.TypeCast(ms.float32)
    imdb_train = imdb_train.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_train = imdb_train.map(operations=[type_cast_op], input_columns=['label'])imdb_test = imdb_test.map(operations=[lookup_op, pad_op], input_columns=['text'])
    imdb_test = imdb_test.map(operations=[type_cast_op], input_columns=['label'])
    imdb_train, imdb_valid = imdb_train.split([0.7, 0.3])
    imdb_train = imdb_train.batch(64, drop_remainder=True)
    imdb_valid = imdb_valid.batch(64, drop_remainder=True)

    Embedding

    Embedding层又可称为EmbeddingLookup层,其作用是使用index id对权重矩阵对应id的向量进行查找,当输入为一个由index id组成的序列时,则查找并返回一个相同长度的矩阵

  • RNN(循环神经网络)

    循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的神经网络。

Dense

在经过LSTM编码获取句子特征后,将其送入一个全连接层,即nn.Dense,将特征维度变换为二分类所需的维度1,经过Dense层后的输出即为模型预测结果。

import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniformclass RNN(nn.Cell):def __init__(self, embeddings, hidden_dim, output_dim, n_layers,bidirectional, pad_idx):super().__init__()vocab_size, embedding_dim = embeddings.shapeself.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)self.rnn = nn.LSTM(embedding_dim,hidden_dim,num_layers=n_layers,bidirectional=bidirectional,batch_first=True)weight_init = HeUniform(math.sqrt(5))bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)def construct(self, inputs):embedded = self.embedding(inputs)_, (hidden, _) = self.rnn(embedded)hidden = ops.concat((hidden[-2, :, :], hidden[-1, :, :]), axis=1)output = self.fc(hidden)return output
hidden_size = 256
output_size = 1
num_layers = 2
bidirectional = True
lr = 0.001
pad_idx = vocab.tokens_to_ids('<pad>')model = RNN(embeddings, hidden_size, output_size, num_layers, bidirectional, pad_idx)
loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
optimizer = nn.Adam(model.trainable_params(), learning_rate=lr)
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return lossgrad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)def train_step(data, label):loss, grads = grad_fn(data, label)optimizer(grads)return lossdef train_one_epoch(model, train_dataset, epoch=0):model.set_train()total = train_dataset.get_dataset_size()loss_total = 0step_total = 0with tqdm(total=total) as t:t.set_description('Epoch %i' % epoch)for i in train_dataset.create_tuple_iterator():loss = train_step(*i)loss_total += loss.asnumpy()step_total += 1t.set_postfix(loss=loss_total/step_total)t.update(1)

相关文章:

昇思训练营打卡第二十五天(RNN实现情感分类)

RNN&#xff0c;即循环神经网络&#xff08;Recurrent Neural Network&#xff09;&#xff0c;是一种深度学习模型&#xff0c;特别适用于处理序列数据。以下是对RNN的简要介绍&#xff1a; RNN的特点&#xff1a; 记忆性&#xff1a;与传统的前馈神经网络不同&#xff0c;R…...

昇思25天学习打卡营第02天|张量 Tensor

一、什么是张量 Tensor 张量是一种特殊的数据结构&#xff0c;与数组和矩阵非常相似。张量&#xff08;Tensor&#xff09;是MindSpore网络运算中的基本数据结构。 张量可以被看作是一个多维数组&#xff0c;但它比普通的数组更加灵活和强大&#xff0c;因为它支持在GPU等加速…...

权威认可 | 海云安开发者安全助手系统通过信通院支撑产品功能认证并荣获信通院2024年数据安全体系建设优秀案例

近日&#xff0c;2024全球数字经济大会——数字安全生态建设专题论坛&#xff08;以下简称“论坛”&#xff09;在京成功举办。由全球数字经济大会组委会主办&#xff0c;中国信息通信研究院及公安部第三研究所共同承办&#xff0c;论坛邀请多位专家和企业共同参与。 会上颁发…...

24.7.10|暑假-数组题目:实现整数的数字反转【学习记录】

1、题目&#xff1a; 32位有符号整数&#xff0c;将整数每位上的数字进行反转 输入&#xff1a;123 输出&#xff1a;321 输入&#xff1a;-123 输出&#xff1a;-321 输入&#xff1a;120 输出&#xff1a;21 &#xff01;&#xff09; 问题 怎么把整数转换成字符串&#xff…...

【ceph】ceph集群-添加/删除mon

本站以分享各种运维经验和运维所需要的技能为主 《python零基础入门》&#xff1a;python零基础入门学习 《python运维脚本》&#xff1a; python运维脚本实践 《shell》&#xff1a;shell学习 《terraform》持续更新中&#xff1a;terraform_Aws学习零基础入门到最佳实战 《k8…...

Django ORM中的Q对象

Q 对象在 Django ORM 中用于构建复杂的查询条件&#xff0c;特别是当你需要使用逻辑运算符&#xff08;如 AND、OR、NOT&#xff09;时。以下是一些使用 Q 对象进行复杂查询的实际例子。 Q对象使用 模型 假设我们有一个包含员工信息的模型 Employee&#xff1a; from djang…...

相控阵雷达原理详解

相控阵&#xff0c;即相位控制阵列&#xff0c;通过控制阵列各个单元的馈电相位来改变波束指向。 相控阵雷达的原理可以清晰地归纳为以下几点&#xff1a; 1. 基本构成&#xff1a; - 相控阵雷达&#xff0c;即相位控制电子扫描阵列雷达&#xff08;Phased Array Radar, PAR&a…...

算法项目报告:物流中的最短路径问题

问题描述 物流问题 有一个物流公司需要从起点A到终点B进行货物运输&#xff0c;在运输过程中&#xff0c;该公司需要途径多个不同的城市&#xff0c;并且在每个城市中都有一个配送站点。为了最大程度地降低运输成本和时间&#xff0c;该公司需要确定经过哪些配送站点&#xff…...

linux中 crontab 定时器用法

*/10 * * * * python3 /home/code/haha2.py Crontab 当然&#xff0c;以下是一个简短的博客&#xff0c;介绍了 Cron 和 Crontab 的用法&#xff1a; --- # 简介&#xff1a;使用 Cron 和 Crontab 在 Linux 中进行定时任务调度 在 Linux 系统中&#xff0c;Cron 是一个用于…...

java算法day16

java算法day16 112 路径总和404 左叶子之和513 找树左下角的值 112 路径总和 题型判定为自顶向下类型&#xff0c;并且为路径和类型。 那就套模板。 自顶向下就是从上到下处理&#xff0c;那么就是前序遍历的思想。 class Solution {boolean res false;public boolean hasP…...

华为HCIP Datacom H12-821 卷41

1.多选题 以下关于BGP Atomic_Aggregate和Aggregator的描述&#xff0c;正确的是哪些项? A、Aggregator属性属于可选过渡属性 B、Atomic_Aggregate属于公认任意属性 C、收到携带Atomic_Aggregate属性的路由表示这条路由不能再度明细化 D、 Agregator表示某条路由可能出现…...

【React Hooks原理 - forwardRef、useImperativeHandle】

概述 上文我们聊了useRef的使用和实现&#xff0c;主要两个用途&#xff1a;1、用于持久化保存 2、用于绑定dom。 但是有时候我们需要在父组件中访问子组件的dom或者属性/方法&#xff0c;而React中默认是不允许父组件直接访问子组件的dom的&#xff0c;这时候就可以通过forwa…...

用于可穿戴传感器的人类活动识别、健康监测和行为建模的大型语言模型

这篇论文题为《用于可穿戴传感器的人类活动识别、健康监测和行为建模的大型语言模型&#xff1a;早期趋势、数据集和挑战的综述》&#xff0c;由埃米利奥费拉拉&#xff08;Emilio Ferrara&#xff09;撰写。论文主要内容如下&#xff1a; 摘要 可穿戴技术的普及使得传感器数…...

react事件绑定

react基础事件绑定 function passwordChange(e){console.log(e.target.value); } function usernameChange(e){console.log(e.target.value); }function App() {return (<div><input type"text" placeholder请输入用户名onChange{usernameChange}/><i…...

spring框架之AOP注解方式(java代码实例)

目录 半注解形式&#xff1a; 业务层接口实现类&#xff1a; 编写切面类&#xff1a; 在配置文件里面唯一需要加的&#xff1a; 测试类&#xff1a; 全注解形式&#xff1a; 不要配置文件&#xff0c;改为配置类&#xff1a; 同样的业务层接口实现类&#xff1a; 同样的…...

windows下gcc编译C、C++程序 MinGW编译器

文章目录 1、概要2、MinGW安装2.1 编译器下载2.2 编译器安装2.3 设置环境变量2.4 查看gcc版本信息 3、编译C、C程序3.1 编写Hello World.c3.2 编译C程序3.3 运行程序3.4 编译C程序 1、概要 GCC原名为GNU C语言编译器&#xff08;GNU C Compiler&#xff09;&#xff0c;只能处…...

uniapp启动图延时效果,启动图的配置

今天阐述uniapp开发中给启动图做延迟效果&#xff0c;不然启动图太快了&#xff0c;一闪就过去了&#xff1b; 一&#xff1a;修改配置文件&#xff1a;manifest.json "app-plus" : {"splashscreen" : {"alwaysShowBeforeRender" : false,"…...

SQL,python,knime将数据混合的文字数字拆出来,合并计算(学习笔记)

将下面将数据混合的文字数字拆出来&#xff0c;合并计算 一、SQL解决&#xff1a; ---创建表插入数据 CREATE TABLE original_data (id INT AUTO_INCREMENT PRIMARY KEY,city VARCHAR(255),value DECIMAL(10, 2) );INSERT INTO original_data (city, value) VALUES (上海0.5…...

【算法】LRU缓存

难度&#xff1a;中等 题目&#xff1a; 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类&#xff1a; LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存int get(int key) 如果关键字 key 存在于缓存中&#xff0c;…...

解决elementUI列表的疑难杂症,排序显示错乱的问题

大家好&#xff0c;在使用elementUI表格时&#xff0c;有时会出现一些意料之外的问题&#xff0c;比如数据排序正常但表格显示、排序错乱等。在网上搜索后一般有2种解决方法&#xff1a;1.给表格每一项的el-table-column添加唯一的id用于区分。2.给表格每一项的el-table-column…...

重大消息:手机车机互联投屏专题发布-千里马带你学框架

背景&#xff1a; android投屏的使用场景以前在新能源车机还没火爆时候&#xff0c;大部分停留在手机小屏幕投屏到大屏幕的情况及整个多端设备的互动&#xff0c;整体需求和技术发展其实也就是比较有限&#xff0c;但是新能源车机火爆后&#xff0c;那么这种手机和车机互联互动…...

jail子系统里升级Ubuntu focal到jammy

Ubuntu focal是20.04 &#xff0c;jammy版本是22.04&#xff0c;本次的目的就是将FreeBSD jail子系统里的Ubuntu 从20.04升级到22.04 。这个focal 子系统是通过cbsd克隆得到的。使用CBSD克隆复制Ubuntu jail子系统环境-CSDN博客 do-release-upgrade升级没成功&#xff0c;用de…...

2024年7月20日(星期六)骑行支里山

2024年7月20日 (星期六&#xff09;骑行支里山&#xff0c;早8:00到8:30&#xff0c;大观公园门口集合&#xff0c;9:00准时出发【因迟到者&#xff0c;骑行速度快者&#xff0c;可自行追赶偶遇。】 偶遇地点:大观公园门口集合 &#xff0c;家住东&#xff0c;南&#xff0c;北…...

Python:正则表达式相关整理

最近因为一些原因频繁使用正则表达式&#xff0c;因为以前系统整理过关于正则表达式的相关知识&#xff0c;所以这里仅记录使用期间遇到的问题。 本文内容基于re包 1. match和search方法的区别 在Python中&#xff0c;re.search和re.match都是用于匹配字符串的正则表达式函数&a…...

ChatGPT对话:有关花卉数据集

【编者按】编者准备研究基于深度学习的花卉识别&#xff0c;首先需要花卉数据集。 后续&#xff0c;编者不断会记录研究花卉识别过程中的技术知识&#xff0c;敬请围观 1问&#xff1a;推荐一下用于深度学习的花卉数据集 ChatGPT 以下是一些用于深度学习的优秀花卉数据集&am…...

特征向量及算法

数据挖掘流程 加载数据 把需要的模型数据先计算出来 特征工程 提取数据特征&#xff0c;对特征数据进行清洗转化 数据的筛选和清洗数据转化 类型转为 性别 男&#xff0c;女 ----> 1,0特征交叉 性别/职业/收入 —> 新特这 优质男性程序员 将多个特征值组合在一起特征筛选…...

cpp 强制转换

一、static_cast static_cast 是 C 中的一个类型转换操作符&#xff0c;用于在类的层次结构中进行安全的向上转换&#xff08;从派生类到基类&#xff09;或进行不需要运行时类型检查的转换。它主要用于基本数据类型之间的转换、对象指针或引用的向上转换&#xff08;即从派生…...

MySQL字符串魔法:拼接、截取、替换与定位的艺术

在数据的世界里&#xff0c;MySQL作为一把强大的数据处理利剑&#xff0c;其字符串处理功能犹如魔术师手中的魔法棒&#xff0c;让数据变换自如。今天&#xff0c;我们就来一场关于MySQL字符串拼接、截取、替换以及查找位置的奇幻之旅&#xff0c;揭开这些操作的神秘面纱。 介绍…...

在 Windows 上开发.NET MAUI 应用_1.安装开发环境

开发跨平台的本机 .NET Multi-platform App UI (.NET MAUI) 应用需要 Visual Studio 2022 17.8 或更高版本&#xff0c;或者具有 .NET MAUI 扩展的最新 Visual Studio Code。要开始在 Windows 上开发本机跨平台 .NET MAUI 应用&#xff0c;请按照安装步骤安装 Visual Studio 20…...

深度学习驱动智能超材料设计与应用

在深度学习与超材料融合的背景下&#xff0c;不仅提高了设计的效率和质量&#xff0c;还为实现定制化和精准化的治疗提供了可能&#xff0c;展现了在材料科学领域的巨大潜力。深度学习可以帮助实现超材料结构参数的优化、电磁响应的预测、拓扑结构的自动设计、相位的预测及结构…...