当前位置: 首页 > article >正文

Social LSTM实战:用Python复现行人轨迹预测模型(附代码)

Social LSTM实战从零构建行人轨迹预测系统行人轨迹预测一直是计算机视觉和机器人导航领域的核心挑战。想象一下当你走在拥挤的商场里会不自觉地调整步伐和路线避开迎面而来的人群——这种看似简单的行为背后隐藏着复杂的社交互动模式。本文将带你用Python实现一个能够理解这种社交行为的深度学习模型Social LSTM。1. 环境准备与数据加载在开始构建模型前我们需要配置合适的开发环境并准备训练数据。Social LSTM对计算资源有一定要求建议使用配备GPU的工作站或云服务器。基础环境配置conda create -n social_lstm python3.8 conda activate social_lstm pip install torch1.10.0cu113 torchvision0.11.1cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html pip install numpy pandas matplotlib scikit-learn我们将使用ETH和UCY这两个公开的行人轨迹数据集。这些数据集包含了真实场景中行人运动的坐标序列是轨迹预测任务的基准测试集。数据预处理关键步骤坐标归一化将所有轨迹坐标转换到相对坐标系序列分割按8帧观察12帧预测的标准划分样本社交关系构建基于空间距离确定行人间的交互关系def load_eth_dataset(data_path): 加载并预处理ETH数据集 raw_data pd.read_csv(data_path) # 坐标归一化 scene_min raw_data[[x, y]].min() scene_max raw_data[[x, y]].max() raw_data[x] (raw_data[x] - scene_min[x]) / (scene_max[x] - scene_min[x]) raw_data[y] (raw_data[y] - scene_min[y]) / (scene_max[y] - scene_min[y]) return raw_data注意实际应用中应考虑数据增强技术如随机旋转和平移以提高模型泛化能力。2. 模型架构设计Social LSTM的核心创新在于其社交池化层该层使模型能够捕捉行人之间的交互模式。我们将使用PyTorch实现这一复杂架构。2.1 基础LSTM模块首先构建标准的LSTM单元用于学习单个行人的运动模式import torch.nn as nn class TrajectoryLSTM(nn.Module): def __init__(self, input_dim2, embedding_dim64, hidden_dim128): super().__init__() self.embedding nn.Linear(input_dim, embedding_dim) self.lstm nn.LSTM(embedding_dim, hidden_dim, batch_firstTrue) def forward(self, x): embedded torch.relu(self.embedding(x)) output, (hidden, cell) self.lstm(embedded) return output, hidden, cell2.2 社交池化层实现社交池化层是Social LSTM区别于普通LSTM的关键组件它负责聚合邻近行人的隐藏状态class SocialPooling(nn.Module): def __init__(self, pool_size32, hidden_dim128): super().__init__() self.pool_size pool_size self.hidden_dim hidden_dim self.pool_embedding nn.Linear(pool_size*pool_size*hidden_dim, hidden_dim) def forward(self, hidden_states, positions, neighbor_masks): batch_size, num_peds hidden_states.size(0), hidden_states.size(1) # 初始化社交张量 social_tensor torch.zeros(batch_size, num_peds, self.pool_size, self.pool_size, self.hidden_dim).to(hidden_states.device) # 填充社交张量 for b in range(batch_size): for i in range(num_peds): for j in range(num_peds): if neighbor_masks[b,i,j] 1: # 是邻居 rel_pos positions[b,j] - positions[b,i] grid_x int((rel_pos[0] 1) * (self.pool_size - 1) / 2) grid_y int((rel_pos[1] 1) * (self.pool_size - 1) / 2) if 0 grid_x self.pool_size and 0 grid_y self.pool_size: social_tensor[b,i,grid_x,grid_y] hidden_states[b,j] # 嵌入聚合后的社交信息 social_tensor social_tensor.view(batch_size, num_peds, -1) social_embedded torch.relu(self.pool_embedding(social_tensor)) return social_embedded提示实际实现时可使用向量化操作替代循环显著提升计算效率。3. 完整Social LSTM实现将基础LSTM与社交池化层结合构建完整的Social LSTM模型class SocialLSTM(nn.Module): def __init__(self, args): super().__init__() self.args args self.traj_lstm TrajectoryLSTM() self.social_pooling SocialPooling() self.position_predictor nn.Linear(2*args.hidden_dim, 5) # 预测二元高斯参数 def forward(self, traj_batch): # 处理输入轨迹 obs_traj traj_batch[observed] batch_size, num_peds, seq_len, _ obs_traj.size() # 初始化隐藏状态 hidden_states torch.zeros(batch_size, num_peds, self.args.hidden_dim).to(obs_traj.device) # 处理观察序列 for t in range(seq_len): current_pos obs_traj[:,:,t,:] _, hiddens, _ self.traj_lstm(current_pos.unsqueeze(1)) hidden_states hiddens.squeeze(0) # 计算社交信息 neighbor_masks self._get_neighbor_masks(current_pos) social_embeddings self.social_pooling(hidden_states, current_pos, neighbor_masks) # 更新隐藏状态 combined torch.cat([hidden_states, social_embeddings], dim-1) hidden_states self.traj_lstm.lstm_cell(combined, hidden_states) # 预测未来轨迹 pred_traj [] last_pos obs_traj[:,:,-1,:] for _ in range(self.args.pred_len): # 预测下一步位置 gaussian_params self.position_predictor(hidden_states) next_pos self._sample_from_gaussian(gaussian_params) pred_traj.append(next_pos) # 更新状态 _, hiddens, _ self.traj_lstm(next_pos.unsqueeze(1)) hidden_states hiddens.squeeze(0) # 更新社交信息 neighbor_masks self._get_neighbor_masks(next_pos) social_embeddings self.social_pooling(hidden_states, next_pos, neighbor_masks) combined torch.cat([hidden_states, social_embeddings], dim-1) hidden_states self.traj_lstm.lstm_cell(combined, hidden_states) return torch.stack(pred_traj, dim2)4. 模型训练与优化Social LSTM的训练需要特别注意批处理策略和损失函数设计因为每个场景中的行人数量可能不同。4.1 自定义数据加载器from torch.utils.data import Dataset, DataLoader class TrajectoryDataset(Dataset): def __init__(self, data, obs_len8, pred_len12): self.data data self.obs_len obs_len self.pred_len pred_len def __len__(self): return len(self.data[scene_ids]) def __getitem__(self, idx): scene_id self.data[scene_ids][idx] frame_ids self.data[scene_frame_map][scene_id] num_peds len(self.data[scene_ped_map][scene_id]) # 提取轨迹序列 obs_traj torch.zeros(num_peds, self.obs_len, 2) pred_traj torch.zeros(num_peds, self.pred_len, 2) for i, ped_id in enumerate(self.data[scene_ped_map][scene_id]): full_traj self.data[trajectories][ped_id] obs_traj[i] torch.from_numpy(full_traj[:self.obs_len]) pred_traj[i] torch.from_numpy(full_traj[self.obs_len:self.obs_lenself.pred_len]) return { observed: obs_traj, predicted: pred_traj, scene_id: scene_id }4.2 训练循环实现def train(model, dataloader, optimizer, device): model.train() total_loss 0 for batch in dataloader: obs_traj batch[observed].to(device) pred_traj batch[predicted].to(device) # 前向传播 pred_output model(batch) # 计算负对数似然损失 loss gaussian_2d_loss(pred_output, pred_traj) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(dataloader) def gaussian_2d_loss(pred_params, true_pos): 计算二元高斯分布的负对数似然 mu_x pred_params[..., 0] mu_y pred_params[..., 1] sigma_x torch.exp(pred_params[..., 2]) sigma_y torch.exp(pred_params[..., 3]) rho torch.tanh(pred_params[..., 4]) x true_pos[..., 0] y true_pos[..., 1] z_x ((x - mu_x) / sigma_x)**2 z_y ((y - mu_y) / sigma_y)**2 z_xy (x - mu_x)*(y - mu_y)/(sigma_x*sigma_y) nll torch.log(2*np.pi*sigma_x*sigma_y*torch.sqrt(1-rho**2)) \ 1/(2*(1-rho**2)) * (z_x z_y - 2*rho*z_xy) return torch.mean(nll)5. 评估与结果分析训练完成后我们需要定量评估模型性能并与基线方法进行比较。使用三个标准指标平均位移误差(ADE)所有预测时间步位置误差的平均最终位移误差(FDE)预测终点与真实终点的距离非线性位移误差(NL-ADE)轨迹非线性区域的误差评估代码实现def evaluate(model, dataloader, device): model.eval() metrics {ade: 0, fde: 0, nl_ade: 0} with torch.no_grad(): for batch in dataloader: obs_traj batch[observed].to(device) pred_traj batch[predicted].to(device) pred_output model(batch) pred_pos pred_output[..., :2] # 取均值作为预测位置 # 计算各项指标 metrics[ade] torch.mean(torch.norm(pred_pos - pred_traj, dim-1)).item() metrics[fde] torch.mean(torch.norm(pred_pos[:,:,-1] - pred_traj[:,:,-1], dim-1)).item() # 识别非线性区域 is_nonlinear identify_nonlinear_regions(pred_traj) nl_ade torch.mean(torch.norm(pred_pos[is_nonlinear] - pred_traj[is_nonlinear], dim-1)) metrics[nl_ade] nl_ade.item() for k in metrics: metrics[k] / len(dataloader) return metrics典型训练结果对比模型ADE (m)FDE (m)NL-ADE (m)线性模型1.332.941.89普通LSTM0.982.121.42Social LSTM0.611.240.836. 实际应用与优化建议将训练好的Social LSTM模型部署到实际系统中时还需要考虑以下优化方向实时性优化使用TensorRT加速推理实现增量式预测避免每帧重新计算多模态预测同时生成多条可能轨迹为每条轨迹分配概率场景融合结合场景语义信息障碍物、出入口等考虑行人姿态和视线方向def predict_real_time(model, current_obs, past_statesNone): 实时预测接口 if past_states is None: past_states initialize_states() # 处理当前观测 processed_obs preprocess(current_obs) # 预测未来轨迹 with torch.no_grad(): pred_traj, new_states model(processed_obs, past_states) return pred_traj, new_states在实际机器人导航系统中集成Social LSTM时发现将预测模块与路径规划器松耦合比端到端方案更稳定。具体实现中我们维护一个预测结果缓存区每100ms更新一次预测而规划器则以更高频率(如10ms)重新规划路径。

相关文章:

Social LSTM实战:用Python复现行人轨迹预测模型(附代码)

Social LSTM实战:从零构建行人轨迹预测系统 行人轨迹预测一直是计算机视觉和机器人导航领域的核心挑战。想象一下,当你走在拥挤的商场里,会不自觉地调整步伐和路线,避开迎面而来的人群——这种看似简单的行为背后,隐藏…...

分子模拟新手指南:退火朗之万动力学采样的5个常见误区

分子模拟新手指南:退火朗之万动力学采样的5个常见误区 实验室的服务器嗡嗡作响,屏幕上跳动的分子轨迹曲线让刚入门的计算化学研究者既兴奋又困惑。退火朗之万动力学采样作为探索复杂能量景观的利器,正被越来越多地应用于材料设计和药物开发领…...

技术解析:从PWM到DShot——无人机电调协议的性能跃迁与实战选择

1. 无人机电调协议的前世今生 第一次接触无人机电调时,我被各种协议缩写搞得晕头转向。直到亲眼目睹竞速无人机从PWM切换到DShot600后,电机响应速度就像从绿皮火车升级到高铁——这个直观对比让我彻底理解了协议迭代的意义。 电调(电子调速器…...

Qwen3-VL-30B使用技巧:如何写出更好的提示词,让图片分析更准确?

Qwen3-VL-30B使用技巧:如何写出更好的提示词,让图片分析更准确? 你有没有遇到过这样的情况:给AI模型上传一张图片,问了一个问题,结果得到的回答要么答非所问,要么细节缺失,要么干脆…...

普冉单片机实战入门:从零到点灯

1. 为什么选择普冉PY32F00系列单片机 第一次接触普冉单片机是在去年底,当时被它的价格震惊到了——作为一款32位ARM Cortex-M0内核的单片机,PY32F00系列的市场价居然不到10块钱。这让我这个常年使用STM32的老玩家产生了强烈的好奇心。经过半年的实际项目…...

实战应用:在快马平台构建企业级git配置管理方案

最近在团队协作中,我们遇到了一个挺典型的问题:随着项目增多,开发环境里的Git配置变得一团乱麻。个人项目和公司项目混用同一个身份,大型项目的子模块更新总忘,代码提交格式五花八门,分支合并也常常出岔子。…...

MT5 Zero-Shot部署教程:支持WebRTC实时语音输入→文本增强→TTS输出全链路

MT5 Zero-Shot部署教程:支持WebRTC实时语音输入→文本增强→TTS输出全链路 想不想体验一个能“听懂”你说话,然后帮你把话“润色”得更漂亮,最后再用“好听的声音”读出来的AI工具?今天,我们就来手把手教你部署一个功…...

通义千问1.5-1.8B-Chat-GPTQ-Int4 重装系统后AI开发环境快速恢复:模型辅助清单与脚本生成

通义千问1.5-1.8B-Chat-GPTQ-Int4 重装系统后AI开发环境快速恢复:模型辅助清单与脚本生成 1. 引言 你有没有过这样的经历?电脑系统崩溃或者换了新机器,重装完系统,看着空荡荡的桌面和命令行,心里一沉——那个精心搭建…...

Mirage Flow 本地知识库构建:基于开源模型的私有化ChatGPT方案

Mirage Flow 本地知识库构建:基于开源模型的私有化ChatGPT方案 1. 引言 你是不是也遇到过这样的场景?公司内部有一堆产品手册、技术文档、会议纪要,每次想查点东西,都得在文件夹里翻半天。或者,你想让AI帮你分析一些…...

FUTURE POLICE语音模型LSTM声学模型对比与优化选择

FUTURE POLICE语音模型:LSTM声学模型对比与优化选择 最近在语音技术圈子里,FUTURE POLICE这个名字出现的频率越来越高。很多朋友都在问,这个新模型到底强在哪里,和咱们以前常用的LSTM模型比起来,到底值不值得花时间去…...

GPEN图像增强保姆级教程:从上传到下载全流程详解

GPEN图像增强保姆级教程:从上传到下载全流程详解 你是否曾面对一张模糊、泛黄或布满划痕的老照片,感到束手无策?想修复它,却又被复杂的专业软件和晦涩的参数吓退?今天,我将带你走进一个完全不同的世界——…...

C++结构体排序实战:如何用sort函数搞定学生成绩排名(附完整代码)

C结构体排序实战:如何用sort函数搞定学生成绩排名(附完整代码) 在编程学习过程中,数据处理和排序是每个开发者必须掌握的核心技能。对于C初学者来说,理解如何自定义排序规则并应用于实际场景,是提升编程能力…...

低成本MEMS IMU标定全攻略:从imu_tk安装到实战避坑指南

低成本MEMS IMU标定全攻略:从imu_tk安装到实战避坑指南 在机器人导航、无人机控制和VR设备开发中,惯性测量单元(IMU)的精度直接影响系统性能。对于预算有限的学生团队和初创公司,如何用开源工具实现专业级标定&#xf…...

非线性系列(三)—— 共轭梯度法在机器学习优化中的实战应用

1. 共轭梯度法:从数学原理到机器学习优化 第一次接触共轭梯度法(CG)是在研究生课程《数值分析》中,当时只觉得这是个解线性方程组的数学工具。直到后来处理一个百万维度的推荐系统优化问题时,我才真正体会到它的威力。相比常见的梯度下降法&a…...

HY-Motion 1.0 Docker部署全攻略:从拉取镜像到生成第一个3D动作

HY-Motion 1.0 Docker部署全攻略:从拉取镜像到生成第一个3D动作 1. 为什么选择Docker来部署HY-Motion 1.0 想象一下,你拿到一个功能强大的新工具,但说明书全是专业术语,安装步骤有几十页,中间任何一个环节出错都得从…...

从零到一:NestJS实体设计的艺术与科学

从零到一:NestJS实体设计的艺术与科学 1. 实体设计的基础理念 在NestJS框架中,实体(Entity)作为连接对象关系映射(ORM)与业务逻辑的桥梁,其设计质量直接影响着应用的扩展性和维护成本。一个优秀的实体设计需要平衡数据库性能、代码可读性和业…...

有限元分析必看:如何快速定位和修复ANSYS中的不良网格区域

有限元分析实战:ANSYS网格质量诊断与高效修复指南 在工程仿真领域,网格质量直接决定了有限元分析结果的可靠性。许多CAE工程师都曾经历过这样的困境:耗时数小时完成的复杂模型网格划分,却在求解阶段因质量警告而被迫中断。更令人头…...

避坑指南:Xilinx ZYNQ Ultrascale+ MPSoC DP转HDMI线材选择与电视兼容性实测

Xilinx ZYNQ Ultrascale MPSoC DP转HDMI实战:线材选择与电视兼容性深度解析 当你在实验室里调试ZYNQ MPSoC的DisplayPort输出时,最令人抓狂的瞬间莫过于:代码和硬件配置都完美,却因为一根转接线导致屏幕一片漆黑。这不是假设——根…...

nanobot开箱即用:内置vllm部署,无需复杂配置即刻体验

nanobot开箱即用:内置vllm部署,无需复杂配置即刻体验 1. nanobot简介:超轻量级AI助手 nanobot是一款受OpenClaw启发的超轻量级个人人工智能助手,其最大特点是仅需约4000行代码就能提供完整的AI助手功能。相比传统AI助手动辄数十…...

一键部署SiameseAOE:搭建属于你自己的智能文本情感分析平台

一键部署SiameseAOE:搭建属于你自己的智能文本情感分析平台 1. 快速了解SiameseAOE SiameseAOE是一个专门用于中文文本情感分析的开源模型,它能从用户评论、社交媒体内容等文本中自动识别产品属性和对应的情感表达。想象一下,你有一大堆客户…...

Zynq UltraScale+ MPSoC双核协作指南:Linux与R5裸机程序的高效通信设计

Zynq UltraScale MPSoC双核协作实战:构建Linux与R5裸机的高效通信系统 在异构计算架构中,Zynq UltraScale MPSoC凭借其独特的双核设计(Cortex-A53应用处理器与Cortex-R5实时处理器)成为工业控制、自动驾驶和边缘计算等领域的理想选…...

Janus-Pro-7B在Android端部署实战:移动设备上的实时多模态推理

Janus-Pro-7B在Android端部署实战:移动设备上的实时多模态推理 你有没有想过,让手机像人一样“看懂”世界?比如,拍一张照片,手机就能立刻告诉你照片里有什么;扫描一份文档,它能马上识别出文字并…...

Tao-8k本地知识库构建:从零搭建基于向量检索的问答系统

Tao-8k本地知识库构建:从零搭建基于向量检索的问答系统 你是不是也遇到过这样的烦恼?公司内部堆积如山的文档、产品手册、技术资料,想找个答案得翻半天。或者,你想让AI助手帮你解答一些专业领域的问题,但它总是一本正…...

从零开始:Windows平台Rust开发环境配置与VSCode调试实战

1. Windows平台Rust开发环境搭建 第一次接触Rust语言时,我被它的安全性和高性能所吸引,但在Windows上配置开发环境却让我踩了不少坑。经过多次实践,我总结出一套简单可靠的安装方法,特别适合刚入门的新手。 Rust官方推荐的安装工具…...

Echarts树图实战:如何将连接线从曲线改成直角线(附完整代码)

Echarts树图连接线直角化改造:从曲线美学到结构清晰的实战指南 在数据可视化领域,树状结构展示一直是呈现层级关系的经典方式。Echarts作为国内领先的可视化库,其树图组件默认采用曲线连接线,这种设计虽然美观流畅,但在…...

Ubuntu 22.04下Zabbix 7.0.0中文乱码终极修复指南(附字体配置详解)

Ubuntu 22.04下Zabbix 7.0.0中文乱码终极修复指南(附字体配置详解) 在监控系统运维工作中,Zabbix作为企业级开源监控解决方案,其数据可视化能力直接影响运维效率。当系统语言环境与监控数据字符集不匹配时,中文乱码问题…...

FLUX.小红书极致真实V2 GPU算力优化:4090显存压缩50%,支持长时间批量生成

FLUX.小红书极致真实V2 GPU算力优化:4090显存压缩50%,支持长时间批量生成 获取更多AI镜像 想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个…...

深入解析目标检测中的IoU计算逻辑与优化实践

1. IoU:目标检测中的"黄金标准" 当你第一次接触目标检测任务时,可能会被各种评价指标搞得晕头转向。但有一个指标,它简单直观又至关重要,那就是IoU(Intersection over Union)。我刚开始做目标检测…...

YOLOv10实战:从零部署到自定义数据集实时检测

1. 环境搭建:5分钟搞定YOLOv10开发环境 第一次接触YOLOv10时,我也被复杂的配置过程吓到过。后来发现只要抓住几个关键点,环境搭建其实比想象中简单得多。这里分享我的"懒人配置法",用最少的步骤完成环境准备。 Python环…...

零基础手把手教你激活WebStorm(含最新下载链接及详细操作截图)

WebStorm 2024 官方正版激活指南:从下载到配置的全流程详解 第一次打开 WebStorm 时,那个充满各种按钮和菜单的界面确实容易让人不知所措。作为 JetBrains 家族中最受欢迎的 JavaScript IDE,WebStorm 提供了强大的代码补全、调试和版本控制功…...