当前位置: 首页 > news >正文

边写代码边学习之RNN

1. 什么是 RNN

循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图:

在这里插入图片描述

 x是输入,h是隐层单元,o为输出,L为损失函数,y为训练集的标签.
这些元素右上角带的t代表t时刻的状态,其中需要注意的是,因策单元h在t时刻的表现不仅由此刻的输入决定,还受t时刻之前时刻的影响。V、W、U是权值,同一类型的权连接权值相同。
有了上面的理解,前向传播算法其实非常简单,对于t时刻:
                                       h ^{(t)} =\phi (Ux^{(t)} +Wh^{(t-1)} +b)

其中\phi ()为激活函数,一般来说会选择tanh函数,b为偏置。
t时刻的输出就更为简单:
                                                     o^{(t)} =Vh ^{(t)} +c
最终模型的预测输出为:
                                                          \hat y^{(t)} =\sigma (o^{(t)} )
其中\sigma为激活函数,通常RNN用于分类,故这里一般用softmax函数。

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

def simple_rnn_layer():# Create a dense layer with 10 output neurons and input shape of (None, 20)model = Sequential()model.add(SimpleRNN(units=3, input_shape=(3, 2),))  # 3 units in the RNN layer, input_shape=(timesteps, features)model.add(Dense(1))  # Output layer with one neuron# Print the summary of the dense layerprint(model.summary())
if __name__ == '__main__':simple_rnn_layer()

输出

Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================simple_rnn (SimpleRNN)      (None, 3)                 18        dense (Dense)               (None, 1)                 4         =================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________
None

2.2. 验证RNN里的逻辑

写代码验证这个过程,看看结果是不是一样的。

import keras.optimizers.optimizer
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense
def change_weight():# Create a simple Dense layerrnn_layer = SimpleRNN(units=3, input_shape=(3, 2), activation=None, return_sequences=True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biases_ = rnn_layer(input_data)# Access the weights and biases of the dense layerkernel, recurrent_kernel, biases = rnn_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel) # (3,3)print('kernal:',kernel) #(2,3)print('biase: ',biases) # (3)kernel = np.array([[1, 0, 2], [2, 1, 3]])recurrent_kernel = np.array([[1, 2, 1.0], [1, 0, 1], [0, 1, 0]])biases = np.array([0, 0, 1.0])rnn_layer.set_weights([kernel, recurrent_kernel, biases])print(rnn_layer.get_weights())test_data = np.array([[[1.0, 3], [1, 1], [2, 3]]])output = rnn_layer(test_data)print(output)if __name__ == '__main__':change_weight()

输出结果如下:可以看到结果是我手算的是一致的。

recurrent_kernel: [[ 0.06973135  0.40464386  0.9118119 ][ 0.6186313  -0.7345941   0.27868783][ 0.7825809   0.5446422  -0.3015495 ]]
kernal: [[-0.48868906  0.52718353 -0.08321357][-1.0569452  -0.9872779   0.72809434]]
biase:  [0. 0. 0.]
[array([[1., 0., 2.],[2., 1., 3.]], dtype=float32), array([[1., 2., 1.],[1., 0., 1.],[0., 1., 0.]], dtype=float32), array([0., 0., 1.], dtype=float32)]
tf.Tensor(
[[[ 7.  3. 12.][13. 27. 16.][48. 45. 54.]]], shape=(1, 3, 3), dtype=float32)

2.3 代码实现一个简单的例子

import keras.optimizers.optimizer
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import SimpleRNN, Dense# Sample sequential data
# Each sequence has three timesteps, and each timestep has two features
data = np.array([[[1, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]
])print('data.shape= ',data.shape)
# Define the RNN model
model = Sequential()
model.add(SimpleRNN(units=4, input_shape=(3, 2), name="simpleRNN"))  # 4 units in the RNN layer, input_shape=(timesteps, features)
model.add(Dense(1, name= "output"))  # Output layer with one neuron# Compile the model
model.compile(loss='mse', optimizer=keras.optimizers.Adam(learning_rate=0.01))# Print the model summary
model.summary()before_RNN_weight = model.get_layer("simpleRNN").get_weights()
print('before train ', before_RNN_weight)# Train the model
model.fit(data, np.array([[10], [20], [30]]), epochs=2000, verbose=1)RNN_weight = model.get_layer("simpleRNN").get_weights()
print('after train ', len(RNN_weight),)for i in range(len(RNN_weight)):print('====',RNN_weight[i].shape, RNN_weight[i])# Make predictions
predictions = model.predict(data)
print("Predictions:", predictions.flatten())

代码输出

data.shape=  (3, 3, 2)
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================simpleRNN (SimpleRNN)       (None, 4)                 28        output (Dense)              (None, 1)                 5         =================================================================
Total params: 33
Trainable params: 33
Non-trainable params: 0
_________________________________________________________________
before train  [array([[-0.00466371,  0.53100157,  0.5298798 ,  0.05514288],[-0.08896947,  0.43185067,  0.7861788 , -0.80616236]],dtype=float32), array([[-0.10712242, -0.03620092, -0.02182053, -0.9933471 ],[-0.6549012 , -0.02620655,  0.7532524 ,  0.05503315],[-0.01986913,  0.9989996 ,  0.02001702, -0.03470401],[-0.74781984,  0.00159313, -0.657065  ,  0.09502006]],dtype=float32), array([0., 0., 0., 0.], dtype=float32)]
2023-08-05 16:02:44.111298: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
Epoch 1/2000
....
Epoch 1999/2000
1/1 [==============================] - 0s 11ms/step - loss: 0.0071
Epoch 2000/2000
1/1 [==============================] - 0s 13ms/step - loss: 0.0070
after train  3
==== (2, 4) [[ 0.27645147  0.6025058   1.6083356  -0.38382724][ 0.11586202  0.32901326  1.4760928  -1.2268958 ]]
==== (4, 4) [[-0.99628973 -2.444563    1.7412992  -1.5265529 ][ 0.80340594  0.9488743   2.44552    -0.7439341 ][-0.1827681  -1.3091801   1.547736   -0.6644555 ][-0.5724374   2.3090494  -2.1779017   0.35992467]]
==== (4,) [-0.40184066 -1.2391611   0.33460653 -0.29144585]
1/1 [==============================] - 0s 78ms/step
Predictions: [10.000422 19.999924 29.85534 ]

相关文章:

边写代码边学习之RNN

1. 什么是 RNN 循环神经网络(Recurrent Neural Network,RNN)是一种以序列数据为输入来进行建模的深度学习模型,它是 NLP 中最常用的模型。其结构如下图: x是输入,h是隐层单元,o为输出&#xff…...

在linux调试进程PID的方法

当我们谈论调试 PID(进程标识符)时,我们通常是指诊断和解决与操作系统中的特定进程相关的问题。有许多工具和方法可用于调试 PID,以下是一些常见的方法: 1. 使用ps命令 ps命令是最基本的调试工具,用于查看…...

【并发编程】线程安全的栈容器

std::stack容器的接口包括 empty(), size(), top(), push(), pop()等。 问题 其原接口在多线程的情况下,会持续很多问题。 例如,在std::stack容器的接口中,在多线程下应用时,empty()和size()的结果是不可信的。因为尽管在某线程…...

ES嵌套查询和普通查询的高亮显示区别

在 Elasticsearch 中,高亮显示是一种强大的搜索结果可视化工具,它可以帮助我们快速识别匹配的关键字或短语。在ES中,我们可以使用两种不同的查询方式来实现高亮显示:嵌套查询和普通查询。本文探讨这两种查询方式的高亮显示区别以及…...

Greenplum集群部署

一,安装说明 1.1环境说明 *名称**版本*操作系统CentOS 7.6 64bitgreenplumgreenplum-db-6.10.1-rhel7-x86_64.rpm1.2集群介绍 IPhostname集群节点10.240.3.244gpmastermaster10.240.3.245gpsegment1segment10.240.3.246gpsegment2segment二,安装环境准备 2.1 修改各节点名称…...

电教智能云数据可视化平台开发电能优化日志实录

电教智能云数据可视化平台开发电脑优化日志实录 一、2K和4K弹窗判断二、电能API对接1.电脑爬虫2.电能分组过滤3.数据可视化渲染4.弹窗 三.数组按顺序输出 一、2K和4K弹窗判断 {* 判断2k和4k弹窗 *}{if $dataScene[scene_standard] eq 0}<a class"menuBtn subMenu"…...

JSX语法基础总结

题记&#xff1a;首先我们要了解一下jsx是什么&#xff0c;跟js有什么区别&#xff0c;其实就是js的语法糖&#xff0c;加上了xml的语法&#xff0c;使得产生虚拟dom更加的方便&#xff0c;简单说一下&#xff0c;xml就是存储数据的格式&#xff0c;想了解xml的话&#xff0c;可…...

socker套接字

1.打印错误信息 2.socketaddr_in结构体 结构体&#xff1a; &#xff08;部分库代码&#xff09; (宏中的##) 3.manual TCP: SOCK_STREAM &#xff1a; 提供有序地&#xff0c;可靠的&#xff0c;全双工的&#xff0c;基于连接的流式服务 UDP: 面向数据报...

No111.精选前端面试题,享受每天的挑战和学习

文章目录 map和foreach的区别在组件中如何获取vuex的action对象中的属性怎么去获取封装在vuex的某个接口数据有没有抓包过&#xff1f;你如何跟踪某一个特定的请求&#xff1f;比如一个特定的URL&#xff0c;你如何把有关这部分的url数据提取出来&#xff1f;1. 使用网络抓包工…...

【Apollo学习笔记】—— 相机仿真

文章目录 前言相关代码整理 测试实践文件目录包管理BUILD文件以及cyberfile.xml文件源程序BUILD运行结果其他参考CameraOutput channels启动camera驱动启动camera video compression驱动 前言 本文是对Cyber RT的学习记录,文章可能存在不严谨、不完善、有缺漏的部分&#xff0…...

【数据结构】——线性表的相关习题

目录 题型一&#xff08;线性表的存储结构&#xff09;题型二&#xff08;链表的判空&#xff09;题型三&#xff08;单链表的建立&#xff09;题型四&#xff08;顺序表、单链表的插入删除操作&#xff09;题型五&#xff08;双链表的插入删除操作&#xff09;题型六&#xff…...

SpringBoot集成Elasticsearch8.x(8)|(新版本Java API Client的Painless语言脚本script使用)

SpringBoot集成Elasticsearch8.x&#xff08;8&#xff09;|&#xff08;新版本Java API Client的Painless语言脚本script使用&#xff09; 文章目录 SpringBoot集成Elasticsearch8.x&#xff08;8&#xff09;|&#xff08;新版本Java API Client的Painless语言脚本script使用…...

SpringBoot复习:(19)Condition接口和@Conditional注解

Condition接口代码如下&#xff1a; public interface Condition {boolean matches(ConditionContext context, AnnotatedTypeMetadata metadata);}它是一个函数式接口&#xff0c;只有一个方法matches用来表示条件是否满足。matches方法中的ConditionContext类对象context可以…...

K8s中的Controller

Controller的作用 &#xff08;1&#xff09;确保预期的pod副本数量 &#xff08;2&#xff09;无状态应用部署 &#xff08;3&#xff09;有状态应用部署 &#xff08;4&#xff09;确保所有的node运行同一个pod&#xff0c;一次性任务和定时任务 1.无状态和有状态 无状态&…...

【MFC】03.常用复杂控件的使用-笔记

热键&#xff1a; 对话框-类向导&#xff1a;初始化函数中&#xff0c;热键需要在最开始的时候就注册进去&#xff1a; 注册热键&#xff1a; 在这之前&#xff0c;先去定义一个宏&#xff0c;代表你这个快捷键。 参数&#xff1a;窗口句柄&#xff0c;热键编号&#xff08;热…...

Autosar诊断实战系列14-NRC优先级解析

本文框架 前言1. NRC分类2. NRC优先级判断2.1. NRC优先级判断逻辑介绍2.2 NRC测试注意事项前言 在本系列笔者将结合工作中对诊断实战部分的应用经验进一步介绍常用UDS服务的进一步探讨及开发中注意事项, Dem/Dcm/CanTp/Fim模块配置开发及注意事项,诊断与BswM/NvM关联模块的应…...

《向量数据库指南》——腾讯云向量数据库Tencent Cloud VectorDB产品特性,架构和应用场景

腾讯云向量数据库(Tencent Cloud VectorDB)是一款全托管的自研企业级分布式数据库服务,专用于存储、检索、分析多维向量数据。该数据库支持多种索引类型和相似度计算方法,单索引支持 10 亿级向量规模,可支持百万级 QPS 及毫秒级查询延迟。腾讯云向量数据库不仅能为大模型提…...

xcode 的app工程与ffmpeg 4.4版本的静态库联调,ffmpeg内下的断点无法暂停。

先阐述一下我的业务场景&#xff0c;我有一个iOS的app sdk项目&#xff0c;下面简称 A &#xff0c;以及运行 A 的 app 项目&#xff0c;简称 A demo 。 引用关系为 A demo 引用了 A &#xff0c;而 A 引用了 ffmpeg 的静态库&#xff08;.a文件&#xff09;。此时业务出现了 b…...

机器学习06 数据准备-(利用 scikit-learn基于Pima Indian数据集作 数据特征选定)

什么是数据特征选定? 数据特征选定&#xff08;Feature Selection&#xff09;是指从原始数据中选择最相关、最有用的特征&#xff0c;用于构建机器学习模型。特征选定是机器学习流程中非常重要的一步&#xff0c;它直接影响模型的性能和泛化能力。通过选择最重要的特征&#…...

机器学习-特征选择:如何使用Lassco回归精确选择最佳特征?

一、引言 特征选择在机器学习领域中扮演着至关重要的角色&#xff0c;它能够从原始数据中选择最具信息量的特征&#xff0c;提高模型性能、减少过拟合&#xff0c;并加快模型训练和预测的速度。在大规模数据集和高维数据中&#xff0c;特征选择尤为重要&#xff0c;因为不必要的…...

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始&#xff0c;我们会探讨数据链路层的差错控制功能&#xff0c;差错控制功能的主要目标是要发现并且解决一个帧内部的位错误&#xff0c;我们需要使用特殊的编码技术去发现帧内部的位错误&#xff0c;当我们发现位错误之后&#xff0c;通常来说有两种解决方案。第一…...

TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案

一、TRS收益互换的本质与业务逻辑 &#xff08;一&#xff09;概念解析 TRS&#xff08;Total Return Swap&#xff09;收益互换是一种金融衍生工具&#xff0c;指交易双方约定在未来一定期限内&#xff0c;基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

04-初识css

一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

Spring数据访问模块设计

前面我们已经完成了IoC和web模块的设计&#xff0c;聪明的码友立马就知道了&#xff0c;该到数据访问模块了&#xff0c;要不就这俩玩个6啊&#xff0c;查库势在必行&#xff0c;至此&#xff0c;它来了。 一、核心设计理念 1、痛点在哪 应用离不开数据&#xff08;数据库、No…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

Golang——7、包与接口详解

包与接口详解 1、Golang包详解1.1、Golang中包的定义和介绍1.2、Golang包管理工具go mod1.3、Golang中自定义包1.4、Golang中使用第三包1.5、init函数 2、接口详解2.1、接口的定义2.2、空接口2.3、类型断言2.4、结构体值接收者和指针接收者实现接口的区别2.5、一个结构体实现多…...

Linux中《基础IO》详细介绍

目录 理解"文件"狭义理解广义理解文件操作的归类认知系统角度文件类别 回顾C文件接口打开文件写文件读文件稍作修改&#xff0c;实现简单cat命令 输出信息到显示器&#xff0c;你有哪些方法stdin & stdout & stderr打开文件的方式 系统⽂件I/O⼀种传递标志位…...