当前位置: 首页 > news >正文

第99步 深度学习图像目标检测:SSDlite建模

基于WIN10的64位系统演示

一、写在前面

本期,我们继续学习深度学习图像目标检测系列,SSD(Single Shot MultiBox Detector)模型的后续版本,SSDlite模型。

二、SSDlite简介

SSDLite 是 SSD 模型的一个变种,旨在为移动设备和边缘计算设备提供更高效的目标检测。SSDLite 的主要特点是使用了轻量级的骨干网络和特定的卷积操作来减少计算复杂性,从而提高检测速度,同时在大多数情况下仍保持了较高的准确性。

以下是 SSDLite 的主要特性和组件:

(1)轻量级骨干:

SSDLite 不使用 VGG 或 ResNet 这样的重量级骨干。相反,它使用 MobileNet 作为骨干,特别是 MobileNetV2 或 MobileNetV3。这些网络使用深度可分离的卷积和其他轻量级操作来减少计算成本。

(2)深度可分离的卷积:

这是 MobileNet 的核心组件,也被用于 SSDLite。深度可分离的卷积将传统的卷积操作分解为两个较小的操作:一个深度卷积和一个点卷积,这大大减少了计算和参数数量。

(3)多尺度特征映射:

与原始的 SSD 相似,SSDLite 也从不同的层级提取特征图以检测不同大小的物体。

(4)默认框:

SSDLite 也使用默认框(或称为锚框)来进行边界框预测。

(5)单阶段检测:

与 SSD 相同,SSDLite 也是一个单阶段检测器,同时进行边界框回归和分类。

(6)损失函数:

SSDLite 使用与 SSD 相同的组合损失,包括平滑 L1 损失和交叉熵损失。

综上,SSDLite 是为了速度和效率而设计的,特别是针对计算和内存资源有限的设备。通过使用轻量级的骨干和深度可分离的卷积,它能够在减少计算负担的同时,仍然保持合理的检测准确性。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、SSDlite实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np# Function to parse XML annotations
def parse_xml(xml_path):tree = ET.parse(xml_path)root = tree.getroot()boxes = []for obj in root.findall("object"):bndbox = obj.find("bndbox")xmin = int(bndbox.find("xmin").text)ymin = int(bndbox.find("ymin").text)xmax = int(bndbox.find("xmax").text)ymax = int(bndbox.find("ymax").text)# Check if the bounding box is validif xmin < xmax and ymin < ymax:boxes.append((xmin, ymin, xmax, ymax))else:print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")return boxes# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]random.shuffle(all_images)split_idx = int(len(all_images) * split_ratio)train_images = all_images[:split_idx]val_images = all_images[split_idx:]return train_images, val_images# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):def __init__(self, image_dir, annotation_dir, image_list, transform=None):self.image_dir = image_dirself.annotation_dir = annotation_dirself.image_list = image_listself.transform = transformdef __len__(self):return len(self.image_list)def __getitem__(self, idx):image_path = os.path.join(self.image_dir, self.image_list[idx])image = Image.open(image_path).convert("RGB")xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))boxes = parse_xml(xml_path)# Check for empty bounding boxes and return Noneif len(boxes) == 0:return Noneboxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.ones((len(boxes),), dtype=torch.int64)iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["image_id"] = torch.tensor([idx])target["iscrowd"] = iscrowd# Apply transformationsif self.transform:image = self.transform(image)return image, target# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # Convert PIL image to tensortorchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))return tuple(zip(*batch))def get_ssdlite_model_for_finetuning(num_classes):# Load an SSDlite model with a MobileNetV3 Large backbone without pre-trained weightsmodel = ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)return model# Function to save the model
def save_model(model, path="SSDlite_mtb.pth", save_full_model=False):if save_full_model:torch.save(model, path)else:torch.save(model.state_dict(), path)print(f"Model saved to {path}")# Function to compute Intersection over Union
def compute_iou(boxA, boxB):xA = max(boxA[0], boxB[0])yA = max(boxA[1], boxB[1])xB = min(boxA[2], boxB[2])yB = min(boxA[3], boxB[3])interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)iou = interArea / float(boxAArea + boxBArea - interArea)return iou# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))if len(batch) == 0:# Return placeholder batch if entirely emptyreturn [torch.zeros(1, 3, 224, 224)], [{}]return tuple(zip(*batch))#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):model.train()model.to(device)loss_values = []iou_values = []for epoch in range(num_epochs):epoch_loss = 0.0total_ious = 0num_boxes = 0for images, targets in train_loader:# Skip batches with placeholder dataif len(targets) == 1 and not targets[0]:continue# Skip batches with empty targetsif any(len(target["boxes"]) == 0 for target in targets):continueimages = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]loss_dict = model(images, targets)losses = sum(loss for loss in loss_dict.values())optimizer.zero_grad()losses.backward()optimizer.step()epoch_loss += losses.item()# Compute IoU for evaluationwith torch.no_grad():model.eval()predictions = model(images)for i, prediction in enumerate(predictions):pred_boxes = prediction["boxes"].cpu().numpy()true_boxes = targets[i]["boxes"].cpu().numpy()for pred_box in pred_boxes:for true_box in true_boxes:iou = compute_iou(pred_box, true_box)total_ious += iounum_boxes += 1model.train()avg_loss = epoch_loss / len(train_loader)avg_iou = total_ious / num_boxes if num_boxes != 0 else 0loss_values.append(avg_loss)iou_values.append(avg_iou)print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")# Plotting loss and IoU valuesplt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(loss_values, label="Training Loss")plt.title("Training Loss across Epochs")plt.xlabel("Epochs")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(iou_values, label="IoU")plt.title("IoU across Epochs")plt.xlabel("Epochs")plt.ylabel("IoU")plt.show()# Save model after trainingsave_model(model)# Validation function
def validate_model(model, val_loader, device):model.eval()model.to(device)with torch.no_grad():for images, targets in val_loader:images = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]model(images)# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"# Split data
train_images, val_images = split_data(image_dir)# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)# Model and optimizer
model = get_ssdlite_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")

需要从头训练的,就不跑了,摆烂了。

五、写在后面

目标检测模型门槛更高了,运行起来对硬件要求也很高,时间也很久,都是小时起步的。因此只是简单介绍,算是入个门了。

相关文章:

第99步 深度学习图像目标检测:SSDlite建模

基于WIN10的64位系统演示 一、写在前面 本期&#xff0c;我们继续学习深度学习图像目标检测系列&#xff0c;SSD&#xff08;Single Shot MultiBox Detector&#xff09;模型的后续版本&#xff0c;SSDlite模型。 二、SSDlite简介 SSDLite 是 SSD 模型的一个变种&#xff0c…...

用EasyAVFilter将网络文件或者本地文件推送RTMP出去的时候发现CPU占用好高,用的也是vcodec copy呀,什么原因?

最近同事在用EasyAVFilter集成在EasyDarwin中做视频拉流转推RTMP流的功能的时候&#xff0c;发现怎么做CPU占用都会很高&#xff0c;但是视频没有调用转码&#xff0c;vcodec用的就是copy&#xff0c;这是什么原因呢&#xff1f; 我们用在线的RTSP流就不会出现这种情况&#x…...

Vatee万腾科技的独特力量:Vatee数字时代创新的新视野

在数字化时代的浪潮中&#xff0c;Vatee万腾科技以其独特而强大的创新力量&#xff0c;为整个行业描绘了一幅崭新的视野。这不仅是一场科技创新的冒险&#xff0c;更是对未来数字时代发展方向的领先探索。 Vatee万腾将创新视为数字时代发展的引擎&#xff0c;成为推动行业向前的…...

【JavaSE】基础笔记 - 异常(Exception)

目录 1、异常的概念和体系结构 1.1、异常的概念 1.2、 异常的体系结构 1.3 异常的分类 2、异常的处理 2.1、防御式编程 2.2、异常的抛出 2.3、异常的捕获 2.3.1、异常声明throws 2.3.2、try-catch捕获并处理 3、自定义异常类 1、异常的概念和体系结构 1.1、异常的…...

QTableWidget——编辑单元格

文章目录 前言熟悉QTableWiget&#xff0c;通过实现单元格的合并、拆分、通过编辑界面实现表格内容及属性的配置、实现表格的粘贴复制功能熟悉QTableWiget的属性 一、[单元格的合并、拆分](https://blog.csdn.net/qq_15672897/article/details/134476530?spm1001.2014.3001.55…...

编译QT Mysql库并集成使用

安装MSVC编译器与Windows 10 SDK 打开Visual Studio Installer&#xff0c;如果已经安装过内容了可能是如下页面&#xff0c;点击修改&#xff08;头一回打开的话不需要这一步&#xff09;&#xff1a; 然后在工作负荷中勾选使用C的桌面开发&#xff0c;它会帮我们勾选好一些…...

利用企业被执行人信息查询API保障商业交易安全

前言 在当今竞争激烈的商业环境中&#xff0c;企业为了保障商业交易的安全性不断寻求新的手段。随着技术的发展&#xff0c;利用企业被执行人信息查询API已经成为了一种强有力的工具&#xff0c;能够帮助企业在商业交易中降低风险&#xff0c;提高合作的信任度。 企业被执行人…...

【深度学习】P1 深度学习基础框架 - 张量 Tensor

深度学习基础框架 张量 Tensor 张量数据操作导入创建张量获取张量信息改变张量张量运算 张量与内存 张量 Pytorch 是一个深度学习框架&#xff0c;用于开发和训练神经网络模型。 而其核心数据结构&#xff0c;则是张量 Tensor&#xff0c;类似于 Numpy 数组&#xff0c;但是可…...

vue2 识别页面参数中的html

在Vue 2中&#xff0c;你可以使用v-html指令来识别页面参数中的HTML内容。v-html指令允许你将HTML代码作为Vue模板的一部分进行渲染。 以下是一个示例&#xff0c;演示了如何在Vue 2中使用v-html指令来识别页面参数中的HTML内容&#xff1a; <template><div v-html&…...

matlab 一些画图法总结(持续更新)

*****************************************画Dmd_L极坐标表示法**************************************** if(~exist(Dmd_L_array)) Dmd_L_array []; end Dmd_L_array [Dmd_L_array; Dmd_L]; thetaangle(Dmd_L_array); rabs(Dmd_L_array); polarplot(theta,r,o); *****…...

MDK AC5和AC6是什么?在KEIL5中添加和选择ARMCC版本

前言 看视频有UP主提到“AC5”“AC6”这样的词&#xff0c;一开始有些不理解&#xff0c;原来他说的是ARMCC版本。 keil自带的是ARMCC5&#xff0c;由于ARMCC5已经停止维护了&#xff0c;很多开发者会选择ARMCC6。 在维护公司“成年往事”项目可能就会遇到新KEIL旧版本编译器…...

杰发科技AC7801——EEP内存分布情况

简介 按照文档进行配置 核心代码如下 /*!* file sweeprom_demo.c** brief This file provides sweeprom demo test function.**//* Includes */ #include <stdlib.h> #include "ac780x_sweeprom.h" #include "ac780x_debugout.h"/* Define …...

【mybatis注解实现条件查询】

文章目录 步骤1: 引入MyBatis依赖步骤2: 创建数据模型步骤3: 创建Mapper接口步骤4: 配置MyBatis步骤5: 执行条件查询 步骤1: 引入MyBatis依赖 <dependency><groupId>org.mybatis</groupId><artifactId>mybatis</artifactId><version>3.x.…...

【广州华锐互动】VR线上课件制作软件满足数字化教学需求

随着科技的不断发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术在教学领域的应用逐渐成为趋势。其中&#xff0c;广州华锐互动开发的VR线上课件制作软件更是备受关注。这种工具为教师提供了便捷的制作VR课件的手段&#xff0c;使得VR教学成为可能&#xff0c;极大地丰…...

MySQL 中 DELETE 语句中可以使用别名么?

某天&#xff0c;正按照业务的要求删除不需要的数据&#xff0c;在执行 DELETE 语句时&#xff0c;竟然出现了报错&#xff01; 作者&#xff1a;林靖华&#xff0c;开源数据库技术爱好者&#xff0c;擅长MySQL和Redis的运维 爱可生开源社区出品&#xff0c;原创内容未经授权不…...

flutter创建不同样式的按钮,背景色,边框,圆角,圆形,大小都可以设置

在ui设计中&#xff0c;可能按钮会有不同的样式需要你来写出来&#xff0c;所以按钮的不同样式&#xff0c;应该是最基础的功能&#xff0c;在这里我们赶紧学起来吧&#xff0c;web端可能展示有问题&#xff0c;需要优化&#xff0c;但是基本样式还是出来了 我是将所有的按钮放…...

【C++】标准模板库STL作业(其二)

&#x1f383;个人专栏&#xff1a; &#x1f42c; 算法设计与分析&#xff1a;算法设计与分析_IT闫的博客-CSDN博客 &#x1f433;Java基础&#xff1a;Java基础_IT闫的博客-CSDN博客 &#x1f40b;c语言&#xff1a;c语言_IT闫的博客-CSDN博客 &#x1f41f;MySQL&#xff1a…...

基于SpringBoot+Redis实现点赞/排行榜功能,可同理实现收藏/关注功能,可拓展实现共同好友/共同关注/关注推送功能

前言 简单记录一下在SpringBoot项目中&#xff0c;使用Redis实现点赞/排行榜功能&#xff0c;可同理实现收藏/关注功能&#xff0c;可拓展实现共同好友/共同关注/关注推送功。主要用到了Redis中的Set集合和ZSet集合。 一、指定使用某个索引的数据库 在Redis中&#xff0c;可…...

AI“胡说八道”?怎么解?

原创 | 文 BFT机器人 01 引言 近年来&#xff0c;人工智能产业迅猛发展&#xff0c;大型语言模型GPT-4发展势头强劲&#xff0c;OpenAI推出ChatGPT、微软推出Bing、马斯克推出“最好的聊天机器人Grok”……科技巨头纷纷入局AI领域&#xff0c;引入人工智能作为办公工具的行业…...

[SIGGRAPH-23] 3D Gaussian Splatting for Real-Time Radiance Field Rendering

pdf | proj | code 本文提出一种新的3D数据表达形式3D Gaussians。每个Gaussian由以下参数组成&#xff1a;中心点位置、协方差矩阵、可见性、颜色。通过世界坐标系到相机坐标系&#xff0c;再到图像坐标系的仿射关系&#xff0c;可将3D Gaussian映射到相机坐标系&#xff0c;通…...

Android Wi-Fi 连接失败日志分析

1. Android wifi 关键日志总结 (1) Wi-Fi 断开 (CTRL-EVENT-DISCONNECTED reason3) 日志相关部分&#xff1a; 06-05 10:48:40.987 943 943 I wpa_supplicant: wlan0: CTRL-EVENT-DISCONNECTED bssid44:9b:c1:57:a8:90 reason3 locally_generated1解析&#xff1a; CTR…...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

K8S认证|CKS题库+答案| 11. AppArmor

目录 11. AppArmor 免费获取并激活 CKA_v1.31_模拟系统 题目 开始操作&#xff1a; 1&#xff09;、切换集群 2&#xff09;、切换节点 3&#xff09;、切换到 apparmor 的目录 4&#xff09;、执行 apparmor 策略模块 5&#xff09;、修改 pod 文件 6&#xff09;、…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

《Playwright:微软的自动化测试工具详解》

Playwright 简介:声明内容来自网络&#xff0c;将内容拼接整理出来的文档 Playwright 是微软开发的自动化测试工具&#xff0c;支持 Chrome、Firefox、Safari 等主流浏览器&#xff0c;提供多语言 API&#xff08;Python、JavaScript、Java、.NET&#xff09;。它的特点包括&a…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

MODBUS TCP转CANopen 技术赋能高效协同作业

在现代工业自动化领域&#xff0c;MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步&#xff0c;这两种通讯协议也正在被逐步融合&#xff0c;形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...

DBAPI如何优雅的获取单条数据

API如何优雅的获取单条数据 案例一 对于查询类API&#xff0c;查询的是单条数据&#xff0c;比如根据主键ID查询用户信息&#xff0c;sql如下&#xff1a; select id, name, age from user where id #{id}API默认返回的数据格式是多条的&#xff0c;如下&#xff1a; {&qu…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...