当前位置: 首页 > news >正文

深入理解策略梯度算法

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。

策略梯度算法的基本概念

在强化学习中,智能体通过与环境交互来学习一种策略(policy),该策略定义了在每个状态下采取哪种行动的概率分布。策略可以是确定性的或随机的。在策略梯度方法中,策略通常表示为参数化的概率分布,即 $\pi_\theta(a|s)$,其中$\theta$ 是策略的参数,$s$ 是状态,$a$ 是行动。

目标是找到最佳的策略参数 $\theta$ 使得智能体在环境中获得的期望回报最大。为此,我们需要定义一个目标函数$J(\theta)$,表示期望回报。然后,通过梯度上升法(或下降法)来优化该目标函数。

策略梯度的数学推导

假设我们的目标函数 $J(\theta)$ 定义为:

J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} [R(\tau)]

其中$\tau$ 表示一个完整的轨迹(从初始状态到终止状态的状态-动作序列),$R(\tau)$ 是该轨迹的总回报。根据策略的定义,我们有:

\pi_\theta(\tau) = p(s_0) \prod_{t=0}^{T-1} \pi_\theta(a_t|s_t) p(s_{t+1}|s_t, a_t)

因此,目标函数可以重写为:

J(\theta) = \sum_{\tau} \pi_\theta(\tau) R(\tau)

为了最大化$J(\theta)$,我们需要计算其梯度 $\nabla_\theta J(\theta)$

\nabla_\theta J(\theta) = \nabla_\theta \sum_{\tau} \pi_\theta(\tau) R(\tau) = \sum_{\tau} \nabla_\theta \pi_\theta(\tau) R(\tau)

使用概率分布的梯度性质,我们有:

\nabla_\theta \pi_\theta(\tau) = \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau)

因此,梯度可以表示为:

\nabla_\theta J(\theta) = \sum_{\tau} \pi_\theta(\tau) \nabla_\theta \log \pi_\theta(\tau) R(\tau) = \mathbb{E}_{\tau \sim \pi_\theta} [\nabla_\theta \log \pi_\theta(\tau) R(\tau)]

这个公式被称为策略梯度定理。为了估计这个期望值,我们通常使用蒙特卡洛方法,从策略 $\pi_\theta$ 中采样多个轨迹 $\tau$,然后计算平均值。

策略梯度算法的实现

我们以一个简单的环境为例,展示如何实现策略梯度算法。假设我们有一个离散动作空间的环境,我们使用一个神经网络来参数化策略$\pi_\theta(a|s)$

步骤 1:环境设置

首先,设置环境和参数:

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimenv = gym.make('CartPole-v1')
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]
步骤 2:策略网络定义

定义一个简单的策略网络:

class PolicyNetwork(nn.Module):def __init__(self, state_dim, n_actions):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 128)self.fc2 = nn.Linear(128, n_actions)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return torch.softmax(x, dim=-1)policy = PolicyNetwork(state_dim, n_actions)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
步骤 3:采样轨迹

编写函数来从策略中采样轨迹:

def sample_trajectory(env, policy, max_steps=1000):state = env.reset()states, actions, rewards = [], [], []for _ in range(max_steps):state = torch.FloatTensor(state).unsqueeze(0)probs = policy(state)action = np.random.choice(n_actions, p=probs.detach().numpy()[0])next_state, reward, done, _ = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:breakstate = next_statereturn states, actions, rewards
步骤 4:计算回报和梯度

计算每个状态的回报,并使用策略梯度定理更新策略:

def compute_returns(rewards, gamma=0.99):returns = []G = 0for r in reversed(rewards):G = r + gamma * Greturns.insert(0, G)return returnsdef update_policy(policy, optimizer, states, actions, returns):returns = torch.FloatTensor(returns)loss = 0for state, action, G in zip(states, actions, returns):state = state.squeeze(0)probs = policy(state)log_prob = torch.log(probs[action])loss += -log_prob * Goptimizer.zero_grad()loss.backward()optimizer.step()
步骤 5:训练策略

将上述步骤组合在一起,训练策略网络:

num_episodes = 1000
for episode in range(num_episodes):states, actions, rewards = sample_trajectory(env, policy)returns = compute_returns(rewards)update_policy(policy, optimizer, states, actions, returns)if episode % 100 == 0:print(f"Episode {episode}, total reward: {sum(rewards)}")
总结

通过以上步骤,我们实现了一个基本的策略梯度算法。策略梯度方法通过直接优化策略来最大化智能体的期望回报,具有理论上的简洁性和实用性。本文详细推导了策略梯度的数学公式,并提供了具体的实现步骤,希望能够帮助读者更好地理解和应用这一重要的强化学习算法。

相关文章:

深入理解策略梯度算法

策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。 策略梯度算法的基本概念 在强化学习…...

Unicode 和 UTF-8 以及它们之间的关系

通俗易懂的 Unicode 和 UTF-8 解释 Unicode 是什么? 想象一下,我们有一个巨大的图书馆,这个图书馆里有各种各样的书,每本书都有一个唯一的编号。Unicode 就像是这个图书馆的目录系统,它给世界上所有的字符&#xff0…...

【C++】多态详解

💗个人主页💗 ⭐个人专栏——C学习⭐ 💫点击关注🤩一起学习C语言💯💫 目录 一、多态概念 二、多态的定义及实现 1. 多态的构成条件 2. 虚函数 2.1 什么是虚函数 2.2 虚函数的重写 2.3 虚函数重写的两个…...

C#异常捕获

前言 在C#中,我们无法保证我们编写的程序没有一点bug,如果我们对于这些抛出异常的bug不进行任何的处理的话,那么我们的软件在抛出这些异常的时候就会崩溃,也就是软件闪退,并且这种闪退由于我们没有进行处理&#xff0…...

工业一体机根据软件应用需求灵活选配

在当今工业领域,数字化、智能化的发展趋势愈发明显,工业一体机作为关键的设备,其重要性日益凸显。而能够根据软件应用需求进行灵活选配的工业一体机,更是为企业提供了高效、定制化的解决方案。 一、工业一体机的全封闭无风扇散热功…...

centos7 mqtt服务mosquitto搭建记录

1、系统centos7.6,安装默认版本 yum install mosquitto 2、启动运行 systemctl start mosquitto 3、设置自启动 systemctl enable mosquitto 4、修改配置文件 vim /etc/mosquitto/mosquitto.conf 监听端口,默认为1883,需要修改删除前面…...

双阶段目标检测算法:精确与效率的博弈

双阶段目标检测算法:精确与效率的博弈 目标检测是计算机视觉领域的一个核心任务,它涉及在图像或视频中识别和定位多个对象。双阶段目标检测算法是一种特殊的目标检测方法,它通过两个阶段来提高检测的准确性。本文将详细介绍双阶段目标检测算…...

Python量化交易策略

策略详情 按照1分k线图;跳过9:30点1分k线图不计算 买入;监控市面的可转债;当某1分涨幅大于x涨幅,一直重复x次,选择买入,符合x设置的条件只选择成交额最大的可转债买入(x要自定义&…...

为什么我感觉 C 语言在 Linux 下执行效率比 Windows 快得多?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「Linux的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!!Windows的终端或者叫控制台…...

算法导论 总结索引 | 第四部分 第十六章:贪心算法

1、求解最优化问题的算法 通常需要经过一系列的步骤,在每个步骤都面临多种选择。对于许多最优化问题,使用动态规划算法求最优解有些杀鸡用牛刀了,可以使用更简单、更高效的算法 贪心算法(greedy algorithm)就是这样的算…...

用“文心一言”写的文章,看看AI写得怎么样?

​零售连锁店的“支付结算”业务设计 在数字化浪潮的推动下,连锁店零售支付结算的设计愈发重要。一个优秀的支付结算设计不仅能够提升用户体验,还能增强品牌竞争力,进而促进销售增长。 本文将围绕一个具体的连锁店零售支付结算案例&#xf…...

企业消费采购成本和员工体验如何实现“鱼和熊掌“的兼得?

有企业说企业消费采购成本和员工体验的关系好比是“鱼和熊掌”,无法兼得? 要想控制好成本就一定要加强管控,但是加强管控以后,就会很难让员工获得满意的体验度。如果不加以管控,员工自由度增加了,往往就很难…...

发表EI论文相当于SCI几区?

EI(工程索引)本身并不进行分区,它是一个收录工程领域高质量文献的数据库,与SCI(科学引文索引)的分区制度不同。然而,在非正式的学术评价中,有时人们会将EI与SCI的分区进行比较。 虽…...

STFT短时傅里叶变换MTLAB简析

代码: 解释: 如果信号x有Nx个时间样本,短时傅里叶变换的结果矩阵s有k列; k的计算方式如图所示,M是窗函数的长度,L是重叠长度。 此符号是向下取整符号。 短时傅里叶变换的结果矩阵s的行数与参数‘FFTLength’…...

海致科技实施实习生面试

一、面试内容 注:此次是电话面试 1.是XX先生吗 2.你是有考虑转实施的吗? 3.请讲一下你对项目部署实施的理解和掌握 4.用过数据库,会编写SQL语句吗? 5.讲一下SQL的常用关键字 6.了解SQL中的函数吗?谈谈函数 7.多…...

论文阅读之旋转目标检测ARC:《Adaptive Rotated Convolution for Rotated Object Detection》

论文link:link code:code ARC是一个改进的backbone,相比于ResNet,最后的几层有一些改变。 Introduction ARC自适应地旋转以调整每个输入的条件参数,其中旋转角度由路由函数以数据相关的方式预测。此外,还采…...

面向对象(Java)

构造方法只能在对象实例化的时候调用 this可以作为方法参数,表示调用方法的当前对象 this可以作为方法返回值,表示返回当前对象 封装 通过方法访问数据,隐藏类的实现细节 static:类对象共享,类加载时产生,…...

I/O多路复用

参考面试官:简单说一下阻塞IO、非阻塞IO、IO复用的区别 ?_unix环境编程 阻塞io和非阻塞io-CSDN博客 同步阻塞(BIO) BIO 以流的方式处理数据 应用程序发起一个系统调用(recvform),这个时候应用程序会一直阻塞下去&am…...

线性代数基础概念:向量空间

目录 线性代数基础概念:向量空间 1. 向量空间的定义 2. 向量空间的性质 3. 基底和维数 4. 子空间 5. 向量空间的例子 总结 线性代数基础概念:向量空间 向量空间是线性代数中最基本的概念之一,它为我们提供了一个抽象的框架&#xff0c…...

php 抓取淘宝商品评论数据 json

要抓取淘宝商品评论数据,你可以使用PHP的cURL库来发送HTTP请求并获取JSON格式的数据。 API接入流程:需要开放平台或者是封装接口注册账号,并申请相应的API使用权限,以获取必要的密钥和接口文档。获取接口使用权限:接入…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...

【机器视觉】单目测距——运动结构恢复

ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛&#xf…...

使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装

以下是基于 vant-ui&#xff08;适配 Vue2 版本 &#xff09;实现截图中照片上传预览、删除功能&#xff0c;并封装成可复用组件的完整代码&#xff0c;包含样式和逻辑实现&#xff0c;可直接在 Vue2 项目中使用&#xff1a; 1. 封装的图片上传组件 ImageUploader.vue <te…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。

1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj&#xff0c;再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

算法:模拟

1.替换所有的问号 1576. 替换所有的问号 - 力扣&#xff08;LeetCode&#xff09; ​遍历字符串​&#xff1a;通过外层循环逐一检查每个字符。​遇到 ? 时处理​&#xff1a; 内层循环遍历小写字母&#xff08;a 到 z&#xff09;。对每个字母检查是否满足&#xff1a; ​与…...

mac 安装homebrew (nvm 及git)

mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用&#xff1a; 方法一&#xff1a;使用 Homebrew 安装 Git&#xff08;推荐&#xff09; 步骤如下&#xff1a;打开终端&#xff08;Terminal.app&#xff09; 1.安装 Homebrew…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下&#xff0c;推客小程序系统凭借其裂变传播、精准营销等特性&#xff0c;成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径&#xff0c;助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...