当前位置: 首页 > news >正文

卷积神经网络-3D医疗影像识别

文章目录

  • 一、前言
  • 二、前期工作
      • 1. 介绍
      • 2. 加载和预处理数据
    • 二、构建训练和验证集
    • 三、数据增强
    • 四、数据可视化
    • 五、构建3D卷积神经网络模型
    • 六、训练模型
    • 七、可视化模型性能
    • 八、对单次 CT 扫描进行预测

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 介绍

本案例将展示通过构建 3D 卷积神经网络 (CNN) 来预测计算机断层扫描 (CT) 中病毒性肺炎是否存在。 2D 的 CNN 通常用于处理 RGB 图像(3 个通道)。 3D 的 CNN 仅仅是 3D 等价物,我们可以将 3D 图像简单理解成 2D 图像的叠加。3D 的 CNN 可以理解成是学习立体数据的强大模型。

import os,zipfile
import numpy as np
from tensorflow import keras
from tensorflow.keras import layersimport tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)

2. 加载和预处理数据

数据文件是 Nifti,扩展名为 .nii。我使用nibabel 包来读取文件,你可以通过 pip install nibabel 来安装 nibabel

数据预处理步骤:

  1. 首先将体积旋转 90 度,确保方向是固定的
  2. 将 HU 值缩放到 0 和 1 之间。
  3. 调整宽度、高度和深度。

我定义了几个辅助函数来完成处理数据,这些功能将在构建训练和验证数据集时使用。

import nibabel as nib
from scipy import ndimagedef read_nifti_file(filepath):# 读取文件scan = nib.load(filepath)# 获取数据scan = scan.get_fdata()return scandef normalize(volume):"""归一化"""min = -1000max = 400volume[volume < min] = minvolume[volume > max] = maxvolume = (volume - min) / (max - min)volume = volume.astype("float32")return volumedef resize_volume(img):"""修改图像大小"""# Set the desired depthdesired_depth = 64desired_width = 128desired_height = 128# Get current depthcurrent_depth = img.shape[-1]current_width = img.shape[0]current_height = img.shape[1]# Compute depth factordepth = current_depth / desired_depthwidth = current_width / desired_widthheight = current_height / desired_heightdepth_factor = 1 / depthwidth_factor = 1 / widthheight_factor = 1 / height# 旋转img = ndimage.rotate(img, 90, reshape=False)# 数据调整img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)return imgdef process_scan(path):# 读取文件volume = read_nifti_file(path)# 归一化volume = normalize(volume)# 调整尺寸 width, height and depthvolume = resize_volume(volume)return volume

读取CT扫描文件的路径

# “CT-0”文件夹中是正常肺组织的CT扫描
normal_scan_paths = [os.path.join(os.getcwd(), "MosMedData/CT-0", x)for x in os.listdir("MosMedData/CT-0")
]# “CT-23”文件夹中是患有肺炎的人的CT扫描
abnormal_scan_paths = [os.path.join(os.getcwd(), "MosMedData/CT-23", x)for x in os.listdir("MosMedData/CT-23")
]print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))
CT scans with normal lung tissue: 100
CT scans with abnormal lung tissue: 100
# 读取数据并进行预处理
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])# 标签数字化
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])

二、构建训练和验证集

从类目录中读取扫描并分配标签。对扫描进行下采样以具有 128x128x64 的形状。将原始 HU 值重新调整到 0 到 1 的范围内。最后,将数据集拆分为训练和验证子集。

# 按照7:3的比例划分训练集、验证集
x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
print("Number of samples in train and validation are %d and %d."% (x_train.shape[0], x_val.shape[0])
)
Number of samples in train and validation are 140 and 60.

三、数据增强

CT扫描也通过在训练期间在随机角度旋转来增强数据。由于数据存储在Rank-3的形状(样本,高度,宽度,深度)中,因此我们在轴4处添加大小1的尺寸,以便能够对数据执行3D卷积。因此,新形状(样品,高度,宽度,深度,1)。在那里有不同类型的预处理和增强技术,这个例子显示了一些简单的开始。

import random
from scipy import ndimage@tf.function
def rotate(volume):"""不同程度上进行旋转"""def scipy_rotate(volume):# 定义一些旋转角度angles = [-20, -10, -5, 5, 10, 20]# 随机选择一个角度angle = random.choice(angles)volume = ndimage.rotate(volume, angle, reshape=False)volume[volume < 0] = 0volume[volume > 1] = 1return volumeaugmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)return augmented_volumedef train_preprocessing(volume, label):volume = rotate(volume)volume = tf.expand_dims(volume, axis=3)return volume, labeldef validation_preprocessing(volume, label):volume = tf.expand_dims(volume, axis=3)return volume, label

在定义训练和验证数据加载器的同时,训练数据将进行不同角度的随机旋转。训练和验证数据都已重新调整为具有 0 到 1 之间的值。

# 定义数据加载器
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))batch_size = 2train_dataset = (train_loader.shuffle(len(x_train)).map(train_preprocessing).batch(batch_size).prefetch(2)
)validation_dataset = (validation_loader.shuffle(len(x_val)).map(validation_preprocessing).batch(batch_size).prefetch(2)
)

四、数据可视化

import matplotlib.pyplot as pltdata = train_dataset.take(1)
images, labels = list(data)[0]
images = images.numpy()
image = images[0]
print("Dimension of the CT scan is:", image.shape)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
Dimension of the CT scan is: (128, 128, 64, 1)

在这里插入图片描述

def plot_slices(num_rows, num_columns, width, height, data):"""Plot a montage of 20 CT slices"""data = np.rot90(np.array(data))data = np.transpose(data)data = np.reshape(data, (num_rows, num_columns, width, height))rows_data, columns_data = data.shape[0], data.shape[1]heights = [slc[0].shape[0] for slc in data]widths = [slc.shape[1] for slc in data[0]]fig_width = 12.0fig_height = fig_width * sum(heights) / sum(widths)f, axarr = plt.subplots(rows_data,columns_data,figsize=(fig_width, fig_height),gridspec_kw={"height_ratios": heights},)for i in range(rows_data):for j in range(columns_data):axarr[i, j].imshow(data[i][j], cmap="gray")axarr[i, j].axis("off")plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)plt.show()# Visualize montage of slices.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 128, image[:, :, :40])

在这里插入图片描述

五、构建3D卷积神经网络模型

为了使模型更容易理解,我将其构建成块。

def get_model(width=128, height=128, depth=64):"""构建 3D 卷积神经网络模型"""inputs = keras.Input((width, height, depth, 1))x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)x = layers.MaxPool3D(pool_size=2)(x)x = layers.BatchNormalization()(x)x = layers.GlobalAveragePooling3D()(x)x = layers.Dense(units=512, activation="relu")(x)x = layers.Dropout(0.3)(x)outputs = layers.Dense(units=1, activation="sigmoid")(x)# 定义模型model = keras.Model(inputs, outputs, name="3dcnn")return model# 构建模型
model = get_model(width=128, height=128, depth=64)
model.summary()

六、训练模型

# 设置动态学习率
initial_learning_rate = 1e-4
lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=30, decay_rate=0.96, staircase=True
)
# 编译
model.compile(loss="binary_crossentropy",optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),metrics=["acc"],
)
# 保存模型
checkpoint_cb = keras.callbacks.ModelCheckpoint("3d_image_classification.h5", save_best_only=True
)
# 定义早停策略
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)epochs = 100
model.fit(train_dataset,validation_data=validation_dataset,epochs=epochs,shuffle=True,verbose=2,callbacks=[checkpoint_cb, early_stopping_cb],
)

七、可视化模型性能

fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()for i, metric in enumerate(["acc", "loss"]):ax[i].plot(model.history.history[metric])ax[i].plot(model.history.history["val_" + metric])ax[i].set_title("Model {}".format(metric))ax[i].set_xlabel("epochs")ax[i].set_ylabel(metric)ax[i].legend(["train", "val"])

八、对单次 CT 扫描进行预测

# 加载模型
model.load_weights("3d_image_classification.h5")
prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
scores = [1 - prediction[0], prediction[0]]class_names = ["normal", "abnormal"]
for score, name in zip(scores, class_names):print("This model is %.2f percent confident that CT scan is %s"% ((100 * score), name))
This model is 27.88 percent confident that CT scan is normal
This model is 72.12 percent confident that CT scan is abnormal

相关文章:

卷积神经网络-3D医疗影像识别

文章目录 一、前言二、前期工作1. 介绍2. 加载和预处理数据 二、构建训练和验证集三、数据增强四、数据可视化五、构建3D卷积神经网络模型六、训练模型七、可视化模型性能八、对单次 CT 扫描进行预测 一、前言 我的环境&#xff1a; 语言环境&#xff1a;Python3.6.5编译器&a…...

C++基础 -33- 单目运算符重载

单目运算符重载格式 a和a通过形参确定 data1 operator() {this->a;return *this; }data1 operator(int) {data1 temp*this;this->a;return temp; }举例使用单目运算符重载 #include "iostream"using namespace std;class data1 {public :int a;data1(int…...

[传智杯 #3 初赛] 课程报名

题目描述 传智播客推出了一款课程&#xff0c;并进行了一次促销活动。具体来说就是&#xff0c;课程的初始定价为 v 元&#xff1b;每报名 m 个学员&#xff0c;课程的定价就要提升 a 元。由于课程能够容纳的学生有限&#xff0c;因此报名到 n 人的时候就停止报名。 现在老师…...

华为OD机试 - 悄悄话(Java JS Python C)

题目描述 给定一个二叉树,每个节点上站一个人,节点数字表示父节点到该节点传递悄悄话需要花费的时间。 初始时,根节点所在位置的人有一个悄悄话想要传递给其他人,求二叉树所有节点上的人都接收到悄悄话花费的时间。 输入描述 给定二叉树 0 9 20 -1 -1 15 7 -1 -1 -1 -1 …...

LeetCode - 965. 单值二叉树(C语言,二叉树,配图)

二叉树每个节点都具有相同的值&#xff0c;我们就可以比较每个树的根节点与左右两个孩子节点的值是否相同&#xff0c;如果不同返回false&#xff0c;否则&#xff0c;返回true。 如果是叶子节点&#xff0c;不存在还孩子节点&#xff0c;则这个叶子节点为根的树是单值二叉树。…...

每日一题(LeetCode)----哈希表--三数之和

每日一题(LeetCode)----哈希表–三数之和 1.题目&#xff08;15. 三数之和&#xff09; 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums[j] nums[k] 0 。请 你返回所…...

DL中的GPU使用问题

写在前面 在使用GPU进行深度学习训练经常会遇到下面几个问题&#xff0c;这里做一个解决方法的汇总。 &#x1f415;Q1&#x1f415;&#xff1a;在一个多卡服务器上&#xff0c;指定了cuda:1&#xff0c;但是0号显卡显存还是会被占用一定量的显存。 这个问题很经典的出现场景就…...

Linux命令——watch

watch是周期性的执行下个程序&#xff0c;并全屏显示执行结果 用法&#xff1a; vmfedora:~$ watch --helpUsage:watch [options] commandOptions:-b, --beep beep if command has a non-zero exit-c, --color interpret ANSI color and style sequen…...

力扣题:字符的统计-12.2

力扣题-12.2 [力扣刷题攻略] Re&#xff1a;从零开始的力扣刷题生活 力扣题1&#xff1a;423. 从英文中重建数字 解题思想&#xff1a;有的单词通过一个字母就可以确定&#xff0c;依次确定即可 class Solution(object):def originalDigits(self, s):""":typ…...

Python----Pandas

目录 Series属性 DataFrame的属性 Pandas的CSV文件 Pandas数据处理 Pandas的主要数据结构是Series&#xff08;一维数据&#xff09;与DataFrame&#xff08;二维数据&#xff09; Series属性 Series的属性如下&#xff1a; 属性描述pandas.Series(data,index,dtype,nam…...

【UE】UEC++获取屏幕颜色GetPixelFromCursorPosition()

目录 【UE】UE C 获取屏幕颜色GetPixelFromCursorPosition() 一、函数声明与定义 二、函数的调用 三、运行结果 【UE】UE C 获取屏幕颜色GetPixelFromCursorPosition() 一、函数声明与定义 创建一个蓝图方法库方法 GetPixelFromCursorPosition()&#xff0c;并给他指定UF…...

数学建模-基于BL回归模型和决策树模型对早产危险因素的探究和预测

整体求解过程概述(摘要) 近年来&#xff0c;全球早产率总体呈上升趋势&#xff0c;在我国&#xff0c;早产儿以每年 20 万的数目逐年递增&#xff0c;目前早产已经成为重大的公共卫生问题之一。据研究,早产是威胁胎儿及新生儿健康的重要因素&#xff0c;可能会造成死亡或智力体…...

接口测试 —— 接口测试的意义

1、接口测试的意义&#xff08;优势&#xff09; &#xff08;1&#xff09;更早的发现问题&#xff1a; 不少的测试资料中强调&#xff0c;测试应该更早的介入到项目开发中&#xff0c;因为越早的发现bug&#xff0c;修复的成本越低。 然而功能测试必须要等到系统提供可测试…...

一些常见的爬虫库

一些常见的爬虫库&#xff0c;并按功能和用途进行分类&#xff1a; 通用爬虫库&#xff1a; Beautiful Soup&#xff1a;用于解析HTML和XML文档&#xff0c;方便地提取数据。Requests&#xff1a;用于HTTP请求&#xff0c;获取网页内容。Scrapy&#xff1a;一个强大的爬虫框架…...

2023.12.2 做一个后台管理网页(左侧边栏实现手风琴和隐藏/出现效果)

2023.12.2 做一个后台管理网页&#xff08;左侧边栏实现手风琴和隐藏/出现效果&#xff09; 网页源码见附件&#xff0c;比较简单&#xff0c;之前用很多种方法实现过该效果&#xff0c;这次的效果相对更好。 实现功能&#xff1a; &#xff08;1&#xff09;实现左侧边栏的手…...

【EMFace】《EMface: Detecting Hard Faces by Exploring Receptive Field Pyramids》

arXiv-2021 文章目录 1 Background and Motivation2 Related Work3 Advantages / Contributions4 Method5 Experiments5.1 Datasets and Metrics5.2 Ablation Study5.3 Comparison with State-of-the-Arts 6 Conclusion&#xff08;own&#xff09; 1 Background and Motivatio…...

详细学习Pyqt5的20种输入控件(Input Widgets)

Pyqt5相关文章: 快速掌握Pyqt5的三种主窗口 快速掌握Pyqt5的2种弹簧 快速掌握Pyqt5的5种布局 快速弄懂Pyqt5的5种项目视图&#xff08;Item View&#xff09; 快速弄懂Pyqt5的4种项目部件&#xff08;Item Widget&#xff09; 快速掌握Pyqt5的6种按钮 快速掌握Pyqt5的10种容器&…...

【JavaEE初阶】Thread 类及常见方法、线程的状态

目录 1、Thread 类及常见方法 1.1 Thread 的常见构造方法 1.2 Thread 的几个常见属性 1.3 启动⼀个线程 - start() 1.4 中断⼀个线程 1.5 等待⼀个线程 - join() 1.6 获取当前线程引用 1.7 休眠当前线程 2、线程的状态 2.1 观察线程的所有状态 2.2 线程状态和状…...

0 NLP: 数据获取与EDA

0数据准备与分析 二分类任务&#xff0c;正负样本共计6W&#xff1b; 数据集下载 https://github.com/SophonPlus/ChineseNlpCorpus/raw/master/datasets/online_shopping_10_cats/online_shopping_10_cats.zip 样本的分布 正负样本中评论字段的长度 &#xff0c;超过500的都…...

159.库存管理(TOPk问题!)

思路&#xff1a;也是tok的问题&#xff0c;与上篇博客思路一样&#xff0c;只不过是求前k个小的元素&#xff01; 基于快排分块思路的代码如下&#xff1a; class Solution { public:int getkey(vector<int>&nums,int left,int right){int rrand();return nums[r%…...

vscode里如何用git

打开vs终端执行如下&#xff1a; 1 初始化 Git 仓库&#xff08;如果尚未初始化&#xff09; git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

逻辑回归:给不确定性划界的分类大师

想象你是一名医生。面对患者的检查报告&#xff08;肿瘤大小、血液指标&#xff09;&#xff0c;你需要做出一个**决定性判断**&#xff1a;恶性还是良性&#xff1f;这种“非黑即白”的抉择&#xff0c;正是**逻辑回归&#xff08;Logistic Regression&#xff09;** 的战场&a…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

今日科技热点速览

&#x1f525; 今日科技热点速览 &#x1f3ae; 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售&#xff0c;主打更强图形性能与沉浸式体验&#xff0c;支持多模态交互&#xff0c;受到全球玩家热捧 。 &#x1f916; 人工智能持续突破 DeepSeek-R1&…...

3-11单元格区域边界定位(End属性)学习笔记

返回一个Range 对象&#xff0c;只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意&#xff1a;它移动的位置必须是相连的有内容的单元格…...

纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join

纯 Java 项目&#xff08;非 SpringBoot&#xff09;集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...

MySQL:分区的基本使用

目录 一、什么是分区二、有什么作用三、分类四、创建分区五、删除分区 一、什么是分区 MySQL 分区&#xff08;Partitioning&#xff09;是一种将单张表的数据逻辑上拆分成多个物理部分的技术。这些物理部分&#xff08;分区&#xff09;可以独立存储、管理和优化&#xff0c;…...

Golang——7、包与接口详解

包与接口详解 1、Golang包详解1.1、Golang中包的定义和介绍1.2、Golang包管理工具go mod1.3、Golang中自定义包1.4、Golang中使用第三包1.5、init函数 2、接口详解2.1、接口的定义2.2、空接口2.3、类型断言2.4、结构体值接收者和指针接收者实现接口的区别2.5、一个结构体实现多…...

sshd代码修改banner

sshd服务连接之后会收到字符串&#xff1a; SSH-2.0-OpenSSH_9.5 容易被hacker识别此服务为sshd服务。 是否可以通过修改此banner达到让人无法识别此服务的目的呢&#xff1f; 不能。因为这是写的SSH的协议中的。 也就是协议规定了banner必须这么写。 SSH- 开头&#xff0c…...

图解JavaScript原型:原型链及其分析 | JavaScript图解

​​ 忽略该图的细节&#xff08;如内存地址值没有用二进制&#xff09; 以下是对该图进一步的理解和总结 1. JS 对象概念的辨析 对象是什么&#xff1a;保存在堆中一块区域&#xff0c;同时在栈中有一块区域保存其在堆中的地址&#xff08;也就是我们通常说的该变量指向谁&…...