当前位置: 首页 > news >正文

深度学习使用LSTM实现时间序列预测

大家好,LSTM是一种特殊的循环神经网络(RNN)架构,它被设计用来解决传统RNN在处理长序列数据时的梯度消失和梯度爆炸问题,特别是在时间序列预测、自然语言处理和语音识别等领域中表现出色。LSTM的核心在于其独特的门控机制,这些门控机制允许网络动态地决定信息的流动,从而能够学习到长期依赖关系。本文将从背景与原理、数据预处理、LSTM模型构建与训练等方面进行介绍,用LSTM预测未来一周的天气变化。

1. 基本原理

简单来说,LSTM 是 RNN 的一种,它通过引入“记忆单元”来捕捉长时间的依赖关系,使其在处理长期依赖问题时非常有效。对于天气数据的预测,LSTM特别适用,因为天气数据是高度时序依赖的。例如,某一天的温度和湿度可能会受到前几天数据的影响,这些“依赖关系”是LSTM所擅长捕捉的。

LSTM 用于解决普通RNN在处理长序列时常见的梯度消失和梯度爆炸问题,其核心特点是引入了“记忆单元”(cell state)和三个“门”机制(遗忘门、输入门、输出门)来控制信息的流动。

1.1 基本结构

LSTM单元的主要结构包括:

  • 记忆单元(Cell State) :用于存储长期的信息。记忆单元在时间上连接,不同时间步的数据可以选择性地被保留或丢弃,这使得LSTM可以“记住”长期的信息。

  • 隐藏状态(Hidden State) :与普通RNN的隐藏状态类似,用于存储短期信息,但在LSTM中,隐藏状态还依赖于记忆单元的状态。

1.2 三个“门”机制

LSTM中的三个门分别用于控制信息的“遗忘”“更新”和“输出”:

遗忘门的目的是决定哪些信息应该从单元状态中被遗忘或丢弃。它基于当前的输入和前一个时间步的隐藏状态来计算。遗忘门的输出是一个介于0和1之间的值,接近1表示“保留信息”,接近0表示“遗忘信息”。

输入门包含两部分:一部分决定是否更新单元状态,另一部分决定新输入的信息。输入门由两组sigmoid层和一个tanh层组成。决定当前输入信息是否写入记忆单元中,用于更新记忆内容。输入门同样通过sigmoid函数生成一个0到1的值,表示当前输入数据的重要性。

输出门的目的是决定当前的单元状态如何贡献到下一个隐藏状态,它基于当前的单元状态和前一个时间步的隐藏状态来计算。

1.3 LSTM 整体流程

通过上述过程,LSTM在每个时间步的操作可以概括为以下步骤:

  1. 计算遗忘门,决定旧记忆单元信息的遗忘比例。

  2. 计算输入门和候选记忆单元,决定新信息对记忆单元的更新比例。

  3. 更新记忆单元,结合遗忘门和输入门的结果,形成新的记忆状态。

  4. 计算输出门,控制隐藏状态的生成。

  5. 根据记忆单元和输出门,计算新的隐藏状态,并传递给下一个时间步。

通过这种记忆单元状态的更新与控制机制,LSTM能够有效地在较长的序列中保持记忆,从而适用于时间序列预测等长时序依赖的任务。

2. 数据预处理与虚拟数据集生成

实际数据非常大不利于学习,为了更好理解算法本身,构建一个虚拟天气数据集,包括温度、湿度、风速等变量。假设我们有一年的历史数据,每日更新。我们将模拟这些数据并将其用于训练和测试。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt# 生成虚拟天气数据集
np.random.seed(42)
days = 365  # 一年数据
temperature = 30 + 5 * np.sin(np.linspace(0, 2 * np.pi, days)) + np.random.normal(0, 1, days)
humidity = 50 + 10 * np.sin(np.linspace(0, 2 * np.pi, days)) + np.random.normal(0, 2, days)
wind_speed = 10 + 3 * np.sin(np.linspace(0, 2 * np.pi, days)) + np.random.normal(0, 1, days)data = pd.DataFrame({'temperature': temperature,'humidity': humidity,'wind_speed': wind_speed
})data.head()

在训练模型之前,需要将数据标准化,以便LSTM能够更有效地学习数据特征。

from sklearn.preprocessing import MinMaxScalerscaler = MinMaxScaler(feature_range=(0, 1))
data_scaled = scaler.fit_transform(data)

3. LSTM模型构建与训练

3.1 数据切分

数据切分是机器学习中的一个重要步骤,它涉及将数据集划分为不同的部分,以便于模型的训练和验证。将数据分为训练集和测试集(80%训练,20%测试):

train_size = int(len(data_scaled) * 0.8)
train_data = data_scaled[:train_size]
test_data = data_scaled[train_size:]def create_sequences(data, seq_length):xs, ys = [], []for i in range(len(data) - seq_length):x = data[i:i+seq_length]y = data[i+seq_length]xs.append(x)ys.append(y)return np.array(xs), np.array(ys)seq_length = 7  # 用前7天的数据预测第8天
X_train, y_train = create_sequences(train_data, seq_length)
X_test, y_test = create_sequences(test_data, seq_length)

3.2 模型定义

导入PyTorch及其相关模块,使用PyTorch构建LSTM模型,创建一个继承自torch.nn.Module的类,并在其中定义LSTM层和其他必要的层。在模型类的构造函数中初始化LSTM层和其他层,定义模型如何根据输入数据进行前向传播。

import torch
import torch.nn as nnclass WeatherLSTM(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(WeatherLSTM, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.lstm(x)out = self.fc(out[:, -1, :])return out# 定义超参数
input_size = 3  # 特征数:温度、湿度、风速
hidden_size = 64
output_size = 3
num_layers = 1model = WeatherLSTM(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

3.3 模型训练

import torch.optim as optimnum_epochs = 100
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(torch.Tensor(X_train))loss = criterion(outputs, torch.Tensor(y_train))loss.backward()optimizer.step()if (epoch+1) % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

4. 预测与可视化分析

模型训练完成后,对测试集进行预测,使用图形展示结果。

model.eval()
with torch.no_grad():predicted = model(torch.Tensor(X_test)).detach().numpy()predicted = scaler.inverse_transform(predicted)actual = scaler.inverse_transform(y_test)# 转为DataFrame便于可视化
predicted_df = pd.DataFrame(predicted, columns=['temperature', 'humidity', 'wind_speed'])
actual_df = pd.DataFrame(actual, columns=['temperature', 'humidity', 'wind_speed'])

对模型预测结果进行展示,具体包括以下信息:

  • 温度预测结果:展示LSTM对温度的预测与实际值的比较。

  • 湿度预测结果:展示LSTM对湿度的预测与实际值的差距。

  • 风速预测结果:分析风速的预测效果。

  • 多特征趋势对比:对比所有特征在不同时间段的预测效果。

colors = ['#1f77b4', '#ff7f0e']  # 蓝色:实际值,橙色:预测值
fig, axes = plt.subplots(3, 1, figsize=(12, 10))# 标题和字体设置
fig.suptitle('Weather Prediction Using LSTM', fontsize=16, weight='bold')# 温度预测图
axes[0].plot(actual_df['temperature'], color=colors[0], label='Actual Temperature', linewidth=1.5)
axes[0].plot(predicted_df['temperature'], color=colors[1], linestyle='--', label='Predicted Temperature', linewidth=1.5)
axes[0].set_title('Temperature Prediction', fontsize=14, weight='bold')
axes[0].set_ylabel('Temperature (°C)', fontsize=12)
axes[0].legend(fontsize=10, loc='upper right')
axes[0].grid(alpha=0.3)# 湿度预测图
axes[1].plot(actual_df['humidity'], color=colors[0], label='Actual Humidity', linewidth=1.5)
axes[1].plot(predicted_df['humidity'], color=colors[1], linestyle='--', label='Predicted Humidity', linewidth=1.5)
axes[1].set_title('Humidity Prediction', fontsize=14, weight='bold')
axes[1].set_ylabel('Humidity (%)', fontsize=12)
axes[1].legend(fontsize=10, loc='upper right')
axes[1].grid(alpha=0.3)# 风速预测图
axes[2].plot(actual_df['wind_speed'], color=colors[0], label='Actual Wind Speed', linewidth=1.5)
axes[2].plot(predicted_df['wind_speed'], color=colors[1], linestyle='--', label='Predicted Wind Speed', linewidth=1.5)
axes[2].set_title('Wind Speed Prediction', fontsize=14, weight='bold')
axes[2].set_ylabel('Wind Speed (km/h)', fontsize=12)
axes[2].set_xlabel('Days', fontsize=12)
axes[2].legend(fontsize=10, loc='upper right')
axes[2].grid(alpha=0.3)# 调整布局并显示
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()

使用了三个基本的折线图来对比LSTM模型在温度、湿度和风速预测方面的实际值和预测值:

 温度预测的图形展示了LSTM模型对温度时间序列的捕捉能力。如果预测线能够紧密跟随实际温度曲线,说明模型能较好地捕捉温度的变化趋势。如果偏差较大,则需要调整模型复杂度或序列长度。

湿度预测的图形反映了LSTM对湿度时序变化的拟合效果。通常湿度变化较温度更不规则,因此湿度预测的误差可能更大,这提示我们可以考虑将湿度数据的平滑度处理,减少噪声。

风速图形反映了模型在风速数据上的预测效果。如果预测值偏差较大,可能说明风速的时序特征在当前的LSTM结构下未能得到充分捕捉,这时可以尝试增加风速数据的周期性特征,或调整输入序列长度。

5. 模型优化方向

LSTM模型的性能在很大程度上依赖于参数设置和数据处理,下面论述一些比较重要的方面。

5.1 隐藏层数量和单元数优化

在 LSTM 中,隐藏层数量和每一层的隐藏单元数会影响模型的复杂度。通常情况下,较高的隐藏单元数和更多的LSTM层能够捕捉更复杂的时序特征,但过多的隐藏单元数和层数可能导致过拟合。因此可以尝试:

  • 单层LSTM vs 多层LSTM:从1层开始,如果模型效果不理想可以尝试增加到2-3层,逐渐观察效果的提升。

  • 单元数(Hidden Units):一般来说,选择16、32、64、128等值逐步增加,同时注意训练时间和过拟合的风险。

5.2 学习率调整

学习率是优化器的重要参数之一,它决定了每次参数更新的步长。在训练过程中,可以使用学习率衰减策略,即随着训练轮次增加逐步减小学习率,帮助模型在接近最优点时更加平稳地收敛。常见策略:

  • Step Decay:每隔一定轮次将学习率缩小至原来的某个比例(如0.1倍)。

  • Exponential Decay:每次更新时将学习率按指数函数递减。

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
for epoch in range(num_epochs):# 模型训练代码...optimizer.step()scheduler.step()  # 调整学习率

5.3 正则化手段

LSTM 模型可能会因数据有限而出现过拟合问题,适当的正则化手段可以提高模型的泛化能力:

  • Dropout:LSTM层中添加dropout可以有效防止过拟合。

  • L2正则化:在损失函数中添加L2惩罚项,限制权重的过大波动。

self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.2)

5.4 批量大小调整

批量大小决定了每次训练中使用的数据量,合适的批量大小(如32、64、128等)在计算效率和泛化性能上会有较好的平衡。对于时间序列数据,一般来说,较小的批量可以帮助捕捉更多的特征信息。

6. 调参流程

在优化模型时,系统化的调参流程能够提高效率并找到最佳参数组合。推荐的几个调参方式:

  • 确定基本模型结构:先从简单的LSTM结构入手,比如1层LSTM,16个隐藏单元,学习率0.01。

  • 逐步增加复杂度:根据模型初始结果,逐渐增加隐藏单元数或层数,并观察训练集和测试集的误差变化。

  • 优化学习率和批量大小:通过实验不同的学习率(0.01,0.001等)和批量大小,找到误差最小且收敛速度较快的组合。

  • 添加正则化项:当模型效果较好但存在过拟合时,添加正则化手段(如Dropout)并调整比例(如0.1、0.2等)。

  • 迭代实验:通过实验记录并分析结果曲线,继续微调参数,直至得到满意的结果。

相关文章:

深度学习使用LSTM实现时间序列预测

大家好,LSTM是一种特殊的循环神经网络(RNN)架构,它被设计用来解决传统RNN在处理长序列数据时的梯度消失和梯度爆炸问题,特别是在时间序列预测、自然语言处理和语音识别等领域中表现出色。LSTM的核心在于其独特的门控机…...

Vue第一篇:组件模板总结

前言 本文希望读者有一定的Vue开发经验&#xff0c;样例采用vue中的单文件组件&#xff0c;也是我的个人笔记&#xff0c;欢迎一起进步 必须有根元素 这是一个最简单的vue单文件组件&#xff0c;<template></template>被称为模板&#xff0c;模板中必须有一个根元素…...

时钟使能、

时钟使能 如果正确使用&#xff0c;时钟使能能够显著地降低系统功耗&#xff0c;同时对面积或性能的影响极小。但是如果不正确地使用时钟使能&#xff0c; 可能会造成下列后果&#xff1a; • 面积增大 • 密度减小 • 功耗上升 • 性能下降 在许多使用大量控制集的…...

1. Autogen官网教程 (Introduction to AutoGen)

why autogen The whole is greater than the sum of its parts.(整体的功能或价值往往超过单独部分简单相加的总和。) -Aristotle autogen 例子 1. 导入必要的库 首先&#xff0c;导入os库和autogen库中的ConversableAgent类。 import os from autogen import Conversable…...

开源账目和账单

开源竞争&#xff1a; 开源竞争&#xff08;当你无法彻底掌握技术的时候&#xff0c;你就开源这个技术&#xff0c;让更多的人了解这个技术&#xff0c;形成更多的技术依赖&#xff0c;你会说这不就是在砸罐子吗&#xff1f;一个行业里面总会有人砸罐子&#xff0c;你不如先砸…...

vue2面试题10|[2024-11-24]

问题1&#xff1a;vue设置代理 如果你的前端应用和后端API服务器没有运行在同一个主机上&#xff0c;你需要在开发环境下将API请求代理到API服务器。这个问题可以通过vue.config.js中的devServer.proxy选项来配置。 1.devServer.proxy可以是一个指向开发环境API服务器的字符串&…...

c语言与c++到底有什么区别?

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///C爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于c语言与c区别的相关内容&#xff01; 关…...

云计算-华为HCIA-学习笔记

笔者今年7月底考取了华为云计算方向的HCIE认证&#xff0c;回顾从IA到IE的学习和项目实战&#xff0c;想整合和分享自己的学习历程&#xff0c;欢迎志同道合的朋友们一起讨论&#xff01; 第二章&#xff1a;服务器基础 服务器是什么&#xff1f; 服务器本质上就是个性能超强的…...

优先算法 —— 双指针系列 - 复写零

目录 1. 复写零 2. 算法原理 一般情况下 改为就地操作&#xff1a;从左到右&#xff08;错误&#xff09; 从右到左 总结一下解决方法&#xff1a; 如何找到最后一个复写的数 特殊情况 完整步骤&#xff1a; 3. 代码 1. 复写零 题目链接&#xff1a;1089. 复写零 - 力…...

初识Linux—— 基本指令(下)

前言&#xff1a; 本篇继续来学习Linux的基础指令&#xff0c;继续加油&#xff01;&#xff01;&#xff01; 本篇文章对于图片即内容详解&#xff0c;已同步到本人gitee&#xff1a;Linux学习: Linux学习与知识讲解 Linux指令 1、查看文件内容的指令 cat ​ cat 查看文件…...

esayexcel进行模板下载,数据导入,验证不通过,错误信息标注在excel上进行返回下载

场景&#xff1a;普普通通模板下载&#xff0c;加数据导入&#xff0c;分全量和增量&#xff0c;预计20w数据&#xff0c;每一条数据校验&#xff0c;前后端代码贴上&#xff08;代码有删改&#xff0c;关键代码都有&#xff0c;好朋友们自己取舍&#xff0c;代码一股脑贴上了&…...

服务器数据恢复—raid5阵列热备盘上线失败导致EXT3文件系统不可用的数据恢复案例

服务器数据恢复环境&#xff1a; 两组分别由4块SAS硬盘组建的raid5阵列&#xff0c;两组阵列划分的LUN组成LVM架构&#xff0c;格式化为EXT3文件系统。 服务器故障&#xff1a; 一组raid5阵列中的一块硬盘离线。热备盘自动上线替换离线硬盘&#xff0c;但在热备盘上线同步数据…...

《Qt Creator:人工智能时代的跨平台开发利器》

《Qt Creator&#xff1a;人工智能时代的跨平台开发利器》 一、Qt Creator 简介&#xff08;一&#xff09;功能和优势&#xff08;二&#xff09;快捷键与效率提升&#xff08;三&#xff09;跨平台支持&#xff08;四&#xff09;工具介绍与使用主要特性&#xff1a;使用步骤…...

AG32既可以做MCU,也可以仅当CPLD使用

Question: AHB总线上的所有外设都需要像ADC一样&#xff0c;通过cpld处理之后才能使用? Reply: 不用。 除了ADC外&#xff0c;其他都是 mcu可以直接配置使用的。 Question: DMA和CMP也不用? Reply: DMA不用。 ADC/DAC/CMP 用。 CMP 其实配置好后&#xff0c;可以直…...

51c自动驾驶~合集31

我自己的原文哦~ https://blog.51cto.com/whaosoft/12121357 #大语言模型会成为自动驾驶的灵丹妙药吗 人工智能&#xff08;AI&#xff09;在自动驾驶&#xff08;AD&#xff09;研究中起着至关重要的作用&#xff0c;推动其向智能化和高效化发展。目前AD技术的发展主要遵循…...

2023年3月GESPC++一级真题解析

一、单选题&#xff08;每题2分&#xff0c;共30分&#xff09; 题目123456789101112131415答案BAACBDDAADBCDBC 1.以下不属于计算机输入设备的有&#xff08; &#xff09;。 A &#xff0e;键盘 B &#xff0e;音箱 C &#xff0e;鼠标 D &#xff0e;传感器 【答案】 …...

linux NFS

什么是NFS NFS是Network File System的缩写&#xff0c;即网络文件系统。一种使用于分散式 文件协议通过网络让不同的机器、不同的操作系统能够分享个人数据&#xff0c;让应用 程序通过网络可以访问位于服务器磁盘中的数据。NFS在文件传送或信息传送 的过程中&#xff0c;依赖…...

查看浏览器的请求头

爬虫时用到了请求头&#xff0c;虽然可以用网上公开的&#xff0c;但是还是想了解一下本机浏览器的。以 Edge 为例&#xff0c;其余浏览器通用。 打开浏览器任一网页&#xff0c;按F12打开DevTools&#xff1b;或鼠标右键&#xff0c;选择“检查”。首次打开界面应该显示在网页…...

【JavaEE进阶】 JavaScript

本节⽬标 了解什么是JavaScript, 学习JavaScript的常⻅操作, 以及使⽤JQuery完成简单的⻚⾯元素操作. 一. 初识 JavaScript 1.JavaScript 是什么 JavaScript (简称 JS), 是⼀个脚本语⾔, 解释型或即时编译型的编程语⾔. 虽然它是作为开发Web⻚⾯的脚本语⾔⽽出名&#xff0c;…...

后端接受大写参数(亲测能用)

重要点引入包别引用错了 import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.annotation.JsonProperty; import lombok.Data;JsonSerialize Data public class Item {JsonProperty(value "Token")private String token…...

浅谈 React Hooks

React Hooks 是 React 16.8 引入的一组 API&#xff0c;用于在函数组件中使用 state 和其他 React 特性&#xff08;例如生命周期方法、context 等&#xff09;。Hooks 通过简洁的函数接口&#xff0c;解决了状态与 UI 的高度解耦&#xff0c;通过函数式编程范式实现更灵活 Rea…...

Java 语言特性(面试系列1)

一、面向对象编程 1. 封装&#xff08;Encapsulation&#xff09; 定义&#xff1a;将数据&#xff08;属性&#xff09;和操作数据的方法绑定在一起&#xff0c;通过访问控制符&#xff08;private、protected、public&#xff09;隐藏内部实现细节。示例&#xff1a; public …...

转转集团旗下首家二手多品类循环仓店“超级转转”开业

6月9日&#xff0c;国内领先的循环经济企业转转集团旗下首家二手多品类循环仓店“超级转转”正式开业。 转转集团创始人兼CEO黄炜、转转循环时尚发起人朱珠、转转集团COO兼红布林CEO胡伟琨、王府井集团副总裁祝捷等出席了开业剪彩仪式。 据「TMT星球」了解&#xff0c;“超级…...

鸿蒙DevEco Studio HarmonyOS 5跑酷小游戏实现指南

1. 项目概述 本跑酷小游戏基于鸿蒙HarmonyOS 5开发&#xff0c;使用DevEco Studio作为开发工具&#xff0c;采用Java语言实现&#xff0c;包含角色控制、障碍物生成和分数计算系统。 2. 项目结构 /src/main/java/com/example/runner/├── MainAbilitySlice.java // 主界…...

大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计

随着大语言模型&#xff08;LLM&#xff09;参数规模的增长&#xff0c;推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长&#xff0c;而KV缓存的内存消耗可能高达数十GB&#xff08;例如Llama2-7B处理100K token时需50GB内存&a…...

算法笔记2

1.字符串拼接最好用StringBuilder&#xff0c;不用String 2.创建List<>类型的数组并创建内存 List arr[] new ArrayList[26]; Arrays.setAll(arr, i -> new ArrayList<>()); 3.去掉首尾空格...

算法岗面试经验分享-大模型篇

文章目录 A 基础语言模型A.1 TransformerA.2 Bert B 大语言模型结构B.1 GPTB.2 LLamaB.3 ChatGLMB.4 Qwen C 大语言模型微调C.1 Fine-tuningC.2 Adapter-tuningC.3 Prefix-tuningC.4 P-tuningC.5 LoRA A 基础语言模型 A.1 Transformer &#xff08;1&#xff09;资源 论文&a…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作&#xff1a;验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化&#xff1a;测试aof和aof持久化机制&#xff0c;确保数据在开启后正确恢复。 事务&#xff1a;检查事务的原子性和回滚机制。 发布订阅&#xff1a;确保消息正确传递。 2、性…...

基于PHP的连锁酒店管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的连锁酒店管理系统 一 介绍 连锁酒店管理系统基于原生PHP开发&#xff0c;数据库mysql&#xff0c;前端bootstrap。系统角色分为用户和管理员。 技术栈 phpmysqlbootstrapphpstudyvscode 二 功能 用户 1 注册/登录/注销 2 个人中…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...