深度学习R8周:RNN实现阿尔兹海默症(pytorch)
- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊
数据集包含2149名患者的广泛健康信息,每名患者的ID范围从4751到6900不等。该数据集包括人口统计详细信息、生活方式因素、病史、临床测量、认知和功能评估、症状以及阿尔兹海默症的诊断。
一、前期准备工作
1.设置硬件设备
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
结果输出:

2.导入数据
df = pd.read_csv("alzheimers_disease_data.csv")
# 删除第一列和最后一列
df = df.iloc[:, 1:-1]
print(df)
结果输出:

二、构建数据集
1.标准化
#构建数据集
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1]
y = df.iloc[:,-1]# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X = sc.fit_transform(X)
2.划分数据集
#划分数据集
X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y,test_size = 0.1,random_state = 1)print(X_train.shape, y_train.shape)
3.构建数据加载器
#构建数据加载器
from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)test_dl = DataLoader(TensorDataset(X_test, y_test),batch_size=64,shuffle=False)
输出结果:

![]()
三、模型训练

1.构建模型
#构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200,num_layers=1, batch_first=True)self.fc0 = nn.Linear(200, 50)self.fc1 = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
结果输出:

如何来看模型的输出数据集格式是什么?
#查看数据集输出格式是什么
print(model(torch.rand(30,32).to(device)).shape)
结果输出:
![]()
2.定义训练函数
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0 # 初始化训练损失和正确率for X, y in dataloader: # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
3.定义测试函数
def test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 测试集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
4.正式训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 5e-5 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 50train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)
输出结果:


四、模型评估
1.Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 200 #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
输出结果:

2.混淆矩阵
print("=========输入数据Shape为=========")
print("X_test.shape: ", X_test.shape)
print("y_test.shape: ", y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print("\n======输出数据Shape为 ======")
print("pred.shape: ",pred.shape)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay#计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d",cmap="Blues")#修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label",fontsize=10)
plt.ylabel("True Label", fontsize=10)#显示图
plt.tight_layout()
plt.show()


3.调用模型进行预测
text_X = X_test[0].reshape(1,-1) #test[0]为输入数据pred = model(test_X.to(device)).argmax(1).item()
print("模型预测结果为:“,pred")
print("=="*20)
print("0:未患病")
print("1:已患病")
五、优化特征选择版
特征选择的思路值得学习。
数据维度多,一般是先特征提取,降维等操作。
特征提取:①首先想到相关性分析,用热力图,但分析得出与是否患病相关性比较强的只有四个特征,而日常以为的年龄、日常生活得分这些没有看出有相关性。②通过画图分析特征是否与目标有关,但特征纬度多,不是有效的一个方式。③采用随机森林进行分析,效果很好。
六、总结
根据对数据的预处理,帮助实验精度提高。RNN也是很基础的模型,跟着教案,逐渐开始体会实验的思路。看完流程图,也对自己该怎么干,如何干有了大致的方向。
相关文章:
深度学习R8周:RNN实现阿尔兹海默症(pytorch)
🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 数据集包含2149名患者的广泛健康信息,每名患者的ID范围从4751到6900不等。该数据集包括人口统计详细信息、生活方式因素、病史、临床测量、认知和功…...
vuex中的state是响应式的吗?
在 Vue.js 中,Vuex 的 state 是响应式的。这意味着当你更改 state 中的数据时,依赖于这些数据的 Vue 组件会自动更新。这是通过 Vue 的响应式系统实现的,该系统使用了 ES6 的 Proxy 对象来监听数据的变化。 当你在 Vuex 中定义了一个 state …...
JavaScript系列05-现代JavaScript新特性
JavaScript作为网络的核心语言之一,近年来发展迅速。从ES6(ECMAScript 2015)开始,JavaScript几乎每年都有新的语言特性加入,极大地改善了开发体验和代码质量。本文主要内容包括: ES6关键特性:解构赋值与扩展运算符&am…...
【量化金融自学笔记】--开篇.基本术语及学习路径建议
在当今这个信息爆炸的时代,金融领域正经历着一场前所未有的变革。传统的金融分析方法逐渐被更加科学、精准的量化技术所取代。量化金融,这个曾经高不可攀的领域,如今正逐渐走进大众的视野。它将数学、统计学、计算机科学与金融学深度融合&…...
3d投影到2d python opencv
目录 cv2.projectPoints 投影 矩阵计算投影 cv2.projectPoints 投影 cv2.projectPoints() 是 OpenCV 中的一个函数,用于将三维空间中的点(3D points)投影到二维图像平面上。这在计算机视觉中经常用于相机标定、物体姿态估计、3D物体与2D图…...
26-小迪安全-模块引用,mvc框架,渲染,数据联动0-rce安全
先创建一个新闻需要的库 这样id值可以逐级递增 然后随便写个值,让他输出一下看看 模板引入 但是这样不够美观,这就涉及到了引入html模板 模板引入是html有一个的地方值可以通过php代码去传入过去,其他的html界面直接调用,这样页…...
【第14节】C++设计模式(行为模式)-Strategy (策略)模式
一、问题的提出 Strategy 模式:算法实现与抽象接口的解耦 Strategy 模式和 Template 模式要解决的问题是相似的,都是为了将业务逻辑(算法)的具体实现与抽象接口解耦。Strategy 模式通过将算法封装到一个类(Context&am…...
播放器系列4——PCM重采样
FFmpeg重采样过程 #mermaid-svg-QydNPsDAlg9lTn6z {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-QydNPsDAlg9lTn6z .error-icon{fill:#552222;}#mermaid-svg-QydNPsDAlg9lTn6z .error-text{fill:#552222;stroke:#5…...
网络安全需要学多久才能入门?
网络安全是一个复杂且不断发展的领域,想要入行该领域,我们需要付出足够多的时间和精力好好学习相关知识,才可以获得一份不错的工作,那么网络安全需要学多久才能入门?我们通过这篇文章来了解一下。 学习网络安全的入门时间因个人的…...
通俗版解释:分布式和微服务就像开餐厅
一、分布式系统:把大厨房拆成多个小厨房 想象你开了一家超火爆的餐厅,但原来的厨房太小了: 问题:一个厨师要同时切菜、炒菜、烤面包,手忙脚乱还容易出错。 解决方案: 拆分成多个小厨房(分布式…...
JAVA安全—手搓内存马
前言 最近在学这个内存马,就做一个记录,说实话这个内存马还是有点难度的。 什么是内存马 首先什么是内存马呢,顾名思义就是把木马打进内存中。传统的webshell一旦把文件删除就断开连接了,而Java内存马则不同,它将恶…...
【神经网络】python实现神经网络(一)——数据集获取
一.概述 在文章【机器学习】一个例子带你了解神经网络是什么中,我们大致了解神经网络的正向信息传导、反向传导以及学习过程的大致流程,现在我们正式开始进行代码的实现,首先我们来实现第一步的运算过程模拟讲解:正向传导。本次代…...
历年湖南大学计算机复试上机真题
历年湖南大学计算机复试机试真题 在线评测:https://app2098.acapp.acwing.com.cn/ 杨辉三角形 题目描述 提到杨辉三角形。 大家应该都很熟悉。 这是我国宋朝数学家杨辉在公元 1261 年著书《详解九章算法》提出的。 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 1 5 10 10 …...
[LeetCode]day33 150.逆波兰式求表达值 + 239.滑动窗口最大值
逆波兰式求表达值 题目链接 题目描述 给你一个字符串数组 tokens ,表示一个根据 逆波兰表示法 表示的算术表达式。 请你计算该表达式。返回一个表示表达式值的整数。 注意: 有效的算符为 ‘’、‘-’、‘*’ 和 ‘/’ 。 每个操作数(运…...
【银河麒麟高级服务器操作系统实际案例分享】数据库资源重启现象分析及处理全过程
更多银河麒麟操作系统产品及技术讨论,欢迎加入银河麒麟操作系统官方论坛 https://forum.kylinos.cn 了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer…...
C#中泛型的协变和逆变
协变: 在泛型接口中,使用out关键字可以声明协变。这意味着接口的泛型参数只能作为返回类型出现,而不能作为方法的参数类型。 示例:泛型接口中的协变 假设我们有一个基类Animal和一个派生类Dog: csharp复制 public…...
【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-附录B-严格模式
附录B、严格模式 严格模式 ECMAScript 5 首次引入严格模式的概念。严格模式用于选择以更严格的条件检查 JavaScript 代码错误,可以应用到全局,也可以应用到函数内部。严格模式的好处是可以提早发现错误,因此可以捕获某些 ECMAScript 问题导致…...
跨平台 C++ 程序崩溃调试与 Dump 文件分析
前言 C 程序在运行时可能会由于 空指针访问、数组越界、非法内存访问、栈溢出 等原因崩溃。为了分析崩溃原因,我们通常会生成 Dump 文件(Windows 的 .dmp,Linux 的 core,macOS 的 .crash),然后用调试工具分…...
缺陷VS质量:为何软件缺陷是质量属性的致命对立面?
为何说缺陷是质量的对立面? 核心逻辑:软件质量的定义是“满足用户需求的程度”,而缺陷会直接破坏这种满足关系。 对立性:缺陷的存在意味着软件偏离了预期行为(如功能错误、性能不足、安全性漏洞等)&#…...
伍[5],伺服电机,电流环,速度环,位置环
电流环、速度环和位置环是电机控制系统中常见的三个闭环控制环节,通常采用嵌套结构(内环→外环:电流环→速度环→位置环),各自负责不同层级的控制目标。以下是它们的详细说明及相互关系: 1. 电流环(最内环) 作用:控制电机的电流,间接控制输出转矩(τ=Kt⋅Iτ=Kt⋅…...
别再手动画封装了!用UltraLibrarian和3D ContentCentral搞定AD/Altium Designer的3D模型(附避坑技巧)
高效获取Altium Designer封装与3D模型的终极指南 在PCB设计领域,封装获取一直是工程师们日常工作中最耗时却又必不可少的环节。想象一下,当你正全神贯注于一个复杂的电路设计,突然发现某个关键元器件没有现成的封装可用,不得不停…...
别再只写assign了!用三种Verilog建模风格重构你的三人表决器(行为级/数据流/门级)
别再只写assign了!用三种Verilog建模风格重构你的三人表决器 三人表决器是数字电路设计中的经典案例,它能直观展示不同抽象层次的Verilog建模风格如何影响代码质量与硬件实现。很多工程师习惯性地使用assign语句完成所有设计,却忽略了Verilo…...
蓝桥杯单片机备赛:AT24C02读写避坑指南(附STC15完整工程)
蓝桥杯单片机备赛:AT24C02读写避坑指南(附STC15完整工程) 在蓝桥杯单片机竞赛中,AT24C02这颗小小的EEPROM芯片常常成为决定胜负的关键。作为参赛选手,你可能已经掌握了I2C协议的基本原理,但在紧张的比赛环境…...
告别终端!为OpenWrt打造Web版脚本管家:Luci插件开发实战与全功能解析
1. 为什么我们需要Web版脚本管家? 每次在OpenWrt上折腾脚本都要打开终端,这对新手来说简直是噩梦。记得我第一次给路由器写脚本时,光是学会用vi编辑器就花了半小时,保存退出时还差点把系统搞崩。后来发现用WinSCP上传脚本还要改权…...
深入UE渲染管线:从.usf文件到FGlobalShader,理解全局Shader的完整生命周期与最佳实践
深入UE渲染管线:从.usf文件到FGlobalShader,理解全局Shader的完整生命周期与最佳实践 当我们需要在Unreal Engine中实现一个全新的后处理效果或定制底层渲染管线时,全局Shader(Global Shader)往往是必经之路。与材质编…...
告别‘端口冲突’:手把手教你用Ganache CLI和UI版搭建本地以太坊测试链(macOS/Windows)
告别‘端口冲突’:手把手教你用Ganache CLI和UI版搭建本地以太坊测试链(macOS/Windows) 在以太坊开发中,本地测试链是不可或缺的工具。Ganache作为Truffle套件中的明星产品,提供了CLI和UI两种版本,但许多开…...
dialoqbase入门指南:如何在5分钟内创建你的第一个AI聊天机器人
dialoqbase入门指南:如何在5分钟内创建你的第一个AI聊天机器人 【免费下载链接】dialoqbase Create chatbots with ease 项目地址: https://gitcode.com/gh_mirrors/di/dialoqbase dialoqbase是一款强大的开源工具,让你能够轻松创建AI聊天机器人。…...
思源宋体TTF格式终极指南:免费商用中文字体的完整使用教程
思源宋体TTF格式终极指南:免费商用中文字体的完整使用教程 【免费下载链接】source-han-serif-ttf Source Han Serif TTF 项目地址: https://gitcode.com/gh_mirrors/so/source-han-serif-ttf 还在为商业项目寻找既专业又免费的中文字体而烦恼吗?…...
告别繁琐配置!用EB和S32DS快速搭建AutoSar MCAL基础工程(附完整文件结构解析)
从零构建AutoSar MCAL工程:EB与S32DS高效协作实战指南 当第一次打开AutoSar MCAL的官方示例工程时,多数工程师都会被密密麻麻的文件夹和配置文件淹没。Base、Platform、ECUC、MemIf等模块交织在一起,而EB生成的generate文件夹里又充斥着大量看…...
Visual C++运行库终极修复指南:如何3步解决95%的DLL缺失问题?
Visual C运行库终极修复指南:如何3步解决95%的DLL缺失问题? 【免费下载链接】vcredist AIO Repack for latest Microsoft Visual C Redistributable Runtimes 项目地址: https://gitcode.com/gh_mirrors/vc/vcredist Visual C运行库是Windows系统…...
