当前位置: 首页 > news >正文

分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------  log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma)  # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,#                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}]  ''Train Loss avg: {loss_train:>6.4f}  ''Valid Loss avg: {loss_valid:>6.4f}  ''Train Acc@1 avg:  {top1_train:>7.2f}%   ''Valid Acc@1 avg: {top1_valid:>7.2%}    ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

相关文章:

分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版 简介 本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化&#xff0…...

从Milvus迁移DashVector

本文档演示如何从Milvus将Collection数据全量导出,并适配迁移至DashVector。方案的主要流程包括: 首先,升级Milvus版本,目前Milvus只有在最新版本(v.2.3.x)中支持全量导出其次,将Milvus Collection的Schema信息和数据…...

彻底改变计算机视觉的 Vision Transformer (ViT) 综合指南(视觉转换器终极指南)

欢迎来到雲闪世界。大家好!对于那些还不认识我的人,我叫 Francois,我是 Meta 的研究科学家。我热衷于解释先进的 AI 概念并使其更容易理解。 今天,让我们深入探讨计算机视觉领域最重要的贡献之一:Vision Transformer&…...

vue3 v-bind=“$attrs“ 的一些理解,透传 Attributes相关说明及事例说明

1、可能小伙伴们经常会在自己的项目中看到v-bind"$attrs"&#xff0c;这个一般是在自定义组件中看到。 比如&#xff1a; <template><BasicModalv-bind"$attrs"register"registerModal":title"getTitle"ok"handleSubm…...

鸿蒙开发基础知识-页面布局【第四篇】

1.类型转换 2.交互点击事件 3.状态管理 4.forEch渲染和右上角图标 测试案例 Stack 层叠布局一个生肖卡 5. 动画展示图片 6. Swiper 轮播组件的基本使用 图片等比显示 aspectRatio&#xff08;&#xff09;...

用CSS实现前端响应式布局

一、响应式布局的重要性 随着移动设备的普及&#xff0c;越来越多的用户通过手机、平板电脑等设备访问网页。如果网页不能适应不同的屏幕尺寸&#xff0c;就会出现布局混乱、内容显示不全等问题&#xff0c;严重影响用户体验。响应式布局可以确保网页在各种设备上都能保持美观…...

【docker】docker启动sqlserver

sqlserver-docker官方地址 # sqlserver不是从docker的中央仓库拉取的&#xff0c;而是从ms的仓库拉取的。 docker pull mcr.microsoft.com/mssql/server:2019-latest# 宿主机即docker程序运行的linux服务器 docker run -d \ --user root \ --name mssql2019 \ -e "ACCEPT…...

Python爬虫01

requests模块 文档 安装 pip/pip3 install requestsresponse.text 和 response.content的区别 1.response.text 等价于 response.content.decode("推测出的编码字符集")response.text 类型&#xff1a;str 编码类型&#xff1a;requests模块自动根据Http头部对…...

关于vue项目启动报错Error: error:0308010C:digital envelope routines::unsupported

周五啦&#xff0c;总结一下这周遇到的个别问题吧&#xff0c;就是关于启动项目的时候其他的东西都准备好了&#xff0c;执行命令后报错Error: error:0308010C:digital envelope routines::unsupported 这里看一下我标注的地方&#xff0c;然后总结一下就不难发现问题所在 查看…...

随笔1:数学建模与数值计算

目录 1.1 矩阵运算 1.2 基本数学函数 1.3 数值求解 数学建模与数值计算 是将实际问题通过数学公式和模型进行描述&#xff0c;并通过计算获得模型解的过程。这是数学建模中最基本也是最重要的环节之一。下面是详细的知识点讲解及相应的MATLAB代码示例。 1.1 矩阵运算 知识点…...

SDN架构详解

目录 1&#xff09;经典的IP网络-分布式网络 2&#xff09;经典网络面临的问题 3&#xff09;SDN起源 4&#xff09;OpenFlow基本概念 5&#xff09;Flow Table简介 6&#xff09;SDN的网络架构 7&#xff09;华为SDN网络架构 8&#xff09;传统网络 vs SDN 9&#xf…...

platform框架

platform框架 注册设备进入总线platform_device_register函数 注册驱动进入总线platform_driver_register函数 注册设备进入总线 platform_device_register函数 int platform_device_register(struct platform_device *pdev) struct platform_device {const char * name; 名…...

零成本搞定静态博客——十分钟安装hugo与主题

文章目录 hugo介绍hugo安装与使用方式一&#xff1a;新建站点自建主题方式二&#xff1a;新建站点使用系统推荐的主题 hugo介绍 通过 Hugo 你可以快速搭建你的静态网站&#xff0c;比如博客系统、文档介绍、公司主页、产品介绍等等。相对于其他静态网站生成器来说&#xff0c;…...

windows C++ 并行编程-转换使用取消的 OpenMP 循环以使用并发运行时

某些并行循环不需要执行所有迭代。 例如&#xff0c;搜索值的算法可以在找到值后终止。 OpenMP 不提供中断并行循环的机制。 但是&#xff0c;可以使用布尔值或标志来启用循环迭代&#xff0c;以指示已找到解决方案。 并发运行时提供允许一个任务取消其他尚未启动的任务的功能。…...

经验笔记:跨站脚本攻击(Cross-Site Scripting,简称XSS)

跨站脚本攻击&#xff08;Cross-Site Scripting&#xff0c;简称XSS&#xff09;经验笔记 跨站脚本攻击&#xff08;XSS&#xff1a;Cross-Site Scripting&#xff09;是一种常见的Web应用程序安全漏洞&#xff0c;它允许攻击者将恶意脚本注入到看起来来自可信网站的网页上。当…...

演示:基于WPF的DrawingVisual和谷歌地图瓦片开发的地图(完全独立不依赖第三方库)

一、目的&#xff1a;基于WPF的DrawingVisual和谷歌地图瓦片开发的地图 二、预览 三、环境 VS2022&#xff0c;Net7,DrawingVisual&#xff0c;谷歌地图瓦片 四、主要功能 地图缩放&#xff0c;平移&#xff0c;定位 真实经纬度 显示瓦片信息 显示真实经纬度和经纬线 省市县…...

【C++】static作用总结

文章目录 1. 在函数内&#xff08;局部静态变量&#xff09;2. 在类中的静态成员变量3. 在类中的静态成员函数4. 在文件/模块中的静态变量或函数总结 1. 在函数内&#xff08;局部静态变量&#xff09; 当 static 用于函数内的局部变量时&#xff0c;该变量的生命周期变为整个…...

视频提取字幕的软件有哪些?高效转录用这些

探索视频的奥秘&#xff0c;从字幕开始&#xff01;你是否曾被繁复的字幕处理困扰&#xff0c;渴望有一款简单好用的在线免费软件来轻松解锁字幕提取&#xff1f; 告别手动输入的烦恼&#xff0c;我们为你精选了6款视频字幕提取在线免费软件&#xff0c;它们不仅能一键转录&am…...

(4)SVG-path中的椭圆弧A(绝对)或a(相对)

1、概念 表示经过起始点(即上一条命令的结束点)&#xff0c;到结束点之间画一段椭圆弧 2、7个参数 rx&#xff0c;ry&#xff0c;x-axis-rotation&#xff0c;large-arc-flag&#xff0c;sweep-flag&#xff0c;x&#xff0c;y &#xff08;1&#xff09;和&#xff08;2&a…...

docker国内镜像源报错解决方案

Job for docker.service failed because the control process exited with error code. See "systemctl status docker.service" and "journalctl -xe" for details. 遇到 Job for docker.service failed because the control process exited with error …...

如何通过SEO总监的工作经验提升个人价值

SEO总监的工作经验&#xff1a;如何提升个人价值 在当今数字化时代&#xff0c;SEO&#xff08;搜索引擎优化&#xff09;已经成为各行各业不可或缺的一部分。作为一名SEO总监&#xff0c;你不仅要了解如何提升企业网站的搜索排名&#xff0c;更要通过自己的工作经验提升个人价…...

Hunyuan-MT 7B效果实测:韩语、俄语、藏语等小语种翻译到底有多准?

Hunyuan-MT 7B效果实测&#xff1a;韩语、俄语、藏语等小语种翻译到底有多准&#xff1f; 1. 小语种翻译的痛点与解决方案 在全球化交流日益频繁的今天&#xff0c;小语种翻译需求快速增长&#xff0c;但传统解决方案往往存在三大痛点&#xff1a; 准确率低&#xff1a;韩语…...

模块化多电平变换器MMC的NLM与CPS-PWM调制策略仿真实现:交流3000V-直流5000...

模块化多电平变换器MMC两种调制策略实现&#xff08;交流3000V-直流5000V整流&#xff09;仿真&#xff0c;单桥臂二十子模块&#xff0c;分别采用最近电平逼近NLM与载波移相调制CPS-PWM实现&#xff0c;仿真中使用环流抑制&#xff0c;NLM中采用快速排序&#xff0c;两个仿真动…...

测试、项目管理、软件度量和质量

欢迎来到我的软考中级——软件设计师备考合集。这里不只是一份简单的知识点堆砌&#xff0c;而是我在备考征途中&#xff0c;对庞杂知识体系进行深度梳理与内化的结晶。 面对浩瀚的考纲&#xff0c;从计算机组成原理的底层逻辑&#xff0c;到操作系统的进程调度&#xff1b;从数…...

参数党VS体验派?雅马哈、卡西欧、费森4款热门电钢琴型号终极对决,结果有点意外!

你是否也有这样的时刻&#xff1f;练习时间在不断累积&#xff0c;指法日渐熟练&#xff0c;可弹奏出的声音却依然显得机械、平淡&#xff0c;甚至有点“假”。那种在琴行试弹顶级三角钢琴时&#xff0c;指尖与琴键、琴弦与空气共鸣所带来的微妙震颤与心灵悸动&#xff0c;在自…...

计算机网络:从基础到未来趋势,从0死磕全栈之Next.js 中间件(Middleware)详解与实战。

计算机网络基础概念 计算机网络是通过通信链路和交换设备将地理上分散的计算机系统连接起来&#xff0c;实现资源共享和信息传递的系统。其核心目标是提供高效、可靠的数据传输服务。 网络拓扑结构包括星型、总线型、环型和网状等。每种拓扑结构在性能、可靠性和成本上各有优劣…...

终极指南:如何利用HTTPS-PORTAL与Docker Gen实现自动HTTPS配置的魔法

终极指南&#xff1a;如何利用HTTPS-PORTAL与Docker Gen实现自动HTTPS配置的魔法 【免费下载链接】https-portal A fully automated HTTPS server powered by Nginx, Lets Encrypt and Docker. 项目地址: https://gitcode.com/gh_mirrors/ht/https-portal HTTPS-PORTAL是…...

OpenClaw会议小秘书:Qwen3.5-9B自动生成待办事项

OpenClaw会议小秘书&#xff1a;Qwen3.5-9B自动生成待办事项 1. 为什么需要会议自动化助手 每周三下午的组会结束后&#xff0c;我的记事本上总是密密麻麻写满了待办事项。但问题在于——这些潦草的手写笔记有30%的概率会丢失&#xff0c;50%的概率会忘记执行截止时间。直到上…...

别再傻傻分不清!ESP32-S3上USB CDC、UART0和板载CH340到底谁在干活?

ESP32-S3串口全解析&#xff1a;快速识别USB CDC、UART0与CH340的实战指南 刚拿到ESP32-S3开发板时&#xff0c;很多开发者都会遇到一个令人困惑的场景——连接电脑后&#xff0c;设备管理器里突然冒出三四个COM端口&#xff0c;而Arduino IDE的下拉菜单里也列出一堆选项。到底…...

嵌入式开发自动化实践与效率提升

1. 嵌入式开发中的重复工作困境作为一名在嵌入式领域摸爬滚打多年的工程师&#xff0c;我深知这个行业的痛点——那些看似简单却消耗大量精力的重复性工作。从版本构建到代码移植&#xff0c;从环境配置到测试验证&#xff0c;这些工作就像影子一样伴随着每个开发者的日常。刚入…...