当前位置: 首页 > news >正文

LSTM的变体

一、GRU

1、什么是GRU

门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了LSTM的结构,将遗忘门和输入门合并为一个更新门,并增加了一个重置门,同时合并了单元状态和隐藏状态,使得模型更加简洁,训练速度更快,且在性能上与LSTM相当。

2、GRU的核心

核心在于两个门:更新门(update gate)和重置门(reset gate)。更新门控制着从前一时刻的状态信息中保留多少到当前状态,而重置门决定着前一状态有多少信息被写入到当前的候选集中。这种结构使得GRU在处理长序列数据时能够更好地捕捉长期依赖关系,同时减少了模型参数,提高了计算效率。

3、GRU的应用

应用非常广泛,包括但不限于自然语言处理(NLP)、语音识别、图像处理等领域。在NLP领域,GRU可以用于语言建模、情感分析、机器翻译等任务;在语音识别领域,GRU可以用于语音信号的特征提取和识别;在图像处理领域,GRU可以用于图像分类、目标检测等任务。GRU的简洁性和效率使其在处理大规模序列数据时具有优势。

在选择GRU和LSTM时,通常考虑的因素包括任务的复杂性、数据集的大小以及训练资源。由于GRU参数更少,收敛速度更快,因此在需要快速迭代和实验时,GRU通常是首选。然而,在某些需要对复杂序列依赖关系进行建模的任务中,LSTM可能会表现得更好。

总的来说,GRU是一种强大的循环神经网络架构,它通过引入门控机制来控制信息流,有效地解决了传统RNN的梯度消失问题。GRU的简洁性和效率使其在多种序列建模任务中表现出色,成为了深度学习中处理时序数据的重要工具之一。

4、GRU的工作原理

5、手写代码实现

import numpy as npclass GRU():def __init__(self, input_size, hidden_size):self.input_size = input_sizeself.hidden_size = hidden_size# 初始化参数w和bself.W_z = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_z = np.zeros(self.hidden_size)# 重置门self.W_r = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_r = np.zeros(self.hidden_size)# 候选隐藏状态self.W_h = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_h = np.zeros(self.hidden_size)def tanh(self, x):return np.tanh(x)def sigmoid(self, x):return 1 / (1 + np.exp(-x))def forward(self, x):h_prev = np.zeros((self.hidden_size,))concat_input = np.concatenate([x, h_prev], axis=0)z_t = self.sigmoid(np.dot(self.W_z, concat_input) + self.b_z)r_t = self.sigmoid(np.dot(self.W_r, concat_input) + self.b_r)concat_reset_input = np.concatenate([x, r_t * h_prev], axis=0)h_hat_t = self.tanh(np.dot(self.W_h, concat_reset_input) + self.b_h)h_t = (1 - z_t) * h_prev + z_t * h_hat_treturn h_t

二、BiLSTM

1、什么是BiSTM

BiSTM,即双向门控循环单元(Bidirectional Gated Recurrent Unit),是一种循环神经网络(RNN)的变体。它结合了前向和后向的GRU,能够同时处理过去和未来的信息,从而更好地捕捉序列数据中的上下文关系。

在BiSTM中,数据通过两个GRU网络进行处理:一个从左到右(前向),另一个从右到左(后向)。这两个网络的输出然后被拼接或相加,形成最终的特征表示,这个特征表示包含了序列的双向信息。这种结构特别适合于需要理解序列中前后文信息的任务,如文本分类、语音识别、命名实体识别(NER)等。

2、BiSTM的关键特点包括:

  1. 双向信息捕捉:BiSTM能够同时考虑序列中每个元素之前的和之后的上下文信息,这使得它在处理像文本这样的序列数据时非常有效,因为文本中词汇的含义往往受到其前后词汇的影响。

  2. 门控机制:BiSTM继承了GRU的门控机制,包括更新门和重置门,这些门控单元可以控制信息的流动,从而减少无效或噪声信息的干扰,并增强模型对重要信息的记忆能力。

  3. 应用广泛:BiSTM因其强大的序列处理能力而被广泛应用于各种领域,包括自然语言处理(NLP)、语音识别、时间序列分析等。

  4. 模型性能:在某些任务中,BiSTM能够提供比单向GRU或LSTM更好的性能,尤其是在需要捕捉长期依赖关系的任务中。

  5. 模型复杂度:由于BiSTM包含两个GRU网络,其模型参数和计算复杂度相对于单向GRU或LSTM会有所增加,但在很多情况下,这种增加是值得的,因为它能带来更准确的预测结果

 3、手写BiLSTM代码

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequenceclass LSTM(nn.Module):def __init__(self, vocab_size, target_size, input_size=512, hidden_size=512):super(LSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(vocab_size, input_size)self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size),nn.GELU(),nn.Linear(hidden_size, hidden_size))self.lstm = nn.LSTM(hidden_size, hidden_size * 2, num_layers=3, batch_first=True, dropout=0.5)self.avg_lstm = nn.AdaptiveAvgPool1d(1)self.avg_linear = nn.AdaptiveAvgPool1d(1)self.out_linear = nn.Sequential(nn.Linear(hidden_size * 2 + hidden_size, hidden_size),nn.GELU(),nn.LayerNorm(hidden_size),nn.Linear(hidden_size, target_size))self.norm = nn.LayerNorm(hidden_size * 2)def forward(self, x, lengths):x = self.embedding(x)mlp = self.mlp(x)pached_embed = pack_padded_sequence(mlp, lengths, batch_first=True, enforce_sorted=False)lstm_out, _ = self.lstm(pached_embed)lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)lstm_out = self.norm(lstm_out)avg_lstm = self.avg_lstm(lstm_out.permute(0, 2, 1)).squeeze(-1)avg_linear = self.avg_linear(mlp.permute(0, 2, 1)).squeeze(-1)out = torch.cat([avg_lstm, avg_linear], dim=-1)return self.out_linear(out)class BiLSTM(nn.Module):def __init__(self, input_size=512, hidden_size=512, output_size=512):super(BiLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm_forward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)self.lstm_backward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)def forward(self, x):out_forward, _ = self.lstm_forward(x)out_backward, _ = self.lstm_backward(torch.flip(x, dims=[1]))out_backward = torch.flip(out_backward, dims=[1])combined_output = torch.cat([out_forward, out_backward], dim=-1)return combined_output

相关文章:

LSTM的变体

一、GRU 1、什么是GRU 门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了…...

LeetCode讲解篇之852. 山脉数组的峰顶索引

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 我们可以采用二分查找,每次查询区间中点元素与中点下一个元素比较 如果中点元素大于其下一个元素,则表示从中点开始向右是递减趋势,那峰值索引一定小于等于中点,我…...

矿井人员数据集,用于目标检测,深度学习,采用txt打标签,即yolo格式,也有原文件可以自己转换。总共3500张图片的数据量,划分给训练集2446张,

矿井人员数据集,用于目标检测,深度学习,采用txt打标签,即yolo格式,也有原文件可以自己转换。总共3500张图片的数据量,划分给训练集2446张: ### 矿井人员数据集用于目标检测的详细说明 #### 1. …...

消息队列RabbitMQ

文章目录 1. 简介与安装2. 基本概念3. SpringAMQP4. 交换机类型5. 消息转换器5.1 默认转换器5.2 配置JSON转换器 6 生产者的可靠性6.1 生产者超时重连机制6.2 生产者确认机制 6. MQ的可靠性6.1 数据持久化6.2 惰性队列 Lazy Queue 7. 消费者的可靠性7.1 消费者确认机制7.2 失败…...

RabbitMQ概述

什么是MQ MQ (message queue)消息队列 MQ从字⾯意思上看,本质是个队列,FIFO先⼊先出,只不过队列中存放的内容是消息(message).消息可以⾮常简单,⽐如只包含⽂本字符串,JSON等,也可以很复杂,⽐如内嵌对象 RabbitMQ是MQ的一种实现,是Rabbit 企业下的⼀个消息队列产…...

Golang学习路线

以下是一条学习Golang(Go语言)的路线: 一、基础入门 1. 环境搭建 安装Go编译器,在官网(https://golang.org/dl/)下载适合操作系统的安装包并配置好环境变量。 2. 语法学习学习变量、数据类型&#xff08…...

Flink从ck拉起任务脚本

#!/bin/bashAPP_NAME"orderTest"CHECKPOINT_BASE_PATH"hdfs:///jobs/flink/checkpoints/aaa-test/"is_running$(yarn application -list | grep -w "$APP_NAME" | grep -c "RUNNING")if [ $is_running -gt 0 ]; thenecho "应用程…...

GADBench Revisiting and Benchmarking Supervised Graph Anomaly Detection

Neurips 23 推荐指数: #paper/⭐⭐⭐ 领域:图异常检测 胡言乱语: neurips 的benchmark模块的文章总能给人一些启发性的理解,这篇的insight真有意思。个人感兴趣的地方会加粗。此外,这篇文章和腾讯AIlab合作&#xff…...

某象异形滑块99%准确率方案

注意,本文只提供学习的思路,严禁违反法律以及破坏信息系统等行为,本文只提供思路 如有侵犯,请联系作者下架 该文章模型已经上线ocr识别网站,欢迎测试!!,地址:https://yxlocr.windy-rain.cn/ocr/slider/6 所谓的顶象异形滑块,是指没有采用常规的缺口,使用各种形状的…...

CDN绕过学习

1.什么是CDN? CDN就是分布在各个地区的服务器,这些服务器储存着数据的副本。 哪些服务器比较接近你,当你发起请求时,提前就会快速为你提供服务。 总结来说就是: 其实就是用来加速访问的,以及缓解压力&a…...

SpringBoot日常:redission的接入使用和源码解析

文章目录 一、简介二、集成redissionpom文件redission 配置文件application.yml文件启动类 三、JAVA 操作案例字符串操作哈希操作列表操作集合操作有序集合操作布隆过滤器操作分布式锁操作 四、源码解析 一、简介 Redisson 是一个在 Redis 的基础上实现的 Java 驻内存数据网格…...

npm包管理深度探索:从基础到进阶全面教程!

目录 一、npm概述1、npm简介(1)什么是npm?(2)npm的核心功能(3)npm的工作原理(4)npm的优势(5)npm的局限性(6)总结 2、npm的…...

最新免费GPT4O和Midjourney

一、什么是GPT4O? GPT-4 是 OpenAI 研发的大型语言模型。它具有强大的语言理解和生成能力,在自然语言处理等诸多领域有着广泛的应用和表现。 二、什么是Midjourney? Midjourney 是一款人工智能图像生成工具。它可以根据用户输入的描述或提…...

python操作OpenAI教程

python操作OpenAI pip install -U openai代码: from openai import OpenAI# 解决请求超时问题 import os os.environ["http_proxy"] "http://localhost:7890" os.environ["https_proxy"] "http://localhost:7890"# 需要…...

如何版本REST API:综合指南

目录 总则什么是REST API版本控制?为什么API版本控制很重要?如何对REST API进行版本控制 理解API契约评估需求选择版本控制策略沟通变化保持向后兼容性详细文档记录REST API版本控制最佳实践REST API版本控制常见问题:REST API版本控制总则 版本化REST API对于确保软件应用…...

Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南

Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南 文章目录 Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南一 查看模块是否安装二 配置 status 访问端点三 Docker 部署 nginx-prome…...

网络安全-IPv4和IPv6的区别

1. 2409:8c20:6:1135:0:ff:b027:210d。 这是一个IPv6地址。IPv6(互联网协议版本6)是用于标识网络中的设备的一种协议,它可以提供比IPv4更大的地址空间。这个地址由八组十六进制数字组成,每组之间用冒号分隔。IPv6地址通常用于替代…...

【移动端】事件基础

一、移动端事件分类 移动端事件主要分为以下几类: 1. 触摸事件(Touch Events) 触摸事件是移动设备特有的事件,用来处理用户通过触摸屏幕进行的操作。主要的触摸事件有: touchstart:手指触摸屏幕时触发。…...

软件测试比赛-学习

一、环境配置 二、浏览器适配 //1.设置浏览器的位置,google浏览器位置是默认且固定在电脑里的//2.设置浏览器驱动的位置,C:\Users\27743\AppData\Local\Google\Chrome\ApplicationSystem.setProperty("webdriver.chrome.driver", "C:\\Users\\27743\\AppData\\…...

力扣LeetCode-链表中的循环与递归使用

标题做题的时候发现循环与递归的使用差别: 看两道题: 两道题都是不知道链表有多长,所以需要用到循环,用到循环就可以把整个过程分成多个循环体,就是每一次循环要执行的内容。 反转链表: 把null–>1…...

【Python】 -- 趣味代码 - 小恐龙游戏

文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互

物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统:ubuntu22.04 IDE:Visual Studio Code 编程语言:C11 题目描述 地上有一个 m 行 n 列的方格,从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子,但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

对WWDC 2025 Keynote 内容的预测

借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...

postgresql|数据库|只读用户的创建和删除(备忘)

CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...

SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现

摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...

Python ROS2【机器人中间件框架】 简介

销量过万TEEIS德国护膝夏天用薄款 优惠券冠生园 百花蜂蜜428g 挤压瓶纯蜂蜜巨奇严选 鞋子除臭剂360ml 多芬身体磨砂膏280g健70%-75%酒精消毒棉片湿巾1418cm 80片/袋3袋大包清洁食品用消毒 优惠券AIMORNY52朵红玫瑰永生香皂花同城配送非鲜花七夕情人节生日礼物送女友 热卖妙洁棉…...

STM32---外部32.768K晶振(LSE)无法起振问题

晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...

Rust 开发环境搭建

环境搭建 1、开发工具RustRover 或者vs code 2、Cygwin64 安装 https://cygwin.com/install.html 在工具终端执行: rustup toolchain install stable-x86_64-pc-windows-gnu rustup default stable-x86_64-pc-windows-gnu ​ 2、Hello World fn main() { println…...