当前位置: 首页 > news >正文

UNet进行病理图像分割

数据集链接:https://pan.baidu.com/s/1IBe_P0AyHgZC39NqzOxZhA?pwd=nztc
提取码:nztc

  • UNet模型
import torch
import torch.nn as nnclass conv_block(nn.Module):def __init__(self, ch_in, ch_out):super(conv_block, self).__init__()self.conv = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True),nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self, x):x = self.conv(x)return xclass up_conv(nn.Module):def __init__(self, ch_in, ch_out):super(up_conv, self).__init__()self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self, x):x = self.up(x)return x
class UNet(nn.Module):def __init__(self, img_ch=3, output_ch=1):super(UNet, self).__init__()self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)self.Conv2 = conv_block(ch_in=64, ch_out=128)self.Conv3 = conv_block(ch_in=128, ch_out=256)self.Conv4 = conv_block(ch_in=256, ch_out=512)self.Conv5 = conv_block(ch_in=512, ch_out=1024)self.Up5 = up_conv(ch_in=1024, ch_out=512)self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)self.Up4 = up_conv(ch_in=512, ch_out=256)self.Up_conv4 = conv_block(ch_in=512, ch_out=256)self.Up3 = up_conv(ch_in=256, ch_out=128)self.Up_conv3 = conv_block(ch_in=256, ch_out=128)self.Up2 = up_conv(ch_in=128, ch_out=64)self.Up_conv2 = conv_block(ch_in=128, ch_out=64)self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)def forward(self, x):# encoding pathx1 = self.Conv1(x)x2 = self.Maxpool(x1)x2 = self.Conv2(x2)x3 = self.Maxpool(x2)x3 = self.Conv3(x3)x4 = self.Maxpool(x3)x4 = self.Conv4(x4)x5 = self.Maxpool(x4)x5 = self.Conv5(x5)# decoding + concat pathd5 = self.Up5(x5)d5 = torch.cat((x4, d5), dim=1)d5 = self.Up_conv5(d5)d4 = self.Up4(d5)d4 = torch.cat((x3, d4), dim=1)d4 = self.Up_conv4(d4)d3 = self.Up3(d4)d3 = torch.cat((x2, d3), dim=1)d3 = self.Up_conv3(d3)d2 = self.Up2(d3)d2 = torch.cat((x1, d2), dim=1)d2 = self.Up_conv2(d2)d1 = self.Conv_1x1(d2)output = torch.sigmoid(d1)  # 在最后加上Sigmoid激活函数return output
  • 数据加载
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transformsclass SegmentationDataset(Dataset):def __init__(self, image_dir, mask_dir, output_size=(256, 256)):self.image_dir = image_dirself.mask_dir = mask_dirself.image_list = os.listdir(image_dir)self.output_size = output_size# 定义图像和掩码的变换self.image_transform = transforms.Compose([transforms.Resize(self.output_size),transforms.ToTensor()])self.mask_transform = transforms.Compose([transforms.Resize(self.output_size),transforms.ToTensor()])def __len__(self):return len(self.image_list)def __getitem__(self, idx):image_name = self.image_list[idx]image_path = os.path.join(self.image_dir, image_name)mask_path = os.path.join(self.mask_dir, image_name)image = Image.open(image_path).convert("RGB")  # 确保是RGBmask = Image.open(mask_path).convert("L")  # 确保是灰度图像image = self.image_transform(image)mask = self.mask_transform(mask)return image, mask
  • 训练和测试。训练函数中保存的最好模型后缀最大(因为loss小才保存当前这个epoch的模型,我训练的最好模型是第171轮产生的),测试代码包含计算模型性能指标的代码和保存结果图片的代码。
import os
import numpy as np
import torch
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from UNet import UNet
from DataLoader2 import SegmentationDataset# IoU计算
def compute_iou(pred_mask, true_mask):smooth = 1e-6  # 避免分母为0pred_mask = (pred_mask > 0.5).float()true_mask = (true_mask > 0.5).float()intersection = (pred_mask * true_mask).sum()union = pred_mask.sum() + true_mask.sum() - intersectionreturn (intersection + smooth) / (union + smooth)# Dice系数计算
def compute_dice(pred_mask, true_mask):smooth = 1e-6  # 避免分母为0pred_mask = (pred_mask > 0.5).float()true_mask = (true_mask > 0.5).float()intersection = (pred_mask * true_mask).sum()return (2. * intersection + smooth) / (pred_mask.sum() + true_mask.sum() + smooth)# 精度、召回率和F1分数计算
def compute_precision_recall_f1(pred_mask, true_mask):pred_mask = (pred_mask > 0.5).numpy().astype(int)true_mask = (true_mask > 0.5).numpy().astype(int)# 将mask平展为一维数组pred_mask_flat = pred_mask.flatten()true_mask_flat = true_mask.flatten()conf_matrix = confusion_matrix(true_mask_flat, pred_mask_flat)tn, fp, fn, tp = conf_matrix.ravel()precision = tp / (tp + fp)recall = tp / (tp + fn)f1_score = 2 * (precision * recall) / (precision + recall)return precision, recall, f1_score# 训练函数
def train():model = UNet()dataset = SegmentationDataset('./dataset_exp2/train/image', './dataset_exp2/train/label')dataloader = DataLoader(batch_size=16, shuffle=True, dataset=dataset)# 训练参数num_epochs = 200learning_rate = 1e-4# 损失函数和优化器criterion = nn.BCELoss()optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 设备device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')model = model.to(device)model.train()best_loss = float('inf')for epoch in range(num_epochs):epoch_loss = 0for images, labels in dataloader:images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()epoch_loss += loss.item()if epoch_loss < best_loss:best_loss = epoch_losstorch.save(model.state_dict(), f'./save_model_UNet/res_{epoch + 1}.pth')print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(dataloader)}')def test():model = UNet()# 确保模型在CPU上model.load_state_dict(torch.load('./save_model_UNet/res_171.pth'))save_dir = './test_results_UNet'model.eval()dataset = SegmentationDataset('./dataset_exp2/test/image', './dataset_exp2/test/label')dataloader = DataLoader(batch_size=1, shuffle=False, dataset=dataset)iou_list = []dice_list = []precision_list = []recall_list = []f1_list = []plt.ion()with torch.no_grad():for idx, (images, labels) in tqdm(enumerate(dataloader)):pre = model(images)img_pre = torch.squeeze(pre)img_true = torch.squeeze(labels)iou = compute_iou(img_pre, img_true)dice = compute_dice(img_pre, img_true)precision, recall, f1_score = compute_precision_recall_f1(img_pre, img_true)img_pre = img_pre.numpy()img_true = img_true.numpy()img_x = torch.squeeze(images).numpy().transpose(1, 2, 0)img_x = (img_x * 255).astype(np.uint8)  # 恢复到0-255的范围# 保存结果plt.figure(figsize=(12, 4))plt.subplot(1, 3, 1)plt.title('Input Image')plt.imshow(img_x)plt.axis('off')plt.subplot(1, 3, 2)plt.title('True Mask')plt.imshow(img_true, cmap='gray')plt.axis('off')plt.subplot(1, 3, 3)plt.title('UNet Predicted Mask')plt.imshow(img_pre, cmap='gray')plt.axis('off')plt.savefig(os.path.join(save_dir, f'result_{idx + 1}.png'))plt.close()  # 关闭当前figure,避免内存占用过多iou_list.append(iou.item())dice_list.append(dice.item())precision_list.append(precision)recall_list.append(recall)f1_list.append(f1_score)plt.ioff()  # 关闭交互模式print(f'Results saved in {save_dir}')print(f'Average IoU: {np.mean(iou_list)}')print(f'Average Dice Coefficient: {np.mean(dice_list)}')print(f'Average Precision: {np.mean(precision_list)}')print(f'Average Recall: {np.mean(recall_list)}')print(f'Average F1 Score: {np.mean(f1_list)}')if __name__ == '__main__':print('++++++++++++++++train++++++++++++++++')train()print('++++++++++++++++test++++++++++++++++')test()

测试效果:
在这里插入图片描述
在这里插入图片描述

相关文章:

UNet进行病理图像分割

数据集链接&#xff1a;https://pan.baidu.com/s/1IBe_P0AyHgZC39NqzOxZhA?pwdnztc 提取码&#xff1a;nztc UNet模型 import torch import torch.nn as nnclass conv_block(nn.Module):def __init__(self, ch_in, ch_out):super(conv_block, self).__init__()self.conv nn…...

初二数学基础差从哪开始补?附深度解析!

有时候&#xff0c;当你推不开一扇门的时候&#xff0c;不要着急&#xff0c;试着反方向拉一下&#xff0c;或者横向拉一下。下面是小偏整理的初二数学基础差从哪开始补2021年&#xff0c;感谢您的每一次阅读。   初二数学基础差从哪开始补2021年   第一个问题是很多同学都…...

【C语言】return 关键字

在C语言中&#xff0c;return是一个关键字&#xff0c;用于从函数中返回值或者结束函数的执行。它是函数的重要组成部分&#xff0c;负责将函数的计算结果返回给调用者&#xff0c;并可以提前终止函数的执行。 主要用途和原理&#xff1a; 返回值给调用者&#xff1a; 当函数执…...

华为机试HJ13句子逆序

华为机试HJ13句子逆序 题目&#xff1a; 将一个英文语句以单词为单位逆序排放。例如“I am a boy”&#xff0c;逆序排放后为“boy a am I”所有单词之间用一个空格隔开&#xff0c;语句中除了英文字母外&#xff0c;不再包含其他字符 想法&#xff1a; 将输入的字符串通过…...

代码随想录day40 动态规划(5)

52. 携带研究材料&#xff08;第七期模拟笔试&#xff09; (kamacoder.com) 完全背包&#xff0c;可重复放入物品&#xff0c;需要用一维滚动数组从前往后遍历。 由于第0个物品和后面物品的转移方程没有区别&#xff0c;可以不额外初始化dp数组&#xff0c;直接用元素全0的d…...

FFmpeg 命令行 音视频格式转换

&#x1f4da;&#xff1a;FFmpeg 提供了丰富的命令行选项和功能&#xff0c;可以用来处理音视频文件、流媒体等&#xff0c;掌握命令行的使用&#xff0c;可以有效提高工作效率。 目录 一、视频转换和格式转换 &#x1f535; 将视频文件转换为另一种格式 &#x1f535; 指定…...

Jmeter使用JSON Extractor提取多个变量

1.当正则不好使时&#xff0c;用json extractor 2.提取多个值时&#xff0c;默认值必填&#xff0c;否则读不到变量...

c++ 设计模式 的课本范例(下)

&#xff08;19&#xff09; 桥接模式 Bridge&#xff0c;不是采用类继承&#xff0c;而是采用类组合&#xff0c;一个类的数据成员是类对象&#xff0c;来扩展类的功能。源码如下&#xff1a; class OS // 操作系统负责绘图 { public:virtual ~OS() {}virtual void draw(cha…...

结合数据索引结构看SQL的真实执行过程

引言 关于数据库设计与优化的前几篇文章中&#xff0c;我们提到了数据库设计优化应该遵守的指导原则、数据库底层的索引组织结构、数据库的核心功能组件以及SQL的解析、编译等。这些其实都是在为SQL的优化、执行的理解打基础。 今天这篇文章&#xff0c;我们以MySQL中InnoDB存…...

spark shuffle——shuffle管理

ShuffleManager shuffle系统的入口。ShuffleManager在driver和executor中的sparkEnv中创建。在driver中注册shuffle&#xff0c;在executor中读取和写入数据。 registerShuffle&#xff1a;注册shuffle&#xff0c;返回shuffleHandle unregisterShuffle&#xff1a;移除shuff…...

HTMLCSS(入门)

HTML <html> <head><title>第一个页面</title></head><body>键盘敲烂&#xff0c;工资过万</body> </html> <!DOCTYPE>文档类型声明&#xff0c;告诉浏览器使用哪种HTML版本显示网页 <!DOCTYPE html>当前页面采取…...

富格林:曝光可信策略制止亏损

富格林指出&#xff0c;相信大家都对黄金投资的价值空间有目共睹&#xff0c;现如今黄金市场波动频繁&#xff0c;因此不少投资者也开始加入该市场试图赢得额外的财富。但作为新手投资者贸贸然地进场操作&#xff0c;亏损的几率是很大的&#xff0c;因此要学会掌握正规平台曝光…...

Android --- Service

出自于此&#xff0c;写得很清楚。关于Android Service真正的完全详解&#xff0c;你需要知道的一切_android service-CSDN博客 出自【zejian的博客】 什么是Service? Service(服务)是一个一种可以在后台执行长时间运行操作而没有用户界面的应用组件。 服务可由其他应用组件…...

Vue3从入门到精通(三)

vue3插槽Slots 在 Vue3 中&#xff0c;插槽&#xff08;Slots&#xff09;的使用方式与 Vue2 中基本相同&#xff0c;但有一些细微的差异。以下是在 Vue3 中使用插槽的示例&#xff1a; // ChildComponent.vue <template><div><h2>Child Component</h2&…...

【FreeRTOS】同步与互斥通信-有缺陷的互斥案例

目录 同步与互斥通信同步与互斥的概念同步与互斥并不简单缺陷分析汇编指令优化过程 - 关闭中断时间轴分析 思考时刻 参考《FreeRTOS入门与工程实践(基于DshanMCU-103).pdf》 同步与互斥通信 同步与互斥的概念 一句话理解同步与互斥&#xff1a;我等你用完厕所&#xff0c;我再…...

Docker 安装 Python

Docker 安装 Python 在当今的软件开发领域,Docker 已成为一项关键技术,它允许开发人员将应用程序及其依赖环境打包到一个可移植的容器中。Python,作为一种广泛使用的高级编程语言,经常被部署在 Docker 容器中。本文将详细介绍如何在 Docker 中安装 Python,以及如何配置环…...

外泌体相关基因肝癌临床模型预测——2-3分纯生信文章复现——4.预后相关外泌体基因确定单因素cox回归(2)

内容如下&#xff1a; 1.外泌体和肝癌TCGA数据下载 2.数据格式整理 3.差异表达基因筛选 4.预后相关外泌体基因确定 5.拷贝数变异及突变图谱 6.外泌体基因功能注释 7.LASSO回归筛选外泌体预后模型 8.预后模型验证 9.预后模型鲁棒性分析 10.独立预后因素分析及与临床的…...

C++: Map数组的遍历

在C中&#xff0c;map是一个关联容器&#xff0c;它存储的元素是键值对&#xff08;key-value pairs&#xff09;&#xff0c;其中每个键都是唯一的&#xff0c;并且自动根据键来排序。遍历map的方式有几种&#xff0c;但最常用的两种是使用迭代器&#xff08;iterator&#xf…...

【Windows】Bootstrap Studio(网页设计)软件介绍及安装步骤

软件介绍 Bootstrap Studio 是一款专为前端开发者设计的强大工具&#xff0c;主要用于快速创建现代化的响应式网页和网站。以下是它的主要特点和功能&#xff1a; 直观的界面设计 Bootstrap Studio 提供了直观的用户界面&#xff0c;使用户能够轻松拖放元素来构建网页。界面…...

二维舵机颜色追踪,使用树莓派+opencv+usb摄像头+两个舵机实现颜色追踪,采用pid调控

效果演示 二维云台颜色追踪 使用树莓派opencvusb摄像头两个舵机实现颜色追踪&#xff0c;采用pid调控 import cv2 import time import numpy as np from threading import Thread from servo import Servo from pid import PID# 初始化伺服电机 pan Servo(pin19) tilt Serv…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

高频面试之3Zookeeper

高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个&#xff1f;3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制&#xff08;过半机制&#xff0…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

Aspose.PDF 限制绕过方案:Java 字节码技术实战分享(仅供学习)

Aspose.PDF 限制绕过方案&#xff1a;Java 字节码技术实战分享&#xff08;仅供学习&#xff09; 一、Aspose.PDF 简介二、说明&#xff08;⚠️仅供学习与研究使用&#xff09;三、技术流程总览四、准备工作1. 下载 Jar 包2. Maven 项目依赖配置 五、字节码修改实现代码&#…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块&#xff0c;用于对本地知识库系统中的知识库进行增删改查&#xff08;CRUD&#xff09;操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 &#x1f4d8; 一、整体功能概述 该模块…...

Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)

引言 在人工智能飞速发展的今天&#xff0c;大语言模型&#xff08;Large Language Models, LLMs&#xff09;已成为技术领域的焦点。从智能写作到代码生成&#xff0c;LLM 的应用场景不断扩展&#xff0c;深刻改变了我们的工作和生活方式。然而&#xff0c;理解这些模型的内部…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会

在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准

城市路内停车管理常因行道树遮挡、高位设备盲区等问题&#xff0c;导致车牌识别率低、逃费率高&#xff0c;传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法&#xff0c;正成为破局关键。该设备安装于车位侧方0.5-0.7米高度&#xff0c;直接规避树枝遮…...