TensorFlow2实战-系列教程11:RNN文本分类3
🧡💛💚TensorFlow2实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传
6、构建训练数据
- 所有的输入样本必须都是相同shape(文本长度,词向量维度等)
- tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
- tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本
def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, y
- 定义一个生成器函数,传进来读数据的路径、和一些有限制的参数,params 在是一个字典,它包含了最大序列长度(max_len)、词到索引的映射(word2idx)等关键信息
- 打开文件
- 打印文件路径
- 遍历每行数据
- 获取标签和文本
- 文本按照空格分离出单词
- 获取当前句子的所有词对应的索引,for w in text取出这个句子的每一个单词,[params[‘word2idx’]取出params中对应的word2idx字典,.get(w, len(word2idx))从word2idx字典中取出该单词对应的索引,如果有这个索引则返回这个索引,如果没有则返回len(word2idx)作为索引,这个索引表示unknow
- 如果当前句子大于预设的最大句子长度
- 进行截断操作
- 如果小于
- 补充0
- 标签从str转换为int类型
- yield 关键字:用于从一个函数返回一个生成器(generator)。与 return 不同,yield 不会退出函数,而是将函数暂时挂起,保存当前的状态,当生成器再次被调用时,函数会从上次 yield 的地方继续执行,使用 yield 的函数可以在处理大数据集时节省内存,因为它允许逐个生成和处理数据,而不是一次性加载整个数据集到内存中
也就是说yield 会从上一次取得地方再接着去取数据,而return却不会
def dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds
- 定义一个制作数据集的函数,is_training表示是否是训练,这个函数在验证和测试也会使用,训练的时候设置为True,验证和测试为False
- 当前shape值
- 1
- 是否在训练,如果是:
- 构建一个Dataset
- 传进我们刚刚定义的生成器函数,并且传进实际的路径和配置参数
- 输出的shape值
- 输出的类型
- 指定shuffle
- 指定 batch_size
- 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速
- 如果不是在训练,则:
- 验证和测试不同的就是路径不同,以及没有shuffle操作,其他都一样
- 最后把做好的Datasets返回回去
7、自定义网络模型

一条文本变成一组向量/矩阵的基本流程:
- 拿到一个英文句子
- 通过查语料表将句子变成一组索引
- 通过词嵌入表结合索引将每个单词都变成一组向量,一条句子就变成了一个矩阵,这就是特征了


BiLSTM即双向LSTM,就是在原本的LSTM增加了一个从后往前走的模块,这样前向和反向两个方向都各自生成了一组特征,把两个特征拼接起来得到一组新的特征,得到翻倍的特征。其他前面和后续的处理操作都是一样的。
class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'), dtype=tf.float32, name='pretrained_embedding', trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2) def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs) x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x
- 自定义一个模型,继承tf.keras.Model模块
- 初始化函数
- 初始化
- 词嵌入,把之前保存好的词嵌入文件向量读进来
- 定义一层dropout1
- 定义一层dropout2
- 定义一层dropout3
- 定义一个rnn1,rnn_units表示得到多少维的特征,return_sequences表示是返回一个序列还是最后一个输出
- 定义一个rnn2,最后一层的rnn肯定只需要最后一个输出,前后两个rnn的堆叠肯定需要返回一个序列
- 定义一个rnn3 ,tf.keras.layers.LSTM()直接就可以定义一个LSTM,在外面再封装一层API:tf.keras.layers.Bidirectional就实现了双向LSTM
- 定义全连接层的dropout
- 定义一个全连接层,因为是双向的,这里就需要把参数乘以2
- 定义最后输出的全连接层,只需要得到是正例还是负例,所以是2
- 定义前向传播函数,传进来一个batch的数据和是否是在训练
- 如果输入数据不是tf.int32类型
- 转换成tf.int32类型
- 取出batch_size
- 设置LSTM神经元个数,双向乘以2
- 使用 TensorFlow 的 embedding_lookup 函数将输入的整数索引转换为词向量
- 数据通过第1个 Dropout 层
- 数据通过第1个rnn
- 数据通过第2个 Dropout 层
- 数据通过第2个rnn
- 数据通过第3个 Dropout 层
- 数据通过第3个rnn
- 经过全连接层对应的Dropout
- 数据通过一个全连接层
- 最后,数据通过一个输出层
- 返回最终的模型输出
相关文章:
TensorFlow2实战-系列教程11:RNN文本分类3
🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 6、构建训练数据 所有的输入样本必须都是相同shape(文本长度,…...
故障诊断 | 一文解决,RF随机森林的故障诊断(Matlab)
效果一览 文章概述 故障诊断 | 一文解决,RF随机森林的故障诊断(Matlab) 模型描述 随机森林(Random Forest)是一种集成学习(Ensemble Learning)方法,常用于解决分类和回归问题。它由多个决策树组成,每个决策树都独立地对数据进行训练,并且最终的预测结果是由所有决策…...
DAO设计模式
概念:DAO(Data Access Object) 数据库访问对象,**面向数据库SQL操作**的封装。 (一)场景 问题分析 在实际开发中,针对一张表的复杂业务功能通常需要和表交互多次(比如转账)。如果每次针对表的…...
【Midjourney】新手指南:参数设置
1.--aspect 或 --ar 用于设置图片长宽比,例如 --ar 16:9就是设置图片宽为16,高为9 2.--chaos 用于设置躁点,噪点值越高随机性越大,取值为0到100,例如 --chaos 50 3.--turbo 覆盖seetings的设置并启用极速模式生成…...
阿里云a10GPU,centos7,cuda11.2环境配置
Anaconda3-2022.05-Linux-x86_64.sh gcc升级 centos7升级gcc至8.2_centos7 yum gcc8.2.0-CSDN博客 paddlepaddle python -m pip install paddlepaddle-gpu2.5.1.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html 报错 ImportError: libssl.so…...
RTSP/Onvif协议视频平台EasyNVR激活码授权异常该如何解决
TSINGSEE青犀视频安防监控平台EasyNVR可支持设备通过RTSP/Onvif协议接入,并能对接入的视频流进行处理与多端分发,包括RTSP、RTMP、HTTP-FLV、WS-FLV、HLS、WebRTC等多种格式。在智慧安防等视频监控场景中,EasyNVR可提供视频实时监控直播、云端…...
React16源码: React中event事件对象的创建过程源码实现
event 对象 1 ) 概述 在生产事件对象的过程当中,要去调用每一个 possiblePlugin.extractEvents 方法现在单独看下这里面的细节过程,即如何去生产这个事件对象的过程 2 )源码 定位到 packages/events/EventPluginHub.js#L172 f…...
深度学习(12)--Mnist分类任务
一.Mnist分类任务流程详解 1.1.引入数据集 Mnist数据集是官方的数据集,比较特殊,可以直接通过%matplotlib inline自动下载,博主此处已经完成下载,从本地文件中引入数据集。 设置数据路径 from pathlib import Path# 设置数据路…...
AI工具【OCR 01】Java可使用的OCR工具Tess4J使用举例(身份证信息识别核心代码及信息提取方法分享)
Java可使用的OCR工具Tess4J使用举例 1.简介1.1 简单介绍1.2 官方说明 2.使用举例2.1 依赖及语言数据包2.2 核心代码2.3 识别身份证信息2.3.1 核心代码2.3.2 截取指定字符2.3.3 去掉字符串里的非中文字符2.3.4 提取出生日期(待优化)2.3.5 实测 3.总结 1.简…...
【MySQL复制】半同步复制
介绍 除了内置的异步复制之外,MySQL 5.7 还支持通过插件实现的半同步复制接口。本节讨论半同步复制的概念及其工作原理。接下来的部分将涵盖与半同步复制相关的管理界面,以及如何安装、配置和监控它。 异步复制 MySQL 复制默认是异步的。源服务器将事…...
PHP面试知识点--echo、print、print_r、var_dump区别
echo、print、print_r、var_dump 区别 echo 输出单个或多个字符,多个使用逗号分隔无返回值 echo "String 1", "String 2";print 只可以输出单个字符返回1,因此可用于表达式 print "Hello"; if ($expr && pri…...
centos 7 部署若依前后端分离项目
目录 一、新建数据库 二、修改需求配置 1.修改数据库连接 2.修改Redis连接信息 3.文件路径 4.日志存储路径调整 三、编译后端项目 四、编译前端项目 1.上传项目 2.安装依赖 3.构建生产环境 五、项目部署 1.创建目录 2.后端文件上传 3. 前端文件上传 六、服务启…...
RFID手持终端_智能pda手持终端设备定制方案
手持终端是一款多功能、适用范围广泛的安卓产品,具有高性能、大容量存储、高端扫描头和全网通数据连接能力。它能够快速平稳地运行,并提供稳定的连接表现和快速的响应时,适用于医院、物流运输、零售配送、资产盘点等苛刻的环境。通过快速采集…...
51单片机学习——矩阵按键
目录 gitee链接 小程吃饭饭 (xiaocheng-has-a-meal) - Gitee.comhttps://gitee.com/xiaocheng-has-a-meal 1.图~突突突突突 矩阵键盘原理图 矩阵键盘的实物图 2.矩阵键盘 引入~啦啦啦啦啦 原理~沥沥沥沥沥 代码~嗷嗷嗷嗷嗷 【1】延时函数 【2】 LCD1602 【3】检测按…...
重写Sylar基于协程的服务器(1、日志模块的架构)
重写Sylar基于协程的服务器(1、日志模块的架构) 重写Sylar基于协程的服务器系列: 重写Sylar基于协程的服务器(0、搭建开发环境以及项目框架 || 下载编译简化版Sylar) 重写Sylar基于协程的服务器(1、日志模…...
ElementUI Form:Radio 单选框
ElementUI安装与使用指南 Radio 单选框 点击下载learnelementuispringboot项目源码 效果图 el-radio.vue (Radio 单选框)页面效果图 项目里el-radio.vue代码 <script> export default {name: el_radio,data() {return {radio: 1,radio2: 2,ra…...
react-activation实现缓存,且部分页面刷新缓存,清除缓存
1.安装依赖 npm i -S react-activation2.使用AliveScope 包裹根组件 import { AliveScope } from "react-activation" <AliveScope><Router><Switch><Route exact path"/" render{() > <Redirect to"/login" push …...
idea 中 tomcat 乱码问题修复
之前是修改 Tomcat 目录下 conf/logging.properties 的配置,将 UTF-8 修改为 GBK,现在发现不用这样修改了。只需要修改 IDEA 中 Tomcat 的配置就可以了。 修改IDEA中Tomcat的配置:添加-Dfile.encodingUTF-8 本文结束...
Modbus协议学习第七篇之libmodbus库API介绍(modbus_write_bits等)
写在前面 在第六篇中我们介绍了基于libmodbus库的演示代码,那本篇博客就详细介绍一下第六篇的代码中使用的基于该库的API函数。另各位读者,Modbus相关知识受众较少,如果觉得我的专栏文章有帮助,请一定点个赞,在此跪谢&…...
第九节HarmonyOS 常用基础组件13-TimePicker
1、描述 时间选择组件,根据指定参数创建选择器,支持选择小时以及分钟。默认以24小时的时间区间创建滑动选择器。 2、接口 TimePicker(options?: {selected?: Date}) 3、参数 selected - Date - 设置选中项的时间。默认是系统当前的时间。 4、属性…...
智慧工地云平台源码,基于微服务架构+Java+Spring Cloud +UniApp +MySql
智慧工地管理云平台系统,智慧工地全套源码,java版智慧工地源码,支持PC端、大屏端、移动端。 智慧工地聚焦建筑行业的市场需求,提供“平台网络终端”的整体解决方案,提供劳务管理、视频管理、智能监测、绿色施工、安全管…...
深入理解JavaScript设计模式之单例模式
目录 什么是单例模式为什么需要单例模式常见应用场景包括 单例模式实现透明单例模式实现不透明单例模式用代理实现单例模式javaScript中的单例模式使用命名空间使用闭包封装私有变量 惰性单例通用的惰性单例 结语 什么是单例模式 单例模式(Singleton Pattern&#…...
鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序
一、开发环境准备 工具安装: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 项目初始化: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...
WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...
CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云
目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...
CMake控制VS2022项目文件分组
我们可以通过 CMake 控制源文件的组织结构,使它们在 VS 解决方案资源管理器中以“组”(Filter)的形式进行分类展示。 🎯 目标 通过 CMake 脚本将 .cpp、.h 等源文件分组显示在 Visual Studio 2022 的解决方案资源管理器中。 ✅ 支持的方法汇总(共4种) 方法描述是否推荐…...
Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决
Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中,新增了一个本地验证码接口 /code,使用函数式路由(RouterFunction)和 Hutool 的 Circle…...
Springboot社区养老保险系统小程序
一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,社区养老保险系统小程序被用户普遍使用,为方…...
深度学习水论文:mamba+图像增强
🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...
免费数学几何作图web平台
光锐软件免费数学工具,maths,数学制图,数学作图,几何作图,几何,AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...
