当前位置: 首页 > article >正文

DQN 玩 2048 实战|第三期!优化网络,使用GPU、Env奖励优化

视频讲解:

DQN 玩 2048 实战|第三期!优化网络,使用GPU、Env奖励优化

1. 仅考虑局部合并奖励:目前的奖励只设置为合并方块时获得的分数,只关注了每一步的即时合并收益,而没有对最终达成 2048 这个目标给予额外的激励,如果没有对达成 2048 给予足够的奖励信号,Agent 可能不会将其作为一个重要的目标

2. 训练硬件资源利用不高,没有使用GPU进行加速,默认为CPU,较慢

代码修改如下:

step函数里面,输入维度增加max_tile最大的数是多少

if 2048 in self.board:reward += 10000done = True
state = self.board.flatten()
max_tile = np.max(self.board)
state = np.append(state, max_tile)
return state, reward, done
input_size = 17

检查系统中是否有可用的 GPU,如果有则使用 GPU 进行计算,否则使用 CPU。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

在 train ,创建模型实例后,使用 .to(device) 将模型移动到指定的设备(GPU 或 CPU)

model = DQN(input_size, output_size).to(device)
target_model = DQN(input_size, output_size).to(device)

在训练和推理过程中,将输入数据(状态、动作、奖励等)也移动到指定的设备上。

state = torch.FloatTensor(state).unsqueeze(0).to(device)next_state = torch.FloatTensor(next_state).unsqueeze(0).to(device)states = torch.FloatTensor(states).to(device)
actions = torch.LongTensor(actions).to(device)
rewards = torch.FloatTensor(rewards).to(device)
next_states = torch.FloatTensor(next_states).to(device)
dones = torch.FloatTensor(dones).to(device)

将 state 和 next_state 先使用 .cpu() 方法移动到 CPU 上,再使用 .numpy() 方法转换为 NumPy 数组

replay_buffer.add(state.cpu().squeeze(0).numpy(), action, reward, next_state.cpu().squeeze(0).numpy(), done)

这个不改的话,会出现 TypeError: can't convert cuda:0 device type tensor to numpy 错误

完整代码如下:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.table import Tabledevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 2048 游戏环境类
class Game2048:def __init__(self):self.board = np.zeros((4, 4), dtype=int)self.add_random_tile()self.add_random_tile()def add_random_tile(self):empty_cells = np.argwhere(self.board == 0)if len(empty_cells) > 0:index = random.choice(empty_cells)self.board[index[0], index[1]] = 2 if random.random() < 0.9 else 4def move_left(self):reward = 0new_board = np.copy(self.board)for row in range(4):line = new_board[row]non_zero = line[line != 0]merged = []i = 0while i < len(non_zero):if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:merged.append(2 * non_zero[i])reward += 2 * non_zero[i]i += 2else:merged.append(non_zero[i])i += 1new_board[row] = np.pad(merged, (0, 4 - len(merged)), 'constant')if not np.array_equal(new_board, self.board):self.board = new_boardself.add_random_tile()return rewarddef move_right(self):self.board = np.fliplr(self.board)reward = self.move_left()self.board = np.fliplr(self.board)return rewarddef move_up(self):self.board = self.board.Treward = self.move_left()self.board = self.board.Treturn rewarddef move_down(self):self.board = self.board.Treward = self.move_right()self.board = self.board.Treturn rewarddef step(self, action):if action == 0:reward = self.move_left()elif action == 1:reward = self.move_right()elif action == 2:reward = self.move_up()elif action == 3:reward = self.move_down()done = not np.any(self.board == 0) and all([np.all(self.board[:, i] != self.board[:, i + 1]) for i in range(3)]) and all([np.all(self.board[i, :] != self.board[i + 1, :]) for i in range(3)])if 2048 in self.board:reward += 10000done = Truestate = self.board.flatten()max_tile = np.max(self.board)state = np.append(state, max_tile)return state, reward, donedef reset(self):self.board = np.zeros((4, 4), dtype=int)self.add_random_tile()self.add_random_tile()state = self.board.flatten()max_tile = np.max(self.board)state = np.append(state, max_tile)return state# 深度 Q 网络类
class DQN(nn.Module):def __init__(self, input_size, output_size):super(DQN, self).__init__()self.fc1 = nn.Linear(input_size, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 经验回放缓冲区类
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):batch = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*batch)return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)def __len__(self):return len(self.buffer)# 可视化函数
def visualize_board(board, ax):ax.clear()table = Table(ax, bbox=[0, 0, 1, 1])nrows, ncols = board.shapewidth, height = 1.0 / ncols, 1.0 / nrows# 定义颜色映射cmap = mcolors.LinearSegmentedColormap.from_list("", ["white", "yellow", "orange", "red"])for (i, j), val in np.ndenumerate(board):color = cmap(np.log2(val + 1) / np.log2(2048 + 1)) if val > 0 else "white"table.add_cell(i, j, width, height, text=val if val > 0 else "",loc='center', facecolor=color)ax.add_table(table)ax.set_axis_off()plt.draw()plt.pause(0.1)# 训练函数
def train():env = Game2048()input_size = 17output_size = 4model = DQN(input_size, output_size).to(device)target_model = DQN(input_size, output_size).to(device)target_model.load_state_dict(model.state_dict())target_model.eval()optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.MSELoss()replay_buffer = ReplayBuffer(capacity=10000)batch_size = 32gamma = 0.99epsilon = 1.0epsilon_decay = 0.995epsilon_min = 0.01update_target_freq = 10num_episodes = 1000fig, ax = plt.subplots()for episode in range(num_episodes):state = env.reset()state = torch.FloatTensor(state).unsqueeze(0).to(device)done = Falsetotal_reward = 0while not done:visualize_board(env.board, ax)if random.random() < epsilon:action = random.randint(0, output_size - 1)else:q_values = model(state)action = torch.argmax(q_values, dim=1).item()next_state, reward, done = env.step(action)next_state = torch.FloatTensor(next_state).unsqueeze(0).to(device)replay_buffer.add(state.cpu().squeeze(0).numpy(), action, reward, next_state.cpu().squeeze(0).numpy(), done)if len(replay_buffer) >= batch_size:states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)states = torch.FloatTensor(states).to(device)actions = torch.LongTensor(actions).to(device)rewards = torch.FloatTensor(rewards).to(device)next_states = torch.FloatTensor(next_states).to(device)dones = torch.FloatTensor(dones).to(device)q_values = model(states)q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)next_q_values = target_model(next_states)next_q_values = next_q_values.max(1)[0]target_q_values = rewards + gamma * (1 - dones) * next_q_valuesloss = criterion(q_values, target_q_values)optimizer.zero_grad()loss.backward()optimizer.step()state = next_statetotal_reward += rewardif episode % update_target_freq == 0:target_model.load_state_dict(model.state_dict())epsilon = max(epsilon * epsilon_decay, epsilon_min)print(f"Episode {episode}: Total Reward = {total_reward}, Epsilon = {epsilon}")plt.close()if __name__ == "__main__":train()

相关文章:

DQN 玩 2048 实战|第三期!优化网络,使用GPU、Env奖励优化

视频讲解&#xff1a; DQN 玩 2048 实战&#xff5c;第三期&#xff01;优化网络&#xff0c;使用GPU、Env奖励优化 1. 仅考虑局部合并奖励&#xff1a;目前的奖励只设置为合并方块时获得的分数&#xff0c;只关注了每一步的即时合并收益&#xff0c;而没有对最终达成 2048 这个…...

【python】http post 在body中传递json数据 以发送

http post 在body中传递json数据 以发送&#xff0c;json的格式非常重要这里要传递json对象&#xff0c;而不是一个json字符串 传递post一个 JSON 字符串 是ok的 是的&#xff0c; {"rsource_rhythm_action_list": {"name": "AI_\\u6708\\u4eae\\u…...

Linux错误(2)程序触发SIGBUS信号分析

Linux错误(2)之SIGBUS错误分析 Author: Once Day Date: 2025年3月12日 一位热衷于Linux学习和开发的菜鸟&#xff0c;试图谱写一场冒险之旅&#xff0c;也许终点只是一场白日梦… 漫漫长路&#xff0c;有人对你微笑过嘛… 全系列文章可参考专栏: Linux实践记录_Once_day的博…...

【Halcon】灰度不均解决方案

目录 1、平场校正 2、形态学背景估计 3、频域滤波抑制低频光照不均 4、动态局部自适应 1、平场校正 原理:通过白场(White Image)和黑场(Black Image)图像,手动计算校正系数 * 读取图像 read_image(ImageRaw, raw_image) // 原始图像 read_image(ImageWhite, …...

滑动窗口算法详解:从入门到精通

目录 引言 1. 滑动窗口算法简介 2. 滑动窗口的基本思想 3. 滑动窗口的应用场景 3.1 最大子数组和 3.2 最小覆盖子串 3.3 最长无重复字符子串 4. 滑动窗口的实现步骤 5. 滑动窗口的代码示例 6. 滑动窗口的优化技巧 6.1 使用哈希表记录字符频率 6.2 使用双指针维护窗口…...

JAVA数据库技术(一)

JDBC 简介 JDBC&#xff08;Java Database Connectivity&#xff09;是Java平台提供的一套用于执行SQL语句的Java API。它允许Java程序连接到数据库&#xff0c;并通过发送SQL语句来查询、更新和管理数据库中的数据。JDBC为不同的数据库提供了一种统一的访问方式&#xff0c;使…...

LightGBM + TA-Lib A股实战进阶:Optuna调优与Plotly可视化详解

LightGBM TA-Lib A 股实战进阶&#xff1a;Optuna 调优与 Plotly 可视化详解 本文系统讲解了 LightGBM 在 A 股市场的应用&#xff0c;涵盖模型构建、Optuna 参数调优及 Plotly 可视化。通过实战案例&#xff0c;帮助读者全面掌握相关技术&#xff0c;提升在金融数据分析与预测…...

第二:go 链接mysql 数据库

mac  mysql 安装 的步骤 mysql  安装 配制&#xff1a; https://juejin.cn/post/7454870544929472550 mac brew 如何安装mysql数据库 要在Mac上使用Homebrew安装MySQL数据库&#xff0c;请按照以下步骤操作&#xff1a;步骤 1: 安装Homebrew 如果你还没有安装Homebrew&a…...

QListView、QListWidget、QTableView和QTableWidget

一、概念 在Qt框架中&#xff0c;QListView、QListWidget、QTableView和QTableWidget都是用于显示列表或表格数据的控件。 QListView是一个基于模型-视图架构的控件&#xff0c;用于展示列表形式的数据。它本身并不存储数据&#xff0c;而是依赖于一个QAbstractListModel或其子…...

[贪心算法]-最大数(lambda 表达式的补充)

1.解析 我们一般使用的排序比较大小都是 a>b 那么a在b的前面 ab 无所谓 a<b a在b的后面 本题的排序则是 ab>ba 那么a在b的前面 abba 无所谓 ab<ba a在b的后面 2.代码 class Solution { public:string largestNumber(vector<int>& nums) {//1.先把所有…...

C语言 —— 此去经年梦浪荡魂音 - 深入理解指针(卷二)

目录 1. 数组名与地址 2. 指针访问数组 3.一维数组传参本质 4.二级指针 5. 指针数组 6. 指针数组模拟二维数组 1. 数组名与地址 我们先看下面这个代码&#xff1a; int arr[10] { 1,2,3,4,5,6,7,8,9,10 };int* p &arr[0]; 这里我们使用 &arr[0] 的方式拿到了数…...

python实现简单的图片去水印工具

python实现简单的图片去水印工具 使用说明&#xff1a; 点击"打开图片"选择需要处理的图片 在图片上拖拽鼠标选择水印区域&#xff08;红色矩形框&#xff09; 点击"去除水印"执行处理 点击"保存结果"保存处理后的图片 运行效果 先简要说明…...

使用dify+deepseek部署本地知识库

使用difydeepseek部署本地知识库 一、概述二、安装windows docker desktop1、确认系统的Hyper-v功能正常启用2、docker官网下载安装windows客户端3、安装完成后的界面如下所示 三、下载安装ollama四、部署本地deepseek五、本地下载部署dify5.1 下载dify的安装包5.2 将dify解压到…...

(C语言)指针与指针数组的使用教学(C语言基础教学)(指针教学)

指针是什么&#xff1f;指针怎么用&#xff1f;指针数组又是什么&#xff1f;&#xff1f;&#xff1f; 想必大家刚学C语言的时候对指针可谓是十分头疼了&#xff0c;听也听不懂&#xff0c;用也不会用 下面我来用我的理解来教你指针怎么用&#xff0c;还你一个脑子 1.指针的…...

【算法day13】最长公共前缀

最长公共前缀 https://leetcode.cn/problems/longest-common-prefix/submissions/612055945/ 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀&#xff0c;返回空字符串 “”。 class Solution { public:string longestCommonPrefix(vector<string&g…...

Effective C++ 剖析(条款1~9)

目录 条款01 视C为一个语言联邦(C由几部分组成) 条款02 尽量以 const,enum,inline 替换 #define 条款03 尽量使用 const 条款04 确定对象再使用前已经被初始化 条款05 了解c默默编写并调用那些函数 条款06 若不想使用编译器自动生成的函数就该明确拒绝 条款07 为多态基类…...

【Maven-plugin】有多少官方插件?

之前疏理了容器底层原理&#xff0c;现在回归主题,在阅读 next-public时发现 parent 将从多基础插件集成到 parent 仓库中单独维护&#xff0c;数量众多&#xff0c;故在此将所有插件分类整理。以达观其全貌&#xff0c;心中有数。 以下是 Apache Maven 官方维护的核心插件列表…...

Java高频面试之集合-13

hello啊&#xff0c;各位观众姥爷们&#xff01;&#xff01;&#xff01;本baby今天来报道了&#xff01;哈哈哈哈哈嗝&#x1f436; 面试官&#xff1a;为什么 hash 函数能降哈希碰撞&#xff1f; 哈希函数通过以下核心机制有效降低碰撞概率&#xff0c;确保不同输入尽可能映…...

RGV调度算法(三)--遗传算法

1、基于时间窗 https://wenku.baidu.com/view/470e9fd8b4360b4c2e3f5727a5e9856a57122693.html?_wkts_1741880736197&bdQuery%E7%8E%AF%E7%A9%BF%E8%B0%83%E5%BA%A6%E7%AE%97%E6%B3%95 2.2019年MathorCup高校数学建模挑战赛B题 2019-mathorcupB题-环形穿梭机调度模型&a…...

【数学建模】一致矩阵的应用及其在层次分析法(AHP)中的性质

一致矩阵在层次分析法(AHP)中的应用与性质 在层次分析法(AHP)中&#xff0c;一致矩阵是判断矩阵的一种理想状态&#xff0c;它反映了决策者判断的完全合理性和一致性&#xff0c;也就是为了避免决策者认为“A比B重要&#xff0c;B比C重要&#xff0c;但是C又比A重要”的矛盾。…...

YOLOv8轻量化改进——Coordinate Attention注意力机制

现在针对YOLOv8的架构改进越来越多&#xff0c;今天尝试引入了Coordinate Attention注意力机制以改进对小目标物体的检测效率。 yolov8的下载和安装参考我这篇博客&#xff1a; 基于SeaShips数据集的yolov8训练教程_seaships处理成yolov8-CSDN博客 首先我们可以去官网找到CA注…...

php开发转go的学习计划及课程资料信息

以下是为该课程体系整理的配套教材和教程资源清单,包含书籍、视频、官方文档和实战项目资源,帮助你系统化学习: Go语言学习教材推荐(PHP开发者适配版) 一、核心教材(按学习阶段分类) 1. 基础语法阶段(阶段一) 资源类型名称推荐理由链接/获取方式官方教程Go语言之旅交…...

解释 TypeScript 中的枚举(enum),如何使用枚举定义一组常量?

枚举&#xff08;Enum&#xff09;​ 是 TypeScript 中用于定义一组具名常量的核心类型&#xff0c;通过语义化的命名提升代码可读性&#xff0c;同时利用类型检查减少低级错误。 以下从定义方式、使用建议、注意事项三方面深入解析。 一、枚举的定义方式 1. 数字枚举 特性&…...

基于SpringBoot+Vue的驾校预约管理系统+LW示例参考

1.项目介绍 系统角色&#xff1a;管理员、普通用户、教练功能模块&#xff1a;用户管理、管理员管理、教练管理、教练预约管理、车辆管理、车辆预约管理、论坛管理、基础数据管理等技术选型&#xff1a;SpringBoot&#xff0c;Vue等测试环境&#xff1a;idea2024&#xff0c;j…...

ONNX:统一深度学习工作流的关键枢纽

引言 在深度学习领域&#xff0c;模型创建与部署的割裂曾是核心挑战。不同框架训练的模型难以在多样环境部署&#xff0c;而 ONNX&#xff08;Open Neural Network Exchange&#xff09;作为开放式神经网络交换格式&#xff0c;搭建起从模型创建到部署的统一桥梁&#xff0c;完…...

蓝桥杯————23年省赛 ——————平方差

3.平方差 - 蓝桥云课 一开始看题我还没有意识到问题的严重性 我丢&#xff0c;我想 的是用两层循环来做&#xff0c;后来我试了一下最坏情况&#xff0c;也就是l1 r 1000000000 结果运行半天没运行出来&#xff0c;我就知道坏了&#xff0c;孩子们&#xff0c;要出事&#…...

一、串行通信基础知识

一、串行通信基础知识 1.处理器与外部设备通信有两种方式 并行通信&#xff1a;数据的各个位用多条数据线同时传输。&#xff08;传输速度快&#xff0c;但占用引脚资源多。&#xff09; 串行通信&#xff1a;将数据分成一位一位的形式在一条数据线上逐个传输。&#xff08;线路…...

自带多个接口,完全免费使用!

做自媒体的小伙伴们&#xff0c;是不是经常为语音转文字的事儿头疼&#xff1f; 今天给大家推荐一款超实用的语音转文字软件——AsrTools&#xff0c;它绝对是你的得力助手&#xff01; AsrTools 免费的语音转文字软件 这款软件特别贴心&#xff0c;完全免费&#xff0c;而且操…...

大数据学习(70)-大数据调度工具对比

&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4dd;支持一…...

Qt QML解决SVG图片显示模糊的问题

前言 在QML中直接使用SVG图片&#xff0c;使用Image控件加载资源&#xff0c;显示出来图片是模糊的&#xff0c;很影响使用体验。本文介绍重新绘制SVG图片&#xff0c;然后注册到QML中使用。 效果图&#xff1a; 左边是直接使用Image加载资源显示的效果 右边是重绘后的效果 …...