当前位置: 首页 > news >正文

Tensorflow2基础代码实战系列之双层RNN文本分类任务

深度学习框架Tensorflow2系列

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
这个系列主要和大家分享深度学习框架Tensorflow2的各种api,从基础开始。
#博学谷IT学习技术支持#


文章目录

  • 深度学习框架Tensorflow2系列
  • 前言
  • 一、文本分类任务实战
  • 二、数据集介绍
  • 三、RNN模型所需数据解读
  • 四、实战代码
    • 1.数据预处理
    • 2.构建初始化embedding层
    • 3.构建训练数据
    • 4.自定义双层RNN网络模型
    • 5.设置参数和训练策略
    • 6.模型训练
  • 总结


前言

通过代码案例实战,学习Tensorflow2的各种api。


一、文本分类任务实战

任务介绍:
数据集构建:影评数据集进行情感分析(分类任务)
词向量模型:加载训练好的词向量或者自己训练都可以
序列网络模型:训练RNN模型进行识别

二、数据集介绍

训练和测试集都是比较简单的电影评价数据集,标签为0和1的二分类,表示对电影的喜欢和不喜欢
在这里插入图片描述

三、RNN模型所需数据解读

在这里插入图片描述
RNN是一个比较基础的序列化模型,其中输入的数据为[batch_size,max_len,feature_dim]

四、实战代码

1.数据预处理

import os
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import numpy as np
import pprint
import logging
import time
from collections import Counter
from pathlib import Path
from tqdm import tqdm# 构建语料表,基于词频来进行统计
counter = Counter()
with open('./data/train.txt',encoding='utf-8') as f:for line in f:line = line.rstrip()label, words = line.split('\t')words = words.split(' ')counter.update(words)words = ['<pad>'] + [w for w, freq in counter.most_common() if freq >= 10]
print('Vocab Size:', len(words))Path('./vocab').mkdir(exist_ok=True)with open('./vocab/word.txt', 'w',encoding='utf-8') as f:for w in words:f.write(w+'\n')# 得到word2id映射表
word2idx = {}
with open('./vocab/word.txt',encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i

得到的结果如下
在这里插入图片描述

2.构建初始化embedding层

# 做了一个大表,里面有20598个不同的词,【20599*50】
embedding = np.zeros((len(word2idx)+1, 50)) # + 1 表示如果不在语料表中,就都是unknowwith open('./data/glove.6B.50d.txt',encoding='utf-8') as f: #下载好的count = 0for i, line in enumerate(f):if i % 100000 == 0:print('- At line {}'.format(i)) #打印处理了多少数据line = line.rstrip()sp = line.split(' ')word, vec = sp[0], sp[1:]if word in word2idx:count += 1embedding[word2idx[word]] = np.asarray(vec, dtype='float32') #将词转换成对应的向量# 保存结果
print("[%d / %d] words have found pre-trained values"%(count, len(word2idx)))
np.save('./vocab/word.npy', embedding)
print('Saved ./vocab/word.npy')

得到的结果如下:word.txt中的每个单词转换成对应的向量
在这里插入图片描述

3.构建训练数据

def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, ydef dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)#设置缓存序列,目的加速else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds

4.自定义双层RNN网络模型

class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'),dtype=tf.float32,name='pretrained_embedding',trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x

5.设置参数和训练策略

params = {'vocab_path': './vocab/word.txt','train_path': './data/train.txt','test_path': './data/test.txt','num_samples': 25000,'num_labels': 2,'batch_size': 32,'max_len': 1000,'rnn_units': 200,'dropout_rate': 0.2,'clip_norm': 10.,'num_patience': 3,'lr': 3e-4,
}
def is_descending(history: list):history = history[-(params['num_patience']+1):]for i in range(1, len(history)):if history[i-1] <= history[i]:return Falsereturn True  
word2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1model = Model(params)
model.build(input_shape=(None, None))#设置输入的大小,或者fit时候也能自动找到decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0history_acc = []
best_acc = .0t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)

6.模型训练

while True:# 训练模型for texts, labels in dataset(is_training=True, params=params):with tf.GradientTape() as tape:#梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度logits = model(texts, training=True)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss = tf.reduce_mean(loss)optim.lr.assign(decay_lr(global_step))grads = tape.gradient(loss, model.trainable_variables)grads, _ = tf.clip_by_global_norm(grads, params['clip_norm']) #将梯度限制一下,有的时候回更新太猛,防止过拟合optim.apply_gradients(zip(grads, model.trainable_variables))#更新梯度if global_step % 50 == 0:logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))t0 = time.time()global_step += 1# 验证集效果m = tf.keras.metrics.Accuracy()for texts, labels in dataset(is_training=False, params=params):logits = model(texts, training=False)y_pred = tf.argmax(logits, axis=-1)m.update_state(y_true=labels, y_pred=y_pred)acc = m.result().numpy()logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))history_acc.append(acc)if acc > best_acc:best_acc = acclogger.info("Best Accuracy: {:.3f}".format(best_acc))if len(history_acc) > params['num_patience'] and is_descending(history_acc):logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))break

总结

通过RNN文本分类任务代码案例实战,学习Tensorflow2的各种api。

相关文章:

Tensorflow2基础代码实战系列之双层RNN文本分类任务

深度学习框架Tensorflow2系列 注&#xff1a;大家觉得博客好的话&#xff0c;别忘了点赞收藏呀&#xff0c;本人每周都会更新关于人工智能和大数据相关的内容&#xff0c;内容多为原创&#xff0c;Python Java Scala SQL 代码&#xff0c;CV NLP 推荐系统等&#xff0c;Spark …...

Python爬虫-快手photoId

前言 本文是该专栏的第49篇,后面会持续分享python爬虫干货知识,记得关注。 笔者在本专栏的上一篇,有详细介绍平台视频播放量的爬取方法。与该平台相关联的文章,笔者已整理在下方,感兴趣的同学可查看翻阅。 1. Python如何解决“快手滑块验证码”(4) 2. 快手pcursor 3. …...

软件测试人员如何为项目的质量保障兜底?看完你就明白了...

上线前层层保障 01文档管理 关键词&#xff1a;需求文档、设计文档、测试文档 1.需求和设计产出方为产品、开发&#xff0c;测试需要做好流程监督&#xff0c;这里重点说下测试文档。 2.测试文档&#xff0c;从业务领域来说&#xff0c;一般有测试计划、测试用例、业务总结文…...

《幸福关系的7段旅程》

关于作者 本书作者安德鲁∙马歇尔&#xff0c;英国顶尖婚姻咨询机构RELATE的资深专家&#xff0c;拥有 30年丰富的咨询经验&#xff0c;并为《泰晤士报》《观察家》和《星期日快报》撰写专栏文章。已出版19部作品&#xff0c;并被翻译成20种语言。 关于本书 《幸福关系的7段…...

使用Python中PDB模块中的命令来调试Python代码的教程

这篇文章主要介绍了使用Python中PDB模块中的命令来调试Python代码的教程,包括设置断点来修改代码等、对于Python团队项目工作有一定帮助&#xff0c;需要的朋友可以参考下 你有多少次陷入不得不更改别人代码的境地&#xff1f;如果你是一个开发团队的一员&#xff0c;那么你遇…...

Codeforces Round 764 (Div. 3)

比赛链接 Codeforces Round 764 A. Plus One on the SubsetB. Make APC. Division by Two and PermutationD. Palindromes ColoringE. Masha-forgetful A. Plus One on the Subset Example input 3 6 3 4 2 4 1 2 3 1000 1002 998 2 12 11output 3 4 1题意&#xff1a; 你可…...

四月,收割12家offer,面试也太容易了吧....

前言 下面是我根据工作这几年来的面试经验&#xff0c;加上之前收集的资料&#xff0c;整理出来350道软件测试工程师 常考的面试题。字节跳动、阿里、腾讯、百度、快手、美团等大厂常考的面试题&#xff0c;在文章里面都有 提到。 虽然这篇文章很长&#xff0c;但是绝对值得你…...

Xubuntu22.04之自动调节亮度护眼redshift(一百七十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…...

Spark基础学习笔记----RDD检查点与共享变量

零、本讲学习目标 了解RDD容错机制理解RDD检查点机制的特点与用处理解共享变量的类别、特点与使用 一、RDD容错机制 当Spark集群中的某一个节点由于宕机导致数据丢失&#xff0c;则可以通过Spark中的RDD进行容错恢复已经丢失的数据。RDD提供了两种故障恢复的方式&#xff0c…...

ES6(对象,数组,类型化数组)

对象 1&#xff0c;Object.is 用于判断两个值是否相等&#xff0c; 其内部实现类SameValue算法&#xff0c; 其行为类似于“” 但与“”不同的是 它认为两个NaN是相等的 而0&#xff0c;-0是不相等的 2&#xff0c;Object.assign 表示此方法可以将对象合并成一个 他的第一个…...

JVM系列-第12章-垃圾回收器

垃圾回收器 GC 分类与性能指标 垃圾回收器概述 垃圾收集器没有在规范中进行过多的规定&#xff0c;可以由不同的厂商、不同版本的JVM来实现。 由于JDK的版本处于高速迭代过程中&#xff0c;因此Java发展至今已经衍生了众多的GC版本。 从不同角度分析垃圾收集器&#xff0c;…...

零操作难度,轻松进行应用测试,App专项测试之Monkey测试完全指南!

目录 前言&#xff1a; 一、 Monkey测试的基础参数 1.1 事件类型参数&#xff1a; 1.2 覆盖包 1.3 事件数量 二、 Monkey测试的高级参数 2.1 稳定性级别 2.2 策略参数 2.3 包含选项参数 三、 附加代码 四、 总结 前言&#xff1a; 在移动应用的开发过程中&#xff0…...

Linux安装Docker(这应该是你看过的最简洁的安装教程)

Docker是一种开源的容器化平台&#xff0c;可以将应用程序及其依赖项打包成一个可移植的容器&#xff0c;以便在不同的环境中运行。Docker的核心是Docker引擎&#xff0c;它可以自动化应用程序的部署、扩展和管理&#xff0c;同时还提供了一个开放的API&#xff0c;可以与其他工…...

使用AES算法加密技术集成Java和Vue保护您的数据,代码示例和算法原理

1 算法的原理: AES是一种对称加密算法,也就是说加密和解密使用的是同一个密钥。其基本原理是将明文分成固定大小的块(128位),然后使用密钥对每个块进行加密操作,最后生成密文。在加密过程中,还需要使用一个向量(IV)来增加安全性,避免相同的明文块生成相同的密文块。…...

vcruntime140_1.dll丢失怎样修复,推荐4个vcruntime140_1.dll丢失的修复方法

vcruntime140_1.dll文件是Microsoft Visual C Redistributable for Visual Studio 2015运行库的一部分&#xff0c;它是一个用于支持Visual C构建的应用程序的系统文件。这个文件包含了在运行C程序时所需要的函数和类库&#xff0c;主要负责向应用程序提供运行时环境。如果电脑…...

快来试试这几个简单好用的手机技巧吧

技巧一&#xff1a;相机功能 苹果手机的相机功能确实非常出色&#xff0c;除了出色的像素之外&#xff0c;还有许多其他实用功能可以提升拍摄体验。 这些相机功能提供了更多的选择和便利性&#xff0c;使用户能够更好地适应不同的拍摄需求。 自拍功能&#xff1a;通过选择自…...

OneDrive同步角标消失 - 解决方案

问题 在电脑端使用OneDrive时&#xff0c;文件管理器OneDrive文件夹内的文件会在左下角显示同步状态&#xff0c;如下图。若没有显示同步角标&#xff0c;则此功能出现异常&#xff0c;下文介绍如何显示同步角标。 值得一提的是&#xff0c;同步角标只起到显示作用&#xff0…...

自学网络安全【黑客】,一般人我劝你还是算了吧

前言&#xff1a;我是劝一般人算了&#xff0c;看你是一般人还是。。。 一、网络安全学习的误区 1.不要试图以编程为基础去学习网络安全2.不要刚开始就深度学习网络安全3.收集适当的学习资料4.适当的报班学习二、学习网络安全的些许准备 1.硬件选择2.软件选择3.外语能力三、网…...

Java集合工具:first和last

在平常开发过程中&#xff0c;我们经常会遇到截取列表片段的需求&#xff0c;比如取列表中前4个元素、取后四个元素。Java的List提供了subList方法&#xff0c;可以用来完成这些工作&#xff0c;但是使用起来并没有那么便利&#xff0c;比如取前四个元素&#xff1a; list.sub…...

leetcode 905. 按奇偶排序数组

题目描述解题思路执行结果 leetcode 905. 按奇偶排序数组 题目描述 按奇偶排序数组 给你一个整数数组 nums&#xff0c;将 nums 中的的所有偶数元素移动到数组的前面&#xff0c;后跟所有奇数元素。 返回满足此条件的 任一数组 作为答案。 示例 1&#xff1a; 输入&#xff1a;…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

web vue 项目 Docker化部署

Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段&#xff1a; 构建阶段&#xff08;Build Stage&#xff09;&#xff1a…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

Appium+python自动化(十六)- ADB命令

简介 Android 调试桥(adb)是多种用途的工具&#xff0c;该工具可以帮助你你管理设备或模拟器 的状态。 adb ( Android Debug Bridge)是一个通用命令行工具&#xff0c;其允许您与模拟器实例或连接的 Android 设备进行通信。它可为各种设备操作提供便利&#xff0c;如安装和调试…...

23-Oracle 23 ai 区块链表(Blockchain Table)

小伙伴有没有在金融强合规的领域中遇见&#xff0c;必须要保持数据不可变&#xff0c;管理员都无法修改和留痕的要求。比如医疗的电子病历中&#xff0c;影像检查检验结果不可篡改行的&#xff0c;药品追溯过程中数据只可插入无法删除的特性需求&#xff1b;登录日志、修改日志…...

ffmpeg(四):滤镜命令

FFmpeg 的滤镜命令是用于音视频处理中的强大工具&#xff0c;可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下&#xff1a; ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜&#xff1a; ffmpeg…...

Rust 异步编程

Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

OpenLayers 分屏对比(地图联动)

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能&#xff0c;和卷帘图层不一样的是&#xff0c;分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕&#xff0c;#AI 监考一度冲上热搜。当AI深度融入高考&#xff0c;#时间同步 不再是辅助功能&#xff0c;而是决定AI监考系统成败的“生命线”。 AI亮相2025高考&#xff0c;40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕&#xff0c;江西、…...