当前位置: 首页 > news >正文

reinforce 跑 CartPole-v1

gym版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。

然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。

代码

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X),dim=1)class REINFORCE:def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)self.gamma = gamma # 折扣因子self.device = devicedef take_action(self, state): # 根据动作概率分布随机采样state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)probs = self.policy_net(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):  # 公式用的是简化推导reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))):  # 从最后一步算起reward = reward_list[i]state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(1, action))G = self.gamma * G + reward loss = -log_prob * G  # 因为梯度更新是减的,所以取个负号loss.backward()self.optimizer.step()
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated :  # 主要是这部分和原始的有点不同action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()

我是在jupyter里直接跑的,结果如下所示。

相关文章:

reinforce 跑 CartPole-v1

gym版本是0.26.1 CartPole-v1的详细信息,点链接里看就行了。 修改了下动手深度强化学习对应的代码。 然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。 代码 import gym import torch from …...

【VRTK】【VR开发】【Unity】13-攀爬

课程配套学习资源下载 https://download.csdn.net/download/weixin_41697242/88485426?spm=1001.2014.3001.5503 【概述】 VRTK提供两个预制件实现攀爬 Climbing Controller,用于控制Player的物理义体Climbable Interactable,用于设置可攀爬对象【设置Climbing Controller…...

华为OD机试真题-求幸存数之和-2023年OD统一考试(C卷)

题目描述: 给一个正整数列 nums,一个跳数 jump,及幸存数量 left。运算过程为:从索引为0的位置开始向后跳,中间跳过 J 个数字,命中索引为J1的数字,该数被敲出,并从该点起跳&#xff…...

python pyaudio实时读取音频数据并展示波形图

python pyaudio实时读取音频数据并展示波形图 下面代码可以驱动电脑接受声音数据,并实时展示音波图: import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import pyaudio import wave import os import op…...

【算法系列篇】递归、搜索和回溯(二)

文章目录 前言1. 两两交换链表中的节点1.1 题目要求1.2 做题思路1.3 代码实现 2. Pow(X,N)2.1 题目要求2.2 做题思路2.3 代码实现 3. 计算布尔二叉树的值3.1 题目要求3.2 做题思路3.3 代码实现 4. 求根节点到叶结点数字之和4.1 题目要求4.2 做题思路4.3 代码实现 前言 前面为大…...

Ubuntu下安装SDL

源码下载地址(SDL version 2.0.14):https://www.libsdl.org/release/SDL2-2.0.14.tar.gz 将源码包拷贝到系统里 使用命令解压 tar -zxvf SDL2-2.0.14.tar.gz 解压得到文件夹 SDL2-2.0.14 进入文件夹 执行命令 ./configure 执行命令 make…...

创建vue项目:vue脚手架安装、vue-cli安装,vue ui界面创建vue工程(vue2/vue3),安装vue、搭建vue项目开发环境(保姆级教程二)

今天讲解 Windows 如何利用脚手架创建 vue 工程,以及 vue ui 图形化界面搭建 vue 开发环境,这是这个系列的第二章,有什么问题请留言,请点赞收藏!!! 文章目录 1、安装vue-cli脚手架2、vue ui创建…...

【3】密评-物理和环境安全测评

0x01 依据 GB/T 39786 -2021《信息安全技术 信息系统密码应用基本要求》针对等保三级系统要求: 物理和环境层面: a)宜采用密码技术进行物理访问身份鉴别,保证重要区域进入人员身份的真实性; b)宜采用密码技术保证电子门…...

笨爸爸工房,我们在校园|“小鲁班”,铸未来

为了响应国家号召,将劳动教育课程真正实现融入校园生活,笨爸爸工房已与洛阳市西下池小学、洛阳市第一实验小学西工校区、洛阳市西工区第二实验小学、洛阳第二外国语学校(兰溪校区)、洛阳市睿源幼儿园,这4所学校及1家幼…...

RPC 集群,gRPC 广播和组播

一、集群抽象:cluster 它是指我们在调用远程的时候,尝试解决: 1、failover:即引入重试功能,但是重试的时候会换一个新节点 2、failfast: 立刻失败,不需要重试 3、广播:将请求发送到所有的节点上 4、组…...

OpenSSL SSL_read: Connection was reset, errno 10054

fatal: unable to access ‘https://github.com/vangleer/es-big-screen.git/’: OpenSSL SSL_read: Connection was reset, errno 10054 解决方法:git config --global http.sslVerify “false” 参考链接: https://github.com/Kong/insomnia/issues/2…...

【springboot】整合redis和定制化

1.前提条件:docker安装好了redis,确定redis可以访问 可选软件: 2.测试代码 (1)redis依赖 org.springframework.boot spring-boot-starter-data-redis (2)配置redis (3) 注入 Resource StringRedisTemplate stringRedisTemplate; 这里如果用Autowi…...

HarmonyOS鸿蒙操作系统架构开发

什么是HarmonyOS鸿蒙操作系统? HarmonyOS是华为公司开发的一种全场景分布式操作系统。它可以在各种智能设备(如手机、电视、汽车、智能穿戴设备等)上运行,具有高效、安全、低延迟等优势。 目录 HarmonyOS 一、HarmonyOS 与其他操…...

共创共赢|美创科技获江苏移动2023DICT生态合作“产品共创奖”

12月6日,以“5G江山蓝 算网融百业 数智创未来”为主题的中国移动江苏公司2023DICT合作伙伴大会在南京成功举办。来自行业领军企业、科研院所等DICT产业核心力量的百余家单位代表参加本次大会,共话数实融合新趋势,共拓合作发展新空间。 作为生…...

深度学习——第3章 Python程序设计语言(3.5 Python类和对象)

3.5 Python类和对象 目录 1. 面向对象的基本概念 2. 类和对象的关系 3. 类的声明 4. 对象的创建和使用 5. 类对象属性 6. 类对象方法 7. 面向对象的三个基本特征 8. 综合案例:汉诺塔图形化移动 1.1 面向对象的基本概念 1.1.1 对象(object&#x…...

【原创】【一类问题的通法】【真题+李6卷6+李4卷4(+李6卷5)分析】合同矩阵A B有PTAP=B,求可逆阵P的策略

【铺垫】二次型做的变换与相应二次型矩阵的对应:二次型f(x1,x2,x3)xTAx,g(y1,y2,y3)yTBy ①若f在可逆变换xPy下化为g,即P为可逆阵,有P…...

代码随想录算法训练营第六十天 | 84.柱状图中最大的矩形

84.柱状图中最大的矩形 题目链接:84. 柱状图中最大的矩形 本题与接雨水相近。按列来看,是要找到每一个柱子左右第一个比它矮的柱子,即对于该柱子来说所能组成的最大面积,将每个柱子所能得到的最大面积进行对比最终得到最大矩形。 …...

C#结合JavaScript实现多文件上传

目录 需求 引入 关键代码 操作界面 ​JavaScript包程序 服务端 ashx 程序 服务端上传后处理程序 小结 需求 在许多应用场景里,多文件上传是一项比较实用的功能。实际应用中,多文件上传可以考虑如下需求: 1、对上传文件的类型、大小…...

STM32——继电器

继电器工作原理 单片机供电 VCC GND 接单片机, VCC 需要接 3.3V , 5V 不行! 最大负载电路交流 250V/10A ,直流 30V/10A 引脚 IN 接收到 低电平 时,开关闭合。...

性能监控体系:InfluxDB Grafana Prometheus

InfluxDB 简介 什么是 InfluxDB ? InfluxDB 是一个由 InfluxData 开发的,开源的时序型数据库。它由 Go 语言写成,着力于高性能地查询与存储时序型数据。 InfluxDB 被广泛应用于存储系统的监控数据、IoT 行业的实时数据等场景。 可配合 Te…...

Debian系统简介

目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版&#xff…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言:为什么 Eureka 依然是存量系统的核心? 尽管 Nacos 等新注册中心崛起,但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制,是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 (部分有免费额度&#x…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天,Spring AI 作为 Spring 生态系统的新生力量,正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务(如 OpenAI、Anthropic)的无缝对接&…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...

【7色560页】职场可视化逻辑图高级数据分析PPT模版

7种色调职场工作汇报PPT,橙蓝、黑红、红蓝、蓝橙灰、浅蓝、浅绿、深蓝七种色调模版 【7色560页】职场可视化逻辑图高级数据分析PPT模版:职场可视化逻辑图分析PPT模版https://pan.quark.cn/s/78aeabbd92d1...

Linux离线(zip方式)安装docker

目录 基础信息操作系统信息docker信息 安装实例安装步骤示例 遇到的问题问题1:修改默认工作路径启动失败问题2 找不到对应组 基础信息 操作系统信息 OS版本:CentOS 7 64位 内核版本:3.10.0 相关命令: uname -rcat /etc/os-rele…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程,代码下载:这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中,**知识蒸馏(Knowledge Distillation)**被广泛应用,作为提升模型…...

【SpringBoot自动化部署】

SpringBoot自动化部署方法 使用Jenkins进行持续集成与部署 Jenkins是最常用的自动化部署工具之一,能够实现代码拉取、构建、测试和部署的全流程自动化。 配置Jenkins任务时,需要添加Git仓库地址和凭证,设置构建触发器(如GitHub…...