当前位置: 首页 > news >正文

YOLOv8添加MobileViTv3模块(代码+free)

目录

一、理由

二、方法

(1)导入MobileViTv3模块

(2)在ultralytics/nn/tasks.py的函数parse_model中修改

(3)在yaml配置文件中写入

(4)开始训练,先把其他梯度关闭,保留新加的模块的梯度。

代码已在GitHub上传,链接:yolov8_vit


一、理由

        MobileViTv3是一种为移动设备优化的轻量级视觉Transformer架构,它结合了卷积神经网络(CNN)和视觉Transformer(ViT)的特点,以创建适合移动视觉任务的轻量级模型。

二、方法

(1)导入MobileViTv3模块

在ultralytics/nn创建vit文件夹,文件夹内放MobileViTv3以及需要的包。MobileViTv3模块如下:

import numpy as np
from torch import nn, Tensor
import math
import torch
from torch.nn import functional as F
from typing import Optional, Dict, Tuple, Union, Sequence
from mobilevit_v2_block import MobileViTBlockv2 as MbViTBkV2class MbViTV3(MbViTBkV2):def __init__(self,in_channels: int,attn_unit_dim: int,patch_h: Optional[int] = 2,patch_w: Optional[int] = 2,ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0,n_attn_blocks: Optional[int] = 2,attn_dropout: Optional[float] = 0.0,dropout: Optional[float] = 0.0,ffn_dropout: Optional[float] = 0.0,conv_ksize: Optional[int] = 3,attn_norm_layer: Optional[str] = "layer_norm_2d",enable_coreml_compatible_fn: Optional[bool] = False,) -> None:super(MbViTV3, self).__init__(in_channels, attn_unit_dim)self.enable_coreml_compatible_fn = enable_coreml_compatible_fnif self.enable_coreml_compatible_fn:# we set persistent to false so that these weights are not part of model's state_dictself.register_buffer(name="unfolding_weights",tensor=self._compute_unfolding_weights(),persistent=False,)cnn_out_dim = attn_unit_dimself.conv_proj = nn.Conv2d(2 * cnn_out_dim, in_channels, 1, 1)def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor:x = self.resize_input_if_needed(x)fm_conv = self.local_rep(x)# convert feature map to patchesif self.enable_coreml_compatible_fn:patches, output_size = self.unfolding_coreml(fm_conv)else:patches, output_size = self.unfolding_pytorch(fm_conv)# learn global representations on all patchespatches = self.global_rep(patches)# [B x Patch x Patches x C] --> [B x C x Patches x Patch]if self.enable_coreml_compatible_fn:fm = self.folding_coreml(patches=patches, output_size=output_size)else:fm = self.folding_pytorch(patches=patches, output_size=output_size)# MobileViTv3: local+global instead of only globalfm = self.conv_proj(torch.cat((fm, fm_conv), dim=1))# MobileViTv3: skip connectionfm = fm + xreturn fmif __name__ == '__main__':from thop import profile  ## 导入thop模块model = MbViTV3(320, 160, enable_coreml_compatible_fn=False)input = torch.randn(1, 320, 44, 84)#flops, params = profile(model, inputs=(input,))outpus = model.forward_spatial(input)print('flops')  ## 打印计算量# print('params', params)  ## 打印参数量

(2)在ultralytics/nn/tasks.py的函数parse_model中修改

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)# Parse a YOLO model.yaml dictionaryif verbose:LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()if verbose:LOGGER.info(f"{colorstr('activation:')} {act}")  # printlayers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args.......elif m in {MbViTV3}:c2 = args[0].......

(3)在yaml配置文件中写入

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2        320*320*64- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4       160*160*128- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8       80*80*256- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16      40*40*512- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32     20*20*1024- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9              20*20*1024# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 10- [[-1, 6], 1, Concat, [1]]                  # 11- [-1, 3, C2f, [512]]                        # 12                 40*40*512- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13- [[-1, 4], 1, Concat, [1]]                  # 14- [-1, 3, C2f, [256]]                        # 15 (P3/8-small)    44*84*320- [-1, 1, MbViTV3, [320, 160]]               # 16- [-1, 1, Conv, [256, 3, 2]]                 # 17- [[-1, 12], 1, Concat, [1]]                 # 18- [-1, 3, C2f, [512]]                        # 19 (P4/16-medium)  40*40*512- [-1, 1, Conv, [512, 3, 2]]                # 20- [[-1, 9], 1, Concat, [1]]                 # 21- [-1, 3, C2f, [1024]]                      # 22 (P5/32-large)  20*20*1024- [[16, 19, 22], 1, Detect, [nc]]           # 23

(4)开始训练,先把其他梯度关闭,保留新加的模块的梯度。

import os
from ultralytics import YOLO
import subprocess
from ultralytics.nn.vit.Vit import MbViTV3
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def add_vit(model):for name, param in model.model.named_parameters():stand = name[6:8]vit_ls = ['16']if stand in vit_ls:param.requires_grad = Trueelse:param.requires_grad = Falsefor name, param in model.model.named_parameters():if param.requires_grad:print(name)return modeldef main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8x.yaml').load('/root/autodl-tmp/yolov8x.pt')model = YOLO(r'yolov8x_vit.yaml').load('runs/detect/vit/weights/vit.pt')model = add_vit(model)model.train(data="data.yaml", imgsz=640, epochs=50, batch=10, device=0, workers=0)
if __name__ == '__main__':main()

————————————over————————————

相关文章:

YOLOv8添加MobileViTv3模块(代码+free)

目录 一、理由 二、方法 &#xff08;1&#xff09;导入MobileViTv3模块 &#xff08;2&#xff09;在ultralytics/nn/tasks.py的函数parse_model中修改 &#xff08;3&#xff09;在yaml配置文件中写入 &#xff08;4&#xff09;开始训练&#xff0c;先把其他梯度关闭&…...

从概念到落地:全面解析DApp项目开发的核心要素与未来趋势

随着区块链技术的迅猛发展&#xff0c;去中心化应用程序&#xff08;DApp&#xff09;逐渐成为Web3时代的重要组成部分。DApp通过智能合约和分布式账本技术&#xff0c;提供了无需信任中介的解决方案&#xff0c;这种去中心化的特性使其在金融、游戏、社交等多个领域得到了广泛…...

仓颉编程入门 -- 泛型概述 , 如何定义泛型函数

泛型概述 , 如何定义泛型函数 1 . 泛型的定义 在仓颉编程语言中&#xff0c;泛型机制允许我们定义参数化类型&#xff0c;这些类型在声明时不具体指定其操作的数据类型&#xff0c;而是作为类型形参保留&#xff0c;待使用时通过类型实参来明确。这种灵活性在函数和类型声明中…...

SOC估算方法之(OCV-SOC+安时积分法)

一、引言 此方法主要参考电动汽车用磷酸铁锂电池SOC估算方法这篇论文 总结&#xff1a; 开路电压的测量需要将电池静止相当长的一段时间才能达到平衡状态进行测量。 安时积分法存在初始SOC的估算和累积的误差。 所以上述两种方法都存在一定的缺陷&#xff0c;因此下面主要讲…...

指针(下)

文章目录 指针(下)野指针、空指针野指针空指针 二级指针**main**函数的原型说明 常量指针与指针常量常量指针指针常量常量指针常量 动态内存分配常用函数**malloc****calloc****realloc****free** **void**与**void***的区别扩展&#xff1a;形式参数和实际参数的对应关系 指针…...

C# 浅谈IEnumerable

一、IEnumerable 简介 IEnumerable 是一个接口&#xff0c;它定义了对集合进行迭代所需的方法。IEnumerable 接口主要用于允许开发者使用foreach循环来遍历集合中的元素。这个接口定义了一个名为 GetEnumerator 的方法&#xff0c;该方法返回一个实现了 IEnumerator 接口的对象…...

mmdebstrap:创建 Debian 系统 chroot 环境的利器 ️

文章目录 mmdebstrap 的一般性参数说明 &#x1f4dc;mmdebstrap 的常见用法示例 &#x1f308;使用 mmdebstrap 的注意事项 ⚠️ &#x1f308;你好呀&#xff01;我是 山顶风景独好 &#x1f388;欢迎踏入我的博客世界&#xff0c;能与您在此邂逅&#xff0c;真是缘分使然&am…...

【Linux SQLite数据库】一、SQLite交叉编译与移植

SQLite 是一个用 C 语言编写的开源、轻量级、快速、独立且高可靠性的 SQL 数据库引擎&#xff0c;它提供了功能齐全的数据库解决方案。SQLite 几乎可以在所有的手机和计算机上运行&#xff0c;它被嵌入到无数人每天都在使用的众多应用程序中。此外&#xff0c;SQLite 还具有稳定…...

每天写两道(数组篇)移除元素、

27.移除元素 给你一个数组 nums 和一个值 val&#xff0c;你需要 原地 移除所有数值等于 val 的元素。元素的顺序可能发生改变。然后返回 nums 中与 val 不同的元素的数量。 假设 nums 中不等于 val 的元素数量为 k&#xff0c;要通过此题&#xff0c;您需要执行以下操作&#…...

Unity 使用 NewtonSoft Json插件报错

JsonReaderException: Unexpected character encountered while parsing value: . Path , line 0, position 0. 通过断点发现&#xff0c;头有一串ZWNBSP&#xff0c;这个是BOM格式的JSON。在文件下看不到。 解决方法&#xff1a;改编码格式&#xff0c;Remove BOM....

k8s 部署 Mysqld_exporter 以及添加告警规则

最近监控 mysql 数据库&#xff0c;用了 pmm-server、pmm-client 发现监控是真的不太好用&#xff0c;还是用回 prometheus 吧。 部署mysqld_exporter k8s 部署最新版本的 mysqld_exporter&#xff0c;支持的数据库版本 MySQL >5.6、MariaDB > 10.3。 先在数据库创建用…...

基于STM32开发的智能农业环境监测系统

目录 引言环境准备工作 硬件准备软件安装与配置系统设计 系统架构硬件连接代码实现 初始化代码控制代码应用场景 农田环境监测温室环境控制常见问题及解决方案 常见问题解决方案结论 1. 引言 智能农业环境监测系统通过集成多种环境传感器&#xff0c;实时监测土壤湿度、温度…...

【SQL】平均售价

目录 题目 分析 代码 题目 表&#xff1a;Prices ------------------------ | Column Name | Type | ------------------------ | product_id | int | | start_date | date | | end_date | date | | price | int | ---------------…...

存储器与CPU的连接

1.单块存储芯片与CPU的连接 单独的一块独立的存储芯片提供的线有&#xff1a;地址总线&#xff0c;数据总线&#xff0c;读写控制线&#xff0c;片选线&#xff0c;如果该存储器只有八根数据总线用于输出数据&#xff0c;而cpu一次可以读64位的数据呢&#xff1f; 我们可以将八…...

unity--webgl 访问本地index.html

目录 1:使用本地服务器 1.1 使用 Python 的 SimpleHTTPServer 1.2 使用 Node.js 的 http-server 2&#xff1a;让其他人通过 IP 地址来访问你的 Unity WebGL 项目 2.1: 确保服务器可访问 2.2 获取公共 IP 地址 2.3 配置本地服务器 1.使用 Python 的 SimpleHTTPServer 2…...

慢慢欣赏DPDK RTE_MAX_ETHPORTS的定义

DPDK代码里面&#xff0c;RTE_MAX_ETHPORTS是一个常见的宏定义&#xff0c;但是在.c和.h文件找不到其定义&#xff0c;在全文件搜索条件下&#xff0c;在config/meson.build找到这么一个定义 dpdk_conf.set(RTE_MAX_ETHPORTS, get_option(max_ethports)) 该宏定义是根据构建输…...

Java Nacos与Gateway的使用

Java系列文章目录 IDEA使用指南 Java泛型总结&#xff08;快速上手详解&#xff09; Java Lambda表达式总结&#xff08;快速上手详解&#xff09; Java Optional容器总结&#xff08;快速上手图解&#xff09; Java 自定义注解笔记总结&#xff08;油管&#xff09; Jav…...

前端项目中的Server-sent Events(SSE)项目实践及其与websocket的区别

前端项目中的Server-sent Events(SSE)项目实践 前言 在前端开发中&#xff0c;实时数据更新是提升用户体验的重要因素之一。Server-SentEvents(SSE)是一种高效的技术&#xff0c;允许服务器通过单向连接将实时数据推送到客户端。下面将从SSE的基本改变&#xff0c;使用场景展…...

《老俞闲话|唯爱和热情不可辜负》读后感

《老俞闲话&#xff5c;唯爱和热情不可辜负》读后感 俞敏洪先生的这篇讲话充满了深情与智慧&#xff0c;他以自己丰富的人生经历和教育实践&#xff0c;向我们展现了一位教育家对于教育事业的热爱和对教师角色的深刻理解。 情感真挚&#xff0c;触动人心 俞敏洪先生的讲话中流…...

C语言 ——— 在杨氏矩阵中查找具体的某个数

目录 何为杨氏矩阵 题目要求 代码实现 何为杨氏矩阵 可以把杨氏矩阵理解为一个二维数组&#xff0c;这个二维数组中的每一行从左到右是递增的&#xff0c;每一列从上到下是递增的 题目要求 在杨氏矩阵中查找具体的某个数 要求&#xff1a;时间复杂度小于O(N) 代码实现…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分&#xff1a; 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析&#xff1a; CTR…...

7.4.分块查找

一.分块查找的算法思想&#xff1a; 1.实例&#xff1a; 以上述图片的顺序表为例&#xff0c; 该顺序表的数据元素从整体来看是乱序的&#xff0c;但如果把这些数据元素分成一块一块的小区间&#xff0c; 第一个区间[0,1]索引上的数据元素都是小于等于10的&#xff0c; 第二…...

ubuntu搭建nfs服务centos挂载访问

在Ubuntu上设置NFS服务器 在Ubuntu上&#xff0c;你可以使用apt包管理器来安装NFS服务器。打开终端并运行&#xff1a; sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享&#xff0c;例如/shared&#xff1a; sudo mkdir /shared sud…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

解决Ubuntu22.04 VMware失败的问题 ubuntu入门之二十八

现象1 打开VMware失败 Ubuntu升级之后打开VMware上报需要安装vmmon和vmnet&#xff0c;点击确认后如下提示 最终上报fail 解决方法 内核升级导致&#xff0c;需要在新内核下重新下载编译安装 查看版本 $ vmware -v VMware Workstation 17.5.1 build-23298084$ lsb_release…...

【单片机期末】单片机系统设计

主要内容&#xff1a;系统状态机&#xff0c;系统时基&#xff0c;系统需求分析&#xff0c;系统构建&#xff0c;系统状态流图 一、题目要求 二、绘制系统状态流图 题目&#xff1a;根据上述描述绘制系统状态流图&#xff0c;注明状态转移条件及方向。 三、利用定时器产生时…...

土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等

&#x1f50d; 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术&#xff0c;可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势&#xff0c;还能有效评价重大生态工程…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

Redis的发布订阅模式与专业的 MQ(如 Kafka, RabbitMQ)相比,优缺点是什么?适用于哪些场景?

Redis 的发布订阅&#xff08;Pub/Sub&#xff09;模式与专业的 MQ&#xff08;Message Queue&#xff09;如 Kafka、RabbitMQ 进行比较&#xff0c;核心的权衡点在于&#xff1a;简单与速度 vs. 可靠与功能。 下面我们详细展开对比。 Redis Pub/Sub 的核心特点 它是一个发后…...

JAVA后端开发——多租户

数据隔离是多租户系统中的核心概念&#xff0c;确保一个租户&#xff08;在这个系统中可能是一个公司或一个独立的客户&#xff09;的数据对其他租户是不可见的。在 RuoYi 框架&#xff08;您当前项目所使用的基础框架&#xff09;中&#xff0c;这通常是通过在数据表中增加一个…...