当前位置: 首页 > article >正文

transformer➕lstm训练回归模型

使用 Transformer 和 LSTM 优化时序数据回归模型:全流程分析

在机器学习和深度学习中,处理时序数据是一项常见的任务。无论是金融预测、气象预测还是库存管理等领域,时序数据都扮演着至关重要的角色。对于时序数据的建模,深度学习模型,如 LSTM(长短期记忆网络)和 Transformer,已被广泛应用。本文将介绍如何结合 LSTM 和 Transformer 模块,构建一个优化后的回归模型,并展示从数据生成到模型训练的全流程。

目录

  1. 数据生成与处理

  2. 模型构建与优化

  3. 模型训练与评估

  4. 总结与展望


数据生成与处理

在时序数据建模中,首先需要准备数据。我们将生成一组合成的时序数据,并进行数据预处理,使其适应 LSTM 和 Transformer 模型的输入要求。

生成合成时序数据

我们使用 Python 库 timeseries-generator 来生成包含线性趋势和白噪声的时序数据,数据形式类似于股票价格或传感器数据的变化。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from timeseries_generator import LinearTrend, WhiteNoise, Generator# 设置随机种子以确保结果可重现
np.random.seed(42)# 定义线性趋势和白噪声
lt = LinearTrend(coef=0.5, offset=10.0, col_name="linear_trend")
wn = WhiteNoise(stdev_factor=0.1)# 生成数据
g = Generator(factors={lt, wn}, features=None, date_range=pd.date_range(start="2020-01-01", end="2020-12-31"))
g.generate()# 获取生成的数据
df = g.df
df['target'] = df['linear_trend'] + df['white_noise'] + np.random.normal(0, 0.1, len(df))# 可视化生成的数据
plt.figure(figsize=(10, 6))
plt.plot(df['date'], df['target'], label='Generated Time Series')
plt.xlabel('Date')
plt.ylabel('Value')
plt.title('Synthetic Time Series Data')
plt.legend()
plt.grid(True)
plt.show()

这段代码生成了包含线性趋势和噪声的时序数据,并可视化了其变化趋势。

数据预处理

为了将数据输入到 LSTM 和 Transformer 模型中,我们需要对数据进行归一化处理,并将其转换为适合模型输入的格式。

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split# 选择特征和目标变量
X = df[['linear_trend', 'white_noise']].values
y = df['target'].values# 归一化处理
scaler_X = MinMaxScaler()
scaler_y = MinMaxScaler()
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y.reshape(-1, 1))# 重塑 X 为 LSTM 输入格式
X_scaled = X_scaled.reshape((X_scaled.shape[0], 1, X_scaled.shape[1]))# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, shuffle=False)print(f"Training data shape: {X_train.shape}")
print(f"Test data shape: {X_test.shape}")

这段代码将数据标准化,并将其重塑为 LSTM 输入格式。


模型构建与优化

在本次任务中,我们将结合 Transformer 和 LSTM 模块来构建一个优化的时序数据回归模型。Transformer 模块负责捕捉长程依赖关系,LSTM 模块负责建模短期时序依赖。

模型定义

我们定义了一个包含 Transformer 和 LSTM 的混合模型。Transformer 模块采用了多头自注意力机制(MultiHeadAttention),并与 LSTM 网络共同处理时序数据。

import tensorflow as tf
from tensorflow.keras import layers, regularizersclass TransformerLSTMModel(tf.keras.Model):def __init__(self, input_dim, output_dim, lstm_units=64, transformer_heads=4, transformer_dim=64, dropout_rate=0.5, l2_reg=0.01, initial_lr=0.1):super(TransformerLSTMModel, self).__init__()# Transformer 模块self.transformer_attention = layers.MultiHeadAttention(num_heads=transformer_heads, key_dim=transformer_dim)self.transformer_dropout = layers.Dropout(dropout_rate)self.transformer_norm = layers.LayerNormalization()# LSTM 模块self.lstm_layer = layers.LSTM(lstm_units, return_sequences=True)self.lstm_norm = layers.LayerNormalization()# 全连接层self.dense1 = layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))self.batch_norm1 = layers.BatchNormalization()self.dropout1 = layers.Dropout(dropout_rate)self.dense2 = layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))self.batch_norm2 = layers.BatchNormalization()self.dropout2 = layers.Dropout(dropout_rate)# 输出层self.output_layer = layers.Dense(output_dim, activation='linear')# 学习率调度self.lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_lr, decay_steps=100000, decay_rate=0.96, staircase=True)self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr_schedule)def call(self, inputs):# Transformer 模块x = self.transformer_attention(inputs, inputs)x = self.transformer_dropout(x)x = self.transformer_norm(x)x = layers.Add()([x, inputs])  # 残差连接# LSTM 模块x = self.lstm_layer(x)x = self.lstm_norm(x)# 聚合时间步的输出x = layers.GlobalAveragePooling1D()(x)# 全连接层x = self.dense1(x)x = self.batch_norm1(x)x = self.dropout1(x)x = self.dense2(x)x = self.batch_norm2(x)x = self.dropout2(x)# 输出outputs = self.output_layer(x)return outputsdef compile_model(self):self.compile(optimizer=self.optimizer,loss='mean_squared_error',metrics=['mae'])  # 回归任务使用均方误差损失函数

模型优化

在模型优化过程中,我们使用了以下技术:

  1. Transformer 模块: 使用多头自注意力机制捕捉全局信息。

  2. LSTM 模块: 负责处理时序依赖,帮助模型理解时间序列的短期依赖关系。

  3. 正则化: 通过 L2 正则化、Dropout 和批量归一化,防止过拟合。

  4. 学习率调度: 使用指数衰减学习率调度,在训练过程中动态调整学习率,优化训练过程。


模型训练与评估

在数据预处理和模型构建完成后,我们开始训练模型,并监控训练过程中的损失和 MAE 曲线。

# 训练模型
history = model.fit(X_train, y_train, epochs=30, batch_size=32, validation_data=(X_test, y_test))# 评估模型
test_loss, test_mae = model.evaluate(X_test, y_test)
print(f"Test Loss: {test_loss}")
print(f"Test MAE: {test_mae}")# 绘制训练过程中的损失和 MAE 曲线
import matplotlib.pyplot as pltplt.figure(figsize=(12, 6))# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()# 绘制 MAE 曲线
plt.subplot(1, 2, 2)
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Val MAE')
plt.title('Mean Absolute Error over Epochs')
plt.xlabel('Epochs')
plt.ylabel('MAE')
plt.legend()plt.tight_layout()
plt.show()

训练完成后,我们打印了测试集上的损失(Loss)和 MAE(Mean Absolute Error)结果,帮助我们了解模型的表现。

未来展望

  • 模型进一步优化:

可以尝试使用其他先进的模型架构,如 TCN(时序卷积网络)或 GRU(门控循环单元),以提高预测精度。

  • 超参数调整: 通过网格搜索或随机搜索来调整模型的超参数,以达到更好的预测效果。

  • 多任务学习: 该模型还可以扩展为多任务学习,用于解决多个相关的时序预测任务。

通过这样的深度学习模型,可以有效地捕捉时序数据中的复杂模式,提升预测的准确性。

 

相关文章:

transformer➕lstm训练回归模型

使用 Transformer 和 LSTM 优化时序数据回归模型:全流程分析 在机器学习和深度学习中,处理时序数据是一项常见的任务。无论是金融预测、气象预测还是库存管理等领域,时序数据都扮演着至关重要的角色。对于时序数据的建模,深度学习…...

用卷积神经网络 (CNN) 实现 MNIST 手写数字识别

在深度学习领域,MNIST 手写数字识别是经典的入门级项目,就像编程世界里的 “Hello, World”。卷积神经网络(Convolutional Neural Network,CNN)作为处理图像数据的强大工具,在该任务中展现出卓越的性能。本…...

windows的rancherDesktop修改镜像源

您好!要在Windows系统上的Rancher Desktop中修改Docker镜像源(即设置registry mirror),您需要根据Rancher Desktop使用的容器运行时(containerd或dockerd)进行配置。用户提到“allowed-image”没有效果&…...

spring中的@ComponentScan注解详解

ComponentScan 是 Spring 框架中用于自动扫描并注册组件的核心注解,它简化了 Spring 应用中 Bean 的发现和装配流程。以下从核心功能、属性解析、使用场景及示例等方面进行详细说明。 一、核心功能与作用 自动扫描组件 ComponentScan 会扫描指定包及其子包下的类&am…...

机器学习之嵌入(Embeddings):从理论到实践

机器学习之嵌入(Embeddings):从理论到实践 摘要 本文深入探讨了机器学习中嵌入(Embeddings)的概念和应用。通过具体的实例和可视化展示,我们将了解嵌入如何将高维数据转换为低维表示,以及这种转换在推荐系统、自然语言处理等领域的实际应用…...

深入剖析 I/O 复用之 select 机制

深入剖析 I/O 复用之 select 机制 在网络编程中,I/O 复用是一项关键技术,它允许程序同时监控多个文件描述符的状态变化,从而高效地处理多个 I/O 操作。select 作为 I/O 复用的经典实现方式,在众多网络应用中扮演着重要角色。本文…...

SpringBoot指定项目层日志记录

1、新建一个Springboot项目&#xff0c;添加Lombok依赖&#xff08;注意&#xff1a;这里使用的Lombok下的Slf4j快速日志记录方式&#xff09; <dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId></dependenc…...

RISC-V hardfault分析工具,RTTHREAD-RVBACKTRACE

RV BACKTRACE 简介 本文主要讲述RV BACKTRACE 的内部主要原理 没有接触过rvbacktrace可以看下面两篇文章&#xff0c;理解一下如何使用RVBACKTRACE RVBacktrace RISC-V极简栈回溯组件&#xff1a;https://club.rt-thread.org/ask/article/64bfe06feb7b3e29.html RVBacktra…...

xiaopiu原型设计工具笔记

文章目录 有没有行组件是否支持根据图片生成原型呢? 其他官网 做项目要用到原型设计&#xff0c;还是那句话&#xff0c;遇到的必须会用&#xff0c;走起。 支持本地也支持线上。 有没有行组件 是这样&#xff0c;同一行有多个字段&#xff0c;如何弄的准确点呢? 目前只会弄…...

matlab 中function的用法

matlab 中function的用法 前言介绍1. 基本语法示例&#xff08;1&#xff09;可以直接输出&#xff08;2&#xff09;调用函数 2.输入参数和输出参数示例多输入参数和输出参数定义一个函数&#xff0c;计算两个数的和与差&#xff1a;调用该函数&#xff1a; 3. 默认参数示例 4…...

解锁 LLM 推理速度:深入 FlashAttention 与 PagedAttention 的原理与实践

写在前面 大型语言模型 (LLM) 已经渗透到我们数字生活的方方面面,从智能问答、内容创作到代码辅助,其能力令人惊叹。然而,驱动这些强大模型的背后,是对计算资源(尤其是 GPU)的巨大需求。在模型推理 (Inference) 阶段,即模型实际对外提供服务的阶段,速度 (Latency) 和吞…...

4个纯CSS自定义的简单而优雅的滚动条样式

今天发现 uni-app 项目的滚动条不显示&#xff0c;查了下原来是设置了 ::-webkit-scrollbar {display: none; } 那么怎么用 css 设置滚动条样式呢&#xff1f; 定义滚动条整体样式‌ ::-webkit-scrollbar 定义滚动条滑块样式 ::-webkit-scrollbar-thumb 定义滚动条轨道样式‌…...

查看jdk是否安装并且配置成功?(Android studio安装前的准备)

WinR输入cmd打开命令提示窗口 输入命令 java -version 回车显示如下&#xff1a;...

5月8日直播见!Atlassian Team‘25大会精华+AI实战分享

在刚刚落幕的 Atlassian Team’25 全球大会上&#xff0c;Atlassian发布了多项重磅创新&#xff0c;全面升级其协作平台&#xff0c;涵盖从Al驱动、知识管理到跨团队协作&#xff0c;再到战略执行的各个方面。 为帮助中国用户深入了解这些前沿动态&#xff0c;Atlassian全球白…...

Windows系统下使用Kafka和Zookeeper,Python运行kafka(一)

下载和安装见Linux系统下使用Kafka和Zookeeper 配置 Zookeeper Zookeeper 是 Kafka 所依赖的分布式协调服务。在 Kafka 解压目录下,有一个 Zookeeper 的配置文件模板config/zookeeper.properties,你可以直接使用默认配置。 启动 Zookeeper 打开命令提示符(CMD),进入 K…...

C++之“继承”

继续开始关于C相关的内容。C作为面向对象的语言&#xff0c;有三大特性&#xff1a;封装&#xff0c;继承&#xff0c;多态。 这篇文章我们开始学习&#xff1a;继承。 一、继承的概念和定义 1. 继承的概念 什么是继承呢&#xff1f; 字面意思理解来看&#xff1a;继承就是…...

Webug4.0靶场通关笔记19- 第24关邮箱轰炸

目录 第24关 邮箱轰炸 1.配置环境 2.打开靶场 3.源码分析 4.邮箱轰炸 &#xff08;1&#xff09;注册界面bp抓包 &#xff08;2&#xff09;发送到intruder &#xff08;3&#xff09;配置position &#xff08;4&#xff09;配置payload &#xff08;5&#xff09;开…...

java CompletableFuture 异步编程工具用法1

1、测试异步调用&#xff1a; static void testCompletableFuture1() throws ExecutionException, InterruptedException {// 1、无返回值的异步任务。异步线程执行RunnableCompletableFuture.runAsync(() -> System.out.println("only you"));// 2、有返回值的异…...

缺乏实体人形机器人的主流高精度仿真方案

在缺乏实体人形机器人的情况下&#xff0c;可通过以下主流仿真方案实现高精度模拟&#xff08;基于2025年最新技术&#xff09;&#xff1a; 一、基础建模工具链 MATLAB Robotics Toolbox • 通过连杆(Link)和关节(Joint)定义生物力学参数 • 示例代码创建简化模型&#xff1a…...

基于STM32、HAL库的CP2104 USB转UART收发器 驱动程序设计

一、简介: CP2104是Silicon Labs公司推出的一款USB转UART桥接芯片,具有以下特点: USB 2.0全速兼容 集成USB收发器,无需外部电阻 支持UART数据传输,波特率从300bps到2Mbps 内置EEPROM可配置设备信息 支持RTS/CTS硬件流控制 3.3V I/O电平,内置5V至3.3V稳压器 紧凑的QFN-24…...

ERC-20与ERC-721:区块链代币标准的双星解析

一、代币标准的诞生背景 在以太坊生态中&#xff0c;代币标准是构建去中心化应用&#xff08;DApps&#xff09;的基石。ERC-20与ERC-721分别代表同质化与非同质化代币的两大核心标准&#xff0c;前者支撑着90%以上的加密资产流通&#xff0c;后者则开启了数字资产唯一性的新时…...

使用Go语言对接全球股票数据源API实践指南

使用Go语言对接全球股票数据API实践指南 概述 本文介绍如何通过Go语言对接支持多国股票数据的API服务。我们将基于提供的API文档&#xff0c;实现包括市场行情、K线数据、实时推送等核心功能的对接。 一、准备工作 1. 获取API Key 联系服务提供商获取访问密钥&#xff08;替…...

经典密码学算法实现

# AES-128 加密算法的规范实现&#xff08;不使用外部库&#xff09; # ECB模式S_BOX [0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B,0xFE, 0xD7, 0xAB, 0x76, 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0,0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0x…...

git 远程仓库管理详解

Git 的远程仓库管理是多人协作和代码共享的核心功能。以下是 Git 远程仓库管理的详细说明&#xff0c;包括常用操作、命令和最佳实践。 1. 什么是远程仓库&#xff1f; 远程仓库&#xff08;Remote Repository&#xff09;&#xff1a;存储在网络服务器上的 Git 仓库&#xff0…...

ABP vNext + gRPC 实现服务间高速通信

ABP vNext gRPC 实现服务间高速通信 &#x1f4a8; 在现代微服务架构中&#xff0c;服务之间频繁的调用往往对性能构成挑战。尤其在电商秒杀、金融风控、实时监控等对响应延迟敏感的场景中&#xff0c;传统 REST API 面临序列化负担重、数据体积大、通信延迟高等瓶颈。 本文…...

若依框架Ruoyi-vue整合图表Echarts中国地图标注动态数据

若依框架Ruoyi-vue整合图表Echarts中国地图 概述创作灵感预期效果整合教程前期准备整合若依框架1、引入china.json2、方法3、data演示数据4、核心代码 完整代码[毫无保留]组件调用 总结 概述 首先&#xff0c;我需要回忆之前给出的回答&#xff0c;确保这次的内容不重复&#…...

京东(JD)API 商品详情数据接口讲解及 JSON 示例

前言 京东开放平台提供了多种商品详情相关的 API 接口&#xff0c;开发者可以通过这些接口获取商品的详细信息。以下为接口调用方式及 JSON 返回数据的参考示例。 1. 接口调用方式 京东商品详情接口通常采用以下形式&#xff1a; 请求方式&#xff1a;GET/POST关键参数&…...

算法中的数学:约数

1.求一个整数的所有约数 对于一个整数x&#xff0c;他的其中一个约数若为i&#xff0c;那么x/i也是x的一个约数。而其中一个约数的大小一定小于等于根号x&#xff08;完全平方数则两个约数都为根号x&#xff09;&#xff0c;所以我们只需要遍历到根号x&#xff0c;然后计算出另…...

Python实例题:Python获取喜马拉雅音频

目录 Python实例题 题目 python-get-ximalaya-audioPython 获取喜马拉雅音频脚本 代码解释 get_audio_info 函数&#xff1a; download_audio 函数&#xff1a; 主程序&#xff1a; 运行思路 注意事项 Python实例题 题目 Python获取喜马拉雅音频 python-get-ximala…...

[监控看板]Grafana+Prometheus+Exporter监控疑难排查

采用GrafanaPrometheusExporter监控MySQL时发现经常数据不即时同步&#xff0c;本示例也是本地搭建采用。 Prometheus面板 1&#xff0c;Detected a time difference of 11h 47m 22.337s between your browser and the server. You may see unexpected time-shifted query res…...