当前位置: 首页 > news >正文

循环神经网络的变体模型-LSTM、GRU

一.LSTM(长短时记忆网络)

1.1基本介绍

长短时记忆网络(Long Short-Term Memory,LSTM)是一种深度学习模型,属于循环神经网络(Recurrent Neural Network,RNN)的一种变体。LSTM的设计旨在解决传统RNN中遇到的长序列依赖问题,以更好地捕捉和处理序列数据中的长期依赖关系。

下面是LSTM的内部结构图

LSTM

LSTM为了改善梯度消失,引入了一种特殊的存储单元,该存储单元被设计用于存储和提取长期记忆。与传统的RNN不同,LSTM包含三个关键的门(gate)来控制信息的流动,这些门分别是遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate)。

LSTM的结构允许它有效地处理和学习序列中的长期依赖关系,这在许多任务中很有用,如自然语言处理、语音识别和时间序列预测。由于其能捕获长期记忆,LSTM成为深度学习中重要的组件之一。

1.2 主要组成部分和工作原理

首先我们先弄明白LSTM单元中的每个符号的含义。每个黄色方框表示一个神经网络层,由权值,偏置以及激活函数组成;每个粉色圆圈表示元素级别操作;箭头表示向量流向;相交的箭头表示向量的拼接;分叉的箭头表示向量的复制。
图中元素的节点信息

以下是LSTM的主要组成部分和工作原理:

  1. 细胞状态(Cell State):
    细胞状态是LSTM网络的主要存储单元,用于存储和传递长期记忆。细胞状态在序列的每一步都会被更新。在LSTM中,细胞状态负责保留网络需要记住的信息,以便更好地处理长期依赖关系。在每个时间步,LSTM通过一系列的操作来更新细胞状态。这些操作包括遗忘门、输入门和输出门的计算。细胞状态在这些门的帮助下动态地保留和遗忘信息。
    细胞状态

  2. 遗忘门(Forget Gate):
    遗忘门决定哪些信息应该被遗忘,从而允许网络丢弃不重要的信息。它通过一个sigmoid激活函数生成一个介于0和1之间的值,用于控制细胞状态中信息的丢失程度。
    遗忘门的计算过程如下:
    2.1 输入:
    上一时刻的隐藏状态(或者是输入数据的向量)
    当前时刻的输入数据
    2.2 计算遗忘门的值:
    将上一时刻的隐藏状态和当前时刻的输入数据拼接在一起。
    通过一个带有sigmoid激活函数的全连接层(通常称为遗忘门层)得到介于0和1之间的值。
    这个值表示细胞状态中哪些信息应该被保留(接近1),哪些信息应该被遗忘(接近0)。
    2.3 遗忘操作:
    将上一时刻的细胞状态与遗忘门的输出相乘,以决定保留哪些信息。
    2.4数学表达式如下:
    遗忘门的输出:
    遗忘门

其中:
W f 和 b f 是遗忘门的权重矩阵和偏置向量。 W_f 和 b_f是遗忘门的权重矩阵和偏置向量。 Wfbf是遗忘门的权重矩阵和偏置向量。
h t − 1 ​是上一时刻的隐藏状态。 h_{t−1}​ 是上一时刻的隐藏状态。 ht1是上一时刻的隐藏状态。
x t 是当前时刻的输入数据。 x_t是当前时刻的输入数据。 xt是当前时刻的输入数据。
σ 是 s i g m o i d 激活函数。 σ 是sigmoid激活函数。 σsigmoid激活函数。

遗忘门的输出 ft 决定了细胞状态中上一时刻信息的保留程度。这个机制允许LSTM网络在处理时间序列数据时更有效地记住长期依赖关系。

  1. 输入门(Input Gate):
    输入门负责确定在当前时间步骤中要添加到细胞状态的新信息。类似于遗忘门,输入门使用sigmoid激活函数产生一个介于0和1之间的值,表示要保留多少新信息,并使用tanh激活函数生成一个新的候选值。
    在这里插入图片描述输入门的计算过程如下:
(1)输入门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示要保留的新信息的程度。
(2)生成新的候选值:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有tanh激活函数的全连接层得到一个新的候选值(介于-1和1之间)。
(3)更新细胞状态的操作:将输入门的输出与新的候选值相乘,得到要添加到细胞状态的新信息。
  1. 输出门(Output Gate):
    输出门(Output Gate)在LSTM中控制细胞在特定时间步上的输出。输出门使用sigmoid激活函数产生介于0和1之间的值,这个值决定了在当前时间步细胞状态中有多少信息被输出。同时,输出门的输出与细胞状态经过tanh激活函数后的值相乘,产生最终的LSTM输出。

输出门的计算过程如下:

输出门的输出计算:将上一时刻的隐藏状态(或者是输入数据)和当前时刻的输入数据拼接在一起。通过一个带有sigmoid激活函数的全连接层得到介于0和1之间的值。这个值表示在当前时间步细胞状态中有多少信息要输出。
生成最终的LSTM输出:将当前时刻的细胞状态经过tanh激活函数,得到介于-1和1之间的值。将输出门的输出与tanh激活函数的细胞状态相乘,产生最终的LSTM输出。

在这里插入图片描述

1.3 LSTM的基础代码实现

以下是一个基础的实现,其中包括多层双向LSTM的前向传播。请注意,这个实现仍然是一个简化版本,实际应用中可能需要更多的调整和优化。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def lstm_cell(xt, a_prev, c_prev, parameters):# 从参数中提取权重和偏置Wf = parameters["Wf"]bf = parameters["bf"]Wi = parameters["Wi"]bi = parameters["bi"]Wo = parameters["Wo"]bo = parameters["bo"]Wc = parameters["Wc"]bc = parameters["bc"]# 合并输入和上一个时间步的隐藏状态concat = np.concatenate((a_prev, xt), axis=0)# 遗忘门ft = sigmoid(np.dot(Wf, concat) + bf)# 输入门it = sigmoid(np.dot(Wi, concat) + bi)# 更新细胞状态cct = tanh(np.dot(Wc, concat) + bc)c_next = ft * c_prev + it * cct# 输出门ot = sigmoid(np.dot(Wo, concat) + bo)# 更新隐藏状态a_next = ot * tanh(c_next)# 保存计算中间结果,以便反向传播cache = (xt, a_prev, c_prev, a_next, c_next, ft, it, ot, cct)return a_next, c_next, cachedef lstm_forward(x, a0, parameters):n_x, m, T_x = x.shapen_a = a0.shape[0]a = np.zeros((n_a, m, T_x))c = np.zeros_like(a)caches = []a_prev = a0c_prev = np.zeros_like(a_prev)for t in range(T_x):xt = x[:, :, t]a_next, c_next, cache = lstm_cell(xt, a_prev, c_prev, parameters)a[:,:,t] = a_nextc[:,:,t] = c_nextcaches.append(cache)a_prev = a_nextc_prev = c_nextreturn a, c, cachesdef lstm_model_forward(x, parameters):caches = []a = xc_list = []for layer in parameters:a, c, layer_cache = lstm_forward(a, np.zeros_like(a[:, :, 0]), layer)caches.append(layer_cache)c_list.append(c)return a, c_list, cachesdef dense_layer_forward(a, parameters):W = parameters["W"]b = parameters["b"]z = np.dot(W, a) + ba_next = sigmoid(z)return a_next, zdef model_forward(x, parameters_lstm, parameters_dense):a_lstm, c_list, caches_lstm = lstm_model_forward(x, parameters_lstm)a_dense = a_lstm[:, :, -1]z_dense_list = []for layer_dense in parameters_dense:a_dense, z_dense = dense_layer_forward(a_dense, layer_dense)z_dense_list.append(z_dense)return a_dense, c_list, caches_lstm, z_dense_list# 示例数据和参数
np.random.seed(1)
x = np.random.randn(10, 5, 3)  # 10个样本,每个样本5个时间步,每个时间步3个特征# LSTM参数
parameters_lstm = [{"Wf": np.random.randn(5, 8), "bf": np.random.randn(5, 1),"Wi": np.random.randn(5, 8), "bi": np.random.randn(5, 1),"Wo": np.random.randn(5, 8), "bo": np.random.randn(5, 1),"Wc": np.random.randn(5, 8), "bc": np.random.randn(5, 1)},{"Wf": np.random.randn(3, 8), "bf": np.random.randn(3, 1),"Wi": np.random.randn(3, 8), "bi": np.random.randn(3, 1),"Wo": np.random.randn(3, 8), "bo": np.random.randn(3, 1),"Wc": np.random.randn(3, 8), "bc": np.random.randn(3, 1)}
]# Dense层参数
parameters_dense = [{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)},{"W": np.random.randn(1, 5), "b": np.random.randn(1, 1)}
]# 进行正向传播
a_dense, c_list, caches_lstm, z_dense_list = model_forward(x, parameters_lstm, parameters_dense)# 打印输出形状
print("a_dense.shape:", a_dense.shape)

二.GRU(门控循环单元)

GRU

2.1 GRU的基本介绍

门控循环单元(GRU,Gated Recurrent Unit)是一种用于处理序列数据的循环神经网络(RNN)变体,旨在解决传统RNN中的梯度消失问题,并提供更好的长期依赖建模。GRU引入了门控机制,类似于LSTM,但相对于LSTM,GRU结构更加简单。

GRU包含两个门:更新门(Update Gate)和重置门(Reset Gate)。这两个门允许GRU网络决定在当前时间步更新细胞状态的程度以及如何利用先前的隐藏状态。

重置门(Reset Gate)的计算:

通过一个sigmoid激活函数计算重置门的输出。重置门决定了在当前时间步,应该忽略多少先前的隐藏状态信息。

更新门(Update Gate)的计算:

通过一个sigmoid激活函数计算更新门的输出。更新门决定了在当前时间步,应该保留多少先前的隐藏状态信息。

候选隐藏状态的计算:

通过tanh激活函数计算一个候选的隐藏状态。

新的隐藏状态的计算:

通过更新门和候选隐藏状态计算新的隐藏状态。

2.2 GRU的代码实现

以下是使用PyTorch库实现基本的门控循环单元(GRU)的代码。PyTorch提供了GRU的高级API,可以轻松实现和使用。下面是一个简单的例子:

import torch
import torch.nn as nn# 定义GRU模型
class SimpleGRU(nn.Module):def __init__(self, input_size, hidden_size):super(SimpleGRU, self).__init__()self.gru = nn.GRU(input_size, hidden_size)def forward(self, x, hidden=None):output, hidden = self.gru(x, hidden)return output, hidden# 示例数据和模型参数
input_size = 3
hidden_size = 5
seq_len = 1  # 序列长度
batch_size = 1# 创建GRU模型
gru_model = SimpleGRU(input_size, hidden_size)# 将输入数据转换为PyTorch的Tensor
x = torch.randn(seq_len, batch_size, input_size)# 前向传播
output, hidden = gru_model(x)# 打印输出形状
print("Output shape:", output.shape)
print("Hidden shape:", hidden.shape)

以下是使用NumPy库实现基本的门控循环单元(GRU)的代码。这个实现是一个简化版本,其中包含更新门和重置门的计算,以及候选隐藏状态和新的隐藏状态的计算。

import numpy as npdef sigmoid(x):return 1 / (1 + np.exp(-x))def tanh(x):return np.tanh(x)def gru_cell(a_prev, x, parameters):# 从参数中提取权重和偏置W_r = parameters["W_r"]b_r = parameters["b_r"]W_z = parameters["W_z"]b_z = parameters["b_z"]W_a = parameters["W_a"]b_a = parameters["b_a"]# 计算重置门r_t = sigmoid(np.dot(W_r, np.concatenate([a_prev, x])) + b_r)# 计算更新门z_t = sigmoid(np.dot(W_z, np.concatenate([a_prev, x])) + b_z)# 计算候选隐藏状态tilde_a_t = tanh(np.dot(W_a, np.concatenate([r_t * a_prev, x])) + b_a)# 计算新的隐藏状态a_t = (1 - z_t) * a_prev + z_t * tilde_a_t# 保存计算中间结果,以便反向传播cache = (a_prev, x, r_t, z_t, tilde_a_t, a_t)return a_t, cache# 示例数据和参数
np.random.seed(1)
a_prev = np.random.randn(5, 1)  # 上一时刻的隐藏状态
x = np.random.randn(3, 1)  # 当前时刻的输入数据# GRU参数
parameters = {"W_r": np.random.randn(5, 8),"b_r": np.random.randn(5, 1),"W_z": np.random.randn(5, 8),"b_z": np.random.randn(5, 1),"W_a": np.random.randn(5, 8),"b_a": np.random.randn(5, 1)
}# 单个GRU单元的前向传播
a_t, cache = gru_cell(a_prev, x, parameters)# 打印输出形状
print("a_t.shape:", a_t.shape)

本文参考了以下链接:http://colah.github.io/posts/2015-08-Understanding-LSTMs/

相关文章:

循环神经网络的变体模型-LSTM、GRU

一.LSTM(长短时记忆网络) 1.1基本介绍 长短时记忆网络(Long Short-Term Memory,LSTM)是一种深度学习模型,属于循环神经网络(Recurrent Neural Network,RNN)的一种变体。…...

视频图像的color range简介

介绍 研究FFmpeg发现,在avcodec.h中有关于color的解释,主要有四个属性,primaries、transfer、space和range。 color primaries: 基于RGB空间对应的绝对颜色XYZ的变换,决定了最终三原色RGB分别是什么颜色;…...

tcp的三次握手

http 和 https 都是是基于 TCP 的请求,https 是 http 加上 tls 连接。TCP 是面向连接的协议。 对于 http1.1 协议chrome 限制在同一个域名下最多可以建立 6 个 tcp 连接,所以如果在同一个域名下,同时有超过 6 个请求发生,那么多余…...

unity 矩阵探究

public void MatrixTest1(){ ///Matrix4x4 是列矩阵,就是一个vector4表示一列,所以在c#中矩阵和Vector4只能矩阵右乘坐标。但是在shader中是矩阵左乘坐标,所以在shader中是行矩阵 Matrix4x4 moveMatrix1 new Matrix4x4(new Vector4(1,0,0,0)…...

MySQL---单表查询综合练习

创建emp表 CREATE TABLE emp( empno INT(4) NOT NULL COMMENT 员工编号, ename VARCHAR(10) COMMENT 员工名字, job VARCHAR(10) COMMENT 职位, mgr INT(4) COMMENT 上司, hiredate DATE COMMENT 入职时间, sal INT(7) COMMENT 基本工资, comm INT(7) COMMENT 补贴, deptno INT…...

Python项目——搞怪小程序(PySide6+Pyinstaller)

1、介绍 使用python编写一个小程序,回答你是猪吗。 点击“是”提交,弹窗并退出。 点击“不是”提交,等待5秒,重新选择。 并且隐藏了关闭按钮。 2、实现 新建一个项目。 2.1、设计UI 使用Qt designer设计一个UI界面&#xff0c…...

MySQL练习题

参考:https://blog.csdn.net/paul0127/article/details/82529216 数据表介绍 --1.学生表 Student(SId,Sname,Sage,Ssex) --SId 学生编号,Sname 学生姓名,Sage 出生年月,Ssex 学生性别 --2.课程表 Course(CId,Cname,TId) --CId 课程编号,Cname 课程名称,TId 教师编号…...

vue-项目打包、配置路由懒加载

1. 简介 在现代前端开发中,Vue.js因其简洁、灵活和高效的特点,已经成为许多开发者的首选框架。 在Vue项目中,打包部署和路由懒加载是两个非常重要的环节。 打包Vue项目是为了将源代码转换为浏览器可以解析的JavaScript文件,以便…...

词语的魔力:语言在我们生活中的艺术与影响

Words That Move Mountains: The Art and Impact of Language in Our Lives 词语的魔力:语言在我们生活中的艺术与影响 Hello there, wonderful people! Today, I’d like to gab about the magical essence of language that’s more than just a chatty tool in o…...

android List,Set,Map区别和介绍

List 元素存放有序,元素可重复 1.LinkedList 链表,插入删除,非线性安全,插入和删除操作是双向链表操作,增加删除快,查找慢 add(E e)//添加元素 addFirst(E e)//向集合头部添加元素 addList(E e)//向集合…...

Mysql 编译安装部署

Mysql 编译安装部署 环境: 172.20.26.198(Centos7.6) 源码安装Mysql-5.7 大概步骤如下: 1、上传mysql-5.7.28.tar.gz 、boost_1_59_0.tar 到/usr/src 目录下 2、安装依赖 3、cmake 4、make && make install 5、…...

【目标检测】YOLOv5算法实现(九):模型预测

本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。   本系列文章主要以YOLOv5为例完成算法的实现,后续修改、增加相关模…...

centos宝塔远程服务器怎么链接?

要远程连接CentOS宝塔服务器,可以按照以下步骤操作: 打开终端或远程连接工具,比如PuTTY。输入服务器的IP地址和SSH端口号(默认为22),点击连接。输入用户名和密码进行登录。 如果你已经安装了宝塔面板&…...

C语言练习day8

变种水仙花 变种水仙花_牛客题霸_牛客网 题目: 思路:我们拿到题目的第一步可以先看一看题目给的例子,1461这个数被从中间拆成了两部分:1和461,14和61,146和1,不知道看到这大家有没有觉得很熟…...

蓝凌OA-sysuicomponent-任意文件上传_exp-漏洞复现

0x01阅读须知 技术文章仅供参考,此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的…...

C#,入门教程(38)——大型工程软件中类(class)修饰词partial的使用方法

上一篇: C#,入门教程(37)——优秀程序员的修炼之道https://blog.csdn.net/beijinghorn/article/details/125011644 一、大型(工程应用)软件倚重 partial 先说说大型(工程应用)软件对源代码的文件及函数“…...

C++播放音乐:使用EGE图形库

——开胃菜,闲话篓子一大片 最近,我发现ege图形库不是个正经的图形库—— 那天,我又在打趣儿地翻代码时,无意间看到了这个: 图形库?!你哪来的音乐(Music)呢&#xff1f…...

C++中const和constexpr的区别:了解常量的不同用法

C中const和constexpr的区别 一、C中的常量概念二、const关键字的用法和特点三、constexpr关键字的用法和特点四、const和constexpr的区别对比4.1、编译时计算能力4.2、可以赋值的范围4.3、对类和对象的适用性4.4、对函数的适用性4.5、性能和效率的差异 五、使用示例六、总结 一…...

高级架构师是如何设计一个系统的?

架构师如何设计系统? 系统拆分 通过DDD领域模型,对服务进行拆分,将一个系统拆分为多个子系统,做成SpringCloud的微服务。微服务设计时要尽可能做到少扇出,多扇入,根据服务器的承载,进行客户端负…...

力扣:474. 一和零(动态规划)(01背包)

题目: 给你一个二进制字符串数组 strs 和两个整数 m 和 n 。 请你找出并返回 strs 的最大子集的长度,该子集中 最多 有 m 个 0 和 n 个 1 。 如果 x 的所有元素也是 y 的元素,集合 x 是集合 y 的 子集 。 示例 1: 输入&#…...

【网络】每天掌握一个Linux命令 - iftop

在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器

一.自适应梯度算法Adagrad概述 Adagrad(Adaptive Gradient Algorithm)是一种自适应学习率的优化算法,由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率,适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中,我们可能会遇到一些流式数据处理的场景,比如接收来自上游接口的 Server-Sent Events(SSE) 或 流式 JSON 内容,并将其原样中转给前端页面或客户端。这种情况下,传统的 RestTemplate 缓存机制会…...

shell脚本--常见案例

1、自动备份文件或目录 2、批量重命名文件 3、查找并删除指定名称的文件: 4、批量删除文件 5、查找并替换文件内容 6、批量创建文件 7、创建文件夹并移动文件 8、在文件夹中查找文件...

现代密码学 | 椭圆曲线密码学—附py代码

Elliptic Curve Cryptography 椭圆曲线密码学(ECC)是一种基于有限域上椭圆曲线数学特性的公钥加密技术。其核心原理涉及椭圆曲线的代数性质、离散对数问题以及有限域上的运算。 椭圆曲线密码学是多种数字签名算法的基础,例如椭圆曲线数字签…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

【开发技术】.Net使用FFmpeg视频特定帧上绘制内容

目录 一、目的 二、解决方案 2.1 什么是FFmpeg 2.2 FFmpeg主要功能 2.3 使用Xabe.FFmpeg调用FFmpeg功能 2.4 使用 FFmpeg 的 drawbox 滤镜来绘制 ROI 三、总结 一、目的 当前市场上有很多目标检测智能识别的相关算法&#xff0c;当前调用一个医疗行业的AI识别算法后返回…...

sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!

简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求&#xff0c;并检查收到的响应。它以以下模式之一…...

LRU 缓存机制详解与实现(Java版) + 力扣解决

&#x1f4cc; LRU 缓存机制详解与实现&#xff08;Java版&#xff09; 一、&#x1f4d6; 问题背景 在日常开发中&#xff0c;我们经常会使用 缓存&#xff08;Cache&#xff09; 来提升性能。但由于内存有限&#xff0c;缓存不可能无限增长&#xff0c;于是需要策略决定&am…...

LangFlow技术架构分析

&#x1f527; LangFlow 的可视化技术栈 前端节点编辑器 底层框架&#xff1a;基于 &#xff08;一个现代化的 React 节点绘图库&#xff09; 功能&#xff1a; 拖拽式构建 LangGraph 状态机 实时连线定义节点依赖关系 可视化调试循环和分支逻辑 与 LangGraph 的深…...