Tensorflow2基础代码实战系列之双层RNN文本分类任务
深度学习框架Tensorflow2系列
注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
这个系列主要和大家分享深度学习框架Tensorflow2的各种api,从基础开始。
#博学谷IT学习技术支持#
文章目录
- 深度学习框架Tensorflow2系列
- 前言
- 一、文本分类任务实战
- 二、数据集介绍
- 三、RNN模型所需数据解读
- 四、实战代码
- 1.数据预处理
- 2.构建初始化embedding层
- 3.构建训练数据
- 4.自定义双层RNN网络模型
- 5.设置参数和训练策略
- 6.模型训练
- 总结
前言
通过代码案例实战,学习Tensorflow2的各种api。
一、文本分类任务实战
任务介绍:
数据集构建:影评数据集进行情感分析(分类任务)
词向量模型:加载训练好的词向量或者自己训练都可以
序列网络模型:训练RNN模型进行识别
二、数据集介绍
训练和测试集都是比较简单的电影评价数据集,标签为0和1的二分类,表示对电影的喜欢和不喜欢
三、RNN模型所需数据解读
RNN是一个比较基础的序列化模型,其中输入的数据为[batch_size,max_len,feature_dim]
四、实战代码
1.数据预处理
import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import numpy as np
import pprint
import logging
import time
from collections import Counter
from pathlib import Path
from tqdm import tqdm# 构建语料表,基于词频来进行统计
counter = Counter()
with open('./data/train.txt',encoding='utf-8') as f:for line in f:line = line.rstrip()label, words = line.split('\t')words = words.split(' ')counter.update(words)words = ['<pad>'] + [w for w, freq in counter.most_common() if freq >= 10]
print('Vocab Size:', len(words))Path('./vocab').mkdir(exist_ok=True)with open('./vocab/word.txt', 'w',encoding='utf-8') as f:for w in words:f.write(w+'\n')# 得到word2id映射表
word2idx = {}
with open('./vocab/word.txt',encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
得到的结果如下
2.构建初始化embedding层
# 做了一个大表,里面有20598个不同的词,【20599*50】
embedding = np.zeros((len(word2idx)+1, 50)) # + 1 表示如果不在语料表中,就都是unknowwith open('./data/glove.6B.50d.txt',encoding='utf-8') as f: #下载好的count = 0for i, line in enumerate(f):if i % 100000 == 0:print('- At line {}'.format(i)) #打印处理了多少数据line = line.rstrip()sp = line.split(' ')word, vec = sp[0], sp[1:]if word in word2idx:count += 1embedding[word2idx[word]] = np.asarray(vec, dtype='float32') #将词转换成对应的向量# 保存结果
print("[%d / %d] words have found pre-trained values"%(count, len(word2idx)))
np.save('./vocab/word.npy', embedding)
print('Saved ./vocab/word.npy')
得到的结果如下:word.txt中的每个单词转换成对应的向量
3.构建训练数据
def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, ydef dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)#设置缓存序列,目的加速else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds
4.自定义双层RNN网络模型
class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'),dtype=tf.float32,name='pretrained_embedding',trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x
5.设置参数和训练策略
params = {'vocab_path': './vocab/word.txt','train_path': './data/train.txt','test_path': './data/test.txt','num_samples': 25000,'num_labels': 2,'batch_size': 32,'max_len': 1000,'rnn_units': 200,'dropout_rate': 0.2,'clip_norm': 10.,'num_patience': 3,'lr': 3e-4,
}
def is_descending(history: list):history = history[-(params['num_patience']+1):]for i in range(1, len(history)):if history[i-1] <= history[i]:return Falsereturn True
word2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1model = Model(params)
model.build(input_shape=(None, None))#设置输入的大小,或者fit时候也能自动找到decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0history_acc = []
best_acc = .0t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)
6.模型训练
while True:# 训练模型for texts, labels in dataset(is_training=True, params=params):with tf.GradientTape() as tape:#梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度logits = model(texts, training=True)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss = tf.reduce_mean(loss)optim.lr.assign(decay_lr(global_step))grads = tape.gradient(loss, model.trainable_variables)grads, _ = tf.clip_by_global_norm(grads, params['clip_norm']) #将梯度限制一下,有的时候回更新太猛,防止过拟合optim.apply_gradients(zip(grads, model.trainable_variables))#更新梯度if global_step % 50 == 0:logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))t0 = time.time()global_step += 1# 验证集效果m = tf.keras.metrics.Accuracy()for texts, labels in dataset(is_training=False, params=params):logits = model(texts, training=False)y_pred = tf.argmax(logits, axis=-1)m.update_state(y_true=labels, y_pred=y_pred)acc = m.result().numpy()logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))history_acc.append(acc)if acc > best_acc:best_acc = acclogger.info("Best Accuracy: {:.3f}".format(best_acc))if len(history_acc) > params['num_patience'] and is_descending(history_acc):logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))break
总结
通过RNN文本分类任务代码案例实战,学习Tensorflow2的各种api。
相关文章:

Tensorflow2基础代码实战系列之双层RNN文本分类任务
深度学习框架Tensorflow2系列 注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark …...
Python爬虫-快手photoId
前言 本文是该专栏的第49篇,后面会持续分享python爬虫干货知识,记得关注。 笔者在本专栏的上一篇,有详细介绍平台视频播放量的爬取方法。与该平台相关联的文章,笔者已整理在下方,感兴趣的同学可查看翻阅。 1. Python如何解决“快手滑块验证码”(4) 2. 快手pcursor 3. …...

软件测试人员如何为项目的质量保障兜底?看完你就明白了...
上线前层层保障 01文档管理 关键词:需求文档、设计文档、测试文档 1.需求和设计产出方为产品、开发,测试需要做好流程监督,这里重点说下测试文档。 2.测试文档,从业务领域来说,一般有测试计划、测试用例、业务总结文…...

《幸福关系的7段旅程》
关于作者 本书作者安德鲁∙马歇尔,英国顶尖婚姻咨询机构RELATE的资深专家,拥有 30年丰富的咨询经验,并为《泰晤士报》《观察家》和《星期日快报》撰写专栏文章。已出版19部作品,并被翻译成20种语言。 关于本书 《幸福关系的7段…...
使用Python中PDB模块中的命令来调试Python代码的教程
这篇文章主要介绍了使用Python中PDB模块中的命令来调试Python代码的教程,包括设置断点来修改代码等、对于Python团队项目工作有一定帮助,需要的朋友可以参考下 你有多少次陷入不得不更改别人代码的境地?如果你是一个开发团队的一员,那么你遇…...

Codeforces Round 764 (Div. 3)
比赛链接 Codeforces Round 764 A. Plus One on the SubsetB. Make APC. Division by Two and PermutationD. Palindromes ColoringE. Masha-forgetful A. Plus One on the Subset Example input 3 6 3 4 2 4 1 2 3 1000 1002 998 2 12 11output 3 4 1题意: 你可…...

四月,收割12家offer,面试也太容易了吧....
前言 下面是我根据工作这几年来的面试经验,加上之前收集的资料,整理出来350道软件测试工程师 常考的面试题。字节跳动、阿里、腾讯、百度、快手、美团等大厂常考的面试题,在文章里面都有 提到。 虽然这篇文章很长,但是绝对值得你…...

Xubuntu22.04之自动调节亮度护眼redshift(一百七十四)
简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…...

Spark基础学习笔记----RDD检查点与共享变量
零、本讲学习目标 了解RDD容错机制理解RDD检查点机制的特点与用处理解共享变量的类别、特点与使用 一、RDD容错机制 当Spark集群中的某一个节点由于宕机导致数据丢失,则可以通过Spark中的RDD进行容错恢复已经丢失的数据。RDD提供了两种故障恢复的方式,…...
ES6(对象,数组,类型化数组)
对象 1,Object.is 用于判断两个值是否相等, 其内部实现类SameValue算法, 其行为类似于“” 但与“”不同的是 它认为两个NaN是相等的 而0,-0是不相等的 2,Object.assign 表示此方法可以将对象合并成一个 他的第一个…...

JVM系列-第12章-垃圾回收器
垃圾回收器 GC 分类与性能指标 垃圾回收器概述 垃圾收集器没有在规范中进行过多的规定,可以由不同的厂商、不同版本的JVM来实现。 由于JDK的版本处于高速迭代过程中,因此Java发展至今已经衍生了众多的GC版本。 从不同角度分析垃圾收集器,…...

零操作难度,轻松进行应用测试,App专项测试之Monkey测试完全指南!
目录 前言: 一、 Monkey测试的基础参数 1.1 事件类型参数: 1.2 覆盖包 1.3 事件数量 二、 Monkey测试的高级参数 2.1 稳定性级别 2.2 策略参数 2.3 包含选项参数 三、 附加代码 四、 总结 前言: 在移动应用的开发过程中࿰…...

Linux安装Docker(这应该是你看过的最简洁的安装教程)
Docker是一种开源的容器化平台,可以将应用程序及其依赖项打包成一个可移植的容器,以便在不同的环境中运行。Docker的核心是Docker引擎,它可以自动化应用程序的部署、扩展和管理,同时还提供了一个开放的API,可以与其他工…...
使用AES算法加密技术集成Java和Vue保护您的数据,代码示例和算法原理
1 算法的原理: AES是一种对称加密算法,也就是说加密和解密使用的是同一个密钥。其基本原理是将明文分成固定大小的块(128位),然后使用密钥对每个块进行加密操作,最后生成密文。在加密过程中,还需要使用一个向量(IV)来增加安全性,避免相同的明文块生成相同的密文块。…...

vcruntime140_1.dll丢失怎样修复,推荐4个vcruntime140_1.dll丢失的修复方法
vcruntime140_1.dll文件是Microsoft Visual C Redistributable for Visual Studio 2015运行库的一部分,它是一个用于支持Visual C构建的应用程序的系统文件。这个文件包含了在运行C程序时所需要的函数和类库,主要负责向应用程序提供运行时环境。如果电脑…...

快来试试这几个简单好用的手机技巧吧
技巧一:相机功能 苹果手机的相机功能确实非常出色,除了出色的像素之外,还有许多其他实用功能可以提升拍摄体验。 这些相机功能提供了更多的选择和便利性,使用户能够更好地适应不同的拍摄需求。 自拍功能:通过选择自…...

OneDrive同步角标消失 - 解决方案
问题 在电脑端使用OneDrive时,文件管理器OneDrive文件夹内的文件会在左下角显示同步状态,如下图。若没有显示同步角标,则此功能出现异常,下文介绍如何显示同步角标。 值得一提的是,同步角标只起到显示作用࿰…...

自学网络安全【黑客】,一般人我劝你还是算了吧
前言:我是劝一般人算了,看你是一般人还是。。。 一、网络安全学习的误区 1.不要试图以编程为基础去学习网络安全2.不要刚开始就深度学习网络安全3.收集适当的学习资料4.适当的报班学习二、学习网络安全的些许准备 1.硬件选择2.软件选择3.外语能力三、网…...
Java集合工具:first和last
在平常开发过程中,我们经常会遇到截取列表片段的需求,比如取列表中前4个元素、取后四个元素。Java的List提供了subList方法,可以用来完成这些工作,但是使用起来并没有那么便利,比如取前四个元素: list.sub…...
leetcode 905. 按奇偶排序数组
题目描述解题思路执行结果 leetcode 905. 按奇偶排序数组 题目描述 按奇偶排序数组 给你一个整数数组 nums,将 nums 中的的所有偶数元素移动到数组的前面,后跟所有奇数元素。 返回满足此条件的 任一数组 作为答案。 示例 1: 输入:…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

UE5 学习系列(三)创建和移动物体
这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...

【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...
ffmpeg(四):滤镜命令
FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...

在WSL2的Ubuntu镜像中安装Docker
Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...
return this;返回的是谁
一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请,不同级别的经理有不同的审批权限: // 抽象处理者:审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...

AI+无人机如何守护濒危物种?YOLOv8实现95%精准识别
【导读】 野生动物监测在理解和保护生态系统中发挥着至关重要的作用。然而,传统的野生动物观察方法往往耗时耗力、成本高昂且范围有限。无人机的出现为野生动物监测提供了有前景的替代方案,能够实现大范围覆盖并远程采集数据。尽管具备这些优势…...

云原生安全实战:API网关Kong的鉴权与限流详解
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关(API Gateway) API网关是微服务架构中的核心组件,负责统一管理所有API的流量入口。它像一座…...
C#学习第29天:表达式树(Expression Trees)
目录 什么是表达式树? 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持: 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...
Go语言多线程问题
打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...