当前位置: 首页 > news >正文

paddle2.3-基于联邦学习实现FedAVg算法-CNN

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict

class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)

6. 模型训练

class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k]   weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

Round: 10... 	Average Loss: [0.024]
Round: 20... 	Average Loss: [0.015]
Round: 30... 	Average Loss: [0.008]
Round: 40... 	Average Loss: [0.003]
Round: 50... 	Average Loss: [0.004]
Round: 60... 	Average Loss: [0.002]
Round: 70... 	Average Loss: [0.002]
Round: 80... 	Average Loss: [0.002]
Round: 90... 	Average Loss: [0.001]
Round: 100... 	Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039

相关文章:

paddle2.3-基于联邦学习实现FedAVg算法-CNN

目录 1. 联邦学习介绍 2. 实验流程 3. 数据加载 4. 模型构建 5. 数据采样函数 6. 模型训练 1. 联邦学习介绍 联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备&#…...

nuiapp保存canvas绘图

要保存一个 Canvas 绘图,可以使用以下步骤: 获取 Canvas 元素和其绘图上下文: var canvas document.getElementById("myCanvas"); var ctx canvas.getContext("2d");使用 Canvas 绘图 API 绘制图形。 使用 toDataUR…...

Object.defineProperty()方法详解,了解vue2的数据代理

假期第一篇,对于基础的知识点,我感觉自己还是很薄弱的。 趁着假期,再去复习一遍 Object.defineProperty(),对于这个方法,更多的还是停留在面试的时候,面试官问你vue2和vue3区别的时候,不免要提一提这个方法…...

Linux 磁盘管理

Linux 系统的磁盘管理直接关系到整个系统的性能表现。磁盘管理常用三个命令为: df、du 和 fdisk。 df df(英文全称:disk free)。df 命令用于显示磁盘空间的使用情况,包括文件系统的挂载点、总容量、已用空间、可用空间…...

大数据与人工智能的未来已来

大数据与人工智能的定义 大数据: 大数据指的是规模庞大、复杂性高、多样性丰富的数据集合。这些数据通常无法通过传统的数据库管理工具来捕获、存储、管理和处理。大数据的特点包括"3V": 大量(Volume):大数…...

【AI视野·今日Robot 机器人论文速览 第四十一期】Tue, 26 Sep 2023

AI视野今日CS.Robotics 机器人学论文速览 Tue, 26 Sep 2023 Totally 73 papers 👉上期速览✈更多精彩请移步主页 Daily Robotics Papers Extreme Parkour with Legged Robots Authors Xuxin Cheng, Kexin Shi, Ananye Agarwal, Deepak Pathak人类可以通过以高度动态…...

[NOIP2012 提高组] 开车旅行

[NOIP2012 提高组] 开车旅行 题目描述 小 A \text{A} A 和小 B \text{B} B 决定利用假期外出旅行,他们将想去的城市从 $1 $ 到 n n n 编号,且编号较小的城市在编号较大的城市的西边,已知各个城市的海拔高度互不相同,记城市 …...

数据库设计流程---以案例熟悉

案例名字:宠物商店系统 课程来源:点击跳转 信息->概念模型->数据模型->数据库结构模型 将现实世界中的信息转换为信息世界的概念模型(E-R模型) 业务逻辑 构建 E-R 图 确定三个实体:用户、商品、订单...

Miniconda创建paddlepaddle环境

1、conda env list 2、conda create --name paddle_env python3.8 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 3、activate paddle_env 4、python -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple 5、pip install "p…...

postgresql实现单主单从

实现步骤 1.主库创建一个有复制权限的用户 CREATE ROLE 用户名login # 有登录权限的角色即是用户replication #复制权限 encrypted password 密码;2.主库配置开放从库外部访问权限 修改 pg_hba.conf 文件 (相当于开放防火墙) # 类型 数据库 …...

提取PDF数据:Documents for PDF ( GcPdf )

在当今数据驱动的世界中,从 PDF 文档中无缝提取结构化表格数据已成为开发人员的一项关键任务。借助GrapeCity Documents for PDF ( GcPdf ),您可以使用 C# 以编程方式轻松解锁这些 PDF 中隐藏的信息宝藏。 考虑一下 PDF(最常用的文档格式之一…...

adb连接切换到模拟器端口

查看连接状态 adb devices出现以下情况 C:\Users\22560>adb devices List of devices attached 127.0.0.1:5555 offline emulator-5554 device可以发现我们想要连接的雷电模拟器的5555端口目前没有连接,只有emulator-5554被连接了,此时我们需要关…...

为何每个开发者都在谈论Go?

目录 一、引言Go的历史回顾关键时间节点 使用场景Go的语言地位技术社群与企业支持资源投入和生态系统 二、简洁的语法结构基本组成元素变量声明与初始化代码示例 类型推断函数与返回值代码示例输出 接口与结构体:组合而非继承错误处理:明确而不是异常小结…...

【Leetcode】 501. 二叉搜索树中的众数

给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素)。 如果树中有不止一个众数,可以按 任意顺序 返回。 假定 BST 满足如下定义&#xf…...

怎样给Ubuntu系统安装vmware-tools

首先我要告诉你:Ubuntu无法安装vmware-tools,之所以这么些是因为我一开始也是这样认为的,vmware-tools是给Windows系统准备的我认为,毕竟Windows占有率远远高于Linux,这也可以理解。 那么怎么样实现Ubuntu虚拟机跟Wind…...

DDS信号发生器波形发生器VHDL

名称:DDS信号发生器波形发生器 软件:Quartus 语言:VHDL 要求: 在EDA平台中使用VHDL语言为工具,设计一个常见信号发生电路,要求: 1. 能够产生锯齿波,方波,三角波&…...

Python3操作SQLite3创建表主键自增长|CRUD基本操作

Win11查看安装的Python路径及安装的库 Python PEP8 代码规范常见问题及解决方案 Python3操作MySQL8.XX创建表|CRUD基本操作 Python3操作SQLite3创建表主键自增长|CRUD基本操作 anaconda3最新版安装|使用详情|Error: Please select a valid Python interpreter Python函数绘…...

B. Comparison String

题目&#xff1a; 样例&#xff1a; 输入 4 4 <<>> 4 >><< 5 >>>>> 7 <><><><输出 3 3 6 2 思路&#xff1a; 由题意&#xff0c;条件是 又因为要使用尽可能少的数字&#xff0c;这是一道贪心题&#xff0c;所以…...

python端口扫描

扫描所有端口 import socket, threading, os, timedef port_thread(ip, start, step, timeout):for port in range(start, start step):s socket.socket()s.settimeout(timeout)try:s.connect((ip, port))print(f"port[{port}] 可用")except Exception as e:# pri…...

国庆第二天

#include<th.h>#define ERR_MSG(msg) do{\fprintf(stderr,"__%d__",__LINE__);\perror(msg);\ }while(0)#define PORT 6666 #define IP "192.168.2.3"//键盘输入事件 int serverkeyboard(fd_set readfds) {char buf[128] "";int sndfd -…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装

以下是基于 vant-ui&#xff08;适配 Vue2 版本 &#xff09;实现截图中照片上传预览、删除功能&#xff0c;并封装成可复用组件的完整代码&#xff0c;包含样式和逻辑实现&#xff0c;可直接在 Vue2 项目中使用&#xff1a; 1. 封装的图片上传组件 ImageUploader.vue <te…...

Web后端基础(基础知识)

BS架构&#xff1a;Browser/Server&#xff0c;浏览器/服务器架构模式。客户端只需要浏览器&#xff0c;应用程序的逻辑和数据都存储在服务端。 优点&#xff1a;维护方便缺点&#xff1a;体验一般 CS架构&#xff1a;Client/Server&#xff0c;客户端/服务器架构模式。需要单独…...

【把数组变成一棵树】有序数组秒变平衡BST,原来可以这么优雅!

【把数组变成一棵树】有序数组秒变平衡BST,原来可以这么优雅! 🌱 前言:一棵树的浪漫,从数组开始说起 程序员的世界里,数组是最常见的基本结构之一,几乎每种语言、每种算法都少不了它。可你有没有想过,一组看似“线性排列”的有序数组,竟然可以**“长”成一棵平衡的二…...

node.js的初步学习

那什么是node.js呢&#xff1f; 和JavaScript又是什么关系呢&#xff1f; node.js 提供了 JavaScript的运行环境。当JavaScript作为后端开发语言来说&#xff0c; 需要在node.js的环境上进行当JavaScript作为前端开发语言来说&#xff0c;需要在浏览器的环境上进行 Node.js 可…...

数据库正常,但后端收不到数据原因及解决

从代码和日志来看&#xff0c;后端SQL查询确实返回了数据&#xff0c;但最终user对象却为null。这表明查询结果没有正确映射到User对象上。 在前后端分离&#xff0c;并且ai辅助开发的时候&#xff0c;很容易出现前后端变量名不一致情况&#xff0c;还不报错&#xff0c;只是单…...

Xcode 16 集成 cocoapods 报错

基于 Xcode 16 新建工程项目&#xff0c;集成 cocoapods 执行 pod init 报错 ### Error RuntimeError - PBXGroup attempted to initialize an object with unknown ISA PBXFileSystemSynchronizedRootGroup from attributes: {"isa">"PBXFileSystemSynchro…...

UE5 音效系统

一.音效管理 音乐一般都是WAV,创建一个背景音乐类SoudClass,一个音效类SoundClass。所有的音乐都分为这两个类。再创建一个总音乐类&#xff0c;将上述两个作为它的子类。 接着我们创建一个音乐混合类SoundMix&#xff0c;将上述三个类翻入其中&#xff0c;通过它管理每个音乐…...

SQL进阶之旅 Day 22:批处理与游标优化

【SQL进阶之旅 Day 22】批处理与游标优化 文章简述&#xff08;300字左右&#xff09; 在数据库开发中&#xff0c;面对大量数据的处理任务时&#xff0c;单条SQL语句往往无法满足性能需求。本篇文章聚焦“批处理与游标优化”&#xff0c;深入探讨如何通过批量操作和游标技术提…...