当前位置: 首页 > news >正文

分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------  log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma)  # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,#                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}]  ''Train Loss avg: {loss_train:>6.4f}  ''Valid Loss avg: {loss_valid:>6.4f}  ''Train Acc@1 avg:  {top1_train:>7.2f}%   ''Valid Acc@1 avg: {top1_valid:>7.2%}    ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

相关文章:

分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版 简介 本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化&#xff0…...

从Milvus迁移DashVector

本文档演示如何从Milvus将Collection数据全量导出,并适配迁移至DashVector。方案的主要流程包括: 首先,升级Milvus版本,目前Milvus只有在最新版本(v.2.3.x)中支持全量导出其次,将Milvus Collection的Schema信息和数据…...

彻底改变计算机视觉的 Vision Transformer (ViT) 综合指南(视觉转换器终极指南)

欢迎来到雲闪世界。大家好!对于那些还不认识我的人,我叫 Francois,我是 Meta 的研究科学家。我热衷于解释先进的 AI 概念并使其更容易理解。 今天,让我们深入探讨计算机视觉领域最重要的贡献之一:Vision Transformer&…...

vue3 v-bind=“$attrs“ 的一些理解,透传 Attributes相关说明及事例说明

1、可能小伙伴们经常会在自己的项目中看到v-bind"$attrs"&#xff0c;这个一般是在自定义组件中看到。 比如&#xff1a; <template><BasicModalv-bind"$attrs"register"registerModal":title"getTitle"ok"handleSubm…...

鸿蒙开发基础知识-页面布局【第四篇】

1.类型转换 2.交互点击事件 3.状态管理 4.forEch渲染和右上角图标 测试案例 Stack 层叠布局一个生肖卡 5. 动画展示图片 6. Swiper 轮播组件的基本使用 图片等比显示 aspectRatio&#xff08;&#xff09;...

用CSS实现前端响应式布局

一、响应式布局的重要性 随着移动设备的普及&#xff0c;越来越多的用户通过手机、平板电脑等设备访问网页。如果网页不能适应不同的屏幕尺寸&#xff0c;就会出现布局混乱、内容显示不全等问题&#xff0c;严重影响用户体验。响应式布局可以确保网页在各种设备上都能保持美观…...

【docker】docker启动sqlserver

sqlserver-docker官方地址 # sqlserver不是从docker的中央仓库拉取的&#xff0c;而是从ms的仓库拉取的。 docker pull mcr.microsoft.com/mssql/server:2019-latest# 宿主机即docker程序运行的linux服务器 docker run -d \ --user root \ --name mssql2019 \ -e "ACCEPT…...

Python爬虫01

requests模块 文档 安装 pip/pip3 install requestsresponse.text 和 response.content的区别 1.response.text 等价于 response.content.decode("推测出的编码字符集")response.text 类型&#xff1a;str 编码类型&#xff1a;requests模块自动根据Http头部对…...

关于vue项目启动报错Error: error:0308010C:digital envelope routines::unsupported

周五啦&#xff0c;总结一下这周遇到的个别问题吧&#xff0c;就是关于启动项目的时候其他的东西都准备好了&#xff0c;执行命令后报错Error: error:0308010C:digital envelope routines::unsupported 这里看一下我标注的地方&#xff0c;然后总结一下就不难发现问题所在 查看…...

随笔1:数学建模与数值计算

目录 1.1 矩阵运算 1.2 基本数学函数 1.3 数值求解 数学建模与数值计算 是将实际问题通过数学公式和模型进行描述&#xff0c;并通过计算获得模型解的过程。这是数学建模中最基本也是最重要的环节之一。下面是详细的知识点讲解及相应的MATLAB代码示例。 1.1 矩阵运算 知识点…...

SDN架构详解

目录 1&#xff09;经典的IP网络-分布式网络 2&#xff09;经典网络面临的问题 3&#xff09;SDN起源 4&#xff09;OpenFlow基本概念 5&#xff09;Flow Table简介 6&#xff09;SDN的网络架构 7&#xff09;华为SDN网络架构 8&#xff09;传统网络 vs SDN 9&#xf…...

platform框架

platform框架 注册设备进入总线platform_device_register函数 注册驱动进入总线platform_driver_register函数 注册设备进入总线 platform_device_register函数 int platform_device_register(struct platform_device *pdev) struct platform_device {const char * name; 名…...

零成本搞定静态博客——十分钟安装hugo与主题

文章目录 hugo介绍hugo安装与使用方式一&#xff1a;新建站点自建主题方式二&#xff1a;新建站点使用系统推荐的主题 hugo介绍 通过 Hugo 你可以快速搭建你的静态网站&#xff0c;比如博客系统、文档介绍、公司主页、产品介绍等等。相对于其他静态网站生成器来说&#xff0c;…...

windows C++ 并行编程-转换使用取消的 OpenMP 循环以使用并发运行时

某些并行循环不需要执行所有迭代。 例如&#xff0c;搜索值的算法可以在找到值后终止。 OpenMP 不提供中断并行循环的机制。 但是&#xff0c;可以使用布尔值或标志来启用循环迭代&#xff0c;以指示已找到解决方案。 并发运行时提供允许一个任务取消其他尚未启动的任务的功能。…...

经验笔记:跨站脚本攻击(Cross-Site Scripting,简称XSS)

跨站脚本攻击&#xff08;Cross-Site Scripting&#xff0c;简称XSS&#xff09;经验笔记 跨站脚本攻击&#xff08;XSS&#xff1a;Cross-Site Scripting&#xff09;是一种常见的Web应用程序安全漏洞&#xff0c;它允许攻击者将恶意脚本注入到看起来来自可信网站的网页上。当…...

演示:基于WPF的DrawingVisual和谷歌地图瓦片开发的地图(完全独立不依赖第三方库)

一、目的&#xff1a;基于WPF的DrawingVisual和谷歌地图瓦片开发的地图 二、预览 三、环境 VS2022&#xff0c;Net7,DrawingVisual&#xff0c;谷歌地图瓦片 四、主要功能 地图缩放&#xff0c;平移&#xff0c;定位 真实经纬度 显示瓦片信息 显示真实经纬度和经纬线 省市县…...

【C++】static作用总结

文章目录 1. 在函数内&#xff08;局部静态变量&#xff09;2. 在类中的静态成员变量3. 在类中的静态成员函数4. 在文件/模块中的静态变量或函数总结 1. 在函数内&#xff08;局部静态变量&#xff09; 当 static 用于函数内的局部变量时&#xff0c;该变量的生命周期变为整个…...

视频提取字幕的软件有哪些?高效转录用这些

探索视频的奥秘&#xff0c;从字幕开始&#xff01;你是否曾被繁复的字幕处理困扰&#xff0c;渴望有一款简单好用的在线免费软件来轻松解锁字幕提取&#xff1f; 告别手动输入的烦恼&#xff0c;我们为你精选了6款视频字幕提取在线免费软件&#xff0c;它们不仅能一键转录&am…...

(4)SVG-path中的椭圆弧A(绝对)或a(相对)

1、概念 表示经过起始点(即上一条命令的结束点)&#xff0c;到结束点之间画一段椭圆弧 2、7个参数 rx&#xff0c;ry&#xff0c;x-axis-rotation&#xff0c;large-arc-flag&#xff0c;sweep-flag&#xff0c;x&#xff0c;y &#xff08;1&#xff09;和&#xff08;2&a…...

docker国内镜像源报错解决方案

Job for docker.service failed because the control process exited with error code. See "systemctl status docker.service" and "journalctl -xe" for details. 遇到 Job for docker.service failed because the control process exited with error …...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例&#xff1a;使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例&#xff1a;使用OpenAI GPT-3进…...

IGP(Interior Gateway Protocol,内部网关协议)

IGP&#xff08;Interior Gateway Protocol&#xff0c;内部网关协议&#xff09; 是一种用于在一个自治系统&#xff08;AS&#xff09;内部传递路由信息的路由协议&#xff0c;主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台&#xff08;Launchpad&#xff09;多出来了&#xff1a;Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显&#xff0c;都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

在Ubuntu中设置开机自动运行(sudo)指令的指南

在Ubuntu系统中&#xff0c;有时需要在系统启动时自动执行某些命令&#xff0c;特别是需要 sudo权限的指令。为了实现这一功能&#xff0c;可以使用多种方法&#xff0c;包括编写Systemd服务、配置 rc.local文件或使用 cron任务计划。本文将详细介绍这些方法&#xff0c;并提供…...

04-初识css

一、css样式引入 1.1.内部样式 <div style"width: 100px;"></div>1.2.外部样式 1.2.1.外部样式1 <style>.aa {width: 100px;} </style> <div class"aa"></div>1.2.2.外部样式2 <!-- rel内表面引入的是style样…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

LabVIEW双光子成像系统技术

双光子成像技术的核心特性 双光子成像通过双低能量光子协同激发机制&#xff0c;展现出显著的技术优势&#xff1a; 深层组织穿透能力&#xff1a;适用于活体组织深度成像 高分辨率观测性能&#xff1a;满足微观结构的精细研究需求 低光毒性特点&#xff1a;减少对样本的损伤…...

从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障

关键领域软件测试的"安全密码"&#xff1a;Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天&#xff0c;软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力&#xff0c;从金融交易到交通管控&#xff0c;这些关乎国计民生的关键领域…...

nnUNet V2修改网络——暴力替换网络为UNet++

更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...

云安全与网络安全:核心区别与协同作用解析

在数字化转型的浪潮中&#xff0c;云安全与网络安全作为信息安全的两大支柱&#xff0c;常被混淆但本质不同。本文将从概念、责任分工、技术手段、威胁类型等维度深入解析两者的差异&#xff0c;并探讨它们的协同作用。 一、核心区别 定义与范围 网络安全&#xff1a;聚焦于保…...