NLP之RNN的原理讲解(python示例)
目录
- 代码示例
- 代码解读
- 知识点介绍
代码示例
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32)
print(xt)
# https://www.cnblogs.com/Renyi-Fan/p/13722276.htmlcell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
cell.build(input_shape=[None, 1])
print('variables', cell.variables)
print('config:', cell.get_config())print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))# 第t时刻运算
ht_1 = tf.ones([1, 1])
out, ht = cell(xt, ht_1) # LSTM
print(out, ht[0])
print(id(out), id(ht[0]))# 第t+1时刻运算
cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones')
xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32)
out2, ht2 = cell2(xt2, ht)
print(out2, ht2[0])
代码解读
这段代码包含了一些使用 TensorFlow 来创建和操作循环神经网络(RNN)的基础操作。我们将一步步地解释其含义。
-
导入所需的库:
import numpy as np import tensorflow as tf from tensorflow.keras.layers import SimpleRNNCell
代码导入了NumPy库、TensorFlow库以及
SimpleRNNCell
,这是一个实现了简单的RNN单元操作的类。 -
创建训练数据:
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32) print(xt)
这里创建了一个
1x1
的张量,其值是2或3之间的随机整数。这代表了在时间t
的输入数据。 -
定义RNN单元:
cell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
使用
SimpleRNNCell
创建了一个RNN单元。这个单元有以下特性:- 只有一个神经元(
units=1
)。 - 不使用激活函数(
activation=None
)。 - 使用偏置,并初始化为3(
bias_initializer=tf.keras.initializers.Constant(value=3)
)。 - 输入权重和循环权重都初始化为1。
kernel_initializer='ones'
:- 这是一个初始化器,用于初始化RNN单元的权重(也称为内核权重)。
'ones'
表示所有的权重都被初始化为1。- 换句话说,当输入数据经过RNN单元时,它会与这些权重相乘,而这些权重的初始值都是1。
recurrent_initializer='ones'
:- 这是一个初始化器,用于初始化RNN单元的循环权重。
- 在RNN中,当前时间步的隐藏状态是基于前一个时间步的隐藏状态计算的。这个计算涉及到的权重就是循环权重。
'ones'
表示所有的循环权重都被初始化为1。
bias_initializer=tf.keras.initializers.Constant(value=3)
:- 这是一个初始化器,用于初始化RNN单元的偏置。
tf.keras.initializers.Constant(value=3)
表示所有的偏置被初始化为常数3。
- 简而言之,这些参数(kernel_initializer、recurrent_initializer、bias_initializer)确定了RNN单元在开始训练之前的权重和偏置的初始状态。这些初始值在训练过程中会被更新。选择合适的初始化器对于模型的收敛速度和性能至关重要,尽管在这个特定的例子中,这些权重和偏置被赋予了特定的常数值。
cell.build(input_shape=[None, 1])
这行代码是用来告诉RNN单元输入的形状,这样它就可以创建相应的权重和偏置张量。- 在TensorFlow和Keras中,
input_shape
是用来指定输入数据的维度的参数。具体到这里的input_shape=[None, 1]
,我们可以解读它为: [None, 1]
:这是一个形状列表,其中有两个维度。None
:- 第一个维度通常表示批处理的大小(即在一个批次中的样本数)。在许多情况下,为了使模型更加灵活,我们可能不想在定义模型时硬编码一个固定的批处理大小。
- 使用
None
作为批处理的大小意味着模型可以接受任何大小的批次。 - 例如,你可以选择在训练时使用64的批大小,在评估或推理时使用1的批大小,或者使用其他任何数字。
1
:- 第二个维度是数据的特征维度。
- 在这里,它指的是输入数据的每个样本有1个特征。
- 综上所述,
input_shape=[None, 1]
表示模型可以接受一个二维的输入,其中第一个维度是任意大小的批处理,第二个维度是1个特征。
- 只有一个神经元(
-
显示RNN单元的变量和配置:
代码打印出RNN单元的所有变量(如权重和偏置)以及配置。print('variables', cell.variables) print('config:', cell.get_config())
这两行代码是关于打印关于
cell
(这里的cell
是一个SimpleRNNCell
的实例)的相关信息。-
print('variables', cell.variables)
:cell.variables
: 这是一个属性,它返回一个列表,该列表包含cell
中的所有可训练变量(权重和偏置)。在RNN cell的上下文中,这通常包括核权重、递归权重以及偏置。print(...)
: 打印变量列表,以便于你查看和调试。通常这可以帮助你理解RNN cell中的权重如何初始化(例如,这里你已经明确地设置了初始化器)。
-
print('config:', cell.get_config())
:cell.get_config()
: 这是一个方法,它返回一个字典,该字典包含cell
的配置。这通常包括其初始化时使用的参数(例如units的数量、激活函数、是否使用偏置等)。这允许你查看或者后续再次使用这些配置信息,例如,如果你想保存模型的结构并稍后再次创建它。print(...)
: 打印配置字典,使你能够查看cell
的配置。
-
总之,这两行代码提供了关于
SimpleRNNCell
实例(cell
)的详细信息,包括它的权重(和它们的初始值)以及它的配置。这是非常有用的,特别是当你在调试或了解你的模型结构时。
-
-
计算tanh的值:
print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))
这行代码计算了
tanh
函数在-∞
、6和∞
三个点的值。tanh是RNN和其他神经网络中常用的激活函数。 -
第t时刻的计算:
这部分代码首先定义了上一个时间步的隐藏状态ht_1
,然后使用cell(xt, ht_1)
调用RNN单元来获取当前时间步的输出和隐藏状态。ht_1 = tf.ones([1, 1]) out, ht = cell(xt, ht_1) # LSTM print(out, ht[0]) print(id(out), id(ht[0]))
-
第t+1时刻的计算:
同样地,这部分代码定义了一个新的RNN单元cell2
,然后用新的输入xt2
和上一个时间步的隐藏状态ht
来获取下一个时间步的输出和隐藏状态。cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones') xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32) out2, ht2 = cell2(xt2, ht)
-
输出与隐藏状态的关系:
print(id(out), id(ht[0]))
这部分代码展示了在简单的RNN中,输出状态
out
和隐藏状态ht
是相同的对象。
最后,代码的主要目的是演示如何使用SimpleRNNCell
在给定的输入和隐藏状态上进行计算,并展示其结果。
知识点介绍
tf.Variable
是 TensorFlow(TF)中的一个核心概念,它用于表示在 TF 计算过程中可能会发生变化的数据。在 TF 中,计算通常是通过计算图(graph)来定义的,而 tf.Variable
允许我们将可以变化的状态添加到这些计算图中。
以下是 tf.Variable
的一些关键点:
-
可变性:与 TensorFlow 的常量(
tf.constant
)不同,tf.Variable
表示的值是可变的。这意味着在训练过程中,可以更新、修改或赋予其新值。 -
用途:
tf.Variable
通常用于表示模型的参数,例如神经网络中的权重和偏置。 -
初始化:当创建一个
tf.Variable
时,你必须为它提供一个初始值。这个初始值可以是一个固定值,也可以是其他任何 TensorFlow 计算的结果。 -
赋值:使用
assign
、assign_add
等方法,你可以修改tf.Variable
的值。 -
存储和恢复:
tf.Variable
的值可以被存储到磁盘并在之后恢复,这是通过 TensorFlow 的保存和恢复机制实现的,这样可以方便地保存和加载模型。
示例:
import tensorflow as tf# 创建一个初始化为1的变量
v = tf.Variable(1.0)# 使用变量
result = v * 2.0# 修改变量的值
v.assign(2.0) # 现在 v 的值为 2.0
总之,tf.Variable
是 TensorFlow 中表示可变状态的主要方式,尤其是在模型训练中,它用于存储和更新模型的参数。
相关文章:
NLP之RNN的原理讲解(python示例)
目录 代码示例代码解读知识点介绍 代码示例 import numpy as np import tensorflow as tf from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据 xt tf.Variable(np.random.randint(2, 3, size[1, 1]), dtypetf.float32) print(xt) # https://www.cnblog…...

yo!这里是进程间通信
目录 前言 进程间通信简介 目的 分类 匿名通道 介绍 举例(进程池) 命名管道 介绍 举例 共享内存 介绍 共享内存函数 1.shmget 2.shmat 3.shmdt 4.shmctl 举例 1.框架 2.通信逻辑 消息队列 信号量 同步与互斥 理解信号量 后记…...
使用docker安装MySQL,Redis,Nacos,Consul教程
文章目录 安装MySQL安装Redis安装Nacos安装Consul 如未安装docker,参考教程: https://blog.csdn.net/m0_63230155/article/details/134090090 安装MySQL #拉取镜像 sudo docker pull mysql:latestsudo docker run --name mysql \-p 3306:3306 \-e MYSQ…...
python和Springboot如何交互?
Python和Spring Boot可以通过RESTful API进行交互。Spring Boot通常用于后端开发,提供了快速构建RESTful API的工具,而Python则可以用于编写前端或与后端交互的代码。 要实现Python和Spring Boot的交互,可以按照以下步骤进行: 在…...
Qt实现json解析
前提要点 json文件,可通过键值的方式存储你所需要的数据,斌且支持多种类型存储,类似于一种结构化的数据库,在读取json文件时可通过相对应的关键字精准获取。他是一种树状结构,我们可以自己设定叶子的数量以及他所代表…...

Ajax、Json深入浅出,及原生Ajax及简化版Ajax
Ajax 1.路径介绍 1.1 JavaWeb中的路径 在JavaWeb中,路径分为相对路径和绝对路径两种: 相对路径: ./ 表示当前目录(可省略) ../ 表示当前文件所在目录的上一级目录 绝对路径: http://ip:port/工程名/资源路径 2.2 在JavaWeb中…...
前端第一阶段测试
前端第一阶段测试 选择问答 如果觉得有用请给我点个赞⑧~ 选择 1、【单选】下列哪个是子代选择器 A A、p>b B、p b C、pb D、p.b 2、【单选】下述有关css属性position的属性值的描述,说法错误的是?B A、static:没有定位,元素出…...
openlayers+vue的bug
使用addInteraction添加交互draw绘制,预期removeInteraction删除交互draw绘制时不再绘制,但是删除绘制不起作用,各种找原因,结果把data中的map变量注释掉即可,原因未知。 <template><div><div id"…...

实时数仓-Hologres介绍与架构
本文是向大家介绍Hologres是一款实时HSAP产品,隶属阿里自研大数据品牌MaxCompute,兼容 PostgreSQL 生态、支持MaxCompute数据直接查询,支持实时写入实时查询,实时离线联邦分析,低成本、高时效、快速构筑企业实时数据仓…...

asp.net教务管理信息系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio计算机毕业设计
一、源码特点 asp.net 教务管理信息系统是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语言 开发 asp.net教务管理系统 应用技术&a…...
爬虫、数据清洗和分析
爬虫、数据清洗和分析是在数据科学、数据挖掘和网络爬虫开发领域中常见的概念。 爬虫(Web Scraping):爬虫是一种自动化程序或脚本,用于从互联网上的网站上提取信息。这些信息可以是文本、图像、视频或其他类型的数据。爬虫通常会…...

SpringBoot | SpringBoot中实现“微信支付“
SpringBoot中实现"微信支付": 1.“微信支付”产品2."微信支付"接入流程3.“微信小程序支付”时序图:3.1 “商家端JSAPI下单” 接口3.2 “微信小程序端调起支付” 接口 4.微信支付准备工作:4.1 获得微信支付平台证书、商户私钥文件4…...

基于SSM和VUE的留守儿童信息管理系统
末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…...

VMware 16开启虚拟机电脑就蓝屏W11解决方法
问题现象 解决方法 控制面板->程序->启用或关闭windows功能->勾选虚拟机平台->重启...

【Bug——VMware Workstation】虚拟机桥接网络没有 VMnet0
此时 没有VMnet0用来桥接网络。 接下来进行解决 1.找到安装VM的路径,在安装的目录里面找到如图所示的三个文件: 2.依次点击鼠标右键 将这三个文件依次安装如图所示: 二.windows下的操作 1.首先 找到电脑的控制面板->网络和internet->…...

centos中安装Mysql8.0
其实和mysql5.7的安装差不多 1.root用户 2.更新密钥 rpm --import https://repo.mysql.com/RPM-GPG-KEY-mysql-2022 3.安装mysql yum库 rpm -Uvh https://dev.mysql.com/ get/mysql80-community-release-el7-2.noarch.rpm 4.通过上两步,我们就可以使用yum去安装…...
简化对象和函数写法
简化对象写法: 传统写法: var x 10, y 20; var obj {x: x, y: y};简化写法: var x 10, y 20; var obj {x, y};简化函数写法: 传统写法: function add(x, y) {return x y; }简化写法: var add …...

GB/T28181流媒体相关协议详解
GB/T28181流媒体相关协议详解 文章目录 GB/T28181流媒体相关协议详解1 GB/T28181协议中使用的应用层协议介绍2 实时视频点播协议交互流程2.1 设备注册2.2 设备保活2.3 视频播放 总结 本文主要主要针对28181协议中视频流的部分,来阐述视频流通过28181协议如何进行视频…...
十进制转二进制的算法代码 ← Python
【算法分析】 本算法需要用到的Python知识点: 1.求余%,整除 //。例如,7%21,7//23,而7/23.5。 2.Python列表的 append 及 pop 函数。 • append(x) 函数用于将 x 添加到现有列表中。 • pop() 函数默认移除列表中…...

智慧垃圾站:AI视频智能识别技术助力智慧环保项目,以“智”替人强监管
一、背景分析 建设“技术先进、架构合理、开放智能、安全可靠”的智慧环保平台,整合环境相关的数据,对接已建业务系统,将环境相关数据进行统一管理,结合GIS技术进行监测、监控信息的展现和挖掘分析,实现业务数据的快速…...

地震勘探——干扰波识别、井中地震时距曲线特点
目录 干扰波识别反射波地震勘探的干扰波 井中地震时距曲线特点 干扰波识别 有效波:可以用来解决所提出的地质任务的波;干扰波:所有妨碍辨认、追踪有效波的其他波。 地震勘探中,有效波和干扰波是相对的。例如,在反射波…...

stm32G473的flash模式是单bank还是双bank?
今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

大话软工笔记—需求分析概述
需求分析,就是要对需求调研收集到的资料信息逐个地进行拆分、研究,从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要,后续设计的依据主要来自于需求分析的成果,包括: 项目的目的…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器
一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八
现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet,点击确认后如下提示 最终上报fail 解决方法 内核升级导致,需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...
vue3 字体颜色设置的多种方式
在Vue 3中设置字体颜色可以通过多种方式实现,这取决于你是想在组件内部直接设置,还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法: 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

微服务商城-商品微服务
数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一)
宇树机器人多姿态起立控制强化学习框架论文解析 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(一) 论文解读:交大&港大&上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化…...
06 Deep learning神经网络编程基础 激活函数 --吴恩达
深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

vue3+vite项目中使用.env文件环境变量方法
vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量,这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...