复现nnUNet2并跑通自定义数据
复现nnUNet2并跑通自定义数据
- 1. 配置环境
- 2. 处理数据集
- 2.1 创建文件夹
- 2.2 数据集格式转换
- 2.3 数据集预处理
- 3. 训练
- 4. 改进模型
- 4.1 概要
- 4.2 加注意力模块
1. 配置环境
stage1:创建python环境,这里建议python3.10
conda create --n nnunet python=3.10
conda activate nnunet
stage2:按照pytorch官方链接 (conda/pip) 上的说明安装PyTorch。
nnunetv2 2.5.1 需要 torch>=2.1.2,比如:
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8
但是,使用上述指令安装pytorch,很多时候还是装的cpuonly版,这里可以进行本地安装,首先搜索whl,最后的cuda版本可以自行更改。
https://download.pytorch.org/whl/cu117
最后,直接cd到下载文件所属的盘符下:
d:
cd D:\program
pip install torch-2.1.2+cu118-cp310-cp310-win_amd64.whl
stage3:git nnUnet的源代码,并配置nnUnet
git clone https://github.com/MIC-DKFZ/nnUNet.git
cd nnUNet
pip install -e .
2. 处理数据集
可以通过运行将 nnU-Net v1 中的数据集转换为 V2 nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER OUTPUT_DATASET_NAME。请记住,v2 调用数据集 DatasetXXX_Name(而不是 Task),其中 XXX 是一个 3 位数字。请提供旧任务的路径,而不仅仅是任务名称。nnU-Net V2 不知道 v1 任务在哪里!
2.1 创建文件夹
首先,创建nnUNet_raw、nnUNet_preprocessed、nnUNet_results三个文件夹,将下载的原始数据集放到nnUNet_raw下,这里MSD挑战赛的Task02_Heart数据集为例。
nnUNet_raw用于存放原始数据集;
nnUNet_preprocessed用于存放格式转换并预处理的数据集及配置信息;
nnUNet_results用于存放训练结果相关的信息文件。

2.2 数据集格式转换
接着,我们转换数据集格式:
nnUNetv2_convert_MSD_dataset -i D:\nnUNet\nnUNet_raw\Task02_Heart
可能会出现下列错误:
TypeError: expected str, bytes or os.PathLike object, not NoneType
解决方法:
原因是没有设置相关数据集的路径,打开D:\nnUNet\nnunetv2\paths.py文件

方法1:直接在这个py文件指定数据集路径:

方法2:
设置环境变量:
1.windows:
#这个变量只在当前命令提示符会话中有效,不会被保存到系统环境变量中
set nnUNet_raw=D:/nnUNet/nnUNet_raw
set nnUNet_preprocessed=D:/nnUNet/nnUNet_preprocessed
set nnUNet_results=D:/nnUNet/nnUNet_results
2.linux:
export nnUNet_raw="D:/nnUNet/nnUNet_raw"
export nnUNet_preprocessed="D:/nnUNet/nnUNet_preprocessed"
export nnUNet_results="D:/nnUNet/nnUNet_results"
发现,转换成功,生成了1个Dataset002_Heart文件夹:

转换前后主要区别在于:
1、文件夹由Task02_Heart变成Dataset002_Heart。
ps:文件夹命名为:Dataset+三位整数+任务名,Dataset002_Heart中数据集ID为2,任务名为Heart。此文件夹下存放需要的训练数据集imageTr、测试集imageTs、标签labelsTr。其中labelsTr是与imageTr中一一对应的标签,文件中都是nii.gz文件。imageTs是可选项,可以没有。
2、原始json中是:

转换后,字典的key为"channel_names":

3、原始json中有"training"、"test"这两个key存放数据集的相对路径:

转换后,字典的key为"file_ending",直接记录数据集的后缀名即可:

2.3 数据集预处理
nnUNetv2_plan_and_preprocess -d DATASET\_ID --verify_dataset_integrity
#这里我是用的是Dataset002_Heart,将ID改为2
nnUNetv2_plan_and_preprocess -d 2 --verify_dataset_integrity
运行成功后,会在nnUNet_preprocessed文件下生成如下文件:

此步骤主要就是对数据进行:裁剪crop,重采样resample以及标准化normalization。将提取数据集fingerprint(一组特定于数据集的属性,例如图像大小、体素间距、强度信息等)。此信息用于设计三种 U-Net 配置。每个管道都在其自己的数据集预处理版本上运行。其中的nnUNetPlans.json是网络结构相关的配置文件。
3. 训练
训练使用如下脚本指令:
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
DATASET_NAME_OR_ID 指定应在哪个数据集上进行训练;UNET_CONFIGURATION 是一个字符串,用于标识所请求的 U-Net 配置(默认值:2d、3d_fullres、3d_lowres、3d_cascade_lowres);
FOLD 指定训练 5 折交叉验证中的哪一折。
根据数据集名称或者数据集的ID、UNET配置、折数FOLD,进行训练:
nnUNetv2_train 2 2d 0
成功开始训练:

4. 改进模型
4.1 概要
许多人可能不满足于只是复现跑通,而是有一些改进模型,提升指标,最后发paper的需求。
但是,会发现nnUNet2的网络模型结构很难找,但当找到nnUNet\nnunetv2\utilities\get_network_from_plans.py时,发现还是找不到网络模型的model。
但是,我们在数据集处理的时候生成的nnUNet_preprocessed文件下的nnUNetPlans.json文件中发现了网络结构的model文件。

于是,发现这个模型结构在我们配置的环境中,我就在Anaconda中我的环境下D:\Anaconda3\envs\nnunet\Lib\site-packages\dynamic_network_architectures\architectures找到了。

4.2 加注意力模块
我们以在unet的编码器中添加SE注意力模块为例,首先打开unet.py文件,找到调用编码器实例化类对象的地方。

发现,PlainConvEncoder类中卷积模块通过StackedConvBlocks类实例化出来的,

简单点,我们就直接在StackedConvBlocks类中嵌入SE模块。找到了,nn.Sequential()模块拼接模块的部分,
我们在下面堆上SE模块。

这里我也附上SE的代码,通过参考其它来的继承来写SE:
class SE(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(SE, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)self.se = SEAttention(conv_op=conv_op, channel=output_channels)def forward(self, x):x = self.all_modules(x)x = self.se(x)return xdef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class SEAttention(nn.Module):def __init__(self, conv_op, channel=512, reduction=16):super().__init__()if conv_op == torch.nn.modules.conv.Conv2d:self.pool = nn.AdaptiveAvgPool2d(1)elif conv_op == torch.nn.modules.conv.Conv3d:self.pool = nn.AdaptiveAvgPool3d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):size = x.size()b, c = size[0], size[1]y = self.pool(x).view(b, c)y = self.fc(y).view(b, c, *[1 for i in range(len(size) - 2)])return x * y.expand_as(x)def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)
相关文章:
复现nnUNet2并跑通自定义数据
复现nnUNet2并跑通自定义数据 1. 配置环境2. 处理数据集2.1 创建文件夹2.2 数据集格式转换2.3 数据集预处理 3. 训练4. 改进模型4.1 概要4.2 加注意力模块 1. 配置环境 stage1:创建python环境,这里建议python3.10 conda create --n nnunet python3.10 …...
Educational Codeforces Round 169 (Rated for Div. 2)(ABCDE)
A. Closest Point 签到 #define _rep(i,a,b) for(int i(a);i<(b);i) int n,m; int q[N]; void solve() {cin>>n;_rep(i,1,n)cin>>q[i];if(n!2)cout<<"NO\n";else if(abs(q[1]-q[2])!1)cout<<"YES\n";else cout<<"…...
成为Python砖家(2): str 最常用的8大方法
str 类最常用的8个方法 str.lower()str.upper()str.split(sepNone, maxsplit-1)str.count(sub[, start[, end]])str.replace(old, new[, count])str.center(width[, fillchar])str.strip([chars])str.join(iterable) 查询方法的文档 根据 成为Python砖家(1): 在本地查询Pyth…...
深入理解JVM运行时数据区(内存布局 )5大部分 | 异常讨论
前言: JVM运行时数据区(内存布局)是Java程序执行时用于存储各种数据的内存区域。这些区域在JVM启动时被创建,并在JVM关闭时销毁。它们的布局和管理方式对Java程序的性能和稳定性有着重要影响。 目录 一、由以下5大部分组成 1.…...
JAVA根据表名获取Oracle表结构信息
响应实体封装 import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor;/*** author CQY* version 1.0* date 2024/8/15 16:33**/ Data NoArgsConstructor AllArgsConstructor Builder public class OracleTableInfo …...
网络性能优化
网络性能优化是确保网络稳定性、速度和可靠性的关键步骤。优化过程通常包括诊断问题、识别瓶颈以及实施具体的解决方案。以下是关于如何进行网络性能优化的详细指南: 一、问题诊断 网络性能监控 网络流量分析工具:使用Wireshark、NetFlow、Ntop等工具监…...
[C++String]接口解读,深拷贝和浅拷贝,string的模拟实现
💖💖💖欢迎来到我的博客,我是anmory💖💖💖 又和大家见面了 欢迎来到C探索系列 作为一个程序员你不能不掌握的知识 先来自我推荐一波 个人网站欢迎访问以及捐款 推荐阅读 如何低成本搭建个人网站…...
理性看待、正确理解 AI 中的 Scaling “laws”
编者按:LLMs 规模和性能的不断提升,让人们不禁产生疑问:这种趋势是否能一直持续下去?我们是否能通过不断扩大模型规模最终实现通用人工智能(AGI)?回答这些问题对于理解 AI 的未来发展轨迹至关重…...
【OCR 学习笔记】二值化——全局阈值方法
二值化——全局阈值方法 固定阈值方法Otsu算法在OpenCV中的实现固定阈值Otsu算法 图像二值化(Image Binarization)是指将像素点的灰度值设为0或255,使图像呈现明显的黑白效果。二值化一方面减少了数据维度,另一方面通过排除原图中…...
Java - IDEA开发
使用IDEA开发Java程序步骤: 创建工程 Project;创建模块 Module;创建包 Package;创建类;编写代码; 如何查看JDK版本 Package介绍: package是将项目中的各种文件,比如源代码、编译生成的字节码、配置文件、…...
Oracle(62)什么是内存优化表(In-Memory Table)?
内存优化表(In-Memory Table)是指将表的数据存储在内存中,以提高数据访问和查询性能的一种技术。内存优化表通过利用内存的高速访问特性,显著减少I/O操作的延迟,提升数据处理的速度。这种技术在需要高性能数据处理的应…...
#window家庭版安装hyper-v#
由于window 11 家庭版没有hyper-v虚拟机服务,则需要安装一下,使用如下操作 1:新建一个txt文件,拷贝如下脚本到里面 pushd "%\~dp0" dir /b %SystemRoot%\servicing\Packages\*Hyper-V*.mum >hyper-v.txt for /f %%i in (findst…...
【云原生】Pass容器研发基础——汇总篇
云原生基础汇总 系列综述: 💞目的:本系列是个人整理为了云计算学习的,整理期间苛求每个知识点,平衡理解简易度与深入程度。 🥰来源:每个知识点的修正和深入主要参考各平台大佬的文章,…...
【Py/Java/C++三种语言详解】LeetCode743、网络延迟时间【单源最短路问题Djikstra算法】
可上 欧弟OJ系统 练习华子OD、大厂真题 绿色聊天软件戳 od1441了解算法冲刺训练(备注【CSDN】否则不通过) 文章目录 相关推荐阅读一、题目描述二、题目解析三、参考代码PythonJavaC 时空复杂度 华为OD算法/大厂面试高频题算法练习冲刺训练 相关推荐阅读 …...
交替输出
交替输出 题目:线程 1 输出 a 5 次,线程 2 输出 b 5 次,线程 3 输出 c 5 次。现在要求输出 abcabcabcabcabc wait notify 版 public class SyncWaitNotify {private volatile int flag;private volatile int loopNumber;public SyncWaitNo…...
JS(三)——更改html内数据
获取 DOM 元素,然后修改其属性或内容。使用 getElementById 方法获取特定 ID 的元素: <p id"myParagraph">这是初始的文本</p> const paragraph document.getElementById(myParagraph); paragraph.innerHTML 这是修改后的文本…...
CSS小玩意儿:文字适配背景
一,效果 二,代码 1,搭个框架 添加一张背景图片,在图片中显示一行文字。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" conte…...
C++:平衡二叉搜索树之红黑树
一、红黑树的概念 红黑树, 和AVL都是二叉搜索树, 红黑树通过在每个节点上增加一个储存位表示节点的颜色, 可以是RED或者BLACK, 通过任何一条从根到叶子的路径上各个节点着色方式的限制,红黑树能够确保没有一条路径会比…...
CentOS 7 系统优化
CentOS 7 系统优化 1、配置YUM源 阿里云的YUM源配置: CentOS 7使用以下命令: sudo wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repoCentOS 8使用以下命令: sudo wget -O /etc/yum.repos.d/CentOS…...
扫雷游戏——附源代码
扫雷游戏的源代码比较简单,不设计比较复杂的代码,主要是多个函数的组合,每个函数执行自己的功能,最终支持游戏的完成。 1.菜单 我们需要一个提醒信息来让用户进行选择。 void menu() {printf("***********************\n&…...
挑战杯推荐项目
“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 - 个性化梦境…...
ES6从入门到精通:前言
ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var…...
使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...
Python爬虫实战:研究feedparser库相关技术
1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
ETLCloud可能遇到的问题有哪些?常见坑位解析
数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...
Ascend NPU上适配Step-Audio模型
1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统,支持多语言对话(如 中文,英文,日语),语音情感(如 开心,悲伤)&#x…...
【HTTP三个基础问题】
面试官您好!HTTP是超文本传输协议,是互联网上客户端和服务器之间传输超文本数据(比如文字、图片、音频、视频等)的核心协议,当前互联网应用最广泛的版本是HTTP1.1,它基于经典的C/S模型,也就是客…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...
