当前位置: 首页 > news >正文

LSTM的变体

一、GRU

1、什么是GRU

门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了LSTM的结构,将遗忘门和输入门合并为一个更新门,并增加了一个重置门,同时合并了单元状态和隐藏状态,使得模型更加简洁,训练速度更快,且在性能上与LSTM相当。

2、GRU的核心

核心在于两个门:更新门(update gate)和重置门(reset gate)。更新门控制着从前一时刻的状态信息中保留多少到当前状态,而重置门决定着前一状态有多少信息被写入到当前的候选集中。这种结构使得GRU在处理长序列数据时能够更好地捕捉长期依赖关系,同时减少了模型参数,提高了计算效率。

3、GRU的应用

应用非常广泛,包括但不限于自然语言处理(NLP)、语音识别、图像处理等领域。在NLP领域,GRU可以用于语言建模、情感分析、机器翻译等任务;在语音识别领域,GRU可以用于语音信号的特征提取和识别;在图像处理领域,GRU可以用于图像分类、目标检测等任务。GRU的简洁性和效率使其在处理大规模序列数据时具有优势。

在选择GRU和LSTM时,通常考虑的因素包括任务的复杂性、数据集的大小以及训练资源。由于GRU参数更少,收敛速度更快,因此在需要快速迭代和实验时,GRU通常是首选。然而,在某些需要对复杂序列依赖关系进行建模的任务中,LSTM可能会表现得更好。

总的来说,GRU是一种强大的循环神经网络架构,它通过引入门控机制来控制信息流,有效地解决了传统RNN的梯度消失问题。GRU的简洁性和效率使其在多种序列建模任务中表现出色,成为了深度学习中处理时序数据的重要工具之一。

4、GRU的工作原理

5、手写代码实现

import numpy as npclass GRU():def __init__(self, input_size, hidden_size):self.input_size = input_sizeself.hidden_size = hidden_size# 初始化参数w和bself.W_z = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_z = np.zeros(self.hidden_size)# 重置门self.W_r = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_r = np.zeros(self.hidden_size)# 候选隐藏状态self.W_h = np.random.randn(self.hidden_size, self.input_size + self.hidden_size)self.b_h = np.zeros(self.hidden_size)def tanh(self, x):return np.tanh(x)def sigmoid(self, x):return 1 / (1 + np.exp(-x))def forward(self, x):h_prev = np.zeros((self.hidden_size,))concat_input = np.concatenate([x, h_prev], axis=0)z_t = self.sigmoid(np.dot(self.W_z, concat_input) + self.b_z)r_t = self.sigmoid(np.dot(self.W_r, concat_input) + self.b_r)concat_reset_input = np.concatenate([x, r_t * h_prev], axis=0)h_hat_t = self.tanh(np.dot(self.W_h, concat_reset_input) + self.b_h)h_t = (1 - z_t) * h_prev + z_t * h_hat_treturn h_t

二、BiLSTM

1、什么是BiSTM

BiSTM,即双向门控循环单元(Bidirectional Gated Recurrent Unit),是一种循环神经网络(RNN)的变体。它结合了前向和后向的GRU,能够同时处理过去和未来的信息,从而更好地捕捉序列数据中的上下文关系。

在BiSTM中,数据通过两个GRU网络进行处理:一个从左到右(前向),另一个从右到左(后向)。这两个网络的输出然后被拼接或相加,形成最终的特征表示,这个特征表示包含了序列的双向信息。这种结构特别适合于需要理解序列中前后文信息的任务,如文本分类、语音识别、命名实体识别(NER)等。

2、BiSTM的关键特点包括:

  1. 双向信息捕捉:BiSTM能够同时考虑序列中每个元素之前的和之后的上下文信息,这使得它在处理像文本这样的序列数据时非常有效,因为文本中词汇的含义往往受到其前后词汇的影响。

  2. 门控机制:BiSTM继承了GRU的门控机制,包括更新门和重置门,这些门控单元可以控制信息的流动,从而减少无效或噪声信息的干扰,并增强模型对重要信息的记忆能力。

  3. 应用广泛:BiSTM因其强大的序列处理能力而被广泛应用于各种领域,包括自然语言处理(NLP)、语音识别、时间序列分析等。

  4. 模型性能:在某些任务中,BiSTM能够提供比单向GRU或LSTM更好的性能,尤其是在需要捕捉长期依赖关系的任务中。

  5. 模型复杂度:由于BiSTM包含两个GRU网络,其模型参数和计算复杂度相对于单向GRU或LSTM会有所增加,但在很多情况下,这种增加是值得的,因为它能带来更准确的预测结果

 3、手写BiLSTM代码

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequenceclass LSTM(nn.Module):def __init__(self, vocab_size, target_size, input_size=512, hidden_size=512):super(LSTM, self).__init__()self.hidden_size = hidden_sizeself.embedding = nn.Embedding(vocab_size, input_size)self.mlp = nn.Sequential(nn.Linear(input_size, hidden_size),nn.GELU(),nn.Linear(hidden_size, hidden_size))self.lstm = nn.LSTM(hidden_size, hidden_size * 2, num_layers=3, batch_first=True, dropout=0.5)self.avg_lstm = nn.AdaptiveAvgPool1d(1)self.avg_linear = nn.AdaptiveAvgPool1d(1)self.out_linear = nn.Sequential(nn.Linear(hidden_size * 2 + hidden_size, hidden_size),nn.GELU(),nn.LayerNorm(hidden_size),nn.Linear(hidden_size, target_size))self.norm = nn.LayerNorm(hidden_size * 2)def forward(self, x, lengths):x = self.embedding(x)mlp = self.mlp(x)pached_embed = pack_padded_sequence(mlp, lengths, batch_first=True, enforce_sorted=False)lstm_out, _ = self.lstm(pached_embed)lstm_out, _ = pad_packed_sequence(lstm_out, batch_first=True)lstm_out = self.norm(lstm_out)avg_lstm = self.avg_lstm(lstm_out.permute(0, 2, 1)).squeeze(-1)avg_linear = self.avg_linear(mlp.permute(0, 2, 1)).squeeze(-1)out = torch.cat([avg_lstm, avg_linear], dim=-1)return self.out_linear(out)class BiLSTM(nn.Module):def __init__(self, input_size=512, hidden_size=512, output_size=512):super(BiLSTM, self).__init__()self.hidden_size = hidden_sizeself.lstm_forward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)self.lstm_backward = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)def forward(self, x):out_forward, _ = self.lstm_forward(x)out_backward, _ = self.lstm_backward(torch.flip(x, dims=[1]))out_backward = torch.flip(out_backward, dims=[1])combined_output = torch.cat([out_forward, out_backward], dim=-1)return combined_output

相关文章:

LSTM的变体

一、GRU 1、什么是GRU 门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了…...

LeetCode讲解篇之852. 山脉数组的峰顶索引

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 我们可以采用二分查找,每次查询区间中点元素与中点下一个元素比较 如果中点元素大于其下一个元素,则表示从中点开始向右是递减趋势,那峰值索引一定小于等于中点,我…...

矿井人员数据集,用于目标检测,深度学习,采用txt打标签,即yolo格式,也有原文件可以自己转换。总共3500张图片的数据量,划分给训练集2446张,

矿井人员数据集,用于目标检测,深度学习,采用txt打标签,即yolo格式,也有原文件可以自己转换。总共3500张图片的数据量,划分给训练集2446张: ### 矿井人员数据集用于目标检测的详细说明 #### 1. …...

消息队列RabbitMQ

文章目录 1. 简介与安装2. 基本概念3. SpringAMQP4. 交换机类型5. 消息转换器5.1 默认转换器5.2 配置JSON转换器 6 生产者的可靠性6.1 生产者超时重连机制6.2 生产者确认机制 6. MQ的可靠性6.1 数据持久化6.2 惰性队列 Lazy Queue 7. 消费者的可靠性7.1 消费者确认机制7.2 失败…...

RabbitMQ概述

什么是MQ MQ (message queue)消息队列 MQ从字⾯意思上看,本质是个队列,FIFO先⼊先出,只不过队列中存放的内容是消息(message).消息可以⾮常简单,⽐如只包含⽂本字符串,JSON等,也可以很复杂,⽐如内嵌对象 RabbitMQ是MQ的一种实现,是Rabbit 企业下的⼀个消息队列产…...

Golang学习路线

以下是一条学习Golang(Go语言)的路线: 一、基础入门 1. 环境搭建 安装Go编译器,在官网(https://golang.org/dl/)下载适合操作系统的安装包并配置好环境变量。 2. 语法学习学习变量、数据类型&#xff08…...

Flink从ck拉起任务脚本

#!/bin/bashAPP_NAME"orderTest"CHECKPOINT_BASE_PATH"hdfs:///jobs/flink/checkpoints/aaa-test/"is_running$(yarn application -list | grep -w "$APP_NAME" | grep -c "RUNNING")if [ $is_running -gt 0 ]; thenecho "应用程…...

GADBench Revisiting and Benchmarking Supervised Graph Anomaly Detection

Neurips 23 推荐指数: #paper/⭐⭐⭐ 领域:图异常检测 胡言乱语: neurips 的benchmark模块的文章总能给人一些启发性的理解,这篇的insight真有意思。个人感兴趣的地方会加粗。此外,这篇文章和腾讯AIlab合作&#xff…...

某象异形滑块99%准确率方案

注意,本文只提供学习的思路,严禁违反法律以及破坏信息系统等行为,本文只提供思路 如有侵犯,请联系作者下架 该文章模型已经上线ocr识别网站,欢迎测试!!,地址:https://yxlocr.windy-rain.cn/ocr/slider/6 所谓的顶象异形滑块,是指没有采用常规的缺口,使用各种形状的…...

CDN绕过学习

1.什么是CDN? CDN就是分布在各个地区的服务器,这些服务器储存着数据的副本。 哪些服务器比较接近你,当你发起请求时,提前就会快速为你提供服务。 总结来说就是: 其实就是用来加速访问的,以及缓解压力&a…...

SpringBoot日常:redission的接入使用和源码解析

文章目录 一、简介二、集成redissionpom文件redission 配置文件application.yml文件启动类 三、JAVA 操作案例字符串操作哈希操作列表操作集合操作有序集合操作布隆过滤器操作分布式锁操作 四、源码解析 一、简介 Redisson 是一个在 Redis 的基础上实现的 Java 驻内存数据网格…...

npm包管理深度探索:从基础到进阶全面教程!

目录 一、npm概述1、npm简介(1)什么是npm?(2)npm的核心功能(3)npm的工作原理(4)npm的优势(5)npm的局限性(6)总结 2、npm的…...

最新免费GPT4O和Midjourney

一、什么是GPT4O? GPT-4 是 OpenAI 研发的大型语言模型。它具有强大的语言理解和生成能力,在自然语言处理等诸多领域有着广泛的应用和表现。 二、什么是Midjourney? Midjourney 是一款人工智能图像生成工具。它可以根据用户输入的描述或提…...

python操作OpenAI教程

python操作OpenAI pip install -U openai代码: from openai import OpenAI# 解决请求超时问题 import os os.environ["http_proxy"] "http://localhost:7890" os.environ["https_proxy"] "http://localhost:7890"# 需要…...

如何版本REST API:综合指南

目录 总则什么是REST API版本控制?为什么API版本控制很重要?如何对REST API进行版本控制 理解API契约评估需求选择版本控制策略沟通变化保持向后兼容性详细文档记录REST API版本控制最佳实践REST API版本控制常见问题:REST API版本控制总则 版本化REST API对于确保软件应用…...

Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南

Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南 文章目录 Docker 环境下 Nginx 监控实战:使用 Prometheus 实现 Nginx 性能监控的完整部署指南一 查看模块是否安装二 配置 status 访问端点三 Docker 部署 nginx-prome…...

网络安全-IPv4和IPv6的区别

1. 2409:8c20:6:1135:0:ff:b027:210d。 这是一个IPv6地址。IPv6(互联网协议版本6)是用于标识网络中的设备的一种协议,它可以提供比IPv4更大的地址空间。这个地址由八组十六进制数字组成,每组之间用冒号分隔。IPv6地址通常用于替代…...

【移动端】事件基础

一、移动端事件分类 移动端事件主要分为以下几类: 1. 触摸事件(Touch Events) 触摸事件是移动设备特有的事件,用来处理用户通过触摸屏幕进行的操作。主要的触摸事件有: touchstart:手指触摸屏幕时触发。…...

软件测试比赛-学习

一、环境配置 二、浏览器适配 //1.设置浏览器的位置,google浏览器位置是默认且固定在电脑里的//2.设置浏览器驱动的位置,C:\Users\27743\AppData\Local\Google\Chrome\ApplicationSystem.setProperty("webdriver.chrome.driver", "C:\\Users\\27743\\AppData\\…...

力扣LeetCode-链表中的循环与递归使用

标题做题的时候发现循环与递归的使用差别: 看两道题: 两道题都是不知道链表有多长,所以需要用到循环,用到循环就可以把整个过程分成多个循环体,就是每一次循环要执行的内容。 反转链表: 把null–>1…...

多模态跨语言翻译引擎实战指南:本地化部署与场景化应用

多模态跨语言翻译引擎实战指南:本地化部署与场景化应用 【免费下载链接】seamless-m4t-v2-large 项目地址: https://ai.gitcode.com/hf_mirrors/ai-gitcode/seamless-m4t-v2-large 在全球化协作日益频繁的今天,跨语言翻译已成为打破沟通壁垒的核…...

CCF推荐C类会议与期刊全景解析:计算机网络研究者的学术地图

1. CCF推荐C类会议与期刊:计算机网络研究者的学术指南针 刚进入计算机网络领域的研究生常常会面临一个困惑:面对海量的学术会议和期刊,到底该从哪里入手?中国计算机学会(CCF)推荐的C类会议和期刊就像一张精…...

JavaScript注释的艺术:gh_mirrors/js/js教你写出自解释代码

JavaScript注释的艺术:gh_mirrors/js/js教你写出自解释代码 【免费下载链接】js :art: A JavaScript Quality Guide 项目地址: https://gitcode.com/gh_mirrors/js/js 在JavaScript开发中,注释是代码质量的重要组成部分,但很多开发者误…...

悟空率先接入国产最强编程模型Qwen3.6-Plus

4月2日,阿里巴巴正式发布新一代大语言模型Qwen3.6-Plus,阿里在企业级市场的旗舰AI应用悟空率先完成接入。Qwen3.6-Plus在代码、智能体、推理、原生多模态等能力上整体性能大幅增强,在智能体编程SWE-bench系列评测、真实世界智能体任务Claw-Ev…...

JavaScript高效PPTX文档处理方案:js-pptx深度解析与实战指南

JavaScript高效PPTX文档处理方案:js-pptx深度解析与实战指南 【免费下载链接】js-pptx Pure Javascript reader/writer for PowerPoint 项目地址: https://gitcode.com/gh_mirrors/js/js-pptx 在当今数字化办公环境中,PowerPoint演示文稿的自动化…...

效率提升300%:OpenClaw+Phi-3-vision-128k-instruct重构我的学术工作流

效率提升300%:OpenClawPhi-3-vision-128k-instruct重构我的学术工作流 1. 从手动到自动的学术工作流革命 作为一名每天需要处理大量文献、实验数据和演示材料的科研工作者,我曾经花费近40%的工作时间在重复性文档处理上——截图标注、图表整理、笔记归…...

SeamlessM4T v2:跨语言实时对话的终极解决方案与技术实践

SeamlessM4T v2:跨语言实时对话的终极解决方案与技术实践 【免费下载链接】seamless-m4t-v2-large 项目地址: https://ai.gitcode.com/hf_mirrors/ai-gitcode/seamless-m4t-v2-large 在全球化协作日益频繁的今天,跨语言沟通已成为技术团队、跨国…...

Golang怎么用Task替代Makefile_Golang如何用go-task编写跨平台的任务脚本文件【教程】

go-task 是用 Go 编写的跨平台任务编排工具,本质区别于 Makefile:它用 YAML 定义任务、不依赖 shell 缩进、默认不继承父环境变量、无增量构建、支持变量注入与平台条件判断,且单文件分发。go-task 是什么,和 Makefile 有什么本质…...

Captain AI帮你一次过审,上品不再被驳回!

Ozon上品审核驳回、上架后违规下架,是90%以上卖家都踩过的坑。很多卖家遇到上品问题,会用DeepSeek等通用AI查询规则,却往往因为信息滞后、规则解读错误,反复修改仍无法过审,白白错过新品流量黄金期。一、Captain AI能帮…...

智能客服VS语音转写:不同场景下语音识别评估指标的选择指南

智能客服与语音转写:业务场景驱动的语音识别评估指标决策框架 当企业考虑部署语音识别系统时,技术团队常会抛出一堆专业术语:WER 15%、CER 8%、SER 22%...但对产品经理和解决方案架构师而言,这些数字背后意味着什么?选…...