reinforce 跑 CartPole-v1
gym
版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。
然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。
代码
import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X),dim=1)class REINFORCE:def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)self.gamma = gamma # 折扣因子self.device = devicedef take_action(self, state): # 根据动作概率分布随机采样state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)probs = self.policy_net(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict): # 公式用的是简化推导reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))): # 从最后一步算起reward = reward_list[i]state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(1, action))G = self.gamma * G + reward loss = -log_prob * G # 因为梯度更新是减的,所以取个负号loss.backward()self.optimizer.step()
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated : # 主要是这部分和原始的有点不同action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()
我是在jupyter里直接跑的,结果如下所示。



相关文章:

reinforce 跑 CartPole-v1
gym版本是0.26.1 CartPole-v1的详细信息,点链接里看就行了。 修改了下动手深度强化学习对应的代码。 然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。 代码 import gym import torch from …...

【VRTK】【VR开发】【Unity】13-攀爬
课程配套学习资源下载 https://download.csdn.net/download/weixin_41697242/88485426?spm=1001.2014.3001.5503 【概述】 VRTK提供两个预制件实现攀爬 Climbing Controller,用于控制Player的物理义体Climbable Interactable,用于设置可攀爬对象【设置Climbing Controller…...
华为OD机试真题-求幸存数之和-2023年OD统一考试(C卷)
题目描述: 给一个正整数列 nums,一个跳数 jump,及幸存数量 left。运算过程为:从索引为0的位置开始向后跳,中间跳过 J 个数字,命中索引为J1的数字,该数被敲出,并从该点起跳ÿ…...

python pyaudio实时读取音频数据并展示波形图
python pyaudio实时读取音频数据并展示波形图 下面代码可以驱动电脑接受声音数据,并实时展示音波图: import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import pyaudio import wave import os import op…...

【算法系列篇】递归、搜索和回溯(二)
文章目录 前言1. 两两交换链表中的节点1.1 题目要求1.2 做题思路1.3 代码实现 2. Pow(X,N)2.1 题目要求2.2 做题思路2.3 代码实现 3. 计算布尔二叉树的值3.1 题目要求3.2 做题思路3.3 代码实现 4. 求根节点到叶结点数字之和4.1 题目要求4.2 做题思路4.3 代码实现 前言 前面为大…...

Ubuntu下安装SDL
源码下载地址(SDL version 2.0.14):https://www.libsdl.org/release/SDL2-2.0.14.tar.gz 将源码包拷贝到系统里 使用命令解压 tar -zxvf SDL2-2.0.14.tar.gz 解压得到文件夹 SDL2-2.0.14 进入文件夹 执行命令 ./configure 执行命令 make…...

创建vue项目:vue脚手架安装、vue-cli安装,vue ui界面创建vue工程(vue2/vue3),安装vue、搭建vue项目开发环境(保姆级教程二)
今天讲解 Windows 如何利用脚手架创建 vue 工程,以及 vue ui 图形化界面搭建 vue 开发环境,这是这个系列的第二章,有什么问题请留言,请点赞收藏!!! 文章目录 1、安装vue-cli脚手架2、vue ui创建…...

【3】密评-物理和环境安全测评
0x01 依据 GB/T 39786 -2021《信息安全技术 信息系统密码应用基本要求》针对等保三级系统要求: 物理和环境层面: a)宜采用密码技术进行物理访问身份鉴别,保证重要区域进入人员身份的真实性; b)宜采用密码技术保证电子门…...

笨爸爸工房,我们在校园|“小鲁班”,铸未来
为了响应国家号召,将劳动教育课程真正实现融入校园生活,笨爸爸工房已与洛阳市西下池小学、洛阳市第一实验小学西工校区、洛阳市西工区第二实验小学、洛阳第二外国语学校(兰溪校区)、洛阳市睿源幼儿园,这4所学校及1家幼…...
RPC 集群,gRPC 广播和组播
一、集群抽象:cluster 它是指我们在调用远程的时候,尝试解决: 1、failover:即引入重试功能,但是重试的时候会换一个新节点 2、failfast: 立刻失败,不需要重试 3、广播:将请求发送到所有的节点上 4、组…...

OpenSSL SSL_read: Connection was reset, errno 10054
fatal: unable to access ‘https://github.com/vangleer/es-big-screen.git/’: OpenSSL SSL_read: Connection was reset, errno 10054 解决方法:git config --global http.sslVerify “false” 参考链接: https://github.com/Kong/insomnia/issues/2…...

【springboot】整合redis和定制化
1.前提条件:docker安装好了redis,确定redis可以访问 可选软件: 2.测试代码 (1)redis依赖 org.springframework.boot spring-boot-starter-data-redis (2)配置redis (3) 注入 Resource StringRedisTemplate stringRedisTemplate; 这里如果用Autowi…...

HarmonyOS鸿蒙操作系统架构开发
什么是HarmonyOS鸿蒙操作系统? HarmonyOS是华为公司开发的一种全场景分布式操作系统。它可以在各种智能设备(如手机、电视、汽车、智能穿戴设备等)上运行,具有高效、安全、低延迟等优势。 目录 HarmonyOS 一、HarmonyOS 与其他操…...

共创共赢|美创科技获江苏移动2023DICT生态合作“产品共创奖”
12月6日,以“5G江山蓝 算网融百业 数智创未来”为主题的中国移动江苏公司2023DICT合作伙伴大会在南京成功举办。来自行业领军企业、科研院所等DICT产业核心力量的百余家单位代表参加本次大会,共话数实融合新趋势,共拓合作发展新空间。 作为生…...
深度学习——第3章 Python程序设计语言(3.5 Python类和对象)
3.5 Python类和对象 目录 1. 面向对象的基本概念 2. 类和对象的关系 3. 类的声明 4. 对象的创建和使用 5. 类对象属性 6. 类对象方法 7. 面向对象的三个基本特征 8. 综合案例:汉诺塔图形化移动 1.1 面向对象的基本概念 1.1.1 对象(object&#x…...

【原创】【一类问题的通法】【真题+李6卷6+李4卷4(+李6卷5)分析】合同矩阵A B有PTAP=B,求可逆阵P的策略
【铺垫】二次型做的变换与相应二次型矩阵的对应:二次型f(x1,x2,x3)xTAx,g(y1,y2,y3)yTBy ①若f在可逆变换xPy下化为g,即P为可逆阵,有P…...
代码随想录算法训练营第六十天 | 84.柱状图中最大的矩形
84.柱状图中最大的矩形 题目链接:84. 柱状图中最大的矩形 本题与接雨水相近。按列来看,是要找到每一个柱子左右第一个比它矮的柱子,即对于该柱子来说所能组成的最大面积,将每个柱子所能得到的最大面积进行对比最终得到最大矩形。 …...

C#结合JavaScript实现多文件上传
目录 需求 引入 关键代码 操作界面 JavaScript包程序 服务端 ashx 程序 服务端上传后处理程序 小结 需求 在许多应用场景里,多文件上传是一项比较实用的功能。实际应用中,多文件上传可以考虑如下需求: 1、对上传文件的类型、大小…...

STM32——继电器
继电器工作原理 单片机供电 VCC GND 接单片机, VCC 需要接 3.3V , 5V 不行! 最大负载电路交流 250V/10A ,直流 30V/10A 引脚 IN 接收到 低电平 时,开关闭合。...

性能监控体系:InfluxDB Grafana Prometheus
InfluxDB 简介 什么是 InfluxDB ? InfluxDB 是一个由 InfluxData 开发的,开源的时序型数据库。它由 Go 语言写成,着力于高性能地查询与存储时序型数据。 InfluxDB 被广泛应用于存储系统的监控数据、IoT 行业的实时数据等场景。 可配合 Te…...
Leetcode 3576. Transform Array to All Equal Elements
Leetcode 3576. Transform Array to All Equal Elements 1. 解题思路2. 代码实现 题目链接:3576. Transform Array to All Equal Elements 1. 解题思路 这一题思路上就是分别考察一下是否能将其转化为全1或者全-1数组即可。 至于每一种情况是否可以达到…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...
R语言AI模型部署方案:精准离线运行详解
R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

Docker 运行 Kafka 带 SASL 认证教程
Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明:server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...

自然语言处理——Transformer
自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效,它能挖掘数据中的时序信息以及语义信息,但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN,但是…...

优选算法第十二讲:队列 + 宽搜 优先级队列
优选算法第十二讲:队列 宽搜 && 优先级队列 1.N叉树的层序遍历2.二叉树的锯齿型层序遍历3.二叉树最大宽度4.在每个树行中找最大值5.优先级队列 -- 最后一块石头的重量6.数据流中的第K大元素7.前K个高频单词8.数据流的中位数 1.N叉树的层序遍历 2.二叉树的锯…...
服务器--宝塔命令
一、宝塔面板安装命令 ⚠️ 必须使用 root 用户 或 sudo 权限执行! sudo su - 1. CentOS 系统: yum install -y wget && wget -O install.sh http://download.bt.cn/install/install_6.0.sh && sh install.sh2. Ubuntu / Debian 系统…...
Java毕业设计:WML信息查询与后端信息发布系统开发
JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发,实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构,服务器端使用Java Servlet处理请求,数据库采用MySQL存储信息࿰…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...