Python实现基于TD3(Twin Delayed Deep Deterministic Policy Gradient)算法来实时更新路径规划算法
下面是一个使用Python实现基于TD3(Twin Delayed Deep Deterministic Policy Gradient)算法来实时更新路径规划算法的三个参数(sigma0,rho0 和 theta)的示例代码。该算法将依据障碍物环境进行优化。
实现思路
- 环境定义:定义一个包含障碍物的环境,用于模拟路径规划问题。
- TD3算法:使用TD3算法来学习如何优化路径规划算法的三个参数。
- 训练过程:在环境中进行训练,不断更新策略网络和价值网络。
代码示例
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque# 定义TD3网络
class Actor(nn.Module):def __init__(self, state_dim, action_dim, max_action):super(Actor, self).__init__()self.fc1 = nn.Linear(state_dim, 400)self.fc2 = nn.Linear(400, 300)self.fc3 = nn.Linear(300, action_dim)self.max_action = max_actiondef forward(self, state):x = torch.relu(self.fc1(state))x = torch.relu(self.fc2(x))x = self.max_action * torch.tanh(self.fc3(x))return xclass Critic(nn.Module):def __init__(self, state_dim, action_dim):super(Critic, self).__init__()# Q1架构self.fc1 = nn.Linear(state_dim + action_dim, 400)self.fc2 = nn.Linear(400, 300)self.fc3 = nn.Linear(300, 1)# Q2架构self.fc4 = nn.Linear(state_dim + action_dim, 400)self.fc5 = nn.Linear(400, 300)self.fc6 = nn.Linear(300, 1)def forward(self, state, action):sa = torch.cat([state, action], 1)# Q1q1 = torch.relu(self.fc1(sa))q1 = torch.relu(self.fc2(q1))q1 = self.fc3(q1)# Q2q2 = torch.relu(self.fc4(sa))q2 = torch.relu(self.fc5(q2))q2 = self.fc6(q2)return q1, q2def Q1(self, state, action):sa = torch.cat([state, action], 1)q1 = torch.relu(self.fc1(sa))q1 = torch.relu(self.fc2(q1))q1 = self.fc3(q1)return q1# TD3算法类
class TD3:def __init__(self, state_dim, action_dim, max_action):self.actor = Actor(state_dim, action_dim, max_action)self.actor_target = Actor(state_dim, action_dim, max_action)self.actor_target.load_state_dict(self.actor.state_dict())self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)self.critic = Critic(state_dim, action_dim)self.critic_target = Critic(state_dim, action_dim)self.critic_target.load_state_dict(self.critic.state_dict())self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=3e-4)self.max_action = max_actionself.gamma = 0.99self.tau = 0.005self.policy_noise = 0.2self.noise_clip = 0.5self.policy_freq = 2self.total_it = 0def select_action(self, state):state = torch.FloatTensor(state.reshape(1, -1))return self.actor(state).cpu().data.numpy().flatten()def train(self, replay_buffer, batch_size=100):self.total_it += 1# 从回放缓冲区采样state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)with torch.no_grad():# 选择动作并添加噪声noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)# 计算目标Q值target_Q1, target_Q2 = self.critic_target(next_state, next_action)target_Q = torch.min(target_Q1, target_Q2)target_Q = reward + not_done * self.gamma * target_Q# 获取当前Q估计值current_Q1, current_Q2 = self.critic(state, action)# 计算批评损失critic_loss = nn.MSELoss()(current_Q1, target_Q) + nn.MSELoss()(current_Q2, target_Q)# 优化批评网络self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()# 延迟策略更新if self.total_it % self.policy_freq == 0:# 计算演员损失actor_loss = -self.critic.Q1(state, self.actor(state)).mean()# 优化演员网络self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# 软更新目标网络for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)# 回放缓冲区类
class ReplayBuffer:def __init__(self, max_size):self.buffer = deque(maxlen=max_size)def add(self, state, action, next_state, reward, done):self.buffer.append((state, action, next_state, reward, 1 - done))def sample(self, batch_size):state, action, next_state, reward, not_done = zip(*random.sample(self.buffer, batch_size))return torch.FloatTensor(state), torch.FloatTensor(action), torch.FloatTensor(next_state), torch.FloatTensor(reward).unsqueeze(1), torch.FloatTensor(not_done).unsqueeze(1)def __len__(self):return len(self.buffer)# 模拟路径规划环境
class PathPlanningEnv:def __init__(self):# 简单模拟障碍物环境,这里用一个二维数组表示self.obstacles = np.random.randint(0, 2, (10, 10))self.state_dim = 10 * 10 # 环境状态维度self.action_dim = 3 # 三个参数 sigma0, rho0, thetaself.max_action = 1.0def reset(self):# 重置环境return self.obstacles.flatten()def step(self, action):sigma0, rho0, theta = action# 简单模拟奖励计算,这里可以根据实际路径规划算法修改reward = np.random.randn()done = Falsenext_state = self.obstacles.flatten()return next_state, reward, done# 主训练循环
def main():env = PathPlanningEnv()state_dim = env.state_dimaction_dim = env.action_dimmax_action = env.max_actiontd3 = TD3(state_dim, action_dim, max_action)replay_buffer = ReplayBuffer(max_size=1000000)total_steps = 10000episode_steps = 0state = env.reset()for step in range(total_steps):episode_steps += 1# 选择动作action = td3.select_action(state)# 执行动作next_state, reward, done = env.step(action)# 将经验添加到回放缓冲区replay_buffer.add(state, action, next_state, reward, done)# 训练TD3if len(replay_buffer) > 100:td3.train(replay_buffer)state = next_stateif done or episode_steps >= 100:state = env.reset()episode_steps = 0# 输出最终优化的参数final_state = env.reset()final_action = td3.select_action(final_state)sigma0, rho0, theta = final_actionprint(f"Optimized sigma0: {sigma0}, rho0: {rho0}, theta: {theta}")if __name__ == "__main__":main()
代码解释
- 网络定义:定义了
Actor和Critic网络,分别用于生成动作和评估动作的价值。 - TD3类:实现了TD3算法的核心逻辑,包括动作选择、训练和目标网络的软更新。
- ReplayBuffer类:用于存储和采样经验数据。
- PathPlanningEnv类:模拟了一个包含障碍物的路径规划环境,提供了重置和执行动作的方法。
- 主训练循环:在环境中进行训练,不断更新策略网络和价值网络。
注意事项
- 此示例中的奖励计算是简单模拟的,实际应用中需要根据具体的路径规划算法进行修改。
- 障碍物环境的表示可以根据实际需求进行调整。
相关文章:
Python实现基于TD3(Twin Delayed Deep Deterministic Policy Gradient)算法来实时更新路径规划算法
下面是一个使用Python实现基于TD3(Twin Delayed Deep Deterministic Policy Gradient)算法来实时更新路径规划算法的三个参数(sigma0,rho0 和 theta)的示例代码。该算法将依据障碍物环境进行优化。 实现思路 环境定义…...
pytorch实现半监督学习
半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下: 1. 数据准备 有标签数据(Labeled Data)&…...
我的毕设之路:(2)系统类型的论文写法
一般先进行毕设的设计与实现,再在现成毕设基础上进行描述形成文档,那么论文也就成形了。 1 需求分析:毕业设计根据开题报告和要求进行需求分析和功能确定,区分贴合主题的主要功能和拓展功能能,删除偏离无关紧要的功能…...
LosslessScaling-学习版[steam价值30元的游戏无损放大/补帧工具]
LosslessScaling 链接:https://pan.xunlei.com/s/VOHc-yZBgwBOoqtdZAv114ZTA1?pwdxiih# 解压后运行"A-绿化-解压后运行我.cmd"...
concurrent.futures.Future对象详解:利用线程池与进程池实现异步操作
concurrent.futures.Future对象详解:利用线程池与进程池实现异步操作 一、前言二、使用线程池三、使用进程池四、注意事项五、结语 一、前言 在现代编程中,异步操作已成为提升程序性能和响应速度的关键手段。Python的concurrent.futures模块为此提供了强…...
StarRocks 安装部署
StarRocks 安装部署 StarRocks端口: 官方《配置检查》有服务端口详细描述: https://docs.starrocks.io/zh/docs/deployment/environment_configurations/ StarRocks架构:https://docs.starrocks.io/zh/docs/introduction/Architecture/ Sta…...
Python Matplotlib库:从入门到精通
Python Matplotlib库:从入门到精通 在数据分析和科学计算领域,可视化是一项至关重要的技能。Matplotlib作为Python中最流行的绘图库之一,为我们提供了强大的绘图功能。本文将带你从Matplotlib的基础开始,逐步掌握其高级用法&…...
线程概念、操作
一、背景知识 1、地址空间进一步理解 在父子进程对同一变量进行修改时发生写时拷贝,这时候拷贝的基本单位是4KB,会将该变量所在的页框全拷贝一份,这是因为修改该变量很有可能会修改其周围的变量(局部性原理)…...
【PySide6拓展】QSoundEffect
文章目录 【PySide6拓展】QSoundEffect 音效播放类**基本概念****什么是 QSoundEffect?****QSoundEffect 的特点****安装 PySide6** **如何使用 QSoundEffect?****1. 播放音效****示例代码:播放音效** **代码解析****QSoundEffect 的高级用法…...
33【脚本解析语言】
脚本语言也叫解析语言 脚本一词,相信很多人都听过,那么什么是脚本语言,我们在开发时有一个调试功能,但是发布版是需要编译执行的,体积比较大,同时这使得我们每次更新都需要重新编译,客户再…...
【Unity】 HTFramework框架(五十九)快速开发编辑器工具(Assembly Viewer + ILSpy)
更新日期:2025年1月23日。 Github源码:[点我获取源码] Gitee源码:[点我获取源码] 索引 开发编辑器工具MouseRayTarget焦点视角Collider线框Assembly Viewer搜索程序集ILSpy反编译程序集搜索GizmosElement类找到Gizmos菜单找到Gizmos窗口分析A…...
如何解决TikTok网络不稳定的问题
TikTok是目前全球最受欢迎的短视频平台之一,凭借其丰富多彩的内容和社交功能吸引了数以亿计的用户。然而,尽管TikTok在世界范围内的使用情况不断增长,但不少用户在使用过程中仍然会遇到网络不稳定的问题。无论是在观看视频时遇到缓冲…...
告别页面刷新!如何使用AJAX和FormData优化Web表单提交
系列文章目录 01-从零开始学 HTML:构建网页的基本框架与技巧 02-HTML常见文本标签解析:从基础到进阶的全面指南 03-HTML从入门到精通:链接与图像标签全解析 04-HTML 列表标签全解析:无序与有序列表的深度应用 05-HTML表格标签全面…...
WireShark4.4.2浏览器网络调试指南:数据统计(八)
概述 Wireshark 是一款功能强大的开源网络协议分析软件,被广泛应用于网络调试和数据分析。随着互联网的发展,以及网络安全问题日益严峻,了解如何使用 Wireshark进行浏览器网络调试显得尤为重要。最新的 Wireshark4.4.2 提供了更加强大的功能…...
Hypium+python鸿蒙原生自动化安装配置
Hypiumpython自动化搭建 文章目录 Python安装pip源配置HDC安装Hypium安装DevEco Testing Hypium插件安装及使用方法插件安装工程创建区域 Python安装 推荐从官网获取3.10版本,其他版本可能出现兼容性问题 Python下载地址 下载64/32bitwindows安装文件&am…...
2025创业思路和方向有哪些?
创业思路和方向是决定创业成功与否的关键因素。以下是一些基于找到的参考内容的创业思路和方向,旨在激发创业灵感: 一、技术创新与融合: 1、智能手机与云电视结合:开发集成智能手机功能的云电视,提供通讯、娱乐一体化体…...
实验五---控制系统的稳定性分析---自动控制原理实验课
一 实验目的 1、理解控制系统稳定性的概念 2、掌握多种判定系统稳定性的原理及方法 3、掌握使用Matlab软件进行控制系统的稳定性分析 二 实验仪器 计算机,MATLAB仿真软件 三 实验内容及步骤 1.计算系统闭环特征根,判别系统稳定性; 2.绘制系统…...
AttributeError: can‘t set attribute ‘lines‘
报错: ax p3.Axes3D(fig) ax.lines [] AttributeError: cant set attribute lines 总结下来,解决方案应包括: 1. 使用ax.clear()方法清除所有内容。 2. 逐个移除lines中的元素。 3. 检查matplotlib版本,确保没有已知的bug。…...
Day07:缓存-数据淘汰策略
Redis的数据淘汰策略有哪些 ? (key过期导致的) 在redis中提供了两种数据过期删除策略 第一种是惰性删除,在设置该key过期时间后,我们不去管它,当需要该key时,我们再检查其是否过期,如果过期&…...
基于聚类与相关性分析对马来西亚房价数据进行分析
碎碎念:由于最近太忙了,更新的比较慢,提前祝大家新春快乐,万事如意!本数据集的下载地址,读者可以自行下载。 1.项目背景 本项目旨在对马来西亚房地产市场进行初步的数据分析,探索各州的房产市…...
绝地求生罗技鼠标宏终极教程:5分钟实现完美压枪
绝地求生罗技鼠标宏终极教程:5分钟实现完美压枪 【免费下载链接】logitech-pubg PUBG no recoil script for Logitech gaming mouse / 绝地求生 罗技 鼠标宏 项目地址: https://gitcode.com/gh_mirrors/lo/logitech-pubg 还在为《绝地求生》中难以控制的后坐…...
数据与大语言模型融合:从NL2SQL到RAG架构的实践指南
1. 项目概述:当数据遇见大语言模型如果你是一名数据工程师、数据分析师,或者任何需要和数据打交道的开发者,最近肯定被“大语言模型”和“数据智能”这两个词轮番轰炸。我们手里有海量的数据,从结构化的业务表到非结构化的日志、文…...
OpenClaw-China:中文场景下开源大语言模型高效微调与部署实战指南
1. 项目概述与核心价值 最近在GitHub上看到一个挺有意思的项目,叫“BytePioneer-AI/openclaw-china”。光看这个名字,你可能会有点摸不着头脑——“BytePioneer”是字节先锋,“openclaw”是开放之爪,再加上“china”的后缀&#x…...
Java-Callgraph2:Java静态分析工具终极指南
Java-Callgraph2:Java静态分析工具终极指南 【免费下载链接】java-callgraph2 Programs for producing static call graphs for Java programs. 项目地址: https://gitcode.com/gh_mirrors/ja/java-callgraph2 Java-Callgraph2是一款功能强大的Java静态分析工…...
Proxima向量检索库:硬件优化与量化技术实战解析
1. 项目概述:一个为现代开发者打造的“近邻”代码库 最近在GitHub上看到一个挺有意思的项目,叫“Zen4-bit/Proxima”。乍一看这个标题,可能会有点摸不着头脑。“Zen4-bit”像是一个用户名或者某种架构的代号,而“Proxima”则让人联…...
ARM ETMv4跟踪寄存器架构与调试实践
1. ARM ETMv4 跟踪寄存器架构概述ARM嵌入式跟踪宏单元(ETM)是处理器调试架构中的关键组件,ETMv4作为其第四代架构,提供了更强大的指令和数据跟踪能力。与传统的断点调试不同,ETM采用实时跟踪技术,能够在不中断处理器运行的情况下&…...
技术团队的“信息透明”策略:报喜也报忧,反而更受信任
在软件测试领域,我们每天都在与“不确定性”打交道。一个隐藏的边界值、一次偶发的并发冲突、一个在特定机型上才能复现的诡异Bug,都足以让看似稳固的系统瞬间变得脆弱。然而,比起代码中的不确定性,更让测试团队感到无力的&#x…...
从LLM到智能体:基于推理循环的AI应用开发框架解析
1. 项目概述:一个面向推理任务的智能体框架最近在探索如何让AI模型更“聪明”地处理复杂任务时,我注意到了GitHub上一个名为“zyron-reasoning”的项目。这个由kaiogs07维护的仓库,其核心定位是一个用于构建和运行“推理智能体”的框架。简单…...
【Proteus仿真】SRF04超声波阈值预警系统设计与LCD1602交互实现
1. SRF04超声波测距原理与硬件连接 SRF04超声波模块是工业测距的经典选择,它通过发射40kHz的声波并计算回波时间差来测量距离。在实际项目中,我发现很多初学者容易忽略声速受温度影响的问题——常温下声速约343m/s,但温度每升高1℃࿰…...
别再死记硬背公式了!用MATLAB besselj函数5分钟搞定贝塞尔函数可视化
用MATLAB可视化贝塞尔函数:从数学恐惧到图形直觉的5分钟蜕变 当《数学物理方法》教材上那些密密麻麻的积分符号和无穷级数开始在你眼前跳舞,当教授在黑板上推导贝塞尔方程时粉笔灰与数学焦虑一起飞扬——是时候让MATLAB成为你理解这些特殊函数的"视…...
