当前位置: 首页 > news >正文

RNN文本分类任务实战

  1. 递归神经网络 (RNN):
    定义:RNN 是一类专为顺序数据处理而设计的人工神经网络。
    顺序处理:RNN 保持一个隐藏状态,该状态捕获有关序列中先前输入的信息,使其适用于涉及顺序依赖关系的任务。
  2. 词嵌入:
    定义:词嵌入是捕获语义关系的词的密集向量表示。
    重要性:它们允许神经网络学习上下文信息和单词之间的关系。
    实现:使用预先训练的词嵌入(Word2Vec、GloVe)或在模型中包含嵌入层。
  3. 文本标记化和填充:
    代币化:将文本分解为单个单词或子单词。
    填充:通过添加零或截断来确保所有序列具有相同的长度。
  4. Keras 中的顺序模型:
    实现:利用 Keras 库中的 Sequential 模型创建线性层堆栈。
  5. 嵌入层:
    实现:向模型添加嵌入层,将单词转换为密集向量。
    配置:指定输入维度、输出维度(嵌入大小)和输入长度。
  6. 循环层(LSTM 或 GRU):
    LSTM 和 GRU:长短期记忆 (LSTM) 和门控循环单元 (GRU) 层有助于捕获长期依赖关系。
    实现:将一个或多个 LSTM 或 GRU 层添加到模型中。
  7. 致密层:
    目的:密集层用于最终分类输出。
    实现:添加一个或多个具有适当激活函数的密集层。
  8. 激活功能:
    选择:ReLU(整流线性单元)或tanh是隐藏层中激活函数的常见选择。
  9. 损失函数和优化器:
    损失函数:稀疏分类交叉熵通常用于文本分类任务。
    优化:Adam 或 RMSprop 是常用的优化器。
  10. 批处理和排序:
    批处理:在批量输入序列上训练模型。
    处理不同长度的物料:使用填充来处理不同长度的序列。
  11. 培训流程:
    汇编:使用所选的损失函数、优化器和指标编译模型。
    训练:将模型拟合到训练数据,在单独的集合上进行验证。
  12. 防止过拟合:
    技术:实现 dropout 或 recurrent dropout 层以防止过拟合。
    正规化:如果需要,请考虑 L1 或 L2 正则化。
  13. 超参数调优:
    参数:根据验证性能调整超参数,例如学习率、批量大小和循环单元数。
  14. 评估指标:
    指标:选择适当的指标,如准确率、精确率、召回率或 F1 分数进行评估。

# 文本分类任务实战
# 数据集构建:影评数据集进行情感分析
# 词向量模型:加载训练好的词向量或者自己训练
# 序列网络模型:训练好RNN模型进行识别import  os
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
import numpy as np
import pprint
import logging
import time
from collections import  Counterfrom pathlib import Path
from tqdm import tqdm#加载影评数据集,可以自动下载放到对应位置
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data()
# a=x_train.shape
# print(a)
# 读进来的数据是已经转换成ID映射的,一般的数据读进来都是词语,都需要手动转换成ID映射的_word2idx = tf.keras.datasets.imdb.get_word_index()
word2idx = {w: i+3 for w, i in _word2idx.items()}
word2idx['<pad>'] = 0
word2idx['<start>'] = 1
word2idx['<unk>'] = 2
idx2word = {i: w for w, i in word2idx.items()}# 按文本长度大小进行排序def sort_by_len(x, y):x, y = np.asarray(x), np.asarray(y)idx = sorted(range(len(x)), key=lambda i: len(x[i]))return x[idx], y[idx]# 将中间结果保存到本地,万一程序崩了还得重玩,保存的是文本数据,不是IDx_train, y_train = sort_by_len(x_train, y_train)
x_test, y_test = sort_by_len(x_test, y_test)def write_file(f_path, xs, ys):with open(f_path, 'w',encoding='utf-8') as f:for x, y in zip(xs, ys):f.write(str(y)+'\t'+' '.join([idx2word[i] for i in x][1:])+'\n')write_file('./data/train.txt', x_train, y_train)
write_file('./data/test.txt', x_test, y_test)# 构建语料表,基于词频来进行统计counter = Counter()
with open('./data/train.txt',encoding='utf-8') as f:for line in f:line = line.rstrip()label, words = line.split('\t')words = words.split(' ')counter.update(words)words = ['<pad>'] + [w for w, freq in counter.most_common() if freq >= 10]
print('Vocab Size:', len(words))Path('./vocab').mkdir(exist_ok=True)with open('./vocab/word.txt', 'w',encoding='utf-8') as f:for w in words:f.write(w+'\n')# 得到新的word2id映射表word2idx = {}
with open('./vocab/word.txt',encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i# embedding层
# 可以基于网络来训练,也可以直接加载别人训练好的,一般都是加载预训练模型
# 这里有一些常用的:https://nlp.stanford.edu/projects/glove/#做了一个大表,里面有20598个不同的词,【20599*50】
embedding = np.zeros((len(word2idx)+1, 50)) # + 1 表示如果不在语料表中,就都是unknowwith open('./data/glove.6B.50d.txt',encoding='utf-8') as f: #下载好的count = 0for i, line in enumerate(f):if i % 100000 == 0:print('- At line {}'.format(i)) #打印处理了多少数据line = line.rstrip()0sp = line.split(' ')word, vec = sp[0], sp[1:]if word in word2idx:count += 1embedding[word2idx[word]] = np.asarray(vec, dtype='float32') #将词转换成对应的向量# 现在已经得到每个词索引所对应的向量print("[%d / %d] words have found pre-trained values"%(count, len(word2idx)))
np.save('./vocab/word.npy', embedding)
print('Saved ./vocab/word.npy')# 构建训练数据
# 注意所有的输入样本必须都是相同shape(文本长度,词向量维度等)
# 数据生成器
# tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
#
# tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, ydef dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes=_shapes,output_types=_types, )ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)  # 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes=_shapes,output_types=_types, )ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds# 自定义网络模型class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'),dtype=tf.float32,name='pretrained_embedding',trainable=False, )self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2 * params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2 * params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)x = tf.reshape(x, (batch_sz * 10 * 10, 10, 50))x = self.drop1(x, training=training)x = self.rnn1(x)x = tf.reduce_max(x, 1)x = tf.reshape(x, (batch_sz * 10, 10, rnn_units))x = self.drop2(x, training=training)x = self.rnn2(x)x = tf.reduce_max(x, 1)x = tf.reshape(x, (batch_sz, 10, rnn_units))x = self.drop3(x, training=training)x = self.rnn3(x)x = tf.reduce_max(x, 1)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x# 设置参数params = {'vocab_path': './vocab/word.txt','train_path': './data/train.txt','test_path': './data/test.txt','num_samples': 25000,'num_labels': 2,'batch_size': 32,'max_len': 1000,'rnn_units': 200,'dropout_rate': 0.2,'clip_norm': 10.,'num_patience': 3,'lr': 3e-4,
}def is_descending(history: list):history = history[-(params['num_patience']+1):]for i in range(1, len(history)):if history[i-1] <= history[i]:return Falsereturn Trueword2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1model = Model(params)
model.build(input_shape=(None, None))#设置输入的大小,或者fit时候也能自动找到
#pprint.pprint([(v.name, v.shape) for v in model.trainable_variables])#链接:https://tensorflow.google.cn/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay?version=stable
#return initial_learning_rate * decay_rate ^ (step / decay_steps)
decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0history_acc = []
best_acc = .0t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)while True:# 训练模型for texts, labels in dataset(is_training=True, params=params):with tf.GradientTape() as tape:  # 梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度logits = model(texts, training=True)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss = tf.reduce_mean(loss)optim.lr.assign(decay_lr(global_step))grads = tape.gradient(loss, model.trainable_variables)grads, _ = tf.clip_by_global_norm(grads, params['clip_norm'])  # 将梯度限制一下,有的时候回更新太猛,防止过拟合optim.apply_gradients(zip(grads, model.trainable_variables))  # 更新梯度if global_step % 50 == 0:logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(global_step, loss.numpy().item(), time.time() - t0, optim.lr.numpy().item()))t0 = time.time()global_step += 1# 验证集效果m = tf.keras.metrics.Accuracy()for texts, labels in dataset(is_training=False, params=params):logits = model(texts, training=False)y_pred = tf.argmax(logits, axis=-1)m.update_state(y_true=labels, y_pred=y_pred)acc = m.result().numpy()logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))history_acc.append(acc)if acc > best_acc:best_acc = acclogger.info("Best Accuracy: {:.3f}".format(best_acc))if len(history_acc) > params['num_patience'] and is_descending(history_acc):logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))break

相关文章:

RNN文本分类任务实战

递归神经网络 &#xff08;RNN&#xff09;&#xff1a; 定义&#xff1a;RNN 是一类专为顺序数据处理而设计的人工神经网络。 顺序处理&#xff1a;RNN 保持一个隐藏状态&#xff0c;该状态捕获有关序列中先前输入的信息&#xff0c;使其适用于涉及顺序依赖关系的任务。词嵌入…...

【算法系列 | 12】深入解析查找算法之—斐波那契查找

序言 心若有阳光&#xff0c;你便会看见这个世界有那么多美好值得期待和向往。 决定开一个算法专栏&#xff0c;希望能帮助大家很好的了解算法。主要深入解析每个算法&#xff0c;从概念到示例。 我们一起努力&#xff0c;成为更好的自己&#xff01; 今天第12讲&#xff0c;讲…...

全新的C++语言

一、概述 C 的最初目标就是成为 “更好的 C”&#xff0c;因此新的标准首先要对基本的底层编程进行强化&#xff0c;能够反映当前计算机软硬件系统的最新发展和变化&#xff08;例如多线程&#xff09;。另一方面&#xff0c;C对多线程范式的支持增加了语言的复杂度&#xff0…...

three.js 多通道组合

效果&#xff1a; 代码&#xff1a; <template><div><el-container><el-main><div class"box-card-left"><div id"threejs" style"border: 1px solid red"></div><div style"border: 1px so…...

编程笔记 html5cssjs 022 HTML表单概要

编程笔记 html5&css&js 022 HTML表单概要 一、<form> 元素二、HTML Form 属性三、操作小结 网页光是输出没有输入可不行&#xff0c;因为输出还是比输入容易&#xff0c;所有就先接触输出&#xff0c;后学习输入。html用来输入的东西叫“表单”。 HTML 表单用于搜…...

​三子棋(c语言)

前言&#xff1a; 三子棋是一种民间传统游戏&#xff0c;又叫九宫棋、圈圈叉叉棋、一条龙、井字棋等。游戏规则是双方对战&#xff0c;双方依次在9宫格棋盘上摆放棋子&#xff0c;率先将自己的三个棋子走成一条线就视为胜利。但因棋盘太小&#xff0c;三子棋在很多时候会出现和…...

MySQL-DCL

DCL是数据控制语言&#xff0c;用来管理数据库用户&#xff0c;控制数据库的访问权限。 管理用户&#xff1a;管理哪些用户可以访问哪些数据库 1.查询用户 USE mysql; SELECT * FROM user; 注意&#xff1a; MySQL中用户信息和用户的权限信息都是记录在mysql数据库的user表中的…...

QT开源类库集合

QT开源类库集合 一、自定义控件 QSintQicsTableLongscroll-qtAdvanced Docking System 二、图表控件 QwtQCustomPlotJKQTPlotter 三、网络 QHttpEngineHTTP 四、 音视频 vlc-qt 五、多线程 tasks 六、数据库 EasyQtSql 一、自定义控件 1. QSint 源代码地址&#xff1a;QSint&…...

C++ STL(2)--算法(2)

算法(2)----STL里的排序函数。 1. sort: 对容器或普通数组中指定范围内的元素进行排序&#xff0c;默认进行升序排序。 sort函数是基于快速排序实现的&#xff0c;属于不稳定排序。 只支持3种容器&#xff1a;array、vector、deque。 如果容器中存储的是自定义的对象&#xff…...

格密码基础:对偶格(超全面)

目录 一. 对偶格的格点 1.1 基本定义 1.2 对偶格的例子 1.3 对偶格的图形理解 二. 对偶格的格基 2.1 基本定义 2.2 对偶格的格基证明 三. 对偶格的行列式 3.1 满秩格 3.2 非满秩格 四. 重复对偶格 五. 对偶格的转移定理&#xff08;transference theorem&#xff…...

ECMAScript简介及特性

ECMAScript是一种由ECMA国际&#xff08;前身为欧洲计算机制造商协会&#xff09;制定和发布的脚本语言规范&#xff0c;JavaScript在它基础上进行了自己的封装。ECMAScript和JavaScript的关系是&#xff0c;前者是后者的规格&#xff0c;后者是前者的一种实现。 ECMAScript的…...

csdn中的资源文件如何删除?

csdn中的资源文件如何删除&#xff1f; 然后写文章的时候 点击资源绑定&#xff0c;解锁资源&#xff0c;就可以再次上传。...

NA原理及配置

在IP地址空间中&#xff0c;a&#xff1b;b&#xff1b;c类地址中各有一部分地址&#xff0c;被称为私有IP地址&#xff08;私网地址&#xff09;&#xff0c;其余的为公有IP地址&#xff08;公网地址&#xff09; A&#xff1a;10.0.0.0 - 10.255.255.255 --- 相当于1条A类网段…...

解决:TypeError: ‘tuple’ object does not support item assignment

解决&#xff1a;TypeError: ‘tuple’ object does not support item assignment 文章目录 解决&#xff1a;TypeError: tuple object does not support item assignment背景报错问题报错翻译报错位置代码报错原因解决方法方法一&#xff1a;方法二&#xff1a;今天的分享就到…...

vue3项目中axios的常见用法和封装拦截(详细解释)

1、axios的简单介绍 Axios是一个基于Promise的HTTP客户端库&#xff0c;用于浏览器和Node.js环境中发送HTTP请求。它提供了一种简单、易用且功能丰富的方式来与后端服务器进行通信。能够发送常见的HTTP请求&#xff0c;并获得服务端返回的数据。 此外&#xff0c;Axios还提供…...

基础语法(一)(1)

常量和表达式 在这里&#xff0c;我们可以把Python当成一个计算器&#xff0c;来进行一些算术运算 例如&#xff1a; print(1 2 - 3) print(1 2 * 3) print(1 2 / 3)注意&#xff1a; print是一个python内置的函数&#xff0c;这个稍后我们会进行介绍 可以使用-*/&…...

YOLOv8模型yaml结构图理解(逐层分析)

前言 YOLO-V8&#xff08;官网地址&#xff09;&#xff1a;https://github.com/ultralytics/ultralytics 一、yolov8配置yaml文件 YOLOv8的配置文件定义了模型的关键参数和结构&#xff0c;包括类别数、模型尺寸、骨架&#xff08;backbone&#xff09;和头部&#xff08;hea…...

【大数据】Zookeeper 集群及其选举机制

Zookeeper 集群及其选举机制 1.安装 Zookeeper 集群2.如何选取 Leader 1.安装 Zookeeper 集群 我们之前说了&#xff0c;Zookeeper 集群是由一个领导者&#xff08;Leader&#xff09;和多个追随者&#xff08;Follower&#xff09;组成&#xff0c;但这个领导者是怎么选出来的…...

Redis 过期策略

我们在set key的时候可以设置key的过期时间&#xff0c;哪redis是怎么处理过期的key的呢&#xff1f; 有三种过期策略 定时过期&#xff1a;每个设置过期时间的key会创建一个定时器&#xff0c;到过期时间就会立即对key进行清除。该策略可以立即清除过期的数据&#xff0c;对…...

RT_Thread 调试笔记:串口打印、MSH控制台 相关

说明&#xff1a;记录日常使用 RT_Thread 开发时做的笔记。 持续更新中&#xff0c;欢迎收藏。 1.打印相关 1.打印宏定义&#xff0c;可以打印打印所在文件&#xff0c;函数&#xff0c;行数。 #define PRINT_TRACE() printf("-------%s:%s:%d------\r\n", __FIL…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

YSYX学习记录(八)

C语言&#xff0c;练习0&#xff1a; 先创建一个文件夹&#xff0c;我用的是物理机&#xff1a; 安装build-essential 练习1&#xff1a; 我注释掉了 #include <stdio.h> 出现下面错误 在你的文本编辑器中打开ex1文件&#xff0c;随机修改或删除一部分&#xff0c;之后…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

学习STC51单片机31(芯片为STC89C52RCRC)OLED显示屏1

每日一言 生活的美好&#xff0c;总是藏在那些你咬牙坚持的日子里。 硬件&#xff1a;OLED 以后要用到OLED的时候找到这个文件 OLED的设备地址 SSD1306"SSD" 是品牌缩写&#xff0c;"1306" 是产品编号。 驱动 OLED 屏幕的 IIC 总线数据传输格式 示意图 …...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

AI,如何重构理解、匹配与决策?

AI 时代&#xff0c;我们如何理解消费&#xff1f; 作者&#xff5c;王彬 封面&#xff5c;Unplash 人们通过信息理解世界。 曾几何时&#xff0c;PC 与移动互联网重塑了人们的购物路径&#xff1a;信息变得唾手可得&#xff0c;商品决策变得高度依赖内容。 但 AI 时代的来…...

管理学院权限管理系统开发总结

文章目录 &#x1f393; 管理学院权限管理系统开发总结 - 现代化Web应用实践之路&#x1f4dd; 项目概述&#x1f3d7;️ 技术架构设计后端技术栈前端技术栈 &#x1f4a1; 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 &#x1f5c4;️ 数据库设…...

免费数学几何作图web平台

光锐软件免费数学工具&#xff0c;maths,数学制图&#xff0c;数学作图&#xff0c;几何作图&#xff0c;几何&#xff0c;AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...