当前位置: 首页 > news >正文

卷积神经网络(CNN):乳腺癌识别.ipynb

文章目录

  • 一、前言
  • 一、设置GPU
  • 二、导入数据
    • 1. 导入数据
    • 2. 检查数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keraswarnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

二、导入数据

1. 导入数据

import pathlibdata_dir = "./32-data"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 13403
batch_size = 16
img_height = 50
img_width  = 50
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 13403 files belonging to 2 classes.
Using 10723 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
Found 13403 files belonging to 2 classes.
Using 2680 files for validation.
class_names = train_ds.class_names
print(class_names)
['0', '1']

2. 检查数据

for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(16, 50, 50, 3)
(16,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNEdef train_preprocessing(image,label):return (image/255.0,label)train_ds = (train_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)val_ds = (val_ds.cache().shuffle(1000).map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size.prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")class_names = ["乳腺癌细胞","正常细胞"]for images, labels in train_ds.take(1):for i in range(15):plt.subplot(4, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)# 显示图片plt.imshow(images[i])# 显示标签plt.xlabel(class_names[labels[i]-1])plt.show()

在这里插入图片描述

三、构建模型

import tensorflow as tfmodel = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu",input_shape=[img_width, img_height, 3]),tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Dropout(0.5),tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),tf.keras.layers.MaxPooling2D((2,2)),tf.keras.layers.Flatten(),tf.keras.layers.Dense(2, activation="softmax")
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 50, 50, 16)        448       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 50, 50, 16)        2320      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 25, 25, 16)        0         
_________________________________________________________________
dropout (Dropout)            (None, 25, 25, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 25, 25, 16)        2320      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 16)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 12, 12, 16)        2320      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 2)                 1154      
=================================================================
Total params: 8,562
Trainable params: 8,562
Non-trainable params: 0
_________________________________________________________________

四、编译

model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateSchedulerNO_EPOCHS = 100
PATIENCE  = 5
VERBOSE   = 1# 设置动态学习率
annealer = LearningRateScheduler(lambda x: 1e-3 * 0.99 ** (x+NO_EPOCHS))# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)# 
checkpointer = ModelCheckpoint('best_model.h5',monitor='val_accuracy',verbose=VERBOSE,save_best_only=True,save_weights_only=True)
train_model  = model.fit(train_ds,epochs=NO_EPOCHS,verbose=1,validation_data=val_ds,callbacks=[earlystopper, checkpointer, annealer])

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']loss = train_model.history['loss']
val_loss = train_model.history['val_loss']epochs_range = range(len(acc))plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):# 生成混淆矩阵conf_numpy = confusion_matrix(labels, predictions)# 将矩阵转化为 DataFrameconf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  plt.figure(figsize=(8,7))sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")plt.title('混淆矩阵',fontsize=15)plt.ylabel('真实值',fontsize=14)plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵for image, label in zip(images, labels):# 需要给图片增加一个维度img_array = tf.expand_dims(image, 0) # 使用模型预测图片中的人物prediction = model.predict(img_array)val_pre.append(class_names[np.argmax(prediction)])val_label.append(class_names[label])
plot_cm(val_label, val_pre)

3. 各项指标评估

from sklearn import metricsdef test_accuracy_report(model):print(metrics.classification_report(val_label, val_pre, target_names=class_names)) score = model.evaluate(val_ds, verbose=0)print('Loss function: %s, accuracy:' % score[0], score[1])test_accuracy_report(model)
             precision    recall  f1-score   support乳腺癌细胞       0.92      0.90      0.91      1339正常细胞       0.91      0.92      0.91      1341accuracy                           0.91      2680macro avg       0.91      0.91      0.91      2680
weighted avg       0.91      0.91      0.91      2680Loss function: 0.22688131034374237, accuracy: 0.9138059616088867

pport

   乳腺癌细胞       0.92      0.90      0.91      1339正常细胞       0.91      0.92      0.91      1341accuracy                           0.91      2680

macro avg 0.91 0.91 0.91 2680
weighted avg 0.91 0.91 0.91 2680

Loss function: 0.22688131034374237, accuracy: 0.9138059616088867


相关文章:

卷积神经网络(CNN):乳腺癌识别.ipynb

文章目录 一、前言一、设置GPU二、导入数据1. 导入数据2. 检查数据3. 配置数据集4. 数据可视化 三、构建模型四、编译五、训练模型六、评估模型1. Accuracy与Loss图2. 混淆矩阵3. 各项指标评估 一、前言 我的环境: 语言环境:Python3.6.5编译器&#xf…...

有文件实体的后门无文件实体的后门rootkit后门

有文件实体后门和无文件实体后门&RootKit后门 什么是有文件的实体后门: 在传统的webshell当中,后门代码都是可以精确定位到某一个文件上去的,你可以rm删除它,可以鼠标右键操作它,它是有一个文件实体对象存在的。…...

GPT实战系列-大模型训练和预测,如何加速、降低显存

GPT实战系列-大模型训练和预测,如何加速、降低显存 不做特别处理,深度学习默认参数精度为浮点32位精度(FP32)。大模型参数庞大,10-1000B级别,如果不注意优化,既耗费大量的显卡资源,…...

SQL Sever 基础知识 - 数据排序

SQL Sever 基础知识 - 二 、数据排序 二 、对数据进行排序第1节 ORDER BY 子句简介第2节 ORDER BY 子句示例2.1 按一列升序对结果集进行排序2.2 按一列降序对结果集进行排序2.3 按多列对结果集排序2.4 按多列对结果集不同排序2.5 按不在选择列表中的列对结果集进行排序2.6 按表…...

vscode配置使用 cpplint

标题安装clang-format和cpplint sudo apt-get install clang-format sudo pip3 install cpplint标题以下settings.json文件放置xxx/Code/User目录 settings.json {"sync.forceDownload": false,"workbench.sideBar.location": "right","…...

C++ 系列 第四篇 C++ 数据类型上篇—基本类型

系列文章 C 系列 前篇 为什么学习C 及学习计划-CSDN博客 C 系列 第一篇 开发环境搭建(WSL 方向)-CSDN博客 C 系列 第二篇 你真的了解C吗?本篇带你走进C的世界-CSDN博客 C 系列 第三篇 C程序的基本结构-CSDN博客 前言 面向对象编程(OOP)的…...

C++ 指针详解

目录 一、指针概述 指针的定义 指针的大小 指针的解引用 野指针 指针未初始化 指针越界访问 指针运算 二级指针 指针与数组 二、字符指针 三、指针数组 四、数组指针 函数指针 函数指针数组 指向函数指针数组的指针 回调函数 指针与数组 一维数组 字符数组…...

.locked、locked1勒索病毒的最新威胁:如何恢复您的数据?

导言: 网络安全问题变得愈加严峻。.locked、locked1勒索病毒是近期备受关注的一种恶意软件,给用户的数据带来了巨大威胁。本文将深入探讨.locked、locked1勒索病毒的特征,探讨如何有效恢复被其加密的数据,并提供一些建议&#xf…...

Apache Sqoop使用

1. Sqoop介绍 Apache Sqoop 是在 Hadoop 生态体系和 RDBMS 体系之间传送数据的一种工具。 Sqoop 工作机制是将导入或导出命令翻译成 mapreduce 程序来实现。在翻译出的 mapreduce 中主要是对 inputformat 和 outputformat 进行定制。 Hadoop 生态系统包括:HDFS、Hi…...

【UGUI】实现UGUI背包系统的六个主要交互功能

在这篇教程中,我们将详细介绍如何在Unity中实现一个背包系统的六个主要功能:添加物品、删除物品、查看物品信息、排序物品、搜索物品和使用物品。让我们开始吧! 一、添加物品 首先,我们需要创建一个方法来添加新的物品到背包中。…...

电压驻波比

电压驻波比 关于IF端口的电压驻波比 一个信号变频后,从中频端口输出,它的输出跟输入是互异的。这个电压柱波比反映了它输出的能量有多少可以真正的输送到后端连接的器件或者设备。...

Open3D 最小二乘拟合二维直线(直接求解法)

目录 一、算法原理二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接。爬虫网站自重。 一、算法原理 平面直线的表达式为: y = k x + b...

面试题目总结(二)

1. IoC 和 AOP 的区别 控制反转(Ioc) 和面向切面编程(AOP) 是两个不同的概念,它们在软件设计中有着不同的应用和目的。 IoC 是一种基于对象组合的编程模式,通过将对象的创建、依赖关系和生命周期等管理权交给外部容器或框架来实现程序间的解耦。IoC 的…...

TrustZone概述

目录 一、概述 1.1 在开始之前 二、什么是TrustZone? 2.1 Armv8-M的TrustZone 2.2 Armv9-A Realm Management Ext...

[go 面试] Go Kit中读取原始HTTP请求体的方法

关注公众号【爱发白日梦的后端】分享技术干货、读书笔记、开源项目、实战经验、高效开发工具等,您的关注将是我的更新动力! 在Go Kit中,如果你想读取未序列化的HTTP请求体,可以使用标准的net/http包来实现。以下是一个示例,演示了如何完成这个任务: package mainimport …...

小程序如何刷新当前页面?

在小程序中,刷新当前页面通常有两种方法: 使用 wx.navigateBack 方法: wx.navigateBack({delta: 1 }) 这将返回上一页,并刷新页面。你可以通过调整 delta 参数来控制返回的页面数。例如,如果你想要返回到两页之前的页…...

ChatGPT使用路径:从新手到专家的指南

原文&精华文章&转载注明:ChatGPT与日本首相交流核废水事件-精准Prompt... hello,我是小索奇,有任何问题或者需要帮助的都可以在这里找到我或者留言哈 一、初识ChatGPT 什么是ChatGPT? ChatGPT是一种大型语言模型&…...

VsCode 调试 MySQL 源码

1. 启动 MySQL 2. 查看 MySQL 进程号 [root ~]# ps -ef | grep mysqld root 21479 1 0 Nov01 ? 00:00:00 /bin/sh /usr/local/mysql/bin/mysqld_safe --datadir/usr/local/mysql/data --pid-file/usr/local/mysql/data/mysqld.pid root 26622 21479 0 …...

Mysql中的正经行锁、间隙锁和临键锁

行锁、间隙锁和临键锁是数据库中的三种不同类型的锁,三者都属于行锁,第一个一般叫他正经的行锁(《Mysql是怎样运行的》一书中的说法)。 行锁(Row Lock):行锁是指对数据表中的某一行进行的锁定操…...

最强AI之风袭来,你爱了吗?

2017年,柯洁同阿尔法狗人机大战,AlphaGo以3比0大获全胜,一代英才泪洒当场...... 2019年,换脸哥视频“杨幂换朱茵”轰动全网,时至今日AI换脸仍热度只增不减; 2022年,ChatGPT一经发布便轰动全球&a…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank?由于时间太久,我真忘记了。搜搜发现,还真有人和我一样。见下面的链接:https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能,我们需要对它的功能特点进行分析: 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具: mysql:关系型数据库&am…...

多场景 OkHttpClient 管理器 - Android 网络通信解决方案

下面是一个完整的 Android 实现&#xff0c;展示如何创建和管理多个 OkHttpClient 实例&#xff0c;分别用于长连接、普通 HTTP 请求和文件下载场景。 <?xml version"1.0" encoding"utf-8"?> <LinearLayout xmlns:android"http://schemas…...

无法与IP建立连接,未能下载VSCode服务器

如题&#xff0c;在远程连接服务器的时候突然遇到了这个提示。 查阅了一圈&#xff0c;发现是VSCode版本自动更新惹的祸&#xff01;&#xff01;&#xff01; 在VSCode的帮助->关于这里发现前几天VSCode自动更新了&#xff0c;我的版本号变成了1.100.3 才导致了远程连接出…...

生成 Git SSH 证书

&#x1f511; 1. ​​生成 SSH 密钥对​​ 在终端&#xff08;Windows 使用 Git Bash&#xff0c;Mac/Linux 使用 Terminal&#xff09;执行命令&#xff1a; ssh-keygen -t rsa -b 4096 -C "your_emailexample.com" ​​参数说明​​&#xff1a; -t rsa&#x…...

跨链模式:多链互操作架构与性能扩展方案

跨链模式&#xff1a;多链互操作架构与性能扩展方案 ——构建下一代区块链互联网的技术基石 一、跨链架构的核心范式演进 1. 分层协议栈&#xff1a;模块化解耦设计 现代跨链系统采用分层协议栈实现灵活扩展&#xff08;H2Cross架构&#xff09;&#xff1a; 适配层&#xf…...

镜像里切换为普通用户

如果你登录远程虚拟机默认就是 root 用户&#xff0c;但你不希望用 root 权限运行 ns-3&#xff08;这是对的&#xff0c;ns3 工具会拒绝 root&#xff09;&#xff0c;你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案&#xff1a;创建非 roo…...

Python爬虫(一):爬虫伪装

一、网站防爬机制概述 在当今互联网环境中&#xff0c;具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类&#xff1a; 身份验证机制&#xff1a;直接将未经授权的爬虫阻挡在外反爬技术体系&#xff1a;通过各种技术手段增加爬虫获取数据的难度…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...