复现nnUNet2并跑通自定义数据
复现nnUNet2并跑通自定义数据
- 1. 配置环境
- 2. 处理数据集
- 2.1 创建文件夹
- 2.2 数据集格式转换
- 2.3 数据集预处理
- 3. 训练
- 4. 改进模型
- 4.1 概要
- 4.2 加注意力模块
1. 配置环境
stage1:创建python环境,这里建议python3.10
conda create --n nnunet python=3.10
conda activate nnunet
stage2:按照pytorch官方链接 (conda/pip) 上的说明安装PyTorch。
nnunetv2 2.5.1 需要 torch>=2.1.2,比如:
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8
但是,使用上述指令安装pytorch,很多时候还是装的cpuonly版,这里可以进行本地安装,首先搜索whl
,最后的cuda版本可以自行更改。
https://download.pytorch.org/whl/cu117
最后,直接cd到下载文件所属的盘符下:
d:
cd D:\program
pip install torch-2.1.2+cu118-cp310-cp310-win_amd64.whl
stage3:git nnUnet的源代码,并配置nnUnet
git clone https://github.com/MIC-DKFZ/nnUNet.git
cd nnUNet
pip install -e .
2. 处理数据集
可以通过运行将 nnU-Net v1 中的数据集转换为 V2 nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER OUTPUT_DATASET_NAME。请记住,v2 调用数据集 DatasetXXX_Name(而不是 Task),其中 XXX 是一个 3 位数字。请提供旧任务的路径,而不仅仅是任务名称。nnU-Net V2 不知道 v1 任务在哪里!
2.1 创建文件夹
首先,创建nnUNet_raw
、nnUNet_preprocessed
、nnUNet_results
三个文件夹,将下载的原始数据集放到nnUNet_raw下,这里MSD挑战赛的Task02_Heart
数据集为例。
nnUNet_raw
用于存放原始数据集;
nnUNet_preprocessed
用于存放格式转换并预处理的数据集及配置信息;
nnUNet_results
用于存放训练结果相关的信息文件。
2.2 数据集格式转换
接着,我们转换数据集格式:
nnUNetv2_convert_MSD_dataset -i D:\nnUNet\nnUNet_raw\Task02_Heart
可能会出现下列错误:
TypeError: expected str, bytes or os.PathLike object, not NoneType
解决方法:
原因是没有设置相关数据集的路径,打开D:\nnUNet\nnunetv2\paths.py
文件
方法1
:直接在这个py文件指定数据集路径:
方法2
:
设置环境变量:
1.windows:
#这个变量只在当前命令提示符会话中有效,不会被保存到系统环境变量中
set nnUNet_raw=D:/nnUNet/nnUNet_raw
set nnUNet_preprocessed=D:/nnUNet/nnUNet_preprocessed
set nnUNet_results=D:/nnUNet/nnUNet_results
2.linux:
export nnUNet_raw="D:/nnUNet/nnUNet_raw"
export nnUNet_preprocessed="D:/nnUNet/nnUNet_preprocessed"
export nnUNet_results="D:/nnUNet/nnUNet_results"
发现,转换成功,生成了1个Dataset002_Heart
文件夹:
转换前后主要区别
在于:
1、文件夹由Task02_Heart变成Dataset002_Heart。
ps:文件夹命名为:Dataset+三位整数+任务名
,Dataset002_Heart中数据集ID为2,任务名为Heart。此文件夹下存放需要的训练数据集imageTr、测试集imageTs、标签labelsTr。其中labelsTr是与imageTr中一一对应的标签,文件中都是nii.gz文件。imageTs是可选项,可以没有。
2、原始json中是:
转换后,字典的key为"channel_names"
:
3、原始json中有"training"、"test"这两个key存放数据集的相对路径:
转换后,字典的key为"file_ending"
,直接记录数据集的后缀名即可:
2.3 数据集预处理
nnUNetv2_plan_and_preprocess -d DATASET\_ID --verify_dataset_integrity
#这里我是用的是Dataset002_Heart,将ID改为2
nnUNetv2_plan_and_preprocess -d 2 --verify_dataset_integrity
运行成功后,会在nnUNet_preprocessed
文件下生成如下文件:
此步骤主要就是对数据进行:裁剪crop,重采样resample以及标准化normalization。将提取数据集fingerprint(一组特定于数据集的属性,例如图像大小、体素间距、强度信息等)。此信息用于设计三种 U-Net 配置。每个管道都在其自己的数据集预处理版本上运行。其中的nnUNetPlans.json是网络结构相关的配置文件。
3. 训练
训练使用如下脚本指令:
nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
DATASET_NAME_OR_ID 指定应在哪个数据集上进行训练;UNET_CONFIGURATION 是一个字符串,用于标识所请求的 U-Net 配置(默认值:2d、3d_fullres、3d_lowres、3d_cascade_lowres);
FOLD 指定训练 5 折交叉验证中的哪一折。
根据数据集名称或者数据集的ID、UNET配置、折数FOLD,进行训练:
nnUNetv2_train 2 2d 0
成功开始训练:
4. 改进模型
4.1 概要
许多人可能不满足于只是复现跑通,而是有一些改进模型,提升指标,最后发paper
的需求。
但是,会发现nnUNet2的网络模型结构很难找,但当找到nnUNet\nnunetv2\utilities\get_network_from_plans.py
时,发现还是找不到网络模型的model。
但是,我们在数据集处理的时候生成的nnUNet_preprocessed文件下的nnUNetPlans.json
文件中发现了网络结构的model文件。
于是,发现这个模型结构在我们配置的环境中,我就在Anaconda中我的环境下D:\Anaconda3\envs\nnunet\Lib\site-packages\dynamic_network_architectures\architectures
找到了。
4.2 加注意力模块
我们以在unet的编码器中添加SE注意力模块为例,首先打开unet.py文件,找到调用编码器实例化类对象的地方。
发现,PlainConvEncoder类中卷积模块通过StackedConvBlocks类实例化出来的,
简单点,我们就直接在StackedConvBlocks类中嵌入SE模块。找到了,nn.Sequential()模块拼接模块的部分,
我们在下面堆上SE模块。
这里我也附上SE的代码,通过参考其它来的继承来写SE:
class SE(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(SE, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)self.se = SEAttention(conv_op=conv_op, channel=output_channels)def forward(self, x):x = self.all_modules(x)x = self.se(x)return xdef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class SEAttention(nn.Module):def __init__(self, conv_op, channel=512, reduction=16):super().__init__()if conv_op == torch.nn.modules.conv.Conv2d:self.pool = nn.AdaptiveAvgPool2d(1)elif conv_op == torch.nn.modules.conv.Conv3d:self.pool = nn.AdaptiveAvgPool3d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):size = x.size()b, c = size[0], size[1]y = self.pool(x).view(b, c)y = self.fc(y).view(b, c, *[1 for i in range(len(size) - 2)])return x * y.expand_as(x)def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):init.normal_(m.weight, std=0.001)if m.bias is not None:init.constant_(m.bias, 0)
相关文章:

复现nnUNet2并跑通自定义数据
复现nnUNet2并跑通自定义数据 1. 配置环境2. 处理数据集2.1 创建文件夹2.2 数据集格式转换2.3 数据集预处理 3. 训练4. 改进模型4.1 概要4.2 加注意力模块 1. 配置环境 stage1:创建python环境,这里建议python3.10 conda create --n nnunet python3.10 …...

Educational Codeforces Round 169 (Rated for Div. 2)(ABCDE)
A. Closest Point 签到 #define _rep(i,a,b) for(int i(a);i<(b);i) int n,m; int q[N]; void solve() {cin>>n;_rep(i,1,n)cin>>q[i];if(n!2)cout<<"NO\n";else if(abs(q[1]-q[2])!1)cout<<"YES\n";else cout<<"…...

成为Python砖家(2): str 最常用的8大方法
str 类最常用的8个方法 str.lower()str.upper()str.split(sepNone, maxsplit-1)str.count(sub[, start[, end]])str.replace(old, new[, count])str.center(width[, fillchar])str.strip([chars])str.join(iterable) 查询方法的文档 根据 成为Python砖家(1): 在本地查询Pyth…...

深入理解JVM运行时数据区(内存布局 )5大部分 | 异常讨论
前言: JVM运行时数据区(内存布局)是Java程序执行时用于存储各种数据的内存区域。这些区域在JVM启动时被创建,并在JVM关闭时销毁。它们的布局和管理方式对Java程序的性能和稳定性有着重要影响。 目录 一、由以下5大部分组成 1.…...

JAVA根据表名获取Oracle表结构信息
响应实体封装 import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor;/*** author CQY* version 1.0* date 2024/8/15 16:33**/ Data NoArgsConstructor AllArgsConstructor Builder public class OracleTableInfo …...

网络性能优化
网络性能优化是确保网络稳定性、速度和可靠性的关键步骤。优化过程通常包括诊断问题、识别瓶颈以及实施具体的解决方案。以下是关于如何进行网络性能优化的详细指南: 一、问题诊断 网络性能监控 网络流量分析工具:使用Wireshark、NetFlow、Ntop等工具监…...

[C++String]接口解读,深拷贝和浅拷贝,string的模拟实现
💖💖💖欢迎来到我的博客,我是anmory💖💖💖 又和大家见面了 欢迎来到C探索系列 作为一个程序员你不能不掌握的知识 先来自我推荐一波 个人网站欢迎访问以及捐款 推荐阅读 如何低成本搭建个人网站…...

理性看待、正确理解 AI 中的 Scaling “laws”
编者按:LLMs 规模和性能的不断提升,让人们不禁产生疑问:这种趋势是否能一直持续下去?我们是否能通过不断扩大模型规模最终实现通用人工智能(AGI)?回答这些问题对于理解 AI 的未来发展轨迹至关重…...

【OCR 学习笔记】二值化——全局阈值方法
二值化——全局阈值方法 固定阈值方法Otsu算法在OpenCV中的实现固定阈值Otsu算法 图像二值化(Image Binarization)是指将像素点的灰度值设为0或255,使图像呈现明显的黑白效果。二值化一方面减少了数据维度,另一方面通过排除原图中…...

Java - IDEA开发
使用IDEA开发Java程序步骤: 创建工程 Project;创建模块 Module;创建包 Package;创建类;编写代码; 如何查看JDK版本 Package介绍: package是将项目中的各种文件,比如源代码、编译生成的字节码、配置文件、…...

Oracle(62)什么是内存优化表(In-Memory Table)?
内存优化表(In-Memory Table)是指将表的数据存储在内存中,以提高数据访问和查询性能的一种技术。内存优化表通过利用内存的高速访问特性,显著减少I/O操作的延迟,提升数据处理的速度。这种技术在需要高性能数据处理的应…...

#window家庭版安装hyper-v#
由于window 11 家庭版没有hyper-v虚拟机服务,则需要安装一下,使用如下操作 1:新建一个txt文件,拷贝如下脚本到里面 pushd "%\~dp0" dir /b %SystemRoot%\servicing\Packages\*Hyper-V*.mum >hyper-v.txt for /f %%i in (findst…...

【云原生】Pass容器研发基础——汇总篇
云原生基础汇总 系列综述: 💞目的:本系列是个人整理为了云计算学习的,整理期间苛求每个知识点,平衡理解简易度与深入程度。 🥰来源:每个知识点的修正和深入主要参考各平台大佬的文章,…...

【Py/Java/C++三种语言详解】LeetCode743、网络延迟时间【单源最短路问题Djikstra算法】
可上 欧弟OJ系统 练习华子OD、大厂真题 绿色聊天软件戳 od1441了解算法冲刺训练(备注【CSDN】否则不通过) 文章目录 相关推荐阅读一、题目描述二、题目解析三、参考代码PythonJavaC 时空复杂度 华为OD算法/大厂面试高频题算法练习冲刺训练 相关推荐阅读 …...

交替输出
交替输出 题目:线程 1 输出 a 5 次,线程 2 输出 b 5 次,线程 3 输出 c 5 次。现在要求输出 abcabcabcabcabc wait notify 版 public class SyncWaitNotify {private volatile int flag;private volatile int loopNumber;public SyncWaitNo…...

JS(三)——更改html内数据
获取 DOM 元素,然后修改其属性或内容。使用 getElementById 方法获取特定 ID 的元素: <p id"myParagraph">这是初始的文本</p> const paragraph document.getElementById(myParagraph); paragraph.innerHTML 这是修改后的文本…...

CSS小玩意儿:文字适配背景
一,效果 二,代码 1,搭个框架 添加一张背景图片,在图片中显示一行文字。 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" conte…...

C++:平衡二叉搜索树之红黑树
一、红黑树的概念 红黑树, 和AVL都是二叉搜索树, 红黑树通过在每个节点上增加一个储存位表示节点的颜色, 可以是RED或者BLACK, 通过任何一条从根到叶子的路径上各个节点着色方式的限制,红黑树能够确保没有一条路径会比…...

CentOS 7 系统优化
CentOS 7 系统优化 1、配置YUM源 阿里云的YUM源配置: CentOS 7使用以下命令: sudo wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repoCentOS 8使用以下命令: sudo wget -O /etc/yum.repos.d/CentOS…...

扫雷游戏——附源代码
扫雷游戏的源代码比较简单,不设计比较复杂的代码,主要是多个函数的组合,每个函数执行自己的功能,最终支持游戏的完成。 1.菜单 我们需要一个提醒信息来让用户进行选择。 void menu() {printf("***********************\n&…...

Vue3列表(List)
效果如下图:在线预览 APIs List 参数说明类型默认值bordered是否展示边框booleanfalsevertical是否使用竖直样式booleanfalsesplit是否展示分割线booleantruesize列表尺寸‘small’ | ‘middle’ | ‘large’‘middle’loading是否加载中booleanfalsehoverable是否…...

HarmonyOS NEXT - Navigation组件封装BaseNavigation
demo 地址: https://github.com/iotjin/JhHarmonyDemo 代码不定时更新,请前往github查看最新代码 在demo中这些组件和工具类都通过module实现了,具体可以参考HarmonyOS NEXT - 通过 module 模块化引用公共组件和utils 官方介绍 组件导航 (Navigation)(推…...

浅看MySQL数据库
有这么一句话:“一个不会数据库的程序员不是合格的程序员”。有点夸张,但是确是如此。透彻学习数据库是要学习好多知识,需要学的东西也是偏难的。我们今天来看数据库MySQL的一些简单基础东西,跟着小编一起来看一下吧。 什么是数据…...

Pytorch常用训练套路框架(CPU)
文章目录 1. 数据准备示例:加载 CIFAR-10 数据集 2. 模型定义示例:定义一个简单的卷积神经网络 3. 损失函数和优化器示例:定义损失函数和优化器 4. 训练循环示例:训练循环 5. 评估和测试示例:评估模型 6. 保存和加载模…...

C++ | Leetcode C++题解之第338题比特位计数
题目: 题解: class Solution { public:vector<int> countBits(int n) {vector<int> bits(n 1);for (int i 1; i < n; i) {bits[i] bits[i & (i - 1)] 1;}return bits;} };...

智慧校园云平台电子班牌系统源码,智慧教育一体化云解决方案
智慧校园云平台电子班牌系统,利用先进的云计算技术,将教育信息化资源和教学管理系统进行有效整合,实现生态基础数据共享、应用生态统一管理,为智慧教育建设的统一性,稳定性,可扩展性,互通性提供…...

数据库系统 第17节 数据仓库 案例赏析
下面我将通过几个具体的案例来说明数据仓库如何在不同的行业中发挥作用,并解决实际业务问题。 案例 1: 零售业 背景: 一家大型零售商希望改进其库存管理和市场营销策略,以提高销售额和顾客满意度。 解决方案: 数据仓库: 构建一个数据仓库࿰…...
硬件面试经典 100 题(71~90 题)
71、请问下图电路的作用是什么? 该电路实现 IIC 信号的电平转换(3.3V 和 5V 电平转换),并且是双向通信的。 上下两路是一样的,只分析 SDA 一路: 1) 从左到右通信(SDA2 为输入状态&…...

【git】代理相关
问题: 开启了翻墙代理工具,拉取代码时报错:fatal: 无法访问 xxxx : Failed to connect to github.com port 443: 连接超时 解决: 0,取消代理仍然无法拉取 1,查看控制面板-网络与Internet-代理ÿ…...

golang gin框架中创建自定义中间件的2种方式总结 - func(*gin.Context)方式和闭包函数方式定义gin中间件
在gin框架中,我们可以通过2种方式创建自定义中间件: 1. 直接定义一个类型为 func(*gin.Context)的函数或者方法 这种方式是我们常用的方式,也就是定义一个参数为*gin.Context的函数或者方法。定义的方法就是创建一个 参数类型为 gin.Handler…...