联合目标检测与图像分类提升数据不平衡场景下的准确率
联合目标检测与图像分类提升数据不平衡场景下的准确率
在一些数据不平衡的场景下,使用单一的目标检测模型很难达到99%的准确率。为了优化这一问题,适当将其拆解为目标检测模型和图像分类模型的组合,可以更有效地控制最终效果,尤其是在添加焦点损失(focal loss)、调整超参数和数据预处理无效的情况下。以下是具体的实现方式及联合两个模型的推理代码。
整体功能概述
这段代码的主要功能包括:
- 加载目标检测和分类模型:使用两个 Ultralytics YOLO(YOLOv8/YOLOv11均可) 模型进行目标检测和分类。
- 处理图像:遍历指定输入文件夹中的所有图像,进行目标检测和分类。
- 绘制检测框和分类标签:在图像上绘制检测到的对象的边界框,并在框上方添加分类名称和置信度。
- 可选保存裁剪的对象图像:根据设置,裁剪检测到的对象区域并保存为单独的图像文件,文件名包含类别名称、置信度和坐标信息(便于调试)。
实现细节
1. 加载模型
代码加载了两个 YOLO 模型:
- 目标检测模型:一个单一类别的 YOLO 模型,用于检测主体对象。
- 图像分类模型:一个多类别的 YOLO 模型,用于对检测到的对象进行分类。
2. 处理图像
脚本处理输入文件夹中的每一张图像,步骤如下:
- 目标检测:使用目标检测模型检测图像中的对象。
- 裁剪检测到的对象:根据检测到的边界框坐标,裁剪出感兴趣的区域。
- 图像分类:对裁剪出的对象区域进行分类。
- 数据增强或欠采样:根据任务需求,对裁剪出的子图像进行数据增强或欠采样,以平衡数据集。
3. 绘制检测框和标签
对于每一个检测到的对象,脚本会:
- 在图像上绘制一个边界框。
- 在边界框上方添加分类名称和置信度标签。
4. 保存裁剪的对象图像
可选地,脚本会保存裁剪出的对象图像,文件名包含以下信息:
- 分类名称
- 置信度
- 边界框坐标
这对于调试和分析特定的检测结果非常有帮助。
推理代码
import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import randomdef generate_random_color_from_name(name):"""根据类别名生成可重复的颜色。"""random.seed(name) # 使用类别名作为随机种子return tuple(random.randint(0, 255) for _ in range(3))def generate_class_colors(names):"""为每个类别生成一个固定的颜色。"""class_colors = {}for class_name in names:class_colors[class_name] = generate_random_color_from_name(class_name)return class_colorsdef draw_box_on_image(image, box, color=(0, 255, 0), thickness=2):"""在图像上绘制检测框。"""x1, y1, x2, y2 = map(int, box)cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)def add_classification_to_box(image, box, class_name, confidence, color=(0, 255, 0)):"""在边界框上方添加分类名称和置信度。"""x1, y1, x2, y2 = map(int, box)label = f"{class_name}: {confidence:.2f}"cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, cv2.LINE_AA)def save_cropped_object(image, box, cls_class_name, confidence, output_folder, image_name):"""将裁剪的对象区域保存为图像到子文件夹中,文件名包含类别名、置信度和坐标。"""x1, y1, x2, y2 = map(int, box)cropped_img = image[y1:y2, x1:x2]# 为当前图像创建一个以图像文件名命名的子文件夹image_subfolder = Path(output_folder) / Path(image_name).stemimage_subfolder.mkdir(parents=True, exist_ok=True)# 为裁剪的对象创建文件名(class_name_confidence_x1_y1_x2_y2.jpg)# 确保置信度格式安全,使用两位小数,并用下划线分隔cropped_img_name = f"{cls_class_name}_{confidence:.2f}_{x1}_{y1}_{x2}_{y2}.jpg"cropped_img_path = image_subfolder / cropped_img_namecv2.imwrite(str(cropped_img_path), cropped_img)print(f"已保存裁剪对象: {cropped_img_path}")def process_image_with_detection_and_classification(model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped=False, detection_size=1280, classification_size=640):"""处理单张图像:执行对象检测,分类每个对象,并返回处理后的图像。:param model_det: 检测模型:param model_cls: 分类模型:param img_path: 图像路径:param names: 类别名称列表:param class_colors: 类别颜色映射字典:param output_folder: 输出文件夹路径:param save_cropped: 是否保存裁剪的对象图像:param detection_size: 检测模型输入图像大小:param classification_size: 分类模型输入图像大小:return: 处理后的图像"""img = cv2.imread(str(img_path))if img is None:print(f"无法读取图像: {img_path}")return None# 创建图像副本用于绘制(不修改原始图像)img_copy = img.copy()# 执行对象检测results_det = model_det.predict(str(img_path), imgsz=detection_size, conf=0.25, iou=0.45)# 处理每个检测结果(每个检测框)for r in results_det:boxes = r.boxes.xyxy.cpu().numpy() # xyxy 格式classes = r.boxes.cls.cpu().numpy()confidences = r.boxes.conf.cpu().numpy()for box, cls_id, confidence in zip(boxes, classes, confidences):# 检测模型的类别名det_class_name = names[int(cls_id)]# 使用检测到的类别名对应的颜色(该颜色是全局唯一的)color = class_colors.get(det_class_name, (255, 255, 255))# 裁剪对象区域x1, y1, x2, y2 = map(int, box)object_region = img[y1:y2, x1:x2]# 将对象区域调整为分类模型的输入大小object_region = cv2.resize(object_region, (classification_size, classification_size))# 执行分类results_cls = model_cls.predict(object_region, imgsz=classification_size)for result in results_cls:try:# 获取Top1预测结果classification_confidence = result.probs.cpu().numpy().top1conftop1_index = result.probs.top1cls_class_name = names[top1_index]# 根据分类结果的类别名设置颜色final_color = class_colors.get(cls_class_name, color)add_classification_to_box(img_copy, box, cls_class_name, classification_confidence, color=final_color)# 如果启用了保存裁剪对象,则保存if save_cropped:save_cropped_object(img, box, cls_class_name, classification_confidence, output_folder, img_path.name)except Exception as e:print(f"分类时出错: {e}")# 在图像副本上绘制检测框draw_box_on_image(img_copy, box, color=color)return img_copydef process_images(model_det, model_cls, input_folder, output_folder, names, class_colors, save_cropped=False, detection_size=1280, classification_size=640):"""处理输入文件夹中的图像,执行对象检测和分类,并保存处理后的图像。:param model_det: 检测模型:param model_cls: 分类模型:param input_folder: 输入文件夹路径:param output_folder: 输出文件夹路径:param names: 类别名称列表:param class_colors: 类别颜色映射字典:param save_cropped: 是否保存裁剪的对象图像:param detection_size: 检测模型输入图像大小:param classification_size: 分类模型输入图像大小"""Path(output_folder).mkdir(parents=True, exist_ok=True)image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.webp']for ext in image_extensions:for img_path in Path(input_folder).glob(ext):print(f"正在处理: {img_path}")processed_img = process_image_with_detection_and_classification(model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped, detection_size, classification_size)if processed_img is not None:output_image_path = Path(output_folder) / f"{img_path.stem}_with_boxes_and_classification.jpg"cv2.imwrite(str(output_image_path), processed_img)print(f"已保存处理后的图像: {output_image_path}")else:print(f"跳过图像: {img_path} (无法处理)")if __name__ == '__main__':# 设置是否保存裁剪的对象图像(默认不保存)SAVE_CROPPED = True # 设置为 True 以启用保存裁剪对象# 加载检测和分类模型model_det = YOLO('runs/device_train/exp9/weights/best.pt')model_cls = YOLO('runs/cls_99.4%_exp14/weights/best.pt')# 设置输入和输出文件夹路径input_folder = 'test1'output_folder = 'infer-1216'# 获取类别名(用于生成一致的类别颜色映射)# 这里使用一张全白的图像来获取类别名black_image = 255 * np.ones((224, 224, 3), dtype=np.uint8)results = model_cls.predict(source=black_image)name_dict = results[0].namesnames = list(name_dict.values())# 只在这里生成一次类别颜色映射class_colors = generate_class_colors(names)# 开始处理图像process_images(model_det, model_cls, input_folder, output_folder,names, class_colors,save_cropped=SAVE_CROPPED,detection_size=1280,classification_size=224)
执行完后的结果
下面贴一下目标检测和图像分类的ultralytics的训练代码
目标检测训练代码
注意把single_cls=False改成True,变成单类训练
# nohup python -m torch.distributed.launch --nproc_per_node=4 --master_port=25643 det_train.py > output-lane-1212.txt 2>&1 &
# nohup python -m torch.distributed.launch --nproc_per_node=5 --master_port=25698 det_train.py > output-lane-1212.txt 2>&1 &
from ultralytics import YOLOif __name__ == '__main__':# 加载模型model = YOLO("checkpoints/yolo11l.pt") # 使用预训练权重训练# 训练参数 ----------------------------------------------------------------------------------------------model.train(data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_1212_yolo_without_vdd/config.yaml',epochs=150, # (int) 训练的周期数patience=50, # (int) 等待无明显改善以进行早期停止的周期数batch=16, # (int) 每批次的图像数量(-1 为自动批处理)imgsz=1280, # (int) 输入图像的大小,整数或w,hsave=True, # (bool) 保存训练检查点和预测结果save_period=-1, # (int) 每x周期保存检查点(如果小于1则禁用)cache=False, # (bool) True/ram、磁盘或False。使用缓存加载数据device='1,2,3,5', # (int | str | list, optional) 运行的设备,例如 cuda device=0 或 device=0,1,2,3 或 device=cpuworkers=8, # (int) 数据加载的工作线程数(每个DDP进程)project='runs/device_train', # (str, optional) 项目名称name='exp', # (str, optional) 实验名称,结果保存在'project/name'目录下exist_ok=False, # (bool) 是否覆盖现有实验pretrained=True, # (bool | str) 是否使用预训练模型(bool),或从中加载权重的模型(str)optimizer='auto', # (str) 要使用的优化器,选择=[SGD,Adam,Adamax,AdamW,NAdam,RAdam,RMSProp,auto]verbose=True, # (bool) 是否打印详细输出seed=0, # (int) 用于可重复性的随机种子deterministic=True, # (bool) 是否启用确定性模式single_cls=False, # (bool) 将多类数据训练为单类rect=False, # (bool) 如果mode='train',则进行矩形训练,如果mode='val',则进行矩形验证cos_lr=True, # (bool) 使用余弦学习率调度器close_mosaic=10, # (int) 在最后几个周期禁用马赛克增强resume=False, # (bool) 从上一个检查点恢复训练amp=True, # (bool) 自动混合精度(AMP)训练,选择=[True, False],True运行AMP检查fraction=1.0, # (float) 要训练的数据集分数(默认为1.0,训练集中的所有图像)profile=False, # (bool) 在训练期间为记录器启用ONNX和TensorRT速度freeze= None, # (int | list, 可选) 在训练期间冻结前 n 层,或冻结层索引列表。# 超参数 ----------------------------------------------------------------------------------------------lr0=0.01, # (float) 初始学习率(例如,SGD=1E-2,Adam=1E-3)lrf=0.01, # (float) 最终学习率(lr0 * lrf)momentum=0.937, # (float) SGD动量/Adam beta1weight_decay=0.0005, # (float) 优化器权重衰减 5e-4warmup_epochs=3.0, # (float) 预热周期(分数可用)warmup_momentum=0.8, # (float) 预热初始动量warmup_bias_lr=0.1, # (float) 预热初始偏置学习率box=6, # (float) 盒损失增益cls=1.5, # (float) 类别损失增益(与像素比例)dfl=1.5, # (float) dfl损失增益pose=12.0, # (float) 姿势损失增益kobj=1.0, # (float) 关键点对象损失增益label_smoothing=0.05, # (float) 标签平滑(分数)nbs=64, # (int) 名义批量大小hsv_h=0.015, # (float) 图像HSV-Hue增强(分数)hsv_s=0.7, # (float) 图像HSV-Saturation增强(分数)hsv_v=0.4, # (float) 图像HSV-Value增强(分数)degrees=90.0, # (float) 图像旋转(+/- deg)translate=0.5, # (float) 图像平移(+/- 分数)scale=0.5, # (float) 图像缩放(+/- 增益)shear=0.4, # (float) 图像剪切(+/- deg)perspective=0.0, # (float) 图像透视(+/- 分数),范围为0-0.001flipud=0.5, # (float) 图像上下翻转(概率)fliplr=0.5, # (float) 图像左右翻转(概率)mosaic=1.0, # (float) 图像马赛克(概率)mixup=0.0, # (float) 图像混合(概率)copy_paste=0.0, # (float) 分割复制-粘贴(概率))
图像分类训练代码
from ultralytics import YOLOmodel = YOLO("checkpoints/yolo11l-cls.pt")
model.train(data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate_grid_110%', project='runs/cls_train', # (str, optional) 项目名称name='exp', # (str, optional) 实验名称,结果保存在'project/name'目录下epochs=20, batch=1024,device='1,2,3,5',erasing=0.0,crop_fraction=1.0,augment=False,auto_augment=False,hsv_h=0.015, # (float) 图像HSV-Hue增强(分数)hsv_s=0.7, # (float) 图像HSV-Saturation增强(分数)hsv_v=0.4, # (float) 图像HSV-Value增强(分数)degrees=0.0, # (float) 图像旋转(+/- deg)translate=0.0, # (float) 图像平移(+/- 分数)scale=0.0, # (float) 图像缩放(+/- 增益)shear=0.0, # (float) 图像剪切(+/- deg)perspective=0.0, # (float) 图像透视(+/- 分数),范围为0-0.001flipud=0.5, # (float) 图像上下翻转(概率)fliplr=0.5, # (float) 图像左右翻转(概率)mosaic=1.0, # (float) 图像马赛克(概率)mixup=0.0) # (float) 图像混合(概率))
相关文章:

联合目标检测与图像分类提升数据不平衡场景下的准确率
联合目标检测与图像分类提升数据不平衡场景下的准确率 在一些数据不平衡的场景下,使用单一的目标检测模型很难达到99%的准确率。为了优化这一问题,适当将其拆解为目标检测模型和图像分类模型的组合,可以更有效地控制最终效果,尤其…...
Git的简介
文章目录 一.Git是什么二.核心概念三.工作流程四.Git的优势 下载Git 推荐官网下载 官网地址 一.Git是什么 Git是一个分布式版本控制系统,用于跟踪文件的变化并协调多人对同一项目的开发工作。它就像是一个时光机器,能够记录文件在不同时间点的状态&…...

麒麟操作系统服务架构保姆级教程(四)NGINX中间件
如果你想拥有你从未拥有过的东西,那么你必须去做你从未做过的事情 想要在网页上访问到代码那么就需要用到应用服务类中间件,国外的有Nginx,Tomcat等,国内的有金蝶web,东方通的服务中间件(Tongweb࿰…...
Glide 自定义圆角、铺满FitXY
在 Android 开发中,使用 Glide 来加载图片时,有时需要对图片进行特定的处理,比如设置圆角或者使图片完全填充到一个视图中(类似于 ImageView 的 scaleType 中的 FitXY)。以下是如何使用 Glide 来实现这些自定义需求的处…...

蓝牙协议——音乐启停控制
手机播放音乐 手机暂停音乐 耳机播放音乐 耳机暂停音乐...

Krita安装krita-ai-diffusion工具搭建comfyui报错没有ComfyUI_IPAdapter_plus解决办法
我们在使用Kirta安装krita-ai-diffusion工具之后搭建comfyui环境需要安装很多扩展文件。 一般正常安装都可以使用了。 但是有一个插件很特别,无论你安装多少遍都会显示缺失,是什么插件这么难搞定呢? 没错,就是我们的ComfyUI_IPAdapter_plus插件。 就像下图一样: 那么怎…...

四相机设计实现全向视觉感知的开源空中机器人无人机
开源空中机器人 基于深度学习的OmniNxt全向视觉算法OAK-4p-New 全景硬件同步相机 机器人的纯视觉避障定位建图一直是个难题: 系统实现复杂 纯视觉稳定性不高 很难选到实用的视觉传感器 为此多数厂家还是采用激光雷达的定位方案。 OAK-4p-New 为了弥合这一差距…...

LightGBM分类算法在医疗数据挖掘中的深度探索与应用创新(上)
一、引言 1.1 医疗数据挖掘的重要性与挑战 在当今数字化医疗时代,医疗数据呈爆炸式增长,这些数据蕴含着丰富的信息,对医疗决策具有极为重要的意义。通过对医疗数据的深入挖掘,可以发现潜在的疾病模式、治疗效果关联以及患者的健康风险因素,从而为精准医疗、个性化治疗方…...
JVM(Java虚拟机)的组成部分详解
摘要: JVM (Java Virtual Machine) 是一个抽象计算模型,它使Java程序可以在任何支持JVM的操作系统上运行,而无需考虑底层硬件架构。本文将深入探讨JVM的内部结构和工作机制,包括类加载器、运行时数据区、执行引擎以及内存管理等关…...
jsp中的四个域对象(Spring MVC)
在Spring MVC中,Model中的数据会被自动放入到请求域(Request Scope)中。也就是说,当我们在控制器中使用model.addAttribute()时,这些属性会被放入到HttpServletRequest对象的属性中。 让我们通过代码来详细解释&#…...
计算机基础知识复习12.24
http和https有那些区别 http是超文本传输协议,信息是明文传输,存在安全风险的问题,https则解决http不安全的缺点,在TCP和HTTP网络层之间加入了SSL/TLS安全协议,使得报文能够加密传输 http连接建立相对简单࿰…...
如何使用vscode解决git冲突
在使用VSCode时,遇到Git冲突是很常见的情况。Git冲突是指当多个人同时修改同一个文件的同一行或相邻行时,Git无法自动决定应该保留哪一个修改,需要手动解决这个冲突。 要解决Git冲突,可以按照以下步骤操作: 1. 打开V…...

告别卡顿:CasaOS轻NAS设备安装Gopeed打造高效下载环境
文章目录 前言1. 更新应用中心2.Gopeed安装与配置3. 本地下载测试4. 安装内网穿透工具5. 配置公网地址6. 配置固定公网地址 前言 无论你是需要大量文件传输的专业人士,还是只是想快速下载电影或音乐的普通用户,都会使用到下载工具。如果你对现有的下载工…...

Java 重写(Override)与重载(Overload)
重写 (Override) 重写是子类对父类的允许访问的方法的实现过程进行重新编写!返回值和形参都不能改变。即外壳不变,核心重写! 重写的好处在于子类可以根据需要,定义特定于自己的行为。 也就是说子类能够根据需要实现父类的方法。…...
HDFS与HBase有什么关系?
1 、 HDFS 文件存储系统和 HBase 分布式数据库 HDFS 是 Hadoop 分布式文件系统。 HBase 的数据通常存储在 HDFS 上。 HDFS 为 HBase 提供了高可靠性的底层存储支持。 Hbase 是 Hadoop database ,即 Hadoop 数据库。它是一个适合于非结构化数据存储的数据库, HBase 基于列的…...

CentOS7下的vsftpd服务器和客户端
目录 1、安装vsftpd服务器和ftp客户端; 2、配置vsftpd服务器,允许普通用户登录、下载、上传文件; 3、配置vsftpd服务器,允许anonymous用户登录、下载、上传文件; 4、配置vsftpd服务器,允许root用户登录…...

全网最详细Gradio教程系列10——Blocks:底层区块类(下)
全网最详细Gradio教程系列10——Blocks:底层区块类(下) 前言本篇摘要10. Blocks:底层区块类10.4 Blocks Layout:布局10.4.1 行与列1. Rows2. Columns 10.4.2 选项卡和折叠类10.4.3 重渲染.render()10.4.4 Group分组10.…...
嵌入式设备常用性能和内存调试指令
文章目录 嵌入式设备常用性能和内存调试指令内存问题分析性能测试android设备通过NDK 使用SimplePerf 抓取火焰图嵌入式linux抓取特定进程的perf火焰图 杂记 嵌入式设备常用性能和内存调试指令 内存问题分析 安装valgrind,按照如下指令执行应用程序: …...

数据库系统原理:数据恢复与备份策略
3.1可行性分析 开发者在进行开发系统之前,都需要进行可行性分析,保证该系统能够被成功开发出来。 3.1.1技术可行性 开发该《数据库系统原理》课程平台所采用的技术是vue和MYSQL数据库。计算机专业的学生在学校期间已经比较系统的学习了很多编程方面的知识…...
C++软件设计模式之装饰器模式
装饰器模式(Decorator Pattern)是C软件设计模式中的一种结构型设计模式,主要用于解决在不改变现有对象结构的情况下动态地给对象添加新功能的问题。通过使用装饰器模式,可以在运行时为对象添加新的行为,而不需要修改其…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析
1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...

K8S认证|CKS题库+答案| 11. AppArmor
目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作: 1)、切换集群 2)、切换节点 3)、切换到 apparmor 的目录 4)、执行 apparmor 策略模块 5)、修改 pod 文件 6)、…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)
目录 1.TCP的连接管理机制(1)三次握手①握手过程②对握手过程的理解 (2)四次挥手(3)握手和挥手的触发(4)状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...
JVM垃圾回收机制全解析
Java虚拟机(JVM)中的垃圾收集器(Garbage Collector,简称GC)是用于自动管理内存的机制。它负责识别和清除不再被程序使用的对象,从而释放内存空间,避免内存泄漏和内存溢出等问题。垃圾收集器在Ja…...

HBuilderX安装(uni-app和小程序开发)
下载HBuilderX 访问官方网站:https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本: Windows版(推荐下载标准版) Windows系统安装步骤 运行安装程序: 双击下载的.exe安装文件 如果出现安全提示&…...
OpenLayers 分屏对比(地图联动)
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 地图分屏对比在WebGIS开发中是很常见的功能,和卷帘图层不一样的是,分屏对比是在各个地图中添加相同或者不同的图层进行对比查看。…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题
分区配置 (ptab.json) img 属性介绍: img 属性指定分区存放的 image 名称,指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件,则以 proj_name:binary_name 格式指定文件名, proj_name 为工程 名&…...

iview框架主题色的应用
1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题,无需引入,直接可…...
前端中slice和splic的区别
1. slice slice 用于从数组中提取一部分元素,返回一个新的数组。 特点: 不修改原数组:slice 不会改变原数组,而是返回一个新的数组。提取数组的部分:slice 会根据指定的开始索引和结束索引提取数组的一部分。不包含…...
十九、【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建
【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建 前言准备工作第一部分:回顾 Django 内置的 `User` 模型第二部分:设计并创建 `Role` 和 `UserProfile` 模型第三部分:创建 Serializers第四部分:创建 ViewSets第五部分:注册 API 路由第六部分:后端初步测…...