当前位置: 首页 > article >正文

策略梯度 (Policy Gradient):直接优化策略的强化学习方法

策略梯度 (Policy Gradient) 是强化学习中的一种方法,用于优化智能体的策略,使其在给定环境中表现得更好。与值函数方法(如 Q-learning)不同,策略梯度方法直接对策略进行优化,而不是通过学习一个值函数来间接估计最优策略。

核心思想:

在策略梯度方法中,智能体的策略是一个参数化的函数(通常是神经网络),通过梯度上升法来优化该策略的参数,使得智能体在与环境互动时获得最大的预期奖励。该方法通过计算策略相对于策略参数的梯度来更新策略参数,从而改善智能体的行为。

实现方式:

  1. 收集经验: 智能体与环境互动,收集状态-动作对以及相应的奖励。
  2. 计算梯度: 基于当前策略和收集到的经验,计算梯度。
  3. 更新策略: 使用计算出的梯度更新策略参数。

优点:

  • 可以直接优化策略,适用于连续动作空间。
  • 不依赖于环境的价值函数,适用于部分可观测或高维的状态空间。

缺点:

  • 策略梯度的估计通常具有较高的方差,需要更多的样本来获得稳定的结果。
  • 收敛速度较慢,可能需要更多的计算资源。

简单例子:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt# 1D迷宫环境,目标是从位置0移动到位置10
class SimpleMazeEnv:def __init__(self):self.state = 0  # 初始位置self.target = 10  # 目标位置self.max_steps = 20  # 最大步数def reset(self):self.state = 0return self.statedef step(self, action):if action == 0:  # 向左移动self.state = max(0, self.state - 1)elif action == 1:  # 向右移动self.state = min(self.target, self.state + 1)# 计算奖励,靠近目标位置时奖励更高reward = -abs(self.state - self.target)  # 离目标越远奖励越低done = (self.state == self.target)  # 到达目标时结束return self.state, reward, done# 策略网络
class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(PolicyNetwork, self).__init__()self.fc1 = nn.Linear(input_dim, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, output_dim)self.softmax = nn.Softmax(dim=-1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return self.softmax(x)# 策略梯度算法(REINFORCE)
def reinforce(env, policy, optimizer, episodes=1000, gamma=0.99):episode_rewards = []best_reward = -float('inf')best_path = []for episode in range(episodes):state = env.reset()state = torch.tensor([state], dtype=torch.float32)done = Falserewards = []log_probs = []path = []  # 记录当前回合的路径while not done:# 选择动作action_probs = policy(state)dist = torch.distributions.Categorical(action_probs)action = dist.sample()# 执行动作并观察结果next_state, reward, done = env.step(action.item())next_state = torch.tensor([next_state], dtype=torch.float32)# 保存奖励和动作的log概率rewards.append(reward)log_probs.append(dist.log_prob(action))path.append(state.item())  # 记录当前位置state = next_state# 计算回报returns = []G = 0for r in reversed(rewards):G = r + gamma * Greturns.insert(0, G)# 计算损失并更新模型returns = torch.tensor(returns, dtype=torch.float32)log_probs = torch.stack(log_probs)loss = -torch.sum(log_probs * returns)optimizer.zero_grad()loss.backward()optimizer.step()total_reward = sum(rewards)episode_rewards.append(total_reward)if total_reward > best_reward:best_reward = total_rewardbest_path = pathif (episode + 1) % 100 == 0:print(f"Episode {episode + 1}, Total Reward: {total_reward}, Best Reward: {best_reward}")return episode_rewards, best_path# 初始化环境和模型
env = SimpleMazeEnv()
input_dim = 1  # 状态是一个标量
output_dim = 2  # 动作是向左或向右
policy = PolicyNetwork(input_dim, output_dim)
optimizer = optim.Adam(policy.parameters(), lr=0.001)# 训练模型
episode_rewards, best_path = reinforce(env, policy, optimizer, episodes=1000)# 可视化训练结果
plt.figure(figsize=(12, 6))# 绘制奖励曲线
plt.subplot(1, 2, 1)
plt.plot(episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Training Progress')# 绘制最优路径图
plt.subplot(1, 2, 2)
plt.plot(best_path, marker='o', markersize=5, label="Best Path")
for i, coord in enumerate(best_path):plt.text(i, coord, f"({i}, {coord})", fontsize=8)  # 显示坐标
plt.xlabel('Steps')
plt.ylabel('State')
plt.title('Best Path Taken')
plt.legend()plt.tight_layout()
plt.show()
  1. 环境SimpleMazeEnv是一个非常简单的1D迷宫环境,智能体的目标是从位置0移动到目标位置10。每步,智能体可以选择向左或向右移动。
  2. 策略网络PolicyNetwork是一个简单的神经网络,输出的是两个动作的概率(向左和向右)。
  3. 训练过程:采用策略梯度算法(REINFORCE),在每一轮训练中,智能体根据当前策略选择动作,通过累积奖励(回报)来更新策略网络。
  4. 奖励:智能体的奖励是与目标位置的距离成反比,离目标越近奖励越高。

预期效果:

  • 训练过程:每个回合的奖励会逐渐增加,智能体会逐步学习到正确的动作。
  • 可视化:我们会看到训练过程中每个回合的奖励曲线,以及最优路径(即智能体最终到达目标位置时的移动轨迹)。

运行后的图:

  • 左图:训练过程中的奖励变化。
  • 右图:最优路径的轨迹图,标记了每一步的位置。

相关文章:

策略梯度 (Policy Gradient):直接优化策略的强化学习方法

策略梯度 (Policy Gradient) 是强化学习中的一种方法,用于优化智能体的策略,使其在给定环境中表现得更好。与值函数方法(如 Q-learning)不同,策略梯度方法直接对策略进行优化,而不是通过学习一个值函数来间…...

练习(复习)

大家好,今天我们来做几道简单的选择题目来巩固一下最近学习的知识,以便我们接下来更好的学习。 这道题比较简单,我们前面学过,在Java中,一个类只能继承一个父类,但是一个父类可以有多个子类,一个…...

数据结构与算法之栈: LeetCode 739. 每日温度 (Ts版)

每日温度 https://leetcode.cn/problems/daily-temperatures/description/ 描述 给定一个整数数组 temperatures ,表示每天的温度,返回一个数组 answer ,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温…...

【原创改进】SCI级改进算法,一种多策略改进Alpha进化算法(IAE)

目录 1.前言2.CEC2017指标3.效果展示4.探索开发比5.定性分析6.附件材料7.代码获取 1.前言 本期推出一期原创改进——一种多策略改进Alpha进化算法(IAE)~ 选择CEC2017测试集低维(30dim)和高维(100dim)进行测…...

创建 priority_queue - 初阶(c++)

优先级队列 普通的队列是⼀种先进先出的数据结构,即元素插⼊在队尾,⽽元素删除在队头。 ⽽在优先级队列中,元素被赋予优先级,当插⼊元素时,同样是在队尾,但是会根据优先级进⾏位置 调整,优先级…...

56. 协议及端口号

协议及端口号 在计算机网络中,协议和端口号是两个重要的概念。它们共同确保了不同计算机和网络设备之间可以正确、有效地进行通信。 协议(Protocol) 协议是网络通信的一组规则或标准,它定义了如何在计算机网络中发送、接收和解释…...

前端知识速记—JS篇:null 与 undefined

前端知识速记—JS篇:null 与 undefined 什么是 null 和 undefined? 1. undefined 的含义 undefined 是 JavaScript 中默认的值,表示某个变量已被声明但尚未被赋值。当尝试访问一个未初始化的变量、函数没有返回值时,都会得到 u…...

短链接项目02---依赖的添加和postman测试

文章目录 1.声明2.对于依赖的引入和处理2.1原有的内容说明2.2添加公共信息2.3dependencies和management区别说明2.4添加spring-boot依赖2.5数据库的相关依赖2.6hutool工具类的依赖添加2.7测试test 的依赖添加 3.core文件的代码3.1目录层级结构3.2启动类3.3testcontroller测试类…...

Docker容器数据恢复

Docker容器数据恢复 1 创建mongo数据库时未挂载数据到宿主机2 查找数据卷位置3 将容器在宿主机上的数据复制到指定目录下4 修改docker-compose并挂载数据(注意端口)5 重新运行新容器 以mongodb8.0.3为例。 1 创建mongo数据库时未挂载数据到宿主机 versi…...

Sqoop源码修改:增加落地HDFS文件数与MapTask数量一致性检查

个人博客地址:Sqoop源码修改:增加落地HDFS文件数与MapTask数量一致性检查 | 一张假钞的真实世界 本篇是对记录一次Sqoop从MySQL导入数据到Hive问题的排查经过的补充。 Sqoop 命令通过 bin 下面的脚本调用,调用如下: exec ${HAD…...

AI常见的算法

人工智能(AI)中常见的算法分为多个领域,如机器学习、深度学习、强化学习、自然语言处理和计算机视觉等。以下是一些常见的算法及其用途: 例子代码:纠结哥/pytorch_learn 1. 机器学习 (Machine Learning) 监督学习 (S…...

ADC 精度 第二部分:总的未调整误差解析

在关于ADC精度的第一篇文章中,我们阐述了模拟-数字转换器(ADC)的分辨率和精度之间的区别。现在,我们可以深入探讨影响ADC总精度的因素,这通常被称为总未调整误差(TUE)。 你是否曾好奇ADC数据表…...

我的毕设之路:(2)系统类型的论文写法

一般先进行毕设的设计与实现,再在现成毕设基础上进行描述形成文档,那么论文也就成形了。 1 需求分析:毕业设计根据开题报告和要求进行需求分析和功能确定,区分贴合主题的主要功能和拓展功能能,删除偏离无关紧要的功能…...

密码强度验证代码解析:C语言实现与细节剖析

在日常的应用开发中,密码强度验证是保障用户账户安全的重要环节。今天,我们就来深入分析一段用C语言编写的密码强度验证代码,看看它是如何实现对密码强度的多维度检测的。 代码整体结构 这段C语言代码主要实现了对输入密码的一系列规则验证&a…...

独立成分分析 (ICA):用于信号分离或降维

独立成分分析 (Independent Component Analysis, ICA) 是一种用于信号分离和降维的统计方法,常用于盲源分离 (Blind Source Separation, BSS) 问题,例如音频信号分离或脑电信号 (EEG) 处理。 实现 ICA(独立成分分析) 步骤 生成…...

Java中的注解与反射:深入理解getAnnotation(Class<T> annotationClass)方法

Java的注解(Annotation)是一种元数据机制,它允许我们在代码中添加额外的信息,这些信息可以在编译时或运行时被读取和处理。结合Java的反射机制(Reflection),我们可以在运行时动态地获取类、方法…...

Vue - pinia

Pinia 是 Vue 3 的官方状态管理库,旨在替代 Vuex,提供更简单的 API 和更好的 TypeScript 支持。Pinia 的设计遵循了组合式 API 的理念,能够很好地与 Vue 3 的功能结合使用。 Pinia 的基本概念 Store: Pinia 中的核心概念,类似于…...

JxBrowser 7.41.7 版本发布啦!

JxBrowser 7.41.7 版本发布啦! • 已更新 #Chromium 至更新版本 • 实施了多项质量改进 🔗 点击此处了解更多详情。 🆓 获取 30 天免费试用。...

亚博microros小车-原生ubuntu支持系列:17 gmapping

前置依赖 先看下亚博官网的介绍 Gmapping简介 gmapping只适用于单帧二维激光点数小于1440的点,如果单帧激光点数大于1440,那么就会出【[mapping-4] process has died】 这样的问题。 Gmapping是基于滤波SLAM框架的常用开源SLAM算法。 Gmapping基于RBp…...

Python 变量和简单数据类型思维导图_2025-01-30

变量和简单数据类型思维导图 下载链接腾讯云盘: https://share.weiyun.com/15A8hrTs...

小麦重测序-文献精读107

Whole-genome sequencing of diverse wheat accessions uncovers genetic changes during modern breeding in China and the United States 中国和美国现代育种过程中小麦不同种质的全基因组测序揭示遗传变化 大豆重测序-文献精读53_gmsw17-CSDN博客 大豆重测序二&#xff…...

Django基础之ORM

一.前言 上一节简单的讲了一下orm,主要还是做个了解,这一节将和大家介绍更加细致的orm,以及他们的用法,到最后再和大家说一下cookie和session,就结束了全部的django基础部分 二.orm的基本操作 1.settings.py&#x…...

大模型知识蒸馏技术(2)——蒸馏技术发展简史

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl2006年模型压缩研究 知识蒸馏的早期思想可以追溯到2006年,当时Geoffrey Hinton等人在模型压缩领域进行了开创性研究。尽管当时深度学习尚未像今天这样广泛普及,但Hinton的研究已经为知识迁移和模…...

android获取EditText内容,TextWatcher按条件触发

android获取EditText内容,TextWatcher按条件触发 背景:解决方案:效果: 背景: 最近在尝试用原生安卓实现仿element-ui表单校验功能,其中涉及到EditText组件内容的动态校验,初步实现功能后&#…...

毕业设计--具有车流量检测功能的智能交通灯设计

摘要: 随着21世纪机动车保有量的持续增加,城市交通拥堵已成为一个日益严重的问题。传统的固定绿灯时长方案导致了大量的时间浪费和交通拥堵。为解决这一问题,本文设计了一款智能交通灯系统,利用车流量检测功能和先进的算法实现了…...

[权限提升] 操作系统权限介绍

关注这个专栏的其他相关笔记:[内网安全] 内网渗透 - 学习手册-CSDN博客 权限提升简称提权,顾名思义就是提升自己在目标系统中的权限。现在的操作系统都是多用户操作系统,用户之间都有权限控制,我们通过 Web 漏洞拿到的 Web 进程的…...

Qt Designer and Python: Build Your GUI

1.install pyside6 2.pyside6-designer.exe 发送到桌面快捷方式 在Python安装的所在 Scripts 文件夹下找到此文件。如C:\Program Files\Python312\Scripts 3. 打开pyside6-designer 设计UI 4.保存为simple.ui 文件,再转成py文件 用代码执行 pyside6-uic.exe simpl…...

数据结构与算法之栈: LeetCode LCR 152. 验证二叉搜索树的后序遍历序列 (Ts版)

验证二叉搜索树的后序遍历序列 https://leetcode.cn/problems/er-cha-sou-suo-shu-de-hou-xu-bian-li-xu-lie-lcof/description/ 描述 请实现一个函数来判断整数数组 postorder 是否为二叉搜索树的后序遍历结果 示例 1 输入: postorder [4,9,6,5,8] 输出: false解释&#…...

蓝桥备赛指南(5)

这篇文章相对简单&#xff0c;主要是让大家简单了解下stack函数。 stack的定义和结构 stack是一种先进后出的数据结构&#xff0c;使用前也需要包含头文件<stack>。stack提供了一组函数来操作和访问元素&#xff0c;但它的功能相对简单。 stack的常用函数 1.push&…...

[STM32 - 野火] - - - 固件库学习笔记 - - -十三.高级定时器

一、高级定时器简介 高级定时器的简介在前面一章已经介绍过&#xff0c;可以点击下面链接了解&#xff0c;在这里进行一些补充。 [STM32 - 野火] - - - 固件库学习笔记 - - -十二.基本定时器 1.1 功能简介 1、高级定时器可以向上/向下/两边计数&#xff0c;还独有一个重复计…...