当前位置: 首页 > news >正文

用深度强化学习来玩Flappy Bird

目录

演示视频

核心代码


演示视频

用深度强化学习来玩Flappy Bird

核心代码

import torch.nn as nnclass DeepQNetwork(nn.Module):def __init__(self):super(DeepQNetwork, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True))self.conv3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True))self.fc1 = nn.Sequential(nn.Linear(7 * 7 * 64, 512), nn.ReLU(inplace=True))self.fc2 = nn.Linear(512, 2)self._create_weights()def _create_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):nn.init.uniform_(m.weight, -0.01, 0.01)nn.init.constant_(m.bias, 0)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = output.view(output.size(0), -1)output = self.fc1(output)output = self.fc2(output)return output
from itertools import cycle
from numpy.random import randint
from pygame import Rect, init, time, display
from pygame.event import pump
from pygame.image import load
from pygame.surfarray import array3d, pixels_alpha
from pygame.transform import rotate
import numpy as npclass FlappyBird(object):init()fps_clock = time.Clock()screen_width = 288screen_height = 512screen = display.set_mode((screen_width, screen_height))display.set_caption('Deep Q-Network Flappy Bird')base_image = load('assets/sprites/base.png').convert_alpha()background_image = load('assets/sprites/background-black.png').convert()pipe_images = [rotate(load('assets/sprites/pipe-green.png').convert_alpha(), 180),load('assets/sprites/pipe-green.png').convert_alpha()]bird_images = [load('assets/sprites/redbird-upflap.png').convert_alpha(),load('assets/sprites/redbird-midflap.png').convert_alpha(),load('assets/sprites/redbird-downflap.png').convert_alpha()]# number_images = [load('assets/sprites/{}.png'.format(i)).convert_alpha() for i in range(10)]bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]fps = 30pipe_gap_size = 100pipe_velocity_x = -4# parameters for birdmin_velocity_y = -8max_velocity_y = 10downward_speed = 1upward_speed = -9bird_index_generator = cycle([0, 1, 2, 1])def __init__(self):self.iter = self.bird_index = self.score = 0self.bird_width = self.bird_images[0].get_width()self.bird_height = self.bird_images[0].get_height()self.pipe_width = self.pipe_images[0].get_width()self.pipe_height = self.pipe_images[0].get_height()self.bird_x = int(self.screen_width / 5)self.bird_y = int((self.screen_height - self.bird_height) / 2)self.base_x = 0self.base_y = self.screen_height * 0.79self.base_shift = self.base_image.get_width() - self.background_image.get_width()pipes = [self.generate_pipe(), self.generate_pipe()]pipes[0]["x_upper"] = pipes[0]["x_lower"] = self.screen_widthpipes[1]["x_upper"] = pipes[1]["x_lower"] = self.screen_width * 1.5self.pipes = pipesself.current_velocity_y = 0self.is_flapped = Falsedef generate_pipe(self):x = self.screen_width + 10gap_y = randint(2, 10) * 10 + int(self.base_y / 5)return {"x_upper": x, "y_upper": gap_y - self.pipe_height, "x_lower": x, "y_lower": gap_y + self.pipe_gap_size}def is_collided(self):# Check if the bird touch groundif self.bird_height + self.bird_y + 1 >= self.base_y:return Truebird_bbox = Rect(self.bird_x, self.bird_y, self.bird_width, self.bird_height)pipe_boxes = []for pipe in self.pipes:pipe_boxes.append(Rect(pipe["x_upper"], pipe["y_upper"], self.pipe_width, self.pipe_height))pipe_boxes.append(Rect(pipe["x_lower"], pipe["y_lower"], self.pipe_width, self.pipe_height))# Check if the bird's bounding box overlaps to the bounding box of any pipeif bird_bbox.collidelist(pipe_boxes) == -1:return Falsefor i in range(2):cropped_bbox = bird_bbox.clip(pipe_boxes[i])min_x1 = cropped_bbox.x - bird_bbox.xmin_y1 = cropped_bbox.y - bird_bbox.ymin_x2 = cropped_bbox.x - pipe_boxes[i].xmin_y2 = cropped_bbox.y - pipe_boxes[i].yif np.any(self.bird_hitmask[self.bird_index][min_x1:min_x1 + cropped_bbox.width,min_y1:min_y1 + cropped_bbox.height] * self.pipe_hitmask[i][min_x2:min_x2 + cropped_bbox.width,min_y2:min_y2 + cropped_bbox.height]):return Truereturn Falsedef next_frame(self, action):pump()reward = 0.1terminal = False# Check input actionif action == 1:self.current_velocity_y = self.upward_speedself.is_flapped = True# Update scorebird_center_x = self.bird_x + self.bird_width / 2for pipe in self.pipes:pipe_center_x = pipe["x_upper"] + self.pipe_width / 2if pipe_center_x < bird_center_x < pipe_center_x + 5:self.score += 1reward = 1break# Update index and iterationif (self.iter + 1) % 3 == 0:self.bird_index = next(self.bird_index_generator)self.iter = 0self.base_x = -((-self.base_x + 100) % self.base_shift)# Update bird's positionif self.current_velocity_y < self.max_velocity_y and not self.is_flapped:self.current_velocity_y += self.downward_speedif self.is_flapped:self.is_flapped = Falseself.bird_y += min(self.current_velocity_y, self.bird_y - self.current_velocity_y - self.bird_height)if self.bird_y < 0:self.bird_y = 0# Update pipes' positionfor pipe in self.pipes:pipe["x_upper"] += self.pipe_velocity_xpipe["x_lower"] += self.pipe_velocity_x# Update pipesif 0 < self.pipes[0]["x_lower"] < 5:self.pipes.append(self.generate_pipe())if self.pipes[0]["x_lower"] < -self.pipe_width:del self.pipes[0]if self.is_collided():terminal = Truereward = -1self.__init__()# Draw everythingself.screen.blit(self.background_image, (0, 0))self.screen.blit(self.base_image, (self.base_x, self.base_y))self.screen.blit(self.bird_images[self.bird_index], (self.bird_x, self.bird_y))for pipe in self.pipes:self.screen.blit(self.pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))self.screen.blit(self.pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))image = array3d(display.get_surface())display.update()self.fps_clock.tick(self.fps)return image, reward, terminal
import argparse
import torchfrom src.deep_q_network import DeepQNetwork
from src.flappy_bird import FlappyBird
from src.utils import pre_processingdef get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Flappy Bird""")parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")parser.add_argument("--saved_path", type=str, default="trained_models")args = parser.parse_args()return argsdef q_test(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)if torch.cuda.is_available():model = torch.load("{}/flappy_bird".format(opt.saved_path))else:model = torch.load("{}/flappy_bird".format(opt.saved_path), map_location=lambda storage, loc: storage)model.eval()game_state = FlappyBird()image, reward, terminal = game_state.next_frame(0)image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)image = torch.from_numpy(image)if torch.cuda.is_available():model.cuda()image = image.cuda()state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]while True:prediction = model(state)[0]action = torch.argmax(prediction)next_image, reward, terminal = game_state.next_frame(action)next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,opt.image_size)next_image = torch.from_numpy(next_image)if torch.cuda.is_available():next_image = next_image.cuda()next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]state = next_stateif __name__ == "__main__":opt = get_args()q_test(opt)
def get_args():parser = argparse.ArgumentParser("""Implementation of Deep Q Network to play Flappy Bird""")parser.add_argument("--image_size", type=int, default=84, help="The common width and height for all images")parser.add_argument("--batch_size", type=int, default=32, help="The number of images per batch")parser.add_argument("--optimizer", type=str, choices=["sgd", "adam"], default="adam")parser.add_argument("--lr", type=float, default=1e-6)parser.add_argument("--gamma", type=float, default=0.99)parser.add_argument("--initial_epsilon", type=float, default=0.1)parser.add_argument("--final_epsilon", type=float, default=1e-4)parser.add_argument("--num_iters", type=int, default=2000000)parser.add_argument("--replay_memory_size", type=int, default=50000,help="Number of epoches between testing phases")parser.add_argument("--log_path", type=str, default="tensorboard")parser.add_argument("--saved_path", type=str, default="trained_models")args = parser.parse_args()return argsdef train(opt):if torch.cuda.is_available():torch.cuda.manual_seed(123)else:torch.manual_seed(123)model = DeepQNetwork()if os.path.isdir(opt.log_path):shutil.rmtree(opt.log_path)os.makedirs(opt.log_path)writer = SummaryWriter(opt.log_path)optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)criterion = nn.MSELoss()game_state = FlappyBird()image, reward, terminal = game_state.next_frame(0)image = pre_processing(image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size, opt.image_size)image = torch.from_numpy(image)if torch.cuda.is_available():model.cuda()image = image.cuda()state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]replay_memory = []iter = 0while iter < opt.num_iters:prediction = model(state)[0]# Exploration or exploitationepsilon = opt.final_epsilon + ((opt.num_iters - iter) * (opt.initial_epsilon - opt.final_epsilon) / opt.num_iters)u = random()random_action = u <= epsilonif random_action:print("Perform a random action")action = randint(0, 1)else:action = torch.argmax(prediction)next_image, reward, terminal = game_state.next_frame(action)next_image = pre_processing(next_image[:game_state.screen_width, :int(game_state.base_y)], opt.image_size,opt.image_size)next_image = torch.from_numpy(next_image)if torch.cuda.is_available():next_image = next_image.cuda()next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]replay_memory.append([state, action, reward, next_state, terminal])if len(replay_memory) > opt.replay_memory_size:del replay_memory[0]batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))state_batch, action_batch, reward_batch, next_state_batch, terminal_batch = zip(*batch)state_batch = torch.cat(tuple(state for state in state_batch))action_batch = torch.from_numpy(np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])next_state_batch = torch.cat(tuple(state for state in next_state_batch))if torch.cuda.is_available():state_batch = state_batch.cuda()action_batch = action_batch.cuda()reward_batch = reward_batch.cuda()next_state_batch = next_state_batch.cuda()current_prediction_batch = model(state_batch)next_prediction_batch = model(next_state_batch)y_batch = torch.cat(tuple(reward if terminal else reward + opt.gamma * torch.max(prediction) for reward, terminal, prediction inzip(reward_batch, terminal_batch, next_prediction_batch)))q_value = torch.sum(current_prediction_batch * action_batch, dim=1)optimizer.zero_grad()# y_batch = y_batch.detach()loss = criterion(q_value, y_batch)loss.backward()optimizer.step()state = next_stateiter += 1print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}, Q-value: {}".format(iter + 1,opt.num_iters,action,loss,epsilon, reward, torch.max(prediction)))writer.add_scalar('Train/Loss', loss, iter)writer.add_scalar('Train/Epsilon', epsilon, iter)writer.add_scalar('Train/Reward', reward, iter)writer.add_scalar('Train/Q-value', torch.max(prediction), iter)if (iter+1) % 1000000 == 0:torch.save(model, "{}/flappy_bird_{}".format(opt.saved_path, iter+1))torch.save(model, "{}/flappy_bird".format(opt.saved_path))if __name__ == "__main__":opt = get_args()train(opt)

相关文章:

用深度强化学习来玩Flappy Bird

目录 演示视频 核心代码 演示视频 用深度强化学习来玩Flappy Bird 核心代码 import torch.nn as nnclass DeepQNetwork(nn.Module):def __init__(self):super(DeepQNetwork, self).__init__()self.conv1 nn.Sequential(nn.Conv2d(4, 32, kernel_size8, stride4), nn.ReLU(inp…...

HTML5-4-表单

文章目录 表单属性表单标签输入元素文本域&#xff08;Text Fields&#xff09;密码字段单选按钮&#xff08;Radio Buttons&#xff09;复选框&#xff08;Checkboxes&#xff09;按钮&#xff08;button&#xff09;提交按钮(Submit)label标签 文本框&#xff08;textarea&am…...

Nacos 开源版的使用测评

文章目录 一、Nacos的使用二、Nacos和Eureka在性能、功能、控制台体验、上下游生态和社区体验的对比&#xff1a;三、记使使用Nacos中容易犯的错误四、对Nacos开源提出的一些需求 一、Nacos的使用 这里配置mysql的连接方式&#xff0c;spring.datasource.platformmysql是老版本…...

【Linux】一些常见查看各种各样信息的命令

Linux命令 find命令&#xff0c;用来查找文件。常用的按照名字查找-name&#xff0c;按照文件类型查找-type&#xff0c;linux常用的文件类型有七种&#xff0c;普通文件&#xff0c;目录文件&#xff0c;管道&#xff0c;套接字&#xff0c;软链接&#xff0c;块设备&#xf…...

51单片机DHT11温湿度控制系统仿真设计( proteus仿真+程序+原理图+报告+讲解视频)

51单片机DHT11温湿度控制系统仿真设计 1.主要功能&#xff1a;2.仿真3. 程序代码4. 原理图元器件清单5. 设计报告6. 设计资料内容清单&下载链接 51单片机DHT11温湿度控制系统仿真设计( proteus仿真程序原理图报告讲解视频&#xff09; 仿真图proteus8.9及以上 程序编译器&…...

神仙级python入门教程(非常详细),从0到精通,从看这篇开始!

毫无疑问&#xff0c;Python 是当下最火的编程语言之一。对于许多未曾涉足计算机编程的领域「小白」来说&#xff0c;深入地掌握 Python 看似是一件十分困难的事。其实&#xff0c;只要掌握了科学的学习方法并制定了合理的学习计划&#xff0c;Python 从 入门到精通只需要一个月…...

详解4种类型的爬虫技术

聚焦网络爬虫是“面向特定主题需求”的一种爬虫程序&#xff0c;而通用网络爬虫则是捜索引擎抓取系统&#xff08;Baidu、Google、Yahoo等&#xff09;的重要组成部分&#xff0c;主要目的是将互联网上的网页下载到本地&#xff0c;形成一个互联网内容的镜像备份。 增量抓取意…...

QTday1基础

作业 一、做个QT页面 #include "hqyj.h"HQYJ::HQYJ(QWidget *parent)//构造函数定义: QWidget(parent)//显性调用父类的有参构造 {//主界面设置this->resize(540,410);//设置大小this->setFixedSize(540,410);//设置固定大小this->setWindowIcon(QIcon(&q…...

activiti 通过xml上传 直接部署模型

通过流程xml 直接先发布模型&#xff0c;然后再通过发布模型之后的流程定义获取bpmn model来创建Model. 1、通过xml先发布模型 InputStream bpmnStream multipartFile.getInputStream() deployment repositoryService.createDeployment().addInputStream(multipartFile.getO…...

算法题打卡day56-编辑距离 | 583. 两个字符串的删除操作、72. 编辑距离

583. 两个字符串的删除操作 - 力扣&#xff08;LeetCode&#xff09; 状态&#xff1a;查看思路后AC。 和查找子序列的操作类似&#xff0c;但是考虑的是删除操作。代码如下&#xff1a; class Solution { public:int minDistance(string word1, string word2) {int len1 wor…...

SQL中的CASE WHEN语句:从基础到高级应用指南

SQL中的CASE WHEN语句&#xff1a;从基础到高级应用指南 准备工作 - 表1: products 示例数据&#xff1a; 我们使用一个名为"Products"的表&#xff0c;包含以下列&#xff1a;ProductID、ProductName、CategoryID、UnitPrice、StockQuantity。 -- 建表 CREATE TA…...

超时取消子线程任务

文章目录 前言一、编码思路二、使用步骤直接上代码 总结 前言 问题背景: 主线程需要执行一些任务,不能影响主任务执行,这些任务有超时时间,当超过处理时间后,应该不予处理;如果未超时,应该获取到这些任务的执行结果; 一、编码思路 由于主线程正常执行不能影响,任务会处理很久…...

模块化---common.js

入口文件&#xff1a;app.js // require是同步加载 // 客户端&#xff1a;common.js的模块化&#xff0c;需要browserify编译之后才能使用 // 服务端&#xff1a;运行时同步加载&#xff0c;无问题 let module1 require(./module1.js) let module2 require(./module2.js) co…...

VSCode下载、安装及配置、调试的一些过程理解

第一步先下载了vscode&#xff0c;官方地址为&#xff1a;https://code.visualstudio.com/Download 第二步安装vscode&#xff0c;安装环境是win10&#xff0c;安装基本上就是一步步默认即可。 第三步汉化vscode&#xff0c;这一步就是去扩展插件里面下载一个中文插件即可&am…...

KC705开发板——MGT IBERT测试记录

本文介绍使用KC705开发板进行MGT的IBERT测试。 KC705开发板 KC705开发板的图片如下图所示。FPGA芯片型号为XC7K325T-2FFG900C。 MGT MGT是 Multi-Gigabit Transceiver的缩写&#xff0c;是Multi-Gigabit Serializer/Deserializer (SERDES)的别称。MGT包含GTP、GTX、GTH、G…...

前端代码优化散记

把import Button from xxx 的引入方式&#xff0c;变成import {Button} from xxx 的方式引入&#xff0c;以利于按需打包。原生监听事件、定时器等&#xff0c;必须在componentWillUnmount中清除&#xff0c;大型项目会发生内存泄露&#xff0c;极度影响性能。使用PureComponen…...

HTML <map> 标签的使用

map标签的用途&#xff1a;是与img标签绑定使用的&#xff0c;常被用来赋予给客户端图像某处区域特殊的含义&#xff0c;点击该区域可跳转到新的文档。 编写格式&#xff1a; <img src"图片" border"0" usemap"#planetmap" alt"Planets…...

stable diffusion实践操作-大模型介绍

本文专门开一节写大模型相关的内容&#xff0c;在看之前&#xff0c;可以同步关注&#xff1a; stable diffusion实践操作 模型下载网站 国内的是&#xff1a;https://www.liblibai.com 国外的是&#xff1a;https://civitai.com&#xff08;科学上网&#xff09; 一、发展历…...

W5500-EVB-PICO进行MQTT连接订阅发布教程(十二)

前言 上一章我们用开发板通过SNTP协议获取网络协议&#xff0c;本章我们介绍一下开发板通过配置MQTT连接到服务器上&#xff0c;并且订阅和发布消息。 什么是MQTT&#xff1f; MQTT是一种轻量级的消息传输协议&#xff0c;旨在物联网&#xff08;IoT&#xff09;应用中实现设备…...

90、00后严选出的数据可视化工具:奥威BI工具

90、00后主打一个巧用工具&#xff0c;绝不低效率上班&#xff0c;因此当擅长大数据智能可视化分析的BI数据可视化工具出现后&#xff0c;自然而然地就成了90、00后职场人常用的数据可视化工具。 奥威BI工具三大特点&#xff0c;让职场人眼前一亮&#xff01; 1、零编程&…...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

Cloudflare 从 Nginx 到 Pingora:性能、效率与安全的全面升级

在互联网的快速发展中&#xff0c;高性能、高效率和高安全性的网络服务成为了各大互联网基础设施提供商的核心追求。Cloudflare 作为全球领先的互联网安全和基础设施公司&#xff0c;近期做出了一个重大技术决策&#xff1a;弃用长期使用的 Nginx&#xff0c;转而采用其内部开发…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)

前言&#xff1a; 最近在做行为检测相关的模型&#xff0c;用的是时空图卷积网络&#xff08;STGCN&#xff09;&#xff0c;但原有kinetic-400数据集数据质量较低&#xff0c;需要进行细粒度的标注&#xff0c;同时粗略搜了下已有开源工具基本都集中于图像分割这块&#xff0c…...

Kafka主题运维全指南:从基础配置到故障处理

#作者&#xff1a;张桐瑞 文章目录 主题日常管理1. 修改主题分区。2. 修改主题级别参数。3. 变更副本数。4. 修改主题限速。5.主题分区迁移。6. 常见主题错误处理常见错误1&#xff1a;主题删除失败。常见错误2&#xff1a;__consumer_offsets占用太多的磁盘。 主题日常管理 …...

如何配置一个sql server使得其它用户可以通过excel odbc获取数据

要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据&#xff0c;你需要完成以下配置步骤&#xff1a; ✅ 一、在 SQL Server 端配置&#xff08;服务器设置&#xff09; 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到&#xff1a;SQL Server 网络配…...

用鸿蒙HarmonyOS5实现国际象棋小游戏的过程

下面是一个基于鸿蒙OS (HarmonyOS) 的国际象棋小游戏的完整实现代码&#xff0c;使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├── …...

【2D与3D SLAM中的扫描匹配算法全面解析】

引言 扫描匹配(Scan Matching)是同步定位与地图构建(SLAM)系统中的核心组件&#xff0c;它通过对齐连续的传感器观测数据来估计机器人的运动。本文将深入探讨2D和3D SLAM中的各种扫描匹配算法&#xff0c;包括数学原理、实现细节以及实际应用中的性能对比&#xff0c;特别关注…...

Qt 按钮类控件(Push Button 与 Radio Button)(1)

文章目录 Push Button前提概要API接口给按钮添加图标给按钮添加快捷键 Radio ButtonAPI接口性别选择 Push Button&#xff08;鼠标点击不放连续移动快捷键&#xff09; Radio Button Push Button 前提概要 1. 之前文章中所提到的各种跟QWidget有关的各种属性/函数/方法&#…...

EC2安装WebRTC sdk-c环境、构建、编译

1、登录新的ec2实例&#xff0c;证书可以跟之前的实例用一个&#xff1a; ssh -v -i ~/Documents/cert/qa.pem ec2-user70.xxx.165.xxx 2、按照sdk-c demo中readme的描述开始安装环境&#xff1a; https://github.com/awslabs/amazon-kinesis-video-streams-webrtc-sdk-c 2…...

起重机指挥人员在工作中需要注意哪些安全事项?

起重机指挥人员在作业中承担着协调设备运行、保障作业安全的关键职责&#xff0c;其安全操作直接关系到整个起重作业的安全性。以下从作业前、作业中、作业后的全流程&#xff0c;详细说明指挥人员需注意的安全事项&#xff1a; 一、作业前的安全准备 资质与状态检查&#xff…...