【联邦学习——手动搭建简易联邦学习】
1. 目的
用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。
2. 代码
2.1 本地模型计算梯度更新
# 比较训练前后的参数变化
def compare_weights(new_model, old_model):weight_updates = {}for layer_name, params in new_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)return weight_updates
测试代码如下:
有意思的点在于我获得了update = model2-model1
但是我去计算model1+update==model2的时候发现不相等
最后思考了一下可能是在这个计算的过程中存在精度的丢失
import torch
from torch import nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoaderdef weight_init(m):if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)# 也可以判断是否为conv2d,使用相应的初始化方式elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 是否为批归一化层elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if __name__ == '__main__':# 建立model1和model2作为训练前后的模型model1 = models.get_model("resnet18")model1.apply(weight_init)model2 = models.get_model("resnet18")model2.apply(weight_init)weight_updates = compare_weights(model2,model1)# 创建一个临时的模型状态字典,用于存储更新后的参数updated_params = model1.state_dict().copy()for layer_name, update in weight_updates.items():# 确保该层存在于model1中且形状匹配,避免错误if layer_name in updated_params and update.shape == updated_params[layer_name].shape:# 直接相加更新参数updated_params[layer_name] += updateelse:print(f"Warning: Layer {layer_name} not found or shape mismatch, skipping update.")# 将更新后的参数加载回model1model1.load_state_dict(updated_params)for layer_name,params in model1.state_dict().items():ts = params-model2.state_dict().get(layer_name)# 很重要,这里应该是我们在做的时候会有一点模型精度上的损失,所以不能够计算这里等于0if torch.sum(ts).item()>1e-6:print(f"{layer_name}更新后与原有的不匹配,差距为")else:print(f"{layer_name}更新后与原有的匹配")
2.2 客户端代码
import numpy as np
import torch.utils.data
from tqdm import tqdm'''
conf 配置文件
model 模型
train_dataset 数据集
class_ratios 从数据集中筛选出一部分 class_ratios = {0: 0.5, 1: 0.5,..., 8: 0.5, 9: 0.5}
id 客户端的标识
'''class Client(object):def __init__(self,conf,model,device,train_loader,id=1):self.client_id = id # 客户端IDself.conf = conf # 配置文件self.local_model = model # 客户端本地模型self.train_loader = train_loader # 训练数据的迭代器,需要训练的数据已经在里面了self.grad_update = dict() # 本地训练完之后的梯度更新self.weight = conf['weight'] # 全局模型梯度更新时的权重self.device = device # 训练的设备self.local_model.to(self.device) # 将模型放入训练设备def train(self, model):self._before_train(model)self._local_train()self._after_train(model)def _before_train(self, model):self._load_global_model(model)# 用服务器模型来覆盖本地模型def _load_global_model(self,model):for name,param in model.state_dict().items():# 客户端首先用服务器端下发的全局模型覆盖本地模型self.local_model.state_dict()[name].copy_(param.clone())def _local_train(self):# 定义最优化函数器,用于本地模型训练optimizer = torch.optim.SGD(self.local_model.parameters(),lr=self.conf['lr'],momentum=self.conf['momentum'])# 本地模型训练self.local_model.train()loss = 0for epoch in range(self.conf['local_epochs']):for batch in tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.conf['local_epochs']}"):data, target = batch# 放入相应的设备data = data.to(self.device)target = target.to(self.device)# 梯度清零optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)# 反向传播loss.backward()optimizer.step()print(f"Client{self.client_id}----Epoch {epoch} done.Loss {loss}")def _after_train(self,model):self._cal_update_weights(model)def _cal_update_weights(self, old_model):weight_updates = dict()for layer_name, params in self.local_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)# 更新梯度模型的权重self.grad_update = weight_updates
2.3 服务器代码
import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transformsfrom utils.CommonUtils import copy_model_params# 服务端
class Server(object):def __init__(self, conf, eval_dataset, device):self.conf = conf# 全局老模型self.old_model = models.get_model(self.conf["model_name"])# 全局的新模型self.global_model = models.get_model(self.conf["model_name"])# 创建时保持新老模型的参数是一致的copy_model_params(self.old_model,self.global_model)# 根据客户端上传的梯度进行排列组合,用于测量贡献度的模型self.sub_model = models.get_model(self.conf["model_name"])self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)self.accuracy_history = [] # 保存accuracy的数组self.loss_history = [] # 保存loss的数组self.device = deviceself.old_model.to(device)self.global_model.to(device)self.sub_model.to(device)# 模型重构def model_aggregate(self, clients, target_model):if target_model == self.global_model:print("++++++++全局模型更新++++++++")# 更新一下老模型参数copy_model_params(self.old_model,self.global_model)else:print("========子模型重构========")sum_weight = 0# 计算总的权重for client in clients:sum_weight += client.weight# 将old_model的模型参数赋值给sub_modelcopy_model_params(self.sub_model, self.old_model)# 初始化一个空字典来累积客户端的模型更新aggregated_updates = {}# 遍历每个客户端for client in clients:# 根据客户端的权重比例聚合更新for name, update in client.grad_update.items():if name not in aggregated_updates:aggregated_updates[name] = update * client.weight / sum_weightelse:aggregated_updates[name] += update * client.weight / sum_weight# 应用聚合后的更新到sub_modelfor name, param in target_model.state_dict().items():if name in aggregated_updates:param.copy_(param + aggregated_updates[name]) # 累加更新到当前层参数上# 定义模型评估函数def model_eval(self,target_model):target_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id,batch in enumerate(self.eval_loader):data,target = batchdataset_size += data.size()[0]# 放入和模型对应的设备data = data.to(self.device)target = target.to(self.device)# 模型预测output = target_model(data)# 把损失值聚合起来total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 获取最大的对数概率的索引值pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()# 计算准确率acc = 100.0 * (float(correct) / float(dataset_size))# 计算损失值total_l = total_loss / dataset_size# 将accuracy和loss保存到数组中self.accuracy_history.append(acc)self.loss_history.append(total_l)if target_model == self.global_model:print(f"++++++++全局模型评估++++++++acc:{acc} loss:{total_l}")else:print(f"========子模型评估========acc:{acc} loss:{total_l}")return acc,total_ldef save_results_to_file(self):# 将accuracy和loss保存到文件中with open("fed_accuracy_history.txt", "w") as f:for acc in self.accuracy_history:f.write("{:.2f}\n".format(acc))with open("fed_loss_history.txt", "w") as f:for loss in self.loss_history:f.write("{:.4f}\n".format(loss))
2.4 Utils
def copy_model_params(target_model, source_model):for name, param in source_model.state_dict().items():target_model.state_dict()[name].copy_(param.clone())
3. 运行测试代码
import jsonimport torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torchvision import transforms, datasets
from client.Client import Client
from server.Server import Serverwith open("../conf/client1.json",'r') as f:conf = json.load(f)with open("../conf/server1.json",'r') as f:serverConf = json.load(f)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.CIFAR10(root='../data/', train=True, download=True,transform=transform)
eval_dataset = datasets.CIFAR10(root='../data/', train=False, download=True,transform=transform)# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, num_workers=2)# 计算数据集长度
total_samples = len(train_dataset)
# 确保可以平均分配,否则需要调整逻辑以处理余数
assert total_samples % 2 == 0, "数据集样本数需为偶数以便完全平分"# 分割点
split_point = total_samples // 2# 创建两个子集
train_dataset_first_half = Subset(train_dataset, range(0, split_point))
train_dataset_second_half = Subset(train_dataset, range(split_point, total_samples))# 然后为每个子集创建DataLoader
batch_size = 32train_loader_first_half = DataLoader(train_dataset_first_half, shuffle=True, batch_size=batch_size, num_workers=2)
train_loader_second_half = DataLoader(train_dataset_second_half, shuffle=True, batch_size=batch_size, num_workers=2)# 检查CUDA是否可用
if torch.cuda.is_available():device = torch.device("cuda") # 如果CUDA可用,选择GPU
else:device = torch.device("cpu") # 如果CUDA不可用,选择CPUlocal_model = models.get_model("resnet18")
local_model2 = models.get_model("resnet18")server = Server(serverConf,eval_dataset,device)
client1 = Client(conf, local_model, device,train_loader_first_half, 1)
client2 = Client(conf, local_model2, device,train_loader_second_half, 2)
for i in range(2):client1.train(server.global_model)client2.train(server.global_model)server.model_aggregate([client1,client2], server.global_model)server.model_eval(server.global_model)server.model_aggregate([client1], server.sub_model)server.model_eval(server.sub_model)server.model_aggregate([client2], server.sub_model)server.model_eval(server.sub_model)相关文章:
【联邦学习——手动搭建简易联邦学习】
1. 目的 用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。 2. 代码 2.1 本地模型计算梯度更新 # 比较训练前后的参数变化 def compare_weights(new_model, old_model):weight_updates {}for layer_name, params in new_model.state_dic…...
Springboot项目如何创建单元测试
文章目录 目录 文章目录 前言 一、SpringBoot单元测试的使用 1.1 引入依赖 1.2 创建单元测试类 二、Spring Boot使用Mockito进行单元测试 2.1 Mockito中经常使用的注解以及注解的作用 2.2 使用Mockito测试类中的方法 2.3 使用Mockito测试Controller层的方法 2.4 mock…...
Win10 如何同时保留两个CUDA版本并自由切换使用
环境: Win10 专业版 CUDA11.3 CUDA11.8 问题描述: Win10 如何同时保留两个CUDA版本并自由切换 解决方案: 在同一台计算机上安装两个CUDA版本并进行切换可以通过一些环境配置来实现。这通常涉及到管理环境变量,特别是PATH和L…...
实验室纳新宣讲会(java后端)
前言 这是陈旧已久的草稿2021-09-16 15:41:38 当时我进入实验室,也是大二了,实验室纳新需要宣讲, 但是当时有疫情,又没宣讲成。 现在2024-5-12 22:00:39,发布到[个人]专栏中。 实验室纳新宣讲会(java后…...
class常量池、运行时常量池和字符串常量池的关系
类常量池、运行时常量池和字符串常量池这三种常量池,在Java中扮演着不同但又相互关联的角色。理解它们之间的关系,有助于深入理解Java虚拟机(JVM)的内部工作机制,尤其是在类加载、内存分配和字符串处理方面。 类常量池…...
Java | Leetcode Java题解之第88题合并两个有序数组
题目: 题解: class Solution {public void merge(int[] nums1, int m, int[] nums2, int n) {int p1 m - 1, p2 n - 1;int tail m n - 1;int cur;while (p1 > 0 || p2 > 0) {if (p1 -1) {cur nums2[p2--];} else if (p2 -1) {cur nums1[p…...
韵搜坊(全栈)-- 前后端初始化
文章目录 前端初始化后端初始化 前端初始化 使用ant design of vue 组件库 官网快速上手:https://www.antdv.com/docs/vue/getting-started-cn 安装脚手架工具 进入cmd $ npm install -g vue/cli # OR $ yarn global add vue/cli创建一个项目 $ vue create ant…...
Android:资源的管理,Glide图片加载框架的使用
目录 一,Android资源分类 1.使用res目录下的资源 res目录下资源的使用: 2.使用assets目录下的资源 assets目录下的资源的使用: 二,glide图片加载框架 1.glide简介 2.下载和设置 3.基本用法 4.占位符(Placehold…...
conll-2012-formatted-ontonotes-5.0中文数据格式说明
CoNLL-2012 数据格式是用于自然语言处理任务的一种常见格式,特别是在命名实体识别、词性标注、句法分析和语义角色标注等领域。这种格式在 CoNLL-2012 共享任务中被广泛使用,该任务主要集中在语义角色标注上。 CoNLL-2012 数据格式通常包括多列…...
SpringBoot集成Seata分布式事务OpenFeign远程调用
Docker Desktop 安装Seata Server seata 本质上是一个服务,用docker安装更方便,配置默认:file docker run -d --name seata-server -p 8091:8091 -p 7091:7091 seataio/seata-server:2.0.0与SpringBoot集成 表结构 项目目录 dynamic和dyna…...
视觉检测系统,是否所有产品都可以进行视觉检测?
视觉检测系统作为一种先进的质检工具,虽然具有广泛的应用范围,但并非所有产品都适合进行视觉检测。本文将探讨视觉检测系统的适用范围及其局限性。 随着机器视觉技术的快速发展,视觉检测系统已广泛应用于各个行业,为产品质检提供…...
通过金山和微软虚拟打印机转换PDF文件,流程方法及优劣对比
文章目录 一、WPS/金山 PDF虚拟打印机1、常规流程2、PDF文件位置3、严重缺陷二、微软虚拟打印机Microsoft Print to Pdf1、安装流程2、微软虚拟打印机的优势一、WPS/金山 PDF虚拟打印机 1、常规流程 安装过WPS办公组件或金山PDF独立版的电脑,会有一个或两个WPS/金山 PDF虚拟…...
采用java+B/S开发的全套医院绩效考核系统源码springboot+mybaits 医院绩效考核系统优势
采用java开发的全套医院绩效考核系统源码springbootmybaits 医院绩效考核系统优势 医院绩效管理系统解决方案紧扣新医改形势下医院绩效管理的要求,以“工作量为基础的考核方案”为核心思想,结合患者满意度、服务质量、技术难度、工作效率、医德医风等管…...
驱动开发-用户空间和内核空间数据传输
1.用户空间-->内核空间(写) #include<linux/uaccess.h> int copy_from_user(void *to,const void __user volatile*from,unsigned long n) 函数功能:将用户空间数据拷贝到内核空间 参数: to:内核空间首地…...
【408精华知识】速看!各种排序的大总结!
文章目录 一、插入排序(一)直接插入排序(二)折半插入排序(三)希尔排序 二、交换排序(一)冒泡排序(二)快速排序 三、选择排序(一)简单选…...
【STM32 |程序实例】按键控制、光敏传感器控制蜂鸣器
目录 前言 按键控制LED 光敏传感器控制蜂鸣器 前言 上拉输入:若GPIO引脚配置为上拉输入模式,在默认情况下(GPIO引脚无输入),读取的GPIO引脚数据为1,即高电平。 下拉输入:若GPIO引脚配置为下…...
Spring boot使用websocket实现在线聊天
maven依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spr…...
品牌设计理念和logo设计方法
一 品牌设计的目的 设计是为了传播,让传播速度更快,传播效率更高,减少宣传成本 二 什么是好的品牌设计 好的设计是为了让消费者更容易看懂、记住的设计, 从而辅助传播, 即 看得懂、记得住。 1 看得懂 就是让别人看懂…...
Python | Leetcode Python题解之第88题合并两个有序数组
题目: 题解: class Solution:def merge(self, nums1: List[int], m: int, nums2: List[int], n: int) -> None:"""Do not return anything, modify nums1 in-place instead."""p1, p2 m - 1, n - 1tail m n - 1whi…...
vscode新版本remotessh服务端报`GLIBC_2.28‘ not found解决方案
问题现象 通过vscode的remotessh插件连接老版本服务器(如RHEL7,Centos7)时,插件会报错,无法连接。 查看插件的错误日志可以看到类似如下的报错信息: dc96b837cf6bb4af9cd736aa3af08cf8279f7685/node: /li…...
RestClient
什么是RestClient RestClient 是 Elasticsearch 官方提供的 Java 低级 REST 客户端,它允许HTTP与Elasticsearch 集群通信,而无需处理 JSON 序列化/反序列化等底层细节。它是 Elasticsearch Java API 客户端的基础。 RestClient 主要特点 轻量级ÿ…...
挑战杯推荐项目
“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 - 个性化梦境…...
【网络】每天掌握一个Linux命令 - iftop
在Linux系统中,iftop是网络管理的得力助手,能实时监控网络流量、连接情况等,帮助排查网络异常。接下来从多方面详细介绍它。 目录 【网络】每天掌握一个Linux命令 - iftop工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
练习(含atoi的模拟实现,自定义类型等练习)
一、结构体大小的计算及位段 (结构体大小计算及位段 详解请看:自定义类型:结构体进阶-CSDN博客) 1.在32位系统环境,编译选项为4字节对齐,那么sizeof(A)和sizeof(B)是多少? #pragma pack(4)st…...
Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)
概述 在 Swift 开发语言中,各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过,在涉及到多个子类派生于基类进行多态模拟的场景下,…...
UE5 学习系列(三)创建和移动物体
这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...
SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)
上一章用到了V2 的概念,其实 Fiori当中还有 V4,咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务),代理中间件(ui5-middleware-simpleproxy)-CSDN博客…...
基于Java+MySQL实现(GUI)客户管理系统
客户资料管理系统的设计与实现 第一章 需求分析 1.1 需求总体介绍 本项目为了方便维护客户信息为了方便维护客户信息,对客户进行统一管理,可以把所有客户信息录入系统,进行维护和统计功能。可通过文件的方式保存相关录入数据,对…...
