边写代码边学习之RNN
1. 什么是 RNN
循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图:

x是输入,h是隐层单元,o为输出,L为损失函数,y为训练集的标签.
这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。
有了上面的理解,前向传播算法其实非常简单,对于t时刻:
其中为激活函数,一般来说会选择tanh函数,b为偏置。
t时刻的输出就更为简单:
最终模型的预测输出为:
其中为激活函数,通常RNN用于分类,故这里一般用softmax函数。
2. 实验代码
2.1. 搭建一个只有一层RNN和Dense网络的模型。
def simple_rnn_layer():# Create a dense layer with 10 output neurons and input shape of (None, 20)model = Sequential()model.add(SimpleRNN(units=3, input_shape=(3, 2),)) # 3 units in the RNN layer, input_shape=(timesteps, features)model.add(Dense(1)) # Output layer with one neuron# Print the summary of the dense layerprint(model.summary())
if __name__ == '__main__':simple_rnn_layer()
输出
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================simple_rnn (SimpleRNN) (None, 3) 18 dense (Dense) (None, 1) 4 =================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
None
2.2. 验证RNN里的逻辑

写代码验证这个过程,看看结果是不是一样的。
import keras.optimizers.optimizer
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
def change_weight():# Create a simple Dense layerrnn_layer = SimpleRNN(units=3, input_shape=(3, 2), activation=None, return_sequences=True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biases_ = rnn_layer(input_data)# Access the weights and biases of the dense layerkernel, recurrent_kernel, biases = rnn_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel) # (3,3)print('kernal:',kernel) #(2,3)print('biase: ',biases) # (3)kernel = np.array([[1, 0, 2], [2, 1, 3]])recurrent_kernel = np.array([[1, 2, 1.0], [1, 0, 1], [0, 1, 0]])biases = np.array([0, 0, 1.0])rnn_layer.set_weights([kernel, recurrent_kernel, biases])print(rnn_layer.get_weights())test_data = np.array([[[1.0, 3], [1, 1], [2, 3]]])output = rnn_layer(test_data)print(output)if __name__ == '__main__':change_weight()
输出结果如下:可以看到结果是我手算的是一致的。
recurrent_kernel: [[ 0.06973135 0.40464386 0.9118119 ][ 0.6186313 -0.7345941 0.27868783][ 0.7825809 0.5446422 -0.3015495 ]]
kernal: [[-0.48868906 0.52718353 -0.08321357][-1.0569452 -0.9872779 0.72809434]]
biase: [0. 0. 0.]
[array([[1., 0., 2.],[2., 1., 3.]], dtype=float32), array([[1., 2., 1.],[1., 0., 1.],[0., 1., 0.]], dtype=float32), array([0., 0., 1.], dtype=float32)]
tf.Tensor(
[[[ 7. 3. 12.][13. 27. 16.][48. 45. 54.]]], shape=(1, 3, 3), dtype=float32)
2.3 代码实现一个简单的例子
import keras.optimizers.optimizer
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense# Sample sequential data
# Each sequence has three timesteps, and each timestep has two features
data = np.array([[[1, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]
])print('data.shape= ',data.shape)
# Define the RNN model
model = Sequential()
model.add(SimpleRNN(units=4, input_shape=(3, 2), name="simpleRNN")) # 4 units in the RNN layer, input_shape=(timesteps, features)
model.add(Dense(1, name= "output")) # Output layer with one neuron# Compile the model
model.compile(loss='mse', optimizer=keras.optimizers.Adam(learning_rate=0.01))# Print the model summary
model.summary()before_RNN_weight = model.get_layer("simpleRNN").get_weights()
print('before train ', before_RNN_weight)# Train the model
model.fit(data, np.array([[10], [20], [30]]), epochs=2000, verbose=1)RNN_weight = model.get_layer("simpleRNN").get_weights()
print('after train ', len(RNN_weight),)for i in range(len(RNN_weight)):print('====',RNN_weight[i].shape, RNN_weight[i])# Make predictions
predictions = model.predict(data)
print("Predictions:", predictions.flatten())
代码输出
data.shape= (3, 3, 2)
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================simpleRNN (SimpleRNN) (None, 4) 28 output (Dense) (None, 1) 5 =================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
_________________________________________________________________
before train [array([[-0.00466371, 0.53100157, 0.5298798 , 0.05514288],[-0.08896947, 0.43185067, 0.7861788 , -0.80616236]],dtype=float32), array([[-0.10712242, -0.03620092, -0.02182053, -0.9933471 ],[-0.6549012 , -0.02620655, 0.7532524 , 0.05503315],[-0.01986913, 0.9989996 , 0.02001702, -0.03470401],[-0.74781984, 0.00159313, -0.657065 , 0.09502006]],dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
2023-08-05 16:02:44.111298: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/2000
....
Epoch 1999/2000
1/1 [==============================] - 0s 11ms/step - loss: 0.0071
Epoch 2000/2000
1/1 [==============================] - 0s 13ms/step - loss: 0.0070
after train 3
==== (2, 4) [[ 0.27645147 0.6025058 1.6083356 -0.38382724][ 0.11586202 0.32901326 1.4760928 -1.2268958 ]]
==== (4, 4) [[-0.99628973 -2.444563 1.7412992 -1.5265529 ][ 0.80340594 0.9488743 2.44552 -0.7439341 ][-0.1827681 -1.3091801 1.547736 -0.6644555 ][-0.5724374 2.3090494 -2.1779017 0.35992467]]
==== (4,) [-0.40184066 -1.2391611 0.33460653 -0.29144585]
1/1 [==============================] - 0s 78ms/step
Predictions: [10.000422 19.999924 29.85534 ]
相关文章:
边写代码边学习之RNN
1. 什么是 RNN 循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图: x是输入,h是隐层单元,o为输出ÿ…...
在linux调试进程PID的方法
当我们谈论调试 PID(进程标识符)时,我们通常是指诊断和解决与操作系统中的特定进程相关的问题。有许多工具和方法可用于调试 PID,以下是一些常见的方法: 1. 使用ps命令 ps命令是最基本的调试工具,用于查看…...
【并发编程】线程安全的栈容器
std::stack容器的接口包括 empty(), size(), top(), push(), pop()等。 问题 其原接口在多线程的情况下,会持续很多问题。 例如,在std::stack容器的接口中,在多线程下应用时,empty()和size()的结果是不可信的。因为尽管在某线程…...
ES嵌套查询和普通查询的高亮显示区别
在 Elasticsearch 中,高亮显示是一种强大的搜索结果可视化工具,它可以帮助我们快速识别匹配的关键字或短语。在ES中,我们可以使用两种不同的查询方式来实现高亮显示:嵌套查询和普通查询。本文探讨这两种查询方式的高亮显示区别以及…...
Greenplum集群部署
一,安装说明 1.1环境说明 *名称**版本*操作系统CentOS 7.6 64bitgreenplumgreenplum-db-6.10.1-rhel7-x86_64.rpm1.2集群介绍 IPhostname集群节点10.240.3.244gpmastermaster10.240.3.245gpsegment1segment10.240.3.246gpsegment2segment二,安装环境准备 2.1 修改各节点名称…...
电教智能云数据可视化平台开发电能优化日志实录
电教智能云数据可视化平台开发电脑优化日志实录 一、2K和4K弹窗判断二、电能API对接1.电脑爬虫2.电能分组过滤3.数据可视化渲染4.弹窗 三.数组按顺序输出 一、2K和4K弹窗判断 {* 判断2k和4k弹窗 *}{if $dataScene[scene_standard] eq 0}<a class"menuBtn subMenu"…...
JSX语法基础总结
题记:首先我们要了解一下jsx是什么,跟js有什么区别,其实就是js的语法糖,加上了xml的语法,使得产生虚拟dom更加的方便,简单说一下,xml就是存储数据的格式,想了解xml的话,可…...
socker套接字
1.打印错误信息 2.socketaddr_in结构体 结构体: (部分库代码) (宏中的##) 3.manual TCP: SOCK_STREAM : 提供有序地,可靠的,全双工的,基于连接的流式服务 UDP: 面向数据报...
No111.精选前端面试题,享受每天的挑战和学习
文章目录 map和foreach的区别在组件中如何获取vuex的action对象中的属性怎么去获取封装在vuex的某个接口数据有没有抓包过?你如何跟踪某一个特定的请求?比如一个特定的URL,你如何把有关这部分的url数据提取出来?1. 使用网络抓包工…...
【Apollo学习笔记】—— 相机仿真
文章目录 前言相关代码整理 测试实践文件目录包管理BUILD文件以及cyberfile.xml文件源程序BUILD运行结果其他参考CameraOutput channels启动camera驱动启动camera video compression驱动 前言 本文是对Cyber RT的学习记录,文章可能存在不严谨、不完善、有缺漏的部分࿰…...
【数据结构】——线性表的相关习题
目录 题型一(线性表的存储结构)题型二(链表的判空)题型三(单链表的建立)题型四(顺序表、单链表的插入删除操作)题型五(双链表的插入删除操作)题型六ÿ…...
SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用)
SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用) 文章目录 SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用…...
SpringBoot复习:(19)Condition接口和@Conditional注解
Condition接口代码如下: public interface Condition {boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata);}它是一个函数式接口,只有一个方法matches用来表示条件是否满足。matches方法中的ConditionContext类对象context可以…...
K8s中的Controller
Controller的作用 (1)确保预期的pod副本数量 (2)无状态应用部署 (3)有状态应用部署 (4)确保所有的node运行同一个pod,一次性任务和定时任务 1.无状态和有状态 无状态&…...
【MFC】03.常用复杂控件的使用-笔记
热键: 对话框-类向导:初始化函数中,热键需要在最开始的时候就注册进去: 注册热键: 在这之前,先去定义一个宏,代表你这个快捷键。 参数:窗口句柄,热键编号(热…...
Autosar诊断实战系列14-NRC优先级解析
本文框架 前言1. NRC分类2. NRC优先级判断2.1. NRC优先级判断逻辑介绍2.2 NRC测试注意事项前言 在本系列笔者将结合工作中对诊断实战部分的应用经验进一步介绍常用UDS服务的进一步探讨及开发中注意事项, Dem/Dcm/CanTp/Fim模块配置开发及注意事项,诊断与BswM/NvM关联模块的应…...
《向量数据库指南》——腾讯云向量数据库Tencent Cloud VectorDB产品特性,架构和应用场景
腾讯云向量数据库(Tencent Cloud VectorDB)是一款全托管的自研企业级分布式数据库服务,专用于存储、检索、分析多维向量数据。该数据库支持多种索引类型和相似度计算方法,单索引支持 10 亿级向量规模,可支持百万级 QPS 及毫秒级查询延迟。腾讯云向量数据库不仅能为大模型提…...
xcode 的app工程与ffmpeg 4.4版本的静态库联调,ffmpeg内下的断点无法暂停。
先阐述一下我的业务场景,我有一个iOS的app sdk项目,下面简称 A ,以及运行 A 的 app 项目,简称 A demo 。 引用关系为 A demo 引用了 A ,而 A 引用了 ffmpeg 的静态库(.a文件)。此时业务出现了 b…...
机器学习06 数据准备-(利用 scikit-learn基于Pima Indian数据集作 数据特征选定)
什么是数据特征选定? 数据特征选定(Feature Selection)是指从原始数据中选择最相关、最有用的特征,用于构建机器学习模型。特征选定是机器学习流程中非常重要的一步,它直接影响模型的性能和泛化能力。通过选择最重要的特征&#…...
机器学习-特征选择:如何使用Lassco回归精确选择最佳特征?
一、引言 特征选择在机器学习领域中扮演着至关重要的角色,它能够从原始数据中选择最具信息量的特征,提高模型性能、减少过拟合,并加快模型训练和预测的速度。在大规模数据集和高维数据中,特征选择尤为重要,因为不必要的…...
基于服务器使用 apt 安装、配置 Nginx
🧾 一、查看可安装的 Nginx 版本 首先,你可以运行以下命令查看可用版本: apt-cache madison nginx-core输出示例: nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...
vue3 字体颜色设置的多种方式
在Vue 3中设置字体颜色可以通过多种方式实现,这取决于你是想在组件内部直接设置,还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法: 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...
Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
HDFS分布式存储 zookeeper
hadoop介绍 狭义上hadoop是指apache的一款开源软件 用java语言实现开源框架,允许使用简单的变成模型跨计算机对大型集群进行分布式处理(1.海量的数据存储 2.海量数据的计算)Hadoop核心组件 hdfs(分布式文件存储系统)&a…...
技术栈RabbitMq的介绍和使用
目录 1. 什么是消息队列?2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...
【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...
Web后端基础(基础知识)
BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器,应用程序的逻辑和数据都存储在服务端。 优点:维护方便缺点:体验一般 CS架构:Client/Server,客户端/服务器架构模式。需要单独…...
LOOI机器人的技术实现解析:从手势识别到边缘检测
LOOI机器人作为一款创新的AI硬件产品,通过将智能手机转变为具有情感交互能力的桌面机器人,展示了前沿AI技术与传统硬件设计的完美结合。作为AI与玩具领域的专家,我将全面解析LOOI的技术实现架构,特别是其手势识别、物体识别和环境…...
