当前位置: 首页 > news >正文

NLP之RNN的原理讲解(python示例)

目录

    • 代码示例
    • 代码解读
    • 知识点介绍

代码示例

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据
xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32)
print(xt)
# https://www.cnblogs.com/Renyi-Fan/p/13722276.htmlcell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
cell.build(input_shape=[None, 1])
print('variables', cell.variables)
print('config:', cell.get_config())print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))# 第t时刻运算
ht_1 = tf.ones([1, 1])
out, ht = cell(xt, ht_1)  # LSTM
print(out, ht[0])
print(id(out), id(ht[0]))# 第t+1时刻运算
cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones')
xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32)
out2, ht2 = cell2(xt2, ht)
print(out2, ht2[0])

代码解读

这段代码包含了一些使用 TensorFlow 来创建和操作循环神经网络(RNN)的基础操作。我们将一步步地解释其含义。

  1. 导入所需的库:

    import numpy as np
    import tensorflow as tf
    from tensorflow.keras.layers import SimpleRNNCell
    

    代码导入了NumPy库、TensorFlow库以及SimpleRNNCell,这是一个实现了简单的RNN单元操作的类。

  2. 创建训练数据:

    xt = tf.Variable(np.random.randint(2, 3, size=[1, 1]), dtype=tf.float32)
    print(xt)
    

    这里创建了一个1x1的张量,其值是2或3之间的随机整数。这代表了在时间t的输入数据。

  3. 定义RNN单元:

    cell = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones', recurrent_initializer='ones',bias_initializer=tf.keras.initializers.Constant(value=3))
    

    使用SimpleRNNCell创建了一个RNN单元。这个单元有以下特性:

    • 只有一个神经元(units=1)。
    • 不使用激活函数(activation=None)。
    • 使用偏置,并初始化为3(bias_initializer=tf.keras.initializers.Constant(value=3))。
    • 输入权重和循环权重都初始化为1。
    • kernel_initializer='ones':
      • 这是一个初始化器,用于初始化RNN单元的权重(也称为内核权重)。
      • 'ones'表示所有的权重都被初始化为1。
      • 换句话说,当输入数据经过RNN单元时,它会与这些权重相乘,而这些权重的初始值都是1。
    • recurrent_initializer='ones':
      • 这是一个初始化器,用于初始化RNN单元的循环权重。
      • 在RNN中,当前时间步的隐藏状态是基于前一个时间步的隐藏状态计算的。这个计算涉及到的权重就是循环权重。
      • 'ones'表示所有的循环权重都被初始化为1。
    • bias_initializer=tf.keras.initializers.Constant(value=3):
      • 这是一个初始化器,用于初始化RNN单元的偏置。
      • tf.keras.initializers.Constant(value=3)表示所有的偏置被初始化为常数3。
    • 简而言之,这些参数(kernel_initializer、recurrent_initializer、bias_initializer)确定了RNN单元在开始训练之前的权重和偏置的初始状态。这些初始值在训练过程中会被更新。选择合适的初始化器对于模型的收敛速度和性能至关重要,尽管在这个特定的例子中,这些权重和偏置被赋予了特定的常数值。

    cell.build(input_shape=[None, 1])这行代码是用来告诉RNN单元输入的形状,这样它就可以创建相应的权重和偏置张量。

    • 在TensorFlow和Keras中,input_shape是用来指定输入数据的维度的参数。具体到这里的input_shape=[None, 1],我们可以解读它为:
    • [None, 1]:这是一个形状列表,其中有两个维度。
    • None
      • 第一个维度通常表示批处理的大小(即在一个批次中的样本数)。在许多情况下,为了使模型更加灵活,我们可能不想在定义模型时硬编码一个固定的批处理大小。
      • 使用None作为批处理的大小意味着模型可以接受任何大小的批次。
      • 例如,你可以选择在训练时使用64的批大小,在评估或推理时使用1的批大小,或者使用其他任何数字。
    • 1
      • 第二个维度是数据的特征维度。
      • 在这里,它指的是输入数据的每个样本有1个特征。
    • 综上所述,input_shape=[None, 1]表示模型可以接受一个二维的输入,其中第一个维度是任意大小的批处理,第二个维度是1个特征。
  4. 显示RNN单元的变量和配置:
    代码打印出RNN单元的所有变量(如权重和偏置)以及配置。

    print('variables', cell.variables)
    print('config:', cell.get_config())
    

    这两行代码是关于打印关于cell(这里的cell是一个SimpleRNNCell的实例)的相关信息。

    • print('variables', cell.variables):

      • cell.variables: 这是一个属性,它返回一个列表,该列表包含cell中的所有可训练变量(权重和偏置)。在RNN cell的上下文中,这通常包括核权重、递归权重以及偏置。
      • print(...): 打印变量列表,以便于你查看和调试。通常这可以帮助你理解RNN cell中的权重如何初始化(例如,这里你已经明确地设置了初始化器)。
    • print('config:', cell.get_config()):

      • cell.get_config(): 这是一个方法,它返回一个字典,该字典包含cell的配置。这通常包括其初始化时使用的参数(例如units的数量、激活函数、是否使用偏置等)。这允许你查看或者后续再次使用这些配置信息,例如,如果你想保存模型的结构并稍后再次创建它。
      • print(...): 打印配置字典,使你能够查看cell的配置。
    • 总之,这两行代码提供了关于SimpleRNNCell实例(cell)的详细信息,包括它的权重(和它们的初始值)以及它的配置。这是非常有用的,特别是当你在调试或了解你的模型结构时。

  5. 计算tanh的值:

    print(tf.nn.tanh(tf.constant([-float("inf"), 6, float("inf")])))
    

    这行代码计算了tanh函数在-∞、6和三个点的值。tanh是RNN和其他神经网络中常用的激活函数。

  6. 第t时刻的计算:
    这部分代码首先定义了上一个时间步的隐藏状态ht_1,然后使用cell(xt, ht_1)调用RNN单元来获取当前时间步的输出和隐藏状态。

    ht_1 = tf.ones([1, 1])
    out, ht = cell(xt, ht_1)  # LSTM
    print(out, ht[0])
    print(id(out), id(ht[0]))
    
  7. 第t+1时刻的计算:
    同样地,这部分代码定义了一个新的RNN单元cell2,然后用新的输入xt2和上一个时间步的隐藏状态ht来获取下一个时间步的输出和隐藏状态。

    cell2 = SimpleRNNCell(units=1, activation=None, use_bias=True, kernel_initializer='ones',recurrent_initializer=tf.keras.initializers.Constant(value=3), bias_initializer='ones')
    xt2 = tf.Variable(np.random.randint(3, 4, size=[1, 1]), dtype=tf.float32)
    out2, ht2 = cell2(xt2, ht)
    
  8. 输出与隐藏状态的关系:

    print(id(out), id(ht[0]))
    

    这部分代码展示了在简单的RNN中,输出状态out和隐藏状态ht是相同的对象。

最后,代码的主要目的是演示如何使用SimpleRNNCell在给定的输入和隐藏状态上进行计算,并展示其结果。

知识点介绍

tf.Variable 是 TensorFlow(TF)中的一个核心概念,它用于表示在 TF 计算过程中可能会发生变化的数据。在 TF 中,计算通常是通过计算图(graph)来定义的,而 tf.Variable 允许我们将可以变化的状态添加到这些计算图中。

以下是 tf.Variable 的一些关键点:

  1. 可变性:与 TensorFlow 的常量(tf.constant)不同,tf.Variable 表示的值是可变的。这意味着在训练过程中,可以更新、修改或赋予其新值。

  2. 用途tf.Variable 通常用于表示模型的参数,例如神经网络中的权重和偏置。

  3. 初始化:当创建一个 tf.Variable 时,你必须为它提供一个初始值。这个初始值可以是一个固定值,也可以是其他任何 TensorFlow 计算的结果。

  4. 赋值:使用 assignassign_add 等方法,你可以修改 tf.Variable 的值。

  5. 存储和恢复tf.Variable 的值可以被存储到磁盘并在之后恢复,这是通过 TensorFlow 的保存和恢复机制实现的,这样可以方便地保存和加载模型。

示例:

import tensorflow as tf# 创建一个初始化为1的变量
v = tf.Variable(1.0)# 使用变量
result = v * 2.0# 修改变量的值
v.assign(2.0)  # 现在 v 的值为 2.0

总之,tf.Variable 是 TensorFlow 中表示可变状态的主要方式,尤其是在模型训练中,它用于存储和更新模型的参数。

相关文章:

NLP之RNN的原理讲解(python示例)

目录 代码示例代码解读知识点介绍 代码示例 import numpy as np import tensorflow as tf from tensorflow.keras.layers import SimpleRNNCell# 第t时刻要训练的数据 xt tf.Variable(np.random.randint(2, 3, size[1, 1]), dtypetf.float32) print(xt) # https://www.cnblog…...

yo!这里是进程间通信

目录 前言 进程间通信简介 目的 分类 匿名通道 介绍 举例(进程池) 命名管道 介绍 举例 共享内存 介绍 共享内存函数 1.shmget 2.shmat 3.shmdt 4.shmctl 举例 1.框架 2.通信逻辑 消息队列 信号量 同步与互斥 理解信号量 后记…...

使用docker安装MySQL,Redis,Nacos,Consul教程

文章目录 安装MySQL安装Redis安装Nacos安装Consul 如未安装docker,参考教程: https://blog.csdn.net/m0_63230155/article/details/134090090 安装MySQL #拉取镜像 sudo docker pull mysql:latestsudo docker run --name mysql \-p 3306:3306 \-e MYSQ…...

python和Springboot如何交互?

Python和Spring Boot可以通过RESTful API进行交互。Spring Boot通常用于后端开发,提供了快速构建RESTful API的工具,而Python则可以用于编写前端或与后端交互的代码。 要实现Python和Spring Boot的交互,可以按照以下步骤进行: 在…...

Qt实现json解析

前提要点 json文件,可通过键值的方式存储你所需要的数据,斌且支持多种类型存储,类似于一种结构化的数据库,在读取json文件时可通过相对应的关键字精准获取。他是一种树状结构,我们可以自己设定叶子的数量以及他所代表…...

Ajax、Json深入浅出,及原生Ajax及简化版Ajax

Ajax 1.路径介绍 1.1 JavaWeb中的路径 在JavaWeb中,路径分为相对路径和绝对路径两种: 相对路径: ./ 表示当前目录(可省略) ../ 表示当前文件所在目录的上一级目录 绝对路径: http://ip:port/工程名/资源路径 2.2 在JavaWeb中…...

前端第一阶段测试

前端第一阶段测试 选择问答 如果觉得有用请给我点个赞⑧~ 选择 1、【单选】下列哪个是子代选择器 A A、p>b B、p b C、pb D、p.b 2、【单选】下述有关css属性position的属性值的描述,说法错误的是?B A、static:没有定位,元素出…...

openlayers+vue的bug

使用addInteraction添加交互draw绘制&#xff0c;预期removeInteraction删除交互draw绘制时不再绘制&#xff0c;但是删除绘制不起作用&#xff0c;各种找原因&#xff0c;结果把data中的map变量注释掉即可&#xff0c;原因未知。 <template><div><div id"…...

实时数仓-Hologres介绍与架构

本文是向大家介绍Hologres是一款实时HSAP产品&#xff0c;隶属阿里自研大数据品牌MaxCompute&#xff0c;兼容 PostgreSQL 生态、支持MaxCompute数据直接查询&#xff0c;支持实时写入实时查询&#xff0c;实时离线联邦分析&#xff0c;低成本、高时效、快速构筑企业实时数据仓…...

asp.net教务管理信息系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio计算机毕业设计

一、源码特点 asp.net 教务管理信息系统是一套完善的web设计管理系统&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为vs2010&#xff0c;数据库为sqlserver2008&#xff0c;使用c#语言 开发 asp.net教务管理系统 应用技术&a…...

爬虫、数据清洗和分析

爬虫、数据清洗和分析是在数据科学、数据挖掘和网络爬虫开发领域中常见的概念。 爬虫&#xff08;Web Scraping&#xff09;&#xff1a;爬虫是一种自动化程序或脚本&#xff0c;用于从互联网上的网站上提取信息。这些信息可以是文本、图像、视频或其他类型的数据。爬虫通常会…...

SpringBoot | SpringBoot中实现“微信支付“

SpringBoot中实现"微信支付": 1.“微信支付”产品2."微信支付"接入流程3.“微信小程序支付”时序图&#xff1a;3.1 “商家端JSAPI下单” 接口3.2 “微信小程序端调起支付” 接口 4.微信支付准备工作&#xff1a;4.1 获得微信支付平台证书、商户私钥文件4…...

基于SSM和VUE的留守儿童信息管理系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…...

VMware 16开启虚拟机电脑就蓝屏W11解决方法

问题现象 解决方法 控制面板->程序->启用或关闭windows功能->勾选虚拟机平台->重启...

【Bug——VMware Workstation】虚拟机桥接网络没有 VMnet0

此时 没有VMnet0用来桥接网络。 接下来进行解决 1.找到安装VM的路径&#xff0c;在安装的目录里面找到如图所示的三个文件&#xff1a; 2.依次点击鼠标右键 将这三个文件依次安装如图所示&#xff1a; 二.windows下的操作 1.首先 找到电脑的控制面板->网络和internet->…...

centos中安装Mysql8.0

其实和mysql5.7的安装差不多 1.root用户 2.更新密钥 rpm --import https://repo.mysql.com/RPM-GPG-KEY-mysql-2022 3.安装mysql yum库 rpm -Uvh https://dev.mysql.com/ get/mysql80-community-release-el7-2.noarch.rpm 4.通过上两步&#xff0c;我们就可以使用yum去安装…...

简化对象和函数写法

简化对象写法&#xff1a; 传统写法&#xff1a; var x 10, y 20; var obj {x: x, y: y};简化写法&#xff1a; var x 10, y 20; var obj {x, y};简化函数写法&#xff1a; 传统写法&#xff1a; function add(x, y) {return x y; }简化写法&#xff1a; var add …...

GB/T28181流媒体相关协议详解

GB/T28181流媒体相关协议详解 文章目录 GB/T28181流媒体相关协议详解1 GB/T28181协议中使用的应用层协议介绍2 实时视频点播协议交互流程2.1 设备注册2.2 设备保活2.3 视频播放 总结 本文主要主要针对28181协议中视频流的部分&#xff0c;来阐述视频流通过28181协议如何进行视频…...

十进制转二进制的算法代码 ← Python

【算法分析】 本算法需要用到的Python知识点&#xff1a; 1.求余%&#xff0c;整除 //。例如&#xff0c;7%21&#xff0c;7//23&#xff0c;而7/23.5。 2.Python列表的 append 及 pop 函数。 • append(x) 函数用于将 x 添加到现有列表中。 • pop() 函数默认移除列表中…...

智慧垃圾站:AI视频智能识别技术助力智慧环保项目,以“智”替人强监管

一、背景分析 建设“技术先进、架构合理、开放智能、安全可靠”的智慧环保平台&#xff0c;整合环境相关的数据&#xff0c;对接已建业务系统&#xff0c;将环境相关数据进行统一管理&#xff0c;结合GIS技术进行监测、监控信息的展现和挖掘分析&#xff0c;实现业务数据的快速…...

Ubuntu经常安装软件

1、垃圾清理工具stacer sudo apt updatesudo apt install stacer apt cleanapt autocleanapt autoremove 2、类似与everything的工具Fsearcch 1sudo add-apt-repository ppa:christian-boxdoerfer/fsearch-stable 2sudo apt update 3sudo apt install fsearch (注&#xf…...

终极歌词同步神器LRCGET:5分钟为你的音乐库添加完美歌词

终极歌词同步神器LRCGET&#xff1a;5分钟为你的音乐库添加完美歌词 【免费下载链接】lrcget Utility for mass-downloading LRC synced lyrics for your offline music library. 项目地址: https://gitcode.com/gh_mirrors/lr/lrcget 你是否厌倦了在听歌时手动搜索歌词…...

Windows Cleaner:终极免费系统清理工具,彻底解决C盘空间不足问题

Windows Cleaner&#xff1a;终极免费系统清理工具&#xff0c;彻底解决C盘空间不足问题 【免费下载链接】WindowsCleaner Windows Cleaner——专治C盘爆红及各种不服&#xff01; 项目地址: https://gitcode.com/gh_mirrors/wi/WindowsCleaner 你是否经常遇到C盘爆红、…...

NBTExplorer:让Minecraft数据编辑从专业工具变成人人可用的可视化平台

NBTExplorer&#xff1a;让Minecraft数据编辑从专业工具变成人人可用的可视化平台 【免费下载链接】NBTExplorer A graphical NBT editor for all Minecraft NBT data sources 项目地址: https://gitcode.com/gh_mirrors/nb/NBTExplorer 你是否曾经面对Minecraft世界文件…...

Unlock-Music:浏览器中一键解锁加密音乐文件的完整指南

Unlock-Music&#xff1a;浏览器中一键解锁加密音乐文件的完整指南 【免费下载链接】unlock-music 在浏览器中解锁加密的音乐文件。原仓库&#xff1a; 1. https://github.com/unlock-music/unlock-music &#xff1b;2. https://git.unlock-music.dev/um/web 项目地址: http…...

ABS+神经网络:端到端宇宙学参数推断新范式解析

1. 项目概述&#xff1a;当ABS遇上神经网络&#xff0c;一个端到端宇宙学参数推断新范式的诞生 在宇宙学研究的核心地带&#xff0c;有一项任务既令人着迷又充满挑战&#xff1a;如何从宇宙微波背景&#xff08;CMB&#xff09;这张宇宙婴儿时期的“照片”中&#xff0c;精准地…...

拒绝繁琐 PS:美图秀秀 电脑版在技术博客配图、无畸变裁剪与尺寸标准化中的应用

在日常开发、技术写作或维护 GitHub 开源项目时&#xff0c;技术配图和录屏展示是不可或缺的组成部分。 然而&#xff0c;对于大多数程序员和前端开发者来说&#xff0c;仅仅为了裁剪一个 App Icon 尺寸、给一系列产品图加防伪水印、对系统敏感配置截图进行脱敏打码&#xff0…...

5大核心功能掌握HandheldCompanion:Windows掌机终极控制伴侣

5大核心功能掌握HandheldCompanion&#xff1a;Windows掌机终极控制伴侣 【免费下载链接】HandheldCompanion ControllerService 项目地址: https://gitcode.com/gh_mirrors/ha/HandheldCompanion 你是否正在寻找一款能够彻底改变Windows掌机游戏体验的控制软件&#xf…...

如何在macOS上免费安装HSTracker:终极炉石传说套牌追踪器完整指南

如何在macOS上免费安装HSTracker&#xff1a;终极炉石传说套牌追踪器完整指南 【免费下载链接】HSTracker A deck tracker and deck manager for Hearthstone on macOS 项目地址: https://gitcode.com/gh_mirrors/hs/HSTracker 还在为炉石传说对局中记不住对手出牌而烦恼…...

MySQL全局ID生成实战:从自增主键到自定义Sequence的平滑升级方案与避坑指南

MySQL全局ID生成实战&#xff1a;从自增主键到自定义Sequence的平滑升级方案与避坑指南 当电商平台的日订单量突破百万时&#xff0c;技术团队突然发现系统开始频繁出现"Duplicate entry"错误——那些原本可靠的自增主键&#xff0c;在分库分表的环境下变成了数据一致…...