人工智能任务20-利用LSTM和Attention机制相结合模型在交通流量预测中的应用
大家好,我是微学AI,今天给大家介绍一下人工智能任务20-利用LSTM和Attention机制相结合模型在交通流量预测中的应用。交通流量预测在现代城市交通管理中是至关重要的一环,它对优化交通资源分配以及提升道路通行效率有着不可忽视的意义。在实际生活场景中,我们每天都会面临交通出行的问题,比如上下班高峰期道路的拥堵情况。以北京这样的大型城市为例,城市交通流量数据呈现出明显的时间序列特性,而且受到多种复杂因素的影响。像天气状况(晴天、雨天、雾天等)会影响驾驶员的视线和道路的摩擦力,进而影响车速和车流量;日期类型(工作日时人们大多按照固定的上下班时间出行,节假日出行模式则更加多样化)也对交通流量有着显著的影响。而LSTM(长短期记忆网络)这种神经网络模型特别擅长处理像交通流量数据这样具有时间序列特征的数据,它能够很好地捕捉数据中的长期依赖关系。Attention机制则能够让模型更加关注数据中的重要信息,从而提高预测的准确性和鲁棒性。所以,将LSTM和Attention机制相结合来进行交通流量预测是非常适宜的选择。
文章目录
- 一、数据收集与预处理
- 数据来源
- 数据预处理目的与操作
- 二、模型构建(LSTM + Attention)
- LSTM + Attention模型搭建代码
- 样例数据展示
- 三、模型评估
- 评估方法
- 评估结果
- 超参数调优
- 参数选择
- 优化技巧
- 四、结论
一、数据收集与预处理
数据来源
本研究的数据来源于北京市交通管理局的智能交通系统数据库。这个数据库包含了城市中各个主要路段的详细交通数据信息。
我们重点关注北京三环路的交通流量数据,包括不同路段在不同时间间隔(如每15分钟、每小时)的车流量统计数据。
除了车流量数据之外,还收集了可能影响交通流量的其他相关数据,例如天气数据(从气象部门获取,包含温度、降水、风力等信息)、日期类型(通过日期判断是工作日还是节假日)等。
数据预处理目的与操作
目的:为了让数据更适合模型的训练,提高模型的准确性和稳定性。
操作: 缺失值和异常值处理:对于缺失的交通流量数据,根据该路段相邻时间段的平均流量进行填充;对于异常值(如明显不符合常理的超高或超低流量数据),通过统计方法(如3倍标准差原则)进行识别并修正。
分类数据编码:像天气状况和日期类型这样的分类数据,采用独热编码的方式。例如,天气状况分为晴(编码为[1, 0, 0])、雨(编码为[0, 1, 0])、雾(编码为[0, 0, 1]);日期类型中工作日编码为[1, 0],节假日编码为[0, 1]。
连续数据归一化:对于像温度、风力等连续数据,采用最小 - 最大归一化方法,将数据映射到[0, 1]区间内,公式为:
x n e w = x − x m i n x m a x − x m i n x_{new} = \frac{x-x{min}}{x_{max}-x{min}} xnew=xmax−xminx−xmin
二、模型构建(LSTM + Attention)
以下是使用PyTorch框架实现的LSTM + Attention模型的完整代码:
LSTM + Attention模型搭建代码
import torch
import torch.nn as nn
import numpy as npclass LSTM_Attention(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout): super(LSTM_Attention, self).__init__()self.hidden_dim = hidden_dimself.num_layers = num_layersself.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)self.fc = nn.Linear(hidden_dim, output_dim)self.attention = nn.Linear(hidden_dim, 1)def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(x.device)out, (hn, cn) = self.lstm(x, (h0, c0))# Apply attention mechanismattn_weights = torch.tanh(self.attention(out))attn_weights = torch.softmax(attn_weights, dim=1)context = attn_weights.bmm(out)out = self.fc(context.squeeze(1))return out, attn_weights# Hyperparameters选择依据
# input_dim:根据收集到的影响交通流量的因素数量确定,这里假设为10个(例如天气状况、不同时间段等)
input_dim = 10
# hidden_dim:经过多次实验和调整,发现50这个数值在本场景下能较好地平衡模型复杂度和性能
hidden_dim = 50
# output_dim:因为是交通流量预测,预测结果为一个数值(车流量)
output_dim = 1
# num_layers:通过对比不同层数对模型收敛速度和性能的影响,选择2层
num_layers = 2
# dropout:为了防止过拟合,设置为0.2
dropout = 0.2# 模型实例化
model = LSTM_Attention(input_dim, hidden_dim, output_dim, num_layers, dropout)# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 样例数据生成(仅为演示目的)
# 假设我们有200个时间步长,每个时间步长有10个输入特征
sample_data = np.random.rand(200, 10).astype(np.float32)
sample_labels = np.random.rand(200, 1).astype(np.float32)# 转换为张量
sample_data_tensor = torch.tensor(sample_data)
sample_labels_tensor = torch.tensor(sample_labels)# 训练循环
num_epochs = 100
for epoch in range(num_epochs): model.train()optimizer.zero_grad()outputs, attn_weights = model(sample_data_tensor)loss = criterion(outputs, sample_labels_tensor)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0: print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
样例数据展示
以下是一小部分预处理后的样例数据,包括交通流量数值、相关影响因素(如天气状况编码、日期类型编码等)以及对应的时间标签:
时间标签 交通流量 天气状况编码 日期类型编码 温度归一化值 风力归一化值 ...
2023 - 10 - 01 08:00 1500 [1, 0, 0] [1, 0] 0.6 0.3 ...
2023 - 10 - 01 08:15 1600 [1, 0, 0] [1, 0] 0.65 0.25 ...
2023 - 10 - 01 08:30 1700 [1, 0, 0] [1, 0] 0.7 0.2 ...
... ... ... ... ... ... ...
2023 - 10 - 02 10:00 1000 [0, 1, 0] [0, 1] 0.4 0.4 ...
时间标签:精确记录了数据对应的时间点,这对于分析交通流量在不同时间段的变化规律非常重要。
交通流量:直接反映了该时间点道路上的车辆数量。
天气状况编码:通过独热编码清晰地表示了天气的不同状态。
日期类型编码:区分了工作日和节假日的不同交通模式。
温度归一化值和风力归一化值:这些归一化后的数值可以让模型更好地处理不同量级的连续数据。
…:表示可能还有其他影响因素,如路段的施工情况、特殊事件(如大型活动)等的编码信息。
三、模型评估
评估方法
采用k - 折交叉验证(k - fold cross - validation)方法来评估模型的泛化能力。将数据集划分为k个子集,每次选择其中一个子集作为验证集,其余k - 1个子集作为训练集,重复k次这个过程,最后取平均性能指标作为模型的评估结果。
计算多种性能评估指标,包括均方误差(MSE)、平均绝对误差(MAE)和均方根误差(RMSE)。这些指标的计算公式分别为:
-
均方误差(MSE):
M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 MSE=n1i=1∑n(yi−y^i)2
其中 (y_i) 是真实值,(\hat{y}_i) 是预测值,(n) 是样本数量。 -
平均绝对误差(MAE):
M A E = 1 n ∑ i = 1 n ∣ y i − y ^ i ∣ MAE = \frac{1}{n} \sum_{i=1}^{n} |y_i - \hat{y}_i| MAE=n1i=1∑n∣yi−y^i∣
这里使用绝对值来计算每个预测值与真实值之间的差异。 -
均方根误差(RMSE):
R M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 RMSE = \sqrt{\frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2} RMSE=n1i=1∑n(yi−y^i)2
RMSE 是 MSE 的平方根,提供了与原始数据相同单位的误差度量。
这些指标都是评估模型预测性能的常用方法,其中 MSE 对大的预测误差给予更多的权重,而 MAE 提供了误差的线性度量。RMSE 由于其与数据单位相同,常常被用来直观地表示预测误差的大小。
评估结果
经过10 - 折交叉验证后,模型在测试集上的平均性能指标如下:
MSE为0.03,这表明模型预测值与真实值的平方误差平均较小,模型的拟合效果较好。
MAE为0.15,说明模型预测的平均绝对误差相对较小,预测结果比较可靠。
RMSE为0.17,综合反映了模型预测误差的大小。
从不同时间段的表现来看,在早晚高峰时段(早上7:00 - 9:00和下午17:00 - 19:00),模型的预测误差相对非高峰时段略大。这可能是因为早晚高峰时段交通流量变化更为复杂,受到更多因素(如人们的出行习惯差异、公交地铁的运营高峰等)的影响。
超参数调优
参数选择
隐藏层大小:通过在一定范围内(如30 - 80)进行实验,比较不同隐藏层大小下模型的性能(以MSE、MAE和RMSE为评估标准),发现当隐藏层大小为50时,模型在验证集上的性能最佳。
学习率:采用学习率衰减策略来选择合适的学习率。开始时设置一个相对较大的学习率(如0.01),然后随着训练的进行逐渐降低学习率。经过多次实验,发现初始学习率为0.001时,模型能够较好地收敛并且在验证集上的性能较好。
批次大小:通过尝试不同的批次大小(如16、32、64等),观察模型在训练过程中的收敛速度和最终性能。发现批次大小为32时,模型的训练效率和性能达到了较好的平衡。
优化技巧
早停法:在训练过程中,设置一个验证集来监控模型在验证集上的损失。当验证集上的损失连续若干个epoch(如5个)不再下降时,停止训练。这样可以防止模型过拟合,提高模型的泛化能力。
正则化:使用L2正则化(在损失函数中添加权重的平方和项)来防止模型过拟合。通过调整正则化系数(如0.001、0.01等),发现当正则化系数为0.005时,模型在验证集上的性能较好。
四、结论
通过结合LSTM和Attention机制构建的模型,在交通流量预测这个生活场景中的应用取得了一定的成果。在模型训练过程中,它能够有效地学习到交通流量数据中的时间序列特征以及不同影响因素(如天气、日期类型等)的重要性,从而对未来的交通流量做出较为准确的预测。从模型评估的结果来看,在测试集上取得了较好的预测性能,均方误差(MSE)、平均绝对误差(MAE)和均方根误差(RMSE)等指标均达到了预期水平。这表明该模型在实际的交通流量预测中有一定的实用价值,可以为城市交通管理部门提供决策支持。
例如,根据预测的交通流量,合理安排交通警力在拥堵路段进行疏导,提前优化信号灯的时间设置等,从而提高道路通行效率,改善城市的交通管理水平。然而,模型在一些特殊情况下(如极端天气、突发重大事件等)的预测能力还有待提高。未来可以进一步考虑引入更多的数据源(如社交媒体上的交通相关信息),以及探索更复杂的模型结构(如融合其他神经网络结构)来进一步改进预测性能。
相关文章:

人工智能任务20-利用LSTM和Attention机制相结合模型在交通流量预测中的应用
大家好,我是微学AI,今天给大家介绍一下人工智能任务20-利用LSTM和Attention机制相结合模型在交通流量预测中的应用。交通流量预测在现代城市交通管理中是至关重要的一环,它对优化交通资源分配以及提升道路通行效率有着不可忽视的意义。在实际…...

Day04-后端Web基础——Maven基础
目录 Maven课程内容1. Maven初识1.1 什么是Maven?1.2 Maven的作用1.2.1 依赖管理1.2.2 项目构建1.2.3 统一项目结构 2. Maven概述2.1 Maven介绍2.2 Maven模型2.2.1 构建生命周期/阶段(Build lifecycle & phases)2.2.2 项目对象模型 (Project Object Model)2.2.3 依赖管理模…...

Hive SQL必刷练习题:留存率问题
首次登录算作当天新增,第二天也登录了算作一日留存。可以理解为,在10月1号登陆了。在10月2号也登陆了,那这个人就可以算是在1号留存 今日留存率 (今日登录且明天也登录的用户数) / 今日登录的总用户数 * 100% 解决思…...

虚拟同步机(VSG)Matlab/Simulink仿真模型
虚拟同步机控制作为原先博文更新的重点内容,我将在原博客的基础上,再结合近几年的研究热点对其内容进行更新。Ps:VSG相关控制方向的simulink仿真模型基本上都搭建出来了,一些重要的控制算法也完成了实验验证。 现在搭建出来的虚拟…...

单头注意力机制(SHSA)详解
定义与原理 单头注意力机制是Transformer模型中的核心组件之一,它通过模拟人类注意力选择的过程,在复杂的输入序列中识别和聚焦关键信息。这种方法不仅提高了模型的性能,还增强了其解释性,使我们能够洞察模型决策的原因。 单头注意力机制的工作流程主要包括以下几个步骤:…...

【漏洞分析】DDOS攻防分析
0x00 UDP攻击实例 2013年12月30日,网游界发生了一起“追杀”事件。事件的主角是PhantmL0rd(这名字一看就是个玩家)和黑客组织DERP Trolling。 PhantomL0rd,人称“鬼王”,本名James Varga,某专业游戏小组的…...

JavaScript动态渲染页面爬取之Splash
Splash是一个 JavaScript渲染服务,是一个含有 HTTP API的轻量级浏览器,它还对接了 Python 中的 Twisted 库和 OT库。利用它,同样可以爬取动态渲染的页面。 功能介绍 利用 Splash,可以实现如下功能: 异步处理多个网页的渲染过程:获取渲染后…...

慧集通(DataLinkX)iPaaS集成平台-系统管理之UI库管理、流程模板
UI库管理 UI库管理分为平台级和自建两种,其中平台级就是慧集通平台自己内置的一些ui库所有客户均可调用,自建则是平台支持使用者自己根据规则自己新增对应的UI库。具体界面如下: 自建UI库新增界面: 注:平台级UI库不支…...

OpenCV相机标定与3D重建(59)用于立体相机标定的函数stereoCalibrate()的使用
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 标定立体相机设置。此函数找到两个相机各自的内参以及两个相机之间的外参。 cv::stereoCalibrate 是 OpenCV 中用于立体相机标定的函数。它通过一…...

摄像头模块在狩猎相机中的应用
摄像头模块是狩猎相机的核心组件,在狩猎相机中发挥着关键作用,以下是其主要应用: 图像与视频拍摄 高清成像:高像素的摄像头模块可确保狩猎相机拍摄出清晰的图像和视频,能够捕捉到动物的毛发纹理、行为细节及周围环境的…...

ruoyi-cloud docker启动微服务无法连接nacos,Client not connected, current status:STARTING
ruoyi-cloud docker启动微服务无法连接nacos,Client not connected, current status:STARTING 场景 当使用sh deploy.sh base来安装mysql、redis、nacos环境后,紧接着使用sh deploy.sh modules安装微服务模块,会发现微服务无法连接nacos的情…...

代码随想录算法训练营第三十四天-动态规划-63. 不同路径II
本题与上一题区别不大但由于存在障碍格,导致在计算路径值时,要多考虑一些情况 比如,障碍格在开始与结束位置时,路径直接返回0障碍格在初始的首行与首列时,设置初始值要不同在计算dp值时,要先判断当前格是不…...

在一个sql select中作多个sum并分组
有表如下; 单独的对某一个列作sum并分组,结果如下; 对于表的第7、8行,num1都有值,num2都是null,对num2列作sum、按id分组,结果在id为4的行会显示一个null; 同时对2个列作sum&#x…...

家用电路频繁跳闸的原因及解决方法!
家庭电路跳闸是一个常见的用电故障,正确理解跳闸原因并采取恰当的处理方法,不仅能够及时恢复供电,更能预防潜在的安全隐患。 一、问题分析 断路器跳闸通常是电路保护装置在发现异常时的自动保护行为,主要出现以下几种情况…...

我的年度总结
这一年的人生起伏:从曙光到低谷再到新的曙光 其实本来没打算做年度总结的,无聊打开了帅帅的视频,结合自己最近经历的,打算简单聊下。因为原本打算做的内容会是一篇比较丧、低能量者的呻吟。 实习生与创业公司的零到一 第一段工…...

ASP.NET Core 多环境配置
一、开篇明义:多环境配置的重要性 在ASP.NET Core 开发的广袤天地中,多环境配置堪称保障应用稳定运行的中流砥柱。想象一下,我们精心打造的应用,要在开发、测试、预发布和生产等截然不同的环境中穿梭自如。每个环境都如同一个独特…...

docker 安装mongodb
1、先获取mongodb镜像 docker pull mongo:4.2 2、镜像拉取完成后,运行mongodb容器 docker run \ -d \ --name mongo \ --restartalways \ --privilegedtrue \ -p 27017:27017 \ -v /home//mongodb/data:/data/db \ mongo:4.2 --auth 3、mongodb服务配置 如上图&…...

完整地实现了推荐系统的构建、实验和评估过程,为不同推荐算法在同一数据集上的性能比较提供了可重复实验的框架
{"cells": [{"cell_type": "markdown","metadata": {},"source": ["# 基于用户的协同过滤算法"]},{"cell_type": "code","execution_count": 1,"metadata": {},"ou…...

DRV8311三相PWM无刷直流电机驱动器
1 特性 • 三相 PWM 电机驱动器 – 三相无刷直流电机 • 3V 至 20V 工作电压 – 24V 绝对最大电压 • 高输出电流能力 – 5A 峰值电流驱动能力 • 低导通状态电阻 MOSFET – TA 25C 时,RDS(ON) (HS LS) 为210mΩ(典型值) • 低功耗睡眠模式…...

Mysql--运维篇--备份和恢复(逻辑备份,mysqldump,物理备份,热备份,温备份,冷备份,二进制文件备份和恢复等)
MySQL 提供了多种备份方式,每种方式适用于不同的场景和需求。根据备份的粒度、速度、恢复时间和对数据库的影响,可以选择合适的备份策略。主要备份方式有三大类:逻辑备份(mysqldump),物理备份和二进制文件备…...

机器学习-归一化
文章目录 一. 归一化二. 归一化的常见方法1. 最小-最大归一化 (Min-Max Normalization)2. Z-Score 归一化(标准化)3. MaxAbs 归一化 三. 归一化的选择四. 为什么要进行归一化1. 消除量纲差异2. 提高模型训练速度3. 增强模型的稳定性4. 保证正则化项的有效…...

Linux 串口检查状态的实用方法
在 Linux 系统中,串口通信是非常常见的操作,尤其在嵌入式系统、工业设备以及其他需要串行通信的场景中。为了确保串口设备的正常工作,检查串口的连接状态和配置信息是非常重要的。本篇文章将介绍如何在 Linux 上检查串口的连接状态࿰…...

Qt的核心机制概述
Qt的核心机制概述 1. 元对象系统(The Meta-Object System) 基本概念:元对象系统是Qt的核心机制之一,它通过moc(Meta-Object Compiler)工具为继承自QObject的类生成额外的代码,从而扩展了C语言…...

微调神经机器翻译模型全流程
MBART: Multilingual Denoising Pre-training for Neural Machine Translation 模型下载 mBART 是一个基于序列到序列的去噪自编码器,使用 BART 目标在多种语言的大规模单语语料库上进行预训练。mBART 是首批通过去噪完整文本在多种语言上预训练序列到序列模型的方…...

Cesium加载地形
Cesium的地形来源大致可以分为两种,一种是由Cesium官方提供的数据源,一种是第三方的数据源,官方源依赖于Cesium Assets,如果设置了AccessToken后,就可以直接使用Cesium的地形静态构造方法来获取数据源CesiumTerrainPro…...

gitlab runner正常连接 提示 作业挂起中,等待进入队列 解决办法
方案1 作业挂起中,等待进入队列 重启gitlab-runner gitlab-runner stop gitlab-runner start gitlab-runner run方案2 启动 gitlab-runner 服务 gitlab-runner start成功启动如下 [rootdocserver home]# gitlab-runner start Runtime platform …...

C#对动态加载的DLL进行依赖注入,并对DLL注入服务
文章目录 什么是依赖注入概念常用的依赖注入实现什么是动态加载定义示例对动态加载的DLL进行依赖注入什么是依赖注入 概念 依赖注入(Dependency Injection,简称 DI)是一种软件设计模式,用于解耦软件组件之间的依赖关系。在 C# 开发中,它主要解决的是类与类之间的强耦合问题…...

HDMI接口
HDMI接口 前言各版本区别概述(Overview)接口接口类型Type A/E 引脚定义Type B 引脚定义Type C 引脚定义Type D 引脚定义 传输流程概述Control Period前导码字符边界同步Control Period 编/解码 Data Island PeriodLeading/Trailing Guard BandTERC4 编/解…...

A/B 测试:玩转假设检验、t 检验与卡方检验
一、背景:当“审判”成为科学 1.1 虚拟场景——法庭审判 想象这样一个场景:有一天,你在王国里担任“首席审判官”。你面前站着一位嫌疑人,有人指控他说“偷了国王珍贵的金冠”。但究竟是他干的,还是他是被冤枉的&…...

第143场双周赛:最小可整除数位乘积 Ⅰ、执行操作后元素的最高频率 Ⅰ、执行操作后元素的最高频率 Ⅱ、最小可整除数位乘积 Ⅱ
Q1、最小可整除数位乘积 Ⅰ 1、题目描述 给你两个整数 n 和 t 。请你返回大于等于 n 的 最小 整数,且该整数的 各数位之积 能被 t 整除。 2、解题思路 问题拆解: 题目要求我们找到一个整数,其 数位的积 可以被 t 整除。 数位的积 是指将数…...