当前位置: 首页 > news >正文

PyTorch 目标检测教程

PyTorch 目标检测教程

本教程将介绍如何在 PyTorch 中使用几种常见的目标检测模型,包括 Faster R-CNNSSD 以及 YOLO (You Only Look Once)。我们将涵盖预训练模型的使用、推理、微调,以及自定义数据集上的训练。

1. 目标检测概述

目标检测任务不仅要识别图像中的物体类别,还要精确定位物体的边界框。在此任务中,每个模型输出一个物体类别标签和一个边界框。

常见目标检测模型:
  • Faster R-CNN:基于区域提议方法,具有较高的检测精度。
  • SSD (Single Shot MultiBox Detector):单阶段检测器,速度较快,适合实时应用。
  • YOLO (You Only Look Once):单阶段检测器,具有极快的检测速度,适合大规模实时检测。

2. 官方文档链接

  • PyTorch 官方文档
  • Torchvision 模型
  • YOLO 官方实现 (YOLOv5 及 PyTorch 版 YOLO)

3. Faster R-CNN 模型

3.1 加载 Faster R-CNN 模型并进行推理
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches# 加载预训练的 Faster R-CNN 模型
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()  # 切换为评估模式# 加载和预处理图像
image_path = "test_image.jpg"
image = Image.open(image_path)
image_tensor = F.to_tensor(image).unsqueeze(0)  # 转换为张量并添加批次维度# 将模型和输入图像移动到 GPU(如果可用)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
image_tensor = image_tensor.to(device)# 进行推理
with torch.no_grad():predictions = model(image_tensor)# 获取预测的边界框和类别标签
boxes = predictions[0]['boxes'].cpu().numpy()  # 预测的边界框
labels = predictions[0]['labels'].cpu().numpy()  # 预测的类别标签
scores = predictions[0]['scores'].cpu().numpy()  # 预测的分数# 可视化预测结果
fig, ax = plt.subplots(1)
ax.imshow(image)# 设置阈值,只显示高置信度的检测结果
threshold = 0.5
for box, label, score in zip(boxes, labels, scores):if score > threshold:rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=2, edgecolor='r', facecolor='none')ax.add_patch(rect)ax.text(box[0], box[1], f'{label}: {score:.2f}', color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))plt.show()
3.2 微调 Faster R-CNN 模型

你可以使用自定义的数据集微调 Faster R-CNN。下面是如何定义自定义数据集并在上面进行训练。

import torch
from torch.utils.data import Dataset
from PIL import Image# 自定义数据集类
class CustomDataset(Dataset):def __init__(self, image_paths, annotations):self.image_paths = image_pathsself.annotations = annotationsdef __len__(self):return len(self.image_paths)def __getitem__(self, idx):image = Image.open(self.image_paths[idx]).convert("RGB")boxes, labels = self.annotations[idx]boxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.as_tensor(labels, dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelsimage = F.to_tensor(image)return image, target
from torch.utils.data import DataLoader
from torchvision.models.detection import fasterrcnn_resnet50_fpn# 加载自定义数据集
dataset = CustomDataset(image_paths=["img1.jpg", "img2.jpg"], annotations=[([[10, 20, 200, 300]], [1]), ([[30, 40, 180, 220]], [2])])
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))# 加载预训练的 Faster R-CNN 模型
model = fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 3
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torch.nn.Linear(in_features, num_classes)# 定义优化器并进行训练
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)# 训练模型
model.train()
for epoch in range(5):for images, targets in dataloader:images = list(image.to(device) for image in images)targets = [{k: v.to(device) for k, v in t.items()} for t in targets]loss_dict = model(images, targets)losses = sum(loss for loss in loss_dict.values())optimizer.zero_grad()losses.backward()optimizer.step()print(f'Epoch [{epoch+1}/5], Loss: {losses.item():.4f}')

4. SSD 模型

SSD(Single Shot MultiBox Detector)是一种速度较快的单阶段检测器。torchvision 也提供了预训练的 SSD 模型。

4.1 使用预训练 SSD 模型进行推理
import torch
from torchvision.models.detection import ssd300_vgg16
from torchvision.transforms import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches# 加载预训练的 SSD 模型
model = ssd300_vgg16(pretrained=True)
model.eval()# 加载和预处理图像
image_path = "test_image.jpg"
image = Image.open(image_path).convert("RGB")
image_tensor = F.to_tensor(image).unsqueeze(0)# 进行推理
with torch.no_grad():predictions = model(image_tensor)# 获取预测的边界框和类别标签
boxes = predictions[0]['boxes'].cpu().numpy()
labels = predictions[0]['labels'].cpu().numpy()
scores = predictions[0]['scores'].cpu().numpy()# 可视化预测结果
fig, ax = plt.subplots(1)
ax.imshow(image)threshold = 0.5
for box, label, score in zip(boxes, labels, scores):if score > threshold:rect = patches.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], linewidth=2, edgecolor='r', facecolor='none')ax.add_patch(rect)ax.text(box[0], box[1], f'{label}: {score:.2f}', color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))plt.show()

5. YOLO 模型

YOLO(You Only Look Once)是单阶段目标检测器,具有非常快的检测速度。YOLOv5 是目前广泛使用的 YOLO 变体,官方提供了 PyTorch 版 YOLOv5。

5.1 安装 YOLOv5

首先,克隆 YOLOv5 的 GitHub 仓库并安装依赖项。

git clone https://github.com/ultralytics/yolov5
cd yolov5
pip install -r requirements.txt
5.2 使用 YOLOv5 进行推理

YOLOv5 提供了一个简单的 API,用于进行推理和训练。

import torch# 加载预训练的 YOLOv5 模型
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)# 加载图像并进行推理
image_path = 'test_image.jpg'
results = model(image_path)# 打印预测结果并显示
results.print()  # 打印检测到的结果
results.show()   # 显示带有检测框的图像
5.3 微调 YOLOv5 模型

你可以使用 YOLOv5 提供的工具在自定义数据集上训练模型。首先,准备符合 YOLO 格式的数据集,然后使用以下命令进行训练。

python train.py --img 640 --batch 16 --epochs 100 --data custom_dataset.yaml --weights yolov5s.pt

custom_dataset.yaml 文件

需要指定训练和验证集路径,以及类别信息。


6. 总结

  1. Faster R-CNN:基于区域提议方法,具有高精度,但速度相对较慢。适合需要高精度的场景。
  2. SSD:单阶段检测器,速度较快,适合实时检测任务。
  3. YOLOv5:速度极快,适合大规模实时检测应用。

PyTorch 和 torchvision 提供了丰富的目标检测模型,可以用于快速的推理或在自定义数据集上进行微调。对于实时需求较高的应用,YOLOv5 是不错的选择,而对于精度要求更高的场景,Faster R-CNN 是理想的选择。

相关文章:

PyTorch 目标检测教程

PyTorch 目标检测教程 本教程将介绍如何在 PyTorch 中使用几种常见的目标检测模型,包括 Faster R-CNN、SSD 以及 YOLO (You Only Look Once)。我们将涵盖预训练模型的使用、推理、微调,以及自定义数据集上的训练。 1. 目标检测概述 目标检测任务不仅要…...

校园美食导航:Spring Boot技术的美食发现之旅

第二章 系统分析 2.1 可行性分析 可行性分析的目的是确定一个系统是否有必要开发、确定系统是否能以最小的代价实现。其工作主要有三个方面,分别是技术、经济和社会三方面的可行性。我会从这三个方面对网上校园周边美食探索及分享平台进行详细的分析。 2.1.1技术可行…...

51单片机 - DS18B20实验1-读取温度

上来一张图,明确思路,程序整体裤架如下,通过单总线,单独封装一个.c文件用于单总线的操作,其实,我们可以把点c文件看成一个类操作,其属性就是我们面向对象的函数,也叫方法&#xff0c…...

go语言基础入门(一)

变量声明:批量声明变量:变量赋值: 声明变量同时为变量赋值可以在变量声明时为其赋值go中赋值时的编译器会自动根据等号右侧的数据类型自动推导变量的类型使用 : 进行赋值匿名变量 常量常量计数器iota1. 使用场景2. 基本用法3. 简化语法4. 自定义增量5. 复杂使用go的类似枚举 使…...

linux 基础(一)mkdir、ls、vi、ifconfig

1、linux简介 linux是一个操作系统(os: operating system) 中国有没有自己的操作系统(华为鸿蒙HarmonyOS,阿里龙蜥(Anolis) OS 8、百度DuerOS都有) 计算机组的组成:硬件软件 硬件:运算器&am…...

DAMODEL丹摩智算:LLama3.1部署与使用

文章目录 前言 一、LLaMA 3.1 的特点 二、LLaMA3.1的优势 三、LLaMA3.1部署流程 (一)创建实例 (二)通过JupyterLab登录实例 (3)部署LLaMA3.1 (4)使用教程 总结 前言 LLama3…...

Spring Boot 配置全流程 总结

1. 简介 Springboot可以简化SSM的配置&#xff0c;提高开发效率。 2. 代码 在pom.xml中添加&#xff1a; <parent><!-- 包含SSM常用依赖项 --><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</art…...

爬虫技术初步自学

目的 本篇文章实际上自学爬虫技术的学习一份学习笔记&#xff0c;希望可以对后学的小白起到帮助&#xff0c;也希望得到大佬的指点&#xff0c;若有错漏希望大佬指出。 初步认知 爬虫实际上是一个计算机程序。开发爬虫程序的常用语言是Python。&#xff08;Python我已经在五…...

【力扣 | SQL题 | 每日三题】力扣175, 176, 181

1. 力扣175&#xff1a;组合两个表 1.1 题目&#xff1a; 表: Person ---------------------- | 列名 | 类型 | ---------------------- | PersonId | int | | FirstName | varchar | | LastName | varchar | ---------------------- personId 是该…...

SpringBoot使用hutool操作FTP

项目场景&#xff1a; SpringBoot使用hutool操作FTP&#xff0c;可以实现从FTP服务器下载文件到本地&#xff0c;以及将本地文件上传到FTP服务器的功能。 实现步骤&#xff1a; 1、引入依赖 <dependency><groupId>commons-net</groupId><artifactId>…...

如何防止SQL注入攻击

SQL注入攻击是一种常见的网络安全威胁&#xff0c;攻击者通过在用户输入中插入恶意的SQL代码&#xff0c;从而可以执行未经授权的数据库操作。为了防止SQL注入攻击&#xff0c;我们可以采取一系列有效的措施来保护数据库和应用程序的安全性。以下是一些关键的防范策略&#xff…...

Java List类

欢迎来到Cefler的博客&#x1f601; &#x1f54c;博客主页&#xff1a;折纸花满衣 &#x1f3e0;个人专栏&#xff1a;Java 目录 &#x1f449;&#x1f3fb;List1. 接口与实现2. 特性3. 常用方法4. 示例代码5. 遍历6. 线程安全 &#x1f449;&#x1f3fb;List Java的 List …...

使用 Internet 共享 (ICS) 方式分配ip

设备A使用dhcp的情况下&#xff0c;通过设备B分配ip并共享网络的方法。 启用网络共享&#xff08;ICS&#xff09;并配置 NAT Windows 自带的 Internet Connection Sharing (ICS) 功能可以简化 NAT 设置&#xff0c;允许共享一个网络连接给其他设备。 打开网络设置&#xff1…...

SMTP/IMAP服务发在线邮件时要用到

SMTP/IMAP服务 require PHPMailerAutoload.php; // 或 require class.phpmailer.php;// 创建实例 $mail new PHPMailer();// 设定邮件服务器 $mail->isSMTP(); $mail->Host smtp.example.com; // 邮件服务器地址 $mail->SMTPAuth true; $mail->Username your…...

Threejs绘制圆锥体

上一章节实现了胶囊体的绘制&#xff0c;这节来绘制圆锥体&#xff0c;圆锥体就是三角形旋转获得的&#xff0c;如上文一样&#xff0c;先要创建出基础的组件&#xff0c;包括场景&#xff0c;相机&#xff0c;灯光&#xff0c;渲染器。代码如下&#xff1a; initScene() {this…...

速通LLaMA3:《The Llama 3 Herd of Models》全文解读

文章目录 概览论文开篇IntroductionGeneral OverviewPre-TrainingPre-Training DataModel ArchitectureInfrastructure, Scaling, and EfficiencyTraining Recipe Post-TrainingResultsVision ExperimentsSpeech Experiments⭐Related WorkConclusionLlama 3 模型中的数学原理1…...

Python网络爬虫获取Wallhaven壁纸图片(源码)

** 话不多说&#xff0c;直接附源码&#xff0c;可运行&#xff01; ** import requests from lxml import etree from fake_useragent import UserAgent import timeclass wallhaven(object):def __init__(self):# yellow# self.url "https://wallhaven.cc/search?co…...

智能化引领等保测评新时代:AI与大数据的深度融合

随着信息技术的飞速发展&#xff0c;等级保护测评&#xff08;简称“等保测评”&#xff09;作为保障信息系统安全的重要手段&#xff0c;正迎来前所未有的变革。在这一背景下&#xff0c;人工智能&#xff08;AI&#xff09;与大数据技术的深度融合&#xff0c;正引领等保测评…...

深入解析:HTTP 和 HTTPS 的区别

网络安全问题正变得日益重要&#xff0c;而 HTTP 与 HTTPS 对用户数据的保护十分关键。本文将深入探讨这两种协议的特点、工作原理&#xff0c;以及保证数据安全的 HTTPS 为何变得至关重要。 认识 HTTP 与 HTTPS HTTP 的工作原理 HTTP&#xff0c;全称超文本传输协议&#xf…...

《动手学深度学习》笔记1.11——实战Kaggle比赛:预测房价+详细代码讲解

目录 0. 前言 原书正文 1. 下载和缓存数据集 1.1 download() 下载数据集 1.2 download_extract() 解压缩 2. Kaggle 简介 3. 访问和读取数据集 4. 数据预处理 5. 训练&#xff08;核心难点&#xff09; 5.1 get_net() 定义模型-线性回归 5.2 log_rmse() 对数均方根…...

RWKV7-1.5B-g1a参数详解:为何默认top_p=0.3更适合中文生成?语言分布实证

RWKV7-1.5B-g1a参数详解&#xff1a;为何默认top_p0.3更适合中文生成&#xff1f;语言分布实证 1. 模型概述 rwkv7-1.5B-g1a是基于RWKV-7架构的多语言文本生成模型&#xff0c;特别适合中文场景下的基础问答、文案续写和简短总结任务。作为1.5B参数量的轻量级模型&#xff0c…...

OpenClaw多终端访问:远程控制GLM-4.7-Flash助手方案

OpenClaw多终端访问&#xff1a;远程控制GLM-4.7-Flash助手方案 1. 为什么需要远程访问OpenClaw&#xff1f; 去年冬天的一个深夜&#xff0c;我正在外地出差&#xff0c;突然接到同事紧急需求——需要从公司内网服务器提取一份关键数据报告。当时我的OpenClaw助手部署在家里…...

OneAPI 百度文心一言ERNIE-Bot接入:千帆平台Key对接指南

OneAPI 百度文心一言ERNIE-Bot接入&#xff1a;千帆平台Key对接指南 安全提示&#xff1a;使用 root 用户初次登录系统后&#xff0c;务必修改默认密码 123456&#xff01; 1. 引言&#xff1a;为什么需要统一的API管理平台 在当今AI技术快速发展的时代&#xff0c;企业和开发…...

Keil5主题配色进阶:不只是好看,更要好用!详解如何区分函数、变量、宏定义的颜色

Keil5主题配色进阶&#xff1a;不只是好看&#xff0c;更要好用&#xff01;详解如何区分函数、变量、宏定义的颜色 作为一名嵌入式开发者&#xff0c;每天面对Keil5的默认编辑器界面&#xff0c;你是否也感到视觉疲劳&#xff1f;那些单调的配色不仅影响编码心情&#xff0c;更…...

dfs:飞机降落

题目&#xff1a;P9241 [蓝桥杯 2023 省 B] 飞机降落 - 洛谷 做题目之前一定要先看数据范围。这道题的数据范围&#xff0c;T,N均<10&#xff0c;可以用暴力搜索。 这道题是排序&#xff0c;假设有3辆飞机。顺序可以是123&#xff0c;132&#xff0c;213&#xff0c;231&am…...

OpenClaw跨平台实战:Windows到Mac的Qwen3-32B配置迁移

OpenClaw跨平台实战&#xff1a;Windows到Mac的Qwen3-32B配置迁移 1. 为什么需要跨平台配置迁移&#xff1f; 去年冬天&#xff0c;我在Windows工作站上搭建了一套基于Qwen3-32B的OpenClaw自动化系统&#xff0c;用于处理日常的文档整理和数据分析任务。当公司配发新款MacBoo…...

MGeo门址地址解析效果展示:高德×达摩院多模态模型真实解析案例集

MGeo门址地址解析效果展示&#xff1a;高德达摩院多模态模型真实解析案例集 1. 引言&#xff1a;当AI开始“读懂”地址 想象一下&#xff0c;你收到一条外卖订单&#xff0c;地址写着“朝阳区望京SOHO T3 B座15楼1501室&#xff0c;到了打电话”。对于骑手来说&#xff0c;这…...

手把手教你用RTABMAP+T265在Windows10上实现室内三维扫描(含标定技巧)

手把手教你用RTABMAPT265在Windows10上实现高精度室内三维扫描 第一次接触室内三维扫描时&#xff0c;我被这项技术深深吸引——它能让物理空间瞬间数字化&#xff0c;就像给现实世界按下"CtrlC"。但真正动手配置RTABMAP和T265相机时&#xff0c;才发现这条路并不平坦…...

深度解析ConcurrentHashMap设计演进:从分段锁到无锁化的并发之路

在Java并发编程领域&#xff0c;ConcurrentHashMap绝对是“并发容器扛鼎之作”——它既解决了HashMap并发环境下的数据不一致&#xff08;死循环、数据丢失&#xff09;问题&#xff0c;又突破了Hashtable全表锁的性能瓶颈&#xff0c;成为高并发场景下K-V存储的首选。自JDK1.5…...

QT实战:qcustomplot中setData与addData性能对比与最佳实践(附代码示例)

QT实战&#xff1a;qcustomplot中setData与addData性能对比与最佳实践&#xff08;附代码示例&#xff09; 在数据可视化领域&#xff0c;QT的qcustomplot库因其轻量级和高度可定制性而广受欢迎。然而&#xff0c;当处理大规模数据集或实时数据流时&#xff0c;开发者常常会遇到…...