联邦学习实验复现—MNISIT IID实验 pytorch
联邦学习论文复现🚀
在精度的联邦学习的论文之后打算进一步开展写一个联邦学习的基础代码,用于开展之后的相关研究,首先就是复现一下论文中最基础也是最经典的MNIST IID(独立同分布划分) 数据集。然后由于这个联邦学习的论文是谷歌发的,所以官方的代码好像是Tensorflow的,然后为了方便后续的研究我就又自己写了一个pytroch版本的。
记得把代码中的路径都换成自己的
前置文章:联邦学习论文逐句精度:https://blog.csdn.net/chrnhao/article/details/1427517006
文章目录
- 联邦学习论文复现🚀
- 0.预处理流程&项目文件结构
- 1.获取MNIST数据集
- 2.处理测试集 Create_test_datasets.py
- 3.划分客户端样本(训练集数据) Create_client_datasets.py
- 4.构建多客户端Dataloader init_clients.py
- 5.模型代码 CNN.py
- 6.训练代码 train.py
- 6.1 导入库部分
- 6.2 联邦学习超参数初始化
- 6.3 Dataloader 类
- 6.4 固定随机种子
- 6.5 初始化获取所有客户端的dataloader
- 6.6 加载并处理测试集
- 6.7 初始化中心服务器模型和损失函数并迁移到GPU上
- 6.8 client_update 客户端更新函数✨
- 6.9 train 中心服务器训练函数✨
- 6.10 测试代码部分
- 6.11 开始训练&绘图&保存训练结果
- 7.训练结果
- 8.结果对比 plot_compare_curve.py
- 9.结束
0.预处理流程&项目文件结构
Client_datasets
:保存所有客户端的处理好的非图像训练集数据mnisit_test
:存储测试集图片数据mnisit_test
:存储训练集图片数据Test_dataset
:保存测试集的处理好的非图像测试集数据Train_result
:保存每种超参数组合训练后得到的准确率曲线的结果CNN.py
:CNN模型文件Create_client_datasets.py
:将训练集划分为多个客户端的样本Create_test_datasets.py
:处理测试集图片,构建测试集数据init_clients
:工具代码,将所有客户端的数据都处理成dataloaderplot_compare_cureve.py
:绘制结果比较曲线train.py
:训练代码
1.获取MNIST数据集
https://www.kaggle.com/datasets/hojjatk/mnist-dataset
- train-images-idx3-ubyte.gz: training set images (9912422 bytes)
- train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
- t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
- t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
MNIST数据集原始的数据是二级制格式的文件,需要通过代码将其转换成图片格式,之后才能进行后一步的处理,然后具体转换的代码,参考了这个Github代码,感谢博主。
https://www.kaggle.com/datasets/hojjatk/mnist-dataset
训练集和测试集都需要使用代码进行转换,训练集和测试集分别得到10类手写数字的10个文件夹,下面的图是GitHub博主的图。
我自己将转换出的代码分别存在了mnist_train
和mnist_test
中。
2.处理测试集 Create_test_datasets.py
将测试集的中的图片用opencv读取,归一化并打包成npy文件。
import os
import cv2
import numpy as npMNIST_dir_test = r'C:\Users\Administrator\Desktop\Federated\mnist_test'
MNIST_test_list = os.listdir(MNIST_dir_test)data_list = []
label_list = []for label in MNIST_test_list:label_path = os.path.join(MNIST_dir_test, label)print(label, len(os.listdir(label_path)))for image in os.listdir(label_path):image_path = os.path.join(label_path, image)image = cv2.imread(image_path, 0)/255data_list.append([image])label_list.append(int(label))np.save('Test_dataset\MNIST_test_data.npy', data_list)
np.save('Test_dataset\MNIST_test_label.npy', label_list)
3.划分客户端样本(训练集数据) Create_client_datasets.py
首先在联邦学习中,原论文是将mnist的训练集一共60000张图片,划分到了100个客户端中,每个600张图片,这里有一个问题,就是0-9,在数据集中虽然总数是6000,但是每个类别的个数不是正好6000,但是划分IID数据集,又需要将每个客户端上的数据的分布是相同的,也就是类别数量是均匀的,为了应对这种情况,我采用了以下方案。
- 第一步:首先读取每个客户端的数组图片并归一化,保存到一个列表中;
- 第二步:用每个类别的列表的长度除以100,得到每个类别数量除以100得到的商(这里回去我翻了一下小学知识),(假设一类样本数量是5978,则除以100后商为59,则表明他可以均匀的给100个客户端每个客户端100个样本,会有一点剩余);
- 第三步:获得每个类别列表长度的数量除以100得到的余数,这个余数就是每一个类别剩余的样本数量;
- 第四步:首先把每个类别的的剩余样本拿出来留着之后补空;
- 第五步:先将每个类别能均匀分配的样本分配到各个客户端中,然后将剩余样本再顺序填补到每个客户端的数据集中。
这里略微有一点点复杂,想进一步理解的话需要单步运行一下。
import os
import numpy as np
import cv2# 创建100个客户端的文件夹
# for i in range(1, 101):
# os.makedirs(os.path.join('Client_datasets', f'client_{i}'), exist_ok=True)# 获取每个类别数据的图像数据
number_0 = [[cv2.imread(os.path.join("mnist_train/0", i), 0) / 255] for i in os.listdir("mnist_train/0")]
number_1 = [[cv2.imread(os.path.join("mnist_train/1", i), 0) / 255] for i in os.listdir("mnist_train/1")]
number_2 = [[cv2.imread(os.path.join("mnist_train/2", i), 0) / 255] for i in os.listdir("mnist_train/2")]
number_3 = [[cv2.imread(os.path.join("mnist_train/3", i), 0) / 255] for i in os.listdir("mnist_train/3")]
number_4 = [[cv2.imread(os.path.join("mnist_train/4", i), 0) / 255] for i in os.listdir("mnist_train/4")]
number_5 = [[cv2.imread(os.path.join("mnist_train/5", i), 0) / 255] for i in os.listdir("mnist_train/5")]
number_6 = [[cv2.imread(os.path.join("mnist_train/6", i), 0) / 255] for i in os.listdir("mnist_train/6")]
number_7 = [[cv2.imread(os.path.join("mnist_train/7", i), 0) / 255] for i in os.listdir("mnist_train/7")]
number_8 = [[cv2.imread(os.path.join("mnist_train/8", i), 0) / 255] for i in os.listdir("mnist_train/8")]
number_9 = [[cv2.imread(os.path.join("mnist_train/9", i), 0) / 255] for i in os.listdir("mnist_train/9")]# 每个类别的样本总数除以100个客户端
first_round_number = [len(number_0) // 100, len(number_1) // 100, len(number_2) // 100, len(number_3) // 100,len(number_4) // 100, len(number_5) // 100, len(number_6) // 100, len(number_7) // 100,len(number_8) // 100, len(number_9) // 100]# 每个类别剩余的样本数量
remain_number = [len(number_0) % 100, len(number_1) % 100,len(number_2) % 100, len(number_3) % 100,len(number_4) % 100, len(number_5) % 100,len(number_6) % 100, len(number_7) % 100,len(number_8) % 100, len(number_9) % 100]# 获得所有客户端数据集构成的列表
number_list = [number_0, number_1, number_2, number_3, number_4, number_5, number_6, number_7, number_8, number_9]# 处理剩余数据
remain_data = (number_0[-remain_number[0]:] + number_1[-remain_number[1]:] + number_2[-remain_number[2]:] +number_3[-remain_number[3]:] + number_4[-remain_number[4]:] + number_5[-remain_number[5]:] +number_6[-remain_number[6]:] + number_7[-remain_number[7]:] + number_8[-remain_number[8]:] +number_9[-remain_number[9]:])# 剩余数据的标签
remain_label = ([0] * remain_number[0] + [1] * remain_number[1] + [2] * remain_number[2] + [3] * remain_number[3] +[4] * remain_number[4] + [5] * remain_number[5] + [6] * remain_number[6] + [7] * remain_number[7] +[8] * remain_number[8] + [9] * remain_number[9])# 开始构建100个客户端的数据
for i in range(100):data = []label = []for index, j in enumerate(number_list):data += j[i*first_round_number[index]:(i+1)*first_round_number[index]]label += [index] * first_round_number[index]# 将剩余的数据再补充分配到每个数据集当中data += remain_data[i*len(remain_data)//100:(i+1)*len(remain_data)//100]label += remain_label[i*len(remain_label)//100:(i+1)*len(remain_label)//100]# 缓存每个数据集print(i+1,np.shape(data),np.shape(label))# 记得换成自己的路径np.save(rf"C:\Users\Administrator\Desktop\Federated\Client_datasets\client_{i+1}\data.npy", data, allow_pickle=True)np.save(rf"C:\Users\Administrator\Desktop\Federated\Client_datasets\client_{i+1}\label.npy", label, allow_pickle=True)
4.构建多客户端Dataloader init_clients.py
在完成数据集的划分之后,想要在一个电脑上模拟使用多个客户端的数据进行训练,对于使用pytorch框架而言,我能想到的就生成n个dataloader,n就是客户端样本的数量,然后将生成的n个dataloader都存到一个列表里传给客户端,然后客户端如果选择到了用哪个客户端的数据,就用哪个客户端的数据来训练。所以就有了一下的pytorch的Dataloader的代码版本,基本逻辑呢就是,读取所有的客户端文件夹中的数据,然后都封装成pytorch的dataloder,这里正好就可以把联邦学习中的每个客户端本地训练的参数量B
传进去,联邦学习里论文的B
就是clients_dataloader
中的batch_size
参数。这个代码作为一个脚本在项目文件夹的目录中,我命名为了init_clients.py
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Datasetclass CustomDataset(Dataset):def __init__(self, data):self.len = len(data)self.x_data = torch.from_numpy(np.array(list(map(lambda x: x[0], data)), dtype=np.float32))self.y_data = torch.from_numpy(np.array(list(map(lambda x: x[-1], data)))).squeeze().long()def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendef clients_dataloader(batch_size=60):dataloader_list = []dir_path = os.listdir("Client_datasets")for dir in dir_path:data_path = os.path.join("Client_datasets", dir, "data.npy")label_path = os.path.join("Client_datasets", dir, "label.npy")data = np.load(data_path)label = np.load(label_path)dataset = [[i, j] for i, j in zip(data, label)]dataloader = DataLoader(CustomDataset(dataset), shuffle=True, batch_size=batch_size)dataloader_list.append(dataloader)return dataloader_listif __name__ == '__main__':dataloaders = clients_dataloader()if dataloaders:print(f"Loaded {len(dataloaders)} client dataloaders.")else:print("No dataloaders loaded.")
5.模型代码 CNN.py
根据联邦学习论文,我选了MNISIT数据集中的CNN模型进行复现,也是非常简单,两个卷积层,然后跟着两个最大池化层,最后接一个线性层,然后再加个ReLU。
import torchclass CNN(torch.nn.Module):def __init__(self, in_channels=1, classes=10):super(CNN, self).__init__()self.conv1 = torch.nn.Conv2d(in_channels, 32, kernel_size=5)self.max_pool1 = torch.nn.MaxPool2d(2, 2)self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5)self.max_pool2 = torch.nn.MaxPool2d(2, 2)self.linear = torch.nn.Linear(64 * 4 * 4, classes)self.relu = torch.nn.ReLU()def forward(self, x):x = self.conv1(x)x = self.max_pool1(x)x = self.conv2(x)x = self.max_pool2(x)x = x.view(x.size(0), -1)x = self.linear(x)x = self.relu(x)return xif __name__ == '__main__':model = CNN(3, 10)print(model)x = torch.randn((1, 3, 28, 28))y = model(x)print(y.shape)print(y)
6.训练代码 train.py
训练代码的内容确实是先对来说难一些,我这里先给完整代码,然后我再逐个部分进行解释。
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from init_clients import clients_dataloader
import random
from CNN import CNN
# 解决出现libiomp5md.dll缺导致无法绘图的错误
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' C = 0.1 #参与训练的客户端的比例
E = 5 #每个客户端本地训练的轮数
B = 600 #每个客户的BatchSize大小 B>=600 等效于 论文中B=∞client_num = 100# Test_dataset class
class Dataset(Dataset):def __init__(self, data):self.len = len(data)self.x_data = torch.from_numpy(np.array(list(map(lambda x: x[0], data)), dtype=np.float32))self.y_data = torch.from_numpy(np.array(list(map(lambda x: x[-1], data)))).squeeze().long()def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.len# 设置随机种子
def set_random_seed(seed_value=100):random.seed(seed_value) # Fixed Python built-in random generatornp.random.seed(seed_value) # Fixed NumPy random generatortorch.manual_seed(seed_value) # Fixed PyTorch random generatorif torch.cuda.is_available():torch.cuda.manual_seed(seed_value)torch.cuda.manual_seed_all(seed_value) # If multiple GPUs are usedtorch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_random_seed(100)# 初始化客户端dataloader 用100个客户端的数据生成Dataloader
client_dataloader_list = clients_dataloader(B)# 加载测试数据集
data_test = np.load('Test_dataset/MNIST_test_data.npy')
label_test = np.load('Test_dataset/MNIST_test_label.npy')
test_dataset = [[i, j] for i, j in zip(data_test, label_test)]# 将测试集数据处理成Dataloder
Test_dataset = Dataset(test_dataset)
testloader = DataLoader(Test_dataset, shuffle=True, batch_size=256)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型和损失函数,并将模型和损失函数移到GPU上
model = CNN(in_channels=1, classes=10)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
criterion.to(device)def client_update(client_num, E, model_parameter):'''客户端更新参数:param client_num: 选择到的客户端的编号:param E: 本地训练的epoch轮数:param model_parameter: 中心服务器发给客户端的模型参数:return: 在该客户端上训练好的模型参数'''dataloader = client_dataloader_list[client_num] # 获取选择到客户端的dataloderclient_model = CNN().to(device) # 加载一个空模型client_model.load_state_dict(model_parameter) # 加载中心服务器发的模型参数client_model.train() # 将模型变为训练模式# optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01)for i in range(E):correct = 0total = 0for data, label in dataloader:train_data_value, train_data_label = data.to(device), label.to(device)train_data_label_pred = client_model(train_data_value)loss = criterion(train_data_label_pred, train_data_label)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(train_data_label_pred, 1)total += train_data_label.size(0)correct += (predicted == train_data_label).sum().item()accuracy = 100 * correct / totalprint(f'Client:{client_num},Epoch {i+1}/{E}, Accuracy: {accuracy:.2f}%, Loss: {loss.item()}')# 返回客户端训练模型的模型参数return client_model.state_dict()def train(client_num, C, E):model.train() # 将客户端模型变为训练模型send_model_parameter = model.state_dict() # 然后获取即将分发给各个客户端的模型的权重random_numbers = random.sample(range(client_num), int(client_num*C)) # 按照比例随机选择出用于本轮训练的模型的索引编号client_return_parameter_list = [] # 初始化一个列表用与保存每个客户端本地训练好之后返回的模型参数for client in random_numbers: # 遍历所有的选择到的客户端的编号(索引号)model_parameter = client_update(client, E=E, model_parameter=send_model_parameter) # 返回每个客户端训练好的模型权重client_return_parameter_list.append(model_parameter) # 将每一个客户端的模型权重加载到列表中# 先生成一个参数都为0的模型参数的参数字典,用于之后将客户端返回的模型参数都加到改字典上aggregated_model_parameter = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in send_model_parameter.items()}# 将所有客户端模型的权重都加权求和for client_param in client_return_parameter_list:for key in aggregated_model_parameter:aggregated_model_parameter[key] += client_param[key] * (1 / int(client_num*C))# 将求和好的权重加载给中心服务器模型,用于下一轮的发送model.load_state_dict(aggregated_model_parameter)return modeldef test():model.eval()test_correct = 0test_total = 0with torch.no_grad():for testdata in testloader:test_data_value, test_data_label = testdatatest_data_value, test_data_label = test_data_value.to(device), test_data_label.to(device)test_data_label_pred = model(test_data_value)_, test_predicted = torch.max(test_data_label_pred.data, dim=1)test_total += test_data_label.size(0)test_correct += (test_predicted == test_data_label).sum().item()test_acc = round(100 * test_correct / test_total, 3)print(f'Test Accuracy: {test_acc:.2f}%')return test_accif __name__ == '__main__':test_accuracies = []epochs = 1000for i in range(epochs):print(f"Epoch:{i}——",end='')train(client_num, C, E)test_acc = test()test_accuracies.append(test_acc)# Plotting the test accuracy curveplt.plot(range(1, epochs + 1), test_accuracies, marker='o', linestyle='-', color='b')plt.xlabel('Epoch')plt.ylabel('Test Accuracy (%)')plt.title(f'Test Accuracy vs. Epoch (C={C}, E={E}, B={B})')plt.grid(True)plt.show()np.save(rf'Train_result/C{C}B{B}E{E}.npy', test_accuracies)
6.1 导入库部分
除了其他常规部分,这里值得提的是from init_clients import clients_dataloader
这句是需要导入我们自己写的代码,from CNN import CNN
这部分是导入自己的模型。
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from init_clients import clients_dataloader
import random
from CNN import CNN
# 解决出现libiomp5md.dll缺导致无法绘图的错误
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
6.2 联邦学习超参数初始化
这B,C,E三个参数对应的就是联邦学习里最重要的三个超参数。然后这里需要多说的是,这里为什么要初始化客户端数量,因为联邦学习在整个客户端的模型权重部分每个客户端提供的模型参数需要乘以该客户端的数据数量除以全部客户端所持有的数据数量的比例的一个权重,而在该任务中,由于每个客户端的所持有的数据量都一致,所以乘以的权重数量为(1 / int(client_num*C))
也就是选了少个客户端,权重就是多少个客户端分之1。
C = 0.1 #参与训练的客户端的比例
E = 5 #每个客户端本地训练的轮数
B = 600 #每个客户的BatchSize大小 B>=600 等效于 论文中B=∞
client_num = 100
6.3 Dataloader 类
这个是给测试集数据生成dataloader的一个Dataset类,和之前划分客户端的时候用的类一样,这个属于pytorch基础部分。
# Test_dataset class
class Dataset(Dataset):def __init__(self, data):self.len = len(data)self.x_data = torch.from_numpy(np.array(list(map(lambda x: x[0], data)), dtype=np.float32))self.y_data = torch.from_numpy(np.array(list(map(lambda x: x[-1], data)))).squeeze().long()def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.len
6.4 固定随机种子
在联邦学习论文中有专门的论述,使用相同的随机种子初始化的模型使用不同的客户端数据集进行训练,然后加和得到的新模型会比单独两个模型对与测试集的效果更好。
# 设置随机种子
def set_random_seed(seed_value=100):random.seed(seed_value) # Fixed Python built-in random generatornp.random.seed(seed_value) # Fixed NumPy random generatortorch.manual_seed(seed_value) # Fixed PyTorch random generatorif torch.cuda.is_available():torch.cuda.manual_seed(seed_value)torch.cuda.manual_seed_all(seed_value) # If multiple GPUs are usedtorch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_random_seed(100)
6.5 初始化获取所有客户端的dataloader
这个就是我们之前写的代码,B
就是联邦学习论文中的每个客户端的使用的本地的mini_batch的大小,这个代码返回的client_dataloader_list
,是包含所有客户端的数据处理成的dataloader的列表。
# 初始化客户端dataloader 用100个客户端的数据生成Dataloader
client_dataloader_list = clients_dataloader(B)
6.6 加载并处理测试集
这步就是将测试集的数据和标签打包到一起,然后输入Dataset然后实例化出一个Test_dataset,然后再生成一个dataloader,这里dataloader的shuffle,和batch_size,都不是很重要,都不会影响最终的结果,正常填一差不多合适的值就行。
# 加载测试数据集
data_test = np.load('Test_dataset/MNIST_test_data.npy')
label_test = np.load('Test_dataset/MNIST_test_label.npy')
test_dataset = [[i, j] for i, j in zip(data_test, label_test)]# 将测试集数据处理成Dataloder
Test_dataset = Dataset(test_dataset)
testloader = DataLoader(Test_dataset, shuffle=True, batch_size=256)
6.7 初始化中心服务器模型和损失函数并迁移到GPU上
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNN(in_channels=1, classes=10)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
criterion.to(device)
6.8 client_update 客户端更新函数✨
这个逻辑说起来比较复杂,我用GPT帮我解释下,然后我再微调一下,具体的解释在代码的下面。
def client_update(client_num, E, model_parameter):'''客户端更新参数:param client_num: 选择到的客户端的编号:param E: 本地训练的epoch轮数:param model_parameter: 中心服务器发给客户端的模型参数:return: 在该客户端上训练好的模型参数'''dataloader = client_dataloader_list[client_num] # 获取选择到客户端的dataloderclient_model = CNN().to(device) # 加载一个空模型client_model.load_state_dict(model_parameter) # 加载中心服务器发的模型参数client_model.train() # 将模型变为训练模式# optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01)for i in range(E):correct = 0total = 0for data, label in dataloader:train_data_value, train_data_label = data.to(device), label.to(device)train_data_label_pred = client_model(train_data_value)loss = criterion(train_data_label_pred, train_data_label)optimizer.zero_grad()loss.backward()optimizer.step()_, predicted = torch.max(train_data_label_pred, 1)total += train_data_label.size(0)correct += (predicted == train_data_label).sum().item()accuracy = 100 * correct / totalprint(f'Client:{client_num},Epoch {i+1}/{E}, Accuracy: {accuracy:.2f}%, Loss: {loss.item()}')# 返回客户端训练模型的模型参数return client_model.state_dict()
这段代码是一个客户端在联邦学习过程中更新模型参数的逻辑实现。联邦学习是一种分布式的机器学习方法,其中多个客户端独立地在本地数据上训练模型,然后将更新后的模型参数发送到中心服务器进行聚合,从而保护数据隐私。让我们逐行详细解释这段代码。
# 客户端更新参数函数
def client_update(client_num, E, model_parameter):'''客户端更新参数:param client_num: 选择到的客户端的编号:param E: 本地训练的epoch轮数:param model_parameter: 中心服务器发给客户端的模型参数:return: 在该客户端上训练好的模型参数'''
这段代码定义了一个函数 client_update
,它用于更新客户端模型的参数。函数的参数说明如下:
client_num
: 选择到的客户端的编号,即具体是哪一个客户端。E
: 本地训练的 epoch 轮数,即在每个客户端上训练多少次。model_parameter
: 中心服务器发送给客户端的初始模型参数。
函数的目的是在该客户端上使用本地数据训练模型,然后返回训练后的模型参数。
dataloader = client_dataloader_list[client_num] # 获取选择到客户端的dataloder
从 client_dataloader_list
中获取特定客户端的 dataloader
。client_dataloader_list
是一个列表,其中存储了每个客户端的数据加载器。dataloader
用于提供该客户端的训练数据。
client_model = CNN().to(device) # 加载一个空模型
创建一个新的空模型,使用名为 CNN
的神经网络结构(这里假设 CNN
是一个定义好的卷积神经网络)。然后将模型转移到指定的计算设备(device
),可能是 GPU 或 CPU 上。
client_model.load_state_dict(model_parameter) # 加载中心服务器发的模型参数
将中心服务器发送过来的 model_parameter
加载到模型中。这样,客户端的模型从中心服务器的初始模型参数开始训练。
client_model.train() # 将模型变为训练模式
将模型设置为训练模式。这在 PyTorch 中是必要的,因为模型在训练和评估模式下的行为有所不同(例如,批归一化和 dropout 层的行为)。
# optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01, momentum=0.9)optimizer = torch.optim.SGD(client_model.parameters(), lr=0.01)
这里定义了一个优化器,用于更新模型参数。使用的是随机梯度下降(SGD)优化器,学习率 lr
为 0.01
。代码注释掉了包含 momentum
参数的版本,momentum
可以加速收敛,但在此处没有使用。
for i in range(E):correct = 0total = 0
开始一个循环,循环次数为 E
,即 epoch 的数量。每次 epoch 代表对整个训练数据集进行一次完整的训练。初始化变量 correct
和 total
,用于统计模型的预测准确率。
for data, label in dataloader:train_data_value, train_data_label = data.to(device), label.to(device)train_data_label_pred = client_model(train_data_value)
遍历该客户端的 dataloader
,逐批次地获取数据和标签。将数据和标签移到指定设备(GPU 或 CPU)上,然后用模型对输入数据 train_data_value
进行预测,得到 train_data_label_pred
。
loss = criterion(train_data_label_pred, train_data_label)
计算损失,criterion
是损失函数,用于评估模型预测与真实标签之间的差异。
optimizer.zero_grad()loss.backward()optimizer.step()
这是标准的反向传播与梯度更新步骤:
optimizer.zero_grad()
:在每次反向传播之前,将所有参数的梯度置为零,避免梯度累积。loss.backward()
:计算损失相对于模型参数的梯度。optimizer.step()
:根据计算出的梯度更新模型参数。
_, predicted = torch.max(train_data_label_pred, 1)total += train_data_label.size(0)correct += (predicted == train_data_label).sum().item()
这里进行预测的评估:
torch.max(train_data_label_pred, 1)
:找出每个样本的最大值的索引,即模型预测的类别。total
用于累计总的样本数量。correct
用于累计正确预测的数量,即预测值与真实标签相同的样本数量。
accuracy = 100 * correct / totalprint(f'Client:{client_num},Epoch {i+1}/{E}, Accuracy: {accuracy:.2f}%, Loss: {loss.item()}')
计算本地训练的准确率,并输出每个 epoch 的信息,包括客户端编号、当前 epoch、准确率和损失值。
# 返回客户端训练模型的模型参数return client_model.state_dict()
函数返回更新后的模型参数,即客户端在本地数据上训练完成后的模型参数(以 state_dict
的形式返回)。
- 该函数用于实现联邦学习中客户端的本地模型训练过程。
- 中心服务器发送模型参数到客户端,客户端用这些参数作为初始模型进行本地数据训练。
- 客户端训练完成后,将更新后的模型参数返回给服务器,以便进一步聚合
6.9 train 中心服务器训练函数✨
中心服务器的代码是整个联邦学习代码核心的核心,也是我自己调试了最长时间的部分,整体也是让GPT给解释下。
def train(client_num, C, E):model.train() # 将客户端模型变为训练模型send_model_parameter = model.state_dict() # 然后获取即将分发给各个客户端的模型的权重random_numbers = random.sample(range(client_num), int(client_num*C)) # 按照比例随机选择出用于本轮训练的模型的索引编号client_return_parameter_list = [] # 初始化一个列表用与保存每个客户端本地训练好之后返回的模型参数for client in random_numbers: # 遍历所有的选择到的客户端的编号(索引号)model_parameter = client_update(client, E=E, model_parameter=send_model_parameter) # 返回每个客户端训练好的模型权重client_return_parameter_list.append(model_parameter) # 将每一个客户端的模型权重加载到列表中# 先生成一个参数都为0的模型参数的参数字典,用于之后将客户端返回的模型参数都加到改字典上aggregated_model_parameter = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in send_model_parameter.items()}# 将所有客户端模型的权重都加权求和for client_param in client_return_parameter_list:for key in aggregated_model_parameter:aggregated_model_parameter[key] += client_param[key] * (1 / int(client_num*C))# 将求和好的权重加载给中心服务器模型,用于下一轮的发送model.load_state_dict(aggregated_model_parameter)return model
这段代码是中心服务器在联邦学习过程中分发模型参数并聚合客户端返回的模型参数的实现。联邦学习的目标是让多个客户端使用本地数据独立训练模型,然后服务器聚合客户端的更新,从而提升整体模型性能并保持数据隐私。下面我们逐行详细解释这段代码。
# 中心服务器训练模型的函数
def train(client_num, C, E):model.train() # 将客户端模型变为训练模型send_model_parameter = model.state_dict() # 然后获取即将分发给各个客户端的模型的权重
该函数名为 train
,用于执行联邦学习中的模型聚合过程。
client_num
: 表示总客户端数量。C
: 表示客户端参与比例,决定每轮训练中选择的客户端数。E
: 本地训练的 epoch 轮数。
首先,将中心服务器模型设为训练模式 (model.train()
),然后通过 model.state_dict()
获取中心服务器模型的当前参数,这些参数将被分发给客户端。
random_numbers = random.sample(range(client_num), int(client_num*C)) # 按照比例随机选择出用于本轮训练的模型的索引编号client_return_parameter_list = [] # 初始化一个列表用于保存每个客户端本地训练好之后返回的模型参数
random_numbers
使用random.sample()
随机选择一定数量的客户端(client_num * C
)来参与本轮的训练。client_num * C
是根据参与比例选择的客户端数量。client_return_parameter_list
初始化为空列表,用于保存每个客户端本地训练后返回的模型参数。
for client in random_numbers: # 遍历所有的选择到的客户端的编号(索引号)model_parameter = client_update(client, E=E, model_parameter=send_model_parameter) # 返回每个客户端训练好的模型权重client_return_parameter_list.append(model_parameter) # 将每一个客户端的模型权重加载到列表中
- 对于每个被选择的客户端,调用
client_update
函数来进行本地训练,传入客户端编号、epoch 数量以及要发送的模型参数。 client_update
返回训练后的模型参数,将这些参数添加到client_return_parameter_list
中。
# 先生成一个参数都为0的模型参数的参数字典,用于之后将客户端返回的模型参数都加到该字典上aggregated_model_parameter = {key: torch.zeros_like(value, dtype=torch.float32) for key, value in send_model_parameter.items()}
- 初始化一个
aggregated_model_parameter
字典,包含所有模型参数的键,且每个键的值都初始化为与对应原始模型参数相同形状的零张量。这个字典将用于累加各客户端返回的模型参数。
# 将所有客户端模型的权重都加权求和for client_param in client_return_parameter_list:for key in aggregated_model_parameter:aggregated_model_parameter[key] += client_param[key] * (1 / int(client_num*C))
- 遍历
client_return_parameter_list
,对每个客户端返回的模型参数进行聚合。 - 对于每个参数键,将所有客户端的相应参数值按比例累加,比例为每个客户端权重的平均值(
1 / (client_num * C)
)。
# 将求和好的权重加载给中心服务器模型,用于下一轮的发送model.load_state_dict(aggregated_model_parameter)return model
-
将聚合后的模型参数加载到中心服务器的模型中,为下一轮联邦学习做准备。
-
最后返回更新后的中心服务器模型。
-
该函数用于联邦学习中的中心服务器模型训练,通过分发初始模型参数、收集客户端的更新并聚合这些更新来提升全局模型的性能。
-
在每一轮中,随机选择一部分客户端对其本地数据进行训练,然后聚合各客户端返回的模型参数。
-
聚合后,更新中心服务器的模型参数以用于下一轮联邦学习。
6.10 测试代码部分
这个就是基础的计算测试集准确率的代码部分,也是让GPT解释一下。
def test():model.eval()test_correct = 0test_total = 0with torch.no_grad():for testdata in testloader:test_data_value, test_data_label = testdatatest_data_value, test_data_label = test_data_value.to(device), test_data_label.to(device)test_data_label_pred = model(test_data_value)_, test_predicted = torch.max(test_data_label_pred.data, dim=1)test_total += test_data_label.size(0)test_correct += (test_predicted == test_data_label).sum().item()test_acc = round(100 * test_correct / test_total, 3)print(f'Test Accuracy: {test_acc:.2f}%')return test_acc
- 模型设为评估模式
model.eval() # 将模型设为评估模式
model.eval()
用于将模型设置为评估模式。这在 PyTorch 中很重要,因为评估模式会影响像 dropout 和 batch normalization 这样的层,使其在推理阶段使用训练时的参数而不是随机性。
test_correct = 0 # 初始化正确预测数量test_total = 0 # 初始化总测试样本数量
初始化两个变量 test_correct
和 test_total
,分别用于记录正确预测的样本数量和测试样本的总数量。
with torch.no_grad(): # 在测试时不需要计算梯度,提升效率
with torch.no_grad()
用于禁用梯度计算,以减少内存消耗和加快推理速度,因为在测试和推理阶段不需要反向传播。
for testdata in testloader: # 遍历所有测试数据test_data_value, test_data_label = testdatatest_data_value, test_data_label = test_data_value.to(device), test_data_label.to(device)
遍历测试数据集 testloader
,获取每个批次的测试数据和对应标签。然后将这些数据和标签移到计算设备(如 GPU)上。
test_data_label_pred = model(test_data_value) # 使用模型对测试数据进行预测_, test_predicted = torch.max(test_data_label_pred.data, dim=1) # 获取预测值的最大概率索引
使用模型对测试数据进行预测,得到每个样本的预测结果 test_data_label_pred
。使用 torch.max()
找到预测结果中每个样本的最大概率的索引,即模型预测的类别。
test_total += test_data_label.size(0) # 累加测试样本的数量test_correct += (test_predicted == test_data_label).sum().item() # 统计正确预测的数量
test_total
累加当前批次的测试样本数量。test_correct
通过比较预测值和真实标签,统计正确预测的样本数量。
test_acc = round(100 * test_correct / test_total, 3) # 计算测试集的准确率print(f'Test Accuracy: {test_acc:.2f}%') # 打印测试集的准确率return test_acc # 返回测试集的准确率
计算测试集的准确率,并将其四舍五入到小数点后三位。最后打印并返回测试准确率。
- 该函数用于联邦学习中的中心服务器对聚合后的模型进行测试,以评估其性能。
- 测试过程中,模型被设置为评估模式,禁用梯度计算以提高效率。
- 函数通过遍历整个测试数据集来统计预测的正确性,并计算最终的准确率。
6.11 开始训练&绘图&保存训练结果
这里由于我们的随机种子的固定的所以我们是按照不同的超参数配置来命名保存训练文件的。
if __name__ == '__main__':test_accuracies = []epochs = 1000for i in range(epochs):train(client_num, C, E)test_acc = test()print(f"Epoch:{i}——", end='')test_accuracies.append(test_acc)# Plotting the test accuracy curveplt.plot(range(1, epochs + 1), test_accuracies, marker='o', linestyle='-', color='b')plt.xlabel('Epoch')plt.ylabel('Test Accuracy (%)')plt.title(f'Test Accuracy vs. Epoch (C={C}, E={E}, B={B})')plt.grid(True)plt.show()np.save(rf'Train_result/C{C}B{B}E{E}.npy', test_accuracies)
7.训练结果
训练之后的结果如下:可以把client_updat 中的打印每个客户端每轮训练准确率这块注释掉,不然一致刷刷打印,看不清还占用不少时间。
8.结果对比 plot_compare_curve.py
import matplotlib.pyplot as plt
import numpy as npline1 = np.load("Train_result/C0.1B10E5.npy")
line2 = np.load("Train_result/C0.1B50E5.npy")
line3 = np.load("Train_result/C0.1B100E5.npy")
line4 = np.load("Train_result/C0.1B600E1.npy")
line5 = np.load("Train_result/C0.1B600E5.npy")
line6 = np.load("Train_result/C0.2B600E5.npy")
line7 = np.load("Train_result/C1B600E5.npy")plt.plot(line1, linewidth=3.0,label='C=0.1,B=10,E=5')
plt.plot(line2, linewidth=3.0,label='C=0.1,B=50,E=5')
plt.plot(line3, linewidth=3.0,label='C=0.1,B100,E=5')
plt.plot(line4, linewidth=3.0,label='C=0.1,B=∞,E=1')
plt.plot(line5, linewidth=3.0,label='C=0.1,B=∞,E=5')
plt.plot(line6, linewidth=3.0,label='C=0.2,B=∞,E=5')
plt.plot(line7, linewidth=3.0,label='C=1,B=∞,E=5')plt.grid()
plt.legend()plt.xlabel("Tran epoch")
plt.ylabel("Accuracy")
plt.title("Federated Learning")plt.show()
9.结束
这个代码从精度联邦学习论文到复现也是花了不少时间,然后如果不想复制粘贴了手懒了的话,可以到CSDN推广的公众号浩浩的科研笔记中buy,如果是学弟学妹的话,给我私信我也会发你一份。
相关文章:

联邦学习实验复现—MNISIT IID实验 pytorch
联邦学习论文复现🚀 在精度的联邦学习的论文之后打算进一步开展写一个联邦学习的基础代码,用于开展之后的相关研究,首先就是复现一下论文中最基础也是最经典的MNIST IID(独立同分布划分) 数据集。然后由于这个联邦学习的论文是谷歌发的&#…...

2015年-2017年 计算机技术专业 程序设计题(算法题)实战_c语言程序设计数据结构程序设计分析
文章目录 20151.C语言算法设计部分2.数据结构算法设计部分 20161.C语言算法设计部分2.数据结构算法设计部分 2017年1. C语言算法设计部分2.数据结构算法设计部分 2015 1.C语言算法设计部分 int total(int n) {if(n1) return 1;return total(n-1)n1; } //主函数测试代码已省略…...

个人用计算理论导引笔记(待补充)
文章目录 一、正则语言预备知识确定性有穷自动机(DFA)设计DFA正则运算 非确定性有穷自动机(NFA,含有 ε \varepsilon ε,下一个状态可以有若干种选择(包括0种))正则表达式定义计算优…...

2024年诺贝尔物理学奖揭晓:AI背后的“造梦者”是谁?
想象一下,你早上醒来,智能音箱为你播放天气和新闻,中午你用手机刷视频,精准的推荐内容简直和你心有灵犀,晚上回家,自动驾驶汽车安全地把你送回家。这一切看似理所当然,背后却有一双无形的手推动…...

2024年AI 制作PPT新宠儿,3款神器集锦,让你的演示与众不同
咱们今儿聊聊最近超火的AI做PPT的工具。这年头,谁不想省事儿,少熬夜加班,多享受享受生活啊?所以,AI开始帮咱们搞定做PPT这种费时的活儿,我自然得好好研究研究。今天,我就给大家详细说说三款很火…...

CLion和Qt 联合开发环境配置教程(Windows和Linux版)
需要安装的工具CLion 和Qt CLion下载链接 :https://www.jetbrains.com.cn/clion/ 这个软件属于直接默认安装就行,很简单,不多做介绍了 Qt:https://mirrors.tuna.tsinghua.edu.cn/qt/official_releases/online_installers/ window 直接点exe Linux 先c…...

Qt记录使用QtAwesome
Qt记录使用QtAwesome 基本使用 基本使用 pro文件添加 CONFIG fontAwesomeFree include(QtAwesome/QtAwesome.pri) //实例化QtAwesome fa::QtAwesome* awesome new fa::QtAwesome(this); awesome->initFontAwesome();//设置外置适应 图标ICON的颜色color QVariantMap opt…...

ES6新增promise(异步编程新解决方案)如何封装ajax?
1.什么是异步? 异步是指从程序在运行过程中可以先执行其他操作。 2.什么是promise? Promise 是 ES6 引入的异步编程的新解决方案。语法上 Promise 是一个构造函数,用来封装异步 操作并可以获取其成功或失败的结果; 3.promise成功…...

Kubernetes--深入理解Service与CoreDNS
文章目录 Service功能Service 的常见使用场景 Service的模式iptablesIPVS Service类型ClusterIPNodePortLoadBalancerExternalName Service的工作机制EndpointEndpoint 与 Service 的关系Endpoint 的工作原理命令操作 CoreDNSCoreDNS 的配置CoreDNS 的典型插件Corefile 示例Cor…...

AI大模型:开启智能革命新纪元
1.AI大模型技术:智能革命的新引擎 自2022年11月30日OpenAI推出ChatGPT以来,这一大型语言模型(LLM)迅速走红,标志着AI领域进入了一个新的发展阶段,即AI大模型时代。 这一时代预示着AI正朝着通用人工智能&am…...

快速上手C语言【下】(非常详细!!!)
目录 1. 指针 1.1 指针是什么 1.2 指针类型 1.2.1 指针-整数 1.2.2 指针解引用 1.3 const修饰 1.4 字符指针 1.5 指针-指针 1.6 二级指针 2. 数组 2.1 定义和初始化 2.2 下标引用操作符[ ] 2.3 二维数组 2.4 终极测试 3. 函数 3.1 声明和定义 3.2 传值调用…...

红黑树的理解与实现(详解)
相关的数据结构: 搜索二叉树-CSDN博客 AVL树的创建与检测-CSDN博客 个人主页:敲上瘾-CSDN博客 个人专栏:游戏、数据结构、c语言基础、c学习、算法 目录 一、红黑树规则: 二、红黑树的插入 1.变色 2.单旋变色 3.双旋变色 三、…...

从一到无穷大 #37 Databricks Photon:打响 Spark Native Engine 第一枪
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。 本作品 (李兆龙 博文, 由 李兆龙 创作),由 李兆龙 确认,转载请注明版权。 文章目录 引言技术决策JVM vs. Native ExecutionInterpreted Vectorization vs Code-GenRow vs…...

Java 字符串占位格式化
Java 提供了几种方式来处理字符串占位符,最常用的是 String 类的 format 方法和 MessageFormat 类。以下是这两种方法的详细说明和示例。 1、String.format 基本语法: String formatted String.format("格式字符串", 参数1, 参数2, ...); …...

基于netty实现简易版rpc服务-理论分析
1.技术要点 1.1 rpc协议 定义一个rpc协议类,用于rpc服务端和客户端数据交互。 1.2 netty粘包半包处理 由于数据传说使用tcp协议,rpc协议的数据在网络传输过程中会产生三种情况: 1)刚好是完整的一条rpc协议数据 2)不…...

Elasticsearch高级搜索技术-全文搜索
目录 倒排索引 (Inverted Index) 示例 分词器 (Analyzer) 评分机制 (Scoring) 查询执行 match 查询 match_phrase 查询 全文搜索是Elasticsearch的核心功能之一,它通过复杂的算法和数据结构来提供高效的搜索能力。为了深入理解其工作原理,我们需要…...

案例分享—国外优秀UI卡片设计作品赏析
国外UI设计注重用户体验,倾向于采用简洁的布局、清晰的排版和直观的交互方式,减少用户的认知负担。卡片式设计能够完美利用屏幕空间,使内容一目了然,易于用户快速浏览和阅读,从而提升了整体的用户体验。 更加注重扁平化…...

Go语言基础学习(Go安装配置、基础语法)
一、简介及安装教程 1、为什么学习Go? 简单好记的关键词和语法;更高的效率;生态强大;语法检查严格,安全性高;严格的依赖管理, go mod 命令;强大的编译检查、严格的编码规范和完整的…...

STM32—FLASH闪存
1.FLASH简介 STM32F1系列的FLASH包含程序存储器、系统存储器和选项字节三个部分,通过闪存存储器接口(外设)可以对程序存储器和选项字节进行擦除和编程 我们怎么操作这些存储器呢?这就需要用到这个闪存存储器接口了,闪…...

AP上线的那些事儿(1)capwap建立过程、设备初始化以及二层上线
1、了解FITAP与AC的建立过程 之前我们已经知道了FATAP与FIT是一对双胞胎一样的兄弟,FAT哥哥能够直接独立使用当AP桥接、路由器等,而弟弟FIT则比较薄弱,独自发挥不出功效,需要一位师傅(AC)来带领,…...

10 django管理系统 - 管理员管理 - 新建管理员(通过模态框和ajax实现)
在文章“04 django管理系统 - 部门管理 - 新增部门”中,我们通过传统的新增页面来实现部门的添加。 在本文中,我们通过模态框和ajax来实现管理员的新增。 首先在admin_list.html中新建入口,使用按钮 <div class"panel-heading&quo…...

Mysql中表字段VARCHAR(N)类型及长度的解释
本文将针对MySQL 中 varchar (N)类型字段的存储方式进行解释,主要是对字符和字节的关系的理解。 1. varchar (N) 中的 N varchar (N) 中的 N 表示字符数,而不是字节数。这意味着 N 表示你可以存储多少个字符。 字符数:指的是字符的个数&…...

git提交信息写错处理方式
在Git中,你可以通过使用rebase命令来合并提交记录。以下是一个简单的步骤来合并一系列提交: 使用git rebase -i开始交互式变基。在打开的编辑器中,你会看到一个提交列表。若要合并提交,将要合并的提交前面的pick改为squash或s。保…...

C#从零开始学习(用unity探索C#)(unity Lab1)
初次使用Unity 本章所有的代码都放在 https://github.com/hikinazimi/head-first-Csharp Unity的下载与安装 从 unity官网下载Unity Hub Unity的使用 安装后,注册账号,下载unity版本,然后创建3d项目 设置窗口界面布局 3D对象的创建 点击对象,然后点击Move Guzmo,就可以拖动…...

【SpringBoot】15 Echarts+Thymeleaf 绘制各种图表
Gitee仓库 https://gitee.com/Lin_DH/system 介绍 ECharts是百度开源的一个前端组件。它是一个使用 JavaScript 实现的开源可视化库,可以流畅的运行在 PC 和移动设备上,兼容当前绝大部分浏览器(IE8/9/10/11,Chrome,…...

网络学习笔记
一、网络的结构与功能 网络的鲁棒性与抗毁性 如果在移走少量节点后网络中的绝大部分节点仍然是连通的,那么就该网络的连通性对节点故障具有鲁棒性 网络上的动力学 动力系统:自旋、振子或混沌的同步、可激发系统 传播过程:信息传播与拥堵…...

[论文笔记]HERMES 3 TECHNICAL REPORT
引言 今天带来论文HERMES 3 TECHNICAL REPORT,这篇论文提出了一个强大的工具调用模型,包含了训练方案介绍。同时提出了一个函数调用标准。 为了简单,下文中以翻译的口吻记录,比如替换"作者"为"我们"。 聊天模…...

MySQL-19.多表设计-一对多-外键
一.多表问题分析 二.添加外键 三.外键约束的问题...

MySQL程序介绍<一>
目录 MySQL程序简介 mysqld - MySQL 服务器 编辑 mysql - MySQL 命令⾏客⼾端 MySQL程序简介 1.MySQL安装完成通常会包含如下程序: Linux系统程序⼀般在 /usr/bin⽬录下,可以通过命令查看 windows系统⽬录: 你的安装路径\MySQL Server…...

Leetcode 第 419 场周赛题解
Leetcode 第 419 场周赛题解 Leetcode 第 419 场周赛题解题目1:3318. 计算子数组的 x-sum I思路代码复杂度分析 题目2:3319. 第 K 大的完美二叉子树的大小思路代码复杂度分析 题目3:思路代码复杂度分析 题目4:3321. 计算子数组的 …...