当前位置: 首页 > news >正文

【联邦学习——手动搭建简易联邦学习】

1. 目的

用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。

2. 代码

2.1 本地模型计算梯度更新

# 比较训练前后的参数变化
def compare_weights(new_model, old_model):weight_updates = {}for layer_name, params in new_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)return weight_updates

测试代码如下:
有意思的点在于我获得了update = model2-model1
但是我去计算model1+update==model2的时候发现不相等
最后思考了一下可能是在这个计算的过程中存在精度的丢失

import torch
from torch import nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoaderdef weight_init(m):if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)# 也可以判断是否为conv2d,使用相应的初始化方式elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 是否为批归一化层elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if __name__ == '__main__':# 建立model1和model2作为训练前后的模型model1 = models.get_model("resnet18")model1.apply(weight_init)model2 = models.get_model("resnet18")model2.apply(weight_init)weight_updates = compare_weights(model2,model1)# 创建一个临时的模型状态字典,用于存储更新后的参数updated_params = model1.state_dict().copy()for layer_name, update in weight_updates.items():# 确保该层存在于model1中且形状匹配,避免错误if layer_name in updated_params and update.shape == updated_params[layer_name].shape:# 直接相加更新参数updated_params[layer_name] += updateelse:print(f"Warning: Layer {layer_name} not found or shape mismatch, skipping update.")# 将更新后的参数加载回model1model1.load_state_dict(updated_params)for layer_name,params in model1.state_dict().items():ts = params-model2.state_dict().get(layer_name)# 很重要,这里应该是我们在做的时候会有一点模型精度上的损失,所以不能够计算这里等于0if torch.sum(ts).item()>1e-6:print(f"{layer_name}更新后与原有的不匹配,差距为")else:print(f"{layer_name}更新后与原有的匹配")

2.2 客户端代码

import numpy as np
import torch.utils.data
from tqdm import tqdm'''
conf 配置文件
model 模型
train_dataset 数据集
class_ratios 从数据集中筛选出一部分 class_ratios = {0: 0.5, 1: 0.5,..., 8: 0.5, 9: 0.5}
id 客户端的标识
'''class Client(object):def __init__(self,conf,model,device,train_loader,id=1):self.client_id = id                     # 客户端IDself.conf = conf                        # 配置文件self.local_model = model                # 客户端本地模型self.train_loader = train_loader        # 训练数据的迭代器,需要训练的数据已经在里面了self.grad_update = dict()               # 本地训练完之后的梯度更新self.weight = conf['weight']            # 全局模型梯度更新时的权重self.device = device                    # 训练的设备self.local_model.to(self.device)        # 将模型放入训练设备def train(self, model):self._before_train(model)self._local_train()self._after_train(model)def _before_train(self, model):self._load_global_model(model)# 用服务器模型来覆盖本地模型def _load_global_model(self,model):for name,param in model.state_dict().items():# 客户端首先用服务器端下发的全局模型覆盖本地模型self.local_model.state_dict()[name].copy_(param.clone())def _local_train(self):# 定义最优化函数器,用于本地模型训练optimizer = torch.optim.SGD(self.local_model.parameters(),lr=self.conf['lr'],momentum=self.conf['momentum'])# 本地模型训练self.local_model.train()loss = 0for epoch in range(self.conf['local_epochs']):for batch in tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.conf['local_epochs']}"):data, target = batch# 放入相应的设备data = data.to(self.device)target = target.to(self.device)# 梯度清零optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)# 反向传播loss.backward()optimizer.step()print(f"Client{self.client_id}----Epoch {epoch} done.Loss {loss}")def _after_train(self,model):self._cal_update_weights(model)def _cal_update_weights(self, old_model):weight_updates = dict()for layer_name, params in self.local_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)# 更新梯度模型的权重self.grad_update = weight_updates

2.3 服务器代码

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transformsfrom utils.CommonUtils import copy_model_params# 服务端
class Server(object):def __init__(self, conf, eval_dataset, device):self.conf = conf# 全局老模型self.old_model = models.get_model(self.conf["model_name"])# 全局的新模型self.global_model = models.get_model(self.conf["model_name"])# 创建时保持新老模型的参数是一致的copy_model_params(self.old_model,self.global_model)# 根据客户端上传的梯度进行排列组合,用于测量贡献度的模型self.sub_model = models.get_model(self.conf["model_name"])self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)self.accuracy_history = []  # 保存accuracy的数组self.loss_history = []  # 保存loss的数组self.device = deviceself.old_model.to(device)self.global_model.to(device)self.sub_model.to(device)# 模型重构def model_aggregate(self, clients, target_model):if target_model == self.global_model:print("++++++++全局模型更新++++++++")# 更新一下老模型参数copy_model_params(self.old_model,self.global_model)else:print("========子模型重构========")sum_weight = 0# 计算总的权重for client in clients:sum_weight += client.weight# 将old_model的模型参数赋值给sub_modelcopy_model_params(self.sub_model, self.old_model)# 初始化一个空字典来累积客户端的模型更新aggregated_updates = {}# 遍历每个客户端for client in clients:# 根据客户端的权重比例聚合更新for name, update in client.grad_update.items():if name not in aggregated_updates:aggregated_updates[name] = update * client.weight / sum_weightelse:aggregated_updates[name] += update * client.weight / sum_weight# 应用聚合后的更新到sub_modelfor name, param in target_model.state_dict().items():if name in aggregated_updates:param.copy_(param + aggregated_updates[name])  # 累加更新到当前层参数上# 定义模型评估函数def model_eval(self,target_model):target_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id,batch in enumerate(self.eval_loader):data,target = batchdataset_size += data.size()[0]# 放入和模型对应的设备data = data.to(self.device)target = target.to(self.device)# 模型预测output = target_model(data)# 把损失值聚合起来total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 获取最大的对数概率的索引值pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()# 计算准确率acc = 100.0 * (float(correct) / float(dataset_size))# 计算损失值total_l = total_loss / dataset_size# 将accuracy和loss保存到数组中self.accuracy_history.append(acc)self.loss_history.append(total_l)if target_model == self.global_model:print(f"++++++++全局模型评估++++++++acc:{acc}  loss:{total_l}")else:print(f"========子模型评估========acc:{acc}  loss:{total_l}")return acc,total_ldef save_results_to_file(self):# 将accuracy和loss保存到文件中with open("fed_accuracy_history.txt", "w") as f:for acc in self.accuracy_history:f.write("{:.2f}\n".format(acc))with open("fed_loss_history.txt", "w") as f:for loss in self.loss_history:f.write("{:.4f}\n".format(loss))

2.4 Utils

def copy_model_params(target_model, source_model):for name, param in source_model.state_dict().items():target_model.state_dict()[name].copy_(param.clone())

3. 运行测试代码

import jsonimport torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torchvision import transforms, datasets
from client.Client import Client
from server.Server import Serverwith open("../conf/client1.json",'r') as f:conf = json.load(f)with open("../conf/server1.json",'r') as f:serverConf = json.load(f)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.CIFAR10(root='../data/', train=True, download=True,transform=transform)
eval_dataset = datasets.CIFAR10(root='../data/', train=False, download=True,transform=transform)# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, num_workers=2)# 计算数据集长度
total_samples = len(train_dataset)
# 确保可以平均分配,否则需要调整逻辑以处理余数
assert total_samples % 2 == 0, "数据集样本数需为偶数以便完全平分"# 分割点
split_point = total_samples // 2# 创建两个子集
train_dataset_first_half = Subset(train_dataset, range(0, split_point))
train_dataset_second_half = Subset(train_dataset, range(split_point, total_samples))# 然后为每个子集创建DataLoader
batch_size = 32train_loader_first_half = DataLoader(train_dataset_first_half, shuffle=True, batch_size=batch_size, num_workers=2)
train_loader_second_half = DataLoader(train_dataset_second_half, shuffle=True, batch_size=batch_size, num_workers=2)# 检查CUDA是否可用
if torch.cuda.is_available():device = torch.device("cuda")  # 如果CUDA可用,选择GPU
else:device = torch.device("cpu")   # 如果CUDA不可用,选择CPUlocal_model = models.get_model("resnet18")
local_model2 = models.get_model("resnet18")server = Server(serverConf,eval_dataset,device)
client1 = Client(conf, local_model, device,train_loader_first_half, 1)
client2 = Client(conf, local_model2, device,train_loader_second_half, 2)
for i in range(2):client1.train(server.global_model)client2.train(server.global_model)server.model_aggregate([client1,client2], server.global_model)server.model_eval(server.global_model)server.model_aggregate([client1], server.sub_model)server.model_eval(server.sub_model)server.model_aggregate([client2], server.sub_model)server.model_eval(server.sub_model)

相关文章:

【联邦学习——手动搭建简易联邦学习】

1. 目的 用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。 2. 代码 2.1 本地模型计算梯度更新 # 比较训练前后的参数变化 def compare_weights(new_model, old_model):weight_updates {}for layer_name, params in new_model.state_dic…...

Springboot项目如何创建单元测试

文章目录 目录 文章目录 前言 一、SpringBoot单元测试的使用 1.1 引入依赖 1.2 创建单元测试类 二、Spring Boot使用Mockito进行单元测试 2.1 Mockito中经常使用的注解以及注解的作用 2.2 使用Mockito测试类中的方法 2.3 使用Mockito测试Controller层的方法 2.4 mock…...

Win10 如何同时保留两个CUDA版本并自由切换使用

环境: Win10 专业版 CUDA11.3 CUDA11.8 问题描述: Win10 如何同时保留两个CUDA版本并自由切换 解决方案: 在同一台计算机上安装两个CUDA版本并进行切换可以通过一些环境配置来实现。这通常涉及到管理环境变量,特别是PATH和L…...

实验室纳新宣讲会(java后端)

前言 这是陈旧已久的草稿2021-09-16 15:41:38 当时我进入实验室,也是大二了,实验室纳新需要宣讲, 但是当时有疫情,又没宣讲成。 现在2024-5-12 22:00:39,发布到[个人]专栏中。 实验室纳新宣讲会(java后…...

class常量池、运行时常量池和字符串常量池的关系

类常量池、运行时常量池和字符串常量池这三种常量池,在Java中扮演着不同但又相互关联的角色。理解它们之间的关系,有助于深入理解Java虚拟机(JVM)的内部工作机制,尤其是在类加载、内存分配和字符串处理方面。 类常量池…...

Java | Leetcode Java题解之第88题合并两个有序数组

题目: 题解: class Solution {public void merge(int[] nums1, int m, int[] nums2, int n) {int p1 m - 1, p2 n - 1;int tail m n - 1;int cur;while (p1 > 0 || p2 > 0) {if (p1 -1) {cur nums2[p2--];} else if (p2 -1) {cur nums1[p…...

韵搜坊(全栈)-- 前后端初始化

文章目录 前端初始化后端初始化 前端初始化 使用ant design of vue 组件库 官网快速上手:https://www.antdv.com/docs/vue/getting-started-cn 安装脚手架工具 进入cmd $ npm install -g vue/cli # OR $ yarn global add vue/cli创建一个项目 $ vue create ant…...

Android:资源的管理,Glide图片加载框架的使用

目录 一,Android资源分类 1.使用res目录下的资源 res目录下资源的使用: 2.使用assets目录下的资源 assets目录下的资源的使用: 二,glide图片加载框架 1.glide简介 2.下载和设置 3.基本用法 4.占位符(Placehold…...

conll-2012-formatted-ontonotes-5.0中文数据格式说明

CoNLL-2012 数据格式是用于自然语言处理任务的一种常见格式,特别是在命名实体识别、词性标注、句法分析和语义角色标注等领域。这种格式在 CoNLL-2012 共享任务中被广泛使用,该任务主要集中在语义角色标注上。 CoNLL-2012 数据格式通常包括多列&#xf…...

SpringBoot集成Seata分布式事务OpenFeign远程调用

Docker Desktop 安装Seata Server seata 本质上是一个服务,用docker安装更方便,配置默认:file docker run -d --name seata-server -p 8091:8091 -p 7091:7091 seataio/seata-server:2.0.0与SpringBoot集成 表结构 项目目录 dynamic和dyna…...

视觉检测系统,是否所有产品都可以进行视觉检测?

视觉检测系统作为一种先进的质检工具,虽然具有广泛的应用范围,但并非所有产品都适合进行视觉检测。本文将探讨视觉检测系统的适用范围及其局限性。 随着机器视觉技术的快速发展,视觉检测系统已广泛应用于各个行业,为产品质检提供…...

通过金山和微软虚拟打印机转换PDF文件,流程方法及优劣对比

文章目录 一、WPS/金山 PDF虚拟打印机1、常规流程2、PDF文件位置3、严重缺陷二、微软虚拟打印机Microsoft Print to Pdf1、安装流程2、微软虚拟打印机的优势一、WPS/金山 PDF虚拟打印机 1、常规流程 安装过WPS办公组件或金山PDF独立版的电脑,会有一个或两个WPS/金山 PDF虚拟…...

采用java+B/S开发的全套医院绩效考核系统源码springboot+mybaits 医院绩效考核系统优势

采用java开发的全套医院绩效考核系统源码springbootmybaits 医院绩效考核系统优势 医院绩效管理系统解决方案紧扣新医改形势下医院绩效管理的要求,以“工作量为基础的考核方案”为核心思想,结合患者满意度、服务质量、技术难度、工作效率、医德医风等管…...

驱动开发-用户空间和内核空间数据传输

1.用户空间-->内核空间&#xff08;写&#xff09; #include<linux/uaccess.h> int copy_from_user(void *to,const void __user volatile*from,unsigned long n) 函数功能&#xff1a;将用户空间数据拷贝到内核空间 参数&#xff1a; to&#xff1a;内核空间首地…...

【408精华知识】速看!各种排序的大总结!

文章目录 一、插入排序&#xff08;一&#xff09;直接插入排序&#xff08;二&#xff09;折半插入排序&#xff08;三&#xff09;希尔排序 二、交换排序&#xff08;一&#xff09;冒泡排序&#xff08;二&#xff09;快速排序 三、选择排序&#xff08;一&#xff09;简单选…...

【STM32 |程序实例】按键控制、光敏传感器控制蜂鸣器

目录 前言 按键控制LED 光敏传感器控制蜂鸣器 前言 上拉输入&#xff1a;若GPIO引脚配置为上拉输入模式&#xff0c;在默认情况下&#xff08;GPIO引脚无输入&#xff09;&#xff0c;读取的GPIO引脚数据为1&#xff0c;即高电平。 下拉输入&#xff1a;若GPIO引脚配置为下…...

Spring boot使用websocket实现在线聊天

maven依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactId>spr…...

品牌设计理念和logo设计方法

一 品牌设计的目的 设计是为了传播&#xff0c;让传播速度更快&#xff0c;传播效率更高&#xff0c;减少宣传成本 二 什么是好的品牌设计 好的设计是为了让消费者更容易看懂、记住的设计&#xff0c; 从而辅助传播&#xff0c; 即 看得懂、记得住。 1 看得懂 就是让别人看懂…...

Python | Leetcode Python题解之第88题合并两个有序数组

题目&#xff1a; 题解&#xff1a; class Solution:def merge(self, nums1: List[int], m: int, nums2: List[int], n: int) -> None:"""Do not return anything, modify nums1 in-place instead."""p1, p2 m - 1, n - 1tail m n - 1whi…...

vscode新版本remotessh服务端报`GLIBC_2.28‘ not found解决方案

问题现象 通过vscode的remotessh插件连接老版本服务器&#xff08;如RHEL7&#xff0c;Centos7&#xff09;时&#xff0c;插件会报错&#xff0c;无法连接。 查看插件的错误日志可以看到类似如下的报错信息&#xff1a; dc96b837cf6bb4af9cd736aa3af08cf8279f7685/node: /li…...

C#开发者必看:INIFileParser库解决INI配置文件乱码问题的实战指南

C#开发者必看&#xff1a;INIFileParser库解决INI配置文件乱码问题的实战指南 在Windows应用开发中&#xff0c;INI文件作为一种轻量级配置存储格式&#xff0c;至今仍被广泛使用。但许多C#开发者发现&#xff0c;当配置文件路径包含中文、空格或特殊字符时&#xff0c;传统的W…...

如何用Mask2Former实现通用图像分割:从单一模型到多任务实战

如何用Mask2Former实现通用图像分割&#xff1a;从单一模型到多任务实战 【免费下载链接】Mask2Former Code release for "Masked-attention Mask Transformer for Universal Image Segmentation" 项目地址: https://gitcode.com/gh_mirrors/ma/Mask2Former 图…...

不只是代码:用 Agent 构建全自动数据分析与可视化流水线

不只是代码:用 Agent 构建全自动数据分析与可视化流水线 1. 引入与连接:从"手动分析"到"智能自动化"的范式转变 1.1 一个数据分析师的典型一天 让我们想象一下,李明是一家互联网公司的资深数据分析师。每天早上9点,他都会坐在电脑前开始他熟悉的工作…...

C# 异步编程在 AI 应用中的最佳实践

一、引言 AI 应用开发中的异步需求 在当今的人工智能应用开发领域,异步编程已经成为不可或缺的核心技术。当我们与 AI 大模型进行交互时,网络请求的延迟、流式响应的处理、并发调用多个模型——这些场景无不对程序的响应能力和吞吐量提出了极高要求。传统的同步编程模式在面…...

如何用Fuel构建类型安全的GraphQL客户端:终极完整指南

如何用Fuel构建类型安全的GraphQL客户端&#xff1a;终极完整指南 【免费下载链接】fuel The easiest HTTP networking library for Kotlin/Android 项目地址: https://gitcode.com/gh_mirrors/fu/fuel Fuel是Kotlin/Android平台上最简单易用的HTTP网络库&#xff0c;它…...

快速构建SpringBoot微服务:Phi-3-mini智能代码生成与架构咨询

快速构建SpringBoot微服务&#xff1a;Phi-3-mini智能代码生成与架构咨询 1. 引言&#xff1a;当AI助手遇上Java开发 最近接手了一个新项目&#xff0c;需要快速搭建一套SpringBoot微服务架构。正当我对着空白的IDE发愁时&#xff0c;同事推荐了Phi-3-mini这个AI助手。说实话…...

Android Studio中文语言包快速上手实用指南

Android Studio中文语言包快速上手实用指南 【免费下载链接】AndroidStudioChineseLanguagePack AndroidStudio中文插件(官方修改版本&#xff09; 项目地址: https://gitcode.com/gh_mirrors/an/AndroidStudioChineseLanguagePack Android Studio作为主流的Android开发…...

一键部署:圣女司幼幽-造相Z-Turbo文生图模型,快速体验AI绘画魅力

一键部署&#xff1a;圣女司幼幽-造相Z-Turbo文生图模型&#xff0c;快速体验AI绘画魅力 1. 模型简介与特点 1.1 什么是圣女司幼幽-造相Z-Turbo&#xff1f; 圣女司幼幽-造相Z-Turbo是一款基于Xinference框架部署的AI绘画模型&#xff0c;专门用于生成《牧神记》中圣女司幼幽…...

如何用League Director制作电影级英雄联盟视频?6个专业技巧让你的录像脱颖而出

如何用League Director制作电影级英雄联盟视频&#xff1f;6个专业技巧让你的录像脱颖而出 【免费下载链接】leaguedirector League Director is a tool for staging and recording videos from League of Legends replays 项目地址: https://gitcode.com/gh_mirrors/le/leag…...

南北阁4.1-3B极简WebUI入门必看:无需React/Vue的纯Python前端方案

南北阁4.1-3B极简WebUI入门必看&#xff1a;无需React/Vue的纯Python前端方案 想给本地部署的南北阁&#xff08;Nanbeige&#xff09;4.1-3B大模型配一个好看又好用的聊天界面&#xff0c;是不是一想到要学React、Vue这些前端框架就头大&#xff1f;或者觉得Streamlit做出来的…...