LeNet实验 四分类 与 四分类变为多个二分类
目录
1. 划分二分类
2. 训练独立的二分类模型
3. 二分类预测结果代码
4. 二分类预测结果
5 改进训练模型
6 优化后 预测结果代码
7 优化后预测结果
8 训练四分类模型
9 预测结果代码
10 四分类结果识别
1. 划分二分类
可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented
2. 训练独立的二分类模型
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom 文件准备 import data_dir# 数据生成器
train_datagen = ImageDataGenerator(rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,validation_split=0.2 # 20%用于验证
)train_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='training'
)validation_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='validation'
)# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))# 编译模型
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_generator,steps_per_epoch=train_generator.samples // train_generator.batch_size,epochs=10,validation_data=validation_generator,validation_steps=validation_generator.samples // validation_generator.batch_size
)# 保存模型
model.save('lenet_binary_classification_model.h5')
3. 预测结果代码
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')# 预处理图像
def preprocess_image(img_path):img = image.load_img(img_path, target_size=(28, 28))img_array = image.img_to_array(img) / 255.0img_array = np.expand_dims(img_array, axis=0)return img_array# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg' # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'print(f'The predicted class is: {predicted_class}')# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()
4. 预测结果
Demented结果
NonDemented结果没有。。。。。。
竟然全都没有。。。。因为预测的全部都是Demented
疯狂找原因中
猜测是像素太低使得训练的模型准确率太低
于是重新训练
5 改进训练模型
进行重新训练
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(1, activation='sigmoid')])model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()
这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出

可以看到loss值非常小而且准确率是100
6 优化后 预测结果代码
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['Demented', 'NonDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[int(prediction[0] > 0.5)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg' # 替换为你的图片路径
predict_image(img_path)
7 优化后预测结果
图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)
NonDem的也是对应上了
就此训练完成
8 训练四分类模型
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(4, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

loss值与准确率的变化图

可以看到才第四轮准确率就已经很高了
9 预测结果代码
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[np.argmax(prediction)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg' # 你的图片路径
predict_image(img_path)
10 四分类结果识别
1 MildDem成功识别(右侧有图片名称)

2 ModerateDem 成功识别

3 NonDem成功识别
4 VeryMildDem成功识别

相关文章:
LeNet实验 四分类 与 四分类变为多个二分类
目录 1. 划分二分类 2. 训练独立的二分类模型 3. 二分类预测结果代码 4. 二分类预测结果 5 改进训练模型 6 优化后 预测结果代码 7 优化后预测结果 8 训练四分类模型 9 预测结果代码 10 四分类结果识别 1. 划分二分类 可以根据不同的类别进行多个划分,以…...
【BUG】已解决:java.lang.reflect.InvocationTargetException
已解决:java.lang.reflect.InvocationTargetException 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市开发…...
配置kali 的apt命令在线安装包的源为国内源
目录 一、安装VMware Tools 二、配置apt国内源 一、安装VMware Tools 点击安装 VMware Tools 后,会加载一个虚拟光驱,里面包含 VMware Tools 的安装包 鼠标右键单击 VMware Tools 的安装包,点击复制到 点击 主目录,再点击选择…...
JAVA 异步编程(线程安全)二
1、线程安全 线程安全是指你的代码所在的进程中有多个线程同时运行,而这些线程可能会同时运行这段代码,如果每次运行的代码结果和单线程运行的结果是一样的,且其他变量的值和预期的也是一样的,那么就是线程安全的。 一个类或者程序…...
Golang | Leetcode Golang题解之第260题只出现一次的数字III
题目: 题解: func singleNumber(nums []int) []int {xorSum : 0for _, num : range nums {xorSum ^ num}lsb : xorSum & -xorSumtype1, type2 : 0, 0for _, num : range nums {if num&lsb > 0 {type1 ^ num} else {type2 ^ num}}return []in…...
IDEA自带的Maven 3.9.x无法刷新http nexus私服
问题: 自建的私服,配置了域名,使用http协议,在IDEA中或本地Maven 3.9.x会出现报错,提示http被blocked,原因是Maven 3.8.1开始,Maven默认禁止使用HTTP仓库地址,只允许使用HTTPS仓库地…...
56、本地数据库迁移到阿里云
现有需求,本地数据库迁移到阿里云上。 库名xy102表 test01test02test01 test023条数据。1、登录阿里云界面创建免费试用ECS实列。 阿里云登录页 (aliyun.com)](https://account.aliyun.com/login/login.htm?oauth_callbackhttps%3A%2F%2Fusercenter2.aliyun.com%…...
新时代多目标优化【数学建模】领域的极致探索——数学规划模型
目录 例1 1.问题重述 2.基本模型 变量定义: 目标函数: 约束条件: 3.模型分析与假设 4.模型求解 5.LINGO代码实现 6.结果解释 编辑 7.敏感性分析 8.结果解释 例2 奶制品的销售计划 1.问题重述 编辑 2.基本模型 3.模…...
单例模式详解
文章目录 一、概述1.单例模式2.单例模式的特点3.单例模式的实现方法 二、单例模式的实现1. 饿汉式2. 懒汉式3. 双重校验锁4. 静态内部类5. 枚举 三、总结 一、概述 1.单例模式 单例模式(Singleton Pattern)是一种创建型设计模式,确保一个类…...
WebGIS主流的客户端框架比较|OpenLayers|Leaflet|Cesium
实现 WebGIS 应用的主流前端框架主要包括 OpenLayers、Leaflet、Mapbox GL JS 和 Cesium 等。每个框架都有其独特的功能和优势,适合不同的应用场景。 WebGIS主流前端框架的优缺点 前 端 框架优点缺点OpenLayers较重量级的开源库,二维GIS功能最丰富全面…...
【LabVIEW作业篇 - 2】:分数判断、按钮控制while循环暂停、单击按钮获取book文本
文章目录 分数判断按钮控制while循环暂停按钮控制单个while循环暂停 按钮控制多个while循环暂停单击按钮获取book文本 分数判断 限定整型数值输入控件值得输入范围,范围在0-100之间,判断整型数值输入控件的输入值。 输入范围在0-59之间,显示…...
Kafka架构详解之分区Partition
目录 一、简介二、架构三、分区Partition1.分区概念2.Offsets(偏移量)和消息的顺序3.分区如何为Kafka提供扩展能力4.producer写入策略5.consumer消费机制 一、简介 Apache Kafka 是分布式发布 - 订阅消息系统,在 kafka 官网上对 kafka 的定义…...
SSM之Mybatis
SSM之Mybatis 一、MyBatis简介1、MyBatis特性2、MyBatis的下载3、MyBatis和其他持久化层技术对比 二、MyBatis框架搭建三、MyBatis基础功能1、MyBatis核心配置文件2、MyBatis映射文件3、MyBatis实现增删改查4、MyBatis获取参数值的两种方式5、MyBatis查询功能6、MyBatis自定义映…...
Python list comprehension (列表推导式 - 列表解析式 - 列表生成式)
Python list comprehension {列表推导式 - 列表解析式 - 列表生成式} 1. Python list comprehension (列表推导式 - 列表解析式 - 列表生成式)2. Example3. ExampleReferences Python 中的列表解析式并不是用来解决全新的问题,只是为解决已有问题提供新的语法。 列…...
2024年7月12日理发记录
上周五天气还算好,不太热,晚上下班打车回家后,将目的地设置成日常去的那个理发店。 下车走到门口,熟悉的托尼帅哥正在抽烟,他一眼看到了我,马上掐灭烟头,从怀里拿出口香糖,咀嚼起来&…...
几种常用排序算法
1 基本概念 排序是处理数据的一种最常见的操作,所谓排序就是将数据按某字段规律排列,所谓的字段就是数据节点的其中一个属性。比如一个班级的学生,其字段就有学号、姓名、班级、分数等等,我们既可以针对学号排序,也可…...
Spring3(代理模式 Spring1案例补充 Aop 面试题)
一、代理模式 在代理模式(Proxy Pattern)中,一个类代表另一个类的功能,这种类型的设计模式属于结构型模式。 代理模式通过引入一个代理对象来控制对原对象的访问。代理对象在客户端和目标对象之间充当中介,负责将客户端…...
Github报错:Kex_exchange_identification: Connection closed by remote host
文章目录 1. 背景介绍2. 排查和解决方案 1. 背景介绍 Github提交或者拉取代码时,报错如下: Kex_exchange_identification: Connection closed by remote host fatal: Could not read from remote repository.Please make sure you have the correct ac…...
LabVIEW在CRIO中串口通讯数据异常问题
排查与解决步骤 检查硬件连接: 确保CRIO的串口模块正确连接,并且电缆无损坏。 确认串口模块在CRIO中被正确识别和配置。 验证串口配置: 在LabVIEW项目中,检查CRIO目标下的串口配置,确保波特率、数据位、停止位和校验…...
ALTERA芯片解密FPGA、CPLD、PLD芯片解密解密
Altera是世界一流的FPGA、CPLD和ASIC半导体生产商,所提供的解决方案与传统DSP、ASSP和ASIC解决方案相比,缩短了产品面市时间,提高了性能和效能,降低了系统成本。针对Altera芯片解密,益臻芯片解密中心经过多年的芯片解…...
Anomalib Padim模型训练完整踩坑记录:从环境配置、自制数据集准备到ONNX导出一步到位
Anomalib Padim模型实战:工业缺陷检测从零到ONNX部署全指南 工业质检领域正经历一场从传统人工检测到智能算法驱动的变革。想象一下,当生产线上的金属部件以每分钟数十个的速度通过摄像头时,如何确保每个产品表面没有细微划痕、凹陷或腐蚀&am…...
Comsol异构电池力电热耦合模型:探索电池的多场奥秘
comsol异构电池力电热耦合模型 采用椭圆型电极颗粒模拟锂离子正负极的电极颗粒,还原真实电池的3D介观结构,耦合电化学场-热场-力学场,可模拟电流,浓度,温度,应力等多场结果在电池研究领域,深入理…...
OpenClaw备份策略:GLM-4.7-Flash智能管理本地与云端存储
OpenClaw备份策略:GLM-4.7-Flash智能管理本地与云端存储 1. 为什么需要智能备份方案 上周我的移动硬盘突然罢工,导致三个月的项目文档全部丢失。这次惨痛经历让我意识到:传统备份方式已经无法满足现代工作需求。手动备份不仅耗时耗力&#…...
CloudScraper 配置优化:如何提升采集效率与稳定性
在合规采集场景中,不少用户在使用CloudScraper时,频繁出现请求卡顿、采集中断等问题。 本篇文章,LokiProxy将为您系统梳理影响CloudScraper运行效率的关键环节,并结合实际场景提出可行的优化思路,助力用户在合规框架内…...
嵌入式软件工程师面试技术要点解析
嵌入式软件工程师面试技术要点解析1. 通信接口技术1.1 RS-485通信特性RS-485标准采用差分信号传输,物理层上支持全双工通信,但在实际应用中通常配置为半双工模式。这种设计选择主要基于以下工程考虑:半双工模式下只需一对双绞线,显…...
HunyuanVideo-Foley音效生成:支持SMPTE时间码对齐视频关键帧
HunyuanVideo-Foley音效生成:支持SMPTE时间码对齐视频关键帧 1. 产品概述 HunyuanVideo-Foley是一款专为影视后期制作设计的AI音效生成工具,其核心创新在于支持SMPTE时间码精确对齐视频关键帧。这意味着音效师可以基于视频时间轴上的特定帧,…...
CnDataSeed发布:中国科研工作者跳槽研究数据库(CAMRD)
一、数据简介 追踪学术流动,解析科研人才动力机制! 在中国科研生态快速演化的背景下,科研人才流动是科研创新与学术产出的关键驱动力。但跳槽相关研究在高教研究中一直较为稀缺,系统化、可量化的科研工作者跳槽数据长期缺失&…...
首款支持AI渗透的WebShell管理工具,聊个天就能实现免杀|实现高隐蔽内网渗透
0x01 工具介绍 金刚狼首款支持 AI 渗透的 WebShell MCP,也是一款支持多层内网级联的 ASPX、ASHX 高级 WebShell 管理工具。工具采用 AES 加密通信,无需代理即可实现内网穿透,支持内存加载各类渗透工具,做到无文件落地隐蔽渗透目标…...
Echarts实战:如何用散点图+面积图模拟Power BI丝带图效果(附完整代码)
Echarts实战:用散点图与面积图组合实现Power BI丝带图效果 1. 理解丝带图的核心价值与实现难点 丝带图(Ribbon Chart)作为Power BI的特色可视化组件,其独特之处在于能够直观展示数据在不同时间维度上的变化趋势和相对排名。这种图…...
机器学习期末实战:从线性回归到SVM的考题详解(附答案推导)
机器学习期末实战:从线性回归到SVM的考题详解(附答案推导) 期末考试临近,不少同学对机器学习中的核心算法仍存在理解盲区。本文将以典型考题为切入点,深入剖析线性回归、高斯朴素贝叶斯和软间隔SVM的解题逻辑ÿ…...

