深度学习R8周:RNN实现阿尔兹海默症(pytorch)
- 🍨 本文为🔗365天深度学习训练营中的学习记录博客
- 🍖 原作者:K同学啊
数据集包含2149名患者的广泛健康信息,每名患者的ID范围从4751到6900不等。该数据集包括人口统计详细信息、生活方式因素、病史、临床测量、认知和功能评估、症状以及阿尔兹海默症的诊断。
一、前期准备工作
1.设置硬件设备
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
结果输出:
2.导入数据
df = pd.read_csv("alzheimers_disease_data.csv")
# 删除第一列和最后一列
df = df.iloc[:, 1:-1]
print(df)
结果输出:
二、构建数据集
1.标准化
#构建数据集
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_splitX = df.iloc[:,:-1]
y = df.iloc[:,-1]# 将每一列特征标准化为标准正太分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X = sc.fit_transform(X)
2.划分数据集
#划分数据集
X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y,test_size = 0.1,random_state = 1)print(X_train.shape, y_train.shape)
3.构建数据加载器
#构建数据加载器
from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)test_dl = DataLoader(TensorDataset(X_test, y_test),batch_size=64,shuffle=False)
输出结果:
三、模型训练
1.构建模型
#构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200,num_layers=1, batch_first=True)self.fc0 = nn.Linear(200, 50)self.fc1 = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
结果输出:
如何来看模型的输出数据集格式是什么?
#查看数据集输出格式是什么
print(model(torch.rand(30,32).to(device)).shape)
结果输出:
2.定义训练函数
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 训练集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0 # 初始化训练损失和正确率for X, y in dataloader: # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
3.定义测试函数
def test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 测试集的大小num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
4.正式训练模型
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 5e-5 # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 50train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)
输出结果:
四、模型评估
1.Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 200 #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
输出结果:
2.混淆矩阵
print("=========输入数据Shape为=========")
print("X_test.shape: ", X_test.shape)
print("y_test.shape: ", y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print("\n======输出数据Shape为 ======")
print("pred.shape: ",pred.shape)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay#计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d",cmap="Blues")#修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label",fontsize=10)
plt.ylabel("True Label", fontsize=10)#显示图
plt.tight_layout()
plt.show()
3.调用模型进行预测
text_X = X_test[0].reshape(1,-1) #test[0]为输入数据pred = model(test_X.to(device)).argmax(1).item()
print("模型预测结果为:“,pred")
print("=="*20)
print("0:未患病")
print("1:已患病")
五、优化特征选择版
特征选择的思路值得学习。
数据维度多,一般是先特征提取,降维等操作。
特征提取:①首先想到相关性分析,用热力图,但分析得出与是否患病相关性比较强的只有四个特征,而日常以为的年龄、日常生活得分这些没有看出有相关性。②通过画图分析特征是否与目标有关,但特征纬度多,不是有效的一个方式。③采用随机森林进行分析,效果很好。
六、总结
根据对数据的预处理,帮助实验精度提高。RNN也是很基础的模型,跟着教案,逐渐开始体会实验的思路。看完流程图,也对自己该怎么干,如何干有了大致的方向。
相关文章:

深度学习R8周:RNN实现阿尔兹海默症(pytorch)
🍨 本文为🔗365天深度学习训练营中的学习记录博客🍖 原作者:K同学啊 数据集包含2149名患者的广泛健康信息,每名患者的ID范围从4751到6900不等。该数据集包括人口统计详细信息、生活方式因素、病史、临床测量、认知和功…...
vuex中的state是响应式的吗?
在 Vue.js 中,Vuex 的 state 是响应式的。这意味着当你更改 state 中的数据时,依赖于这些数据的 Vue 组件会自动更新。这是通过 Vue 的响应式系统实现的,该系统使用了 ES6 的 Proxy 对象来监听数据的变化。 当你在 Vuex 中定义了一个 state …...

JavaScript系列05-现代JavaScript新特性
JavaScript作为网络的核心语言之一,近年来发展迅速。从ES6(ECMAScript 2015)开始,JavaScript几乎每年都有新的语言特性加入,极大地改善了开发体验和代码质量。本文主要内容包括: ES6关键特性:解构赋值与扩展运算符&am…...
【量化金融自学笔记】--开篇.基本术语及学习路径建议
在当今这个信息爆炸的时代,金融领域正经历着一场前所未有的变革。传统的金融分析方法逐渐被更加科学、精准的量化技术所取代。量化金融,这个曾经高不可攀的领域,如今正逐渐走进大众的视野。它将数学、统计学、计算机科学与金融学深度融合&…...

3d投影到2d python opencv
目录 cv2.projectPoints 投影 矩阵计算投影 cv2.projectPoints 投影 cv2.projectPoints() 是 OpenCV 中的一个函数,用于将三维空间中的点(3D points)投影到二维图像平面上。这在计算机视觉中经常用于相机标定、物体姿态估计、3D物体与2D图…...

26-小迪安全-模块引用,mvc框架,渲染,数据联动0-rce安全
先创建一个新闻需要的库 这样id值可以逐级递增 然后随便写个值,让他输出一下看看 模板引入 但是这样不够美观,这就涉及到了引入html模板 模板引入是html有一个的地方值可以通过php代码去传入过去,其他的html界面直接调用,这样页…...

【第14节】C++设计模式(行为模式)-Strategy (策略)模式
一、问题的提出 Strategy 模式:算法实现与抽象接口的解耦 Strategy 模式和 Template 模式要解决的问题是相似的,都是为了将业务逻辑(算法)的具体实现与抽象接口解耦。Strategy 模式通过将算法封装到一个类(Context&am…...
播放器系列4——PCM重采样
FFmpeg重采样过程 #mermaid-svg-QydNPsDAlg9lTn6z {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-QydNPsDAlg9lTn6z .error-icon{fill:#552222;}#mermaid-svg-QydNPsDAlg9lTn6z .error-text{fill:#552222;stroke:#5…...

网络安全需要学多久才能入门?
网络安全是一个复杂且不断发展的领域,想要入行该领域,我们需要付出足够多的时间和精力好好学习相关知识,才可以获得一份不错的工作,那么网络安全需要学多久才能入门?我们通过这篇文章来了解一下。 学习网络安全的入门时间因个人的…...
通俗版解释:分布式和微服务就像开餐厅
一、分布式系统:把大厨房拆成多个小厨房 想象你开了一家超火爆的餐厅,但原来的厨房太小了: 问题:一个厨师要同时切菜、炒菜、烤面包,手忙脚乱还容易出错。 解决方案: 拆分成多个小厨房(分布式…...

JAVA安全—手搓内存马
前言 最近在学这个内存马,就做一个记录,说实话这个内存马还是有点难度的。 什么是内存马 首先什么是内存马呢,顾名思义就是把木马打进内存中。传统的webshell一旦把文件删除就断开连接了,而Java内存马则不同,它将恶…...

【神经网络】python实现神经网络(一)——数据集获取
一.概述 在文章【机器学习】一个例子带你了解神经网络是什么中,我们大致了解神经网络的正向信息传导、反向传导以及学习过程的大致流程,现在我们正式开始进行代码的实现,首先我们来实现第一步的运算过程模拟讲解:正向传导。本次代…...
历年湖南大学计算机复试上机真题
历年湖南大学计算机复试机试真题 在线评测:https://app2098.acapp.acwing.com.cn/ 杨辉三角形 题目描述 提到杨辉三角形。 大家应该都很熟悉。 这是我国宋朝数学家杨辉在公元 1261 年著书《详解九章算法》提出的。 1 1 1 1 2 1 1 3 3 1 1 4 6 4 1 1 5 10 10 …...

[LeetCode]day33 150.逆波兰式求表达值 + 239.滑动窗口最大值
逆波兰式求表达值 题目链接 题目描述 给你一个字符串数组 tokens ,表示一个根据 逆波兰表示法 表示的算术表达式。 请你计算该表达式。返回一个表示表达式值的整数。 注意: 有效的算符为 ‘’、‘-’、‘*’ 和 ‘/’ 。 每个操作数(运…...

【银河麒麟高级服务器操作系统实际案例分享】数据库资源重启现象分析及处理全过程
更多银河麒麟操作系统产品及技术讨论,欢迎加入银河麒麟操作系统官方论坛 https://forum.kylinos.cn 了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer…...
C#中泛型的协变和逆变
协变: 在泛型接口中,使用out关键字可以声明协变。这意味着接口的泛型参数只能作为返回类型出现,而不能作为方法的参数类型。 示例:泛型接口中的协变 假设我们有一个基类Animal和一个派生类Dog: csharp复制 public…...
【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-附录B-严格模式
附录B、严格模式 严格模式 ECMAScript 5 首次引入严格模式的概念。严格模式用于选择以更严格的条件检查 JavaScript 代码错误,可以应用到全局,也可以应用到函数内部。严格模式的好处是可以提早发现错误,因此可以捕获某些 ECMAScript 问题导致…...
跨平台 C++ 程序崩溃调试与 Dump 文件分析
前言 C 程序在运行时可能会由于 空指针访问、数组越界、非法内存访问、栈溢出 等原因崩溃。为了分析崩溃原因,我们通常会生成 Dump 文件(Windows 的 .dmp,Linux 的 core,macOS 的 .crash),然后用调试工具分…...
缺陷VS质量:为何软件缺陷是质量属性的致命对立面?
为何说缺陷是质量的对立面? 核心逻辑:软件质量的定义是“满足用户需求的程度”,而缺陷会直接破坏这种满足关系。 对立性:缺陷的存在意味着软件偏离了预期行为(如功能错误、性能不足、安全性漏洞等)&#…...
伍[5],伺服电机,电流环,速度环,位置环
电流环、速度环和位置环是电机控制系统中常见的三个闭环控制环节,通常采用嵌套结构(内环→外环:电流环→速度环→位置环),各自负责不同层级的控制目标。以下是它们的详细说明及相互关系: 1. 电流环(最内环) 作用:控制电机的电流,间接控制输出转矩(τ=Kt⋅Iτ=Kt⋅…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动
一、前言说明 在2011版本的gb28181协议中,拉取视频流只要求udp方式,从2016开始要求新增支持tcp被动和tcp主动两种方式,udp理论上会丢包的,所以实际使用过程可能会出现画面花屏的情况,而tcp肯定不丢包,起码…...

对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...

Nuxt.js 中的路由配置详解
Nuxt.js 通过其内置的路由系统简化了应用的路由配置,使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

Golang——6、指针和结构体
指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

解析“道作为序位生成器”的核心原理
解析“道作为序位生成器”的核心原理 以下完整展开道函数的零点调控机制,重点解析"道作为序位生成器"的核心原理与实现框架: 一、道函数的零点调控机制 1. 道作为序位生成器 道在认知坐标系$(x_{\text{物}}, y_{\text{意}}, z_{\text{文}}…...
webpack面试题
面试题:webpack介绍和简单使用 一、webpack(模块化打包工具)1. webpack是把项目当作一个整体,通过给定的一个主文件,webpack将从这个主文件开始找到你项目当中的所有依赖文件,使用loaders来处理它们&#x…...
Spring事务传播机制有哪些?
导语: Spring事务传播机制是后端面试中的必考知识点,特别容易出现在“项目细节挖掘”阶段。面试官通过它来判断你是否真正理解事务控制的本质与异常传播机制。本文将从实战与源码角度出发,全面剖析Spring事务传播机制,帮助你答得有…...
第6章:Neo4j数据导入与导出
在实际应用中,数据的导入与导出是使用Neo4j的重要环节。无论是初始数据加载、系统迁移还是数据备份,都需要高效可靠的数据传输机制。本章将详细介绍Neo4j中的各种数据导入与导出方法,帮助读者掌握不同场景下的最佳实践。 6.1 数据导入策略 …...

Unity VR/MR开发-开发环境准备
视频讲解链接: 【XR马斯维】UnityVR/MR开发环境准备【UnityVR/MR开发教程--入门】_哔哩哔哩_bilibili...

2025-06-08-深度学习网络介绍(语义分割,实例分割,目标检测)
深度学习网络介绍(语义分割,实例分割,目标检测) 前言 在开始这篇文章之前,我们得首先弄明白,什么是图像分割? 我们知道一个图像只不过是许多像素的集合。图像分割分类是对图像中属于特定类别的像素进行分类的过程,即像素级别的…...