深度学习R8周:RNN实现阿尔兹海默症(pytorch)
- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊
数据集包含2149名患者的广泛健康信息,每名患者的ID范围从4751到6900不等。该数据集包括人口统计详细信息、生活方式因素、病史、临床测量、认知和功能评估、症状以及阿尔兹海默症的诊断。
一、前期准备工作
1.设置硬件设备
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
结果输出:

2.导入数据
df = pd.read_csv("alzheimers_disease_data.csv")
# 删除第一列和最后一列
df = df.iloc[:, 1:-1]
print(df)
结果输出:

二、构建数据集
1.标准化
#构建数据集
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1]
y = df.iloc[:,-1]# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X = sc.fit_transform(X)
2.划分数据集
#划分数据集
X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y,test_size = 0.1,random_state = 1)print(X_train.shape, y_train.shape)
3.构建数据加载器
#构建数据加载器
from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)test_dl = DataLoader(TensorDataset(X_test, y_test),batch_size=64,shuffle=False)
输出结果:

![]()
三、模型训练

1.构建模型
#构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200,num_layers=1, batch_first=True)self.fc0 = nn.Linear(200, 50)self.fc1 = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
结果输出:

如何来看模型的输出数据集格式是什么?
#查看数据集输出格式是什么
print(model(torch.rand(30,32).to(device)).shape)
结果输出:
![]()
2.定义训练函数
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0 # 初始化训练损失和正确率for X, y in dataloader: # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
3.定义测试函数
def test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 测试集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
4.正式训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 5e-5 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 50train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)
输出结果:


四、模型评估
1.Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 200 #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
输出结果:

2.混淆矩阵
print("=========输入数据Shape为=========")
print("X_test.shape: ", X_test.shape)
print("y_test.shape: ", y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print("\n======输出数据Shape为 ======")
print("pred.shape: ",pred.shape)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay#计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d",cmap="Blues")#修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label",fontsize=10)
plt.ylabel("True Label", fontsize=10)#显示图
plt.tight_layout()
plt.show()


3.调用模型进行预测
text_X = X_test[0].reshape(1,-1) #test[0]为输入数据pred = model(test_X.to(device)).argmax(1).item()
print("模型预测结果为:“,pred")
print("=="*20)
print("0:未患病")
print("1:已患病")
五、优化特征选择版
特征选择的思路值得学习。
数据维度多,一般是先特征提取,降维等操作。
特征提取:①首先想到相关性分析,用热力图,但分析得出与是否患病相关性比较强的只有四个特征,而日常以为的年龄、日常生活得分这些没有看出有相关性。②通过画图分析特征是否与目标有关,但特征纬度多,不是有效的一个方式。③采用随机森林进行分析,效果很好。
六、总结
根据对数据的预处理,帮助实验精度提高。RNN也是很基础的模型,跟着教案,逐渐开始体会实验的思路。看完流程图,也对自己该怎么干,如何干有了大致的方向。
相关文章:
深度学习R8周:RNN实现阿尔兹海默症(pytorch)
🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 数据集包含2149名患者的广泛健康信息,每名患者的ID范围从4751到6900不等。该数据集包括人口统计详细信息、生活方式因素、病史、临床测量、认知和功…...
vuex中的state是响应式的吗?
在 Vue.js 中,Vuex 的 state 是响应式的。这意味着当你更改 state 中的数据时,依赖于这些数据的 Vue 组件会自动更新。这是通过 Vue 的响应式系统实现的,该系统使用了 ES6 的 Proxy 对象来监听数据的变化。 当你在 Vuex 中定义了一个 state …...
JavaScript系列05-现代JavaScript新特性
JavaScript作为网络的核心语言之一,近年来发展迅速。从ES6(ECMAScript 2015)开始,JavaScript几乎每年都有新的语言特性加入,极大地改善了开发体验和代码质量。本文主要内容包括: ES6关键特性:解构赋值与扩展运算符&am…...
【量化金融自学笔记】--开篇.基本术语及学习路径建议
在当今这个信息爆炸的时代,金融领域正经历着一场前所未有的变革。传统的金融分析方法逐渐被更加科学、精准的量化技术所取代。量化金融,这个曾经高不可攀的领域,如今正逐渐走进大众的视野。它将数学、统计学、计算机科学与金融学深度融合&…...
3d投影到2d python opencv
目录 cv2.projectPoints 投影 矩阵计算投影 cv2.projectPoints 投影 cv2.projectPoints() 是 OpenCV 中的一个函数,用于将三维空间中的点(3D points)投影到二维图像平面上。这在计算机视觉中经常用于相机标定、物体姿态估计、3D物体与2D图…...
26-小迪安全-模块引用,mvc框架,渲染,数据联动0-rce安全
先创建一个新闻需要的库 这样id值可以逐级递增 然后随便写个值,让他输出一下看看 模板引入 但是这样不够美观,这就涉及到了引入html模板 模板引入是html有一个的地方值可以通过php代码去传入过去,其他的html界面直接调用,这样页…...
【第14节】C++设计模式(行为模式)-Strategy (策略)模式
一、问题的提出 Strategy 模式:算法实现与抽象接口的解耦 Strategy 模式和 Template 模式要解决的问题是相似的,都是为了将业务逻辑(算法)的具体实现与抽象接口解耦。Strategy 模式通过将算法封装到一个类(Context&am…...
播放器系列4——PCM重采样
FFmpeg重采样过程 #mermaid-svg-QydNPsDAlg9lTn6z {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-QydNPsDAlg9lTn6z .error-icon{fill:#552222;}#mermaid-svg-QydNPsDAlg9lTn6z .error-text{fill:#552222;stroke:#5…...
网络安全需要学多久才能入门?
网络安全是一个复杂且不断发展的领域,想要入行该领域,我们需要付出足够多的时间和精力好好学习相关知识,才可以获得一份不错的工作,那么网络安全需要学多久才能入门?我们通过这篇文章来了解一下。 学习网络安全的入门时间因个人的…...
通俗版解释:分布式和微服务就像开餐厅
一、分布式系统:把大厨房拆成多个小厨房 想象你开了一家超火爆的餐厅,但原来的厨房太小了: 问题:一个厨师要同时切菜、炒菜、烤面包,手忙脚乱还容易出错。 解决方案: 拆分成多个小厨房(分布式…...
JAVA安全—手搓内存马
前言 最近在学这个内存马,就做一个记录,说实话这个内存马还是有点难度的。 什么是内存马 首先什么是内存马呢,顾名思义就是把木马打进内存中。传统的webshell一旦把文件删除就断开连接了,而Java内存马则不同,它将恶…...
【神经网络】python实现神经网络(一)——数据集获取
一.概述 在文章【机器学习】一个例子带你了解神经网络是什么中,我们大致了解神经网络的正向信息传导、反向传导以及学习过程的大致流程,现在我们正式开始进行代码的实现,首先我们来实现第一步的运算过程模拟讲解:正向传导。本次代…...
历年湖南大学计算机复试上机真题
历年湖南大学计算机复试机试真题 在线评测:https://app2098.acapp.acwing.com.cn/ 杨辉三角形 题目描述 提到杨辉三角形。 大家应该都很熟悉。 这是我国宋朝数学家杨辉在公元 1261 年著书《详解九章算法》提出的。 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 1 5 10 10 …...
[LeetCode]day33 150.逆波兰式求表达值 + 239.滑动窗口最大值
逆波兰式求表达值 题目链接 题目描述 给你一个字符串数组 tokens ,表示一个根据 逆波兰表示法 表示的算术表达式。 请你计算该表达式。返回一个表示表达式值的整数。 注意: 有效的算符为 ‘’、‘-’、‘*’ 和 ‘/’ 。 每个操作数(运…...
【银河麒麟高级服务器操作系统实际案例分享】数据库资源重启现象分析及处理全过程
更多银河麒麟操作系统产品及技术讨论,欢迎加入银河麒麟操作系统官方论坛 https://forum.kylinos.cn 了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer…...
C#中泛型的协变和逆变
协变: 在泛型接口中,使用out关键字可以声明协变。这意味着接口的泛型参数只能作为返回类型出现,而不能作为方法的参数类型。 示例:泛型接口中的协变 假设我们有一个基类Animal和一个派生类Dog: csharp复制 public…...
【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-附录B-严格模式
附录B、严格模式 严格模式 ECMAScript 5 首次引入严格模式的概念。严格模式用于选择以更严格的条件检查 JavaScript 代码错误,可以应用到全局,也可以应用到函数内部。严格模式的好处是可以提早发现错误,因此可以捕获某些 ECMAScript 问题导致…...
跨平台 C++ 程序崩溃调试与 Dump 文件分析
前言 C 程序在运行时可能会由于 空指针访问、数组越界、非法内存访问、栈溢出 等原因崩溃。为了分析崩溃原因,我们通常会生成 Dump 文件(Windows 的 .dmp,Linux 的 core,macOS 的 .crash),然后用调试工具分…...
缺陷VS质量:为何软件缺陷是质量属性的致命对立面?
为何说缺陷是质量的对立面? 核心逻辑:软件质量的定义是“满足用户需求的程度”,而缺陷会直接破坏这种满足关系。 对立性:缺陷的存在意味着软件偏离了预期行为(如功能错误、性能不足、安全性漏洞等)&#…...
伍[5],伺服电机,电流环,速度环,位置环
电流环、速度环和位置环是电机控制系统中常见的三个闭环控制环节,通常采用嵌套结构(内环→外环:电流环→速度环→位置环),各自负责不同层级的控制目标。以下是它们的详细说明及相互关系: 1. 电流环(最内环) 作用:控制电机的电流,间接控制输出转矩(τ=Kt⋅Iτ=Kt⋅…...
基于算法竞赛的c++编程(28)结构体的进阶应用
结构体的嵌套与复杂数据组织 在C中,结构体可以嵌套使用,形成更复杂的数据结构。例如,可以通过嵌套结构体描述多层级数据关系: struct Address {string city;string street;int zipCode; };struct Employee {string name;int id;…...
[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?
🧠 智能合约中的数据是如何在区块链中保持一致的? 为什么所有区块链节点都能得出相同结果?合约调用这么复杂,状态真能保持一致吗?本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里…...
linux之kylin系统nginx的安装
一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源(HTML/CSS/图片等),响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址,提高安全性 3.负载均衡服务器 支持多种策略分发流量…...
Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!
5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...
【配置 YOLOX 用于按目录分类的图片数据集】
现在的图标点选越来越多,如何一步解决,采用 YOLOX 目标检测模式则可以轻松解决 要在 YOLOX 中使用按目录分类的图片数据集(每个目录代表一个类别,目录下是该类别的所有图片),你需要进行以下配置步骤&#x…...
ElasticSearch搜索引擎之倒排索引及其底层算法
文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...
【RockeMQ】第2节|RocketMQ快速实战以及核⼼概念详解(二)
升级Dledger高可用集群 一、主从架构的不足与Dledger的定位 主从架构缺陷 数据备份依赖Slave节点,但无自动故障转移能力,Master宕机后需人工切换,期间消息可能无法读取。Slave仅存储数据,无法主动升级为Master响应请求ÿ…...
【分享】推荐一些办公小工具
1、PDF 在线转换 https://smallpdf.com/cn/pdf-tools 推荐理由:大部分的转换软件需要收费,要么功能不齐全,而开会员又用不了几次浪费钱,借用别人的又不安全。 这个网站它不需要登录或下载安装。而且提供的免费功能就能满足日常…...
