当前位置: 首页 > article >正文

20250303-代码笔记-class CVRPTester

文章目录

  • 前言
  • 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
    • 1.1函数解析
    • 1.2函数分析
      • 1.2.1加载预训练模型
    • 1.2函数代码
  • 二、class CVRPTester:run(self)
    • 函数解析
    • 函数代码
  • 三、class CVRPTester:_test_one_batch(self, batch_size)
    • 函数解析
    • 函数代码
  • 附录
    • 代码(全)


前言

学习代码CVRPTester.py,对代码的分析如下。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py


一、class CVRPTester:init(self,env_params,model_params, tester_params)

1.1函数解析

执行流程图链接

在这里插入图片描述

1.2函数分析

1.2.1加载预训练模型

代码:

# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
  • model_load: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的 epoch
model_load = tester_params['model_load']
  • checkpoint_fullname: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
    这会生成形如 /path/to/model/checkpoint-8100.pt 的文件路径。即需要输入参数pathepoch
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
  • 加载模型:
    • torch.load(checkpoint_fullname, map_location=device):从磁盘加载模型检查点(即 .pt 文件),并将其存储在 checkpoint 变量中。map_location=device 确保模型会被加载到正确的设备上(GPU 或 CPU)。
    • self.model.load_state_dict(checkpoint['model_state_dict']):从加载的检查点中提取模型的状态字典,并将其加载到 self.model 中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])

示例
假设 tester_params_regret[‘model_load’] 如下所示:

tester_params_regret = {'model_load': {'path': '../../pretrained/vrp100','epoch': 8100,},# 其他参数...
}

然后 checkpoint_fullname 会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt,模型会从该路径加载。
在这里插入图片描述


1.2函数代码

    def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()

二、class CVRPTester:run(self)

函数解析

函数执行流程图链接
在这里插入图片描述

函数代码

    def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))

三、class CVRPTester:_test_one_batch(self, batch_size)

函数解析

执行流程图链接
在这里插入图片描述

函数代码

    def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()

附录

代码(全)


import torchimport os
from logging import getLoggerfrom CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Modelfrom utils.utils import *class CVRPTester:def __init__(self,env_params,model_params,tester_params):# save argumentsself.env_params = env_paramsself.model_params = model_paramsself.tester_params = tester_params# result folder, loggerself.logger = getLogger(name='trainer')self.result_folder = get_result_folder()# cudaUSE_CUDA = self.tester_params['use_cuda']if USE_CUDA:cuda_device_num = self.tester_params['cuda_device_num']torch.cuda.set_device(cuda_device_num)device = torch.device('cuda', cuda_device_num)torch.set_default_tensor_type('torch.cuda.FloatTensor')else:device = torch.device('cpu')torch.set_default_tensor_type('torch.FloatTensor')self.device = device# ENV and MODELself.env = Env(**self.env_params)self.model = Model(**self.model_params)# Restoremodel_load = tester_params['model_load']checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)checkpoint = torch.load(checkpoint_fullname, map_location=device)self.model.load_state_dict(checkpoint['model_state_dict'])# utilityself.time_estimator = TimeEstimator()def run(self):self.time_estimator.reset()score_AM = AverageMeter()aug_score_AM = AverageMeter()if self.tester_params['test_data_load']['enable']:self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)test_num_episode = self.tester_params['test_episodes']episode = 0while episode < test_num_episode:remaining = test_num_episode - episodebatch_size = min(self.tester_params['test_batch_size'], remaining)score, aug_score = self._test_one_batch(batch_size)score_AM.update(score, batch_size)aug_score_AM.update(aug_score, batch_size)episode += batch_size############################# Logs############################elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))all_done = (episode == test_num_episode)if all_done:self.logger.info(" *** Test Done *** ")self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))def _test_one_batch(self, batch_size):# Augmentation###############################################if self.tester_params['augmentation_enable']:aug_factor = self.tester_params['aug_factor']else:aug_factor = 1# Ready###############################################self.model.eval()with torch.no_grad():self.env.load_problems(batch_size, aug_factor)reset_state, _, _ = self.env.reset()self.model.pre_forward(reset_state)# POMO Rollout###############################################state, reward, done = self.env.pre_step()while not done:selected, _ = self.model(state)# shape: (batch, pomo)state, reward, done = self.env.step(selected)# Return###############################################aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)# shape: (augmentation, batch, pomo)max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo# shape: (augmentation, batch)no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive valuemax_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation# shape: (batch,)aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive valuereturn no_aug_score.item(), aug_score.item()

相关文章:

20250303-代码笔记-class CVRPTester

文章目录 前言一、class CVRPTester:__init__(self,env_params,model_params, tester_params)1.1函数解析1.2函数分析1.2.1加载预训练模型 1.2函数代码 二、class CVRPTester:run(self)函数解析函数代码 三、class CVRPTester:_test_one_batch(self, batch_size)函数解析函数代…...

C++学习之C++初识、C++对C语言增强、对C语言扩展

一.C初识 1.C简介 2.第一个C程序 //#include <iostream> //iostream 相当于 C语言下的 stdio.h i - input 输入 o -output 输出 //using namespace std; //using 使用 namespace 命名空间 std 标准 &#xff0c;理解为打开一个房间&#xff0c;房间里有我们所需…...

依赖注入与控制反转什么关系

依赖注入属于控制反转&#xff08;Inversion of Control&#xff0c;IoC&#xff09;设计模式的一种具体实现方式。以下是具体解释&#xff1a; 控制反转 控制反转是一种设计思想&#xff0c;它将对象之间的依赖关系的控制权从对象内部转移到了外部的容器或框架中。在传统的编…...

关于虚拟环境中遇到的bug

conda和cmd介绍 介绍 Conda 概述&#xff1a; Conda是一个开源包管理系统和环境管理系统&#xff0c;尤其适用于Python和R语言的开发环境。它允许用户创建独立的虚拟环境&#xff0c;方便地管理依赖包和软件版本。 特点&#xff1a; 环境管理&#xff1a;可以创建、导入、导…...

【网络安全 | 渗透测试】GraphQL精讲一:基础知识

未经许可,不得转载, 文章目录 GraphQL 定义GraphQL 工作原理GraphQL 模式GraphQL 查询GraphQL 变更(Mutations)查询(Queries)和变更(Mutations)的组成部分字段(Fields)参数(Arguments)变量别名(Aliases)片段(Fragments)订阅(Subscriptions)自省(Introspecti…...

01_NLP基础之文本处理的基本方法

自然语言处理入门 自然语言处理&#xff08;Natural Language Processing, 简称NLP&#xff09;是计算机科学与语言学中关注于计算机与人类语言间转换的领域&#xff0c;主要目标是让机器能够理解和生成自然语言&#xff0c;这样人们可以通过语言与计算机进行更自然的互动。 …...

Oracle 导出所有表索引的创建语句

在Oracle数据库中&#xff0c;导出所有表的索引创建语句通常涉及到使用数据字典视图来查询索引的定义&#xff0c;然后生成对应的SQL语句。你可以通过查询DBA_INDEXES或USER_INDEXES视图&#xff08;取决于你的权限和需求&#xff09;来获取这些信息。 使用DBA_INDEXES视图 如…...

什么是JTAG、SWD?

一、什么是JTAG&#xff1f; JTAG&#xff08;Joint Test Action Group&#xff0c;联合测试行动小组&#xff09;是一种国际标准测试协议&#xff0c;常用于芯片内部测试及对系统进行调试、编程等操作。以下从其起源、工作原理、接口标准、应用场景等方面详细介绍&#xff1a…...

物理竞赛中的线性代数

线性代数 1 行列式 1.1 n n n 阶行列式 定义 1.1.1&#xff1a;称以下的式子为一个 n n n 阶行列式&#xff1a; ∣ A ∣ ∣ a 11 a 12 ⋯ a 1 n a 21 a 22 ⋯ a 2 n ⋮ ⋮ ⋱ ⋮ a n 1 a n 2 ⋯ a n n ∣ \begin{vmatrix}\mathbf A\end{vmatrix} \begin{vmatrix} a_{11…...

如何在Apple不再支持的MacOS上安装Homebrew

手头有一台2012年产的Macbook Pro&#xff0c;系统版本停留在了10.15.7&#xff08;2020年9月24日发布的&#xff09;。MacOS 11及后续的版本都无法安装到这台老旧的电脑上。想通过pkg安装Homebrew&#xff0c;发现Homebrew releases里最新的pkg安装包不支持MacOS 10.15.7&…...

深入探索DeepSeek开源之旅:开源Week全程解析

摘要 在农历新年刚刚结束之际&#xff0c;DeepSeek以卓越的开源精神&#xff0c;连续六天举办了开源Week活动。这一系列活动不仅展示了DeepSeek在技术领域的活跃度和影响力&#xff0c;还彰显了其对开源社区的贡献。通过这次活动&#xff0c;DeepSeek吸引了众多开发者和技术爱好…...

无人机遥控器无线传输技术解析!

一、主流无线传输方式 无线电遥控系统&#xff08;2.4GHz/5.8GHz频段&#xff09; 频段特性&#xff1a;2.4GHz频段穿透力强、覆盖距离远&#xff08;可达2公里以上&#xff09;&#xff0c;适合控制信号传输&#xff1b;5.8GHz频段带宽更高&#xff0c;适用于高清视频流&…...

微调训练方法概述:Fine-tuning、Prompt-tuning、P-tuning 及其他高效技术

在深度学习和自然语言处理&#xff08;NLP&#xff09;领域&#xff0c;预训练模型&#xff08;如 GPT、BERT 等&#xff09;已经成为许多任务的基石。然而&#xff0c;尽管这些模型在预训练阶段学习了大量的通用知识&#xff0c;它们通常仍然需要根据特定任务进行微调&#xf…...

在笔记本电脑上用DeepSeek搭建个人知识库

最近DeepSeek爆火&#xff0c;试用DeepSeek的企业和个人越来越多。最常见的应用场景就是知识库和知识问答。所以本人也试用了一下&#xff0c;在笔记本电脑上部署DeepSeek并使用开源工具搭建一套知识库&#xff0c;实现完全在本地环境下使用本地文档搭建个人知识库。操作过程共…...

Java面试第七山!《MySQL索引》

一、索引的本质与作用 索引是帮助MySQL高效获取数据的数据结构&#xff0c;类似于书籍的目录。它通过减少磁盘I/O次数&#xff08;即减少数据扫描量&#xff09;来加速查询&#xff0c;尤其在百万级数据场景下&#xff0c;索引可将查询效率提升数十倍。 核心作用&#xff1a;…...

DeepSeek 助力 Vue3 开发:打造丝滑的弹性布局(Flexbox)

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…...

大白话跨域问题的原理与多种解决方法的实现

大白话跨域问题的原理与多种解决方法的实现 跨域问题原理 简单来说&#xff0c;当一个网页中的JavaScript代码想要去访问另一个不同域名、端口或协议的服务器上的数据时&#xff0c;就会出现跨域问题。这是浏览器的一种安全机制&#xff0c;为了防止恶意网站窃取用户信息等。…...

el-table input textarea 文本域 自适应高度,切换分页滚动失效处理办法

场景&#xff1a; el-table 表格 需要 input类型是 textarea 高度是自适应&#xff0c;第一页数据都是单行数据 不会产生滚动条&#xff0c;但是第二页数据是多行数据 会产生滚动条&#xff0c; bug: 第一页切换到第二页 第二页滚动条无法展示 解决办法&#xff1a;直接修改样…...

动态表头报表的绘制与导出

目录 一、效果图 二、整体思路 三、代码区 一、效果图 根据选择的日期范围动态生成表头&#xff08;eg&#xff1a;2025.2.24--2025.03.03&#xff09;每个日期又分为白班、夜班&#xff1b;数据列表中对产线合并单元格。支持按原格式导出对应的报表excel。 点击空白区可新…...

DeepSeek 助力 Vue3 开发:打造丝滑的密码输入框(Password Input)

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…...

【解决】OnTriggerEnter/OnTriggerExit 调用匿名委托误区的问题

开发平台&#xff1a;Unity 开发语言&#xff1a;CSharp 6.0 开发工具&#xff1a;Visual Studio 2022   问题背景 public void OnTriggerEnter(Collider collider) {output.OnInteractionNoticed () > OnInteractionTriggered?.Invoke(); }public void OnTriggerExit(C…...

Linux 基础---文件权限

概念 文件权限是针对文件所有者、文件所属组、其他人这三类人而言的&#xff0c;对应的操作是chmod。设置方式&#xff1a;文字设定法、数字设定法。 文字设定法&#xff1a;r,w,x,- 来描述用户对文件的操作权限数字设定法&#xff1a;0,1,2,3,4,5,6,7 来描述用户对文件的操作…...

SpringBoot五:JSR303校验

精心整理了最新的面试资料和简历模板&#xff0c;有需要的可以自行获取 点击前往百度网盘获取 点击前往夸克网盘获取 松散绑定 意思是比如在yaml中写的是last-name&#xff0c;这个和lastName意思是一样的&#xff0c;-后的字母默认是大写的 JSR303校验 就是可以在字段增加…...

【计算机网络】考研复试高频知识点总结

文章目录 一、基础概念1、计算机⽹络的定义2、计算机⽹络的目标3、计算机⽹络的组成4、计算机⽹络的分类5、计算机⽹络的拓扑结构6、计算机⽹络的协议7、计算机⽹络的分层结构8、OSI 参考模型9、TCP/IP 参考模型10、五层协议体系结构 二、物理层1、物理层的功能2、传输媒体3、 …...

Error Density-dependent Empirical Risk Minimization

经验误差密度依赖的风险最小化 v.s. 经验风险最小化 论文&#xff1a; 《 Error Density-dependent Empirical Risk Minimization》 发表在&#xff1a; ESWA’24 相关代码&#xff1a; github.com/zxlml/EDERM 研究背景 传统的经验风险最小化&#xff08;ERM&#xff09;方…...

02_NLP文本预处理之文本张量表示法

文本张量表示法 概念 将文本使用张量进行表示,一般将词汇表示为向量,称为词向量,再由各个词向量按顺序组成矩阵形成文本表示 例如: ["人生", "该", "如何", "起头"]># 每个词对应矩阵中的一个向量 [[1.32, 4,32, 0,32, 5.2],[3…...

Spring Boot全局异常处理:“危机公关”团队

目录 一、全局异常处理的作用二、Spring Boot 实现全局异常处理&#xff08;附上代码实例&#xff09;三、总结&#xff1a; &#x1f31f;我的其他文章也讲解的比较有趣&#x1f601;&#xff0c;如果喜欢博主的讲解方式&#xff0c;可以多多支持一下&#xff0c;感谢&#x1…...

Ubuntu 下 nginx-1.24.0 源码分析 - ngx_list_init

ngx_list_init 定义在 src\core\ngx_list.h static ngx_inline ngx_int_t ngx_list_init(ngx_list_t *list, ngx_pool_t *pool, ngx_uint_t n, size_t size) {list->part.elts ngx_palloc(pool, n * size);if (list->part.elts NULL) {return NGX_ERROR;}list->par…...

C# OnnxRuntime部署DAMO-YOLO香烟检测

目录 说明 效果 模型信息 项目 代码 下载 参考 说明 效果 模型信息 Model Properties ------------------------- --------------------------------------------------------------- Inputs ------------------------- name&#xff1a;input tensor&#xff1a;Floa…...

GitHub开源协议选择指南:如何为你的项目找到最佳“许可证”?

引言 当你站在GitHub仓库创建的十字路口时&#xff0c;是否曾被众多开源协议晃花了眼&#xff1f; 别担心&#xff01;这篇指南将化身你的"协议导航仪"&#xff0c;用一张流程图五个灵魂拷问&#xff0c;帮你轻松找到最佳选择。无论你是开发者、开源爱好者&#xff…...