具身系列——Diffusion Policy算法实现CartPole游戏
代码原理分析
1. 核心思想
该代码实现了一个基于扩散模型(Diffusion Model)的强化学习策略网络。扩散模型通过逐步去噪过程生成动作,核心思想是:
• 前向过程:通过T步逐渐将专家动作添加高斯噪声,最终变成纯噪声
• 逆向过程:训练神经网络预测噪声,通过T步逐步去噪生成动作
• 数学基础:基于DDPM(Denoising Diffusion Probabilistic Models)框架
算法步骤:
1.1 前向加噪:在动作空间逐步添加高斯噪声,将真实动作分布转化为高斯分布
q ( a t ∣ a t − 1 ) = N ( a t ; 1 − β t a t − 1 , β t I ) q(\mathbf{a}_t|\mathbf{a}_{t-1}) = \mathcal{N}(\mathbf{a}_t; \sqrt{1-\beta_t}\mathbf{a}_{t-1}, \beta_t\mathbf{I}) q(at∣at−1)=N(at;1−βtat−1,βtI)
其中 β t \beta_t βt 为噪声调度参数(网页4][网页5][网页8])。
1.2 逆向去噪:基于观测 o t \mathbf{o}_t ot 条件去噪生成动作
p θ ( a t − 1 ∣ a t , o t ) = N ( a t − 1 ; μ θ ( a t , o t , t ) , Σ t ) p_\theta(\mathbf{a}_{t-1}|\mathbf{a}_t, \mathbf{o}_t) = \mathcal{N}(\mathbf{a}_{t-1}; \mu_\theta(\mathbf{a}_t, \mathbf{o}_t, t), \Sigma_t) pθ(at−1∣at,ot)=N(at−1;μθ(at,ot,t),Σt)
去噪网络 μ θ \mu_\theta μθ 预测噪声残差(网页5][网页6][网页8])。
1.3 训练目标:最小化噪声预测误差
L = E t , a 0 , ϵ [ ∥ ϵ − ϵ θ ( α t a 0 + 1 − α t ϵ , o t , t ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{t,\mathbf{a}_0,\epsilon}\left[ \|\epsilon - \epsilon_\theta(\sqrt{\alpha_t}\mathbf{a}_0 + \sqrt{1-\alpha_t}\epsilon, \mathbf{o}_t, t)\|^2 \right] L=Et,a0,ϵ[∥ϵ−ϵθ(αta0+1−αtϵ,ot,t)∥2]
其中 α t = ∏ s = 1 t ( 1 − β s ) \alpha_t = \prod_{s=1}^t (1-\beta_s) αt=∏s=1t(1−βs)(网页4][网页8][网页11])。
2. 关键数学公式
• 前向过程(扩散过程):
q(a_t|a_{t-1}) = N(a_t; √(α_t)a_{t-1}, (1-α_t)I)
α_t = 1 - β_t,ᾱ_t = ∏_{i=1}^t α_i
a_t = √ᾱ_t a_0 + √(1-ᾱ_t)ε,其中ε ~ N(0,I)
• 训练目标(噪声预测):
L = ||ε - ε_θ(a_t, s, t)||^2
• 逆向过程(采样过程):
p_θ(a_{t-1}|a_t) = N(a_{t-1}; μ_θ(a_t, s, t), Σ_t)
μ_θ = 1/√α_t (a_t - β_t/√(1-ᾱ_t) ε_θ)
逐行代码注释
import torch
import gymnasium as gym
import numpy as npclass DiffusionPolicy(torch.nn.Module):def __init__(self, state_dim=4, action_dim=2, T=20):super().__init__()self.T = T # 扩散过程总步数self.betas = torch.linspace(1e-4, 0.02, T) # 噪声方差调度self.alphas = 1 - self.betas # 前向过程参数self.alpha_bars = torch.cumprod(self.alphas, dim=0) # 累积乘积ᾱ# 去噪网络(输入维度:state(4) + action(2) + timestep(1) = 7)self.denoiser = torch.nn.Sequential(torch.nn.Linear(7, 64), # 输入层torch.nn.ReLU(), # 激活函数torch.nn.Linear(64, 2) # 输出预测的噪声)self.optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=1e-3)def train_step(self, states, expert_actions):batch_size = states.size(0)t = torch.randint(0, self.T, (batch_size,)) # 随机采样时间步alpha_bar_t = self.alpha_bars[t].unsqueeze(1) # 获取对应ᾱ_t# 前向加噪(公式实现)noise = torch.randn_like(expert_actions) # 生成高斯噪声noisy_actions = torch.sqrt(alpha_bar_t) * expert_actions + \torch.sqrt(1 - alpha_bar_t) * noise # 公式(2)# 输入拼接(状态、加噪动作、归一化时间步)inputs = torch.cat([states, noisy_actions,(t.float() / self.T).unsqueeze(1) # 时间步归一化到[0,1]], dim=1) # 最终维度:batch_size x 7pred_noise = self.denoiser(inputs) # 预测噪声loss = torch.mean((noise - pred_noise)**2) # MSE损失return lossdef sample_action(self, state):state_tensor = torch.FloatTensor(state).unsqueeze(0)a_t = torch.randn(1, 2) # 初始化为随机噪声(动作维度2)# 逆向去噪过程(需要补全)for t in reversed(range(self.T)):# 应实现的步骤:# 1. 获取当前时间步参数# 2. 拼接输入(状态,当前动作,时间步)# 3. 预测噪声ε_θ# 4. 根据公式计算均值μ# 5. 采样新动作(最后一步不添加噪声)passreturn a_t.detach().numpy()[0] # 返回最终动作
执行过程详解
训练流程
- 随机采样时间步:为每个样本随机选择扩散步t ∈ [0, T-1]
- 前向加噪:根据公式将专家动作添加对应程度的噪声
- 输入构造:拼接状态、加噪动作和归一化时间步
- 噪声预测:神经网络预测添加的噪声
- 损失计算:最小化预测噪声与真实噪声的MSE
采样流程(需补全)
- 初始化:从高斯噪声开始
- 迭代去噪:从t=T到t=1逐步去噪
• 根据当前动作和状态预测噪声
• 计算前一步的均值
• 添加随机噪声(最后一步除外) - 输出:得到最终去噪后的动作
关键改进建议
- 实现逆向过程:需要补充时间步循环和去噪公式
- 添加方差调度:在采样时使用更复杂的方差计算
- 时间步嵌入:可以使用正弦位置编码代替简单归一化
- 网络结构优化:考虑使用Transformer或条件批归一化
该实现展示了扩散策略的核心思想,但完整的扩散策略还需要实现完整的逆向采样过程,并可能需要调整噪声调度参数以获得更好的性能。
最终可执行代码:
import torch
import gymnasium as gym
import numpy as npclass DiffusionPolicy(torch.nn.Module):def __init__(self, state_dim=4, action_dim=2, T=20):super().__init__()self.T = Tself.betas = torch.linspace(1e-4, 0.02, T)self.alphas = 1 - self.betasself.alpha_bars = torch.cumprod(self.alphas, dim=0)# 去噪网络(输入维度:4+2+1=7)self.denoiser = torch.nn.Sequential(torch.nn.Linear(7, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2))self.optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=1e-3)def train_step(self, states, expert_actions):batch_size = states.size(0)t = torch.randint(0, self.T, (batch_size,))alpha_bar_t = self.alpha_bars[t].unsqueeze(1)# 前向加噪公式[2](@ref)noise = torch.randn_like(expert_actions)noisy_actions = torch.sqrt(alpha_bar_t) * expert_actions + torch.sqrt(1 - alpha_bar_t) * noise# 输入拼接(维度对齐)[1](@ref)inputs = torch.cat([states, noisy_actions,(t.float() / self.T).unsqueeze(1)], dim=1) # 最终维度:batch_size x 7pred_noise = self.denoiser(inputs)loss = torch.mean((noise - pred_noise)**2)return lossdef sample_action(self, state):state_tensor = torch.FloatTensor(state).unsqueeze(0)a_t = torch.randn(1, 2) # 二维动作空间[2](@ref)# 逆向去噪过程[2](@ref)for t in reversed(range(self.T)):alpha_t = self.alphas[t]alpha_bar_t = self.alpha_bars[t]inputs = torch.cat([state_tensor,a_t,torch.tensor([[t / self.T]], dtype=torch.float32)], dim=1)pred_noise = self.denoiser(inputs)a_t = (a_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)if t > 0:a_t += torch.sqrt(self.betas[t]) * torch.randn_like(a_t)return torch.argmax(a_t).item() # 离散动作选择[1](@ref)if __name__ == "__main__":env = gym.make('CartPole-v1')policy = DiffusionPolicy()# 关键修复:确保状态数据维度统一[1,2](@ref)states, actions = [], []state, _ = env.reset()for _ in range(1000):action = env.action_space.sample()next_state, _, terminated, truncated, _ = env.step(action)done = terminated or truncated# 强制转换状态为numpy数组并检查维度[2](@ref)state = np.array(state, dtype=np.float32).flatten()if len(state) != 4:raise ValueError(f"Invalid state shape: {state.shape}")states.append(state) # 确保每个状态是(4,)的数组actions.append(action)if done:state, _ = env.reset()else:state = next_state# 维度验证与转换[1](@ref)states_array = np.stack(states) # 强制转换为(1000,4)if states_array.shape != (1000,4):raise ValueError(f"States shape error: {states_array.shape}")actions_onehot = np.eye(2)[np.array(actions)] # 转换为one-hot编码[2](@ref)states_tensor = torch.FloatTensor(states_array)actions_tensor = torch.FloatTensor(actions_onehot)# 训练循环for epoch in range(100):loss = policy.train_step(states_tensor, actions_tensor)policy.optimizer.zero_grad()loss.backward()policy.optimizer.step()print(f"Epoch {epoch}, Loss: {loss.item():.4f}")# 测试state, _ = env.reset()for _ in range(200):action = policy.sample_action(state)state, _, done, _, _ = env.step(action)if done: break
相关文章:
具身系列——Diffusion Policy算法实现CartPole游戏
代码原理分析 1. 核心思想 该代码实现了一个基于扩散模型(Diffusion Model)的强化学习策略网络。扩散模型通过逐步去噪过程生成动作,核心思想是: • 前向过程:通过T步逐渐将专家动作添加高斯噪声,最终变成…...
前端性能优化:深入解析哈希算法与TypeScript实践
/ 示例:开放寻址哈希表核心实现 class OpenAddressingHashTable<T> {private size: number;private keys: (string | null)[];private values: (T | null)[];private tombstone Symbol(Deleted);constructor(size: number 53) {this.size size;this.keys …...
知识就是力量——物联网应用技术
基础知识篇 一、常用电子元器件1——USB Type C 接口引脚详解特点接口定义作用主从设备关于6P引脚的简介 2——常用通信芯片CH343P概述特点引脚定义 CH340概述特点封装 3——蜂鸣器概述类型驱动电路原文链接 二、常用封装介绍贴片电阻电容封装介绍封装尺寸与功率关系࿱…...
(windows)conda虚拟环境下open-webui安装与启动
一、创建conda环境 重点强调下,如果用python pip安装,一定要选择python3.11系列版本,我选的3.11.9。 如果你的版本不是这个系列,将会出现一些未知的问题。 conda create -n open-webui python3.11 -y如下就创建好了 二、安装o…...
oracle密码过期 ORA-28001解决方案: the password has expired
** oracle密码过期 ORA-28001解决方案: the password has expired ** oracle 11g 默认密码过期时间为180天密码过期后,访问数据库会出现如下异常java.sql.SQLException: ORA-28001: the password has expired 查询密码过期设定 select * from dba profiles where…...
GStreamer —— 3.1、Qt+GStreamer制作多功能播放器,支持本地mp4文件、rtsp流、usb摄像头等(可跨平台,附源码)
🔔 GStreamer 相关音视频技术、疑难杂症文章合集(掌握后可自封大侠 ⓿_⓿)(记得收藏,持续更新中…) 运行效果...
六十天Linux从0到项目搭建(第十天)(系统调用 vs 库函数/进程管理的建模/为什么进程管理中需要PCB?/exec 函数/fork原理与行为详解)
1 系统调用 vs 库函数:本质区别与协作关系 核心区别 特性系统调用(System Call)库函数(Library Function)定义操作系统内核提供的 底层接口,直接操作硬件。封装系统调用的 高级函数,提供便捷功…...
资本运营:基于Python实现的资本运作模拟
基于Python实现的一个简单的资本运营框架; 企业生命周期演示:观察初创→成长→上市→并购全流程 行业对比分析:不同行业的财务特征和估值差异 资本运作策略:体验IPO定价、投资决策、并购整合等操作 市场动态观察ÿ…...
当EFISH-SBC-RK3576遇上区块链:物联网安全与可信数据网络
在工业物联网场景中,设备身份伪造与数据篡改是核心安全隐患。EFISH-SBC-RK3576 通过 硬件安全模块 区块链链上验证,实现设备身份可信锚定与数据全生命周期加密,安全性能提升10倍以上。 1. 安全架构:从芯片到链的端到端防…...
关于spark在yarn上运行时候内存的介绍
在YARN上运行Spark时,内存管理是性能调优的核心环节。以下是 Driver Memory、Executor Memory、堆内存(Heap Memory) 和 堆外内存(Off-Heap Memory) 的区别与配置方法,以及实际场景中的最佳实践:…...
分布式系统面试总结:3、分布式锁(和本地锁的区别、特点、常见实现方案)
仅供自学回顾使用,请支持javaGuide原版书籍。 本篇文章涉及到的分布式锁,在本人其他文章中也有涉及。 《JUC:三、两阶段终止模式、死锁的jconsole检测、乐观锁(版本号机制CAS实现)悲观锁》:https://blog.…...
【VSCode的安装与配置】
目录: 一:下载 VSCode二:安装 VSCode三:配置 VSCode 一:下载 VSCode 下载地址:https://code.visualstudio.com/download 下载完成之后,在对应的下载目录中可以看到安装程序。 二:安装…...
ElasticSearch常用优化点
关闭交换分区:因为Linux采用了三级页表虚存管理,关闭交换分区可以减少系统IO,页面换入唤出时所耗费的总线时间以及减少系统中断次数;swap的使用会显著增加延迟和降低吞吐量。文件描述符配置:任何网络应用都需要增加文件…...
脱围机制-react18废除forwardRef->react19直接使用ref的理解
采用ref,可以在父组件调用到子组件的功能 第一步:在父组件声明ref并传递ref interface SideOptsHandle {refreshData: () > Promise<void> }const sideOptsRef useRef<SideOptsHandle>(null) // 创建 ref<SideOpts ref{sideOptsRef…...
Spark2 之 Expression/Functions
ExpressionConverter src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala TopNTransformer src/main/scala/org/apache/gluten/execution/TopNTransformer.scala...
Windows中安装git工具
下载好git安装包 点击next 选择安装目录 根据需要去勾选 点击next 点击next PATH环境选择第二个【Git...software】即可,再点击【Next】。 第一种配置是“仅从Git Bash使用Git”。这是最安全的选择,因为您的PATH根本不会被修改。您只能使用 Git Bash 的…...
【CSS】CSS 使用全教程
CSS 使用全教程 介绍 CSS(层叠样式表,Cascading Style Sheets)是一种样式表语言,用于描述 HTML 或 XML 文档的布局和外观,它允许开发者将文档的内容结构与样式表现分离,通过定义一系列的样式规则来控制网页…...
《HarmonyOS Next自定义TabBar页签凸起和凹陷案例与代码》
引言 自定义TabBar在HarmonyOS Next应用中很常见,本文将介绍如何实现页签的凸起和凹陷效果,并通过代码示例展示实现过程。 实现思路 基于已有的自定义TabBar思路,通过调整布局和样式实现凸起和凹陷效果。凸起效果可以通过在选中的页签下方…...
全分辨率免ROOT懒人精灵-自动化编程思维-设计思路-实战训练
全分辨率免ROOT懒人精灵-自动化编程思维-设计思路-实战训练 1.2025新版懒人精灵-实战红果搜索关键词刷视频:https://www.bilibili.com/video/BV1eK9kY7EWV 2.懒人精灵-全分辨率节点识别(红果看广告领金币小实战):https://www.bili…...
如何在IDEA中借助深度思考模型 QwQ 提高编码效率?
通义灵码上新模型选择功能,不仅引入了 DeepSeek 满血版 V3 和 R1 这两大 “新星”,Qwen2.5-Max 和 QWQ 也强势登场,正式加入通义灵码的 “豪华阵容”。开发者只需在通义灵码智能问答窗口的输入框中,单击模型选择的下拉菜单&#x…...
C++11QT复习 (四)
Day6-1 输入输出流运算符重载(2025.03.25) 1. 拷贝构造函数的调用时机 2. 友元2.1 友元函数 3. 输入输出流运算符重载3.1 关键知识点3.2 代码3.3 关键问题3.4 完整代码 4. 下标访问运算符 operator[]4.1 关键知识点4.2 代码 5. 函数调用运算符 operator…...
LVS的 NAT 模式实验
文章目录 目录 文章目录 概要 IP规划与题目分析 实验步骤 一、nginx配置(rs1、rs2、rs3) 二、LVS配置 三、客户端配置 四、防火墙和selinux配置 实验结果 痛点解答 概要 LVS/NAT lvs/nat网络地址转换模式,进站/出站的数据流量经过分发器(IP负…...
【MacOS】2025年硬核方法清理MacOS中的可清除空间(Purgeable space)
背景 MacOS使用一段时间之后,硬盘空间会越来越少,但自己的文件没有存储那么多,在储存空间中可以发现可用空间明明还剩很多,但磁盘工具却显示已满,见下图。 尝试解决 df -h 命令却发现磁盘已经被快被占满。使用du命…...
ue材质学习感想总结笔记
2025 - 3 - 27 1.1 加法 对TexCoord上的每一个像素加上一个值,如果加上0.1,0.1, 那么左上角原来0,0的位置变成了0.1,0.1 右上角就变成了1.1,1.1,那么原来0,0的位置就去到了左上角左上边,所以图像往左上偏移。 总而言…...
Go 语言 sync 包使用教程
Go 语言 sync 包使用教程 Go 语言的 sync 包提供了基本的同步原语,用于在并发编程中协调 goroutine 之间的操作。 1. 互斥锁 (Mutex) 互斥锁用于保护共享资源,确保同一时间只有一个 goroutine 可以访问。 特点: 最基本的同步原语&#x…...
约束文件SDC常用命令
约束文件SDC常用命令 定义时钟create_clock -name CLK-period 2 [get_ports_clk]告诉工具主时钟周期是2ns(频率500MHz),从clk端口输入 输入信号延迟set_input_delay 0.5 -clock CLK [get_ports data_in]数据进芯片前,外部电路已消耗0.5ns,综合要预留这段“堵车时间”。 输出…...
信而泰PFC/ECN流量测试方案:打造智能无损网络的关键利器
导语: AI算力爆发的背后,如何保障网络“零丢包”? 在当今数据中心网络中,随着AI、高性能计算(HPC)和分布式存储等应用的飞速发展,网络的无损传输能力变得至关重要。PFC(基于优先级的…...
golang不使用锁的情况下,对slice执行并发写操作,是否会有并发问题呢?
背景 并发问题最简单的解决方案加个锁,但是,加锁就会有资源争用,提高并发能力其中的一个优化方向就是减少锁的使用。 我在之前的这篇文章《开启多个协程,并行对struct中的每个元素操作,是否会引起并发问题?》中讨论过多协程场景下struct的并发问题。 Go语言中的slice在…...
Android 底部EditView输入时悬浮到软键盘上方
1. 修改 Activity 的 Manifest 配置 确保你的 Activity 在 AndroidManifest.xml 中有以下配置: <activityandroid:name".YourActivity"android:windowSoftInputMode"adjustResize|stateHidden" /> 关键点: adjustResize 是…...
CNN和LSTM的计算复杂度分析
前言:今天做边缘计算的时候,在评估模型性能的时候发现NPU计算的大部分时间都花在了LSTM上,使用的是Bi-LSTM(耗时占比98%),CNN耗时很短,不禁会思考为什么LSTM会花费这么久时间。 首先声明一下实…...
