边写代码边学习之RNN
1. 什么是 RNN
循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图:

x是输入,h是隐层单元,o为输出,L为损失函数,y为训练集的标签.
这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。
有了上面的理解,前向传播算法其实非常简单,对于t时刻:
其中为激活函数,一般来说会选择tanh函数,b为偏置。
t时刻的输出就更为简单:
最终模型的预测输出为:
其中为激活函数,通常RNN用于分类,故这里一般用softmax函数。
2. 实验代码
2.1. 搭建一个只有一层RNN和Dense网络的模型。
def simple_rnn_layer():# Create a dense layer with 10 output neurons and input shape of (None, 20)model = Sequential()model.add(SimpleRNN(units=3, input_shape=(3, 2),)) # 3 units in the RNN layer, input_shape=(timesteps, features)model.add(Dense(1)) # Output layer with one neuron# Print the summary of the dense layerprint(model.summary())
if __name__ == '__main__':simple_rnn_layer()
输出
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================simple_rnn (SimpleRNN) (None, 3) 18 dense (Dense) (None, 1) 4 =================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
None
2.2. 验证RNN里的逻辑

写代码验证这个过程,看看结果是不是一样的。
import keras.optimizers.optimizer
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
def change_weight():# Create a simple Dense layerrnn_layer = SimpleRNN(units=3, input_shape=(3, 2), activation=None, return_sequences=True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biases_ = rnn_layer(input_data)# Access the weights and biases of the dense layerkernel, recurrent_kernel, biases = rnn_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel) # (3,3)print('kernal:',kernel) #(2,3)print('biase: ',biases) # (3)kernel = np.array([[1, 0, 2], [2, 1, 3]])recurrent_kernel = np.array([[1, 2, 1.0], [1, 0, 1], [0, 1, 0]])biases = np.array([0, 0, 1.0])rnn_layer.set_weights([kernel, recurrent_kernel, biases])print(rnn_layer.get_weights())test_data = np.array([[[1.0, 3], [1, 1], [2, 3]]])output = rnn_layer(test_data)print(output)if __name__ == '__main__':change_weight()
输出结果如下:可以看到结果是我手算的是一致的。
recurrent_kernel: [[ 0.06973135 0.40464386 0.9118119 ][ 0.6186313 -0.7345941 0.27868783][ 0.7825809 0.5446422 -0.3015495 ]]
kernal: [[-0.48868906 0.52718353 -0.08321357][-1.0569452 -0.9872779 0.72809434]]
biase: [0. 0. 0.]
[array([[1., 0., 2.],[2., 1., 3.]], dtype=float32), array([[1., 2., 1.],[1., 0., 1.],[0., 1., 0.]], dtype=float32), array([0., 0., 1.], dtype=float32)]
tf.Tensor(
[[[ 7. 3. 12.][13. 27. 16.][48. 45. 54.]]], shape=(1, 3, 3), dtype=float32)
2.3 代码实现一个简单的例子
import keras.optimizers.optimizer
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense# Sample sequential data
# Each sequence has three timesteps, and each timestep has two features
data = np.array([[[1, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]
])print('data.shape= ',data.shape)
# Define the RNN model
model = Sequential()
model.add(SimpleRNN(units=4, input_shape=(3, 2), name="simpleRNN")) # 4 units in the RNN layer, input_shape=(timesteps, features)
model.add(Dense(1, name= "output")) # Output layer with one neuron# Compile the model
model.compile(loss='mse', optimizer=keras.optimizers.Adam(learning_rate=0.01))# Print the model summary
model.summary()before_RNN_weight = model.get_layer("simpleRNN").get_weights()
print('before train ', before_RNN_weight)# Train the model
model.fit(data, np.array([[10], [20], [30]]), epochs=2000, verbose=1)RNN_weight = model.get_layer("simpleRNN").get_weights()
print('after train ', len(RNN_weight),)for i in range(len(RNN_weight)):print('====',RNN_weight[i].shape, RNN_weight[i])# Make predictions
predictions = model.predict(data)
print("Predictions:", predictions.flatten())
代码输出
data.shape= (3, 3, 2)
Model: "sequential"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================simpleRNN (SimpleRNN) (None, 4) 28 output (Dense) (None, 1) 5 =================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
_________________________________________________________________
before train [array([[-0.00466371, 0.53100157, 0.5298798 , 0.05514288],[-0.08896947, 0.43185067, 0.7861788 , -0.80616236]],dtype=float32), array([[-0.10712242, -0.03620092, -0.02182053, -0.9933471 ],[-0.6549012 , -0.02620655, 0.7532524 , 0.05503315],[-0.01986913, 0.9989996 , 0.02001702, -0.03470401],[-0.74781984, 0.00159313, -0.657065 , 0.09502006]],dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
2023-08-05 16:02:44.111298: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/2000
....
Epoch 1999/2000
1/1 [==============================] - 0s 11ms/step - loss: 0.0071
Epoch 2000/2000
1/1 [==============================] - 0s 13ms/step - loss: 0.0070
after train 3
==== (2, 4) [[ 0.27645147 0.6025058 1.6083356 -0.38382724][ 0.11586202 0.32901326 1.4760928 -1.2268958 ]]
==== (4, 4) [[-0.99628973 -2.444563 1.7412992 -1.5265529 ][ 0.80340594 0.9488743 2.44552 -0.7439341 ][-0.1827681 -1.3091801 1.547736 -0.6644555 ][-0.5724374 2.3090494 -2.1779017 0.35992467]]
==== (4,) [-0.40184066 -1.2391611 0.33460653 -0.29144585]
1/1 [==============================] - 0s 78ms/step
Predictions: [10.000422 19.999924 29.85534 ]
相关文章:
边写代码边学习之RNN
1. 什么是 RNN 循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图: x是输入,h是隐层单元,o为输出ÿ…...
在linux调试进程PID的方法
当我们谈论调试 PID(进程标识符)时,我们通常是指诊断和解决与操作系统中的特定进程相关的问题。有许多工具和方法可用于调试 PID,以下是一些常见的方法: 1. 使用ps命令 ps命令是最基本的调试工具,用于查看…...
【并发编程】线程安全的栈容器
std::stack容器的接口包括 empty(), size(), top(), push(), pop()等。 问题 其原接口在多线程的情况下,会持续很多问题。 例如,在std::stack容器的接口中,在多线程下应用时,empty()和size()的结果是不可信的。因为尽管在某线程…...
ES嵌套查询和普通查询的高亮显示区别
在 Elasticsearch 中,高亮显示是一种强大的搜索结果可视化工具,它可以帮助我们快速识别匹配的关键字或短语。在ES中,我们可以使用两种不同的查询方式来实现高亮显示:嵌套查询和普通查询。本文探讨这两种查询方式的高亮显示区别以及…...
Greenplum集群部署
一,安装说明 1.1环境说明 *名称**版本*操作系统CentOS 7.6 64bitgreenplumgreenplum-db-6.10.1-rhel7-x86_64.rpm1.2集群介绍 IPhostname集群节点10.240.3.244gpmastermaster10.240.3.245gpsegment1segment10.240.3.246gpsegment2segment二,安装环境准备 2.1 修改各节点名称…...
电教智能云数据可视化平台开发电能优化日志实录
电教智能云数据可视化平台开发电脑优化日志实录 一、2K和4K弹窗判断二、电能API对接1.电脑爬虫2.电能分组过滤3.数据可视化渲染4.弹窗 三.数组按顺序输出 一、2K和4K弹窗判断 {* 判断2k和4k弹窗 *}{if $dataScene[scene_standard] eq 0}<a class"menuBtn subMenu"…...
JSX语法基础总结
题记:首先我们要了解一下jsx是什么,跟js有什么区别,其实就是js的语法糖,加上了xml的语法,使得产生虚拟dom更加的方便,简单说一下,xml就是存储数据的格式,想了解xml的话,可…...
socker套接字
1.打印错误信息 2.socketaddr_in结构体 结构体: (部分库代码) (宏中的##) 3.manual TCP: SOCK_STREAM : 提供有序地,可靠的,全双工的,基于连接的流式服务 UDP: 面向数据报...
No111.精选前端面试题,享受每天的挑战和学习
文章目录 map和foreach的区别在组件中如何获取vuex的action对象中的属性怎么去获取封装在vuex的某个接口数据有没有抓包过?你如何跟踪某一个特定的请求?比如一个特定的URL,你如何把有关这部分的url数据提取出来?1. 使用网络抓包工…...
【Apollo学习笔记】—— 相机仿真
文章目录 前言相关代码整理 测试实践文件目录包管理BUILD文件以及cyberfile.xml文件源程序BUILD运行结果其他参考CameraOutput channels启动camera驱动启动camera video compression驱动 前言 本文是对Cyber RT的学习记录,文章可能存在不严谨、不完善、有缺漏的部分࿰…...
【数据结构】——线性表的相关习题
目录 题型一(线性表的存储结构)题型二(链表的判空)题型三(单链表的建立)题型四(顺序表、单链表的插入删除操作)题型五(双链表的插入删除操作)题型六ÿ…...
SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用)
SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用) 文章目录 SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用…...
SpringBoot复习:(19)Condition接口和@Conditional注解
Condition接口代码如下: public interface Condition {boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata);}它是一个函数式接口,只有一个方法matches用来表示条件是否满足。matches方法中的ConditionContext类对象context可以…...
K8s中的Controller
Controller的作用 (1)确保预期的pod副本数量 (2)无状态应用部署 (3)有状态应用部署 (4)确保所有的node运行同一个pod,一次性任务和定时任务 1.无状态和有状态 无状态&…...
【MFC】03.常用复杂控件的使用-笔记
热键: 对话框-类向导:初始化函数中,热键需要在最开始的时候就注册进去: 注册热键: 在这之前,先去定义一个宏,代表你这个快捷键。 参数:窗口句柄,热键编号(热…...
Autosar诊断实战系列14-NRC优先级解析
本文框架 前言1. NRC分类2. NRC优先级判断2.1. NRC优先级判断逻辑介绍2.2 NRC测试注意事项前言 在本系列笔者将结合工作中对诊断实战部分的应用经验进一步介绍常用UDS服务的进一步探讨及开发中注意事项, Dem/Dcm/CanTp/Fim模块配置开发及注意事项,诊断与BswM/NvM关联模块的应…...
《向量数据库指南》——腾讯云向量数据库Tencent Cloud VectorDB产品特性,架构和应用场景
腾讯云向量数据库(Tencent Cloud VectorDB)是一款全托管的自研企业级分布式数据库服务,专用于存储、检索、分析多维向量数据。该数据库支持多种索引类型和相似度计算方法,单索引支持 10 亿级向量规模,可支持百万级 QPS 及毫秒级查询延迟。腾讯云向量数据库不仅能为大模型提…...
xcode 的app工程与ffmpeg 4.4版本的静态库联调,ffmpeg内下的断点无法暂停。
先阐述一下我的业务场景,我有一个iOS的app sdk项目,下面简称 A ,以及运行 A 的 app 项目,简称 A demo 。 引用关系为 A demo 引用了 A ,而 A 引用了 ffmpeg 的静态库(.a文件)。此时业务出现了 b…...
机器学习06 数据准备-(利用 scikit-learn基于Pima Indian数据集作 数据特征选定)
什么是数据特征选定? 数据特征选定(Feature Selection)是指从原始数据中选择最相关、最有用的特征,用于构建机器学习模型。特征选定是机器学习流程中非常重要的一步,它直接影响模型的性能和泛化能力。通过选择最重要的特征&#…...
机器学习-特征选择:如何使用Lassco回归精确选择最佳特征?
一、引言 特征选择在机器学习领域中扮演着至关重要的角色,它能够从原始数据中选择最具信息量的特征,提高模型性能、减少过拟合,并加快模型训练和预测的速度。在大规模数据集和高维数据中,特征选择尤为重要,因为不必要的…...
基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...
在鸿蒙HarmonyOS 5中实现抖音风格的点赞功能
下面我将详细介绍如何使用HarmonyOS SDK在HarmonyOS 5中实现类似抖音的点赞功能,包括动画效果、数据同步和交互优化。 1. 基础点赞功能实现 1.1 创建数据模型 // VideoModel.ets export class VideoModel {id: string "";title: string ""…...
FFmpeg 低延迟同屏方案
引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...
Auto-Coder使用GPT-4o完成:在用TabPFN这个模型构建一个预测未来3天涨跌的分类任务
通过akshare库,获取股票数据,并生成TabPFN这个模型 可以识别、处理的格式,写一个完整的预处理示例,并构建一个预测未来 3 天股价涨跌的分类任务 用TabPFN这个模型构建一个预测未来 3 天股价涨跌的分类任务,进行预测并输…...
【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统
目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...
华硕a豆14 Air香氛版,美学与科技的馨香融合
在快节奏的现代生活中,我们渴望一个能激发创想、愉悦感官的工作与生活伙伴,它不仅是冰冷的科技工具,更能触动我们内心深处的细腻情感。正是在这样的期许下,华硕a豆14 Air香氛版翩然而至,它以一种前所未有的方式&#x…...
Python ROS2【机器人中间件框架】 简介
销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...
JAVA后端开发——多租户
数据隔离是多租户系统中的核心概念,确保一个租户(在这个系统中可能是一个公司或一个独立的客户)的数据对其他租户是不可见的。在 RuoYi 框架(您当前项目所使用的基础框架)中,这通常是通过在数据表中增加一个…...
scikit-learn机器学习
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: # Also add the following code, # so that every time the environment (kernel) starts, # just run the following code: import sys sys.path.append(/home/aistudio/external-libraries)机…...
uniapp 字符包含的相关方法
在uniapp中,如果你想检查一个字符串是否包含另一个子字符串,你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的,但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...
