卷积神经网络-3D医疗影像识别
文章目录
- 一、前言
- 二、前期工作
- 1. 介绍
- 2. 加载和预处理数据
- 二、构建训练和验证集
- 三、数据增强
- 四、数据可视化
- 五、构建3D卷积神经网络模型
- 六、训练模型
- 七、可视化模型性能
- 八、对单次 CT 扫描进行预测
一、前言
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
往期精彩内容:
- 卷积神经网络(CNN)实现mnist手写数字识别
- 卷积神经网络(CNN)多种图片分类的实现
- 卷积神经网络(CNN)衣服图像分类的实现
- 卷积神经网络(CNN)鲜花识别
- 卷积神经网络(CNN)天气识别
- 卷积神经网络(VGG-16)识别海贼王草帽一伙
- 卷积神经网络(ResNet-50)鸟类识别
- 卷积神经网络(AlexNet)鸟类识别
- 卷积神经网络(CNN)识别验证码
来自专栏:机器学习与深度学习算法推荐
二、前期工作
1. 介绍
本案例将展示通过构建 3D 卷积神经网络 (CNN) 来预测计算机断层扫描 (CT) 中病毒性肺炎是否存在。 2D 的 CNN 通常用于处理 RGB 图像(3 个通道)。 3D 的 CNN 仅仅是 3D 等价物,我们可以将 3D 图像简单理解成 2D 图像的叠加。3D 的 CNN 可以理解成是学习立体数据的强大模型。
import os,zipfile
import numpy as np
from tensorflow import keras
from tensorflow.keras import layersimport tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)
2. 加载和预处理数据
数据文件是 Nifti,扩展名为 .nii。我使用nibabel
包来读取文件,你可以通过 pip install nibabel
来安装 nibabel
包。
数据预处理步骤:
- 首先将体积旋转 90 度,确保方向是固定的
- 将 HU 值缩放到 0 和 1 之间。
- 调整宽度、高度和深度。
我定义了几个辅助函数来完成处理数据,这些功能将在构建训练和验证数据集时使用。
import nibabel as nib
from scipy import ndimagedef read_nifti_file(filepath):# 读取文件scan = nib.load(filepath)# 获取数据scan = scan.get_fdata()return scandef normalize(volume):"""归一化"""min = -1000max = 400volume[volume < min] = minvolume[volume > max] = maxvolume = (volume - min) / (max - min)volume = volume.astype("float32")return volumedef resize_volume(img):"""修改图像大小"""# Set the desired depthdesired_depth = 64desired_width = 128desired_height = 128# Get current depthcurrent_depth = img.shape[-1]current_width = img.shape[0]current_height = img.shape[1]# Compute depth factordepth = current_depth / desired_depthwidth = current_width / desired_widthheight = current_height / desired_heightdepth_factor = 1 / depthwidth_factor = 1 / widthheight_factor = 1 / height# 旋转img = ndimage.rotate(img, 90, reshape=False)# 数据调整img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)return imgdef process_scan(path):# 读取文件volume = read_nifti_file(path)# 归一化volume = normalize(volume)# 调整尺寸 width, height and depthvolume = resize_volume(volume)return volume
读取CT扫描文件的路径
# “CT-0”文件夹中是正常肺组织的CT扫描
normal_scan_paths = [os.path.join(os.getcwd(), "MosMedData/CT-0", x)for x in os.listdir("MosMedData/CT-0")
]# “CT-23”文件夹中是患有肺炎的人的CT扫描
abnormal_scan_paths = [os.path.join(os.getcwd(), "MosMedData/CT-23", x)for x in os.listdir("MosMedData/CT-23")
]print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))
CT scans with normal lung tissue: 100
CT scans with abnormal lung tissue: 100
# 读取数据并进行预处理
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])# 标签数字化
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])
二、构建训练和验证集
从类目录中读取扫描并分配标签。对扫描进行下采样以具有 128x128x64 的形状。将原始 HU 值重新调整到 0 到 1 的范围内。最后,将数据集拆分为训练和验证子集。
# 按照7:3的比例划分训练集、验证集
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print("Number of samples in train and validation are %d and %d."% (x_train.shape[0], x_val.shape[0])
)
Number of samples in train and validation are 140 and 60.
三、数据增强
CT扫描也通过在训练期间在随机角度旋转来增强数据。由于数据存储在Rank-3的形状(样本,高度,宽度,深度)中,因此我们在轴4处添加大小1的尺寸,以便能够对数据执行3D卷积。因此,新形状(样品,高度,宽度,深度,1)。在那里有不同类型的预处理和增强技术,这个例子显示了一些简单的开始。
import random
from scipy import ndimage@tf.function
def rotate(volume):"""不同程度上进行旋转"""def scipy_rotate(volume):# 定义一些旋转角度angles = [-20, -10, -5, 5, 10, 20]# 随机选择一个角度angle = random.choice(angles)volume = ndimage.rotate(volume, angle, reshape=False)volume[volume < 0] = 0volume[volume > 1] = 1return volumeaugmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)return augmented_volumedef train_preprocessing(volume, label):volume = rotate(volume)volume = tf.expand_dims(volume, axis=3)return volume, labeldef validation_preprocessing(volume, label):volume = tf.expand_dims(volume, axis=3)return volume, label
在定义训练和验证数据加载器的同时,训练数据将进行不同角度的随机旋转。训练和验证数据都已重新调整为具有 0 到 1 之间的值。
# 定义数据加载器
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))batch_size = 2train_dataset = (train_loader.shuffle(len(x_train)).map(train_preprocessing).batch(batch_size).prefetch(2)
)validation_dataset = (validation_loader.shuffle(len(x_val)).map(validation_preprocessing).batch(batch_size).prefetch(2)
)
四、数据可视化
import matplotlib.pyplot as pltdata = train_dataset.take(1)
images, labels = list(data)[0]
images = images.numpy()
image = images[0]
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
Dimension of the CT scan is: (128, 128, 64, 1)
def plot_slices(num_rows, num_columns, width, height, data):"""Plot a montage of 20 CT slices"""data = np.rot90(np.array(data))data = np.transpose(data)data = np.reshape(data, (num_rows, num_columns, width, height))rows_data, columns_data = data.shape[0], data.shape[1]heights = [slc[0].shape[0] for slc in data]widths = [slc.shape[1] for slc in data[0]]fig_width = 12.0fig_height = fig_width * sum(heights) / sum(widths)f, axarr = plt.subplots(rows_data,columns_data,figsize=(fig_width, fig_height),gridspec_kw={"height_ratios": heights},)for i in range(rows_data):for j in range(columns_data):axarr[i, j].imshow(data[i][j], cmap="gray")axarr[i, j].axis("off")plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)plt.show()# Visualize montage of slices.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 128, image[:, :, :40])
五、构建3D卷积神经网络模型
为了使模型更容易理解,我将其构建成块。
def get_model(width=128, height=128, depth=64):"""构建 3D 卷积神经网络模型"""inputs = keras.Input((width, height, depth, 1))x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.GlobalAveragePooling3D()(x)x = layers.Dense(units=512, activation="relu")(x)x = layers.Dropout(0.3)(x)outputs = layers.Dense(units=1, activation="sigmoid")(x)# 定义模型model = keras.Model(inputs, outputs, name="3dcnn")return model# 构建模型
model = get_model(width=128, height=128, depth=64)
model.summary()
六、训练模型
# 设置动态学习率
initial_learning_rate = 1e-4
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=30, decay_rate=0.96, staircase=True
)
# 编译
model.compile(loss="binary_crossentropy",optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),metrics=["acc"],
)
# 保存模型
checkpoint_cb = keras.callbacks.ModelCheckpoint("3d_image_classification.h5", save_best_only=True
)
# 定义早停策略
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)epochs = 100
model.fit(train_dataset,validation_data=validation_dataset,epochs=epochs,shuffle=True,verbose=2,callbacks=[checkpoint_cb, early_stopping_cb],
)
七、可视化模型性能
fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()for i, metric in enumerate(["acc", "loss"]):ax[i].plot(model.history.history[metric])ax[i].plot(model.history.history["val_" + metric])ax[i].set_title("Model {}".format(metric))ax[i].set_xlabel("epochs")ax[i].set_ylabel(metric)ax[i].legend(["train", "val"])
八、对单次 CT 扫描进行预测
# 加载模型
model.load_weights("3d_image_classification.h5")
prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
scores = [1 - prediction[0], prediction[0]]class_names = ["normal", "abnormal"]
for score, name in zip(scores, class_names):print("This model is %.2f percent confident that CT scan is %s"% ((100 * score), name))
This model is 27.88 percent confident that CT scan is normal
This model is 72.12 percent confident that CT scan is abnormal
相关文章:

卷积神经网络-3D医疗影像识别
文章目录 一、前言二、前期工作1. 介绍2. 加载和预处理数据 二、构建训练和验证集三、数据增强四、数据可视化五、构建3D卷积神经网络模型六、训练模型七、可视化模型性能八、对单次 CT 扫描进行预测 一、前言 我的环境: 语言环境:Python3.6.5编译器&a…...

C++基础 -33- 单目运算符重载
单目运算符重载格式 a和a通过形参确定 data1 operator() {this->a;return *this; }data1 operator(int) {data1 temp*this;this->a;return temp; }举例使用单目运算符重载 #include "iostream"using namespace std;class data1 {public :int a;data1(int…...
[传智杯 #3 初赛] 课程报名
题目描述 传智播客推出了一款课程,并进行了一次促销活动。具体来说就是,课程的初始定价为 v 元;每报名 m 个学员,课程的定价就要提升 a 元。由于课程能够容纳的学生有限,因此报名到 n 人的时候就停止报名。 现在老师…...

华为OD机试 - 悄悄话(Java JS Python C)
题目描述 给定一个二叉树,每个节点上站一个人,节点数字表示父节点到该节点传递悄悄话需要花费的时间。 初始时,根节点所在位置的人有一个悄悄话想要传递给其他人,求二叉树所有节点上的人都接收到悄悄话花费的时间。 输入描述 给定二叉树 0 9 20 -1 -1 15 7 -1 -1 -1 -1 …...

LeetCode - 965. 单值二叉树(C语言,二叉树,配图)
二叉树每个节点都具有相同的值,我们就可以比较每个树的根节点与左右两个孩子节点的值是否相同,如果不同返回false,否则,返回true。 如果是叶子节点,不存在还孩子节点,则这个叶子节点为根的树是单值二叉树。…...
每日一题(LeetCode)----哈希表--三数之和
每日一题(LeetCode)----哈希表–三数之和 1.题目(15. 三数之和) 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请 你返回所…...
DL中的GPU使用问题
写在前面 在使用GPU进行深度学习训练经常会遇到下面几个问题,这里做一个解决方法的汇总。 🐕Q1🐕:在一个多卡服务器上,指定了cuda:1,但是0号显卡显存还是会被占用一定量的显存。 这个问题很经典的出现场景就…...
Linux命令——watch
watch是周期性的执行下个程序,并全屏显示执行结果 用法: vmfedora:~$ watch --helpUsage:watch [options] commandOptions:-b, --beep beep if command has a non-zero exit-c, --color interpret ANSI color and style sequen…...

力扣题:字符的统计-12.2
力扣题-12.2 [力扣刷题攻略] Re:从零开始的力扣刷题生活 力扣题1:423. 从英文中重建数字 解题思想:有的单词通过一个字母就可以确定,依次确定即可 class Solution(object):def originalDigits(self, s):""":typ…...

Python----Pandas
目录 Series属性 DataFrame的属性 Pandas的CSV文件 Pandas数据处理 Pandas的主要数据结构是Series(一维数据)与DataFrame(二维数据) Series属性 Series的属性如下: 属性描述pandas.Series(data,index,dtype,nam…...

【UE】UEC++获取屏幕颜色GetPixelFromCursorPosition()
目录 【UE】UE C 获取屏幕颜色GetPixelFromCursorPosition() 一、函数声明与定义 二、函数的调用 三、运行结果 【UE】UE C 获取屏幕颜色GetPixelFromCursorPosition() 一、函数声明与定义 创建一个蓝图方法库方法 GetPixelFromCursorPosition(),并给他指定UF…...

数学建模-基于BL回归模型和决策树模型对早产危险因素的探究和预测
整体求解过程概述(摘要) 近年来,全球早产率总体呈上升趋势,在我国,早产儿以每年 20 万的数目逐年递增,目前早产已经成为重大的公共卫生问题之一。据研究,早产是威胁胎儿及新生儿健康的重要因素,可能会造成死亡或智力体…...

接口测试 —— 接口测试的意义
1、接口测试的意义(优势) (1)更早的发现问题: 不少的测试资料中强调,测试应该更早的介入到项目开发中,因为越早的发现bug,修复的成本越低。 然而功能测试必须要等到系统提供可测试…...
一些常见的爬虫库
一些常见的爬虫库,并按功能和用途进行分类: 通用爬虫库: Beautiful Soup:用于解析HTML和XML文档,方便地提取数据。Requests:用于HTTP请求,获取网页内容。Scrapy:一个强大的爬虫框架…...

2023.12.2 做一个后台管理网页(左侧边栏实现手风琴和隐藏/出现效果)
2023.12.2 做一个后台管理网页(左侧边栏实现手风琴和隐藏/出现效果) 网页源码见附件,比较简单,之前用很多种方法实现过该效果,这次的效果相对更好。 实现功能: (1)实现左侧边栏的手…...

【EMFace】《EMface: Detecting Hard Faces by Exploring Receptive Field Pyramids》
arXiv-2021 文章目录 1 Background and Motivation2 Related Work3 Advantages / Contributions4 Method5 Experiments5.1 Datasets and Metrics5.2 Ablation Study5.3 Comparison with State-of-the-Arts 6 Conclusion(own) 1 Background and Motivatio…...
详细学习Pyqt5的20种输入控件(Input Widgets)
Pyqt5相关文章: 快速掌握Pyqt5的三种主窗口 快速掌握Pyqt5的2种弹簧 快速掌握Pyqt5的5种布局 快速弄懂Pyqt5的5种项目视图(Item View) 快速弄懂Pyqt5的4种项目部件(Item Widget) 快速掌握Pyqt5的6种按钮 快速掌握Pyqt5的10种容器&…...

【JavaEE初阶】Thread 类及常见方法、线程的状态
目录 1、Thread 类及常见方法 1.1 Thread 的常见构造方法 1.2 Thread 的几个常见属性 1.3 启动⼀个线程 - start() 1.4 中断⼀个线程 1.5 等待⼀个线程 - join() 1.6 获取当前线程引用 1.7 休眠当前线程 2、线程的状态 2.1 观察线程的所有状态 2.2 线程状态和状…...

0 NLP: 数据获取与EDA
0数据准备与分析 二分类任务,正负样本共计6W; 数据集下载 https://github.com/SophonPlus/ChineseNlpCorpus/raw/master/datasets/online_shopping_10_cats/online_shopping_10_cats.zip 样本的分布 正负样本中评论字段的长度 ,超过500的都…...

159.库存管理(TOPk问题!)
思路:也是tok的问题,与上篇博客思路一样,只不过是求前k个小的元素! 基于快排分块思路的代码如下: class Solution { public:int getkey(vector<int>&nums,int left,int right){int rrand();return nums[r%…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者
抖音增长新引擎:品融电商,一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中,品牌如何破浪前行?自建团队成本高、效果难控;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

ETLCloud可能遇到的问题有哪些?常见坑位解析
数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成
厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...

12.找到字符串中所有字母异位词
🧠 题目解析 题目描述: 给定两个字符串 s 和 p,找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义: 若两个字符串包含的字符种类和出现次数完全相同,顺序无所谓,则互为…...
【HTTP三个基础问题】
面试官您好!HTTP是超文本传输协议,是互联网上客户端和服务器之间传输超文本数据(比如文字、图片、音频、视频等)的核心协议,当前互联网应用最广泛的版本是HTTP1.1,它基于经典的C/S模型,也就是客…...

Android 之 kotlin 语言学习笔记三(Kotlin-Java 互操作)
参考官方文档:https://developer.android.google.cn/kotlin/interop?hlzh-cn 一、Java(供 Kotlin 使用) 1、不得使用硬关键字 不要使用 Kotlin 的任何硬关键字作为方法的名称 或字段。允许使用 Kotlin 的软关键字、修饰符关键字和特殊标识…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表
##鸿蒙核心技术##运动开发##Sensor Service Kit(传感器服务)# 前言 在运动类应用中,运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据,如配速、距离、卡路里消耗等,用户可以更清晰…...

FFmpeg:Windows系统小白安装及其使用
一、安装 1.访问官网 Download FFmpeg 2.点击版本目录 3.选择版本点击安装 注意这里选择的是【release buids】,注意左上角标题 例如我安装在目录 F:\FFmpeg 4.解压 5.添加环境变量 把你解压后的bin目录(即exe所在文件夹)加入系统变量…...
c# 局部函数 定义、功能与示例
C# 局部函数:定义、功能与示例 1. 定义与功能 局部函数(Local Function)是嵌套在另一个方法内部的私有方法,仅在包含它的方法内可见。 • 作用:封装仅用于当前方法的逻辑,避免污染类作用域,提升…...

从物理机到云原生:全面解析计算虚拟化技术的演进与应用
前言:我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM(Java Virtual Machine)让"一次编写,到处运行"成为可能。这个软件层面的虚拟化让我着迷,但直到后来接触VMware和Doc…...