强化学习嵌入Transformer(代码实践)
这里写目录标题
- ChatGPT的答案
- GPT4.0
ChatGPT的答案
# 定义Transformer模块
class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads)self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)self.linear = nn.Linear(input_dim, hidden_dim)def forward(self, x):x = self.transformer_encoder(x)x = self.linear(x)return x# 定义策略网络
class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim, hidden_dim, num_heads, num_layers):super(PolicyNetwork, self).__init__()self.transformer = Transformer(input_dim, hidden_dim, num_heads, num_layers)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = self.transformer(x)x = F.softmax(self.fc(x), dim=-1)return x
问:为什么只有Encoder 没有decoder?
在强化学习任务中,特别是在策略梯度方法中,通常不需要一个完整的Transformer模型,包括Encoder和Decoder。这是因为在这些任务中,我们通常只需要一个从状态到动作的映射,而不需要生成序列。
因此,我们可以只使用Transformer的Encoder部分,将状态作为输入,经过编码后得到一个表示状态的向量,然后将这个向量传递给策略网络(或者Critic网络)来进行动作选择或值函数估计。
使用Transformer的Encoder部分可以有效地处理状态的变长输入,并且具有对序列建模的能力,这在处理许多强化学习任务中是非常有用的。
因此,在PPO算法中,我们只需使用Transformer的Encoder部分,而不需要Decoder部分。在强化学习中,尤其是在处理具有复杂状态和动作空间的任务时,能够捕获状态序列间的依赖关系对于预测未来奖励和决策至关重要。Transformer通过其自注意力机制,能够有效地处理这种序列依赖,从而提升策略性能。
GPT4.0
# time: 2024/3/1 21:31
# author: YanJP
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler# Transformer特征提取器
class TransformerFeatureExtractor(nn.Module):def __init__(self, input_dim, model_dim, num_heads, num_layers, dropout=0.1):super(TransformerFeatureExtractor, self).__init__()self.model_dim = model_dimencoder_layers = TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout)self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layers, num_layers=num_layers)self.input_projection = nn.Linear(input_dim, model_dim)def forward(self, x):x = self.input_projection(x)# 添加一个序列长度的维度,假设x的原始形状为(batch, model_dim)x = x.unsqueeze(0) # 现在x的形状变为(1, batch, model_dim)output = self.transformer_encoder(x)return output.squeeze(0) # 移除序列长度的维度,恢复到(batch, model_dim)# PPO网络定义
class PPONetwork(nn.Module):def __init__(self, state_dim, action_dim, model_dim=64, num_heads=4, num_layers=4):super(PPONetwork, self).__init__()self.feature_extractor = TransformerFeatureExtractor(input_dim=state_dim, model_dim=model_dim,num_heads=num_heads, num_layers=num_layers)self.policy_head = nn.Linear(model_dim, action_dim)self.value_head = nn.Linear(model_dim, 1)def forward(self, state):features = self.feature_extractor(state)# features = features[:, -1, :] # 使用最后一个时间步的特征action_probs = torch.softmax(self.policy_head(features), dim=-1)state_values = self.value_head(features)return action_probs, state_values# PPO Agent
class PPOAgent:def __init__(self, env):self.env = envself.state_dim = env.observation_space.shape[0]self.action_dim = env.action_space.nself.network = PPONetwork(self.state_dim, self.action_dim)self.optimizer = optim.Adam(self.network.parameters(), lr=2.5e-4)self.gamma = 0.99self.lamda = 0.95self.eps_clip = 0.2self.K_epoch = 4self.buffer_capacity = 1000self.batch_size = 64self.buffer = {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'is_terminals': []}def select_action(self, state):state = torch.FloatTensor(state).unsqueeze(0)with torch.no_grad():action_probs, _ = self.network(state)dist = Categorical(action_probs)action = dist.sample()return action.item(), dist.log_prob(action)def put_data(self, transition):self.buffer['states'].append(transition[0])self.buffer['actions'].append(transition[1])self.buffer['log_probs'].append(transition[2])self.buffer['rewards'].append(transition[3])self.buffer['is_terminals'].append(transition[4])def train_net(self):R = 0discounted_rewards = []for reward, is_terminal in zip(reversed(self.buffer['rewards']), reversed(self.buffer['is_terminals'])):if is_terminal:R = 0R = reward + (self.gamma * R)discounted_rewards.insert(0, R)discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32)old_states = torch.tensor(np.array(self.buffer['states']), dtype=torch.float32)old_actions = torch.tensor(self.buffer['actions']).view(-1, 1)old_log_probs = torch.tensor(self.buffer['log_probs']).view(-1, 1)# Normalize the rewardsdiscounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-5)for _ in range(self.K_epoch):for index in BatchSampler(SubsetRandomSampler(range(len(self.buffer['states']))), self.batch_size, False):# Extract batchesstate_sample = old_states[index]action_sample = old_actions[index]old_log_probs_sample = old_log_probs[index]returns_sample = discounted_rewards[index].view(-1, 1)# Get current policiesaction_probs, state_values = self.network(state_sample)dist = Categorical(action_probs)entropy = dist.entropy().mean()new_log_probs = dist.log_prob(action_sample.squeeze(-1))# Calculating the ratio (pi_theta / pi_theta__old):ratios = torch.exp(new_log_probs - old_log_probs_sample.detach())# Calculating Surrogate Loss:advantages = returns_sample - state_values.detach()surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantagesloss = -torch.min(surr1, surr2) + 0.5 * (state_values - returns_sample).pow(2) - 0.01 * entropy# take gradient stepself.optimizer.zero_grad()loss.mean().backward()self.optimizer.step()self.buffer = {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'is_terminals': []}def train(self, max_episodes):for episode in range(max_episodes):state = self.env.reset()done = Falserewards=0while not done:action, log_prob = self.select_action(state)next_state, reward, done, _ = self.env.step(action)rewards+=rewardself.put_data((state, action, log_prob, reward, done))state = next_stateif done:self.train_net()if episode % 5 == 0:print("eposide:", episode, '\t reward:', rewards)# 主函数
def main():env = gym.make('CartPole-v1')agent = PPOAgent(env)max_episodes = 300agent.train(max_episodes)if __name__ == "__main__":main()
注意:代码能跑,但是不能正常学习到策略!!!!!!!!!!!!!!!!!!!!!!!!!!!!
相关文章:
强化学习嵌入Transformer(代码实践)
这里写目录标题 ChatGPT的答案GPT4.0 ChatGPT的答案 # 定义Transformer模块 class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.encoder_layer nn.TransformerEncoderLayer(d_modeli…...

决定西弗吉尼亚州地区版图的关键历史事件
决定西弗吉尼亚州地区版图的关键历史事件: 1. 内部分裂与美国内战: - 在1861年美国内战爆发时,弗吉尼亚州作为南方邦联的一员宣布退出美利坚合众国。然而,弗吉尼亚州西部的一些县由于经济结构(主要是农业非依赖奴隶制…...
LeetCode_22_中等_括号生成
文章目录 1. 题目2. 思路及代码实现(Python)2.1 暴力法2.2 回溯法 1. 题目 数字 n n n 代表生成括号的对数,请你设计一个函数,用于能够生成所有可能的并且 有效的 括号组合。 示例 1: 输入: n 3 n 3 …...

Verilog(未完待续)
Verilog教程 这个教程写的很好,可以多看看。本篇还没整理完。 一、Verilog简介 什么是FPGA?一种可通过编程来修改其逻辑功能的数字集成电路(芯片) 与单片机的区别?对单片机编程并不改变其地电路的内部结构࿰…...

【Linux实践室】Linux初体验
🌈个人主页:聆风吟 🔥系列专栏:Linux实践室、网络奇遇记 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 一. ⛳️任务描述二. ⛳️相关知识2.1 🔔Linux 目录结构介绍2.2 🔔Linux …...

Flutter中高级JSON处理:使用json_serializable进行深入定制
Flutter中高级JSON处理 使用json_serializable库进行深入定制 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at: https://jclee95.blog.csdn.netEmail: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/qq_28550263/article/details/1363…...
华为OD技术面试案例4-2024年
个人情况:985本,目标院校非计算机专业,情况比较特殊,23年11月研究生退学,电子信息类专业。 初识od:10月底打算退学的时候在智联、BOSS上疯狂投硬件方面的岗位。投了大概一两天后有德科和HW的HR打电话给我介…...
【TestNG】(4) 重试机制与监听器的使用
在UI自动化测试用例执行过程中,经常会有很多不确定的因素导致用例执行失败,比如网络原因、环境问题等,所以我们有必要引入重试机制(失败重跑),来提高测试用例成功率。 在不写代码的情况没有提供可配置方式…...

“智农”-高标准农田
高标准农田是指通过土地整治、土壤改良、水利设施、农电配套、机械化作业等措施,提升农田质量和生产能力,达到田块平整、集中连片、设施完善、节水高效、宜机作业、土壤肥沃、生态友好、抗灾能力强、与现代农业生产和经营方式相适应的旱涝保收、稳产高产…...
利用 lxml 库的XPath()方法在网页中快速查找元素
XPath() 函数是 lxml 库中 Element 对象的方法。在使用 lxml 库解析 HTML 或 XML 文档时,您可以通过创建 Element 对象来表示文档的元素,然后使用 Element 对象的 XPath() 方法来执行 XPath 表达式并选择相应的元素。 具体而言,XPath() 方法是…...

nginx---------------重写功能 防盗链 反向代理 (五)
一、重写功能 rewrite Nginx服务器利用 ngx_http_rewrite_module 模块解析和处理rewrite请求,此功能依靠 PCRE(perl compatible regular expression),因此编译之前要安装PCRE库,rewrite是nginx服务器的重要功能之一,重写功能(…...

unity shaderGraph实例-物体线框显示
文章目录 本项目基于URP实现一,读取UV网格,由自定义shader实现效果优缺点效果展示模型准备整体结构各区域内容区域1区域2区域3区域4shader属性颜色属性材质属性后处理 实现二,直接使用纹理,使用默认shader实现优缺点贴图准备材质准…...

分类问题经典算法 | 二分类问题 | Logistic回归:公式推导
目录 一. Logistic回归的思想1. 分类任务思想2. Logistic回归思想 二. Logistic回归算法:线性可分推导 一. Logistic回归的思想 1. 分类任务思想 分类问题通常可以分为二分类,多分类任务;而对于不同的分类任务,训练的主要目标是…...

redis实现分布式全局唯一id
目录 一、前言二、如何通过Redis设计一个分布式全局唯一ID生成工具2.1 使用 Redis 计数器实现2.2 使用 Redis Hash结构实现 三、通过代码实现分布式全局唯一ID工具3.1 导入依赖配置3.2 配置yml文件3.3 序列化配置3.4 编写获取工具3.5 测试获取工具 四、运行结果 一、前言 在很…...

Sora引发安全新挑战
文章目录 前言一、如何看待Sora二、Sora加剧“深度伪造”忧虑三、Sora无法区分对错四、滥用导致的安全危机五、Sora面临的安全挑战总结前言 今年2月,美国人工智能巨头企业OpenAI再推行业爆款Sora,将之前ChatGPT以图文为主的生成式内容全面扩大到视频领域,引发了全球热议,这…...
Android 14.0 Launcher3定制化之桌面分页横线改成圆点显示功能实现
1.前言 在14.0的系统rom产品定制化开发中,在进行launcher3的定制化中,在双层改为单层的开发中,在原生的分页 是横线,而为了美观就采用了系统原来的另外一种分页方式,就是圆点比较美观,接下来就来分析下相关…...

SemiDrive E3 MCAL 开发系列(3)– Wdg 模块的使用
一、 概述 本文将会介绍 SemiDrive E3 MCAL Wdg 模块的基本配置,并且会结合实际操作的介绍,帮助新手快速了解并掌握这个模块的使用,文中的 MCAL 是基于 PTG3.0 的版本,开发板是官方的 E3640 网关板。 二、 Wdg 模块的主要配置 …...
AI推荐算法的演进之路
推荐算法 基于大数据和AI技术,提供全流程一站式推荐平台,协助企业构建个性化推荐应用,提升企业应用的点击率留存率和永久体验。目前,主要的推荐方法包括:基于内容推荐、协同过滤推荐、基于关联规则推荐、基于效用推荐…...

Tomcat安装,配置文件、组件
一、Tomcat的基本功能 1.1.Tomcat是什么? Tomcat服务器是一个免费的开放源代码的Web应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP程序的首选。一般来说,T…...
精读《React Hooks 最佳实践》
简介 React 16.8 于 2019.2 正式发布,这是一个能提升代码质量和开发效率的特性,笔者就抛砖引玉先列出一些实践点,希望得到大家进一步讨论。 然而需要理解的是,没有一个完美的最佳实践规范,对一个高效团队来说&#x…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...

NFT模式:数字资产确权与链游经济系统构建
NFT模式:数字资产确权与链游经济系统构建 ——从技术架构到可持续生态的范式革命 一、确权技术革新:构建可信数字资产基石 1. 区块链底层架构的进化 跨链互操作协议:基于LayerZero协议实现以太坊、Solana等公链资产互通,通过零知…...
uniapp中使用aixos 报错
问题: 在uniapp中使用aixos,运行后报如下错误: AxiosError: There is no suitable adapter to dispatch the request since : - adapter xhr is not supported by the environment - adapter http is not available in the build 解决方案&…...

C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。
1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj,再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...
重启Eureka集群中的节点,对已经注册的服务有什么影响
先看答案,如果正确地操作,重启Eureka集群中的节点,对已经注册的服务影响非常小,甚至可以做到无感知。 但如果操作不当,可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...

算法:模拟
1.替换所有的问号 1576. 替换所有的问号 - 力扣(LeetCode) 遍历字符串:通过外层循环逐一检查每个字符。遇到 ? 时处理: 内层循环遍历小写字母(a 到 z)。对每个字母检查是否满足: 与…...
现有的 Redis 分布式锁库(如 Redisson)提供了哪些便利?
现有的 Redis 分布式锁库(如 Redisson)相比于开发者自己基于 Redis 命令(如 SETNX, EXPIRE, DEL)手动实现分布式锁,提供了巨大的便利性和健壮性。主要体现在以下几个方面: 原子性保证 (Atomicity)ÿ…...

STM32---外部32.768K晶振(LSE)无法起振问题
晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...

stm32wle5 lpuart DMA数据不接收
配置波特率9600时,需要使用外部低速晶振...