深入理解策略梯度算法
策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。
策略梯度算法的基本概念
在强化学习中,智能体通过与环境交互来学习一种策略(policy),该策略定义了在每个状态下采取哪种行动的概率分布。策略可以是确定性的或随机的。在策略梯度方法中,策略通常表示为参数化的概率分布,即 ,其中
是策略的参数,
是状态,
是行动。
目标是找到最佳的策略参数 $\theta$ 使得智能体在环境中获得的期望回报最大。为此,我们需要定义一个目标函数,表示期望回报。然后,通过梯度上升法(或下降法)来优化该目标函数。
策略梯度的数学推导
假设我们的目标函数 $J(\theta)$ 定义为:
其中 表示一个完整的轨迹(从初始状态到终止状态的状态-动作序列),
是该轨迹的总回报。根据策略的定义,我们有:
因此,目标函数可以重写为:
为了最大化,我们需要计算其梯度
:
使用概率分布的梯度性质,我们有:
因此,梯度可以表示为:
这个公式被称为策略梯度定理。为了估计这个期望值,我们通常使用蒙特卡洛方法,从策略 中采样多个轨迹
,然后计算平均值。
策略梯度算法的实现
我们以一个简单的环境为例,展示如何实现策略梯度算法。假设我们有一个离散动作空间的环境,我们使用一个神经网络来参数化策略。
步骤 1:环境设置
首先,设置环境和参数:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimenv = gym.make('CartPole-v1')
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]
步骤 2:策略网络定义
定义一个简单的策略网络:
class PolicyNetwork(nn.Module):def __init__(self, state_dim, n_actions):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(state_dim, 128)self.fc2 = nn.Linear(128, n_actions)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return torch.softmax(x, dim=-1)policy = PolicyNetwork(state_dim, n_actions)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
步骤 3:采样轨迹
编写函数来从策略中采样轨迹:
def sample_trajectory(env, policy, max_steps=1000):state = env.reset()states, actions, rewards = [], [], []for _ in range(max_steps):state = torch.FloatTensor(state).unsqueeze(0)probs = policy(state)action = np.random.choice(n_actions, p=probs.detach().numpy()[0])next_state, reward, done, _ = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:breakstate = next_statereturn states, actions, rewards
步骤 4:计算回报和梯度
计算每个状态的回报,并使用策略梯度定理更新策略:
def compute_returns(rewards, gamma=0.99):returns = []G = 0for r in reversed(rewards):G = r + gamma * Greturns.insert(0, G)return returnsdef update_policy(policy, optimizer, states, actions, returns):returns = torch.FloatTensor(returns)loss = 0for state, action, G in zip(states, actions, returns):state = state.squeeze(0)probs = policy(state)log_prob = torch.log(probs[action])loss += -log_prob * Goptimizer.zero_grad()loss.backward()optimizer.step()
步骤 5:训练策略
将上述步骤组合在一起,训练策略网络:
num_episodes = 1000
for episode in range(num_episodes):states, actions, rewards = sample_trajectory(env, policy)returns = compute_returns(rewards)update_policy(policy, optimizer, states, actions, returns)if episode % 100 == 0:print(f"Episode {episode}, total reward: {sum(rewards)}")
总结
通过以上步骤,我们实现了一个基本的策略梯度算法。策略梯度方法通过直接优化策略来最大化智能体的期望回报,具有理论上的简洁性和实用性。本文详细推导了策略梯度的数学公式,并提供了具体的实现步骤,希望能够帮助读者更好地理解和应用这一重要的强化学习算法。
相关文章:
深入理解策略梯度算法
策略梯度(Policy Gradient)算法是强化学习中的一种重要方法,通过优化策略以获得最大回报。本文将详细介绍策略梯度算法的基本原理,推导其数学公式,并提供具体的例子来指导其实现。 策略梯度算法的基本概念 在强化学习…...

Unicode 和 UTF-8 以及它们之间的关系
通俗易懂的 Unicode 和 UTF-8 解释 Unicode 是什么? 想象一下,我们有一个巨大的图书馆,这个图书馆里有各种各样的书,每本书都有一个唯一的编号。Unicode 就像是这个图书馆的目录系统,它给世界上所有的字符࿰…...

【C++】多态详解
💗个人主页💗 ⭐个人专栏——C学习⭐ 💫点击关注🤩一起学习C语言💯💫 目录 一、多态概念 二、多态的定义及实现 1. 多态的构成条件 2. 虚函数 2.1 什么是虚函数 2.2 虚函数的重写 2.3 虚函数重写的两个…...
C#异常捕获
前言 在C#中,我们无法保证我们编写的程序没有一点bug,如果我们对于这些抛出异常的bug不进行任何的处理的话,那么我们的软件在抛出这些异常的时候就会崩溃,也就是软件闪退,并且这种闪退由于我们没有进行处理࿰…...

工业一体机根据软件应用需求灵活选配
在当今工业领域,数字化、智能化的发展趋势愈发明显,工业一体机作为关键的设备,其重要性日益凸显。而能够根据软件应用需求进行灵活选配的工业一体机,更是为企业提供了高效、定制化的解决方案。 一、工业一体机的全封闭无风扇散热功…...

centos7 mqtt服务mosquitto搭建记录
1、系统centos7.6,安装默认版本 yum install mosquitto 2、启动运行 systemctl start mosquitto 3、设置自启动 systemctl enable mosquitto 4、修改配置文件 vim /etc/mosquitto/mosquitto.conf 监听端口,默认为1883,需要修改删除前面…...
双阶段目标检测算法:精确与效率的博弈
双阶段目标检测算法:精确与效率的博弈 目标检测是计算机视觉领域的一个核心任务,它涉及在图像或视频中识别和定位多个对象。双阶段目标检测算法是一种特殊的目标检测方法,它通过两个阶段来提高检测的准确性。本文将详细介绍双阶段目标检测算…...
Python量化交易策略
策略详情 按照1分k线图;跳过9:30点1分k线图不计算 买入;监控市面的可转债;当某1分涨幅大于x涨幅,一直重复x次,选择买入,符合x设置的条件只选择成交额最大的可转债买入(x要自定义&…...

为什么我感觉 C 语言在 Linux 下执行效率比 Windows 快得多?
在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「Linux的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!!Windows的终端或者叫控制台…...

算法导论 总结索引 | 第四部分 第十六章:贪心算法
1、求解最优化问题的算法 通常需要经过一系列的步骤,在每个步骤都面临多种选择。对于许多最优化问题,使用动态规划算法求最优解有些杀鸡用牛刀了,可以使用更简单、更高效的算法 贪心算法(greedy algorithm)就是这样的算…...
用“文心一言”写的文章,看看AI写得怎么样?
零售连锁店的“支付结算”业务设计 在数字化浪潮的推动下,连锁店零售支付结算的设计愈发重要。一个优秀的支付结算设计不仅能够提升用户体验,还能增强品牌竞争力,进而促进销售增长。 本文将围绕一个具体的连锁店零售支付结算案例…...

企业消费采购成本和员工体验如何实现“鱼和熊掌“的兼得?
有企业说企业消费采购成本和员工体验的关系好比是“鱼和熊掌”,无法兼得? 要想控制好成本就一定要加强管控,但是加强管控以后,就会很难让员工获得满意的体验度。如果不加以管控,员工自由度增加了,往往就很难…...

发表EI论文相当于SCI几区?
EI(工程索引)本身并不进行分区,它是一个收录工程领域高质量文献的数据库,与SCI(科学引文索引)的分区制度不同。然而,在非正式的学术评价中,有时人们会将EI与SCI的分区进行比较。 虽…...

STFT短时傅里叶变换MTLAB简析
代码: 解释: 如果信号x有Nx个时间样本,短时傅里叶变换的结果矩阵s有k列; k的计算方式如图所示,M是窗函数的长度,L是重叠长度。 此符号是向下取整符号。 短时傅里叶变换的结果矩阵s的行数与参数‘FFTLength’…...
海致科技实施实习生面试
一、面试内容 注:此次是电话面试 1.是XX先生吗 2.你是有考虑转实施的吗? 3.请讲一下你对项目部署实施的理解和掌握 4.用过数据库,会编写SQL语句吗? 5.讲一下SQL的常用关键字 6.了解SQL中的函数吗?谈谈函数 7.多…...

论文阅读之旋转目标检测ARC:《Adaptive Rotated Convolution for Rotated Object Detection》
论文link:link code:code ARC是一个改进的backbone,相比于ResNet,最后的几层有一些改变。 Introduction ARC自适应地旋转以调整每个输入的条件参数,其中旋转角度由路由函数以数据相关的方式预测。此外,还采…...
面向对象(Java)
构造方法只能在对象实例化的时候调用 this可以作为方法参数,表示调用方法的当前对象 this可以作为方法返回值,表示返回当前对象 封装 通过方法访问数据,隐藏类的实现细节 static:类对象共享,类加载时产生,…...

I/O多路复用
参考面试官:简单说一下阻塞IO、非阻塞IO、IO复用的区别 ?_unix环境编程 阻塞io和非阻塞io-CSDN博客 同步阻塞(BIO) BIO 以流的方式处理数据 应用程序发起一个系统调用(recvform),这个时候应用程序会一直阻塞下去&am…...

线性代数基础概念:向量空间
目录 线性代数基础概念:向量空间 1. 向量空间的定义 2. 向量空间的性质 3. 基底和维数 4. 子空间 5. 向量空间的例子 总结 线性代数基础概念:向量空间 向量空间是线性代数中最基本的概念之一,它为我们提供了一个抽象的框架,…...
php 抓取淘宝商品评论数据 json
要抓取淘宝商品评论数据,你可以使用PHP的cURL库来发送HTTP请求并获取JSON格式的数据。 API接入流程:需要开放平台或者是封装接口注册账号,并申请相应的API使用权限,以获取必要的密钥和接口文档。获取接口使用权限:接入…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

Flask RESTful 示例
目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题: 下面创建一个简单的Flask RESTful API示例。首先,我们需要创建环境,安装必要的依赖,然后…...
Java 8 Stream API 入门到实践详解
一、告别 for 循环! 传统痛点: Java 8 之前,集合操作离不开冗长的 for 循环和匿名类。例如,过滤列表中的偶数: List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

边缘计算医疗风险自查APP开发方案
核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...

Day131 | 灵神 | 回溯算法 | 子集型 子集
Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣(LeetCode) 思路: 笔者写过很多次这道题了,不想写题解了,大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...
【HTML-16】深入理解HTML中的块元素与行内元素
HTML元素根据其显示特性可以分为两大类:块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

自然语言处理——Transformer
自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效,它能挖掘数据中的时序信息以及语义信息,但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN,但是…...