NLP之RNN的原理讲解(python示例)
目录
- 代码示例
- 代码解读
- 知识点介绍
代码示例
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32)
print(xt)
# https://www.cnblogs.com/Renyi-Fan/p/13722276.htmlcell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
cell.build(input_shape=[None, 1])
print('variables', cell.variables)
print('config:', cell.get_config())print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))# 第t时刻运算
ht_1 = tf.ones([1, 1])
out, ht = cell(xt, ht_1) # LSTM
print(out, ht[0])
print(id(out), id(ht[0]))# 第t+1时刻运算
cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones')
xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32)
out2, ht2 = cell2(xt2, ht)
print(out2, ht2[0])
代码解读
这段代码包含了一些使用 TensorFlow 来创建和操作循环神经网络(RNN)的基础操作。我们将一步步地解释其含义。
-
导入所需的库:
import numpy as np import tensorflow as tf from tensorflow.keras.layers import SimpleRNNCell代码导入了NumPy库、TensorFlow库以及
SimpleRNNCell,这是一个实现了简单的RNN单元操作的类。 -
创建训练数据:
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32) print(xt)这里创建了一个
1x1的张量,其值是2或3之间的随机整数。这代表了在时间t的输入数据。 -
定义RNN单元:
cell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))使用
SimpleRNNCell创建了一个RNN单元。这个单元有以下特性:- 只有一个神经元(
units=1)。 - 不使用激活函数(
activation=None)。 - 使用偏置,并初始化为3(
bias_initializer=tf.keras.initializers.Constant(value=3))。 - 输入权重和循环权重都初始化为1。
kernel_initializer='ones':- 这是一个初始化器,用于初始化RNN单元的权重(也称为内核权重)。
'ones'表示所有的权重都被初始化为1。- 换句话说,当输入数据经过RNN单元时,它会与这些权重相乘,而这些权重的初始值都是1。
recurrent_initializer='ones':- 这是一个初始化器,用于初始化RNN单元的循环权重。
- 在RNN中,当前时间步的隐藏状态是基于前一个时间步的隐藏状态计算的。这个计算涉及到的权重就是循环权重。
'ones'表示所有的循环权重都被初始化为1。
bias_initializer=tf.keras.initializers.Constant(value=3):- 这是一个初始化器,用于初始化RNN单元的偏置。
tf.keras.initializers.Constant(value=3)表示所有的偏置被初始化为常数3。
- 简而言之,这些参数(kernel_initializer、recurrent_initializer、bias_initializer)确定了RNN单元在开始训练之前的权重和偏置的初始状态。这些初始值在训练过程中会被更新。选择合适的初始化器对于模型的收敛速度和性能至关重要,尽管在这个特定的例子中,这些权重和偏置被赋予了特定的常数值。
cell.build(input_shape=[None, 1])这行代码是用来告诉RNN单元输入的形状,这样它就可以创建相应的权重和偏置张量。- 在TensorFlow和Keras中,
input_shape是用来指定输入数据的维度的参数。具体到这里的input_shape=[None, 1],我们可以解读它为: [None, 1]:这是一个形状列表,其中有两个维度。None:- 第一个维度通常表示批处理的大小(即在一个批次中的样本数)。在许多情况下,为了使模型更加灵活,我们可能不想在定义模型时硬编码一个固定的批处理大小。
- 使用
None作为批处理的大小意味着模型可以接受任何大小的批次。 - 例如,你可以选择在训练时使用64的批大小,在评估或推理时使用1的批大小,或者使用其他任何数字。
1:- 第二个维度是数据的特征维度。
- 在这里,它指的是输入数据的每个样本有1个特征。
- 综上所述,
input_shape=[None, 1]表示模型可以接受一个二维的输入,其中第一个维度是任意大小的批处理,第二个维度是1个特征。
- 只有一个神经元(
-
显示RNN单元的变量和配置:
代码打印出RNN单元的所有变量(如权重和偏置)以及配置。print('variables', cell.variables) print('config:', cell.get_config())这两行代码是关于打印关于
cell(这里的cell是一个SimpleRNNCell的实例)的相关信息。-
print('variables', cell.variables):cell.variables: 这是一个属性,它返回一个列表,该列表包含cell中的所有可训练变量(权重和偏置)。在RNN cell的上下文中,这通常包括核权重、递归权重以及偏置。print(...): 打印变量列表,以便于你查看和调试。通常这可以帮助你理解RNN cell中的权重如何初始化(例如,这里你已经明确地设置了初始化器)。
-
print('config:', cell.get_config()):cell.get_config(): 这是一个方法,它返回一个字典,该字典包含cell的配置。这通常包括其初始化时使用的参数(例如units的数量、激活函数、是否使用偏置等)。这允许你查看或者后续再次使用这些配置信息,例如,如果你想保存模型的结构并稍后再次创建它。print(...): 打印配置字典,使你能够查看cell的配置。
-
总之,这两行代码提供了关于
SimpleRNNCell实例(cell)的详细信息,包括它的权重(和它们的初始值)以及它的配置。这是非常有用的,特别是当你在调试或了解你的模型结构时。
-
-
计算tanh的值:
print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))这行代码计算了
tanh函数在-∞、6和∞三个点的值。tanh是RNN和其他神经网络中常用的激活函数。 -
第t时刻的计算:
这部分代码首先定义了上一个时间步的隐藏状态ht_1,然后使用cell(xt, ht_1)调用RNN单元来获取当前时间步的输出和隐藏状态。ht_1 = tf.ones([1, 1]) out, ht = cell(xt, ht_1) # LSTM print(out, ht[0]) print(id(out), id(ht[0])) -
第t+1时刻的计算:
同样地,这部分代码定义了一个新的RNN单元cell2,然后用新的输入xt2和上一个时间步的隐藏状态ht来获取下一个时间步的输出和隐藏状态。cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones') xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32) out2, ht2 = cell2(xt2, ht) -
输出与隐藏状态的关系:
print(id(out), id(ht[0]))这部分代码展示了在简单的RNN中,输出状态
out和隐藏状态ht是相同的对象。
最后,代码的主要目的是演示如何使用SimpleRNNCell在给定的输入和隐藏状态上进行计算,并展示其结果。
知识点介绍
tf.Variable 是 TensorFlow(TF)中的一个核心概念,它用于表示在 TF 计算过程中可能会发生变化的数据。在 TF 中,计算通常是通过计算图(graph)来定义的,而 tf.Variable 允许我们将可以变化的状态添加到这些计算图中。
以下是 tf.Variable 的一些关键点:
-
可变性:与 TensorFlow 的常量(
tf.constant)不同,tf.Variable表示的值是可变的。这意味着在训练过程中,可以更新、修改或赋予其新值。 -
用途:
tf.Variable通常用于表示模型的参数,例如神经网络中的权重和偏置。 -
初始化:当创建一个
tf.Variable时,你必须为它提供一个初始值。这个初始值可以是一个固定值,也可以是其他任何 TensorFlow 计算的结果。 -
赋值:使用
assign、assign_add等方法,你可以修改tf.Variable的值。 -
存储和恢复:
tf.Variable的值可以被存储到磁盘并在之后恢复,这是通过 TensorFlow 的保存和恢复机制实现的,这样可以方便地保存和加载模型。
示例:
import tensorflow as tf# 创建一个初始化为1的变量
v = tf.Variable(1.0)# 使用变量
result = v * 2.0# 修改变量的值
v.assign(2.0) # 现在 v 的值为 2.0
总之,tf.Variable 是 TensorFlow 中表示可变状态的主要方式,尤其是在模型训练中,它用于存储和更新模型的参数。
相关文章:
NLP之RNN的原理讲解(python示例)
目录 代码示例代码解读知识点介绍 代码示例 import numpy as np import tensorflow as tf from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据 xt tf.Variable(np.random.randint(2, 3, size[1, 1]), dtypetf.float32) print(xt) # https://www.cnblog…...
yo!这里是进程间通信
目录 前言 进程间通信简介 目的 分类 匿名通道 介绍 举例(进程池) 命名管道 介绍 举例 共享内存 介绍 共享内存函数 1.shmget 2.shmat 3.shmdt 4.shmctl 举例 1.框架 2.通信逻辑 消息队列 信号量 同步与互斥 理解信号量 后记…...
使用docker安装MySQL,Redis,Nacos,Consul教程
文章目录 安装MySQL安装Redis安装Nacos安装Consul 如未安装docker,参考教程: https://blog.csdn.net/m0_63230155/article/details/134090090 安装MySQL #拉取镜像 sudo docker pull mysql:latestsudo docker run --name mysql \-p 3306:3306 \-e MYSQ…...
python和Springboot如何交互?
Python和Spring Boot可以通过RESTful API进行交互。Spring Boot通常用于后端开发,提供了快速构建RESTful API的工具,而Python则可以用于编写前端或与后端交互的代码。 要实现Python和Spring Boot的交互,可以按照以下步骤进行: 在…...
Qt实现json解析
前提要点 json文件,可通过键值的方式存储你所需要的数据,斌且支持多种类型存储,类似于一种结构化的数据库,在读取json文件时可通过相对应的关键字精准获取。他是一种树状结构,我们可以自己设定叶子的数量以及他所代表…...
Ajax、Json深入浅出,及原生Ajax及简化版Ajax
Ajax 1.路径介绍 1.1 JavaWeb中的路径 在JavaWeb中,路径分为相对路径和绝对路径两种: 相对路径: ./ 表示当前目录(可省略) ../ 表示当前文件所在目录的上一级目录 绝对路径: http://ip:port/工程名/资源路径 2.2 在JavaWeb中…...
前端第一阶段测试
前端第一阶段测试 选择问答 如果觉得有用请给我点个赞⑧~ 选择 1、【单选】下列哪个是子代选择器 A A、p>b B、p b C、pb D、p.b 2、【单选】下述有关css属性position的属性值的描述,说法错误的是?B A、static:没有定位,元素出…...
openlayers+vue的bug
使用addInteraction添加交互draw绘制,预期removeInteraction删除交互draw绘制时不再绘制,但是删除绘制不起作用,各种找原因,结果把data中的map变量注释掉即可,原因未知。 <template><div><div id"…...
实时数仓-Hologres介绍与架构
本文是向大家介绍Hologres是一款实时HSAP产品,隶属阿里自研大数据品牌MaxCompute,兼容 PostgreSQL 生态、支持MaxCompute数据直接查询,支持实时写入实时查询,实时离线联邦分析,低成本、高时效、快速构筑企业实时数据仓…...
asp.net教务管理信息系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio计算机毕业设计
一、源码特点 asp.net 教务管理信息系统是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语言 开发 asp.net教务管理系统 应用技术&a…...
爬虫、数据清洗和分析
爬虫、数据清洗和分析是在数据科学、数据挖掘和网络爬虫开发领域中常见的概念。 爬虫(Web Scraping):爬虫是一种自动化程序或脚本,用于从互联网上的网站上提取信息。这些信息可以是文本、图像、视频或其他类型的数据。爬虫通常会…...
SpringBoot | SpringBoot中实现“微信支付“
SpringBoot中实现"微信支付": 1.“微信支付”产品2."微信支付"接入流程3.“微信小程序支付”时序图:3.1 “商家端JSAPI下单” 接口3.2 “微信小程序端调起支付” 接口 4.微信支付准备工作:4.1 获得微信支付平台证书、商户私钥文件4…...
基于SSM和VUE的留守儿童信息管理系统
末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…...
VMware 16开启虚拟机电脑就蓝屏W11解决方法
问题现象 解决方法 控制面板->程序->启用或关闭windows功能->勾选虚拟机平台->重启...
【Bug——VMware Workstation】虚拟机桥接网络没有 VMnet0
此时 没有VMnet0用来桥接网络。 接下来进行解决 1.找到安装VM的路径,在安装的目录里面找到如图所示的三个文件: 2.依次点击鼠标右键 将这三个文件依次安装如图所示: 二.windows下的操作 1.首先 找到电脑的控制面板->网络和internet->…...
centos中安装Mysql8.0
其实和mysql5.7的安装差不多 1.root用户 2.更新密钥 rpm --import https://repo.mysql.com/RPM-GPG-KEY-mysql-2022 3.安装mysql yum库 rpm -Uvh https://dev.mysql.com/ get/mysql80-community-release-el7-2.noarch.rpm 4.通过上两步,我们就可以使用yum去安装…...
简化对象和函数写法
简化对象写法: 传统写法: var x 10, y 20; var obj {x: x, y: y};简化写法: var x 10, y 20; var obj {x, y};简化函数写法: 传统写法: function add(x, y) {return x y; }简化写法: var add …...
GB/T28181流媒体相关协议详解
GB/T28181流媒体相关协议详解 文章目录 GB/T28181流媒体相关协议详解1 GB/T28181协议中使用的应用层协议介绍2 实时视频点播协议交互流程2.1 设备注册2.2 设备保活2.3 视频播放 总结 本文主要主要针对28181协议中视频流的部分,来阐述视频流通过28181协议如何进行视频…...
十进制转二进制的算法代码 ← Python
【算法分析】 本算法需要用到的Python知识点: 1.求余%,整除 //。例如,7%21,7//23,而7/23.5。 2.Python列表的 append 及 pop 函数。 • append(x) 函数用于将 x 添加到现有列表中。 • pop() 函数默认移除列表中…...
智慧垃圾站:AI视频智能识别技术助力智慧环保项目,以“智”替人强监管
一、背景分析 建设“技术先进、架构合理、开放智能、安全可靠”的智慧环保平台,整合环境相关的数据,对接已建业务系统,将环境相关数据进行统一管理,结合GIS技术进行监测、监控信息的展现和挖掘分析,实现业务数据的快速…...
Go 语言接口详解
Go 语言接口详解 核心概念 接口定义 在 Go 语言中,接口是一种抽象类型,它定义了一组方法的集合: // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的: // 矩形结构体…...
【单片机期末】单片机系统设计
主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...
用机器学习破解新能源领域的“弃风”难题
音乐发烧友深有体会,玩音乐的本质就是玩电网。火电声音偏暖,水电偏冷,风电偏空旷。至于太阳能发的电,则略显朦胧和单薄。 不知你是否有感觉,近两年家里的音响声音越来越冷,听起来越来越单薄? —…...
Python基于历史模拟方法实现投资组合风险管理的VaR与ES模型项目实战
说明:这是一个机器学习实战项目(附带数据代码文档),如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 在金融市场日益复杂和波动加剧的背景下,风险管理成为金融机构和个人投资者关注的核心议题之一。VaR&…...
快速排序算法改进:随机快排-荷兰国旗划分详解
随机快速排序-荷兰国旗划分算法详解 一、基础知识回顾1.1 快速排序简介1.2 荷兰国旗问题 二、随机快排 - 荷兰国旗划分原理2.1 随机化枢轴选择2.2 荷兰国旗划分过程2.3 结合随机快排与荷兰国旗划分 三、代码实现3.1 Python实现3.2 Java实现3.3 C实现 四、性能分析4.1 时间复杂度…...
PydanticAI快速入门示例
参考链接:https://ai.pydantic.dev/#why-use-pydanticai 示例代码 from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.openai import OpenAIProvider# 配置使用阿里云通义千问模型 model OpenAIMode…...
Easy Excel
Easy Excel 一、依赖引入二、基本使用1. 定义实体类(导入/导出共用)2. 写 Excel3. 读 Excel 三、常用注解说明(完整列表)四、进阶:自定义转换器(Converter) 其它自定义转换器没生效 Easy Excel在…...
在Spring Boot中集成RabbitMQ的完整指南
前言 在现代微服务架构中,消息队列(Message Queue)是实现异步通信、解耦系统组件的重要工具。RabbitMQ 是一个流行的消息中间件,支持多种消息协议,具有高可靠性和可扩展性。 本博客将详细介绍如何在 Spring Boot 项目…...
Gitlab + Jenkins 实现 CICD
CICD 是持续集成(Continuous Integration, CI)和持续交付/部署(Continuous Delivery/Deployment, CD)的缩写,是现代软件开发中的一种自动化流程实践。下面介绍 Web 项目如何在代码提交到 Gitlab 后,自动发布…...
Go 语言中的内置运算符
1. 算术运算符 注意: (自增)和--(自减)在 Go 语言中是单独的语句,并不是运算符。 package mainimport "fmt"func main() {fmt.Println("103", 103) // 13fmt.Println("10-3…...
