当前位置: 首页 > news >正文

第99步 深度学习图像目标检测:SSDlite建模

基于WIN10的64位系统演示

一、写在前面

本期,我们继续学习深度学习图像目标检测系列,SSD(Single Shot MultiBox Detector)模型的后续版本,SSDlite模型。

二、SSDlite简介

SSDLite 是 SSD 模型的一个变种,旨在为移动设备和边缘计算设备提供更高效的目标检测。SSDLite 的主要特点是使用了轻量级的骨干网络和特定的卷积操作来减少计算复杂性,从而提高检测速度,同时在大多数情况下仍保持了较高的准确性。

以下是 SSDLite 的主要特性和组件:

(1)轻量级骨干:

SSDLite 不使用 VGG 或 ResNet 这样的重量级骨干。相反,它使用 MobileNet 作为骨干,特别是 MobileNetV2 或 MobileNetV3。这些网络使用深度可分离的卷积和其他轻量级操作来减少计算成本。

(2)深度可分离的卷积:

这是 MobileNet 的核心组件,也被用于 SSDLite。深度可分离的卷积将传统的卷积操作分解为两个较小的操作:一个深度卷积和一个点卷积,这大大减少了计算和参数数量。

(3)多尺度特征映射:

与原始的 SSD 相似,SSDLite 也从不同的层级提取特征图以检测不同大小的物体。

(4)默认框:

SSDLite 也使用默认框(或称为锚框)来进行边界框预测。

(5)单阶段检测:

与 SSD 相同,SSDLite 也是一个单阶段检测器,同时进行边界框回归和分类。

(6)损失函数:

SSDLite 使用与 SSD 相同的组合损失,包括平滑 L1 损失和交叉熵损失。

综上,SSDLite 是为了速度和效率而设计的,特别是针对计算和内存资源有限的设备。通过使用轻量级的骨干和深度可分离的卷积,它能够在减少计算负担的同时,仍然保持合理的检测准确性。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、SSDlite实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np# Function to parse XML annotations
def parse_xml(xml_path):tree = ET.parse(xml_path)root = tree.getroot()boxes = []for obj in root.findall("object"):bndbox = obj.find("bndbox")xmin = int(bndbox.find("xmin").text)ymin = int(bndbox.find("ymin").text)xmax = int(bndbox.find("xmax").text)ymax = int(bndbox.find("ymax").text)# Check if the bounding box is validif xmin < xmax and ymin < ymax:boxes.append((xmin, ymin, xmax, ymax))else:print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")return boxes# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]random.shuffle(all_images)split_idx = int(len(all_images) * split_ratio)train_images = all_images[:split_idx]val_images = all_images[split_idx:]return train_images, val_images# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):def __init__(self, image_dir, annotation_dir, image_list, transform=None):self.image_dir = image_dirself.annotation_dir = annotation_dirself.image_list = image_listself.transform = transformdef __len__(self):return len(self.image_list)def __getitem__(self, idx):image_path = os.path.join(self.image_dir, self.image_list[idx])image = Image.open(image_path).convert("RGB")xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))boxes = parse_xml(xml_path)# Check for empty bounding boxes and return Noneif len(boxes) == 0:return Noneboxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.ones((len(boxes),), dtype=torch.int64)iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["image_id"] = torch.tensor([idx])target["iscrowd"] = iscrowd# Apply transformationsif self.transform:image = self.transform(image)return image, target# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # Convert PIL image to tensortorchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))return tuple(zip(*batch))def get_ssdlite_model_for_finetuning(num_classes):# Load an SSDlite model with a MobileNetV3 Large backbone without pre-trained weightsmodel = ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)return model# Function to save the model
def save_model(model, path="SSDlite_mtb.pth", save_full_model=False):if save_full_model:torch.save(model, path)else:torch.save(model.state_dict(), path)print(f"Model saved to {path}")# Function to compute Intersection over Union
def compute_iou(boxA, boxB):xA = max(boxA[0], boxB[0])yA = max(boxA[1], boxB[1])xB = min(boxA[2], boxB[2])yB = min(boxA[3], boxB[3])interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)iou = interArea / float(boxAArea + boxBArea - interArea)return iou# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))if len(batch) == 0:# Return placeholder batch if entirely emptyreturn [torch.zeros(1, 3, 224, 224)], [{}]return tuple(zip(*batch))#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):model.train()model.to(device)loss_values = []iou_values = []for epoch in range(num_epochs):epoch_loss = 0.0total_ious = 0num_boxes = 0for images, targets in train_loader:# Skip batches with placeholder dataif len(targets) == 1 and not targets[0]:continue# Skip batches with empty targetsif any(len(target["boxes"]) == 0 for target in targets):continueimages = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]loss_dict = model(images, targets)losses = sum(loss for loss in loss_dict.values())optimizer.zero_grad()losses.backward()optimizer.step()epoch_loss += losses.item()# Compute IoU for evaluationwith torch.no_grad():model.eval()predictions = model(images)for i, prediction in enumerate(predictions):pred_boxes = prediction["boxes"].cpu().numpy()true_boxes = targets[i]["boxes"].cpu().numpy()for pred_box in pred_boxes:for true_box in true_boxes:iou = compute_iou(pred_box, true_box)total_ious += iounum_boxes += 1model.train()avg_loss = epoch_loss / len(train_loader)avg_iou = total_ious / num_boxes if num_boxes != 0 else 0loss_values.append(avg_loss)iou_values.append(avg_iou)print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")# Plotting loss and IoU valuesplt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(loss_values, label="Training Loss")plt.title("Training Loss across Epochs")plt.xlabel("Epochs")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(iou_values, label="IoU")plt.title("IoU across Epochs")plt.xlabel("Epochs")plt.ylabel("IoU")plt.show()# Save model after trainingsave_model(model)# Validation function
def validate_model(model, val_loader, device):model.eval()model.to(device)with torch.no_grad():for images, targets in val_loader:images = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]model(images)# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"# Split data
train_images, val_images = split_data(image_dir)# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)# Model and optimizer
model = get_ssdlite_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")

需要从头训练的,就不跑了,摆烂了。

五、写在后面

目标检测模型门槛更高了,运行起来对硬件要求也很高,时间也很久,都是小时起步的。因此只是简单介绍,算是入个门了。

相关文章:

第99步 深度学习图像目标检测:SSDlite建模

基于WIN10的64位系统演示 一、写在前面 本期&#xff0c;我们继续学习深度学习图像目标检测系列&#xff0c;SSD&#xff08;Single Shot MultiBox Detector&#xff09;模型的后续版本&#xff0c;SSDlite模型。 二、SSDlite简介 SSDLite 是 SSD 模型的一个变种&#xff0c…...

用EasyAVFilter将网络文件或者本地文件推送RTMP出去的时候发现CPU占用好高,用的也是vcodec copy呀,什么原因?

最近同事在用EasyAVFilter集成在EasyDarwin中做视频拉流转推RTMP流的功能的时候&#xff0c;发现怎么做CPU占用都会很高&#xff0c;但是视频没有调用转码&#xff0c;vcodec用的就是copy&#xff0c;这是什么原因呢&#xff1f; 我们用在线的RTSP流就不会出现这种情况&#x…...

Vatee万腾科技的独特力量:Vatee数字时代创新的新视野

在数字化时代的浪潮中&#xff0c;Vatee万腾科技以其独特而强大的创新力量&#xff0c;为整个行业描绘了一幅崭新的视野。这不仅是一场科技创新的冒险&#xff0c;更是对未来数字时代发展方向的领先探索。 Vatee万腾将创新视为数字时代发展的引擎&#xff0c;成为推动行业向前的…...

【JavaSE】基础笔记 - 异常(Exception)

目录 1、异常的概念和体系结构 1.1、异常的概念 1.2、 异常的体系结构 1.3 异常的分类 2、异常的处理 2.1、防御式编程 2.2、异常的抛出 2.3、异常的捕获 2.3.1、异常声明throws 2.3.2、try-catch捕获并处理 3、自定义异常类 1、异常的概念和体系结构 1.1、异常的…...

QTableWidget——编辑单元格

文章目录 前言熟悉QTableWiget&#xff0c;通过实现单元格的合并、拆分、通过编辑界面实现表格内容及属性的配置、实现表格的粘贴复制功能熟悉QTableWiget的属性 一、[单元格的合并、拆分](https://blog.csdn.net/qq_15672897/article/details/134476530?spm1001.2014.3001.55…...

编译QT Mysql库并集成使用

安装MSVC编译器与Windows 10 SDK 打开Visual Studio Installer&#xff0c;如果已经安装过内容了可能是如下页面&#xff0c;点击修改&#xff08;头一回打开的话不需要这一步&#xff09;&#xff1a; 然后在工作负荷中勾选使用C的桌面开发&#xff0c;它会帮我们勾选好一些…...

利用企业被执行人信息查询API保障商业交易安全

前言 在当今竞争激烈的商业环境中&#xff0c;企业为了保障商业交易的安全性不断寻求新的手段。随着技术的发展&#xff0c;利用企业被执行人信息查询API已经成为了一种强有力的工具&#xff0c;能够帮助企业在商业交易中降低风险&#xff0c;提高合作的信任度。 企业被执行人…...

【深度学习】P1 深度学习基础框架 - 张量 Tensor

深度学习基础框架 张量 Tensor 张量数据操作导入创建张量获取张量信息改变张量张量运算 张量与内存 张量 Pytorch 是一个深度学习框架&#xff0c;用于开发和训练神经网络模型。 而其核心数据结构&#xff0c;则是张量 Tensor&#xff0c;类似于 Numpy 数组&#xff0c;但是可…...

vue2 识别页面参数中的html

在Vue 2中&#xff0c;你可以使用v-html指令来识别页面参数中的HTML内容。v-html指令允许你将HTML代码作为Vue模板的一部分进行渲染。 以下是一个示例&#xff0c;演示了如何在Vue 2中使用v-html指令来识别页面参数中的HTML内容&#xff1a; <template><div v-html&…...

matlab 一些画图法总结(持续更新)

*****************************************画Dmd_L极坐标表示法**************************************** if(~exist(Dmd_L_array)) Dmd_L_array []; end Dmd_L_array [Dmd_L_array; Dmd_L]; thetaangle(Dmd_L_array); rabs(Dmd_L_array); polarplot(theta,r,o); *****…...

MDK AC5和AC6是什么?在KEIL5中添加和选择ARMCC版本

前言 看视频有UP主提到“AC5”“AC6”这样的词&#xff0c;一开始有些不理解&#xff0c;原来他说的是ARMCC版本。 keil自带的是ARMCC5&#xff0c;由于ARMCC5已经停止维护了&#xff0c;很多开发者会选择ARMCC6。 在维护公司“成年往事”项目可能就会遇到新KEIL旧版本编译器…...

杰发科技AC7801——EEP内存分布情况

简介 按照文档进行配置 核心代码如下 /*!* file sweeprom_demo.c** brief This file provides sweeprom demo test function.**//* Includes */ #include <stdlib.h> #include "ac780x_sweeprom.h" #include "ac780x_debugout.h"/* Define …...

【mybatis注解实现条件查询】

文章目录 步骤1: 引入MyBatis依赖步骤2: 创建数据模型步骤3: 创建Mapper接口步骤4: 配置MyBatis步骤5: 执行条件查询 步骤1: 引入MyBatis依赖 <dependency><groupId>org.mybatis</groupId><artifactId>mybatis</artifactId><version>3.x.…...

【广州华锐互动】VR线上课件制作软件满足数字化教学需求

随着科技的不断发展&#xff0c;虚拟现实&#xff08;VR&#xff09;技术在教学领域的应用逐渐成为趋势。其中&#xff0c;广州华锐互动开发的VR线上课件制作软件更是备受关注。这种工具为教师提供了便捷的制作VR课件的手段&#xff0c;使得VR教学成为可能&#xff0c;极大地丰…...

MySQL 中 DELETE 语句中可以使用别名么?

某天&#xff0c;正按照业务的要求删除不需要的数据&#xff0c;在执行 DELETE 语句时&#xff0c;竟然出现了报错&#xff01; 作者&#xff1a;林靖华&#xff0c;开源数据库技术爱好者&#xff0c;擅长MySQL和Redis的运维 爱可生开源社区出品&#xff0c;原创内容未经授权不…...

flutter创建不同样式的按钮,背景色,边框,圆角,圆形,大小都可以设置

在ui设计中&#xff0c;可能按钮会有不同的样式需要你来写出来&#xff0c;所以按钮的不同样式&#xff0c;应该是最基础的功能&#xff0c;在这里我们赶紧学起来吧&#xff0c;web端可能展示有问题&#xff0c;需要优化&#xff0c;但是基本样式还是出来了 我是将所有的按钮放…...

【C++】标准模板库STL作业(其二)

&#x1f383;个人专栏&#xff1a; &#x1f42c; 算法设计与分析&#xff1a;算法设计与分析_IT闫的博客-CSDN博客 &#x1f433;Java基础&#xff1a;Java基础_IT闫的博客-CSDN博客 &#x1f40b;c语言&#xff1a;c语言_IT闫的博客-CSDN博客 &#x1f41f;MySQL&#xff1a…...

基于SpringBoot+Redis实现点赞/排行榜功能,可同理实现收藏/关注功能,可拓展实现共同好友/共同关注/关注推送功能

前言 简单记录一下在SpringBoot项目中&#xff0c;使用Redis实现点赞/排行榜功能&#xff0c;可同理实现收藏/关注功能&#xff0c;可拓展实现共同好友/共同关注/关注推送功。主要用到了Redis中的Set集合和ZSet集合。 一、指定使用某个索引的数据库 在Redis中&#xff0c;可…...

AI“胡说八道”?怎么解?

原创 | 文 BFT机器人 01 引言 近年来&#xff0c;人工智能产业迅猛发展&#xff0c;大型语言模型GPT-4发展势头强劲&#xff0c;OpenAI推出ChatGPT、微软推出Bing、马斯克推出“最好的聊天机器人Grok”……科技巨头纷纷入局AI领域&#xff0c;引入人工智能作为办公工具的行业…...

[SIGGRAPH-23] 3D Gaussian Splatting for Real-Time Radiance Field Rendering

pdf | proj | code 本文提出一种新的3D数据表达形式3D Gaussians。每个Gaussian由以下参数组成&#xff1a;中心点位置、协方差矩阵、可见性、颜色。通过世界坐标系到相机坐标系&#xff0c;再到图像坐标系的仿射关系&#xff0c;可将3D Gaussian映射到相机坐标系&#xff0c;通…...

mongodb源码分析session执行handleRequest命令find过程

mongo/transport/service_state_machine.cpp已经分析startSession创建ASIOSession过程&#xff0c;并且验证connection是否超过限制ASIOSession和connection是循环接受客户端命令&#xff0c;把数据流转换成Message&#xff0c;状态转变流程是&#xff1a;State::Created 》 St…...

【决胜公务员考试】求职OMG——见面课测验1

2025最新版&#xff01;&#xff01;&#xff01;6.8截至答题&#xff0c;大家注意呀&#xff01; 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:&#xff08; B &#xff09; A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...

【C语言练习】080. 使用C语言实现简单的数据库操作

080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文&#xff5c;DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色&#xff0c;华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型&#xff0c;能助力我们轻松驾驭 DeepSeek-V3/R1&#xff0c;本文中将分享如何…...

企业如何增强终端安全?

在数字化转型加速的今天&#xff0c;企业的业务运行越来越依赖于终端设备。从员工的笔记本电脑、智能手机&#xff0c;到工厂里的物联网设备、智能传感器&#xff0c;这些终端构成了企业与外部世界连接的 “神经末梢”。然而&#xff0c;随着远程办公的常态化和设备接入的爆炸式…...

10-Oracle 23 ai Vector Search 概述和参数

一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI&#xff0c;使用客户端或是内部自己搭建集成大模型的终端&#xff0c;加速与大型语言模型&#xff08;LLM&#xff09;的结合&#xff0c;同时使用检索增强生成&#xff08;Retrieval Augmented Generation &#…...

宇树科技,改名了!

提到国内具身智能和机器人领域的代表企业&#xff0c;那宇树科技&#xff08;Unitree&#xff09;必须名列其榜。 最近&#xff0c;宇树科技的一项新变动消息在业界引发了不少关注和讨论&#xff0c;即&#xff1a; 宇树向其合作伙伴发布了一封公司名称变更函称&#xff0c;因…...

NPOI操作EXCEL文件 ——CAD C# 二次开发

缺点:dll.版本容易加载错误。CAD加载插件时&#xff0c;没有加载所有类库。插件运行过程中用到某个类库&#xff0c;会从CAD的安装目录找&#xff0c;找不到就报错了。 【方案2】让CAD在加载过程中把类库加载到内存 【方案3】是发现缺少了哪个库&#xff0c;就用插件程序加载进…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...

Vue 模板语句的数据来源

&#x1f9e9; Vue 模板语句的数据来源&#xff1a;全方位解析 Vue 模板&#xff08;<template> 部分&#xff09;中的表达式、指令绑定&#xff08;如 v-bind, v-on&#xff09;和插值&#xff08;{{ }}&#xff09;都在一个特定的作用域内求值。这个作用域由当前 组件…...