当前位置: 首页 > news >正文

基于强化学习算法玩CartPole游戏

什么事CartPole游戏

CartPole(也称为倒立摆问题)是一个经典的控制理论和强化学习的基础问题,通常用于测试和验证控制算法的性能。具体来说,它是一个简单的物理模拟问题,其目标是通过在一个平衡杆(倒立摆)上安装在小车(或称为平衡车)上的水平移动,使杆子保持竖直直立的状态。

有两个动作(action):

左移(0)

右移(1)

四个状态(state): 1. 小车在轨道上的位置 2. 杆子与竖直方向的夹角 3. 小车速度 4. 角度变化率

神经网络设计

1、强化学习的训练网络cartpole_train.py

import  gym
import pygame
import time
import random
import torch
from torch.distributions import Categoricalfrom torch import nn, optim
import torch.nn.functional as Fdef compute_policy_loss(n, log_p):r = list()#构造奖励r列表for i in range(n, 0 ,-1):r.append(i *1.0)r = torch.tensor(r)r = (r - r.mean()) / r.std() #进行标准化处理loss = 0#计算损失函数for pi, ri in zip(log_p, r):loss += -pi * rireturn  lossclass CartPolePolicy(nn.Module):def __init__(self):super(CartPolePolicy, self).__init__()self.fc1 = nn.Linear(in_features = 4, out_features = 128)self.fc2 = nn.Linear(128, 2) #输出为神经元个数为2表示,向左和向向右self.drop = nn.Dropout(p=0.6)def forward(self, x):x = self.fc1(x)x = self.drop(x)x = F.relu(x)x = self.fc2(x)#使用softmax决策最终的行动,是向左还是右return F.softmax(x, dim=1)if __name__ == '__main__':env = gym.make("CartPole-v1") #启动环境env.reset(seed= 543)torch.manual_seed(543)policy = CartPolePolicy() #定义模型optimizer = optim.Adam(policy.parameters(), lr = 0.01) #优化器#我们一共最多训练1000个回合#每个回合最多行动10000次#当某一回合的游戏步数超过5000时,就认为完成训练max_episod = 1000 #最大游戏回合数max_action = 10000 #每回合最大行动数max_steps = 5000 #完成训练的步数for episod in range(1, max_episod + 1):# 对于每一轮循环,都要重新启动一次游戏环境state, _ = env.reset()step = 0log_p = list()for step in range(1, max_action + 1):state = torch.from_numpy(state).float().unsqueeze(0)probs = policy(state) #计算神经网络给出的行动概率# 基于网络给出的概率分布,随机选择行动m = Categorical(probs)# 这里并不是直接使用概率较大的行动,而是通过概率分布生成action, 这样可以进一步探索低概率行动action = m.sample()state, _, done, _, _ = env.step(action.item())if done:break #表示跳出该for循环log_p.append(m.log_prob(action)) #保存每次行动对应的概率分布if step > max_steps: #当step大于最大步数时print(f"Done! last episode {episod} Run steps {step}")break #跳出循序,结束训练#每一回合游戏,都会做一次梯度下降算法optimizer.zero_grad()loss = compute_policy_loss(step, log_p)loss.backward()optimizer.step()if episod % 10 ==0:print(f"Episode {episod} Run step {step}")#保存模型torch.save(policy.state_dict(), f"cartpole_policy.pth")

2、验证:cartpole_eval.py

import  gym
import pygame
import torch.nn as nn
import torch.nn.functional as F
import time
import torch
class CartPolePolicy(nn.Module):def __init__(self):super(CartPolePolicy, self).__init__()self.fc1 = nn.Linear(4, 128)self.fc2 = nn.Linear(128, 2)self.drop = nn.Dropout(p=0.6)def forward(self, x):x = self.fc1(x)x = self.drop(x)x = F.relu(x)x = self.fc2(x)return F.softmax(x, dim=1)if __name__ == '__main__':pygame.init() #初始化pygame#使用gym, 创建一个artPole游戏的运行环境,这个环境是提供给人类玩家使用的env = gym.make('CartPole-v1', render_mode = "human")state, _ =env.reset()#使用env.reset重置环境后,会得到CartPole游戏中关键参数statecart_position = state[0] #小车位置cart_speed = state[1] #小车速度pole_angle = state[2] #杆的角度pole_speed = state[3] #杆的尖端速度#加载网络policy = CartPolePolicy()policy.load_state_dict(torch.load("cartpole_policy.pth"))policy.eval()start_time =time.time()max_action =1000 #设置游戏最大执行次数#最多执行1000次方向键,游戏就可以通关结束step = 0fail = Falsefor step in range(1, max_action + 1):#首先使用time.sleep,使游戏暂停0.3s,用于人的反应,觉得自己反应慢可以设置更长时间# time.sleep(0.3)#小车的控制方式,通过神经网络,来决定小车的运动方向#将环境参数state转为张量state = torch.from_numpy(state).float().unsqueeze(0)#输入至网络模型,计算行动概率probsprobs = policy(state)#选取行动概率最大的行动action =torch.argmax(probs, dim = 1).item()state, _, done, _, _ = env.step(action) #done为True,表示杆倒了if done:fail = Truebreakprint(f"step = {step} action = {action} angle = {state[2]:.2f}  position = {state[0]:.2f}")end_time = time.time()game_time = end_time - start_timeif fail:print(f"Game over ,you play {game_time:.2f} seconds, {step} steps.")else:print(f"Congratulations! you play  {game_time:.2f} seconds, {step} steps.")env.close()

视频讲解:

什么是reinforce强化学习算法,基于强化学习玩CartPole游戏_哔哩哔哩_bilibili

相关文章:

基于强化学习算法玩CartPole游戏

什么事CartPole游戏 CartPole(也称为倒立摆问题)是一个经典的控制理论和强化学习的基础问题,通常用于测试和验证控制算法的性能。具体来说,它是一个简单的物理模拟问题,其目标是通过在一个平衡杆(倒立摆&a…...

uniapp0基础编写安卓原生插件和调用第三方jar包(Ch34的jar包)和如何解决android 如何Application初始化

前言 我假设你会uniapp安卓插件开发了,如果不会请看这篇文章,这篇文章是0基础教学。 这篇文章我们将讲一下如何使用CH34XUARTDriver.jar进行开发成uniapp插件。 它的难点是:uniapp如何Application初始化第三方jar包 先去官网下载CH340/CH341的USB转串口安卓免驱应用库:h…...

使用Leaflet进行船舶航行警告区域绘制实战

目录 前言 一、坐标格式转换 1、数据初认识 2、将区域分割成多个点 3、数据转换 4、数据转换调用 二、WebGIS展示空间位置信息 1、定义底图 2、Polygon的可视化 3、实际效果 三、总结 前言 通常而言,海事部门如海事局,通常会在所述的管辖区域内…...

用Ollama 和 Open WebUI本地部署Llama 3.1 8B

说明: 本人运行环境windows11 N卡6G显存。部署Llama3.1 8B 简介 Ollama是一个开源的大型语言模型服务工具,它允许用户在自己的硬件环境中轻松部署和使用大规模预训练模型。Ollama 的主要功能是在Docker容器内部署和管理大型语言模型(LLM&…...

计算机毕业设计选题推荐-学生作业管理系统-Java/Python项目实战

✨作者主页:IT研究室✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Python…...

RIP实验

实验拓扑: 实验要求: R1-R2-R3-R4-R5:RIP 100 运行版本2 R6-R7:RIP 200 运行版本1 1.使用合理IP地址规划网络,各自创建环回接口 2.R1创建环回 172.16.1.1/24 172.16.2.1/24 172.16.3.1/24 3.要求R3使用R2访问R1环…...

手把手教你如何在宝塔上添加可道云登录页面的ICP备案信息,别跟权威开玩笑。

如何在宝塔上添加可道云登录页面的ICP备案信息 事情的原由来我们开始吧首先登录你的宝塔页面双击打开index.php文件保存退出即可 感谢大佬,希望对被查到的朋友有所帮助! 事情的原由 今天突然收到腾讯云发来的一封Email,说我需要整改我的网站…...

基于JSP技术的大学生校园兼职系统

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:JSP 数据库:MySQL 技术:JSPJavaBeans 工具:MyEclipse,Tomcat,Navicat 系统展示 首页 学…...

VSCode在windows系统下的配置简单版

参考链接 从零开始的vscode安装及环境配置教程(C/C)(Windows系统)_vscode搭建编译器环境-CSDN博客 vscode生成tasks.json、launch.json、c_cpp_properties.json文件_vscode生成launch.json-CSDN博客 自动生成配置文件简单方便!!! 运行c代…...

C++初学(9)

9.1、结构简介 虽然数组能够和存储多个元素,但所有元素必须相同,也就是说,同一个数组不能既存放int类型也存放float类型,而C的结构可以满足要求。结构是一种比数组更灵活的数据格式,因为同一个结构可以存储多种类型的…...

ardupilot开发 --- 网络技术综述 篇

不信人间有白头 一些概念参考文献 一些概念 以太网、局域网、互联网 以太网(Ethernet),是一种计算机局域网技术。以太网是一种有线网络技术,网络传输介质包括:以太网电缆,如常见的双绞线、光纤等。根据传输速度,可以氛…...

一文详解大模型蒸馏工具TextBrewer

原文:https://zhuanlan.zhihu.com/p/648674584 本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。 TextBre…...

Go语言加Vue3零基础入门全栈班10 Go语言+gRPC用户微服务项目实战 2024年07月31日 课程笔记

概述 如果您没有Golang的基础,应该学习如下前置课程。 Golang零基础入门Golang面向对象编程Go Web 基础Go语言开发REST API接口_20240728Go语言操作MySQL开发用户管理系统API教程_20240729Redis零基础快速入门_20231227GoRedis开发用户管理系统API实战_20240730Mo…...

ChatGPT能代替网络作家吗?

最强AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ 当然可以!只要你玩写作AI玩得6,甚至可以达到某些大神的水平! 看看大神、小白、AI输出内容的区…...

Http自定义Header导致的跨域问题

最近写一个小项目,前后端分离,在调试过程中访问远程接口,出现了CORS问题,接口使用的laravel框架,于是添加了解决跨域的中间件,但是前端显示仍存在跨域问题,以为自己写的有问题,检查了…...

python 中 file.read(), file.readline()和file.readlines()区别和用法

python 中 file.read(), file.readline()和file.readlines()区别和用法 文章目录 python 中 file.read(), file.readline()和file.readlines()区别和用法1. file.read()2. file.readline()3. file.readlines()4. 总结5. 注意事项 file.read(), file.readline(), 和 file.readli…...

python 学习: np.pad

在NumPy中,np.pad函数用于对数组进行填充(padding),即在数组的边界处添加额外的值。这在图像处理、信号处理或任何需要扩展数据边界的场景中非常有用。 以下是np.pad函数的一些关键参数和使用示例: array&#xff1a…...

等保2.0 | 人大金仓数据库测评

人大金仓数据库,全称为金仓数据库管理系统KingbaseES(简称:金仓数据库或KingbaseES),是北京人大金仓信息技术股份有限公司自主研制开发的具有自主知识产权的通用关系型数据库管理系统。以下是关于人大金仓数据库的详细…...

AIGC赋能智慧农业:用AI技术绘就作物生长新蓝图

( 于景鑫 国家农业信息化工程技术研究中心)随着人工智能技术的日新月异,AIGC(AI-Generated Content,AI生成内容)正在各行各业掀起一场革命性的浪潮。而在智慧农业领域,AIGC技术的应用也正迸发出耀眼的火花。特别是在作物生长管理方面,AIGC有望彻底改变传…...

yolov8蒸馏(附代码-免费)

首先蒸馏是什么? 模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。 为什么需要蒸馏? 在不增加模型计算…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?

uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件,用于在原生应用中加载 HTML 页面: 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...

【Go语言基础【13】】函数、闭包、方法

文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数(函数作为参数、返回值) 三、匿名函数与闭包1. 匿名函数(Lambda函…...

Netty从入门到进阶(二)

二、Netty入门 1. 概述 1.1 Netty是什么 Netty is an asynchronous event-driven network application framework for rapid development of maintainable high performance protocol servers & clients. Netty是一个异步的、基于事件驱动的网络应用框架,用于…...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用

文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发,后来由Pivotal Software Inc.(现为VMware子公司)接管。RabbitMQ 是一个开源的消息代理和队列服务器,用 Erlang 语言编写。广泛应用于各种分布…...

腾讯云V3签名

想要接入腾讯云的Api,必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口,但总是卡在签名这一步,最后放弃选择SDK,这次终于自己代码实现。 可能腾讯云翻新了接口文档,现在阅读起来,清晰了很多&…...

MySQL JOIN 表过多的优化思路

当 MySQL 查询涉及大量表 JOIN 时,性能会显著下降。以下是优化思路和简易实现方法: 一、核心优化思路 减少 JOIN 数量 数据冗余:添加必要的冗余字段(如订单表直接存储用户名)合并表:将频繁关联的小表合并成…...

【网络安全】开源系统getshell漏洞挖掘

审计过程: 在入口文件admin/index.php中: 用户可以通过m,c,a等参数控制加载的文件和方法,在app/system/entrance.php中存在重点代码: 当M_TYPE system并且M_MODULE include时,会设置常量PATH_OWN_FILE为PATH_APP.M_T…...