当前位置: 首页 > news >正文

Tensorflow2基础代码实战系列之双层RNN文本分类任务

深度学习框架Tensorflow2系列

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
这个系列主要和大家分享深度学习框架Tensorflow2的各种api,从基础开始。
#博学谷IT学习技术支持#


文章目录

  • 深度学习框架Tensorflow2系列
  • 前言
  • 一、文本分类任务实战
  • 二、数据集介绍
  • 三、RNN模型所需数据解读
  • 四、实战代码
    • 1.数据预处理
    • 2.构建初始化embedding层
    • 3.构建训练数据
    • 4.自定义双层RNN网络模型
    • 5.设置参数和训练策略
    • 6.模型训练
  • 总结


前言

通过代码案例实战,学习Tensorflow2的各种api。


一、文本分类任务实战

任务介绍:
数据集构建:影评数据集进行情感分析(分类任务)
词向量模型:加载训练好的词向量或者自己训练都可以
序列网络模型:训练RNN模型进行识别

二、数据集介绍

训练和测试集都是比较简单的电影评价数据集,标签为0和1的二分类,表示对电影的喜欢和不喜欢
在这里插入图片描述

三、RNN模型所需数据解读

在这里插入图片描述
RNN是一个比较基础的序列化模型,其中输入的数据为[batch_size,max_len,feature_dim]

四、实战代码

1.数据预处理

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import numpy as np
import pprint
import logging
import time
from collections import Counter
from pathlib import Path
from tqdm import tqdm# 构建语料表,基于词频来进行统计
counter = Counter()
with open('./data/train.txt',encoding='utf-8') as f:for line in f:line = line.rstrip()label, words = line.split('\t')words = words.split(' ')counter.update(words)words = ['<pad>'] + [w for w, freq in counter.most_common() if freq >= 10]
print('Vocab Size:', len(words))Path('./vocab').mkdir(exist_ok=True)with open('./vocab/word.txt', 'w',encoding='utf-8') as f:for w in words:f.write(w+'\n')# 得到word2id映射表
word2idx = {}
with open('./vocab/word.txt',encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i

得到的结果如下
在这里插入图片描述

2.构建初始化embedding层

# 做了一个大表,里面有20598个不同的词,【20599*50】
embedding = np.zeros((len(word2idx)+1, 50)) # + 1 表示如果不在语料表中,就都是unknowwith open('./data/glove.6B.50d.txt',encoding='utf-8') as f: #下载好的count = 0for i, line in enumerate(f):if i % 100000 == 0:print('- At line {}'.format(i)) #打印处理了多少数据line = line.rstrip()sp = line.split(' ')word, vec = sp[0], sp[1:]if word in word2idx:count += 1embedding[word2idx[word]] = np.asarray(vec, dtype='float32') #将词转换成对应的向量# 保存结果
print("[%d / %d] words have found pre-trained values"%(count, len(word2idx)))
np.save('./vocab/word.npy', embedding)
print('Saved ./vocab/word.npy')

得到的结果如下:word.txt中的每个单词转换成对应的向量
在这里插入图片描述

3.构建训练数据

def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, ydef dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)#设置缓存序列,目的加速else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds

4.自定义双层RNN网络模型

class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'),dtype=tf.float32,name='pretrained_embedding',trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x

5.设置参数和训练策略

params = {'vocab_path': './vocab/word.txt','train_path': './data/train.txt','test_path': './data/test.txt','num_samples': 25000,'num_labels': 2,'batch_size': 32,'max_len': 1000,'rnn_units': 200,'dropout_rate': 0.2,'clip_norm': 10.,'num_patience': 3,'lr': 3e-4,
}
def is_descending(history: list):history = history[-(params['num_patience']+1):]for i in range(1, len(history)):if history[i-1] <= history[i]:return Falsereturn True  
word2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1model = Model(params)
model.build(input_shape=(None, None))#设置输入的大小,或者fit时候也能自动找到decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0history_acc = []
best_acc = .0t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)

6.模型训练

while True:# 训练模型for texts, labels in dataset(is_training=True, params=params):with tf.GradientTape() as tape:#梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度logits = model(texts, training=True)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss = tf.reduce_mean(loss)optim.lr.assign(decay_lr(global_step))grads = tape.gradient(loss, model.trainable_variables)grads, _ = tf.clip_by_global_norm(grads, params['clip_norm']) #将梯度限制一下,有的时候回更新太猛,防止过拟合optim.apply_gradients(zip(grads, model.trainable_variables))#更新梯度if global_step % 50 == 0:logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))t0 = time.time()global_step += 1# 验证集效果m = tf.keras.metrics.Accuracy()for texts, labels in dataset(is_training=False, params=params):logits = model(texts, training=False)y_pred = tf.argmax(logits, axis=-1)m.update_state(y_true=labels, y_pred=y_pred)acc = m.result().numpy()logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))history_acc.append(acc)if acc > best_acc:best_acc = acclogger.info("Best Accuracy: {:.3f}".format(best_acc))if len(history_acc) > params['num_patience'] and is_descending(history_acc):logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))break

总结

通过RNN文本分类任务代码案例实战,学习Tensorflow2的各种api。

相关文章:

Tensorflow2基础代码实战系列之双层RNN文本分类任务

深度学习框架Tensorflow2系列 注&#xff1a;大家觉得博客好的话&#xff0c;别忘了点赞收藏呀&#xff0c;本人每周都会更新关于人工智能和大数据相关的内容&#xff0c;内容多为原创&#xff0c;Python Java Scala SQL 代码&#xff0c;CV NLP 推荐系统等&#xff0c;Spark …...

Python爬虫-快手photoId

前言 本文是该专栏的第49篇,后面会持续分享python爬虫干货知识,记得关注。 笔者在本专栏的上一篇,有详细介绍平台视频播放量的爬取方法。与该平台相关联的文章,笔者已整理在下方,感兴趣的同学可查看翻阅。 1. Python如何解决“快手滑块验证码”(4) 2. 快手pcursor 3. …...

软件测试人员如何为项目的质量保障兜底?看完你就明白了...

上线前层层保障 01文档管理 关键词&#xff1a;需求文档、设计文档、测试文档 1.需求和设计产出方为产品、开发&#xff0c;测试需要做好流程监督&#xff0c;这里重点说下测试文档。 2.测试文档&#xff0c;从业务领域来说&#xff0c;一般有测试计划、测试用例、业务总结文…...

《幸福关系的7段旅程》

关于作者 本书作者安德鲁∙马歇尔&#xff0c;英国顶尖婚姻咨询机构RELATE的资深专家&#xff0c;拥有 30年丰富的咨询经验&#xff0c;并为《泰晤士报》《观察家》和《星期日快报》撰写专栏文章。已出版19部作品&#xff0c;并被翻译成20种语言。 关于本书 《幸福关系的7段…...

使用Python中PDB模块中的命令来调试Python代码的教程

这篇文章主要介绍了使用Python中PDB模块中的命令来调试Python代码的教程,包括设置断点来修改代码等、对于Python团队项目工作有一定帮助&#xff0c;需要的朋友可以参考下 你有多少次陷入不得不更改别人代码的境地&#xff1f;如果你是一个开发团队的一员&#xff0c;那么你遇…...

Codeforces Round 764 (Div. 3)

比赛链接 Codeforces Round 764 A. Plus One on the SubsetB. Make APC. Division by Two and PermutationD. Palindromes ColoringE. Masha-forgetful A. Plus One on the Subset Example input 3 6 3 4 2 4 1 2 3 1000 1002 998 2 12 11output 3 4 1题意&#xff1a; 你可…...

四月,收割12家offer,面试也太容易了吧....

前言 下面是我根据工作这几年来的面试经验&#xff0c;加上之前收集的资料&#xff0c;整理出来350道软件测试工程师 常考的面试题。字节跳动、阿里、腾讯、百度、快手、美团等大厂常考的面试题&#xff0c;在文章里面都有 提到。 虽然这篇文章很长&#xff0c;但是绝对值得你…...

Xubuntu22.04之自动调节亮度护眼redshift(一百七十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…...

Spark基础学习笔记----RDD检查点与共享变量

零、本讲学习目标 了解RDD容错机制理解RDD检查点机制的特点与用处理解共享变量的类别、特点与使用 一、RDD容错机制 当Spark集群中的某一个节点由于宕机导致数据丢失&#xff0c;则可以通过Spark中的RDD进行容错恢复已经丢失的数据。RDD提供了两种故障恢复的方式&#xff0c…...

ES6(对象,数组,类型化数组)

对象 1&#xff0c;Object.is 用于判断两个值是否相等&#xff0c; 其内部实现类SameValue算法&#xff0c; 其行为类似于“” 但与“”不同的是 它认为两个NaN是相等的 而0&#xff0c;-0是不相等的 2&#xff0c;Object.assign 表示此方法可以将对象合并成一个 他的第一个…...

JVM系列-第12章-垃圾回收器

垃圾回收器 GC 分类与性能指标 垃圾回收器概述 垃圾收集器没有在规范中进行过多的规定&#xff0c;可以由不同的厂商、不同版本的JVM来实现。 由于JDK的版本处于高速迭代过程中&#xff0c;因此Java发展至今已经衍生了众多的GC版本。 从不同角度分析垃圾收集器&#xff0c;…...

零操作难度,轻松进行应用测试,App专项测试之Monkey测试完全指南!

目录 前言&#xff1a; 一、 Monkey测试的基础参数 1.1 事件类型参数&#xff1a; 1.2 覆盖包 1.3 事件数量 二、 Monkey测试的高级参数 2.1 稳定性级别 2.2 策略参数 2.3 包含选项参数 三、 附加代码 四、 总结 前言&#xff1a; 在移动应用的开发过程中&#xff0…...

Linux安装Docker(这应该是你看过的最简洁的安装教程)

Docker是一种开源的容器化平台&#xff0c;可以将应用程序及其依赖项打包成一个可移植的容器&#xff0c;以便在不同的环境中运行。Docker的核心是Docker引擎&#xff0c;它可以自动化应用程序的部署、扩展和管理&#xff0c;同时还提供了一个开放的API&#xff0c;可以与其他工…...

使用AES算法加密技术集成Java和Vue保护您的数据,代码示例和算法原理

1 算法的原理: AES是一种对称加密算法,也就是说加密和解密使用的是同一个密钥。其基本原理是将明文分成固定大小的块(128位),然后使用密钥对每个块进行加密操作,最后生成密文。在加密过程中,还需要使用一个向量(IV)来增加安全性,避免相同的明文块生成相同的密文块。…...

vcruntime140_1.dll丢失怎样修复,推荐4个vcruntime140_1.dll丢失的修复方法

vcruntime140_1.dll文件是Microsoft Visual C Redistributable for Visual Studio 2015运行库的一部分&#xff0c;它是一个用于支持Visual C构建的应用程序的系统文件。这个文件包含了在运行C程序时所需要的函数和类库&#xff0c;主要负责向应用程序提供运行时环境。如果电脑…...

快来试试这几个简单好用的手机技巧吧

技巧一&#xff1a;相机功能 苹果手机的相机功能确实非常出色&#xff0c;除了出色的像素之外&#xff0c;还有许多其他实用功能可以提升拍摄体验。 这些相机功能提供了更多的选择和便利性&#xff0c;使用户能够更好地适应不同的拍摄需求。 自拍功能&#xff1a;通过选择自…...

OneDrive同步角标消失 - 解决方案

问题 在电脑端使用OneDrive时&#xff0c;文件管理器OneDrive文件夹内的文件会在左下角显示同步状态&#xff0c;如下图。若没有显示同步角标&#xff0c;则此功能出现异常&#xff0c;下文介绍如何显示同步角标。 值得一提的是&#xff0c;同步角标只起到显示作用&#xff0…...

自学网络安全【黑客】,一般人我劝你还是算了吧

前言&#xff1a;我是劝一般人算了&#xff0c;看你是一般人还是。。。 一、网络安全学习的误区 1.不要试图以编程为基础去学习网络安全2.不要刚开始就深度学习网络安全3.收集适当的学习资料4.适当的报班学习二、学习网络安全的些许准备 1.硬件选择2.软件选择3.外语能力三、网…...

Java集合工具:first和last

在平常开发过程中&#xff0c;我们经常会遇到截取列表片段的需求&#xff0c;比如取列表中前4个元素、取后四个元素。Java的List提供了subList方法&#xff0c;可以用来完成这些工作&#xff0c;但是使用起来并没有那么便利&#xff0c;比如取前四个元素&#xff1a; list.sub…...

leetcode 905. 按奇偶排序数组

题目描述解题思路执行结果 leetcode 905. 按奇偶排序数组 题目描述 按奇偶排序数组 给你一个整数数组 nums&#xff0c;将 nums 中的的所有偶数元素移动到数组的前面&#xff0c;后跟所有奇数元素。 返回满足此条件的 任一数组 作为答案。 示例 1&#xff1a; 输入&#xff1a;…...

uniapp 对接腾讯云IM群组成员管理(增删改查)

UniApp 实战&#xff1a;腾讯云IM群组成员管理&#xff08;增删改查&#xff09; 一、前言 在社交类App开发中&#xff0c;群组成员管理是核心功能之一。本文将基于UniApp框架&#xff0c;结合腾讯云IM SDK&#xff0c;详细讲解如何实现群组成员的增删改查全流程。 权限校验…...

Python爬虫实战:研究MechanicalSoup库相关技术

一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...

Linux 文件类型,目录与路径,文件与目录管理

文件类型 后面的字符表示文件类型标志 普通文件&#xff1a;-&#xff08;纯文本文件&#xff0c;二进制文件&#xff0c;数据格式文件&#xff09; 如文本文件、图片、程序文件等。 目录文件&#xff1a;d&#xff08;directory&#xff09; 用来存放其他文件或子目录。 设备…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

376. Wiggle Subsequence

376. Wiggle Subsequence 代码 class Solution { public:int wiggleMaxLength(vector<int>& nums) {int n nums.size();int res 1;int prediff 0;int curdiff 0;for(int i 0;i < n-1;i){curdiff nums[i1] - nums[i];if( (prediff > 0 && curdif…...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式&#xff1a; 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

Android15默认授权浮窗权限

我们经常有那种需求&#xff0c;客户需要定制的apk集成在ROM中&#xff0c;并且默认授予其【显示在其他应用的上层】权限&#xff0c;也就是我们常说的浮窗权限&#xff0c;那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!

简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求&#xff0c;并检查收到的响应。它以以下模式之一…...

【Post-process】【VBA】ETABS VBA FrameObj.GetNameList and write to EXCEL

ETABS API实战:导出框架元素数据到Excel 在结构工程师的日常工作中,经常需要从ETABS模型中提取框架元素信息进行后续分析。手动复制粘贴不仅耗时,还容易出错。今天我们来用简单的VBA代码实现自动化导出。 🎯 我们要实现什么? 一键点击,就能将ETABS中所有框架元素的基…...