TensorFlow2实战-系列教程11:RNN文本分类3
🧡💛💚TensorFlow2实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传
6、构建训练数据
- 所有的输入样本必须都是相同shape(文本长度,词向量维度等)
- tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
- tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本
def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, y
- 定义一个生成器函数,传进来读数据的路径、和一些有限制的参数,params 在是一个字典,它包含了最大序列长度(max_len)、词到索引的映射(word2idx)等关键信息
- 打开文件
- 打印文件路径
- 遍历每行数据
- 获取标签和文本
- 文本按照空格分离出单词
- 获取当前句子的所有词对应的索引,for w in text取出这个句子的每一个单词,[params[‘word2idx’]取出params中对应的word2idx字典,.get(w, len(word2idx))从word2idx字典中取出该单词对应的索引,如果有这个索引则返回这个索引,如果没有则返回len(word2idx)作为索引,这个索引表示unknow
- 如果当前句子大于预设的最大句子长度
- 进行截断操作
- 如果小于
- 补充0
- 标签从str转换为int类型
- yield 关键字:用于从一个函数返回一个生成器(generator)。与 return 不同,yield 不会退出函数,而是将函数暂时挂起,保存当前的状态,当生成器再次被调用时,函数会从上次 yield 的地方继续执行,使用 yield 的函数可以在处理大数据集时节省内存,因为它允许逐个生成和处理数据,而不是一次性加载整个数据集到内存中
也就是说yield 会从上一次取得地方再接着去取数据,而return却不会
def dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds
- 定义一个制作数据集的函数,is_training表示是否是训练,这个函数在验证和测试也会使用,训练的时候设置为True,验证和测试为False
- 当前shape值
- 1
- 是否在训练,如果是:
- 构建一个Dataset
- 传进我们刚刚定义的生成器函数,并且传进实际的路径和配置参数
- 输出的shape值
- 输出的类型
- 指定shuffle
- 指定 batch_size
- 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速
- 如果不是在训练,则:
- 验证和测试不同的就是路径不同,以及没有shuffle操作,其他都一样
- 最后把做好的Datasets返回回去
7、自定义网络模型

一条文本变成一组向量/矩阵的基本流程:
- 拿到一个英文句子
- 通过查语料表将句子变成一组索引
- 通过词嵌入表结合索引将每个单词都变成一组向量,一条句子就变成了一个矩阵,这就是特征了


BiLSTM即双向LSTM,就是在原本的LSTM增加了一个从后往前走的模块,这样前向和反向两个方向都各自生成了一组特征,把两个特征拼接起来得到一组新的特征,得到翻倍的特征。其他前面和后续的处理操作都是一样的。
class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'), dtype=tf.float32, name='pretrained_embedding', trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2) def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs) x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x
- 自定义一个模型,继承tf.keras.Model模块
- 初始化函数
- 初始化
- 词嵌入,把之前保存好的词嵌入文件向量读进来
- 定义一层dropout1
- 定义一层dropout2
- 定义一层dropout3
- 定义一个rnn1,rnn_units表示得到多少维的特征,return_sequences表示是返回一个序列还是最后一个输出
- 定义一个rnn2,最后一层的rnn肯定只需要最后一个输出,前后两个rnn的堆叠肯定需要返回一个序列
- 定义一个rnn3 ,tf.keras.layers.LSTM()直接就可以定义一个LSTM,在外面再封装一层API:tf.keras.layers.Bidirectional就实现了双向LSTM
- 定义全连接层的dropout
- 定义一个全连接层,因为是双向的,这里就需要把参数乘以2
- 定义最后输出的全连接层,只需要得到是正例还是负例,所以是2
- 定义前向传播函数,传进来一个batch的数据和是否是在训练
- 如果输入数据不是tf.int32类型
- 转换成tf.int32类型
- 取出batch_size
- 设置LSTM神经元个数,双向乘以2
- 使用 TensorFlow 的 embedding_lookup 函数将输入的整数索引转换为词向量
- 数据通过第1个 Dropout 层
- 数据通过第1个rnn
- 数据通过第2个 Dropout 层
- 数据通过第2个rnn
- 数据通过第3个 Dropout 层
- 数据通过第3个rnn
- 经过全连接层对应的Dropout
- 数据通过一个全连接层
- 最后,数据通过一个输出层
- 返回最终的模型输出
相关文章:
TensorFlow2实战-系列教程11:RNN文本分类3
🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 6、构建训练数据 所有的输入样本必须都是相同shape(文本长度,…...
故障诊断 | 一文解决,RF随机森林的故障诊断(Matlab)
效果一览 文章概述 故障诊断 | 一文解决,RF随机森林的故障诊断(Matlab) 模型描述 随机森林(Random Forest)是一种集成学习(Ensemble Learning)方法,常用于解决分类和回归问题。它由多个决策树组成,每个决策树都独立地对数据进行训练,并且最终的预测结果是由所有决策…...
DAO设计模式
概念:DAO(Data Access Object) 数据库访问对象,**面向数据库SQL操作**的封装。 (一)场景 问题分析 在实际开发中,针对一张表的复杂业务功能通常需要和表交互多次(比如转账)。如果每次针对表的…...
【Midjourney】新手指南:参数设置
1.--aspect 或 --ar 用于设置图片长宽比,例如 --ar 16:9就是设置图片宽为16,高为9 2.--chaos 用于设置躁点,噪点值越高随机性越大,取值为0到100,例如 --chaos 50 3.--turbo 覆盖seetings的设置并启用极速模式生成…...
阿里云a10GPU,centos7,cuda11.2环境配置
Anaconda3-2022.05-Linux-x86_64.sh gcc升级 centos7升级gcc至8.2_centos7 yum gcc8.2.0-CSDN博客 paddlepaddle python -m pip install paddlepaddle-gpu2.5.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 报错 ImportError: libssl.so…...
RTSP/Onvif协议视频平台EasyNVR激活码授权异常该如何解决
TSINGSEE青犀视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。在智慧安防等视频监控场景中,EasyNVR可提供视频实时监控直播、云端…...
React16源码: React中event事件对象的创建过程源码实现
event 对象 1 ) 概述 在生产事件对象的过程当中,要去调用每一个 possiblePlugin.extractEvents 方法现在单独看下这里面的细节过程,即如何去生产这个事件对象的过程 2 )源码 定位到 packages/events/EventPluginHub.js#L172 f…...
深度学习(12)--Mnist分类任务
一.Mnist分类任务流程详解 1.1.引入数据集 Mnist数据集是官方的数据集,比较特殊,可以直接通过%matplotlib inline自动下载,博主此处已经完成下载,从本地文件中引入数据集。 设置数据路径 from pathlib import Path# 设置数据路…...
AI工具【OCR 01】Java可使用的OCR工具Tess4J使用举例(身份证信息识别核心代码及信息提取方法分享)
Java可使用的OCR工具Tess4J使用举例 1.简介1.1 简单介绍1.2 官方说明 2.使用举例2.1 依赖及语言数据包2.2 核心代码2.3 识别身份证信息2.3.1 核心代码2.3.2 截取指定字符2.3.3 去掉字符串里的非中文字符2.3.4 提取出生日期(待优化)2.3.5 实测 3.总结 1.简…...
【MySQL复制】半同步复制
介绍 除了内置的异步复制之外,MySQL 5.7 还支持通过插件实现的半同步复制接口。本节讨论半同步复制的概念及其工作原理。接下来的部分将涵盖与半同步复制相关的管理界面,以及如何安装、配置和监控它。 异步复制 MySQL 复制默认是异步的。源服务器将事…...
PHP面试知识点--echo、print、print_r、var_dump区别
echo、print、print_r、var_dump 区别 echo 输出单个或多个字符,多个使用逗号分隔无返回值 echo "String 1", "String 2";print 只可以输出单个字符返回1,因此可用于表达式 print "Hello"; if ($expr && pri…...
centos 7 部署若依前后端分离项目
目录 一、新建数据库 二、修改需求配置 1.修改数据库连接 2.修改Redis连接信息 3.文件路径 4.日志存储路径调整 三、编译后端项目 四、编译前端项目 1.上传项目 2.安装依赖 3.构建生产环境 五、项目部署 1.创建目录 2.后端文件上传 3. 前端文件上传 六、服务启…...
RFID手持终端_智能pda手持终端设备定制方案
手持终端是一款多功能、适用范围广泛的安卓产品,具有高性能、大容量存储、高端扫描头和全网通数据连接能力。它能够快速平稳地运行,并提供稳定的连接表现和快速的响应时,适用于医院、物流运输、零售配送、资产盘点等苛刻的环境。通过快速采集…...
51单片机学习——矩阵按键
目录 gitee链接 小程吃饭饭 (xiaocheng-has-a-meal) - Gitee.comhttps://gitee.com/xiaocheng-has-a-meal 1.图~突突突突突 矩阵键盘原理图 矩阵键盘的实物图 2.矩阵键盘 引入~啦啦啦啦啦 原理~沥沥沥沥沥 代码~嗷嗷嗷嗷嗷 【1】延时函数 【2】 LCD1602 【3】检测按…...
重写Sylar基于协程的服务器(1、日志模块的架构)
重写Sylar基于协程的服务器(1、日志模块的架构) 重写Sylar基于协程的服务器系列: 重写Sylar基于协程的服务器(0、搭建开发环境以及项目框架 || 下载编译简化版Sylar) 重写Sylar基于协程的服务器(1、日志模…...
ElementUI Form:Radio 单选框
ElementUI安装与使用指南 Radio 单选框 点击下载learnelementuispringboot项目源码 效果图 el-radio.vue (Radio 单选框)页面效果图 项目里el-radio.vue代码 <script> export default {name: el_radio,data() {return {radio: 1,radio2: 2,ra…...
react-activation实现缓存,且部分页面刷新缓存,清除缓存
1.安装依赖 npm i -S react-activation2.使用AliveScope 包裹根组件 import { AliveScope } from "react-activation" <AliveScope><Router><Switch><Route exact path"/" render{() > <Redirect to"/login" push …...
idea 中 tomcat 乱码问题修复
之前是修改 Tomcat 目录下 conf/logging.properties 的配置,将 UTF-8 修改为 GBK,现在发现不用这样修改了。只需要修改 IDEA 中 Tomcat 的配置就可以了。 修改IDEA中Tomcat的配置:添加-Dfile.encodingUTF-8 本文结束...
Modbus协议学习第七篇之libmodbus库API介绍(modbus_write_bits等)
写在前面 在第六篇中我们介绍了基于libmodbus库的演示代码,那本篇博客就详细介绍一下第六篇的代码中使用的基于该库的API函数。另各位读者,Modbus相关知识受众较少,如果觉得我的专栏文章有帮助,请一定点个赞,在此跪谢&…...
第九节HarmonyOS 常用基础组件13-TimePicker
1、描述 时间选择组件,根据指定参数创建选择器,支持选择小时以及分钟。默认以24小时的时间区间创建滑动选择器。 2、接口 TimePicker(options?: {selected?: Date}) 3、参数 selected - Date - 设置选中项的时间。默认是系统当前的时间。 4、属性…...
K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
【决胜公务员考试】求职OMG——见面课测验1
2025最新版!!!6.8截至答题,大家注意呀! 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:( B ) A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...
相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)
【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...
JVM虚拟机:内存结构、垃圾回收、性能优化
1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...
保姆级教程:在无网络无显卡的Windows电脑的vscode本地部署deepseek
文章目录 1 前言2 部署流程2.1 准备工作2.2 Ollama2.2.1 使用有网络的电脑下载Ollama2.2.2 安装Ollama(有网络的电脑)2.2.3 安装Ollama(无网络的电脑)2.2.4 安装验证2.2.5 修改大模型安装位置2.2.6 下载Deepseek模型 2.3 将deepse…...
tauri项目,如何在rust端读取电脑环境变量
如果想在前端通过调用来获取环境变量的值,可以通过标准的依赖: std::env::var(name).ok() 想在前端通过调用来获取,可以写一个command函数: #[tauri::command] pub fn get_env_var(name: String) -> Result<String, Stri…...
书籍“之“字形打印矩阵(8)0609
题目 给定一个矩阵matrix,按照"之"字形的方式打印这个矩阵,例如: 1 2 3 4 5 6 7 8 9 10 11 12 ”之“字形打印的结果为:1,…...
云原生安全实战:API网关Envoy的鉴权与限流详解
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关 作为微服务架构的统一入口,负责路由转发、安全控制、流量管理等核心功能。 2. Envoy 由Lyft开源的高性能云原生…...
IP选择注意事项
IP选择注意事项 MTP、FTP、EFUSE、EMEMORY选择时,需要考虑以下参数,然后确定后选择IP。 容量工作电压范围温度范围擦除、烧写速度/耗时读取所有bit的时间待机功耗擦写、烧写功耗面积所需要的mask layer...
