分类任务实现模型集成代码模版
分类任务实现模型(投票式)集成代码模版
简介
本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。
代码
# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------ log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma) # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,# gamma=self.args.lr_gamma) # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}] ''Train Loss avg: {loss_train:>6.4f} ''Valid Loss avg: {loss_valid:>6.4f} ''Train Acc@1 avg: {top1_train:>7.2f}% ''Valid Acc@1 avg: {top1_valid:>7.2%} ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()
效果图
本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。



参考
7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)
相关文章:
分类任务实现模型集成代码模版
分类任务实现模型(投票式)集成代码模版 简介 本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化࿰…...
从Milvus迁移DashVector
本文档演示如何从Milvus将Collection数据全量导出,并适配迁移至DashVector。方案的主要流程包括: 首先,升级Milvus版本,目前Milvus只有在最新版本(v.2.3.x)中支持全量导出其次,将Milvus Collection的Schema信息和数据…...
彻底改变计算机视觉的 Vision Transformer (ViT) 综合指南(视觉转换器终极指南)
欢迎来到雲闪世界。大家好!对于那些还不认识我的人,我叫 Francois,我是 Meta 的研究科学家。我热衷于解释先进的 AI 概念并使其更容易理解。 今天,让我们深入探讨计算机视觉领域最重要的贡献之一:Vision Transformer&…...
vue3 v-bind=“$attrs“ 的一些理解,透传 Attributes相关说明及事例说明
1、可能小伙伴们经常会在自己的项目中看到v-bind"$attrs",这个一般是在自定义组件中看到。 比如: <template><BasicModalv-bind"$attrs"register"registerModal":title"getTitle"ok"handleSubm…...
鸿蒙开发基础知识-页面布局【第四篇】
1.类型转换 2.交互点击事件 3.状态管理 4.forEch渲染和右上角图标 测试案例 Stack 层叠布局一个生肖卡 5. 动画展示图片 6. Swiper 轮播组件的基本使用 图片等比显示 aspectRatio()...
用CSS实现前端响应式布局
一、响应式布局的重要性 随着移动设备的普及,越来越多的用户通过手机、平板电脑等设备访问网页。如果网页不能适应不同的屏幕尺寸,就会出现布局混乱、内容显示不全等问题,严重影响用户体验。响应式布局可以确保网页在各种设备上都能保持美观…...
【docker】docker启动sqlserver
sqlserver-docker官方地址 # sqlserver不是从docker的中央仓库拉取的,而是从ms的仓库拉取的。 docker pull mcr.microsoft.com/mssql/server:2019-latest# 宿主机即docker程序运行的linux服务器 docker run -d \ --user root \ --name mssql2019 \ -e "ACCEPT…...
Python爬虫01
requests模块 文档 安装 pip/pip3 install requestsresponse.text 和 response.content的区别 1.response.text 等价于 response.content.decode("推测出的编码字符集")response.text 类型:str 编码类型:requests模块自动根据Http头部对…...
关于vue项目启动报错Error: error:0308010C:digital envelope routines::unsupported
周五啦,总结一下这周遇到的个别问题吧,就是关于启动项目的时候其他的东西都准备好了,执行命令后报错Error: error:0308010C:digital envelope routines::unsupported 这里看一下我标注的地方,然后总结一下就不难发现问题所在 查看…...
随笔1:数学建模与数值计算
目录 1.1 矩阵运算 1.2 基本数学函数 1.3 数值求解 数学建模与数值计算 是将实际问题通过数学公式和模型进行描述,并通过计算获得模型解的过程。这是数学建模中最基本也是最重要的环节之一。下面是详细的知识点讲解及相应的MATLAB代码示例。 1.1 矩阵运算 知识点…...
SDN架构详解
目录 1)经典的IP网络-分布式网络 2)经典网络面临的问题 3)SDN起源 4)OpenFlow基本概念 5)Flow Table简介 6)SDN的网络架构 7)华为SDN网络架构 8)传统网络 vs SDN 9…...
platform框架
platform框架 注册设备进入总线platform_device_register函数 注册驱动进入总线platform_driver_register函数 注册设备进入总线 platform_device_register函数 int platform_device_register(struct platform_device *pdev) struct platform_device {const char * name; 名…...
零成本搞定静态博客——十分钟安装hugo与主题
文章目录 hugo介绍hugo安装与使用方式一:新建站点自建主题方式二:新建站点使用系统推荐的主题 hugo介绍 通过 Hugo 你可以快速搭建你的静态网站,比如博客系统、文档介绍、公司主页、产品介绍等等。相对于其他静态网站生成器来说,…...
windows C++ 并行编程-转换使用取消的 OpenMP 循环以使用并发运行时
某些并行循环不需要执行所有迭代。 例如,搜索值的算法可以在找到值后终止。 OpenMP 不提供中断并行循环的机制。 但是,可以使用布尔值或标志来启用循环迭代,以指示已找到解决方案。 并发运行时提供允许一个任务取消其他尚未启动的任务的功能。…...
经验笔记:跨站脚本攻击(Cross-Site Scripting,简称XSS)
跨站脚本攻击(Cross-Site Scripting,简称XSS)经验笔记 跨站脚本攻击(XSS:Cross-Site Scripting)是一种常见的Web应用程序安全漏洞,它允许攻击者将恶意脚本注入到看起来来自可信网站的网页上。当…...
演示:基于WPF的DrawingVisual和谷歌地图瓦片开发的地图(完全独立不依赖第三方库)
一、目的:基于WPF的DrawingVisual和谷歌地图瓦片开发的地图 二、预览 三、环境 VS2022,Net7,DrawingVisual,谷歌地图瓦片 四、主要功能 地图缩放,平移,定位 真实经纬度 显示瓦片信息 显示真实经纬度和经纬线 省市县…...
【C++】static作用总结
文章目录 1. 在函数内(局部静态变量)2. 在类中的静态成员变量3. 在类中的静态成员函数4. 在文件/模块中的静态变量或函数总结 1. 在函数内(局部静态变量) 当 static 用于函数内的局部变量时,该变量的生命周期变为整个…...
视频提取字幕的软件有哪些?高效转录用这些
探索视频的奥秘,从字幕开始!你是否曾被繁复的字幕处理困扰,渴望有一款简单好用的在线免费软件来轻松解锁字幕提取? 告别手动输入的烦恼,我们为你精选了6款视频字幕提取在线免费软件,它们不仅能一键转录&am…...
(4)SVG-path中的椭圆弧A(绝对)或a(相对)
1、概念 表示经过起始点(即上一条命令的结束点),到结束点之间画一段椭圆弧 2、7个参数 rx,ry,x-axis-rotation,large-arc-flag,sweep-flag,x,y (1)和(2&a…...
docker国内镜像源报错解决方案
Job for docker.service failed because the control process exited with error code. See "systemctl status docker.service" and "journalctl -xe" for details. 遇到 Job for docker.service failed because the control process exited with error …...
前端倒计时误差!
提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...
让AI看见世界:MCP协议与服务器的工作原理
让AI看见世界:MCP协议与服务器的工作原理 MCP(Model Context Protocol)是一种创新的通信协议,旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天,MCP正成为连接AI与现实世界的重要桥梁。…...
SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题
分区配置 (ptab.json) img 属性介绍: img 属性指定分区存放的 image 名称,指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件,则以 proj_name:binary_name 格式指定文件名, proj_name 为工程 名&…...
Java + Spring Boot + Mybatis 实现批量插入
在 Java 中使用 Spring Boot 和 MyBatis 实现批量插入可以通过以下步骤完成。这里提供两种常用方法:使用 MyBatis 的 <foreach> 标签和批处理模式(ExecutorType.BATCH)。 方法一:使用 XML 的 <foreach> 标签ÿ…...
RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill
视觉语言模型(Vision-Language Models, VLMs),为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展,机器人仍难以胜任复杂的长时程任务(如家具装配),主要受限于人…...
[ACTF2020 新生赛]Include 1(php://filter伪协议)
题目 做法 启动靶机,点进去 点进去 查看URL,有 ?fileflag.php说明存在文件包含,原理是php://filter 协议 当它与包含函数结合时,php://filter流会被当作php文件执行。 用php://filter加编码,能让PHP把文件内容…...
android13 app的触摸问题定位分析流程
一、知识点 一般来说,触摸问题都是app层面出问题,我们可以在ViewRootImpl.java添加log的方式定位;如果是touchableRegion的计算问题,就会相对比较麻烦了,需要通过adb shell dumpsys input > input.log指令,且通过打印堆栈的方式,逐步定位问题,并找到修改方案。 问题…...
Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement
Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement 1. LAB环境2. L2公告策略2.1 部署Death Star2.2 访问服务2.3 部署L2公告策略2.4 服务宣告 3. 可视化 ARP 流量3.1 部署新服务3.2 准备可视化3.3 再次请求 4. 自动IPAM4.1 IPAM Pool4.2 …...
