当前位置: 首页 > article >正文

训练自定义游戏,构建Gymnasium训练环境

认识Gymnasium使用stable_baseline3只需要定义好Gymnasium环境关注训练的奖励机制将重点放在业务的开发上而不是复杂的算法。Gymnasium提供了几个核心的api方法功能返回值reset()将环境重置为初始状态开始新回合。obs, infostep(action)环境向前推进一步执行动作。obs, reward, terminated, truncated, inforender()可视化环境根据render_mode渲染图像或弹出窗口。视配置而定通常无或为np.arrayclose()释放环境资源关闭窗口、清理内存。无其中的各个返回值的含义observation(Object): 当前状态的描述。例如敌人玩家的位置玩家的状态等reward(Float): 上一步动作获得的奖励terminated(Bool): 是否由于任务逻辑结束。例如到达终点、掉进岩浆等truncated(Bool): 是否由于外部限制结束。例如达到最大步数 500 步info(Dict): 辅助诊断信息模型训练通常不用用于用户自定义调试或记录额外统计。手动构建环境案例案例描述利用pygame构建一个简单的游戏躲避掉落方块利用构建的奖励机制进行强化学习。import gymnasium as gym from gymnasium import spaces import numpy as np import pygame import random import cv2 import os from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CheckpointCallback from stable_baselines3.common.env_checker import check_env class MyEnv(gym.Env): def __init__(self, render_modeNone): super(MyEnv, self).__init__() #初始化参数 self.width 400 self.height 300 self.player_size 30 self.enemy_size 30 self.render_mode render_mode self.action_space spaces.Discrete(3) self.observation_space spaces.Box( low0, high255, shape(84, 84, 3), dtypenp.uint8 ) pygame.init() if self.render_mode human: self.screen pygame.display.set_mode((self.width, self.height)) self.canvas pygame.Surface((self.width, self.height)) self.font pygame.font.SysFont(monospace, 15) def reset(self, seedNone, optionsNone): super().reset(seedseed) self.player_x self.width // 2 - self.player_size // 2 self.player_y self.height - self.player_size - 10 self.enemies [] self.score 0 self.frame_count 0 self.current_speed 5 self.spawn_rate 30 return self._get_obs(), {} def step(self, action): reward 0 terminated False truncated False move_speed 8 if action 1 and self.player_x 0: # self.player_x - move_speed reward - 0.05 if action 2 and self.player_x self.width - self.player_size: self.player_x move_speed reward - 0.05 self.frame_count 1 level self.score // 5 self.current_speed 5 level self.spawn_rate 30 - level * 2 spawn_rate max(10, 30 - level) if self.frame_count spawn_rate: self.frame_count 0 enemy_x random.randint(0, self.width - self.enemy_size) self.enemies.append([enemy_x, 0]) # [x, y] for enemy in self.enemies: enemy[1] self.current_speed player_rect pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size) enemy_rect pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size) if player_rect.colliderect(enemy_rect): reward -10 terminated True elif enemy[1] self.height: self.enemies.remove(enemy) self.score 1 reward 1 if not terminated: if self.score 100: reward 0.01 reward 0.01 obs self._get_obs() if self.render_mode human: self._render_window() return obs, reward, terminated, truncated, {} def _get_obs(self): self.canvas.fill((0, 0, 0)) pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size)) for enemy in self.enemies: pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size)) img_array pygame.surfarray.array3d(self.canvas) img_array np.transpose(img_array, (1, 0, 2)) obs cv2.resize(img_array, (84, 84), interpolationcv2.INTER_AREA) return obs.astype(np.uint8) def _render_window(self): self.screen.blit(self.canvas, (0, 0)) text self.font.render(fScore: {self.score}, True, (255, 255, 255)) self.screen.blit(text, (10, 10)) pygame.display.flip() for event in pygame.event.get(): if event.type pygame.QUIT: pygame.quit() def train(): log_dir logs/DodgeGame os.makedirs(log_dir, exist_okTrue) env MyEnv() check_env(env) print(环境检查通过...) model_path models/dodge_ai.zip if not os.path.exists(model_path): print( 未发现旧模型从头开始训练...) model PPO( CnnPolicy, env, verbose1, tensorboard_loglog_dir, learning_rate0.0001, n_steps4096, batch_size256, devicecuda) reset_timesteps True else: print(发现旧模型加载并继续训练...) model PPO.load( model_path, envenv, devicecuda, custom_objects{learning_rate: 0.0001, n_steps: 4096, batch_size: 256} ) reset_timesteps False print(开始训练...) model.learn( total_timesteps50000, reset_num_timestepsreset_timesteps ) model.save(models/dodge_ai) print(模型已保存) env.close() def prodict(): env MyEnv(render_modehuman) model PPO.load(models/dodge_ai, envenv, devicecuda) obs, _ env.reset() while True: action, _states model.predict(obs, deterministicTrue) obs, reward, terminated, truncated, info env.step(action) if terminated or truncated: obs, _ env.reset() pygame.time.Clock().tick(30) if __name__ __main__: train() prodict()代码解析代码流程如下构建游戏环境-训练模型-模型预测本篇重点讲构建游戏环境其中的pygame相关代码简略另外两个流程参考之前文章。构建游戏环境初始化类该类继承gym.Env类class MyEnv(gym.Env):构造函数__init__def __init__(self, render_modeNone): super(MyEnv, self).__init__() #初始化参数 self.width 400 self.height 300 self.player_size 30 self.enemy_size 30 self.render_mode render_mode self.action_space spaces.Discrete(3) self.observation_space spaces.Box( low0, high255, shape(84, 84, 3), dtypenp.uint8 ) pygame.init() if self.render_mode human: self.screen pygame.display.set_mode((self.width, self.height)) self.canvas pygame.Surface((self.width, self.height)) self.font pygame.font.SysFont(monospace, 15)在构造函数中我们主要完成的是声明训练的维度和输入输入self.action_space spaces.Discrete(3)其中的self.action_space是固定名称的父类变量。spaces.Discrete(3)声明输入的数量例如向左 向右 和 不动3个输入。观测维度self.observation_space也是固定名称的父类变量。spaces.Box声明观测维度。self.observation_space spaces.Box( low0, high255, shape(84, 84, 3), dtypenp.uint8 )low观测参数的最小值high观测参数的最大值shape声明维度。例如观测图片shape(高宽RGB)观测一个平面shape(高,宽)dtype每个变量类型这里选np.uint8能够节省训练成本默认是浮点型的。任务重置 reset相当于初始化游戏状态游戏的重新开始。返回的是观测值和状态信息用于调试日志def reset(self, seedNone, optionsNone): super().reset(seedseed) self.player_x self.width // 2 - self.player_size // 2 self.player_y self.height - self.player_size - 10 self.enemies [] self.score 0 self.frame_count 0 self.current_speed 5 self.spawn_rate 30 return self._get_obs(), {}观测值_get_obs通过pygame画出的画面然后用opencv进行简单处理转换坐标轴由于opencv坐标xy轴跟pygame的xy是颠倒的将画面缩放到84 * 84可以提高训练效率def _get_obs(self): self.canvas.fill((0, 0, 0)) pygame.draw.rect(self.canvas, (50, 150, 255), (self.player_x, self.player_y, self.player_size, self.player_size)) for enemy in self.enemies: pygame.draw.rect(self.canvas, (255, 50, 50), (enemy[0], enemy[1], self.enemy_size, self.enemy_size)) img_array pygame.surfarray.array3d(self.canvas) img_array np.transpose(img_array, (1, 0, 2)) obs cv2.resize(img_array, (84, 84), interpolationcv2.INTER_AREA) return obs.astype(np.uint8)步 step重要这个函数是强化训练的核心规定了在一帧或者一步我们给AI的分数。分数的设置至关重要这直接决定了训练出来AI的质量根据下面代码大部分都是游戏逻辑主要讲设置奖励分数在AI进行移动时 惩罚 0.05 分在AI存活时 奖励 0.01分游戏分数大于100时 存活奖励 0.02分在障碍物完全下落时 奖励 1 分在与障碍物碰撞时 惩罚 10 分def step(self, action): reward 0 terminated False truncated False move_speed 8 if action 1 and self.player_x 0: # self.player_x - move_speed reward - 0.05 if action 2 and self.player_x self.width - self.player_size: self.player_x move_speed reward - 0.05 self.frame_count 1 level self.score // 5 self.current_speed 5 level self.spawn_rate 30 - level * 2 spawn_rate max(10, 30 - level) if self.frame_count spawn_rate: self.frame_count 0 enemy_x random.randint(0, self.width - self.enemy_size) self.enemies.append([enemy_x, 0]) # [x, y] for enemy in self.enemies: enemy[1] self.current_speed player_rect pygame.Rect(self.player_x, self.player_y, self.player_size, self.player_size) enemy_rect pygame.Rect(enemy[0], enemy[1], self.enemy_size, self.enemy_size) if player_rect.colliderect(enemy_rect): reward -10 terminated True elif enemy[1] self.height: self.enemies.remove(enemy) self.score 1 reward 1 if not terminated: if self.score 100: reward 0.01 reward 0.01 obs self._get_obs() if self.render_mode human: self._render_window() return obs, reward, terminated, truncated, {}展示游戏画面下面完全是pygame代码用于显示游戏画面这里就不解释了。def _render_window(self): self.screen.blit(self.canvas, (0, 0)) text self.font.render(fScore: {self.score}, True, (255, 255, 255)) self.screen.blit(text, (10, 10)) pygame.display.flip() for event in pygame.event.get(): if event.type pygame.QUIT: pygame.quit()

相关文章:

训练自定义游戏,构建Gymnasium训练环境

认识Gymnasium使用stable_baseline3只需要定义好Gymnasium环境,关注训练的奖励机制,将重点放在业务的开发上而不是复杂的算法。Gymnasium提供了几个核心的api:方法功能返回值reset()将环境重置为初始状态,开始新回合。obs, infost…...

AI率降完又反弹原因在这里解决方案也在

论文AI率降到15%,隔了一周再测,又变成了24%。 这个情况不是你的错,也不是工具骗你,而是有几个实际原因导致的。这篇文章解释清楚原因,然后给解决方案。 AI率反弹的3个真实原因 原因一:检测系统更新了 这…...

如何设计高效的Emscripten与WebAssembly接口:平衡简洁与完整的终极指南

如何设计高效的Emscripten与WebAssembly接口:平衡简洁与完整的终极指南 【免费下载链接】emscripten Emscripten: An LLVM-to-WebAssembly Compiler 项目地址: https://gitcode.com/gh_mirrors/em/emscripten Emscripten作为一款强大的LLVM-to-WebAssembly编…...

Qwen-Image-Layered快速部署:ComfyUI镜像一键启动与配置

Qwen-Image-Layered快速部署:ComfyUI镜像一键启动与配置 1. 引言:图像分层的革命性突破 1.1 传统图像编辑的痛点 在常规的图像处理流程中,我们常常遇到一个根本性难题:图像一旦生成或拍摄完成,就变成了一个"不…...

XXL-SSO开源项目未来展望:技术趋势与roadmap解读

XXL-SSO开源项目未来展望:技术趋势与roadmap解读 XXL-SSO作为一款分布式单点登录框架,已在众多企业中得到广泛应用,为多系统统一认证提供了轻量级且高扩展性的解决方案。随着分布式系统架构的不断演进,XXL-SSO正面临新的技术挑战…...

vue-treeselect源码深度剖析:理解组件内部工作原理

vue-treeselect源码深度剖析:理解组件内部工作原理 【免费下载链接】vue-treeselect A multi-select component with nested options support for Vue.js 项目地址: https://gitcode.com/gh_mirrors/vu/vue-treeselect vue-treeselect是一个功能强大的Vue.js…...

Windows窗口置顶3分钟快速上手指南:告别频繁切换的烦恼

Windows窗口置顶3分钟快速上手指南:告别频繁切换的烦恼 【免费下载链接】PinWin Pin any window to be always on top of the screen 项目地址: https://gitcode.com/gh_mirrors/pin/PinWin 你是否曾在处理多个窗口时感到手忙脚乱?当你在写代码时…...

收藏!小白程序员必看:轻松入门大模型核心概念MCP与Skill,解锁AI能力新姿势!

本文通过生活化比喻,深入浅出地解释了AI领域中的MCP和Skill两大核心概念。MCP如同AI世界的“USB接口”,是标准化的连接协议,让AI能调用外部工具;Skill则像“工作手册”,是工作规范/技能模板,告诉AI在不同场…...

为什么选择Clasp?10个理由让你彻底爱上本地开发Apps Script [特殊字符]

为什么选择Clasp?10个理由让你彻底爱上本地开发Apps Script 🚀 【免费下载链接】clasp 🔗 Command Line Apps Script Projects 项目地址: https://gitcode.com/gh_mirrors/clasp/clasp Clasp(Command Line Apps Script Pro…...

PPTist:开源在线演示文稿工具的创新实践与全场景应用指南

PPTist:开源在线演示文稿工具的创新实践与全场景应用指南 【免费下载链接】PPTist PowerPoint-ist(/pauəpɔintist/), An online presentation application that replicates most of the commonly used features of MS PowerPoint, allowing…...

Windows网络测速终极指南:用iperf3精准诊断你的网络性能

Windows网络测速终极指南:用iperf3精准诊断你的网络性能 【免费下载链接】iperf3-win-builds iperf3 binaries for Windows. Benchmark your network limits. 项目地址: https://gitcode.com/gh_mirrors/ip/iperf3-win-builds 你是否经常遇到网络卡顿、视频缓…...

如何用PyFlow创建自定义节点:从函数到可视化组件的完整指南

如何用PyFlow创建自定义节点:从函数到可视化组件的完整指南 【免费下载链接】PyFlow Visual scripting framework for python 项目地址: https://gitcode.com/gh_mirrors/py/PyFlow PyFlow是一款强大的Python可视化脚本框架,它允许开发者通过拖拽…...

Ubuntu22.04部署Cartographer:从一键安装到参数调优全解析

1. 环境准备:Ubuntu 22.04与ROS2 Humble基础配置 在开始部署Cartographer之前,确保你的Ubuntu 22.04系统已经完成基础环境配置。我遇到过不少开发者因为跳过这一步,导致后续安装出现各种依赖问题。这里分享几个关键检查点: 首先…...

webpack-blocks生态全景:从官方块到第三方扩展的完整盘点

webpack-blocks生态全景:从官方块到第三方扩展的完整盘点 【免费下载链接】webpack-blocks 📦 Configure webpack using functional feature blocks. 项目地址: https://gitcode.com/gh_mirrors/we/webpack-blocks webpack-blocks是一个革命性的w…...

OpenSpeedy高效加速工具分发流程全解析:从环境到发布的实践指南

OpenSpeedy高效加速工具分发流程全解析:从环境到发布的实践指南 【免费下载链接】OpenSpeedy 🎮 An open-source game speed modifier. 项目地址: https://gitcode.com/gh_mirrors/op/OpenSpeedy OpenSpeedy作为一款开源GitHub加速工具&#xff0…...

颈椎病反复复发?终于找到根源解决办法

颈椎疼治好了又犯,花钱不少、遭罪不少,到底为啥?核心就两点:只止疼不修病灶、纤维环破裂没修复。 普通治疗只能暂时推开压迫,髓核还会再次突出,神经反复受刺激,酸痛麻木永远断不了根。长春颈椎腰…...

我在 Mac 写了个服务,硬要它在 18 岁高龄的 Windows 服务器上跑,结果…

前言 事情是这样的。 我有个朋友(以下称他为"怨种朋友"),找到我说: "帮我写个 Go 服务,在你自己 Mac 上开发,最后要能跑在咱们公司那台快入土的 Windows 2008 服务器上。" 我当时的…...

别再手动量了!用Python+Open3D给BIM模型做‘CT扫描’,自动揪出施工误差(附完整代码)

BIM模型质量检测革命:PythonOpen3D实现毫米级施工误差智能分析 施工现场的质量控制一直是建筑行业的核心痛点。传统靠人工抽检的方式不仅效率低下,还容易遗漏隐蔽问题。想象一下,如果能把BIM模型当作"数字孪生体",用三维…...

DynamiCrafter完全指南:从安装到生成高质量动画视频

DynamiCrafter完全指南:从安装到生成高质量动画视频 【免费下载链接】DynamiCrafter DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors 项目地址: https://gitcode.com/gh_mirrors/dy/DynamiCrafter DynamiCrafter是一款强大的AI动…...

红蓝对抗深度解析:从技术体系到落地实践,企业安全真正的实战课

红蓝对抗深度解析:从技术体系到落地实践,企业安全真正的实战课 在数字化攻防进入 “实战对抗” 时代的今天,红蓝对抗已成为企业检验安全防御体系、提升应急响应能力的核心手段。不同于传统的漏洞扫描和合规检查,红蓝对抗以 “高仿…...

护网行动入门指南:零基础也能参与,快速积累网安实战经验

护网行动入门指南:如何参与并积累实战经验 护网行动是国内最高规格的网络安全实战演练,旨在检验企业、单位的网络安全防御能力,现已成为网络安全领域的“实战练兵场”。对计算机专业学生而言,参与护网行动不仅能积累宝贵的实战经…...

如何用MCQTSS_QQMusic解决音乐资源获取难题?3大技术突破实现无损下载

如何用MCQTSS_QQMusic解决音乐资源获取难题?3大技术突破实现无损下载 【免费下载链接】MCQTSS_QQMusic QQ音乐解析 项目地址: https://gitcode.com/gh_mirrors/mc/MCQTSS_QQMusic 在数字音乐时代,QQ音乐作为国内领先的音乐平台,拥有海…...

TFLint Docker终极指南:在容器中轻松运行Terraform代码检查

TFLint Docker终极指南:在容器中轻松运行Terraform代码检查 【免费下载链接】tflint A Pluggable Terraform Linter 项目地址: https://gitcode.com/gh_mirrors/tf/tflint TFLint是一个可插拔的Terraform代码检查工具,帮助开发者发现Terraform配置…...

React Scroll Parallax核心组件详解:Parallax、ParallaxBanner和ParallaxProvider

React Scroll Parallax核心组件详解:Parallax、ParallaxBanner和ParallaxProvider 【免费下载链接】react-scroll-parallax 🔮 React hooks and components to create parallax scroll effects for banners, images or any other DOM elements. 项目地…...

小米设备集成终极测试指南:确保HomeAssistant稳定运行的7个关键步骤

小米设备集成终极测试指南:确保HomeAssistant稳定运行的7个关键步骤 【免费下载链接】hass-xiaomi-miot Automatic integrate all Xiaomi devices to HomeAssistant via miot-spec, support Wi-Fi, BLE, ZigBee devices. 小米米家智能家居设备接入Hass集成 项目地…...

告别键盘连击烦恼:这款开源工具让你的机械键盘重获新生

告别键盘连击烦恼:这款开源工具让你的机械键盘重获新生 【免费下载链接】KeyboardChatterBlocker A handy quick tool for blocking mechanical keyboard chatter. 项目地址: https://gitcode.com/gh_mirrors/ke/KeyboardChatterBlocker 还在为键盘连击问题而…...

多模态跨语言翻译引擎实战指南:本地化部署与场景化应用

多模态跨语言翻译引擎实战指南:本地化部署与场景化应用 【免费下载链接】seamless-m4t-v2-large 项目地址: https://ai.gitcode.com/hf_mirrors/ai-gitcode/seamless-m4t-v2-large 在全球化协作日益频繁的今天,跨语言翻译已成为打破沟通壁垒的核…...

抖音批量下载工具高效应用全攻略:从单视频到批量采集的完整指南

抖音批量下载工具高效应用全攻略:从单视频到批量采集的完整指南 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallb…...

AllTube Download 10个实用技巧:从基础下载到高级格式转换

AllTube Download 10个实用技巧:从基础下载到高级格式转换 【免费下载链接】alltube Web GUI for youtube-dl 项目地址: https://gitcode.com/gh_mirrors/al/alltube AllTube Download 是一款基于 youtube-dl 的 Web GUI 工具,让用户能够轻松从 Y…...

如何用开源工具实现专业级图像修复与纹理合成?揭秘GIMP Resynthesizer的技术奥秘

如何用开源工具实现专业级图像修复与纹理合成?揭秘GIMP Resynthesizer的技术奥秘 【免费下载链接】resynthesizer Suite of gimp plugins for texture synthesis 项目地址: https://gitcode.com/gh_mirrors/re/resynthesizer 在数字图像处理领域,…...