【深度学习】 UNet详解
UNet 是一种经典的卷积神经网络(Convolutional Neural Network, CNN)架构,专为生物医学图像分割任务设计。该模型于 2015 年由 Olaf Ronneberger 等人在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中首次提出,因其卓越的性能和简单的结构,迅速成为图像分割领域的重要模型
1. 环境搭建
1.1 安装 Python 和相关工具
-
安装 Python 3.8 及以上版本
如果尚未安装 Python,可以从 Python官网 下载并安装。确保安装时勾选“Add Python to PATH”选项。 -
安装虚拟环境管理工具
虚拟环境是管理 Python 项目依赖的好方法。可以使用venv或conda来创建虚拟环境。我们这里使用venv,步骤如下:# 创建虚拟环境 python -m venv unet_env# 激活虚拟环境 source unet_env/bin/activate # Linux/Mac unet_env\Scripts\activate # Windows
1.2 安装依赖库
-
安装 PyTorch
根据你的硬件选择正确的 PyTorch 版本。如果你的电脑支持 CUDA(GPU 加速),可以使用带 CUDA 的版本,否则使用 CPU 版本:# 安装支持 CUDA 11.8 版本的 PyTorch pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118# 如果不支持 CUDA,则使用以下命令: pip install torch torchvision torchaudio -
安装其他依赖
你还需要一些其他的辅助库:pip install numpy opencv-python matplotlib tqdm scikit-learn pillow
2. 下载或实现 UNet 模型
2.1 UNet 模型结构详解
UNet 是经典的图像分割网络,其主要特点是由编码器(下采样部分)和解码器(上采样部分)组成。通过跳跃连接,编码器的每一层都将特征图传递到解码器对应层,以保持细节信息。
以下是 UNet 的详细实现,包含编码器、解码器、跳跃连接以及卷积操作:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass UNet(nn.Module):def __init__(self, in_channels, out_channels):super(UNet, self).__init__()# 编码器部分self.encoder1 = self.conv_block(in_channels, 64)self.encoder2 = self.conv_block(64, 128)self.encoder3 = self.conv_block(128, 256)self.encoder4 = self.conv_block(256, 512)# 底部瓶颈部分self.bottleneck = self.conv_block(512, 1024)# 解码器部分self.upconv4 = self.upconv(1024, 512)self.decoder4 = self.conv_block(1024, 512)self.upconv3 = self.upconv(512, 256)self.decoder3 = self.conv_block(512, 256)self.upconv2 = self.upconv(256, 128)self.decoder2 = self.conv_block(256, 128)self.upconv1 = self.upconv(128, 64)self.decoder1 = self.conv_block(128, 64)# 输出层self.output = nn.Conv2d(64, out_channels, kernel_size=1)def conv_block(self, in_channels, out_channels):"""标准的卷积模块,包含两个卷积层和ReLU激活函数"""return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.ReLU(inplace=True))def upconv(self, in_channels, out_channels):"""上采样操作,采用转置卷积(反卷积)"""return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)def forward(self, x):"""前向传播"""# 编码器:下采样enc1 = self.encoder1(x)enc2 = self.encoder2(F.max_pool2d(enc1, 2))enc3 = self.encoder3(F.max_pool2d(enc2, 2))enc4 = self.encoder4(F.max_pool2d(enc3, 2))# 底部瓶颈bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))# 解码器:上采样 + 跳跃连接dec4 = self.upconv4(bottleneck)dec4 = torch.cat((dec4, enc4), dim=1) # 跳跃连接dec4 = self.decoder4(dec4)dec3 = self.upconv3(dec4)dec3 = torch.cat((dec3, enc3), dim=1)dec3 = self.decoder3(dec3)dec2 = self.upconv2(dec3)dec2 = torch.cat((dec2, enc2), dim=1)dec2 = self.decoder2(dec2)dec1 = self.upconv1(dec2)dec1 = torch.cat((dec1, enc1), dim=1)dec1 = self.decoder1(dec1)return self.output(dec1)
3. 数据处理
3.1 数据集准备
为了训练 UNet,你需要准备一个图像分割数据集。数据集通常由原始图像(RGB 图像)和每个图像对应的标注图像(Mask)组成。
假设我们有一个目录结构:
dataset/
├── train/
│ ├── images/
│ └── masks/
├── val/
│ ├── images/
│ └── masks/
每个 images 文件夹包含训练图像,而 masks 文件夹包含对应的标注图像。
3.2 数据加载器
在 PyTorch 中,我们可以通过 Dataset 类来自定义数据加载器。以下是一个简单的 SegmentationDataset 类:
import os
import cv2
import torch
from torch.utils.data import Datasetclass SegmentationDataset(Dataset):def __init__(self, image_dir, mask_dir, transform=None):self.image_dir = image_dirself.mask_dir = mask_dirself.transform = transformself.images = os.listdir(image_dir)def __len__(self):return len(self.images)def __getitem__(self, idx):img_path = os.path.join(self.image_dir, self.images[idx])mask_path = os.path.join(self.mask_dir, self.images[idx])# 读取图像和标签image = cv2.imread(img_path)mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)# 应用任何数据增强(如有)if self.transform:augmented = self.transform(image=image, mask=mask)image = augmented['image']mask = augmented['mask']# 转换为 Tensor,通道数放到最前面return torch.tensor(image, dtype=torch.float32).permute(2, 0, 1), torch.tensor(mask, dtype=torch.long)
4. 训练模型
4.1 定义训练过程
我们将训练一个 UNet 模型,使用交叉熵损失函数和 Adam 优化器。训练时,输入的图像和标签将通过 DataLoader 加载。
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss# 超参数
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
EPOCHS = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"# 数据加载
train_dataset = SegmentationDataset("dataset/train/images", "dataset/train/masks")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)# 初始化模型与优化器
model = UNet(in_channels=3, out_channels=2).to(DEVICE) # 假设是二分类
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
criterion = CrossEntropyLoss()# 训练过程
for epoch in range(EPOCHS):model.train()for images, masks in train_loader:images, masks = images.to(DEVICE), masks.to(DEVICE)optimizer.zero_grad()# 前向传播outputs = model(images)# 计算损失loss = criterion(outputs, masks)# 反向传播loss.backward()# 更新参数optimizer.step()print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}")
5. 模型推理
5.1 保存与加载模型
# 保存模型
torch.save(model.state_dict(), "unet_model.pth")# 加载模型
model.load_state_dict(torch.load("unet_model.pth"))
model.eval()
5.2 单张图片推理
def predict(model, image_path):image = cv2.imread(image_path)image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)with torch.no_grad():output = model(image)return output.argmax(dim=1).squeeze(0).numpy()
6. 模型优化与改进
为了提高 UNet 的性能,我们可以从以下几个方面进行优化:
6.1 数据增强
在训练过程中引入数据增强技术可以提高模型的泛化能力。使用 Albumentations 库可以实现多种增强方式,例如旋转、翻转、裁剪等:
import albumentations as A
from albumentations.pytorch import ToTensorV2transform = A.Compose([A.Resize(256, 256), # 调整尺寸A.HorizontalFlip(p=0.5), # 随机水平翻转A.VerticalFlip(p=0.5), # 随机垂直翻转A.RandomRotate90(p=0.5), # 随机旋转90度A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # 标准化ToTensorV2() # 转为 Tensor
])# 在数据集初始化时传入 transform
train_dataset = SegmentationDataset("dataset/train/images", "dataset/train/masks", transform=transform)
6.2 学习率调度器
动态调整学习率可以提高收敛速度。可以使用 PyTorch 提供的学习率调度器,例如 StepLR 或 ReduceLROnPlateau:
from torch.optim.lr_scheduler import ReduceLROnPlateauscheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)for epoch in range(EPOCHS):model.train()epoch_loss = 0for images, masks in train_loader:images, masks = images.to(DEVICE), masks.to(DEVICE)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()epoch_loss += loss.item()# 更新学习率scheduler.step(epoch_loss / len(train_loader))print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {epoch_loss / len(train_loader):.4f}")
6.3 混合精度训练
使用混合精度训练可以加速训练并减少显存使用,特别是在 GPU 上。PyTorch 提供了 torch.cuda.amp 模块来实现:
from torch.cuda.amp import GradScaler, autocastscaler = GradScaler()for epoch in range(EPOCHS):model.train()for images, masks in train_loader:images, masks = images.to(DEVICE), masks.to(DEVICE)optimizer.zero_grad()# 自动混合精度with autocast():outputs = model(images)loss = criterion(outputs, masks)# 反向传播与优化scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()
6.4 Dice Loss 或 IoU Loss
交叉熵损失适合分类任务,但在分割任务中,Dice Loss 或 IoU Loss 能更好地处理类别不平衡问题:
class DiceLoss(nn.Module):def __init__(self):super(DiceLoss, self).__init__()def forward(self, preds, targets, smooth=1):preds = torch.sigmoid(preds) # 将输出限制在 [0, 1] 之间preds = preds.view(-1)targets = targets.view(-1)intersection = (preds * targets).sum()dice = (2. * intersection + smooth) / (preds.sum() + targets.sum() + smooth)return 1 - dice
然后在训练中替换损失函数:
criterion = DiceLoss()
6.5 模型改进:加入注意力机制
可以在 UNet 的跳跃连接中加入注意力机制(如 Squeeze-and-Excitation 或 Attention Gates),以提升模型对目标区域的关注能力。
以下是一个基于 SE 模块的示例:
class SEBlock(nn.Module):def __init__(self, in_channels, reduction=16):super(SEBlock, self).__init__()self.global_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction),nn.ReLU(inplace=True),nn.Linear(in_channels // reduction, in_channels),nn.Sigmoid())def forward(self, x):batch, channels, _, _ = x.size()y = self.global_pool(x).view(batch, channels)y = self.fc(y).view(batch, channels, 1, 1)return x * y
将 SEBlock 插入 UNet 的编码器和解码器中。
7. 模型评估
为了评估模型性能,通常需要计算一些分割任务的指标,例如:
- 像素精度 (Pixel Accuracy)
- IoU (Intersection over Union)
- Dice 系数
以下是计算 IoU 和 Dice 系数的代码:
def compute_metrics(preds, labels):preds = preds > 0.5 # 阈值化intersection = (preds & labels).sum()union = (preds | labels).sum()iou = intersection / uniondice = (2 * intersection) / (preds.sum() + labels.sum())return iou, dice
在验证集上运行:
model.eval()
with torch.no_grad():for images, masks in val_loader:images, masks = images.to(DEVICE), masks.to(DEVICE)outputs = model(images)preds = torch.sigmoid(outputs) > 0.5 # 二值化预测iou, dice = compute_metrics(preds.cpu(), masks.cpu())print(f"IoU: {iou:.4f}, Dice: {dice:.4f}")
8. 部署与推理加速
8.1 导出 ONNX
将模型导出为 ONNX 格式以便在推理加速框架中使用:
dummy_input = torch.randn(1, 3, 256, 256).to(DEVICE)
torch.onnx.export(model, dummy_input, "unet_model.onnx", opset_version=11)
8.2 使用 TensorRT 加速推理
可以使用 NVIDIA TensorRT 对 ONNX 模型进行优化并加速推理。具体操作请参考 TensorRT 文档。
相关文章:
【深度学习】 UNet详解
UNet 是一种经典的卷积神经网络(Convolutional Neural Network, CNN)架构,专为生物医学图像分割任务设计。该模型于 2015 年由 Olaf Ronneberger 等人在论文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中首次提出&…...
DeepSeek本地部署(windows)
一、下载并安装Ollama 1.下载Ollama Ollama官网:Ollama 点击"Download",会跳转至下载页面。 点击"Download for Windows"。会跳转Github进行下载,如下载速度过慢,可在浏览器安装GitHub加速插件。 2.安装Ollama 双击下载的安装文件,点击"Inst…...
简要介绍C语言/C++的三目运算符
三元运算符是C语言和C中的一种简洁的条件运算符,它的形式为: 条件表达式 ? 表达式1 : 表达式2; 三元运算符的含义 条件表达式:这是一个布尔表达式,通常是一个比较操作(如 >、<、 等)。 表达式1&am…...
SpringCloud系列教程:微服务的未来(十九)请求限流、线程隔离、Fallback、服务熔断
前言 前言 在现代微服务架构中,系统的高可用性和稳定性至关重要。为了解决系统在高并发请求或服务不可用时出现的性能瓶颈或故障,常常需要使用一些技术手段来保证服务的平稳运行。请求限流、线程隔离、Fallback 和服务熔断是微服务中常用的四种策略&…...
STM32 对射式红外传感器配置
这次用的是STM32F103的开发板(这里面的exti.c文件没有how to use this driver 配置说明) 对射式红外传感器 由一个红外发光二极管和NPN光电三极管组成,M3固定安装孔,有输出状态指示灯,输出高电平灯灭,输出…...
(动态规划路径基础 最小路径和)leetcode 64
视频教程 1.初始化dp数组,初始化边界 2、从[1行到n-1行][1列到m-1列]依次赋值 #include<vector> #include<algorithm> #include <iostream>using namespace std; int main() {vector<vector<int>> grid { {1,3,1},{1,5,1},{4,2,1}…...
嵌入式C语言:什么是共用体?
在嵌入式C语言编程中,共用体(Union)是一种特殊的数据结构,它允许在相同的内存位置存储不同类型的数据。意味着共用体中的所有成员共享同一块内存区域,因此,在任何给定时间,共用体只能有效地存储…...
QT简单实现验证码(字符)
0) 运行结果 1) 生成随机字符串 Qt主要通过QRandomGenerator类来生成随机数。在此之前的版本中,qrand()函数也常被使用,但从Qt 5.10起,推荐使用更现代化的QRandomGenerator类。 在头文件添加void generateRandomNumb…...
【4Day创客实践入门教程】Day2 探秘微控制器——单片机与MicroPython初步
Day2 探秘微控制器——单片机与MicroPython初步 目录 Day2 探秘微控制器——单片机与MicroPython初步MicroPython语言基础开始基础语法注释与输出变量模块与函数 单片机基础后记 Day0 创想启程——课程与项目预览Day1 工具箱构建——开发环境的构建Day2 探秘微控制器——单片机…...
C++中vector追加vector
在C中,如果你想将一个vector追加到另一个vector的后面,可以使用std::vector的成员函数insert或者std::copy,或者简单地使用std::vector的push_back方法逐个元素添加。这里我将展示几种常用的方法: 方法1:使用insert方…...
【Java高并发】基于任务类型创建不同的线程池
文章目录 一. 按照任务类型对线程池进行分类1. IO密集型任务的线程数2. CPU密集型任务的线程数3. 混合型任务的线程数 二. 线程数越多越好吗三. Redis 单线程的高效性 使用线程池的好处主要有以下三点: 降低资源消耗:线程是稀缺资源,如果无限…...
[论文阅读] (37)CCS21 DeepAID:基于深度学习的异常检测(解释)
祝大家新春快乐,蛇年吉祥! 《娜璋带你读论文》系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢。由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正࿰…...
国内flutter环境部署(记录篇)
设置系统环境变量 export PUB_HOSTED_URLhttps://pub.flutter-io.cn export FLUTTER_STORAGE_BASE_URLhttps://storage.flutter-io.cn使用以下命令下载flutter镜像 git clone -b stable https://mirror.ghproxy.com/https://github.com/<github仓库地址>#例如flutter仓…...
Java面试题2025-并发编程进阶(线程池和并发容器类)
线程池 一、什么是线程池 为什么要使用线程池 在开发中,为了提升效率的操作,我们需要将一些业务采用多线程的方式去执行。 比如有一个比较大的任务,可以将任务分成几块,分别交给几个线程去执行,最终做一个汇总就可…...
【算法应用】基于鲸鱼优化算法求解OTSU多阈值图像分割问题
目录 1.鲸鱼优化算法WOA 原理2.OTSU多阈值图像分割模型3.结果展示4.参考文献5.代码获取 1.鲸鱼优化算法WOA 原理 SCI二区|鲸鱼优化算法(WOA)原理及实现 2.OTSU多阈值图像分割模型 Otsu 算法(最大类间方差法)设灰度图像有 L L …...
设计模式的艺术-策略模式
行为型模式的名称、定义、学习难度和使用频率如下表所示: 1.如何理解策略模式 在策略模式中,可以定义一些独立的类来封装不同的算法,每个类封装一种具体的算法。在这里,每个封装算法的类都可以称之为一种策略(Strategy…...
Autogen_core源码:_agent_instantiation.py
目录 _agent_instantiation.py代码代码解释代码示例示例 1:使用 populate_context 正确设置上下文示例 2:尝试在上下文之外调用 current_runtime 和 current_agent_id示例 3:模拟 AgentRuntime 使用 AgentInstantiationContext _agent_instan…...
7. 马科维茨资产组合模型+金融研报AI长文本智能体(Qwen-Long)增强方案(理论+Python实战)
目录 0. 承前1. 深度金融研报准备2. 核心AI函数代码讲解2.1 函数概述2.2 输入参数2.3 主要流程2.4 异常处理2.5 清理工作2.7 get_ai_weights函数汇总 3. 汇总代码4. 反思4.1 不足之处4.2 提升思路 5. 启后 0. 承前 本篇博文是对前两篇文章,链接: 5. 马科维茨资产组…...
安装Maven(安装包+步骤)
1. 安装: 通过网盘分享的文件:apache-maven-3.9.9 链接: https://pan.baidu.com/s/16AE_brICuw6sS0tC6tmE1Q?pwda74r 提取码: a74r --来自百度网盘超级会员v3的分享 2.新建应该系统变量: 3.path中添加bin文件夹路径 4.建议在这里建一个仓库文件夹 博主的: 5.I…...
【云安全】云原生-K8S-搭建/安装/部署
一、准备3台虚拟机 务必保证3台是同样的操作系统! 1、我这里原有1台centos7,为了节省资源和效率,打算通过“创建链接克隆”2台出来 2、克隆之前,先看一下是否存在k8s相关组件,或者docker相关组件 3、卸载原有的docker …...
基于PLC的变频调速系统设计
摘要 现代科技发展迅速,特别是通讯技术的发展,工业现场提供了便捷的数据交互和控制的手段,将工业现场的仪表、驱动器、控制器以及上位机之间进行通讯连接,进行相互信息交互,数据准确高效的传送,并且对现场的…...
单细胞-第四节 多样本数据分析,下游画图
文件在单细胞\5_GC_py\1_single_cell\2_plots.Rmd 1.细胞数量条形图 rm(list ls()) library(Seurat) load("seu.obj.Rdata")dat as.data.frame(table(Idents(seu.obj))) dat$label paste(dat$Var1,dat$Freq,sep ":") head(dat) library(ggplot2) lib…...
【算法】动态规划专题① ——线性DP python
目录 引入简单实现稍加变形举一反三实战演练总结 引入 楼梯有个台阶,每次可以一步上1阶或2阶。一共有多少种不同的上楼方法? 怎么去思考? 假设就只有1个台阶,走法只有:1 只有2台阶: 11,2 只有3台…...
知识管理平台在数字经济时代推动企业智慧决策与知识赋能的路径分析
内容概要 在数字经济时代,知识管理平台被视为企业智慧决策与知识赋能的关键工具。其核心作用在于通过高效地整合、存储和分发企业内部的知识资源,促进信息的透明化与便捷化,使得决策者能够在瞬息万变的市场环境中迅速获取所需信息。这不仅提…...
LabVIEW微位移平台位移控制系统
本文介绍了基于LabVIEW的微位移平台位移控制系统的研究。通过设计一个闭环控制系统,针对微位移平台的通信驱动问题进行了解决,并提出了一种LabVIEW的应用方案,用于监控和控制微位移平台的位移,从而提高系统的精度和稳定性。 项目背…...
java求职学习day23
MySQL 单表 & 约束 & 事务 1. DQL操作单表 1.1 创建数据库,复制表 1) 创建一个新的数据库 db2 CREATE DATABASE db2 CHARACTER SET utf8; 2) 将 db1 数据库中的 emp 表 复制到当前 db2 数据库 1.2 排序 通过 ORDER BY 子句 , 可以将查询出的结果进行排序 ( 排序只…...
【张雪峰高考志愿填报】合集
【张雪峰高考志愿填报】合集 链接:https://pan.quark.cn/s/89a2d88fa807 高考结束,分数即将揭晓,志愿填报的关键时刻近在眼前!同学们,这可是人生的重要转折点,选对志愿,就像为未来铺就一条…...
22.Word:小张-经费联审核结算单❗【16】
目录 NO1.2 NO3.4 NO5.6.7 NO8邮件合并 MS搜狗输入法 NO1.2 用ms打开文件,而不是wps❗不然后面都没分布局→页面设置→页面大小→页面方向→上下左右:页边距→页码范围:多页:拼页光标处于→布局→分隔符:分节符…...
MYSQL 商城系统设计 商品数据表的设计 商品 商品类别 商品选项卡 多表查询
介绍 在开发商品模块时,通常使用分表的方式进行查询以及关联。在通过表连接的方式进行查询。每个商品都有不同的分类,每个不同分类下面都有商品规格可以选择,每个商品分类对应商品规格都有自己的价格和库存。在实际的开发中应该给这些表进行…...
python3+TensorFlow 2.x(二) 回归模型
目录 回归算法 1、线性回归 (Linear Regression) 一元线性回归举例 2、非线性回归 3、回归分类 回归算法 回归算法用于预测连续的数值输出。回归分析的目标是建立一个模型,以便根据输入特征预测目标变量,在使用 TensorFlow 2.x 实现线性回归模型时&…...
