paddle2.3-基于联邦学习实现FedAVg算法
目录
1. 联邦学习介绍
2. 实验流程
3. 数据加载
4. 模型构建
5. 数据采样函数
6. 模型训练

1. 联邦学习介绍
联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。
联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。
本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区
2. 实验流程
联邦学习的基本流程是:
1. server初始化模型参数,所有的clients将这个初始模型下载到本地;
2. clients利用本地产生的数据进行SGD训练;
3. 选取K个clients将训练得到的模型参数上传到server;
4. server对得到的模型参数整合,所有的clients下载新的模型。
5. 重复执行2-5,直至收敛或达到预期要求
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F
3. 数据加载
mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing
4. 模型构建
class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y
5. 数据采样函数
# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict
class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)
6. 模型训练
class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss
train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k] weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")
W0605 23:22:00.961916 10307 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0605 23:22:00.966121 10307 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.Round: 10... Average Loss: 0.033
Round: 20... Average Loss: 0.011
Round: 30... Average Loss: 0.012
Round: 40... Average Loss: 0.008
Round: 50... Average Loss: 0.003
Round: 60... Average Loss: 0.002
Round: 70... Average Loss: 0.001
Round: 80... Average Loss: 0.001
Round: 90... Average Loss: 0.001
相关文章:
paddle2.3-基于联邦学习实现FedAVg算法
目录 1. 联邦学习介绍 2. 实验流程 3. 数据加载 4. 模型构建 5. 数据采样函数 6. 模型训练 1. 联邦学习介绍 联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备&#…...
伺服丝杠系统常用运算功能块
这篇博客主要介绍伺服、丝杠系统常用的运算功能块,其它相关运算可以查看下面文章链接: 信捷PLC脉冲频率、位移、转速相关计算(C语言编程应用)_RXXW_Dor的博客-CSDN博客里工业控制张力控制无处不在,也衍生出很多张力控制专用控制器,磁粉制动器等,本篇博客主要讨论PLC的张力…...
【Vue】模板语法,事件处理器及综合案例、自定义组件、组件通信
一、事件处理器 我们之前事件监听可以使用v-on 指令 1、事件修饰符 在Vue中我们通过由点(.)表示的指令后缀来调用修饰符,比如: .stop:阻止事件冒泡。当事件触发时,该修饰符将停止事件进一步冒泡到父元素。相当于调用了 event.stop…...
从0开始写中国象棋-创建棋盘与棋子
从控制台版本开始 考虑到象棋程序,其实就是数据结构与算法实现。 所以和界面相关的QT部分我们先放一放。 我们从控制台版本开始。这样大家更容易接受,也不影响开发。 后面我们会把控制台嫁接到QT上完成完整的游戏,那时候自然就水到渠成了…...
软件的开发步骤,需求分析,开发环境搭建,接口文档 ---苍穹外卖1
目录 项目总览 开发准备 开发步骤 角色分工 软件环境 项目介绍 产品原型 技术选型 开发环境搭建 前端:默认已有 后端 使用Git版本控制 数据库环境搭建 前后端联调 登录功能完善 导入接口文档 使用swagger 和yapi的区别 常用注解 项目总览 开发准备 开发步骤…...
Qt扫盲-QSqlQueryModel理论总结
QSqlQueryModel理论总结 一、概述二、使用1. 与 view 视图 绑定2. 分离视图,只存数据 一、概述 QSqlQueryModel是用于执行SQL语句和遍历结果集的高级接口。它构建在较低级的 QSqlQuery之上,可用于向QTableView 等视图类提供数据,也是使用了Q…...
分类预测 | MATLAB实现WOA-CNN-BiGRU-Attention数据分类预测(SE注意力机制)
分类预测 | MATLAB实现WOA-CNN-BiGRU-Attention数据分类预测(SE注意力机制) 目录 分类预测 | MATLAB实现WOA-CNN-BiGRU-Attention数据分类预测(SE注意力机制)分类效果基本描述模型描述程序设计参考资料 分类效果 基本描述 1.MATLA…...
【单调栈】496. 下一个更大元素 I
496. 下一个更大元素 I 解题思路 首先计算nums2的每一个元素的下一个比他大的元素,使用单调栈将上面的结果和nums2中的每一个元素组成映射map针对每一个Nums1的元素 查询map 记录map 的value class Solution {public int[] nextGreaterElement(int[] nums1, int[…...
搭建Vue的开发环境,Edge浏览器安装VUE拓展工具
一、在下载vue.js文件 在vue官网中下载开发版本的vue.js文件--> 安装 — Vue.js (vuejs.org) 二、将vue.js导入到项目中 这时候我们运行项目控制台会抛出两个错误 三、安装拓展工具 这里以Edge浏览器为例,其他浏览器上可在拓展管理商店中下载 进入Edge的拓展…...
14:00面试,14:06就出来了,这问的谁顶得住啊
从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到8月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%,…...
python 探索分形世界|曼德布洛特|np.frompyfunc()
文章目录 分形的重要特征曼德布洛特集合曼德布洛特集合有一个以证明的结论:图像展示np.ogrid[]np.frompyfunc()集合转图像 julia集合 无边的奇迹源自简单规则的无限重复 ---- 分形之父Benoit B.Mandelbrot 分形的重要特征 自相似性无标度性非线性 曼德布洛特集合…...
Android MVVM示例项目
项目地址 GitHub - yaolunwei/Androidbbc at androidx 项目简介 BBC(基础业务组件) 业务组件的基础,所有业务组件必须基于该组件进行开发,提供一站式开发 快速使用 gradle: implementation com.bigoat.android:bbc:0.0.1 约定成俗 以下继承关系…...
迅为龙芯2K1000开发板通过汇编控制GPIO
上一小节,我们使用了 C 语言控制了 gpio,这一小节我们来看一下如何使用汇编来控制 gpio 呢?有的 同学可能会有疑问了,既然我们可以使用 C 语言来控制 gpio,为什么我们还要使用更底层的汇编语言呢, 如果我…...
合合信息、上海大学、华南理工大学发布业内首个古彝文编码“大字典” ,为古文字打造“身份证”
“乌蒙山连着山外山,月光洒向了响水滩。”近期在各大短视频平台爆火的《奢香夫人》你听过吗?奢香夫人是一位彝族“巾帼英雄”,这首同名歌曲早在2009年便已发布,如今突然“翻红”,不仅体现了大众对于少数民族文化高涨的…...
Django — 类视图和中间件
目录 一、类视图1、基于类的结构2、常见的类视图基类3、类视图的优点4、代码案例 二、中间件1、定义2、工作原理3、自带中间件4、中间件开发流程5、自定义中间件6、案例 一、类视图 类视图(Class-Based Views)是 Django 中用于处理 HTTP 请求和生成 HTT…...
VMware安装CentOS Stream 8以及JDK和Docker
一、下载镜像源 地址:https://developer.aliyun.com/mirror/?spma2c6h.25603864.0.0.285b32d48O2G8Y 二、安装配置 配置项 一共有以下这些,其中软件、软件选择 、安装目的地、网络主机名需要讲一下,其他都简单,自行设置即可。 …...
MySQL作业1
目录 一.创建一张表,包含以下所有数据类型 建表:编辑 二.使用以下六种约束 1.非空约束 2.唯一约束 3.主键约束 4.外键约束 5.检查约束 6.默认值约束 一.创建一张表,包含以下所有数据类型 Text 类型: Number 类型&#…...
基于微信小程序的家校通系统设计与实现(亮点:选题新颖、上传作业、批改作业、成绩统计)
文章目录 前言运行环境说明家长微信小程序端的主要功能有:教师微信小程序端的主要功能有:管理员的主要功能有:具体实现截图详细视频演示为什么选择我自己的网站自己的小程序(小蔡coding)有保障的售后福利 代码参考论文…...
uni-app问题记录
一、启动问题记录 1. 报错1 解决办法: 开启微信开发者工具服务端口 2. 报错2:调用getLocation获取位置信息时报错以下内容 {errMsg: “getLocation:fail the api need to be declared in the requiredPrivateInfos field in app.json/ext.json”} 解决办法: manifest.json文…...
Leetcode---363周赛
题目列表 2859. 计算 K 置位下标对应元素的和 2860. 让所有学生保持开心的分组方法数 2861. 最大合金数 2862. 完全子集的最大元素和 一、计算k置为下标对应元素的和 简单题,直接暴力模拟,代码如下 class Solution { public:int sumIndicesWithKS…...
linux之kylin系统nginx的安装
一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源(HTML/CSS/图片等),响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址,提高安全性 3.负载均衡服务器 支持多种策略分发流量…...
《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》
引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...
MongoDB学习和应用(高效的非关系型数据库)
一丶 MongoDB简介 对于社交类软件的功能,我们需要对它的功能特点进行分析: 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具: mysql:关系型数据库&am…...
STM32标准库-DMA直接存储器存取
文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...
c++ 面试题(1)-----深度优先搜索(DFS)实现
操作系统:ubuntu22.04 IDE:Visual Studio Code 编程语言:C11 题目描述 地上有一个 m 行 n 列的方格,从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子,但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...
如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...
Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...
AI病理诊断七剑下天山,医疗未来触手可及
一、病理诊断困局:刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断",医生需通过显微镜观察组织切片,在细胞迷宫中捕捉癌变信号。某省病理质控报告显示,基层医院误诊率达12%-15%,专家会诊…...
基于SpringBoot在线拍卖系统的设计和实现
摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统,主要的模块包括管理员;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...
