当前位置: 首页 > news >正文

分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------  log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma)  # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,#                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}]  ''Train Loss avg: {loss_train:>6.4f}  ''Valid Loss avg: {loss_valid:>6.4f}  ''Train Acc@1 avg:  {top1_train:>7.2f}%   ''Valid Acc@1 avg: {top1_valid:>7.2%}    ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

相关文章:

分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版 简介 本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化&#xff0…...

从Milvus迁移DashVector

本文档演示如何从Milvus将Collection数据全量导出,并适配迁移至DashVector。方案的主要流程包括: 首先,升级Milvus版本,目前Milvus只有在最新版本(v.2.3.x)中支持全量导出其次,将Milvus Collection的Schema信息和数据…...

彻底改变计算机视觉的 Vision Transformer (ViT) 综合指南(视觉转换器终极指南)

欢迎来到雲闪世界。大家好!对于那些还不认识我的人,我叫 Francois,我是 Meta 的研究科学家。我热衷于解释先进的 AI 概念并使其更容易理解。 今天,让我们深入探讨计算机视觉领域最重要的贡献之一:Vision Transformer&…...

vue3 v-bind=“$attrs“ 的一些理解,透传 Attributes相关说明及事例说明

1、可能小伙伴们经常会在自己的项目中看到v-bind"$attrs"&#xff0c;这个一般是在自定义组件中看到。 比如&#xff1a; <template><BasicModalv-bind"$attrs"register"registerModal":title"getTitle"ok"handleSubm…...

鸿蒙开发基础知识-页面布局【第四篇】

1.类型转换 2.交互点击事件 3.状态管理 4.forEch渲染和右上角图标 测试案例 Stack 层叠布局一个生肖卡 5. 动画展示图片 6. Swiper 轮播组件的基本使用 图片等比显示 aspectRatio&#xff08;&#xff09;...

用CSS实现前端响应式布局

一、响应式布局的重要性 随着移动设备的普及&#xff0c;越来越多的用户通过手机、平板电脑等设备访问网页。如果网页不能适应不同的屏幕尺寸&#xff0c;就会出现布局混乱、内容显示不全等问题&#xff0c;严重影响用户体验。响应式布局可以确保网页在各种设备上都能保持美观…...

【docker】docker启动sqlserver

sqlserver-docker官方地址 # sqlserver不是从docker的中央仓库拉取的&#xff0c;而是从ms的仓库拉取的。 docker pull mcr.microsoft.com/mssql/server:2019-latest# 宿主机即docker程序运行的linux服务器 docker run -d \ --user root \ --name mssql2019 \ -e "ACCEPT…...

Python爬虫01

requests模块 文档 安装 pip/pip3 install requestsresponse.text 和 response.content的区别 1.response.text 等价于 response.content.decode("推测出的编码字符集")response.text 类型&#xff1a;str 编码类型&#xff1a;requests模块自动根据Http头部对…...

关于vue项目启动报错Error: error:0308010C:digital envelope routines::unsupported

周五啦&#xff0c;总结一下这周遇到的个别问题吧&#xff0c;就是关于启动项目的时候其他的东西都准备好了&#xff0c;执行命令后报错Error: error:0308010C:digital envelope routines::unsupported 这里看一下我标注的地方&#xff0c;然后总结一下就不难发现问题所在 查看…...

随笔1:数学建模与数值计算

目录 1.1 矩阵运算 1.2 基本数学函数 1.3 数值求解 数学建模与数值计算 是将实际问题通过数学公式和模型进行描述&#xff0c;并通过计算获得模型解的过程。这是数学建模中最基本也是最重要的环节之一。下面是详细的知识点讲解及相应的MATLAB代码示例。 1.1 矩阵运算 知识点…...

SDN架构详解

目录 1&#xff09;经典的IP网络-分布式网络 2&#xff09;经典网络面临的问题 3&#xff09;SDN起源 4&#xff09;OpenFlow基本概念 5&#xff09;Flow Table简介 6&#xff09;SDN的网络架构 7&#xff09;华为SDN网络架构 8&#xff09;传统网络 vs SDN 9&#xf…...

platform框架

platform框架 注册设备进入总线platform_device_register函数 注册驱动进入总线platform_driver_register函数 注册设备进入总线 platform_device_register函数 int platform_device_register(struct platform_device *pdev) struct platform_device {const char * name; 名…...

零成本搞定静态博客——十分钟安装hugo与主题

文章目录 hugo介绍hugo安装与使用方式一&#xff1a;新建站点自建主题方式二&#xff1a;新建站点使用系统推荐的主题 hugo介绍 通过 Hugo 你可以快速搭建你的静态网站&#xff0c;比如博客系统、文档介绍、公司主页、产品介绍等等。相对于其他静态网站生成器来说&#xff0c;…...

windows C++ 并行编程-转换使用取消的 OpenMP 循环以使用并发运行时

某些并行循环不需要执行所有迭代。 例如&#xff0c;搜索值的算法可以在找到值后终止。 OpenMP 不提供中断并行循环的机制。 但是&#xff0c;可以使用布尔值或标志来启用循环迭代&#xff0c;以指示已找到解决方案。 并发运行时提供允许一个任务取消其他尚未启动的任务的功能。…...

经验笔记:跨站脚本攻击(Cross-Site Scripting,简称XSS)

跨站脚本攻击&#xff08;Cross-Site Scripting&#xff0c;简称XSS&#xff09;经验笔记 跨站脚本攻击&#xff08;XSS&#xff1a;Cross-Site Scripting&#xff09;是一种常见的Web应用程序安全漏洞&#xff0c;它允许攻击者将恶意脚本注入到看起来来自可信网站的网页上。当…...

演示:基于WPF的DrawingVisual和谷歌地图瓦片开发的地图(完全独立不依赖第三方库)

一、目的&#xff1a;基于WPF的DrawingVisual和谷歌地图瓦片开发的地图 二、预览 三、环境 VS2022&#xff0c;Net7,DrawingVisual&#xff0c;谷歌地图瓦片 四、主要功能 地图缩放&#xff0c;平移&#xff0c;定位 真实经纬度 显示瓦片信息 显示真实经纬度和经纬线 省市县…...

【C++】static作用总结

文章目录 1. 在函数内&#xff08;局部静态变量&#xff09;2. 在类中的静态成员变量3. 在类中的静态成员函数4. 在文件/模块中的静态变量或函数总结 1. 在函数内&#xff08;局部静态变量&#xff09; 当 static 用于函数内的局部变量时&#xff0c;该变量的生命周期变为整个…...

视频提取字幕的软件有哪些?高效转录用这些

探索视频的奥秘&#xff0c;从字幕开始&#xff01;你是否曾被繁复的字幕处理困扰&#xff0c;渴望有一款简单好用的在线免费软件来轻松解锁字幕提取&#xff1f; 告别手动输入的烦恼&#xff0c;我们为你精选了6款视频字幕提取在线免费软件&#xff0c;它们不仅能一键转录&am…...

(4)SVG-path中的椭圆弧A(绝对)或a(相对)

1、概念 表示经过起始点(即上一条命令的结束点)&#xff0c;到结束点之间画一段椭圆弧 2、7个参数 rx&#xff0c;ry&#xff0c;x-axis-rotation&#xff0c;large-arc-flag&#xff0c;sweep-flag&#xff0c;x&#xff0c;y &#xff08;1&#xff09;和&#xff08;2&a…...

docker国内镜像源报错解决方案

Job for docker.service failed because the control process exited with error code. See "systemctl status docker.service" and "journalctl -xe" for details. 遇到 Job for docker.service failed because the control process exited with error …...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启&#xff0c;数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后&#xff0c;存在与用户组权限相关的问题。具体表现为&#xff0c;Oracle 实例的运行用户&#xff08;oracle&#xff09;和集…...

基于大模型的 UI 自动化系统

基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)

HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

Prompt Tuning、P-Tuning、Prefix Tuning的区别

一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...

Spark 之 入门讲解详细版(1)

1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室&#xff08;Algorithms, Machines, and People Lab&#xff09;开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目&#xff0c;8个月后成为Apache顶级项目&#xff0c;速度之快足见过人之处&…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

线程与协程

1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指&#xff1a;像函数调用/返回一样轻量地完成任务切换。 举例说明&#xff1a; 当你在程序中写一个函数调用&#xff1a; funcA() 然后 funcA 执行完后返回&…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

华为OD机试-食堂供餐-二分法

import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...