当前位置: 首页 > news >正文

【联邦学习——手动搭建简易联邦学习】

1. 目的

用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。

2. 代码

2.1 本地模型计算梯度更新

# 比较训练前后的参数变化
def compare_weights(new_model, old_model):weight_updates = {}for layer_name, params in new_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)return weight_updates

测试代码如下:
有意思的点在于我获得了update = model2-model1
但是我去计算model1+update==model2的时候发现不相等
最后思考了一下可能是在这个计算的过程中存在精度的丢失

import torch
from torch import nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoaderdef weight_init(m):if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)# 也可以判断是否为conv2d,使用相应的初始化方式elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 是否为批归一化层elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if __name__ == '__main__':# 建立model1和model2作为训练前后的模型model1 = models.get_model("resnet18")model1.apply(weight_init)model2 = models.get_model("resnet18")model2.apply(weight_init)weight_updates = compare_weights(model2,model1)# 创建一个临时的模型状态字典,用于存储更新后的参数updated_params = model1.state_dict().copy()for layer_name, update in weight_updates.items():# 确保该层存在于model1中且形状匹配,避免错误if layer_name in updated_params and update.shape == updated_params[layer_name].shape:# 直接相加更新参数updated_params[layer_name] += updateelse:print(f"Warning: Layer {layer_name} not found or shape mismatch, skipping update.")# 将更新后的参数加载回model1model1.load_state_dict(updated_params)for layer_name,params in model1.state_dict().items():ts = params-model2.state_dict().get(layer_name)# 很重要,这里应该是我们在做的时候会有一点模型精度上的损失,所以不能够计算这里等于0if torch.sum(ts).item()>1e-6:print(f"{layer_name}更新后与原有的不匹配,差距为")else:print(f"{layer_name}更新后与原有的匹配")

2.2 客户端代码

import numpy as np
import torch.utils.data
from tqdm import tqdm'''
conf 配置文件
model 模型
train_dataset 数据集
class_ratios 从数据集中筛选出一部分 class_ratios = {0: 0.5, 1: 0.5,..., 8: 0.5, 9: 0.5}
id 客户端的标识
'''class Client(object):def __init__(self,conf,model,device,train_loader,id=1):self.client_id = id                     # 客户端IDself.conf = conf                        # 配置文件self.local_model = model                # 客户端本地模型self.train_loader = train_loader        # 训练数据的迭代器,需要训练的数据已经在里面了self.grad_update = dict()               # 本地训练完之后的梯度更新self.weight = conf['weight']            # 全局模型梯度更新时的权重self.device = device                    # 训练的设备self.local_model.to(self.device)        # 将模型放入训练设备def train(self, model):self._before_train(model)self._local_train()self._after_train(model)def _before_train(self, model):self._load_global_model(model)# 用服务器模型来覆盖本地模型def _load_global_model(self,model):for name,param in model.state_dict().items():# 客户端首先用服务器端下发的全局模型覆盖本地模型self.local_model.state_dict()[name].copy_(param.clone())def _local_train(self):# 定义最优化函数器,用于本地模型训练optimizer = torch.optim.SGD(self.local_model.parameters(),lr=self.conf['lr'],momentum=self.conf['momentum'])# 本地模型训练self.local_model.train()loss = 0for epoch in range(self.conf['local_epochs']):for batch in tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.conf['local_epochs']}"):data, target = batch# 放入相应的设备data = data.to(self.device)target = target.to(self.device)# 梯度清零optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)# 反向传播loss.backward()optimizer.step()print(f"Client{self.client_id}----Epoch {epoch} done.Loss {loss}")def _after_train(self,model):self._cal_update_weights(model)def _cal_update_weights(self, old_model):weight_updates = dict()for layer_name, params in self.local_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)# 更新梯度模型的权重self.grad_update = weight_updates

2.3 服务器代码

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transformsfrom utils.CommonUtils import copy_model_params# 服务端
class Server(object):def __init__(self, conf, eval_dataset, device):self.conf = conf# 全局老模型self.old_model = models.get_model(self.conf["model_name"])# 全局的新模型self.global_model = models.get_model(self.conf["model_name"])# 创建时保持新老模型的参数是一致的copy_model_params(self.old_model,self.global_model)# 根据客户端上传的梯度进行排列组合,用于测量贡献度的模型self.sub_model = models.get_model(self.conf["model_name"])self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)self.accuracy_history = []  # 保存accuracy的数组self.loss_history = []  # 保存loss的数组self.device = deviceself.old_model.to(device)self.global_model.to(device)self.sub_model.to(device)# 模型重构def model_aggregate(self, clients, target_model):if target_model == self.global_model:print("++++++++全局模型更新++++++++")# 更新一下老模型参数copy_model_params(self.old_model,self.global_model)else:print("========子模型重构========")sum_weight = 0# 计算总的权重for client in clients:sum_weight += client.weight# 将old_model的模型参数赋值给sub_modelcopy_model_params(self.sub_model, self.old_model)# 初始化一个空字典来累积客户端的模型更新aggregated_updates = {}# 遍历每个客户端for client in clients:# 根据客户端的权重比例聚合更新for name, update in client.grad_update.items():if name not in aggregated_updates:aggregated_updates[name] = update * client.weight / sum_weightelse:aggregated_updates[name] += update * client.weight / sum_weight# 应用聚合后的更新到sub_modelfor name, param in target_model.state_dict().items():if name in aggregated_updates:param.copy_(param + aggregated_updates[name])  # 累加更新到当前层参数上# 定义模型评估函数def model_eval(self,target_model):target_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id,batch in enumerate(self.eval_loader):data,target = batchdataset_size += data.size()[0]# 放入和模型对应的设备data = data.to(self.device)target = target.to(self.device)# 模型预测output = target_model(data)# 把损失值聚合起来total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 获取最大的对数概率的索引值pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()# 计算准确率acc = 100.0 * (float(correct) / float(dataset_size))# 计算损失值total_l = total_loss / dataset_size# 将accuracy和loss保存到数组中self.accuracy_history.append(acc)self.loss_history.append(total_l)if target_model == self.global_model:print(f"++++++++全局模型评估++++++++acc:{acc}  loss:{total_l}")else:print(f"========子模型评估========acc:{acc}  loss:{total_l}")return acc,total_ldef save_results_to_file(self):# 将accuracy和loss保存到文件中with open("fed_accuracy_history.txt", "w") as f:for acc in self.accuracy_history:f.write("{:.2f}\n".format(acc))with open("fed_loss_history.txt", "w") as f:for loss in self.loss_history:f.write("{:.4f}\n".format(loss))

2.4 Utils

def copy_model_params(target_model, source_model):for name, param in source_model.state_dict().items():target_model.state_dict()[name].copy_(param.clone())

3. 运行测试代码

import jsonimport torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torchvision import transforms, datasets
from client.Client import Client
from server.Server import Serverwith open("../conf/client1.json",'r') as f:conf = json.load(f)with open("../conf/server1.json",'r') as f:serverConf = json.load(f)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.CIFAR10(root='../data/', train=True, download=True,transform=transform)
eval_dataset = datasets.CIFAR10(root='../data/', train=False, download=True,transform=transform)# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, num_workers=2)# 计算数据集长度
total_samples = len(train_dataset)
# 确保可以平均分配,否则需要调整逻辑以处理余数
assert total_samples % 2 == 0, "数据集样本数需为偶数以便完全平分"# 分割点
split_point = total_samples // 2# 创建两个子集
train_dataset_first_half = Subset(train_dataset, range(0, split_point))
train_dataset_second_half = Subset(train_dataset, range(split_point, total_samples))# 然后为每个子集创建DataLoader
batch_size = 32train_loader_first_half = DataLoader(train_dataset_first_half, shuffle=True, batch_size=batch_size, num_workers=2)
train_loader_second_half = DataLoader(train_dataset_second_half, shuffle=True, batch_size=batch_size, num_workers=2)# 检查CUDA是否可用
if torch.cuda.is_available():device = torch.device("cuda")  # 如果CUDA可用,选择GPU
else:device = torch.device("cpu")   # 如果CUDA不可用,选择CPUlocal_model = models.get_model("resnet18")
local_model2 = models.get_model("resnet18")server = Server(serverConf,eval_dataset,device)
client1 = Client(conf, local_model, device,train_loader_first_half, 1)
client2 = Client(conf, local_model2, device,train_loader_second_half, 2)
for i in range(2):client1.train(server.global_model)client2.train(server.global_model)server.model_aggregate([client1,client2], server.global_model)server.model_eval(server.global_model)server.model_aggregate([client1], server.sub_model)server.model_eval(server.sub_model)server.model_aggregate([client2], server.sub_model)server.model_eval(server.sub_model)

相关文章:

【联邦学习——手动搭建简易联邦学习】

1. 目的 用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。 2. 代码 2.1 本地模型计算梯度更新 # 比较训练前后的参数变化 def compare_weights(new_model, old_model):weight_updates {}for layer_name, params in new_model.state_dic…...

Springboot项目如何创建单元测试

文章目录 目录 文章目录 前言 一、SpringBoot单元测试的使用 1.1 引入依赖 1.2 创建单元测试类 二、Spring Boot使用Mockito进行单元测试 2.1 Mockito中经常使用的注解以及注解的作用 2.2 使用Mockito测试类中的方法 2.3 使用Mockito测试Controller层的方法 2.4 mock…...

Win10 如何同时保留两个CUDA版本并自由切换使用

环境: Win10 专业版 CUDA11.3 CUDA11.8 问题描述: Win10 如何同时保留两个CUDA版本并自由切换 解决方案: 在同一台计算机上安装两个CUDA版本并进行切换可以通过一些环境配置来实现。这通常涉及到管理环境变量,特别是PATH和L…...

实验室纳新宣讲会(java后端)

前言 这是陈旧已久的草稿2021-09-16 15:41:38 当时我进入实验室,也是大二了,实验室纳新需要宣讲, 但是当时有疫情,又没宣讲成。 现在2024-5-12 22:00:39,发布到[个人]专栏中。 实验室纳新宣讲会(java后…...

class常量池、运行时常量池和字符串常量池的关系

类常量池、运行时常量池和字符串常量池这三种常量池,在Java中扮演着不同但又相互关联的角色。理解它们之间的关系,有助于深入理解Java虚拟机(JVM)的内部工作机制,尤其是在类加载、内存分配和字符串处理方面。 类常量池…...

Java | Leetcode Java题解之第88题合并两个有序数组

题目: 题解: class Solution {public void merge(int[] nums1, int m, int[] nums2, int n) {int p1 m - 1, p2 n - 1;int tail m n - 1;int cur;while (p1 > 0 || p2 > 0) {if (p1 -1) {cur nums2[p2--];} else if (p2 -1) {cur nums1[p…...

韵搜坊(全栈)-- 前后端初始化

文章目录 前端初始化后端初始化 前端初始化 使用ant design of vue 组件库 官网快速上手:https://www.antdv.com/docs/vue/getting-started-cn 安装脚手架工具 进入cmd $ npm install -g vue/cli # OR $ yarn global add vue/cli创建一个项目 $ vue create ant…...

Android:资源的管理,Glide图片加载框架的使用

目录 一,Android资源分类 1.使用res目录下的资源 res目录下资源的使用: 2.使用assets目录下的资源 assets目录下的资源的使用: 二,glide图片加载框架 1.glide简介 2.下载和设置 3.基本用法 4.占位符(Placehold…...

conll-2012-formatted-ontonotes-5.0中文数据格式说明

CoNLL-2012 数据格式是用于自然语言处理任务的一种常见格式,特别是在命名实体识别、词性标注、句法分析和语义角色标注等领域。这种格式在 CoNLL-2012 共享任务中被广泛使用,该任务主要集中在语义角色标注上。 CoNLL-2012 数据格式通常包括多列&#xf…...

SpringBoot集成Seata分布式事务OpenFeign远程调用

Docker Desktop 安装Seata Server seata 本质上是一个服务,用docker安装更方便,配置默认:file docker run -d --name seata-server -p 8091:8091 -p 7091:7091 seataio/seata-server:2.0.0与SpringBoot集成 表结构 项目目录 dynamic和dyna…...

视觉检测系统,是否所有产品都可以进行视觉检测?

视觉检测系统作为一种先进的质检工具,虽然具有广泛的应用范围,但并非所有产品都适合进行视觉检测。本文将探讨视觉检测系统的适用范围及其局限性。 随着机器视觉技术的快速发展,视觉检测系统已广泛应用于各个行业,为产品质检提供…...

通过金山和微软虚拟打印机转换PDF文件,流程方法及优劣对比

文章目录 一、WPS/金山 PDF虚拟打印机1、常规流程2、PDF文件位置3、严重缺陷二、微软虚拟打印机Microsoft Print to Pdf1、安装流程2、微软虚拟打印机的优势一、WPS/金山 PDF虚拟打印机 1、常规流程 安装过WPS办公组件或金山PDF独立版的电脑,会有一个或两个WPS/金山 PDF虚拟…...

采用java+B/S开发的全套医院绩效考核系统源码springboot+mybaits 医院绩效考核系统优势

采用java开发的全套医院绩效考核系统源码springbootmybaits 医院绩效考核系统优势 医院绩效管理系统解决方案紧扣新医改形势下医院绩效管理的要求,以“工作量为基础的考核方案”为核心思想,结合患者满意度、服务质量、技术难度、工作效率、医德医风等管…...

驱动开发-用户空间和内核空间数据传输

1.用户空间-->内核空间&#xff08;写&#xff09; #include<linux/uaccess.h> int copy_from_user(void *to,const void __user volatile*from,unsigned long n) 函数功能&#xff1a;将用户空间数据拷贝到内核空间 参数&#xff1a; to&#xff1a;内核空间首地…...

【408精华知识】速看!各种排序的大总结!

文章目录 一、插入排序&#xff08;一&#xff09;直接插入排序&#xff08;二&#xff09;折半插入排序&#xff08;三&#xff09;希尔排序 二、交换排序&#xff08;一&#xff09;冒泡排序&#xff08;二&#xff09;快速排序 三、选择排序&#xff08;一&#xff09;简单选…...

【STM32 |程序实例】按键控制、光敏传感器控制蜂鸣器

目录 前言 按键控制LED 光敏传感器控制蜂鸣器 前言 上拉输入&#xff1a;若GPIO引脚配置为上拉输入模式&#xff0c;在默认情况下&#xff08;GPIO引脚无输入&#xff09;&#xff0c;读取的GPIO引脚数据为1&#xff0c;即高电平。 下拉输入&#xff1a;若GPIO引脚配置为下…...

Spring boot使用websocket实现在线聊天

maven依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spr…...

品牌设计理念和logo设计方法

一 品牌设计的目的 设计是为了传播&#xff0c;让传播速度更快&#xff0c;传播效率更高&#xff0c;减少宣传成本 二 什么是好的品牌设计 好的设计是为了让消费者更容易看懂、记住的设计&#xff0c; 从而辅助传播&#xff0c; 即 看得懂、记得住。 1 看得懂 就是让别人看懂…...

Python | Leetcode Python题解之第88题合并两个有序数组

题目&#xff1a; 题解&#xff1a; class Solution:def merge(self, nums1: List[int], m: int, nums2: List[int], n: int) -> None:"""Do not return anything, modify nums1 in-place instead."""p1, p2 m - 1, n - 1tail m n - 1whi…...

vscode新版本remotessh服务端报`GLIBC_2.28‘ not found解决方案

问题现象 通过vscode的remotessh插件连接老版本服务器&#xff08;如RHEL7&#xff0c;Centos7&#xff09;时&#xff0c;插件会报错&#xff0c;无法连接。 查看插件的错误日志可以看到类似如下的报错信息&#xff1a; dc96b837cf6bb4af9cd736aa3af08cf8279f7685/node: /li…...

盘他系列——oj!!!

1.Openjudge 网站: OpenJudge 2.洛谷 网站: 首页 - 洛谷 | 计算机科学教育新生态 3.环球OJ 网站: QOJ - QOJ.ac 4. 北京大学 OJ:Welcome To PKU JudgeOnline 5.自由OJ 网站: https://loj.ac/ 6.炼码 网站:LintCode 炼码 8.力扣 网站: 力扣 9.晴练网首页 - 晴练网...

洛谷 P2657 [SCOI2009] windy 数 题解 数位dp

[SCOI2009] windy 数 题目背景 windy 定义了一种 windy 数。 题目描述 不含前导零且相邻两个数字之差至少为 2 2 2 的正整数被称为 windy 数。windy 想知道&#xff0c;在 a a a 和 b b b 之间&#xff0c;包括 a a a 和 b b b &#xff0c;总共有多少个 windy 数&…...

Python爬虫入门:网络世界的宝藏猎人

今天阿佑将带你踏上Python的肩膀&#xff0c;成为一名网络世界的宝藏猎人&#xff01; 文章目录 1. 引言1.1 简述Python在爬虫领域的地位1.2 阐明学习网络基础对爬虫的重要性 2. 背景介绍2.1 Python语言的流行与适用场景2.2 网络通信基础概念及其在数据抓取中的角色 3. Python基…...

【NodeMCU实时天气时钟温湿度项目 6】解析天气信息JSON数据并显示在 TFT 屏幕上(心知天气版)

今天是第六专题&#xff0c;主要内容是&#xff1a;导入ArduinoJson功能库&#xff0c;借助该库解析从【心知天气】官网返回的JSON数据&#xff0c;并显示在 TFT 屏幕上。 如您需要了解其它专题的内容&#xff0c;请点击下面的链接。 第一专题内容&#xff0c;请参考&a…...

重构四要素:目的、对象、时机和方法

目录 1.引言 2.重构的目的:为什么重构(why) 3.重构的对象:到底重构什么(what) 4.重构的时机:什么时候重构(when) 5.重构的方法:应该如何重构(how) 6.思考题 1.引言 一些软件工程师对为什么要重构(why)、到底重构什么(what)、什么时候重构(when)应该如何重构(how)等问题的…...

基于Echarts的大数据可视化模板:服务器运营监控

目录 引言背景介绍研究现状与相关工作服务器运营监控技术综述服务器运营监控概述监控指标与数据采集可视化界面设计与实现数据存储与查询优化Echarts与大数据可视化Echarts库以及其在大数据可视化领域的应用优势开发过程和所选设计方案模板如何满足管理的特定需求模板功能与特性…...

Python3 笔记:Python的常量

常量&#xff08;constant&#xff09;&#xff1a;跟变量相对应&#xff0c;指第一次赋予值后就保持固定不变的值。 Python里面没有声明常量的关键字&#xff0c;其他语言像C/C/Java会有const修饰符&#xff0c;但Python没有。 Python中没有使用语法强制定义常量&#xff0c…...

【Linux】自动化构建工具make/Makefile和git介绍

&#x1f308;个人主页&#xff1a;秦jh__https://blog.csdn.net/qinjh_?spm1010.2135.3001.5343&#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/qinjh_/category_12625432.html 目录 前言 Linux项目自动化构建工具-make/Makefile 举例 .PHONY 常见符号 依赖关系…...

C语言—关于字符串(编程实现部分函数功能)

0.前言 当我们使用这些函数功能时&#xff0c;可以直接调用头文件---#include<string.h>&#xff0c;然后直接使用就行了,本文只是手动编写实现函数的部分功能 1.strlen函数功能实现 功能说明&#xff1a;strlen(s)用来计算字符串s的长度&#xff0c;该函数计数不会包括最…...

picoCTF-Web Exploitation-Trickster

Description I found a web app that can help process images: PNG images only! 这应该是个上传漏洞了&#xff0c;十几年没用过了&#xff0c;不知道思路是不是一样的&#xff0c;以前的思路是通过上传漏洞想办法上传一个木马&#xff0c;拿到webshell&#xff0c;今天试试看…...