当前位置: 首页 > news >正文

LSTM的变体

一、GRU

1、什么是GRU

门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了LSTM的结构,将遗忘门和输入门合并为一个更新门,并增加了一个重置门,同时合并了单元状态和隐藏状态,使得模型更加简洁,训练速度更快,且在性能上与LSTM相当。

2、GRU的核心

核心在于两个门:更新门(update gate)和重置门(reset gate)。更新门控制着从前一时刻的状态信息中保留多少到当前状态,而重置门决定着前一状态有多少信息被写入到当前的候选集中。这种结构使得GRU在处理长序列数据时能够更好地捕捉长期依赖关系,同时减少了模型参数,提高了计算效率。

3、GRU的应用

应用非常广泛,包括但不限于自然语言处理(NLP)、语音识别、图像处理等领域。在NLP领域,GRU可以用于语言建模、情感分析、机器翻译等任务;在语音识别领域,GRU可以用于语音信号的特征提取和识别;在图像处理领域,GRU可以用于图像分类、目标检测等任务。GRU的简洁性和效率使其在处理大规模序列数据时具有优势。

在选择GRU和LSTM时,通常考虑的因素包括任务的复杂性、数据集的大小以及训练资源。由于GRU参数更少,收敛速度更快,因此在需要快速迭代和实验时,GRU通常是首选。然而,在某些需要对复杂序列依赖关系进行建模的任务中,LSTM可能会表现得更好。

总的来说,GRU是一种强大的循环神经网络架构,它通过引入门控机制来控制信息流,有效地解决了传统RNN的梯度消失问题。GRU的简洁性和效率使其在多种序列建模任务中表现出色,成为了深度学习中处理时序数据的重要工具之一。

4、GRU的工作原理

5、手写代码实现

import numpy as npclass GRU():def __init__(self, input_size, hidden_size):self.input_size = input_sizeself.hidden_size = hidden_size# 初始化参数w和bself.W_z = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_z = np.zeros(self.hidden_size)# 重置门self.W_r = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_r = np.zeros(self.hidden_size)# 候选隐藏状态self.W_h = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_h = np.zeros(self.hidden_size)def tanh(self, x):return np.tanh(x)def sigmoid(self, x):return 1 / (1 + np.exp(-x))def forward(self, x):h_prev = np.zeros((self.hidden_size,))concat_input = np.concatenate([x, h_prev], axis=0)z_t = self.sigmoid(np.dot(self.W_z, concat_input) + self.b_z)r_t = self.sigmoid(np.dot(self.W_r, concat_input) + self.b_r)concat_reset_input = np.concatenate([x, r_t * h_prev], axis=0)h_hat_t = self.tanh(np.dot(self.W_h, concat_reset_input) + self.b_h)h_t = (1 - z_t) * h_prev + z_t * h_hat_treturn h_t

二、BiLSTM

1、什么是BiSTM

BiSTM,即双向门控循环单元(Bidirectional Gated Recurrent Unit),是一种循环神经网络(RNN)的变体。它结合了前向和后向的GRU,能够同时处理过去和未来的信息,从而更好地捕捉序列数据中的上下文关系。

在BiSTM中,数据通过两个GRU网络进行处理:一个从左到右(前向),另一个从右到左(后向)。这两个网络的输出然后被拼接或相加,形成最终的特征表示,这个特征表示包含了序列的双向信息。这种结构特别适合于需要理解序列中前后文信息的任务,如文本分类、语音识别、命名实体识别(NER)等。

2、BiSTM的关键特点包括:

  1. 双向信息捕捉:BiSTM能够同时考虑序列中每个元素之前的和之后的上下文信息,这使得它在处理像文本这样的序列数据时非常有效,因为文本中词汇的含义往往受到其前后词汇的影响。

  2. 门控机制:BiSTM继承了GRU的门控机制,包括更新门和重置门,这些门控单元可以控制信息的流动,从而减少无效或噪声信息的干扰,并增强模型对重要信息的记忆能力。

  3. 应用广泛:BiSTM因其强大的序列处理能力而被广泛应用于各种领域,包括自然语言处理(NLP)、语音识别、时间序列分析等。

  4. 模型性能:在某些任务中,BiSTM能够提供比单向GRU或LSTM更好的性能,尤其是在需要捕捉长期依赖关系的任务中。

  5. 模型复杂度:由于BiSTM包含两个GRU网络,其模型参数和计算复杂度相对于单向GRU或LSTM会有所增加,但在很多情况下,这种增加是值得的,因为它能带来更准确的预测结果

 3、手写BiLSTM代码

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequenceclass LSTM(nn.Module):def __init__(self, vocab_size, target_size, input_size=512, hidden_size=512):super(LSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(vocab_size, input_size)self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size),nn.GELU(),nn.Linear(hidden_size, hidden_size))self.lstm = nn.LSTM(hidden_size, hidden_size * 2, num_layers=3, batch_first=True, dropout=0.5)self.avg_lstm = nn.AdaptiveAvgPool1d(1)self.avg_linear = nn.AdaptiveAvgPool1d(1)self.out_linear = nn.Sequential(nn.Linear(hidden_size * 2 + hidden_size, hidden_size),nn.GELU(),nn.LayerNorm(hidden_size),nn.Linear(hidden_size, target_size))self.norm = nn.LayerNorm(hidden_size * 2)def forward(self, x, lengths):x = self.embedding(x)mlp = self.mlp(x)pached_embed = pack_padded_sequence(mlp, lengths, batch_first=True, enforce_sorted=False)lstm_out, _ = self.lstm(pached_embed)lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)lstm_out = self.norm(lstm_out)avg_lstm = self.avg_lstm(lstm_out.permute(0, 2, 1)).squeeze(-1)avg_linear = self.avg_linear(mlp.permute(0, 2, 1)).squeeze(-1)out = torch.cat([avg_lstm, avg_linear], dim=-1)return self.out_linear(out)class BiLSTM(nn.Module):def __init__(self, input_size=512, hidden_size=512, output_size=512):super(BiLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm_forward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)self.lstm_backward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)def forward(self, x):out_forward, _ = self.lstm_forward(x)out_backward, _ = self.lstm_backward(torch.flip(x, dims=[1]))out_backward = torch.flip(out_backward, dims=[1])combined_output = torch.cat([out_forward, out_backward], dim=-1)return combined_output

相关文章:

LSTM的变体

一、GRU 1、什么是GRU 门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了…...

LeetCode讲解篇之852. 山脉数组的峰顶索引

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 我们可以采用二分查找,每次查询区间中点元素与中点下一个元素比较 如果中点元素大于其下一个元素,则表示从中点开始向右是递减趋势,那峰值索引一定小于等于中点,我…...

矿井人员数据集,用于目标检测,深度学习,采用txt打标签,即yolo格式,也有原文件可以自己转换。总共3500张图片的数据量,划分给训练集2446张,

矿井人员数据集,用于目标检测,深度学习,采用txt打标签,即yolo格式,也有原文件可以自己转换。总共3500张图片的数据量,划分给训练集2446张: ### 矿井人员数据集用于目标检测的详细说明 #### 1. …...

消息队列RabbitMQ

文章目录 1. 简介与安装2. 基本概念3. SpringAMQP4. 交换机类型5. 消息转换器5.1 默认转换器5.2 配置JSON转换器 6 生产者的可靠性6.1 生产者超时重连机制6.2 生产者确认机制 6. MQ的可靠性6.1 数据持久化6.2 惰性队列 Lazy Queue 7. 消费者的可靠性7.1 消费者确认机制7.2 失败…...

RabbitMQ概述

什么是MQ MQ (message queue)消息队列 MQ从字⾯意思上看,本质是个队列,FIFO先⼊先出,只不过队列中存放的内容是消息(message).消息可以⾮常简单,⽐如只包含⽂本字符串,JSON等,也可以很复杂,⽐如内嵌对象 RabbitMQ是MQ的一种实现,是Rabbit 企业下的⼀个消息队列产…...

Golang学习路线

以下是一条学习Golang(Go语言)的路线: 一、基础入门 1. 环境搭建 安装Go编译器,在官网(https://golang.org/dl/)下载适合操作系统的安装包并配置好环境变量。 2. 语法学习学习变量、数据类型&#xff08…...

Flink从ck拉起任务脚本

#!/bin/bashAPP_NAME"orderTest"CHECKPOINT_BASE_PATH"hdfs:///jobs/flink/checkpoints/aaa-test/"is_running$(yarn application -list | grep -w "$APP_NAME" | grep -c "RUNNING")if [ $is_running -gt 0 ]; thenecho "应用程…...

GADBench Revisiting and Benchmarking Supervised Graph Anomaly Detection

Neurips 23 推荐指数: #paper/⭐⭐⭐ 领域:图异常检测 胡言乱语: neurips 的benchmark模块的文章总能给人一些启发性的理解,这篇的insight真有意思。个人感兴趣的地方会加粗。此外,这篇文章和腾讯AIlab合作&#xff…...

某象异形滑块99%准确率方案

注意,本文只提供学习的思路,严禁违反法律以及破坏信息系统等行为,本文只提供思路 如有侵犯,请联系作者下架 该文章模型已经上线ocr识别网站,欢迎测试!!,地址:https://yxlocr.windy-rain.cn/ocr/slider/6 所谓的顶象异形滑块,是指没有采用常规的缺口,使用各种形状的…...

CDN绕过学习

1.什么是CDN? CDN就是分布在各个地区的服务器,这些服务器储存着数据的副本。 哪些服务器比较接近你,当你发起请求时,提前就会快速为你提供服务。 总结来说就是: 其实就是用来加速访问的,以及缓解压力&a…...

SpringBoot日常:redission的接入使用和源码解析

文章目录 一、简介二、集成redissionpom文件redission 配置文件application.yml文件启动类 三、JAVA 操作案例字符串操作哈希操作列表操作集合操作有序集合操作布隆过滤器操作分布式锁操作 四、源码解析 一、简介 Redisson 是一个在 Redis 的基础上实现的 Java 驻内存数据网格…...

npm包管理深度探索:从基础到进阶全面教程!

目录 一、npm概述1、npm简介(1)什么是npm?(2)npm的核心功能(3)npm的工作原理(4)npm的优势(5)npm的局限性(6)总结 2、npm的…...

最新免费GPT4O和Midjourney

一、什么是GPT4O? GPT-4 是 OpenAI 研发的大型语言模型。它具有强大的语言理解和生成能力,在自然语言处理等诸多领域有着广泛的应用和表现。 二、什么是Midjourney? Midjourney 是一款人工智能图像生成工具。它可以根据用户输入的描述或提…...

python操作OpenAI教程

python操作OpenAI pip install -U openai代码: from openai import OpenAI# 解决请求超时问题 import os os.environ["http_proxy"] "http://localhost:7890" os.environ["https_proxy"] "http://localhost:7890"# 需要…...

如何版本REST API:综合指南

目录 总则什么是REST API版本控制?为什么API版本控制很重要?如何对REST API进行版本控制 理解API契约评估需求选择版本控制策略沟通变化保持向后兼容性详细文档记录REST API版本控制最佳实践REST API版本控制常见问题:REST API版本控制总则 版本化REST API对于确保软件应用…...

Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南

Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南 文章目录 Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南一 查看模块是否安装二 配置 status 访问端点三 Docker 部署 nginx-prome…...

网络安全-IPv4和IPv6的区别

1. 2409:8c20:6:1135:0:ff:b027:210d。 这是一个IPv6地址。IPv6(互联网协议版本6)是用于标识网络中的设备的一种协议,它可以提供比IPv4更大的地址空间。这个地址由八组十六进制数字组成,每组之间用冒号分隔。IPv6地址通常用于替代…...

【移动端】事件基础

一、移动端事件分类 移动端事件主要分为以下几类: 1. 触摸事件(Touch Events) 触摸事件是移动设备特有的事件,用来处理用户通过触摸屏幕进行的操作。主要的触摸事件有: touchstart:手指触摸屏幕时触发。…...

软件测试比赛-学习

一、环境配置 二、浏览器适配 //1.设置浏览器的位置,google浏览器位置是默认且固定在电脑里的//2.设置浏览器驱动的位置,C:\Users\27743\AppData\Local\Google\Chrome\ApplicationSystem.setProperty("webdriver.chrome.driver", "C:\\Users\\27743\\AppData\\…...

力扣LeetCode-链表中的循环与递归使用

标题做题的时候发现循环与递归的使用差别: 看两道题: 两道题都是不知道链表有多长,所以需要用到循环,用到循环就可以把整个过程分成多个循环体,就是每一次循环要执行的内容。 反转链表: 把null–>1…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...

【位运算】消失的两个数字(hard)

消失的两个数字(hard) 题⽬描述:解法(位运算):Java 算法代码:更简便代码 题⽬链接:⾯试题 17.19. 消失的两个数字 题⽬描述: 给定⼀个数组,包含从 1 到 N 所有…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络,将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具,支持 Chrome、Firefox、Safari 等主流浏览器,提供多语言 API(Python、JavaScript、Java、.NET)。它的特点包括&a…...

基于当前项目通过npm包形式暴露公共组件

1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹,并新增内容 3.创建package文件夹...

MVC 数据库

MVC 数据库 引言 在软件开发领域,Model-View-Controller(MVC)是一种流行的软件架构模式,它将应用程序分为三个核心组件:模型(Model)、视图(View)和控制器(Controller)。这种模式有助于提高代码的可维护性和可扩展性。本文将深入探讨MVC架构与数据库之间的关系,以…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文全面剖析RNN核心原理,深入讲解梯度消失/爆炸问题,并通过LSTM/GRU结构实现解决方案,提供时间序列预测和文本生成…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...

《C++ 模板》

目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板,就像一个模具,里面可以将不同类型的材料做成一个形状,其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式:templa…...

腾讯云V3签名

想要接入腾讯云的Api,必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口,但总是卡在签名这一步,最后放弃选择SDK,这次终于自己代码实现。 可能腾讯云翻新了接口文档,现在阅读起来,清晰了很多&…...