Python----循环神经网络(BiLSTM:双向长短时记忆网络)
一、LSTM 与 BiLSTM对比
1.1、LSTM
LSTM(长短期记忆网络) 是一种改进的循环神经网络(RNN),专门解决传统RNN难以学习长期依赖的问题。它通过遗忘门、输入门和输出门来控制信息的流动,保留重要信息并丢弃无关内容,从而有效处理长序列数据。LSTM的核心是细胞状态,它像一条传送带,允许信息在不同时间步之间稳定传递,避免梯度消失或爆炸,适用于时间序列预测、语音识别等任务。
1.2、BiLSTM
BiLSTM(双向长短期记忆网络) 在LSTM的基础上增加反向处理层,同时捕捉过去和未来的上下文信息。前向LSTM按时间顺序处理序列,后向LSTM逆序处理,最终结合两个方向的输出,增强模型对全局上下文的理解。BiLSTM在自然语言处理任务(如机器翻译、命名实体识别)中表现优异,但计算成本更高。它特别适合需要双向信息交互的场景,如语义理解、情感分析等。
BiLSTM结构包含两个方向的LSTM网络:一个正向(forward)LSTM和一 个反向(backward)LSTM。
这两个方向的LSTM在模型训练过程中分别处理输入序列,最后的隐藏状态 由这两个方向的LSTM拼接而成。这样的结构使得模型能够同时考虑到输入 序列中每个位置的过去和未来信息,更全面地捕捉序列中的上下文信息。
如下面这个情感分类的例子,正向的LSTM按照从左到右的顺序处理“我”、 “爱”、“你”,反向的LSTM按照从右到左的顺序处理“你”、“爱”、“我”,然后 将两个LSTM的最后一个隐藏层拼接起来再经过softmax等处理得到分类结果。
举一个例子,如一句话“我今天很开心,因为我考试考了 100 分”要做情感 分类,LSTM只能从左到右的看,因此在看到“很开心”这个关键词时它获得 的只有上文的信息,而BiLSTM是双向的因此也能看到“因为我考试考了 100 分”这一部分,而这一部分对应最终结果是否准确有很大的帮助。
特征 | LSTM | BiLSTM |
---|---|---|
方向性 | 单向(仅过去信息) | 双向(过去和未来信息) |
计算复杂度 | 较低 | 较高(约2倍) |
典型应用 | 时间序列预测、语言模型 | 文本分类、序列标注、机器翻译 |
内存需求 | 较少 | 较多 |
13、优势
BiLSTM相对于单向LSTM具有以下优势:
能够捕捉到输入序列中每个位置的过去和未来信息,更全面地捕捉序列 中的上下文信息。
可以更好地处理长距离的依赖关系。
在许多自然语言处理任务中都取得了良好的效果。
二、库函数-LSTM
torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)
LSTM — PyTorch 2.7 documentation
参数 | 描述 |
input_size | 输入 x 中预期特征的数量 |
hidden_size | 处于隐藏状态 h 的特征数量 |
num_layers | 循环层数。例如,设置意味着将两个 LSTM 堆叠在一起以形成一个堆叠的 LSTM。 第二个 LSTM 接收第一个 LSTM 的输出,并且 计算最终结果。默认值:1num_layers=2 |
bias | 偏置如果 ,则层不使用 b_ih 和 b_hh 的偏差权重。 违约:False True |
batch_first | 如果 ,则提供输入和输出张量 作为 (batch, seq, feature) 而不是 (seq, batch, feature)。 请注意,这不适用于隐藏状态或单元格状态。请参阅 Inputs/Outputs 部分了解详细信息。违约:True False |
dropout | 如果为非零,则在每个 除最后一层外的 LSTM 层,其 dropout 概率等于 。默认值:0dropout |
bidirectional | 如果 ,则变为双向 LSTM。违约:
|
proj_size | 如果 ,将使用 LSTM 和相应大小的投影。默认值:0 |
import torch
import numpy as np
from torch import nn# 1.字符输入
text = "In Beijing Sarah bought a basket of apples In Guangzhou Sarah bought a basket of bananas"torch.manual_seed(1)# 3.数据集划分
input_seq = [text[:-1]]
output_seq = [text[1:]]
print("input_seq:", input_seq)
# print("output_seq:", output_seq)# 4.数据编码:one-hot
chars = set(text)
chars = sorted(chars)
# print("chars:", chars)
# {" ":0, "a":1 }
char2int = {char: ind for ind, char in enumerate(chars)}
# print("char2int:", char2int)
# {0:" ", 1: "a"}
int2char = dict(enumerate(chars))# 将字符转成数字编码
input_seq = [[char2int[char] for char in seq] for seq in input_seq]
# print("input_seq:", input_seq)
output_seq = [[char2int[char] for char in seq] for seq in output_seq]# one-hot 编码,pytorch的RNN的输入张量的填充
def one_hot_encode(seq, bs, seq_len, size):features = np.zeros((bs, seq_len, size), dtype=np.float32)for i in range(bs):for u in range(seq_len):features[i, u, seq[i][u]] = 1.0return torch.tensor(features, dtype=torch.float32)input_seq = one_hot_encode(input_seq, 1, len(text)-1, len(chars))
output_seq = torch.tensor(output_seq, dtype=torch.long).view(-1)
print("output_seq:", output_seq)# 5.定义前向模型
class Model(nn.Module):def __init__(self, input_size, hidden_size, out_size):super(Model, self).__init__()self.hidden_size = hidden_sizeself.bilstm1 = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=True)self.fc1 = nn.Linear(hidden_size * 2, out_size)def forward(self, x):out, hidden = self.bilstm1(x)x = out.contiguous().view(-1, self.hidden_size * 2)x = self.fc1(x)return x, hiddenmodel = Model(len(chars), 32, len(chars))# 6.定义损失函数和优化器
cri = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)# 7.开始迭代
epochs = 1000
for epoch in range(1, epochs+1):output, hidden = model(input_seq)loss = cri(output, output_seq)optimizer.zero_grad()loss.backward()optimizer.step()# 8.显示频率设置if epoch == 0 or epoch % 50 == 0:print(f"Epoch [{epoch}/{epochs}], Loss {loss:.4f}")# print("input_seq.shape:", input_seq.shape)
# print("hidden.shape:", hidden.shape)
# print("output.shape:", output.shape)
# print("input_w:", model.rnn1.weight_ih_l0.shape)# 预测下面几个字符
input_text = "In Beijing Sarah bought a basket of" # re
to_be_pre_len = 20for i in range(to_be_pre_len):chars = [char for char in input_text]# print(chars)character = np.array([[char2int[c] for c in chars]])character = one_hot_encode(character, 1, character.shape[1], 23)character = torch.tensor(character, dtype=torch.float32)out, hidden = model(character)char_index = torch.argmax(out[-1]).item()input_text += int2char[char_index]
print("预测到的:", input_text)
相关文章:

Python----循环神经网络(BiLSTM:双向长短时记忆网络)
一、LSTM 与 BiLSTM对比 1.1、LSTM LSTM(长短期记忆网络) 是一种改进的循环神经网络(RNN),专门解决传统RNN难以学习长期依赖的问题。它通过遗忘门、输入门和输出门来控制信息的流动,保留重要信息并丢弃无关…...
Elasticsearch 常用操作命令整合 (cURL 版本)
Elasticsearch 常用操作命令整合 (cURL 版本) 集群管理 查看集群健康状态 curl -X GET "localhost:9200/_cluster/health?pretty"查看节点信息 curl -X GET "localhost:9200/_cat/nodes?v"查看集群统计信息 curl -X GET "localhost:9200/_clus…...
Redis持久化策略:RDB与AOF详解
目录 1. RDB持久化工作原理触发机制优点缺点配置示例 2. AOF持久化工作原理同步策略重写机制优点缺点配置示例 3. RDB与AOF比较4. 混合持久化(Redis 4.0)5. 选择建议 Redis提供了两种主要的持久化机制来保证数据安全:RDB(Redis Database)和AOF(Append Only File)。本…...

Linux系统编程-DAY10(TCP操作)
一、网络模型 1、服务器/客户端模型 (1)C/S:client server (2)B/S:browser server (3)P2P:peer to peer 2、C/S与B/S区别 (1)客户端不同&#…...

基于eclipse进行Birt报表开发
Birt报表开发最终实现效果: 简洁版的Birt报表开发实现效果,仅供参考! 可动态获取采购单ID,来打印出报表! 下面开始Birt报表开发教程: 首先:汉化的eclipse及Birt值得拥有:至少感觉上…...

GPU虚拟化
引言 现有如下环境(注意相关配置:只有一个k8s节点,且该节点上只有一张GPU卡): // k8s版本 $ kubectl version Client Version: version.Info{Major:"1", Minor:"22", GitVersion:"v1.22.7&…...

LabVIEW工业级多任务实时测控系统
采用LabVIEW构建了一套适用于工业自动化领域的多任务实时测控系统。系统采用分布式架构,集成高精度数据采集、实时控制、网络通信及远程监控等功能,通过硬件与软件的深度协同,实现对工业现场多类型信号的精准测控,展现 LabVIEW 在…...
Python学习(7) ----- Python起源
🐍《Python 的诞生》:一段圣诞假期的奇妙冒险 📍时间:1989 年圣诞节 在荷兰阿姆斯特丹的一个寒冷冬夜,灯光昏黄、窗外飘着雪。一个程序员 Guido van Rossum 正窝在家里度假——没有会议、没有项目、没有 bug…...
Java中List的forEach用法详解
在 Java 中,List.forEach() 是 Java 8 引入的一种简洁的遍历集合元素的方法。它基于函数式编程思想,接受一个 Consumer 函数式接口作为参数,用于对集合中的每个元素执行操作。 基本语法 java 复制 下载 list.forEach(consumer); 使用示…...
LeetCode 1356.根据数字二进制下1的数目排序
题目: 给你一个整数数组 arr 。请你将数组中的元素按照其二进制表示中数字 1 的数目升序排序。 如果存在多个数字二进制中 1 的数目相同,则必须将它们按照数值大小升序排列。 请你返回排序后的数组。 提示: 1 < arr.length < 5000…...

破解HTTP无状态:基于Java的Session与Cookie协同工作指南
HTTP协议自身是属于“无状态”协议 无状态是指:默认情况下,HTTP协议的客户端和服务器之间的这次通信,和下次通信之间没有直接的关系 但在实际开发中,我们很多时候是需要知道请求之间的关联关系的 上述图中的令牌,通常就…...

JS 事件流机制详解:冒泡、捕获与完整事件流
JS 事件流机制详解:冒泡、捕获与完整事件流 文章目录 JS 事件流机制详解:冒泡、捕获与完整事件流一、DOM 事件流基本概念二、事件捕获 (Event Capturing)特点代码示例 三、事件冒泡 (Event Bubbling)特点代码示例 四、完整事件流示例HTML 结构JavaScript…...
MYSQL too many connection问题排查和修复
1.连接数据库 mysql -u root -p 1.1 查看mysql路径 如果没有配置mysql的环境变量,可以直接找mysql的安装目录 打开任务管理器-》服务-》Mysql(根据版本不同后面带有数字,找运行的那个) 打开服务->mysql->属性-》可执行文件的路径,…...
SpringCloudAlibaba和SpringBoot版本问题
SpringCloudAlibaba和SpringBoot版本问题 直接参考官方给出的版本说明,具体地址:https://github.com/alibaba/spring-cloud-alibaba/wiki/%E7%89%88%E6%9C%AC%E8%AF%B4%E6%98%8E Spring Cloud Alibaba VersionSentinel VersionNacos VersionRocketMQ Ver…...

算法专题七:分治
快排 1.颜色分类 题目链接:75. 颜色分类 - 力扣(LeetCode) class Solution {public void swap(int[] nums, int i, int j){int t = nums[i];nums[i] = nums[j];nums[j] = t;}public void sortColors(int[] nums) {int left=-1 ,i=0 ,right=nums.length;while(i<right){i…...

Vue中虚拟DOM的原理与作用
绪论 首先我们先了解,DOM(Document Object Model,文档对象模型) 是浏览器对 HTML/XML 文档的结构化表示,它将文档解析为一个由节点(Node)和对象组成的树形结构(称为 DOM 树…...
前端十种排序算法解析
1. 冒泡排序 1.1 说明 冒泡排序为一种常用排序算法,执行过程为从数组的第一个位置开始,相邻的进行比较,将最大的数移动到数组的最后位置执行的时间复杂度与空间复杂度为 o(n^2) 1.2 执行过程 从数组的第一个位置开始,截止位置为 …...
使用 C/C++ 和 OpenCV 添加图片水印
使用 C/C 和 OpenCV 添加图片水印 🖼️ 在数字图像处理中,添加水印是一种常见的操作,可以用于版权保护、品牌宣传或信息标注。本文将介绍如何使用 C/C 和强大的计算机视觉库 OpenCV 来实现将自定义水印(图片或文字)添…...
Secs/Gem第十二讲(基于secs4net项目的ChatGpt介绍)
好,那我们进入最关键的一讲—— 第十二讲:完整事件通知流程全景图——CEID 触发到主机接收的全过程 关键词:CEID 事件上报、S6F11 报文、事件触发流程、数据驱动机制、Report Dispatch、主机解析流程 本讲目标 你将彻底理解: 设…...
FastAPI实战起步:从Python环境到你的第一个“Hello World”API接口
上一篇文章中介绍了有关FastAPI的优势,本篇文章我将手把手带你从零开始,搭建FastAPI的开发环境,并成功运行你的第一个“Hello World”API。在开始之前,请确保你的电脑已经安装了Python 3.7或更高版本,以及VS Code&…...

使用vue3+ts+input封装上传组件,上传文件显示文件图标
效果图: 代码 <template><div class"custom-file-upload"><div class"upload"><!-- 显示已选择的文件 --><div class"file-list"><div v-for"(item, index) in state.filsList" :key&q…...
iOS 抖音导航栏首页一键分两列功能的实现
要实现 iOS 抖音首页导航栏的“一键分两列”功能(通常指将单列内容切换为双列瀑布流布局),需结合自定义导航栏控件与布局动态切换逻辑。以下是关键实现步骤和技术要点,基于 iOS 原生开发框架(Swift/Objective-C&#x…...
零基础入门 C 语言基础知识(含面试题):结构体、联合体、枚举、链表、环形队列、指针全解析!
🌟 零基础入门 C 语言基础知识(含面试题):结构体、联合体、枚举、链表、环形队列、指针全解析! C 语言是所有程序员通向“系统世界”的第一把钥匙。很多嵌入式开发、操作系统内核、网络通信、图形引擎,背后…...

【Linux】Ubuntu 创建应用图标的方式汇总,deb/appimage/通用方法
Ubuntu 创建应用图标的方式汇总,deb/appimage/通用方法 对于标准的 Ubuntu(使用 GNOME 桌面),desktop 后缀的桌面图标文件主要保存在以下三个路径: 当前用户的桌面目录(这是最常见的位置)。所…...
【Unity】R3 CSharp 响应式编程 - 使用篇(集合)(三)
1、ObservableList 基础 List 类型测试 using System;using System.Collections.Specialized;using ObservableCollections;using UnityEngine;namespace Aladdin.Standard.Observable.Collections.List{public class ObservableListTest : MonoBehaviour{protected readonly O…...
振动力学:弹性杆的纵向振动(固有振动和固有频率的概念)
文章1、2、3中讨论的是离散系统的振动特性,然而实际系统的惯性质量、弹性、阻尼等特性都是连续分布的,因而成为连续系统或分布参数系统。确定连续介质中无数个点的运动需要无限个广义坐标,因此也称为无限自由度系统,典型的结构例如:弦、杆、膜、环、梁、板、壳等,也称为弹…...

LangGraph--Agent工作流
Agent的工作流 下面展示了如何创建一个“计划并执行”风格的代理。 这在很大程度上借鉴了 计划和解决 论文以及Baby-AGI项目。 核心思想是先制定一个多步骤计划,然后逐项执行。完成一项特定任务后,您可以重新审视计划并根据需要进行修改。 般的计算图如…...

Spring Boot 常用注解面试题深度解析
🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot 常用注解面试题深度解析一、核心…...

Linux系统的CentOS7发行版安装MySQL80
文章目录 前言Linux命令行内的”应用商店”安装CentOS的安装软件的yum命令安装MySQL1. 配置yum仓库2. 使用yum安装MySQL3. 安装完成后,启动MySQL并配置开机自启动4. 检查MySQL的运行状态 MySQL的配置1. 获取MySQL的初始密码2. 登录MySQL数据库系统3. 修改root密码4.…...

408第一季 - 数据结构 - 栈与队列
栈 闲聊 栈是一个线性表 栈的特点是后进先出 然后是一个公式 比如123要入栈,一共有5种排列组合的出栈 栈的数组实现 这里有两种情况,,一个是有下标为-1的,一个没有 代码不用看,真题不会考 栈的链式存储结构 L ->…...