当前位置: 首页 > news >正文

边写代码边学习之RNN

1. 什么是 RNN

循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图:

在这里插入图片描述

 x是输入,h是隐层单元,o为输出,L为损失函数,y为训练集的标签.
这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。
有了上面的理解,前向传播算法其实非常简单,对于t时刻:
                                       h ^{(t)} =\phi (Ux^{(t)} +Wh^{(t-1)} +b)

其中\phi ()为激活函数,一般来说会选择tanh函数,b为偏置。
t时刻的输出就更为简单:
                                                     o^{(t)} =Vh ^{(t)} +c
最终模型的预测输出为:
                                                          \hat y^{(t)} =\sigma (o^{(t)} )
其中\sigma为激活函数,通常RNN用于分类,故这里一般用softmax函数。

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

def simple_rnn_layer():# Create a dense layer with 10 output neurons and input shape of (None, 20)model = Sequential()model.add(SimpleRNN(units=3, input_shape=(3, 2),))  # 3 units in the RNN layer, input_shape=(timesteps, features)model.add(Dense(1))  # Output layer with one neuron# Print the summary of the dense layerprint(model.summary())
if __name__ == '__main__':simple_rnn_layer()

输出

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================simple_rnn (SimpleRNN)      (None, 3)                 18        dense (Dense)               (None, 1)                 4         =================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
None

2.2. 验证RNN里的逻辑

写代码验证这个过程,看看结果是不是一样的。

import keras.optimizers.optimizer
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
def change_weight():# Create a simple Dense layerrnn_layer = SimpleRNN(units=3, input_shape=(3, 2), activation=None, return_sequences=True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biases_ = rnn_layer(input_data)# Access the weights and biases of the dense layerkernel, recurrent_kernel, biases = rnn_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel) # (3,3)print('kernal:',kernel) #(2,3)print('biase: ',biases) # (3)kernel = np.array([[1, 0, 2], [2, 1, 3]])recurrent_kernel = np.array([[1, 2, 1.0], [1, 0, 1], [0, 1, 0]])biases = np.array([0, 0, 1.0])rnn_layer.set_weights([kernel, recurrent_kernel, biases])print(rnn_layer.get_weights())test_data = np.array([[[1.0, 3], [1, 1], [2, 3]]])output = rnn_layer(test_data)print(output)if __name__ == '__main__':change_weight()

输出结果如下:可以看到结果是我手算的是一致的。

recurrent_kernel: [[ 0.06973135  0.40464386  0.9118119 ][ 0.6186313  -0.7345941   0.27868783][ 0.7825809   0.5446422  -0.3015495 ]]
kernal: [[-0.48868906  0.52718353 -0.08321357][-1.0569452  -0.9872779   0.72809434]]
biase:  [0. 0. 0.]
[array([[1., 0., 2.],[2., 1., 3.]], dtype=float32), array([[1., 2., 1.],[1., 0., 1.],[0., 1., 0.]], dtype=float32), array([0., 0., 1.], dtype=float32)]
tf.Tensor(
[[[ 7.  3. 12.][13. 27. 16.][48. 45. 54.]]], shape=(1, 3, 3), dtype=float32)

2.3 代码实现一个简单的例子

import keras.optimizers.optimizer
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense# Sample sequential data
# Each sequence has three timesteps, and each timestep has two features
data = np.array([[[1, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]
])print('data.shape= ',data.shape)
# Define the RNN model
model = Sequential()
model.add(SimpleRNN(units=4, input_shape=(3, 2), name="simpleRNN"))  # 4 units in the RNN layer, input_shape=(timesteps, features)
model.add(Dense(1, name= "output"))  # Output layer with one neuron# Compile the model
model.compile(loss='mse', optimizer=keras.optimizers.Adam(learning_rate=0.01))# Print the model summary
model.summary()before_RNN_weight = model.get_layer("simpleRNN").get_weights()
print('before train ', before_RNN_weight)# Train the model
model.fit(data, np.array([[10], [20], [30]]), epochs=2000, verbose=1)RNN_weight = model.get_layer("simpleRNN").get_weights()
print('after train ', len(RNN_weight),)for i in range(len(RNN_weight)):print('====',RNN_weight[i].shape, RNN_weight[i])# Make predictions
predictions = model.predict(data)
print("Predictions:", predictions.flatten())

代码输出

data.shape=  (3, 3, 2)
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================simpleRNN (SimpleRNN)       (None, 4)                 28        output (Dense)              (None, 1)                 5         =================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
_________________________________________________________________
before train  [array([[-0.00466371,  0.53100157,  0.5298798 ,  0.05514288],[-0.08896947,  0.43185067,  0.7861788 , -0.80616236]],dtype=float32), array([[-0.10712242, -0.03620092, -0.02182053, -0.9933471 ],[-0.6549012 , -0.02620655,  0.7532524 ,  0.05503315],[-0.01986913,  0.9989996 ,  0.02001702, -0.03470401],[-0.74781984,  0.00159313, -0.657065  ,  0.09502006]],dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
2023-08-05 16:02:44.111298: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/2000
....
Epoch 1999/2000
1/1 [==============================] - 0s 11ms/step - loss: 0.0071
Epoch 2000/2000
1/1 [==============================] - 0s 13ms/step - loss: 0.0070
after train  3
==== (2, 4) [[ 0.27645147  0.6025058   1.6083356  -0.38382724][ 0.11586202  0.32901326  1.4760928  -1.2268958 ]]
==== (4, 4) [[-0.99628973 -2.444563    1.7412992  -1.5265529 ][ 0.80340594  0.9488743   2.44552    -0.7439341 ][-0.1827681  -1.3091801   1.547736   -0.6644555 ][-0.5724374   2.3090494  -2.1779017   0.35992467]]
==== (4,) [-0.40184066 -1.2391611   0.33460653 -0.29144585]
1/1 [==============================] - 0s 78ms/step
Predictions: [10.000422 19.999924 29.85534 ]

相关文章:

边写代码边学习之RNN

1. 什么是 RNN 循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图: x是输入,h是隐层单元,o为输出&#xff…...

在linux调试进程PID的方法

当我们谈论调试 PID(进程标识符)时,我们通常是指诊断和解决与操作系统中的特定进程相关的问题。有许多工具和方法可用于调试 PID,以下是一些常见的方法: 1. 使用ps命令 ps命令是最基本的调试工具,用于查看…...

【并发编程】线程安全的栈容器

std::stack容器的接口包括 empty(), size(), top(), push(), pop()等。 问题 其原接口在多线程的情况下,会持续很多问题。 例如,在std::stack容器的接口中,在多线程下应用时,empty()和size()的结果是不可信的。因为尽管在某线程…...

ES嵌套查询和普通查询的高亮显示区别

在 Elasticsearch 中,高亮显示是一种强大的搜索结果可视化工具,它可以帮助我们快速识别匹配的关键字或短语。在ES中,我们可以使用两种不同的查询方式来实现高亮显示:嵌套查询和普通查询。本文探讨这两种查询方式的高亮显示区别以及…...

Greenplum集群部署

一,安装说明 1.1环境说明 *名称**版本*操作系统CentOS 7.6 64bitgreenplumgreenplum-db-6.10.1-rhel7-x86_64.rpm1.2集群介绍 IPhostname集群节点10.240.3.244gpmastermaster10.240.3.245gpsegment1segment10.240.3.246gpsegment2segment二,安装环境准备 2.1 修改各节点名称…...

电教智能云数据可视化平台开发电能优化日志实录

电教智能云数据可视化平台开发电脑优化日志实录 一、2K和4K弹窗判断二、电能API对接1.电脑爬虫2.电能分组过滤3.数据可视化渲染4.弹窗 三.数组按顺序输出 一、2K和4K弹窗判断 {* 判断2k和4k弹窗 *}{if $dataScene[scene_standard] eq 0}<a class"menuBtn subMenu"…...

JSX语法基础总结

题记&#xff1a;首先我们要了解一下jsx是什么&#xff0c;跟js有什么区别&#xff0c;其实就是js的语法糖&#xff0c;加上了xml的语法&#xff0c;使得产生虚拟dom更加的方便&#xff0c;简单说一下&#xff0c;xml就是存储数据的格式&#xff0c;想了解xml的话&#xff0c;可…...

socker套接字

1.打印错误信息 2.socketaddr_in结构体 结构体&#xff1a; &#xff08;部分库代码&#xff09; (宏中的##) 3.manual TCP: SOCK_STREAM &#xff1a; 提供有序地&#xff0c;可靠的&#xff0c;全双工的&#xff0c;基于连接的流式服务 UDP: 面向数据报...

No111.精选前端面试题,享受每天的挑战和学习

文章目录 map和foreach的区别在组件中如何获取vuex的action对象中的属性怎么去获取封装在vuex的某个接口数据有没有抓包过&#xff1f;你如何跟踪某一个特定的请求&#xff1f;比如一个特定的URL&#xff0c;你如何把有关这部分的url数据提取出来&#xff1f;1. 使用网络抓包工…...

【Apollo学习笔记】—— 相机仿真

文章目录 前言相关代码整理 测试实践文件目录包管理BUILD文件以及cyberfile.xml文件源程序BUILD运行结果其他参考CameraOutput channels启动camera驱动启动camera video compression驱动 前言 本文是对Cyber RT的学习记录,文章可能存在不严谨、不完善、有缺漏的部分&#xff0…...

【数据结构】——线性表的相关习题

目录 题型一&#xff08;线性表的存储结构&#xff09;题型二&#xff08;链表的判空&#xff09;题型三&#xff08;单链表的建立&#xff09;题型四&#xff08;顺序表、单链表的插入删除操作&#xff09;题型五&#xff08;双链表的插入删除操作&#xff09;题型六&#xff…...

SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用)

SpringBoot集成Elasticsearch8.x&#xff08;8&#xff09;|&#xff08;新版本Java API Client的Painless语言脚本script使用&#xff09; 文章目录 SpringBoot集成Elasticsearch8.x&#xff08;8&#xff09;|&#xff08;新版本Java API Client的Painless语言脚本script使用…...

SpringBoot复习:(19)Condition接口和@Conditional注解

Condition接口代码如下&#xff1a; public interface Condition {boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata);}它是一个函数式接口&#xff0c;只有一个方法matches用来表示条件是否满足。matches方法中的ConditionContext类对象context可以…...

K8s中的Controller

Controller的作用 &#xff08;1&#xff09;确保预期的pod副本数量 &#xff08;2&#xff09;无状态应用部署 &#xff08;3&#xff09;有状态应用部署 &#xff08;4&#xff09;确保所有的node运行同一个pod&#xff0c;一次性任务和定时任务 1.无状态和有状态 无状态&…...

【MFC】03.常用复杂控件的使用-笔记

热键&#xff1a; 对话框-类向导&#xff1a;初始化函数中&#xff0c;热键需要在最开始的时候就注册进去&#xff1a; 注册热键&#xff1a; 在这之前&#xff0c;先去定义一个宏&#xff0c;代表你这个快捷键。 参数&#xff1a;窗口句柄&#xff0c;热键编号&#xff08;热…...

Autosar诊断实战系列14-NRC优先级解析

本文框架 前言1. NRC分类2. NRC优先级判断2.1. NRC优先级判断逻辑介绍2.2 NRC测试注意事项前言 在本系列笔者将结合工作中对诊断实战部分的应用经验进一步介绍常用UDS服务的进一步探讨及开发中注意事项, Dem/Dcm/CanTp/Fim模块配置开发及注意事项,诊断与BswM/NvM关联模块的应…...

《向量数据库指南》——腾讯云向量数据库Tencent Cloud VectorDB产品特性,架构和应用场景

腾讯云向量数据库(Tencent Cloud VectorDB)是一款全托管的自研企业级分布式数据库服务,专用于存储、检索、分析多维向量数据。该数据库支持多种索引类型和相似度计算方法,单索引支持 10 亿级向量规模,可支持百万级 QPS 及毫秒级查询延迟。腾讯云向量数据库不仅能为大模型提…...

xcode 的app工程与ffmpeg 4.4版本的静态库联调,ffmpeg内下的断点无法暂停。

先阐述一下我的业务场景&#xff0c;我有一个iOS的app sdk项目&#xff0c;下面简称 A &#xff0c;以及运行 A 的 app 项目&#xff0c;简称 A demo 。 引用关系为 A demo 引用了 A &#xff0c;而 A 引用了 ffmpeg 的静态库&#xff08;.a文件&#xff09;。此时业务出现了 b…...

机器学习06 数据准备-(利用 scikit-learn基于Pima Indian数据集作 数据特征选定)

什么是数据特征选定? 数据特征选定&#xff08;Feature Selection&#xff09;是指从原始数据中选择最相关、最有用的特征&#xff0c;用于构建机器学习模型。特征选定是机器学习流程中非常重要的一步&#xff0c;它直接影响模型的性能和泛化能力。通过选择最重要的特征&#…...

机器学习-特征选择:如何使用Lassco回归精确选择最佳特征?

一、引言 特征选择在机器学习领域中扮演着至关重要的角色&#xff0c;它能够从原始数据中选择最具信息量的特征&#xff0c;提高模型性能、减少过拟合&#xff0c;并加快模型训练和预测的速度。在大规模数据集和高维数据中&#xff0c;特征选择尤为重要&#xff0c;因为不必要的…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入&#xff0c;一个是通过INMP441麦克风模块采集音频&#xff0c;一个是通过PCM5102A模块播放音频&#xff0c;那如果我们将两者结合起来&#xff0c;将麦克风采集到的音频通过PCM5102A播放&#xff0c;是不是就可以做一个扩音器了呢…...

[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)

参考官方文档&#xff1a;https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java&#xff08;供 Kotlin 使用&#xff09; 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

企业如何增强终端安全?

在数字化转型加速的今天&#xff0c;企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机&#xff0c;到工厂里的物联网设备、智能传感器&#xff0c;这些终端构成了企业与外部世界连接的 “神经末梢”。然而&#xff0c;随着远程办公的常态化和设备接入的爆炸式…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现企业微信功能

1. 开发环境准备 ​​安装DevEco Studio 3.1​​&#xff1a; 从华为开发者官网下载最新版DevEco Studio安装HarmonyOS 5.0 SDK ​​项目配置​​&#xff1a; // module.json5 {"module": {"requestPermissions": [{"name": "ohos.permis…...

加密通信 + 行为分析:运营商行业安全防御体系重构

在数字经济蓬勃发展的时代&#xff0c;运营商作为信息通信网络的核心枢纽&#xff0c;承载着海量用户数据与关键业务传输&#xff0c;其安全防御体系的可靠性直接关乎国家安全、社会稳定与企业发展。随着网络攻击手段的不断升级&#xff0c;传统安全防护体系逐渐暴露出局限性&a…...

Xcode 16 集成 cocoapods 报错

基于 Xcode 16 新建工程项目&#xff0c;集成 cocoapods 执行 pod init 报错 ### Error RuntimeError - PBXGroup attempted to initialize an object with unknown ISA PBXFileSystemSynchronizedRootGroup from attributes: {"isa">"PBXFileSystemSynchro…...

【51单片机】4. 模块化编程与LCD1602Debug

1. 什么是模块化编程 传统编程会将所有函数放在main.c中&#xff0c;如果使用的模块多&#xff0c;一个文件内会有很多代码&#xff0c;不利于组织和管理 模块化编程则是将各个模块的代码放在不同的.c文件里&#xff0c;在.h文件里提供外部可调用函数声明&#xff0c;其他.c文…...

【java面试】微服务篇

【java面试】微服务篇 一、总体框架二、Springcloud&#xff08;一&#xff09;Springcloud五大组件&#xff08;二&#xff09;服务注册和发现1、Eureka2、Nacos &#xff08;三&#xff09;负载均衡1、Ribbon负载均衡流程2、Ribbon负载均衡策略3、自定义负载均衡策略4、总结 …...