reinforce 跑 CartPole-v1
gym版本是0.26.1
CartPole-v1的详细信息,点链接里看就行了。
修改了下动手深度强化学习对应的代码。
然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。
代码
import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import rl_utils # 这个要下载源码,然后放到同个文件目录下,链接在上面给出了
from d2l import torch as d2l # 这个是动手深度学习的库, pip/conda install d2l 就好了class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)def forward(self, X):X = F.relu(self.fc1(X))return F.softmax(self.fc2(X),dim=1)class REINFORCE:def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma, device):self.policy_net = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr = learning_rate)self.gamma = gamma # 折扣因子self.device = devicedef take_action(self, state): # 根据动作概率分布随机采样state = torch.tensor(np.array([state]),dtype=torch.float).to(self.device)probs = self.policy_net(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict): # 公式用的是简化推导reward_list = transition_dict['rewards']state_list = transition_dict['states']action_list = transition_dict['actions']G = 0self.optimizer.zero_grad()for i in reversed(range(len(reward_list))): # 从最后一步算起reward = reward_list[i]state = torch.tensor(np.array([state_list[i]]), dtype=torch.float).to(self.device)action = torch.tensor([action_list[i]]).reshape(-1,1).to(self.device)log_prob = torch.log(self.policy_net(state).gather(1, action))G = self.gamma * G + reward loss = -log_prob * G # 因为梯度更新是减的,所以取个负号loss.backward()self.optimizer.step()
lr = 1e-3
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = d2l.try_gpu()env_name="CartPole-v1"
env = gym.make(env_name)
print(f"_max_episode_steps:{env._max_episode_steps}")
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.nagent = REINFORCE(state_dim, hidden_dim, action_dim, lr, gamma, device)
return_list = []
for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated : # 主要是这部分和原始的有点不同action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'REINFORCE on {env_name}')
plt.show()
我是在jupyter里直接跑的,结果如下所示。
相关文章:
reinforce 跑 CartPole-v1
gym版本是0.26.1 CartPole-v1的详细信息,点链接里看就行了。 修改了下动手深度强化学习对应的代码。 然后这里 J ( θ ) J(\theta) J(θ)梯度上升更新的公式是用的不严谨的,这个和王树森书里讲的严谨公式有点区别。 代码 import gym import torch from …...
【VRTK】【VR开发】【Unity】13-攀爬
课程配套学习资源下载 https://download.csdn.net/download/weixin_41697242/88485426?spm=1001.2014.3001.5503 【概述】 VRTK提供两个预制件实现攀爬 Climbing Controller,用于控制Player的物理义体Climbable Interactable,用于设置可攀爬对象【设置Climbing Controller…...
华为OD机试真题-求幸存数之和-2023年OD统一考试(C卷)
题目描述: 给一个正整数列 nums,一个跳数 jump,及幸存数量 left。运算过程为:从索引为0的位置开始向后跳,中间跳过 J 个数字,命中索引为J1的数字,该数被敲出,并从该点起跳ÿ…...
python pyaudio实时读取音频数据并展示波形图
python pyaudio实时读取音频数据并展示波形图 下面代码可以驱动电脑接受声音数据,并实时展示音波图: import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import pyaudio import wave import os import op…...
【算法系列篇】递归、搜索和回溯(二)
文章目录 前言1. 两两交换链表中的节点1.1 题目要求1.2 做题思路1.3 代码实现 2. Pow(X,N)2.1 题目要求2.2 做题思路2.3 代码实现 3. 计算布尔二叉树的值3.1 题目要求3.2 做题思路3.3 代码实现 4. 求根节点到叶结点数字之和4.1 题目要求4.2 做题思路4.3 代码实现 前言 前面为大…...
Ubuntu下安装SDL
源码下载地址(SDL version 2.0.14):https://www.libsdl.org/release/SDL2-2.0.14.tar.gz 将源码包拷贝到系统里 使用命令解压 tar -zxvf SDL2-2.0.14.tar.gz 解压得到文件夹 SDL2-2.0.14 进入文件夹 执行命令 ./configure 执行命令 make…...
创建vue项目:vue脚手架安装、vue-cli安装,vue ui界面创建vue工程(vue2/vue3),安装vue、搭建vue项目开发环境(保姆级教程二)
今天讲解 Windows 如何利用脚手架创建 vue 工程,以及 vue ui 图形化界面搭建 vue 开发环境,这是这个系列的第二章,有什么问题请留言,请点赞收藏!!! 文章目录 1、安装vue-cli脚手架2、vue ui创建…...
【3】密评-物理和环境安全测评
0x01 依据 GB/T 39786 -2021《信息安全技术 信息系统密码应用基本要求》针对等保三级系统要求: 物理和环境层面: a)宜采用密码技术进行物理访问身份鉴别,保证重要区域进入人员身份的真实性; b)宜采用密码技术保证电子门…...
笨爸爸工房,我们在校园|“小鲁班”,铸未来
为了响应国家号召,将劳动教育课程真正实现融入校园生活,笨爸爸工房已与洛阳市西下池小学、洛阳市第一实验小学西工校区、洛阳市西工区第二实验小学、洛阳第二外国语学校(兰溪校区)、洛阳市睿源幼儿园,这4所学校及1家幼…...
RPC 集群,gRPC 广播和组播
一、集群抽象:cluster 它是指我们在调用远程的时候,尝试解决: 1、failover:即引入重试功能,但是重试的时候会换一个新节点 2、failfast: 立刻失败,不需要重试 3、广播:将请求发送到所有的节点上 4、组…...
OpenSSL SSL_read: Connection was reset, errno 10054
fatal: unable to access ‘https://github.com/vangleer/es-big-screen.git/’: OpenSSL SSL_read: Connection was reset, errno 10054 解决方法:git config --global http.sslVerify “false” 参考链接: https://github.com/Kong/insomnia/issues/2…...
【springboot】整合redis和定制化
1.前提条件:docker安装好了redis,确定redis可以访问 可选软件: 2.测试代码 (1)redis依赖 org.springframework.boot spring-boot-starter-data-redis (2)配置redis (3) 注入 Resource StringRedisTemplate stringRedisTemplate; 这里如果用Autowi…...
HarmonyOS鸿蒙操作系统架构开发
什么是HarmonyOS鸿蒙操作系统? HarmonyOS是华为公司开发的一种全场景分布式操作系统。它可以在各种智能设备(如手机、电视、汽车、智能穿戴设备等)上运行,具有高效、安全、低延迟等优势。 目录 HarmonyOS 一、HarmonyOS 与其他操…...
共创共赢|美创科技获江苏移动2023DICT生态合作“产品共创奖”
12月6日,以“5G江山蓝 算网融百业 数智创未来”为主题的中国移动江苏公司2023DICT合作伙伴大会在南京成功举办。来自行业领军企业、科研院所等DICT产业核心力量的百余家单位代表参加本次大会,共话数实融合新趋势,共拓合作发展新空间。 作为生…...
深度学习——第3章 Python程序设计语言(3.5 Python类和对象)
3.5 Python类和对象 目录 1. 面向对象的基本概念 2. 类和对象的关系 3. 类的声明 4. 对象的创建和使用 5. 类对象属性 6. 类对象方法 7. 面向对象的三个基本特征 8. 综合案例:汉诺塔图形化移动 1.1 面向对象的基本概念 1.1.1 对象(object&#x…...
【原创】【一类问题的通法】【真题+李6卷6+李4卷4(+李6卷5)分析】合同矩阵A B有PTAP=B,求可逆阵P的策略
【铺垫】二次型做的变换与相应二次型矩阵的对应:二次型f(x1,x2,x3)xTAx,g(y1,y2,y3)yTBy ①若f在可逆变换xPy下化为g,即P为可逆阵,有P…...
代码随想录算法训练营第六十天 | 84.柱状图中最大的矩形
84.柱状图中最大的矩形 题目链接:84. 柱状图中最大的矩形 本题与接雨水相近。按列来看,是要找到每一个柱子左右第一个比它矮的柱子,即对于该柱子来说所能组成的最大面积,将每个柱子所能得到的最大面积进行对比最终得到最大矩形。 …...
C#结合JavaScript实现多文件上传
目录 需求 引入 关键代码 操作界面 JavaScript包程序 服务端 ashx 程序 服务端上传后处理程序 小结 需求 在许多应用场景里,多文件上传是一项比较实用的功能。实际应用中,多文件上传可以考虑如下需求: 1、对上传文件的类型、大小…...
STM32——继电器
继电器工作原理 单片机供电 VCC GND 接单片机, VCC 需要接 3.3V , 5V 不行! 最大负载电路交流 250V/10A ,直流 30V/10A 引脚 IN 接收到 低电平 时,开关闭合。...
性能监控体系:InfluxDB Grafana Prometheus
InfluxDB 简介 什么是 InfluxDB ? InfluxDB 是一个由 InfluxData 开发的,开源的时序型数据库。它由 Go 语言写成,着力于高性能地查询与存储时序型数据。 InfluxDB 被广泛应用于存储系统的监控数据、IoT 行业的实时数据等场景。 可配合 Te…...
Omni-Vision Sanctuary 企业级部署架构设计:高可用与弹性伸缩
Omni-Vision Sanctuary 企业级部署架构设计:高可用与弹性伸缩 1. 企业级AI部署面临的挑战 当企业决定在生产环境中部署Omni-Vision Sanctuary这类AI服务时,通常会遇到几个关键挑战。首先是服务可用性问题,任何计划外停机都可能直接影响业务…...
HsMod:炉石传说个性化增强工具 玩家的全方位游戏体验优化方案
HsMod:炉石传说个性化增强工具 玩家的全方位游戏体验优化方案 【免费下载链接】HsMod Hearthstone Modify Based on BepInEx 项目地址: https://gitcode.com/GitHub_Trending/hs/HsMod 你是否曾因炉石传说中繁琐的操作流程而感到沮丧?是否希望拥有…...
多场景适配:ClearerVoice-Studio支持16K/48K采样率,会议直播都适用
多场景适配:ClearerVoice-Studio支持16K/48K采样率,会议直播都适用 1. 为什么音频采样率如此重要? 在语音处理领域,采样率选择直接影响最终效果。就像相机像素决定照片清晰度一样,音频采样率决定了声音的"分辨率…...
Qwen Pixel Art企业级应用:游戏公司美术外包降本提效实战路径
Qwen Pixel Art企业级应用:游戏公司美术外包降本提效实战路径 1. 游戏美术外包的痛点与机遇 游戏开发中,美术资源制作往往占据大量成本和时间。传统像素美术外包存在三个核心痛点: 成本高:资深像素画师日薪通常在800-1500元&am…...
3步实现HTML到Word的智能转换:html-to-docx技术深度解析
3步实现HTML到Word的智能转换:html-to-docx技术深度解析 【免费下载链接】html-to-docx HTML to DOCX converter 项目地址: https://gitcode.com/gh_mirrors/ht/html-to-docx 你是否曾遇到过这样的场景?精心设计的网页报告需要转换为Word文档进行…...
保姆级教程:在PX4 SITL仿真中为Iris无人机挂载Kinect、RPLidar和FPV摄像头
PX4仿真环境多传感器集成实战:从零搭建SLAM无人机开发平台 无人机仿真开发中最令人头疼的,莫过于将各类传感器完美集成到飞行平台上。我曾花了整整两周时间调试Kinect和RPLidar在Gazebo中的兼容性问题,直到找到这套经过验证的解决方案。本文将…...
STM32F4读写SD卡:填一填ST官方HAL库的坑
使用STM32读写SD卡在低功耗存储中的应用是比较常见的,但是网上大多数资料都是基于标准库或者基于寄存器的开发。随着嵌入式设备越来越复杂,使用HAL库能够大大降低开发者的学习成本,从而提高开发效率。近年来,ST官方主推以STM32Cub…...
秋招简历模板下载怎么选?6款主流简历模板工具深度测评
秋招季来临,对应届生来说,简历是踏入职场的第一块敲门砖,而一份贴合岗位需求、契合HR筛选思路的简历模板,既能降低简历制作难度,也是提高简历初筛通过率的关键。如今市面上的简历模板工具五花八门,功能定位…...
从Java到AI Agent:传统后端工程师的下一站,不是学AI,是成为系统工程师!
文章探讨了在AI技术发展的背景下,传统后端工程师的转型方向。作者认为,未来的竞争焦点不再是单纯的技术能力,而是如何将AI技术融入现有系统,构建自动化系统。文章提出了AI Agent工程师的概念,强调系统工程能力的重要性…...
如何用Mermaid Live Editor 5分钟创建专业图表
如何用Mermaid Live Editor 5分钟创建专业图表 【免费下载链接】mermaid-live-editor Edit, preview and share mermaid charts/diagrams. New implementation of the live editor. 项目地址: https://gitcode.com/GitHub_Trending/me/mermaid-live-editor Mermaid Live…...
