当前位置: 首页 > news >正文

PyTorch 图像分割模型教程

PyTorch 图像分割模型教程

在图像分割任务中,目标是将图像的每个像素归类为某一类,以分割出特定的物体。PyTorch 提供了非常灵活的工具,可以用于构建和训练图像分割模型。我们将使用 PyTorch 的经典网络架构,如 UNetDeepLabV3,并演示如何构建、训练和测试这些模型。

1. 图像分割概述

图像分割的目标是将图像的每个像素进行分类。常见的应用场景有医学图像分割(如肿瘤检测)、自动驾驶(道路、车辆、行人分割)等。

  • 语义分割:每个像素被分配给某个类别,例如道路、天空或车辆。
  • 实例分割:不仅对物体分类,还要区分物体实例,如区分不同的行人。

PyTorch 中有许多预训练的模型可以直接用于图像分割任务,常用的模型包括 UNetFCN (Fully Convolutional Network)DeepLabV3 等。

2. 官方文档链接
  • PyTorch 官方文档
  • Torchvision 模型

3. 准备工作

在开始训练之前,我们需要安装 torch, torchvisionPIL 等依赖项,并准备图像数据集。您可以使用自己的图像数据集,或者使用 COCO、VOC 等常用数据集。

pip install torch torchvision pillow

4. 使用预训练的 DeepLabV3 模型

DeepLabV3 是一个性能优异的语义分割模型,PyTorch 的 torchvision 提供了预训练的 DeepLabV3 模型。我们将使用 COCO 数据集中的预训练模型,并进行推理和测试。

import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt# 加载预训练的 DeepLabV3 模型
model = models.segmentation.deeplabv3_resnet50(pretrained=True)
model.eval()  # 切换到评估模式# 定义预处理步骤
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载图像
input_image = Image.open("test_image.jpg")
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # 创建 batch 维度# 将输入移到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_batch = input_batch.to(device)# 进行预测
with torch.no_grad():output = model(input_batch)['out'][0]  # DeepLabV3 的输出包含 'out' 字段# 将输出转换为类别索引(每个像素对应一个类别)
output_predictions = output.argmax(0).cpu().numpy()# 显示分割结果
plt.imshow(output_predictions)
plt.show()

说明

  • models.segmentation.deeplabv3_resnet50(pretrained=True):加载使用 ResNet-50 作为主干网络的 DeepLabV3 模型,预训练于 COCO 数据集。
  • preprocess:对输入图像进行预处理,包括调整大小、裁剪、归一化等。
  • output_predictions:模型的输出是每个像素的类别索引,经过 argmax 操作,获取每个像素的类别。

5. UNet 模型

UNet 是一个广泛用于医学图像分割的经典模型。我们将从头实现 UNet 模型,并在简单的合成数据上进行训练。

5.1 UNet 网络结构
import torch
import torch.nn as nn
import torch.nn.functional as Fclass UNet(nn.Module):def __init__(self):super(UNet, self).__init__()# 下采样(编码器部分)self.encoder1 = self.double_conv(1, 64)self.encoder2 = self.double_conv(64, 128)self.encoder3 = self.double_conv(128, 256)self.encoder4 = self.double_conv(256, 512)# 中间部分self.middle = self.double_conv(512, 1024)# 上采样(解码器部分)self.upconv4 = self.up_conv(1024, 512)self.decoder4 = self.double_conv(1024, 512)self.upconv3 = self.up_conv(512, 256)self.decoder3 = self.double_conv(512, 256)self.upconv2 = self.up_conv(256, 128)self.decoder2 = self.double_conv(256, 128)self.upconv1 = self.up_conv(128, 64)self.decoder1 = self.double_conv(128, 64)# 最后的分类层self.final = nn.Conv2d(64, 1, kernel_size=1)def double_conv(self, in_channels, out_channels):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),)def up_conv(self, in_channels, out_channels):return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)def forward(self, x):# 编码器部分e1 = self.encoder1(x)e2 = self.encoder2(F.max_pool2d(e1, 2))e3 = self.encoder3(F.max_pool2d(e2, 2))e4 = self.encoder4(F.max_pool2d(e3, 2))# 中间部分middle = self.middle(F.max_pool2d(e4, 2))# 解码器部分d4 = self.upconv4(middle)d4 = torch.cat((e4, d4), dim=1)d4 = self.decoder4(d4)d3 = self.upconv3(d4)d3 = torch.cat((e3, d3), dim=1)d3 = self.decoder3(d3)d2 = self.upconv2(d3)d2 = torch.cat((e2, d2), dim=1)d2 = self.decoder2(d2)d1 = self.upconv1(d2)d1 = torch.cat((e1, d1), dim=1)d1 = self.decoder1(d1)return self.final(d1)# 创建模型实例
unet_model = UNet()
print(unet_model)

说明

  • UNet 是一种编码-解码结构,包含多个下采样(编码器)和上采样(解码器)层。每次下采样都会减少特征图的大小,并增加特征通道数,上采样则恢复原始图像的大小。
  • ConvTranspose2d 用于进行上采样操作。
5.2 训练 UNet 模型

为了训练 UNet 模型,我们需要构建一个数据加载器并定义损失函数和优化器。我们以一个简单的二分类分割任务为例。

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms# 创建合成数据集
class SyntheticSegmentationDataset(Dataset):def __init__(self, num_samples, image_size):self.num_samples = num_samplesself.image_size = image_sizeself.transform = transforms.Compose([transforms.ToTensor(),])def __len__(self):return self.num_samplesdef __getitem__(self, idx):image = torch.rand(1, self.image_size, self.image_size)mask = (image > 0.5).float()  # 简单的二分类掩码return image, mask# 创建数据集
dataset = SyntheticSegmentationDataset(num_samples=1000, image_size=128)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()  # 二分类交叉熵损失
optimizer = torch.optim.Adam(unet_model.parameters(), lr=0.001)# 训练循环
unet_model.train()
for epoch in range(5):  # 简单训练 5 个 epochfor images, masks in dataloader:# 前向传播outputs = unet_model(images)loss = criterion(outputs, masks)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')

说明

  • BCEWithLogitsLoss 是二分类任务的标准损失函数,适合输出为单通道(1 表示目标类,0 表示背景)的分割任务。
  • 我们创建了一个合成数据集,其中图像为随机值,掩码为图像中值大于 0.5 的部分。

6. 总结

  • DeepLabV3 是一种非常强大的图像分割模型,适用于各种复杂场景,PyTorch 提供了预训练模型,适合快速部署。
  • UNet 是经典的医学图像分割模型,适用于更细致的分割任务。

通过使用 PyTorch,您可以轻松实现并训练图像分割模型,利用 GPU 加速并扩展到大规模数据集。

相关文章:

PyTorch 图像分割模型教程

PyTorch 图像分割模型教程 在图像分割任务中,目标是将图像的每个像素归类为某一类,以分割出特定的物体。PyTorch 提供了非常灵活的工具,可以用于构建和训练图像分割模型。我们将使用 PyTorch 的经典网络架构,如 UNet 和 DeepLabV…...

物联网——USART协议

接口 串口通信 硬件电路 电平标准 串口参数、时序 USART USART主要框图 TXE: 判断发送寄存器是否为空 RXNE: 判断接收寄存器是否非空 RTS为输出信号,用于表示MCU串口是否准备好接收数据,若输出信号为低电平,则说明MCU串口可以接收数据&#…...

前端框架对比与选择:如何在现代Web开发中做出最佳决策

随着互联网技术的迅速发展,前端开发在现代Web应用开发中扮演了至关重要的角色。对于开发者来说,选择合适的前端框架不仅能够提高开发效率,还能确保项目的可维护性和可扩展性。目前市面上有多种主流的前端框架和库,每一种都有其独特…...

【浅水模型MATLAB】尝试复刻SCI论文中的溃坝流算例

【浅水模型MATLAB】尝试复刻SCI论文中的溃坝流算例 前言问题描述控制方程及数值方法浅水方程及其数值计算方法边界条件的实现 代码框架与关键代码模拟结果 更新于2024年9月17日 前言 这篇博客算是学习浅水方程,并利用MATLAB复刻Liang (2004)1中溃坝流算例的一个记录…...

探索云计算:IT行业的未来趋势

探索云计算:IT行业的未来趋势 在当今快速发展的科技世界,云计算已成为IT行业的核心趋势之一。无论是大企业还是初创公司,越来越多的组织正在转向云计算,以实现更高效的运营和更快的创新。在这篇博文中,我们将探讨云计算…...

[PICO VR眼镜]眼动追踪串流Unity开发与使用方法,眼动追踪打包报错问题解决(Eye Tracking/手势跟踪)

前言 最近在做一个工作需要用到PICO4 Enterprise VR头盔里的眼动追踪功能,但是遇到了如下问题: 在Unity里面没法串流调试眼动追踪功能,根本获取不到Device,只能将整个场景build成APK,安装到头盔里,才能在…...

一周热门|比GPT-4强100倍,OpenAI有望年底发布GPT-Next;1个GPU,1分钟,16K图像

大模型周报将从【企业动态】【技术前瞻】【政策法规】【专家观点】四部分,带你快速跟进大模型行业热门动态。 01 企业动态 Ilya 新公司 SSI 官宣融资 10 亿美元 据路透社报道,由 OpenAI 联合创始人、前首席科学家 Ilya Sutskever 在 2 个多月前共同创…...

软考流水线计算

某计算机系统输入/输出采用双缓冲工作方式,其工作过程如下图所示,假设磁盘块与缓冲区大小相同,每个盘块读入缓冲区的时间T为10μs,由缓冲区送至用户区的时间M为6μs,系统对每个磁盘块数据的处理时间C为2μs。若用户需要…...

1份可以派上用场丢失数据恢复的应用程序列表

无论如何,丢失您的宝贵数据是可怕的。您的 Android 或 iOS 设备可能由于事故、硬件损坏、存储卡问题等而丢失了数据。这就是为什么我们编制了一份可以派上用场以恢复丢失数据的应用程序列表。 如果您四处走动,您大多会随身携带手机或其他移动设备。这些…...

MySQL Workbench 超详细安装教程(一步一图解,保姆级安装)

前言: MySQL Workbench 是一款强大的数据库设计和管理工具,它提供了图形化界面,使得数据库的设计、管理、查询等操作变得更加直观和便捷。本文将详细介绍如何在 Windows 系统上安装 MySQL Workbench。相信读者看这篇文章前一定安装了MySQL数…...

深度学习常见面试题及答案(16~20)

算法学习、4对1辅导、论文辅导或核心期刊以及其他学习资源可以通过公众号滴滴我 文章目录 16. 简述深度学习中的批量归一化(Batch Normalization)的目的和工作原理。一、批量归一化的目的1. 加速训练收敛:2. 提高模型泛化能力:3. …...

Packet Tracer - IPv4 ACL 的实施挑战(完美解析)

目标 在路由器上配置命名的标准ACL。 在路由器上配置命名的扩展ACL。 在路由器上配置扩展ACL来满足特定的 通信需求。 配置ACL来控制对网络设备终端线路的 访问。 在适当的路由器接口上,在适当的方向上 配置ACL。…...

Langchain-chatchat源码部署及测试实验

一年多前接触到Langchain-chatchat的0.2版本,对0.2版本进行了本地部署和大量更新,但0.2版本对最新的大模型支持不够好,部署框架支持也不好且不太稳定,特别是多模态大模型,因此本次主要介绍0.3版本的源码部署,希望对大家有所帮助。Langchain-chatchat从0.3版本开始,支持更…...

【Linux】线程(第十六篇)

目录 线程 1.线程基本概述: 2.线程类型: 3.线程间的共享资源与非共享资源 4.线程原语 1.线程创建函数 2.获取当前线程id的函数 3.回收线程资源 4.将线程设置为分离态 5.结束线程 6.退出线程 线程 1.线程基本概述: 是操作系统能够…...

2024华为杯研赛E题保姆级教程思路分析

E题题目:高速公路应急车道紧急启用模型 今年的E题设计到图像/视频处理,实际上,E题的难度相对来说较低,大家不用畏惧视频的处理,被这个吓到。实际上,这个不难,解决了视频的处理问题,…...

国内可以使用的ChatGPT服务【9月持续更新】

首先基础知识还是要介绍得~ 一、模型知识: GPT-4o:最新的版本模型,支持视觉等多模态,OpenAI 文档中已经更新了 GPT-4o 的介绍:128k 上下文,训练截止 2023 年 10 月(作为对比,GPT-4…...

Linux环境Docker安装Mongodb

Linux环境Docker安装Mongodb 环境要求拉取指定版本镜像创建映射目录(相当于数据存放于容器外,容器被删除不会影响数据)启动容器 进入mongo命令行为指定db创建新用户查看mongodb的容器id进入命令行查看所有db切换db为指定db创建新用户使用新账…...

PyTorch 池化层详解

在深度学习中,池化层(Pooling Layer)是卷积神经网络(CNN)中的关键组成部分。池化层的主要功能是对特征图进行降维和减少计算量,同时增强模型的鲁棒性。本文将详细介绍池化层的作用、种类、实现方法&#xf…...

Intel架构的基本知识

1.字节序 CPU的字节序分为LittleEndian和BigEndian。 所谓Endian,就是多字节数据在内存中的排列方式。 例如,假设有一个整数0x11223344: LittleEndian的排列方式是,从内存的低地址开始,依次存放 0x44 0x33 0x22 0x11; BigEndian的排列方式是,从内存的低地址开始,依…...

Element Plus 中Input输入框

通过鼠标或键盘输入字符 input为受控组件,他总会显示Vue绑定值,正常情况下,input的输入事件会正常被响应,他的处理程序应该更新组件的绑定值(或使用v-model)。否则,输入框的值将不会改变 不支…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势:专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发,是一款收费低廉但功能全面的Windows NAS工具,主打“无学习成本部署” 。与其他NAS软件相比,其优势在于: 无需硬件改造:将任意W…...

PPT|230页| 制造集团企业供应链端到端的数字化解决方案:从需求到结算的全链路业务闭环构建

制造业采购供应链管理是企业运营的核心环节,供应链协同管理在供应链上下游企业之间建立紧密的合作关系,通过信息共享、资源整合、业务协同等方式,实现供应链的全面管理和优化,提高供应链的效率和透明度,降低供应链的成…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

STM32F4基本定时器使用和原理详解

STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...

STM32标准库-DMA直接存储器存取

文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 (忘了有没有这步了 估计有) 刷机程序 和 镜像 就不提供了。要刷的时…...

ardupilot 开发环境eclipse 中import 缺少C++

目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

LRU 缓存机制详解与实现(Java版) + 力扣解决

📌 LRU 缓存机制详解与实现(Java版) 一、📖 问题背景 在日常开发中,我们经常会使用 缓存(Cache) 来提升性能。但由于内存有限,缓存不可能无限增长,于是需要策略决定&am…...

【LeetCode】算法详解#6 ---除自身以外数组的乘积

1.题目介绍 给定一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O…...

如何配置一个sql server使得其它用户可以通过excel odbc获取数据

要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据,你需要完成以下配置步骤: ✅ 一、在 SQL Server 端配置(服务器设置) 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到:SQL Server 网络配…...