当前位置: 首页 > article >正文

RNN实现精神分裂症患者诊断(pytorch)

RNN理论知识

RNN(Recurrent Neural Network,循环神经网络) 是一种 专门用于处理序列数据(如时间序列、文本、语音、视频等)的神经网络。与普通的前馈神经网络(如 MLP、CNN)不同,RNN 具有“记忆”能力,能够利用过去的信息来影响当前的计算结果。

1. RNN 的基本结构

RNN 的核心特点是 “循环”结构,它会将前一个时间步 ( t − 1 ) (t-1) t1计算出的隐藏状态 h t − 1 h_{t-1} ht1 传递给当前时间步 ( t ) (t) t,使得网络可以保留历史信息。

这种结构可以表示为:

h t = f ( W x X t + W h h t − 1 + b ) h_t=f(W_xX_t+W_hh_{t-1}+b) ht=f(WxXt+Whht1+b)

其中:

  • X t X_t Xt:当前时刻的输入数据。
  • h t h_t ht:当前时刻的隐藏状态 。
  • W x 、 W h 、 b W_x、W_h、b WxWhb:可训练的参数 。
  • f f f:激活函数(通常是 tanh 或ReLU)。

RNN 的展开结构:
在时间步(time step)上,RNN 结构可以展开成如下形式:
在这里插入图片描述
图示解释:

X 1 , X 2 , X 3 , . . . X_1,X_2,X_3,... X1,X2,X3,... 代表输入的 序列数据(如文本、时间序列信号)。
h 0 , h 1 , h 2 , h 3 , . . . h_0,h_1,h_2,h_3,... h0,h1,h2,h3,... 代表 隐藏状态,用于存储过去的信息。
Y 1 , Y 2 , Y 3 , . . . Y_1,Y_2,Y_3,... Y1,Y2,Y3,...代表 输出。
在每个时间步,RNN 使用当前输入 X t X_t Xt 和前一时刻的隐藏状态 h t − 1 h_{t-1} ht1来计算新的隐藏状态 h t h_t ht,然后生成输出 Y t Y_t Yt

2. RNN 的缺点

尽管 RNN 在处理序列数据方面有独特的优势,但它也存在一些明显的问题:
(1)梯度消失(Vanishing Gradient)
在长序列训练时,误差的梯度会随着时间步增多而逐渐变小,导致网络无法有效学习较远时间步的信息。
解决方案:使用 LSTM(长短时记忆网络) 或 GRU(门控循环单元) 结构。
(2)梯度爆炸(Exploding Gradient)
如果梯度在反向传播过程中不断累积,可能会变得 非常大,导致模型更新过快或无法收敛。
解决方案:使用 梯度裁剪(Gradient Clipping) 来防止梯度过大。
(3)无法并行计算
由于 RNN 依赖前一个时间步的计算结果,因此无法像 CNN 那样并行计算,这导致训练速度较慢。
解决方案:使用 Transformer 模型(如 BERT、GPT)来替代 RNN。

3. RNN 的改进版本

由于 RNN 存在梯度消失等问题,研究人员提出了更强大的 变种 RNN 结构:
(1)LSTM(Long Short-Term Memory)
在这里插入图片描述

  • LSTM 引入了 “记忆单元” 和 “门机制”,使得它能够保留长期信息,解决梯度消失问题。
  • 包含 遗忘门(Forget Gate)、输入门(Input Gate)、输出门(Output Gate) 三部分来控制信息流。

(2)GRU(Gated Recurrent Unit)

  • GRU 是 LSTM 的简化版本,只包含 更新门(Update Gate) 和 重置门(Reset Gate),计算效率更高。

数据集

精神分裂症数据集,是一个包含精神分裂症人口统计和临床数据的综合数据集。该数据集包括患者的诊断状态、症状评分、治疗史和社会因素。

代码目标

基于给定的特征(如性别、年龄、收入、症状评分等),预测一个人的诊断标签(是否患有精神分裂症),通过可视化训练损失和计算准确率,评估模型的训练效果与性能。

一、前期准备工作

我的环境:

  • 操作系统:windows10
  • 语言环境:Python3.9
  • 编译器:Jupyter notebook
  • 数据集:精神分裂症患者数据集(“schizophrenia_dataset.csv”)

1. 导入库,设置硬件设备

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch#设置GPU训练,也可以使用CPU
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

代码输出:

device(type='cpu')

使用 torch.device() 方法检查当前系统是否有 GPU,并根据条件设置计算设备为 GPU(CUDA)或 CPU。

2. 导入数据

读取指定路径的 CSV 文件,并加载到 pandas 的 DataFrame 中,然后打印出数据框的前五行,用于检查数据的内容。

# 读取数据
file_path = 'schizophrenia_dataset.csv'     # 设置数据文件的路径
df = pd.read_csv(file_path)                 # 使用pandas的read_csv函数读取CSV文件,结果存储在DataFrame对象df中
print(df.head())            # 打印数据框的前五行,检查数据的结构和内容

代码输出:

   Hasta_ID  Yaş  Cinsiyet  Eğitim_Seviyesi  Medeni_Durum  Meslek  \
0         1   72         1                4             2       0   
1         2   49         1                5             2       2   
2         3   53         1                5             3       2   
3         4   67         1                3             2       0   
4         5   54         0                1             2       0   Gelir_Düzeyi  Yaşadığı_Yer  Tanı  Hastalık_Süresi  Hastaneye_Yatış_Sayısı  \
0             2             1     0                0                       0   
1             1             0     1               35                       1   
2             1             0     1               32                       0   
3             2             0     0                0                       0   
4             2             1     0                0                       0   Ailede_Şizofreni_Öyküsü  Madde_Kullanımı  İntihar_Girişimi  \
0                        0                0                 0   
1                        1                1                 1   
2                        1                0                 0   
3                        0                1                 0   
4                        0                0                 0   Pozitif_Semptom_Skoru  Negatif_Semptom_Skoru  GAF_Skoru  Sosyal_Destek  \
0                     32                     48         72              0   
1                     51                     63         40              2   
2                     72                     85         51              0   
3                     10                     21         74              1   
4                      4                     27         98              0   Stres_Faktörleri  İlaç_Uyumu  
0                 2           2  
1                 2           0  
2                 1           1  
3                 1           2  
4                 1           0  

二、构建数据集

1. 划分数据集

处理数据中的不必要列(唯一标识符)和缺失值,以准备好干净的数据进行模型训练。

df = df.drop(columns=['Hasta_ID'])      # 删除 'Hasta_ID' 列,因为该列是唯一标识符,不需要用作模型输入
df = df.fillna(df.mean())      # 使用每一列的均值填充数据框中的缺失值。这里使用 `df.mean()` 来计算均值,并用它来填充缺失值

数据处理流程:

  • 使用 LabelEncoder 将类别变量转换为数值。
  • 将数据划分为特征(X)和目标(y)。
  • 标准化特征数据。
  • 将数据划分为训练集和测试集。
  • 将数据转换为 PyTorch 张量。
  • 调整张量维度以符合 RNN 模型的要求。
label_encoder = LabelEncoder()     # 创建LabelEncoder实例,用于将类别变量转换为数值
df['Cinsiyet'] = label_encoder.fit_transform(df['Cinsiyet'])       # 将 'Cinsiyet'列中的类别值转化为数值
df['Medeni_Durum'] = label_encoder.fit_transform(df['Medeni_Durum'])     # 将 'Medeni_Durum'列中的类别值转化为数值
df['Yaşadığı_Yer'] = label_encoder.fit_transform(df['Yaşadığı_Yer'])     # 将 'Yaşadığı_Yer'列中的类别值转化为数值# 将特征和目标分开
X = df.drop(columns=['Tanı'])     # 将数据框中的 'Tanı' 列移除,剩下的列作为特征(X)
y = df['Tanı']      # 'Tanı' 列作为目标变量(y),表示是否患有精神分裂症(二分类标签)scaler = StandardScaler()     # 创建 StandardScaler 实例,用于标准化特征数据
X_scaled = scaler.fit_transform(X)     # 对特征进行标准化,使得每列的均值为0,标准差为1# 使用 train_test_split 将数据随机划分为训练集和测试集,测试集占20%。random_state=42 设置随机种子,以确保每次划分结果相同
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)# 将数据转换为PyTorch的tensor
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)        # 将训练特征数据转换为PyTorch的tensor格式,并指定数据类型为float32
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)          # 将测试特征数据转换为PyTorch的tensor格式,并指定数据类型为float32
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long)    # 将训练目标数据转换为PyTorch的tensor格式,并指定数据类型为long(用于分类问题)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)      # 将测试目标数据转换为PyTorch的tensor格式,并指定数据类型为long(用于分类问题)# 确保数据的形状符合RNN的要求: [batch_size, seq_len, features]
X_train_tensor = X_train_tensor.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]
X_test_tensor = X_test_tensor.unsqueeze(1)    # [batch_size, features] --> [batch_size, 1, features]# 输出tensor的形状,确保数据正确
print(f"训练数据形状: {X_train_tensor.shape}")     # 打印训练数据的形状,检查是否正确
print(f"测试数据形状: {X_test_tensor.shape}")      # 打印测试数据的形状,检查是否正确

代码输出:

训练数据形状: torch.Size([8000, 1, 18])
测试数据形状: torch.Size([2000, 1, 18])

2. 构建数据加载器

将训练集和测试集的数据(特征和标签)封装成 TensorDataset 对象,并使用 DataLoader 创建数据加载器。
训练集和测试集被分批次加载,每个批次包含 64 个样本。
shuffle=False 表示数据在加载时不进行打乱,在评估的时候顺序保持一致。

from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train_tensor, y_train_tensor),     # 将训练数据、目标数据包装成一个数据集,并创建一个训练数据加载器batch_size=64, shuffle=False)test_dl  = DataLoader(TensorDataset(X_test_tensor, y_test_tensor),      # 将测试数据、目标数据包装成一个数据集,并创建一个测试数据加载器shuffle=False)

三、模型训练

1. 构建模型

import torch.nn as nn#定义一个名为 _RNN_Base 的类,继承自 nn.Module。该类实现了 RNN(包括 RNN、LSTM 和 GRU)的基础结构
class _RNN_Base(nn.Module):def __init__(self, c_in, c_out, hidden_size=100, n_layers=1, bias=True, rnn_dropout=0, bidirectional=False, fc_dropout=0., init_weights=True):"""RNN基础类,支持不同RNN单元(如RNN、LSTM、GRU)的实现。"""super(_RNN_Base, self).__init__()  # 确保正确调用父类的构造函数# 定义RNN层,支持RNN、LSTM、GRU等self.rnn = self._cell(c_in, hidden_size, num_layers=n_layers, bias=bias, batch_first=True, dropout=rnn_dropout, bidirectional=bidirectional)# 定义全连接层的dropout,如果fc_dropout为0则直接用Identityself.dropout = nn.Dropout(fc_dropout) if fc_dropout else nn.Identity()self.fc = nn.Linear(hidden_size * (1 + bidirectional), c_out)def forward(self, x): """        参数:- x: 形状为[batch_size, n_vars, seq_len]。返回:- output: 形状为[batch_size, c_out]。"""# [batch_size, n_vars, seq_len] --> [batch_size, seq_len, n_vars]x = x.transpose(2,1)  # 输出形状为[batch_size, seq_len, hidden_size * (1 + bidirectional)]output, _ = self.rnn(x) # 取最后一个时间步的输出,形状为[batch_size, hidden_size * (1 + bidirectional)]output = output[:, -1]  output = self.fc(self.dropout(output))return output# 定义RNN类,继承自_RNN_Base
class RNN(_RNN_Base):_cell = nn.RNN  # 使用nn.RNN单元# 定义LSTM类,继承自_RNN_Base
class LSTM(_RNN_Base):_cell = nn.LSTM  # 使用nn.LSTM单元# 定义GRU类,继承自_RNN_Base
class GRU(_RNN_Base):_cell = nn.GRU  # 使用nn.GRU单元

定义名为 _RNN_Base 的类,继承自 nn.Module。该类实现了 RNN(包括 RNN、LSTM 和 GRU)的基础结构。

_RNN_Base 类的参数解释:

  • c_in:输入特征的维度,即每个时间步的特征数量。
  • c_out:输出类别数量,即模型的输出维度。
  • hidden_size:RNN隐藏层的大小。
  • n_layers:RNN的层数。
  • bias:是否在RNN层中使用偏置项。
  • rnn_dropout:RNN层中的dropout比例。
  • bidirectional:是否使用双向RNN。
  • fc_dropout:全连接层的dropout比例。
  • init_weights:是否初始化权重。

关于_cell ,定义 RNN 层。self._cell 是一个占位符,它将会被具体子类(RNN、LSTM、GRU)的 _cell 属性替代,相关参数解释:

  • c_in:输入特征的数量。
  • hidden_size:RNN单元的隐藏层大小。
  • num_layers:RNN的层数。
  • bias:是否使用偏置项。
  • batch_first=True:意味着输入和输出的格式为 [batch_size, seq_len,features]。
  • dropout=rnn_dropout:RNN中dropout的概率,用来防止过拟合。
  • bidirectional=bidirectional:是否使用双向RNN(即处理序列时同时考虑正向和反向的时间步)。
# 创建一个基于 RNN 的神经网络模型,并将模型移动到指定的设备(CPU 或 GPU)
model = RNN(c_in=X_train_tensor.shape[1], c_out=2).to(device)    
model 

代码输出:

RNN((rnn): RNN(1, 100, batch_first=True)(dropout): Identity()(fc): Linear(in_features=100, out_features=2, bias=True)
)
from torchinfo import summaryrnn_model = RNN(c_in=3, c_out=5, hidden_size=100,n_layers=2,bidirectional=True, rnn_dropout=.5, fc_dropout=.5)    # 初始化一个 RNN 模型,并设置相关参数summary(rnn_model, input_size=(16, 3, 5))    # 调用 summary 函数,输出 rnn_model 的结构和每一层的详细信息

代码输出:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
RNN                                      --                        --
├─RNN: 1-1                               [16, 5, 200]              81,400
├─Dropout: 1-2                           [16, 200]                 --
├─Linear: 1-3                            [16, 5]                   1,005
==========================================================================================
Total params: 82,405
Trainable params: 82,405
Non-trainable params: 0
Total mult-adds (M): 6.53
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.13
Params size (MB): 0.33
Estimated Total Size (MB): 0.46
==========================================================================================

2. 定义训练函数

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)   # 批次数目train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取数据及其标签X, y = X.to(device), y.to(device)# 1. 确保输入数据有三个维度,添加一个seq_len维度if X.dim() == 2:  # 如果是二维输入,添加一个序列长度维度X = X.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]# 2. 前向传播pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的损失# 3. 反向传播optimizer.zero_grad()  # 清零梯度loss.backward()        # 反向传播optimizer.step()       # 更新参数# 记录准确率和损失train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3. 定义测试函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)   # 批次数目test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)# 1. 确保输入数据有三个维度,添加一个seq_len维度if X.dim() == 2:  # 如果是二维输入,添加一个序列长度维度X = X.unsqueeze(1)  # [batch_size, features] --> [batch_size, 1, features]# 2. 计算损失pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4. 正式训练模型

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 2e-5   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)    # 使用 Adam 优化器,并将学习率 learn_rate 应用到优化器中
epochs     = 20     # 设置训练的总轮数为 20。每轮训练都将通过整个训练集一次train_loss = []  # 初始化一个空列表用于记录每一轮的训练损失
train_acc  = []  # 初始化一个空列表用于记录每一轮的训练准确率
test_loss  = []  # 初始化一个空列表用于记录每一轮的测试损失
test_acc   = []  # 初始化一个空列表用于记录每一轮的测试准确率# 循环遍历训练轮数
for epoch in range(epochs):model.train()    # 设置模型为训练模式epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()    # 设置模型为评估模式epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)  # 将当前训练轮的准确率添加到列表中train_loss.append(epoch_train_loss)  # 将当前训练轮的损失添加到列表中test_acc.append(epoch_test_acc)  # 将当前测试轮的准确率添加到列表中test_loss.append(epoch_test_loss)  # 将当前测试轮的损失添加到列表中# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']# 格式化输出每一轮训练和测试的准确率、损失以及当前学习率template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)

代码输出:

Epoch: 1, Train_acc:70.1%, Train_loss:0.665, Test_acc:70.9%, Test_loss:0.636, Lr:2.00E-05
Epoch: 2, Train_acc:71.4%, Train_loss:0.596, Test_acc:70.3%, Test_loss:0.558, Lr:2.00E-05
Epoch: 3, Train_acc:72.7%, Train_loss:0.507, Test_acc:80.2%, Test_loss:0.442, Lr:2.00E-05
Epoch: 4, Train_acc:90.8%, Train_loss:0.337, Test_acc:95.7%, Test_loss:0.259, Lr:2.00E-05
Epoch: 5, Train_acc:95.9%, Train_loss:0.212, Test_acc:96.4%, Test_loss:0.179, Lr:2.00E-05
Epoch: 6, Train_acc:96.0%, Train_loss:0.161, Test_acc:96.4%, Test_loss:0.146, Lr:2.00E-05
Epoch: 7, Train_acc:96.2%, Train_loss:0.137, Test_acc:96.7%, Test_loss:0.128, Lr:2.00E-05
Epoch: 8, Train_acc:96.5%, Train_loss:0.121, Test_acc:96.7%, Test_loss:0.116, Lr:2.00E-05
Epoch: 9, Train_acc:96.6%, Train_loss:0.110, Test_acc:96.8%, Test_loss:0.107, Lr:2.00E-05
Epoch:10, Train_acc:96.8%, Train_loss:0.103, Test_acc:96.7%, Test_loss:0.100, Lr:2.00E-05
Epoch:11, Train_acc:96.9%, Train_loss:0.097, Test_acc:96.7%, Test_loss:0.095, Lr:2.00E-05
Epoch:12, Train_acc:96.9%, Train_loss:0.092, Test_acc:96.7%, Test_loss:0.091, Lr:2.00E-05
Epoch:13, Train_acc:97.0%, Train_loss:0.089, Test_acc:96.8%, Test_loss:0.088, Lr:2.00E-05
Epoch:14, Train_acc:97.1%, Train_loss:0.085, Test_acc:96.9%, Test_loss:0.084, Lr:2.00E-05
Epoch:15, Train_acc:97.2%, Train_loss:0.082, Test_acc:97.0%, Test_loss:0.081, Lr:2.00E-05
Epoch:16, Train_acc:97.3%, Train_loss:0.078, Test_acc:97.0%, Test_loss:0.077, Lr:2.00E-05
Epoch:17, Train_acc:97.4%, Train_loss:0.075, Test_acc:97.2%, Test_loss:0.073, Lr:2.00E-05
Epoch:18, Train_acc:97.5%, Train_loss:0.071, Test_acc:97.4%, Test_loss:0.070, Lr:2.00E-05
Epoch:19, Train_acc:97.6%, Train_loss:0.068, Test_acc:97.5%, Test_loss:0.065, Lr:2.00E-05
Epoch:20, Train_acc:97.9%, Train_loss:0.063, Test_acc:97.9%, Test_loss:0.061, Lr:2.00E-05
==================== Done ====================

四、模型评估

1. Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 200        #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))   # 创建一个新的图表,并设置图表的大小
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')   # 绘制训练准确率曲线
plt.plot(epochs_range, test_acc, label='Test Accuracy')    # 绘制测试准确率曲线
plt.legend(loc='lower right')      # 显示图例,位置为右下角
plt.title('Training and Validation Accuracy')     # 设置子图的标题
plt.xlabel(current_time)    # 将当前时间作为横坐标标签plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')   # 绘制训练损失曲线
plt.plot(epochs_range, test_loss, label='Test Loss')    # 绘制测试损失曲线
plt.legend(loc='upper right')     # 显示图例,位置为右上角
plt.title('Training and Validation Loss')     # 设置子图的标题
plt.show()    # 显示图表

代码输出:

在这里插入图片描述

2. 混淆矩阵

混淆矩阵(Confusion Matrix) 是一种常用的分类模型评估工具,特别适用于 二分类 和 多分类问题。它能够清晰地展示模型的 真实类别(True Labels) 与 预测类别(Predicted Labels) 之间的对应关系,深入分析模型的分类性能。

# 确保输入数据的维度为 [batch_size, seq_len, features]
print("==============输入数据Shape为==============")
print("X_test.shape:", X_test_tensor.shape)
print("y_test.shape:", y_test_tensor.shape)# 获取预测结果
pred = model(X_test_tensor.to(device)).argmax(1).cpu().numpy()print("\n==============输出数据Shape为==============")
print("pred.shape:", pred.shape)

代码输出:

==============输入数据Shape为==============
X_test.shape: torch.Size([2000, 1, 18])
y_test.shape: torch.Size([2000])==============输出数据Shape为==============
pred.shape: (2000,)
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))    # 创建一个新的图形,设置图形的大小为 6x5 英寸
plt.suptitle('')     # 设置图形的总标题,这里设置为空字符串 '',即不显示总标题
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")    # 使用 seaborn 的热力图函数绘制混淆矩阵# 修改字体大小
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)# 显示图
plt.tight_layout()  # 调整布局防止重叠
plt.show()

代码输出:

在这里插入图片描述

3. 调用模型进行预测

# 选择单个样本并调整形状为 [batch_size, seq_len, features] 
test_X = X_test_tensor[0].reshape(1, 1, -1)  # 注意这里调整为三维的 [1, 1, features] # 获取模型的预测结果
pred = model(test_X.to(device)).argmax(1).item()print("模型预测结果为:", pred)
print("==" * 20)
print("0:未患病")
print("1:已患病")

代码输出:

模型预测结果为: 0
========================================
0:未患病
1:已患病

相关文章:

RNN实现精神分裂症患者诊断(pytorch)

RNN理论知识 RNN(Recurrent Neural Network,循环神经网络) 是一种 专门用于处理序列数据(如时间序列、文本、语音、视频等)的神经网络。与普通的前馈神经网络(如 MLP、CNN)不同,RNN…...

私有云基础架构

基础配置 使用 VMWare Workstation 创建三台 2 CPU、8G内存、100 GB硬盘 的虚拟机 主机 IP 安装服务 web01 192.168.184.110 Apache、PHP database 192.168.184.111 MariaDB web02 192.168.184.112 Apache、PHP 由于 openEuler 22.09 系统已经停止维护了&#xff…...

rust学习笔记11-集合349. 两个数组的交集

rust除了结构体,还有集合类型,同样也很重要,常见的有数组(Array)、向量(Vector)、哈希表(HashMap) 和 集合(HashSet)字符串等,好意外呀…...

全栈(Java+vue)实习面试题(含答案)

在广州一个小公司(BOSS标注是0-20人,薪资2-3k),直接面试没有笔试,一开始就直接拿着简历问,也没有自我介绍,问题是结合场景题和八股文、基础。废话不多说,直接分享面试题目个大家做参考。 1、能…...

SQL经典常用查询语句

1. 基础查询语句 1.1 查询表中所有数据 在SQL中,查询表中所有数据是最基本的操作之一。通过使用SELECT * FROM table_name;语句,可以获取指定表中的所有记录和列。例如,假设有一个名为employees的表,包含员工的基本信息&#xf…...

超详细:数据库的基本架构

MySQL基础架构 下面这个图是我给出的一个MySQL基础架构图,可以清楚的了解到SQL语句在MySQL的各个模块进行执行过程。 然后MySQL可以分为两个部分,一个是server层,另一个是存储引擎。 server层 Server层涵盖了MySQL的大多数核心服务功能&am…...

AI催化新一轮创业潮与创富潮:深圳在抢跑

作者:尺度商业大掌柜黄利明 2025年春节伊始至今,从DeepSeek R1开源模型持续引发全球围观,到腾讯混元Turbo S模型发布秀出了"秒回"绝活,再到国务院发布《新一代人工智能发展规划(2025-2030)》重磅…...

Docker 深度解析:适合零基础用户的详解

此博客涵盖 Docker 的基本概念和作用、架构和核心组件、与传统虚拟机的对比、安装与基本操作,以及在实际开发和运维中的应用场景。 首先,详细解释了 Docker 的基本概念,包括它的诞生背景、作用及其如何解决传统应用部署中的问题。然后&#…...

SpringBoot生成唯一ID的方式

1.为什么要生成唯一ID? 数据唯一性:每个记录都需要有一个独一无二的标识符来确保数据的唯一性。这可以避免重复的数据行,并有助于准确地查询、更新或删除特定的记录。 数据完整性:通过使用唯一ID,可以保证数据库中的数…...

FastGPT 源码:RRF、Rerank 相关代码

文章目录 FastGPT 源码:RRF、Rerank 相关代码1. RRF (Reciprocal Rank Fusion) 合并实现2. Rerank 二次排序实现3. 重排序的主要特点4. 整个搜索流程5. 这种方式的优势 FastGPT 源码:RRF、Rerank 相关代码 下边介绍 RRF 合并和 Rerank 二次排序的相关实…...

Android视频流畅播放要素

要让 Android 设备流畅播放视频,需根据设备性能(低端、中端、高端)和播放场景(本地播放、在线流媒体)动态调整视频参数。以下是针对不同设备的推荐配置方案: 一、通用推荐配置(平衡兼容性与流畅…...

Python:类型转换和深浅拷贝,可变与不可变对象

int():转换为一个整数,只能转换由纯数字组成的字符串 浮点型强转整型会去掉小数点及后面的数,只保留整数部分 #如果字符串中有数字和正负号以外的字符就会报错 float():整形转换为浮点型会自动添加一位小数 .0 如果字符串中有…...

vcredist_x64 资源文件分享

vcredist_x64 是 Microsoft Visual C Redistributable 的 64 位版本,用于在 64 位 Windows 系统上运行使用 Visual C 开发的应用程序。它包含了运行这些应用程序所需的运行时组件。 vcredist_x64 资源工具网盘下载链接:https://pan.quark.cn/s/ef56f838f…...

Linux:vim快捷键

Linux打开vim默认第一个模式是:命令模式! 命令模式快捷键操作: gg:光标快速定位到最开始 shift g G:光标快速定位到最结尾 n shift g n G:光标快速定位到第n行 shift 6 ^:当前行开始 …...

DeepSeek在MATLAB上的部署与应用

在科技飞速发展的当下,人工智能与编程语言的融合不断拓展着创新边界。DeepSeek作为一款备受瞩目的大语言模型,其在自然语言处理领域展现出强大的能力。而MATLAB,作为科学计算和工程领域广泛应用的专业软件,拥有丰富的工具包和高效…...

NAT 代理服务 内网穿透

🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 NAT 技术背景二:🔥 NAT IP 转换过程三:🔥 NAPT四:🔥 代理服务器🦋 正向…...

高级课第五次作业

首先配置交换机,路由器 LSW1配置 [SW1]vlan batch 10 20 30 40 [SW1]int g0/0/2 [SW1-GigabitEthernet0/0/2]port link-type access [SW1-GigabitEthernet0/0/2]port default vlan 10 [SW1]int g0/0/3 [SW1-GigabitEthernet0/0/3]port link-type access […...

51单片机编程学习笔记——动态数码管显示多个数字

大纲 视觉残留原理生理基础神经传导与处理 应用与视觉暂留相关的现象 频闪融合不好的实现好的效果 延伸 在《51单片机编程学习笔记——动态数码管》一文中,我们看到如何使用动态数码管显示数字。但是基于动态数码管设计的特点,每次只能显示1个数字。这就…...

金蝶ERP星空对接流程

1.金蝶ERP星空OPENAPI地址: 金蝶云星空开放平台 2.下载金蝶云星空的对应SDK包 金蝶云星空开放平台 3.引入SDK流程步骤 引入Kingdee.CDP.WebApi.SDK 右键项目添加引用,在打开的引用管理器中选择浏览页签,点击浏览按钮,找到从官…...

【随手笔记】利尔达NB模组

1.名称 移芯EC6263GPP 参数 指令备注 利尔达上电输出 [2025-03-04 10:24:21.379] I_AT_WAIT:i_len2 [2025-03-04 10:24:21.724] LI_AT_WAIT:i_len16 [2025-03-04 10:24:21.724] [2025-03-04 10:24:21.733] Lierda [2025-03-04 10:24:21.733] [2025-03-04 10:24:21.745] OK移…...

Vue3的核心语法【未完】

Vue3的核心语法 OptionsAPI与CompositionAPI Options API(选项式) 和 Composition API (组合式)是 Vue.js 中用于构建组件的两种不同方式。Options API Options API Options API 是 Vue 2 中的传统模式,并在 Vue 3…...

解决redis lettuce连接池经常出现连接拒绝(Connection refused)问题

一.软件环境 windows10、11系统、springboot2.x、redis 6 7 linux(centos)系统没有出现这问题,如果你是linux系统碰到的,本文也有一定大参考价值。 根本思路就是:tcp/ip连接的保活(keepalive)。 二.问题描述 在spr…...

C#进阶指南

C# 是一种功能强大的编程语言,其高级语法特性为开发者提供了更灵活、高效和简洁的编程方式。以下是一些常见的 C# 高级语法特性: 1. 委托(Delegate) 委托是一种类型安全的函数指针,用于封装方法的引用。它可以将方法作为参数传递,实现回调机制。 定义委托: csharp复制 …...

从DNS到TCP:DNS解析流程和浏览器输入域名访问流程

1 DNS 解析流程 1.1 什么是DNS域名解析 在生活中我们会经常遇到域名,比如说CSDN的域名www.csdn.net,百度的域名www.baidu.com,我们也会碰到IP,现在目前有的是IPV4,IPV6。那这两个有什么区别呢?IP地址是互联网上计算机…...

【MySQL、Oracle、SQLserver、postgresql】查询多条数据合并成一行

四大数据库多行合并为单行:函数详解与对比 一、MySQL**GROUP_CONCAT()** 函数说明:语法结构:参数解释:示例:注意事项: 二、Oracle**LISTAGG()** 函数说明:语法结构:参数解释&#xf…...

解锁Egg.js:从Node.js小白到Web开发高手的进阶之路

一、Egg.js 是什么 在当今的 Web 开发领域,Node.js 凭借其事件驱动、非阻塞 I/O 的模型,在构建高性能、可扩展的网络应用方面展现出独特的优势 ,受到了广大开发者的青睐。它让 JavaScript 不仅局限于前端,还能在服务器端大展身手&…...

JavaWeb后端基础(4)

这一篇就开始是做一个项目了,在项目里学习,我主要记录在学习过程中遇到的问题,以及一些知识点 Restful风格 一种软件架构风格 在REST风格的URL中,通过四种请求方式,来操作数据的增删改查。 GET : 查询 …...

软件试用 防破解 防软件调试(C# )

防破解&防软件调试 实现思路 这里采用C#语言为例: 获取网络北京时间:向百度发送 HTTP 请求,从响应头中提取日期时间信息,将其转换为本地时间。记录试用开始时间:首次运行软件时,将获取的百度北京时间作为试用开始时间,并加密存储在本地文件中。检查试用是否过期:每…...

【文献阅读】The Efficiency Spectrum of Large Language Models: An Algorithmic Survey

这篇文章发表于2024年4月 摘要 大语言模型(LLMs)的快速发展推动了多个领域的变革,重塑了通用人工智能的格局。然而,这些模型不断增长的计算和内存需求带来了巨大挑战,阻碍了学术研究和实际应用。为解决这些问题&…...

OpenGL ES -> GLSurfaceView纹理贴图

贴图 XML文件 <?xml version"1.0" encoding"utf-8"?> <com.example.myapplication.MyGLSurfaceViewxmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height…...