复现nnUNet2并跑通自定义数据
复现nnUNet2并跑通自定义数据
- 1. 配置环境
- 2. 处理数据集
- 2.1 创建文件夹
- 2.2 数据集格式转换
- 2.3 数据集预处理
- 3. 训练
- 4. 改进模型
- 4.1 概要
- 4.2 加注意力模块
1. 配置环境
stage1:创建python环境,这里建议python3.10
conda create --n nnunet python=3.10
conda activate nnunet
stage2:按照pytorch官方链接 (conda/pip) 上的说明安装PyTorch。
nnunetv2 2.5.1 需要 torch>=2.1.2,比如:
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8
但是,使用上述指令安装pytorch,很多时候还是装的cpuonly版,这里可以进行本地安装,首先搜索whl,最后的cuda版本可以自行更改。
https://download.pytorch.org/whl/cu117
最后,直接cd到下载文件所属的盘符下:
d:
cd D:\program
pip install torch-2.1.2+cu118-cp310-cp310-win_amd64.whl
stage3:git nnUnet的源代码,并配置nnUnet
git clone https://github.com/MIC-DKFZ/nnUNet.git
cd nnUNet
pip install -e .
2. 处理数据集
可以通过运行将 nnU-Net v1 中的数据集转换为 V2 nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER OUTPUT_DATASET_NAME。请记住,v2 调用数据集 DatasetXXX_Name(而不是 Task),其中 XXX 是一个 3 位数字。请提供旧任务的路径,而不仅仅是任务名称。nnU-Net V2 不知道 v1 任务在哪里!
2.1 创建文件夹
首先,创建nnUNet_raw、nnUNet_preprocessed、nnUNet_results三个文件夹,将下载的原始数据集放到nnUNet_raw下,这里MSD挑战赛的Task02_Heart数据集为例。
nnUNet_raw用于存放原始数据集;
nnUNet_preprocessed用于存放格式转换并预处理的数据集及配置信息;
nnUNet_results用于存放训练结果相关的信息文件。

2.2 数据集格式转换
接着,我们转换数据集格式:
nnUNetv2_convert_MSD_dataset -i D:\nnUNet\nnUNet_raw\Task02_Heart
可能会出现下列错误:
TypeError: expected str, bytes or os.PathLike object, not NoneType
解决方法:
原因是没有设置相关数据集的路径,打开D:\nnUNet\nnunetv2\paths.py文件

方法1:直接在这个py文件指定数据集路径:

方法2:
设置环境变量:
1.windows:
#这个变量只在当前命令提示符会话中有效,不会被保存到系统环境变量中
set nnUNet_raw=D:/nnUNet/nnUNet_raw
set nnUNet_preprocessed=D:/nnUNet/nnUNet_preprocessed
set nnUNet_results=D:/nnUNet/nnUNet_results
2.linux:
export nnUNet_raw="D:/nnUNet/nnUNet_raw"
export nnUNet_preprocessed="D:/nnUNet/nnUNet_preprocessed"
export nnUNet_results="D:/nnUNet/nnUNet_results"
发现,转换成功,生成了1个Dataset002_Heart文件夹:

转换前后主要区别在于:
1、文件夹由Task02_Heart变成Dataset002_Heart。
ps:文件夹命名为:Dataset+三位整数+任务名,Dataset002_Heart中数据集ID为2,任务名为Heart。此文件夹下存放需要的训练数据集imageTr、测试集imageTs、标签labelsTr。其中labelsTr是与imageTr中一一对应的标签,文件中都是nii.gz文件。imageTs是可选项,可以没有。
2、原始json中是:

转换后,字典的key为"channel_names":

3、原始json中有"training"、"test"这两个key存放数据集的相对路径:

转换后,字典的key为"file_ending",直接记录数据集的后缀名即可:

2.3 数据集预处理
nnUNetv2_plan_and_preprocess -d DATASET\_ID --verify_dataset_integrity
#这里我是用的是Dataset002_Heart,将ID改为2
nnUNetv2_plan_and_preprocess -d 2 --verify_dataset_integrity
运行成功后,会在nnUNet_preprocessed文件下生成如下文件:

此步骤主要就是对数据进行:裁剪crop,重采样resample以及标准化normalization。将提取数据集fingerprint(一组特定于数据集的属性,例如图像大小、体素间距、强度信息等)。此信息用于设计三种 U-Net 配置。每个管道都在其自己的数据集预处理版本上运行。其中的nnUNetPlans.json是网络结构相关的配置文件。
3. 训练
训练使用如下脚本指令:
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
DATASET_NAME_OR_ID 指定应在哪个数据集上进行训练;UNET_CONFIGURATION 是一个字符串,用于标识所请求的 U-Net 配置(默认值:2d、3d_fullres、3d_lowres、3d_cascade_lowres);
FOLD 指定训练 5 折交叉验证中的哪一折。
根据数据集名称或者数据集的ID、UNET配置、折数FOLD,进行训练:
nnUNetv2_train 2 2d 0
成功开始训练:

4. 改进模型
4.1 概要
许多人可能不满足于只是复现跑通,而是有一些改进模型,提升指标,最后发paper的需求。
但是,会发现nnUNet2的网络模型结构很难找,但当找到nnUNet\nnunetv2\utilities\get_network_from_plans.py时,发现还是找不到网络模型的model。
但是,我们在数据集处理的时候生成的nnUNet_preprocessed文件下的nnUNetPlans.json文件中发现了网络结构的model文件。

于是,发现这个模型结构在我们配置的环境中,我就在Anaconda中我的环境下D:\Anaconda3\envs\nnunet\Lib\site-packages\dynamic_network_architectures\architectures找到了。

4.2 加注意力模块
我们以在unet的编码器中添加SE注意力模块为例,首先打开unet.py文件,找到调用编码器实例化类对象的地方。

发现,PlainConvEncoder类中卷积模块通过StackedConvBlocks类实例化出来的,

简单点,我们就直接在StackedConvBlocks类中嵌入SE模块。找到了,nn.Sequential()模块拼接模块的部分,
我们在下面堆上SE模块。

这里我也附上SE的代码,通过参考其它来的继承来写SE:
class SE(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(SE, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)self.se = SEAttention(conv_op=conv_op, channel=output_channels)def forward(self, x):x = self.all_modules(x)x = self.se(x)return xdef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class SEAttention(nn.Module):def __init__(self, conv_op, channel=512, reduction=16):super().__init__()if conv_op == torch.nn.modules.conv.Conv2d:self.pool = nn.AdaptiveAvgPool2d(1)elif conv_op == torch.nn.modules.conv.Conv3d:self.pool = nn.AdaptiveAvgPool3d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):size = x.size()b, c = size[0], size[1]y = self.pool(x).view(b, c)y = self.fc(y).view(b, c, *[1 for i in range(len(size) - 2)])return x * y.expand_as(x)def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)
相关文章:
复现nnUNet2并跑通自定义数据
复现nnUNet2并跑通自定义数据 1. 配置环境2. 处理数据集2.1 创建文件夹2.2 数据集格式转换2.3 数据集预处理 3. 训练4. 改进模型4.1 概要4.2 加注意力模块 1. 配置环境 stage1:创建python环境,这里建议python3.10 conda create --n nnunet python3.10 …...
Educational Codeforces Round 169 (Rated for Div. 2)(ABCDE)
A. Closest Point 签到 #define _rep(i,a,b) for(int i(a);i<(b);i) int n,m; int q[N]; void solve() {cin>>n;_rep(i,1,n)cin>>q[i];if(n!2)cout<<"NO\n";else if(abs(q[1]-q[2])!1)cout<<"YES\n";else cout<<"…...
成为Python砖家(2): str 最常用的8大方法
str 类最常用的8个方法 str.lower()str.upper()str.split(sepNone, maxsplit-1)str.count(sub[, start[, end]])str.replace(old, new[, count])str.center(width[, fillchar])str.strip([chars])str.join(iterable) 查询方法的文档 根据 成为Python砖家(1): 在本地查询Pyth…...
深入理解JVM运行时数据区(内存布局 )5大部分 | 异常讨论
前言: JVM运行时数据区(内存布局)是Java程序执行时用于存储各种数据的内存区域。这些区域在JVM启动时被创建,并在JVM关闭时销毁。它们的布局和管理方式对Java程序的性能和稳定性有着重要影响。 目录 一、由以下5大部分组成 1.…...
JAVA根据表名获取Oracle表结构信息
响应实体封装 import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor;/*** author CQY* version 1.0* date 2024/8/15 16:33**/ Data NoArgsConstructor AllArgsConstructor Builder public class OracleTableInfo …...
网络性能优化
网络性能优化是确保网络稳定性、速度和可靠性的关键步骤。优化过程通常包括诊断问题、识别瓶颈以及实施具体的解决方案。以下是关于如何进行网络性能优化的详细指南: 一、问题诊断 网络性能监控 网络流量分析工具:使用Wireshark、NetFlow、Ntop等工具监…...
[C++String]接口解读,深拷贝和浅拷贝,string的模拟实现
💖💖💖欢迎来到我的博客,我是anmory💖💖💖 又和大家见面了 欢迎来到C探索系列 作为一个程序员你不能不掌握的知识 先来自我推荐一波 个人网站欢迎访问以及捐款 推荐阅读 如何低成本搭建个人网站…...
理性看待、正确理解 AI 中的 Scaling “laws”
编者按:LLMs 规模和性能的不断提升,让人们不禁产生疑问:这种趋势是否能一直持续下去?我们是否能通过不断扩大模型规模最终实现通用人工智能(AGI)?回答这些问题对于理解 AI 的未来发展轨迹至关重…...
【OCR 学习笔记】二值化——全局阈值方法
二值化——全局阈值方法 固定阈值方法Otsu算法在OpenCV中的实现固定阈值Otsu算法 图像二值化(Image Binarization)是指将像素点的灰度值设为0或255,使图像呈现明显的黑白效果。二值化一方面减少了数据维度,另一方面通过排除原图中…...
Java - IDEA开发
使用IDEA开发Java程序步骤: 创建工程 Project;创建模块 Module;创建包 Package;创建类;编写代码; 如何查看JDK版本 Package介绍: package是将项目中的各种文件,比如源代码、编译生成的字节码、配置文件、…...
Oracle(62)什么是内存优化表(In-Memory Table)?
内存优化表(In-Memory Table)是指将表的数据存储在内存中,以提高数据访问和查询性能的一种技术。内存优化表通过利用内存的高速访问特性,显著减少I/O操作的延迟,提升数据处理的速度。这种技术在需要高性能数据处理的应…...
#window家庭版安装hyper-v#
由于window 11 家庭版没有hyper-v虚拟机服务,则需要安装一下,使用如下操作 1:新建一个txt文件,拷贝如下脚本到里面 pushd "%\~dp0" dir /b %SystemRoot%\servicing\Packages\*Hyper-V*.mum >hyper-v.txt for /f %%i in (findst…...
【云原生】Pass容器研发基础——汇总篇
云原生基础汇总 系列综述: 💞目的:本系列是个人整理为了云计算学习的,整理期间苛求每个知识点,平衡理解简易度与深入程度。 🥰来源:每个知识点的修正和深入主要参考各平台大佬的文章,…...
【Py/Java/C++三种语言详解】LeetCode743、网络延迟时间【单源最短路问题Djikstra算法】
可上 欧弟OJ系统 练习华子OD、大厂真题 绿色聊天软件戳 od1441了解算法冲刺训练(备注【CSDN】否则不通过) 文章目录 相关推荐阅读一、题目描述二、题目解析三、参考代码PythonJavaC 时空复杂度 华为OD算法/大厂面试高频题算法练习冲刺训练 相关推荐阅读 …...
交替输出
交替输出 题目:线程 1 输出 a 5 次,线程 2 输出 b 5 次,线程 3 输出 c 5 次。现在要求输出 abcabcabcabcabc wait notify 版 public class SyncWaitNotify {private volatile int flag;private volatile int loopNumber;public SyncWaitNo…...
JS(三)——更改html内数据
获取 DOM 元素,然后修改其属性或内容。使用 getElementById 方法获取特定 ID 的元素: <p id"myParagraph">这是初始的文本</p> const paragraph document.getElementById(myParagraph); paragraph.innerHTML 这是修改后的文本…...
CSS小玩意儿:文字适配背景
一,效果 二,代码 1,搭个框架 添加一张背景图片,在图片中显示一行文字。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" conte…...
C++:平衡二叉搜索树之红黑树
一、红黑树的概念 红黑树, 和AVL都是二叉搜索树, 红黑树通过在每个节点上增加一个储存位表示节点的颜色, 可以是RED或者BLACK, 通过任何一条从根到叶子的路径上各个节点着色方式的限制,红黑树能够确保没有一条路径会比…...
CentOS 7 系统优化
CentOS 7 系统优化 1、配置YUM源 阿里云的YUM源配置: CentOS 7使用以下命令: sudo wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repoCentOS 8使用以下命令: sudo wget -O /etc/yum.repos.d/CentOS…...
扫雷游戏——附源代码
扫雷游戏的源代码比较简单,不设计比较复杂的代码,主要是多个函数的组合,每个函数执行自己的功能,最终支持游戏的完成。 1.菜单 我们需要一个提醒信息来让用户进行选择。 void menu() {printf("***********************\n&…...
UE5 学习系列(二)用户操作界面及介绍
这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...
盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...
【力扣数据库知识手册笔记】索引
索引 索引的优缺点 优点1. 通过创建唯一性索引,可以保证数据库表中每一行数据的唯一性。2. 可以加快数据的检索速度(创建索引的主要原因)。3. 可以加速表和表之间的连接,实现数据的参考完整性。4. 可以在查询过程中,…...
HTML 列表、表格、表单
1 列表标签 作用:布局内容排列整齐的区域 列表分类:无序列表、有序列表、定义列表。 例如: 1.1 无序列表 标签:ul 嵌套 li,ul是无序列表,li是列表条目。 注意事项: ul 标签里面只能包裹 li…...
linux 错误码总结
1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...
linux 下常用变更-8
1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行,YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID: YW3…...
OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别
OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...
select、poll、epoll 与 Reactor 模式
在高并发网络编程领域,高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表,以及基于它们实现的 Reactor 模式,为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。 一、I…...
Android第十三次面试总结(四大 组件基础)
Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: onCreate() 调用时机:Activity 首次创建时调用。…...
