RNN实现阿尔茨海默症的诊断识别
本文为为🔗365天深度学习训练营内部文章
原作者:K同学啊
一 导入数据
import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")df = pd.read_excel('dia.xls')
df

二 数据处理分析
# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
Age Gender Ethnicity EducationLevel BMI Smoking 0 73 0 0 2 22.927749 0 \ 1 89 0 0 0 26.827681 0 2 73 0 3 1 17.795882 0 3 74 1 0 1 33.800817 1 4 89 0 0 0 20.716974 0 ... ... ... ... ... ... ... 2144 61 0 0 1 39.121757 0 2145 75 0 0 2 17.857903 0 2146 77 0 0 1 15.476479 0 2147 78 1 3 1 15.299911 0 2148 72 0 0 2 33.289738 0 AlcoholConsumption PhysicalActivity DietQuality SleepQuality ... 0 13.297218 6.327112 1.347214 9.025679 ... \ 1 4.542524 7.619885 0.518767 7.151293 ... 2 19.555085 7.844988 1.826335 9.673574 ... 3 12.209266 8.428001 7.435604 8.392554 ... 4 18.454356 6.310461 0.795498 5.597238 ... ... ... ... ... ... ... 2144 1.561126 4.049964 6.555306 7.535540 ... 2145 18.767261 1.360667 2.904662 8.555256 ... 2146 4.594670 9.886002 8.120025 5.769464 ... 2147 8.674505 6.354282 1.263427 8.322874 ... 2148 7.890703 6.570993 7.941404 9.878711 ... FunctionalAssessment MemoryComplaints BehavioralProblems ADL 0 6.518877 0 0 1.725883 \ 1 7.118696 0 0 2.592424 2 5.895077 0 0 7.119548 3 8.965106 0 1 6.481226 4 6.045039 0 0 0.014691 ... ... ... ... ... 2144 0.238667 0 0 4.492838 2145 8.687480 0 1 9.204952 2146 1.972137 0 0 5.036334 2147 5.173891 0 0 3.785399 2148 6.307543 0 1 8.327563 Confusion Disorientation PersonalityChanges 0 0 0 0 \ 1 0 0 0 2 0 1 0 3 0 0 0 4 0 0 1 ... ... ... ... 2144 1 0 0 2145 0 0 0 2146 0 0 0 2147 0 0 0 2148 0 1 0 DifficultyCompletingTasks Forgetfulness Diagnosis 0 1 0 0 1 0 1 0 2 1 0 0 3 0 0 0 4 1 0 0 ... ... ... ... 2144 0 0 1 2145 0 0 1 2146 0 0 1 2147 0 1 1 2148 0 1 0 [2149 rows x 33 columns]
三 探索性数据分析
1.得病分布
res = df.groupby('Diabetes')['Age'].count()
print(res)plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1, 0),wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()

2.BMI分布直方图
# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

3.年龄分布直方图
# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()
四 构建划分数据集
X = df.iloc[:,:-1]
y = df.iloc[:,-1]sc = StandardScaler()
X = sc.fit_transform(X)# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)
五 训练模型
1.构建模型
# 构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)self.fc0 = nn.Linear(200,50)self.fc1 = nn.Linear(50,2)def forward(self,x):out , hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
model_rnn((rnn0): RNN(32, 200, batch_first=True)(fc0): Linear(in_features=200, out_features=50, bias=True)(fc1): Linear(in_features=50, out_features=2, bias=True) )
2.训练函数
'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader) # 批次数目,(size/batchsize,向上取整)train_acc,train_loss = 0,0 # 初始化训练损失和正确率for x,y in dataloader: # 获取数据X,y = x.to(device),y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred,y) # 计算误差# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss
3.测试函数
# 测试循环
def valid(dataloader,model,loss_fn):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader) # 批次数目,(size/batchsize,向上取整)test_loss, test_acc = 0, 0 # 初始化训练损失和正确率# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target = imgs.to(device),target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred,target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc,test_loss
4.正式训练
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04 Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04 Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04 Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04 Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04 Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04 Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04 Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04 Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04 Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04 Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04 Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04 Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04 Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04 Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04 Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04 Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04 Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04 Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04 Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04 Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04 Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04 Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04 Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04 Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04 Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04 Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04 Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04 Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04 Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04 ==================== Done ====================
六 结果可视化
1.Loss和Acc图
epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

2.调用模型预测
test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0 ======================================== 0:未患病 1:已患病
3.绘制混淆矩阵
'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为============== X_test.shape: torch.Size([215, 32]) y_test.shape: torch.Size([215])==========输出数据shape为============== pred.shape: (215,)
相关文章:
RNN实现阿尔茨海默症的诊断识别
本文为为🔗365天深度学习训练营内部文章 原作者:K同学啊 一 导入数据 import torch.nn as nn import torch.nn.functional as F import torchvision,torch from sklearn.preprocessing import StandardScaler from torch.utils.data import TensorDatase…...
HackTheBox靶机:Sightless;NodeJS模板注入漏洞,盲XSS跨站脚本攻击漏洞实战
HackTheBox靶机:Sightless 渗透过程1. 信息收集常规探测深入分析 2. 漏洞利用(CVE-2022-0944)3. 从Docker中提权4. 信息收集(michael用户)5. 漏洞利用 Froxlor6. 解密Keepass文件 漏洞分析SQLPad CVE-2022-0944 靶机介…...
docker安装elk6.7.1-搜集java日志
docker安装elk6.7.1-搜集java日志 如果对运维课程感兴趣,可以在b站上、A站或csdn上搜索我的账号: 运维实战课程,可以关注我,学习更多免费的运维实战技术视频 0.规划 192.168.171.130 tomcat日志filebeat 192.168.171.131 …...
XML实体注入漏洞攻与防
JAVA中的XXE攻防 回显型 无回显型 cve-2014-3574...
Flutter 与 React 前端框架对比:深入分析与实战示例
Flutter 与 React 前端框架对比:深入分析与实战示例 在现代前端开发中,Flutter 和 React 是两个非常流行的框架。Flutter 是 Google 推出的跨平台开发框架,支持从一个代码库生成 iOS、Android、Web 和桌面应用;React 则是 Facebo…...
使用 Docker Compose 一键启动 Redis、MySQL 和 RabbitMQ
目录 一、Docker Compose 简介 二、服务配置详解 1. Redis 配置 2. MySQL 配置 3. RabbitMQ 配置 三、数据持久化与时间同步 四、部署与管理 五、总结 目录挂载与卷映射的区别 现代软件开发中,微服务架构因其灵活性和可扩展性而备受青睐。为了支持微服务的…...
【问题解决】el-upload数据上传成功后不显示成功icon
el-upload数据上传成功后不显示成功icon 原因 由于后端返回数据与要求形式不符,使用el-upload默认方法调用onSuccess钩子失败,上传文件的状态并未发生改变,因此数据上传成功后并未显示成功的icon标志。 解决方法 点击按钮,调用…...
spring框架之IoC学习与梳理(1)
目录 一、spring-IoC的基本解释。 二、spring-IoC的简单demo(案例)。 (1)maven-repository官网中找依赖坐标。 (2).pom文件中通过标签引入。 (3)使用lombok帮助快速开发。 ÿ…...
MQ的可靠消息投递机制
确保消息在发送、传递和消费过程中不会丢失、重复消费或错乱。 1. 消息的可靠投递 消息持久化: 消息被发送到队列后会存储在磁盘上,即使消息队列崩溃,消息也不会丢失。例如:Kafka、RabbitMQ等都支持持久化消息。Kafka通过将消息存…...
Mono里运行C#脚本35—加载C#语言基类的过程
前面大体地分析了整个Mono运行过程,主要从文件的加载,再到EXE文件的入口点, 然后到方法的编译,机器代码的生成,再到函数调用的跳板转换,进而解析递归地实现JIT。 但是还有很多功能没有解析的,就是C#语言相关最多的,就是类的加载,以及类语言设计的实现属性, 比如类的…...
location+rewrite实现隐性域名配置
隐性域名:访问www.a.com 则跳转到www.b.com的页面,但是地址栏还是显示www.a.com 1、配置基于根目录的隐性域名(就是nginx反向代理) 访问http://www.bbb.org:8002, 跳转http://www.accp.org:8001的页面,地址…...
150 Linux 网络编程6 ,从socket 到 epoll整理。listen函数参数再研究
一 . 只能被一个client 链接 socket例子 此例子用于socket 例子, 该例子只能用于一个客户端连接server。 不能用于多个client 连接 server socket_server_support_one_clientconnect.c /* 此例子用于socket 例子, 该例子只能用于一个客户端连接server。…...
centos7 配置国内镜像源安装 docker
使用国内镜像源:由于 Docker 的官方源在国内访问可能不稳定,你可以使用国内的镜像源,如阿里云的镜像源。手动创建 /etc/yum.repos.d/docker-ce.repo 文件,并添加以下内容: [docker-ce-stable] nameDocker CE Stable -…...
周末总结(2024/01/25)
工作 人际关系核心实践: 要学会随时回应别人的善意,执行时间控制在5分钟以内 坚持每天早会打招呼 遇到接不住的话题时拉低自己,抬高别人(无阴阳气息) 朋友圈点赞控制在5min以内,职场社交不要放在5min以外 职场的人际关系在面对利…...
【go语言】map 和 list
一、map map 是一种无序的键值对的集合。 无序 :map[key]键值对:key - value map 最重要的一点是通过 key 来快速检索数据,key 类似于索引,指向数据的值。map 是一种集合,所以我们可以像迭代数组和切片那样迭代他。…...
PCIe 个人理解专栏——【2】LTSSM(Link Training and Status State Machine)
前言: 链路训练和状况状态机LTSSM(Link Training and Status State Machine)是整个链路训练和运行中状态的状态转换逻辑关系图,总共有11个状态。 正文: 包括检测(Detect),轮询&…...
《DiffIR:用于图像修复的高效扩散模型》学习笔记
paper:2303.09472 GitHub:GitHub - Zj-BinXia/DiffIR: This project is the official implementation of Diffir: Efficient diffusion model for image restoration, ICCV2023 目录 摘要 1、介绍 2、相关工作 2.1 图像恢复(Image Rest…...
Vue3 30天精进之旅:Day01 - 初识Vue.js的奇妙世界
引言 在前端开发领域,Vue.js是一款极具人气的JavaScript框架。它以其简单易用、灵活高效的特性,吸引了大量开发者。本文是“Vue 3之30天系列学习”的第一篇博客,旨在帮助大家快速了解Vue.js的基本概念和核心特性,为后续的深入学习…...
[笔记] 极狐GitLab实例 : 手动备份步骤总结
官方备份文档 : 备份和恢复极狐GitLab 一. 要求 为了能够进行备份和恢复,请确保您系统已安装 Rsync。 如果您安装了极狐GitLab: 如果您使用 Omnibus 软件包,则无需额外操作。如果您使用源代码安装,您需要确定是否安装了 rsync。…...
将本地项目上传到 GitLab/GitHub
以下是将本地项目上传到 GitLab 的完整步骤,从创建仓库到推送代码的详细流程: 1. 在 GitLab 上创建新项目 登录 GitLab,点击 New project。选择 Create blank project。填写项目信息: Project name: 项目名称(如 my-p…...
switch组件的功能与用法
文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了PageView这个Widget,本章回中将介绍Switch Widget.闲话休提,让我们一起Talk Flutter吧。 1 概念介绍 我们在这里介绍的Switch是指左右滑动的开关,常用来表示某项设置是打开还是关闭。Fl…...
mac 电脑上安装adb命令
在Mac下配置android adb命令环境,配置方式如下: 1、下载并安装IDE (android studio) Android Studio官网下载链接 详细的安装连接请参考 Mac 安装Android studio 2、配置环境 在安装完成之后,将android的adb工具所在…...
Couchbase UI: Dashboard
以下是 Couchbase UI Dashboard 页面详细介绍,包括页面布局和功能说明,帮助你更好地理解和使用。 1. 首页(Overview) 功能:提供集群的整体健康状态和性能摘要 集群状态 节点健康状况:绿色(正…...
[极客大挑战 2019]Knife1
题目 蚁剑直接连接密码是Syc 拿下flag flag{1d373584-fc74-4a2c-a6d4-3691314be4ab}...
第17篇:python进阶:详解数据分析与处理
第17篇:数据分析与处理 内容简介 本篇文章将深入探讨数据分析与处理在Python中的应用。您将学习如何使用pandas库进行数据清洗与分析,掌握matplotlib和seaborn库进行数据可视化,以及处理大型数据集的技巧。通过丰富的代码示例和实战案例&am…...
【Maui】提示消息的扩展
文章目录 前言一、问题描述二、解决方案三、软件开发(源码)3.1 消息扩展库3.2 消息提示框使用3.3 错误消息提示使用3.4 问题选择框使用 四、项目展示 前言 .NET 多平台应用 UI (.NET MAUI) 是一个跨平台框架,用于使用 C# 和 XAML 创建本机移…...
消息队列篇--通信协议扩展篇--二进制编码(ASCII,UTF-8,UTF-16,Unicode等)
1、ASCII(American Standard Code for Information Interchange) 范围:0 到 127(共 128 个字符)描述:ASCII 是一种早期的字符编码标准,主要用于表示英文字母、数字和一些常见的符号。每个字符占…...
CentOS 7 搭建lsyncd实现文件实时同步 —— 筑梦之路
在 CentOS 7 上搭建 lsyncd(Live Syncing Daemon)以实现文件的实时同步,可以按照以下步骤进行操作。lsyncd 是一个基于 inotify 的轻量级实时同步工具,支持本地和远程同步。以下是详细的安装和配置步骤: 1. 系统准备 …...
【2025最新计算机毕业设计】基于SSM房屋租赁平台【提供源码+答辩PPT+文档+项目部署】(高质量源码,可定制,提供文档,免费部署到本地)
作者简介:✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流。✌ 主要内容:🌟Java项目、Python项目、前端项目、PHP、ASP.NET、人工智能…...
(开源)基于Django+Yolov8+Tensorflow的智能鸟类识别平台
1 项目简介(开源地址在文章结尾) 系统旨在为了帮助鸟类爱好者、学者、动物保护协会等群体更好的了解和保护鸟类动物。用户群体可以通过平台采集野外鸟类的保护动物照片和视频,甄别分类、实况分析鸟类保护动物,与全世界各地的用户&…...
