当前位置: 首页 > news >正文

公平联邦学习——多目标优化

前言

前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不专业的地方,欢迎大家来指正!

参考文章:FedMDFG: Federated Learning with Multi-Gradient Descent and Fair Guidance (AAAI-2023)

联邦学习背景

横向联邦学习它的基本原理很通俗易懂,就是想象一下这样的场景:

  • 有m个用户,原本每个用户都只用自己的数据来各顾各地进行本地的模型训练,这就叫做Individual Learning;

  • 而如果把所有用户的数据都上传到一处地方,然后在那儿用收集起来的所有数据进行模型训练,就叫做Centralized Learning;

  • 而现在用户不愿意分享数据了,但又怕只做本地模型训练得到的模型泛化性不好,想共同合作训练一个模型,这个时候,它们想到,把本地训练的结果/梯度上传到一个地方(称之为server),然后server收集到本地模型/梯度后,进行聚合运算,得到一个新的模型(称之为全局模型),并将其发给用户;用户基于收到这个新的全局模型来再进行本地训练,完后再回传给server进行聚合,如此往复。这样一来,用户就不需要share数据了,同时又能从这样的联合训练过程中获得一个泛化性更好的模型。

上述三种模型训练方式的对比如下图:

 

总结来说,generic FL的流程如下:

  1. 在第 t 轮通信时,Server把global model分发给各个clients。

  2. Clients收到模型后,进行本地训练,然后上传模型更新的gradients(或local updates)给server。

  3. Server基于收到的gradients(或local updates)更新global model。

  4. t=t+1 ,然后回到第1步,并循环若干轮直到达到终止条件。

公平联邦学习

因为联邦学习是一个涉及多个用户共同参与的过程,因此合作的公平性尤为重要。FL的公平性有很多种,其中最直观的就是performance fairness。简单来说,当一个模型在某些用户的数据上精度很高,但在其他用户上则表现很差,则这个模型是不公平的了,例如下图:

 

为了提升模型的performance fairness,我们的目标是去让模型在各个clients上的精度更平均一些。那怎么衡量公平性呢?我们很容易想到用模型在各个clients上的accuracy的标准差。但标准差不是scale-free的。为此,这里我们引入余弦相似度来衡量公平性:

 

导致模型不公平的直接原因

为了提升FL模型的公平性,前人提出了很多复杂的方法,例如降低模型的本地更新与全局更新的冲突;在模型聚合的时候提高效果较弱的用户的权重,等等。但究其根本,很容易想到,导致模型不公平有两个直接原因:

  • 使用了一个会加剧不公平的更新方向来更新模型;

  • 使用了一个不恰当的学习率来更新模型。

具体来说,如果用一个错误的更新方向,它对于某些用户而言不是梯度下降的,因此就会直接导致模型不公平。此外,哪怕更新方向能促进公平,但由于步长(学习率)太大了,也会破坏模型的公平性。

为此,FedMDFG这篇文章从多目标优化的角度来给出了提升模型公平性的方法。在我之前的多目标优化研究里面,通常是有若干个目标函数,然后并不是用加权聚合的方法把它转成单目标,而是用多目标优化算法,使得解逐渐趋于帕累托最优。帕累托最优的概念可以参考此链接:多目标优化。

而对于联邦学习而言,如果我们把模型参数ω看作是决策变量,每个用户的training loss(记为L)看作优化目标,那么就可以构建这样一个多目标优化问题:

 

由于这是一个连续变量的优化问题,我们可以用Multiple Gradient Descent Algorithm (MGDA)来求解出一个common descent direction来更新模型,因为该direction是common descent的,所以它能同时让模型在各个用户上的效果变好。

但“不让模型在某些用户上的效果变差”只是FL公平性的一个必要条件,只做到这一点,并不能使模型变得公平。举个简单的例子:假设有两个用户,它们的loss值分别是5和6,用common descent direction来更新模型后,两者的loss值变为4和1,显然,虽然两者的loss都下降了,但模型变得更加不公平了。

文章画了一个直观的图来展现什么样的更新方向才能促进模型公平性:

 

从图中可看出,三个用户的梯度g1, g2, g3相互冲突,中间黄色+灰色锥的表示所有满足common descent direction的向量组成的区域,而黄色锥中的更新方向既满足common descent,又能够避免上面所说的导致模型不公平的问题。可以直观地看到,前面的几种公平联邦学习算法,并不能确保更新方向是common descent的。因此它们不但会破坏模型的公平性,甚至可能还会出现收敛问题,尤其是在non-IID的场景下。

为了计算出一个公平的更新方向(即落在黄色区域内),文章在FL的多目标优化的formulation基础上,进一步引入了一个新的优化目标:

其中h是这样一个向量:

 

它是在clients的loss组成的向量L(ω)上作一个垂直的向量,并进行归一化所得。如下图所示:

 

同时优化这个目标,就能让L(ω)在偏离公平引导向量p=(1,1,...,1)的时候,将它拉回去,促进了模型的公平性。

因此,文章的优化目标变成:

 

至此,我们就可以用MGDA来求解上述多目标优化问题:

  • 首先创建一个矩阵Q,它由上述的m+1个目标函数的梯度拼接而成。

  • 然后求解下述dual problem得到λ:

 

  • 最后模型更新方向 就等于:

 

根据对偶理论,上述优化问题是下面这个优化问题的dual problem。

 

因此,上面求解得到的 满足:

  1. 如果 已经是Pareto critical,则 ;

  2. 如果 还不是是Pareto critical,则:

 

因此, 不单对于各用户的loss而言是common descent的,它对于新增的目标而言也是common descent的,因此它能够促进模型公平性。

步长线性搜索

文章还给出了一套适用于FL的线性搜索步长的策略,很容易实施。简单来说,就是server在计算出公平更新方向 并发给clients后,clients用这个 以及一系列从大到小的步长,分别进行local training,并评估结果的loss,发给server。server在收集所有的loss反馈后,得到一个最大的、并且相较于旧模型在精度和公平性上变得更好的步长,来作为最终正式更新模型的参数。值得注意的是,由于这个额外的通信过程中,只需要来回传标量值,因此它并不会增加很多额外的通信开销。

我觉得这是很有意义的一件事情,步长线性搜索策略能让FL不再依赖调节learning rate这个超参数。并且它是目前而言绝无仅有的可行的方法。我分析原因大概是这样:想要确保模型能收敛的同时,用一个更好、收敛性更快的步长,则最好要保证模型的更新方向是common descent的。但此前的FL算法并不能保障梯度下降,因此,如果贸然地加大更新步长(learning rate),则无法保障模型收敛。而FedMDFG因为能计算出一个common descent且公平的梯度下降方向,所以可以用line search搜索出恰当的步长来更新模型。反之,如果用的是传统的FL算法,比如FedAvg,它们不能确保更新方向是common descent的,如果强行用line search,则需要结合更新方向,设计一套更加复杂的步长搜索策略,才能保障收敛性。

实验

文章还给出了算法的收敛性严格证明,并且在多个场景下进行了实验。具体的实验设置这里就不赘述了,这里贴其中一个实验图,可以直观看到FedMDFG显著地提升了算法的收敛速度,以及收敛效果。并且在公平性上也明显好很多。文章还给出了复现的代码。

 

后记

将多目标优化与联邦学习结合,确实是一个很令人信服的方法。它能够具有理论保障地改善联邦学习的公平性,使得联邦学习在non-IID的场景下表现更佳。并且引入步长的线性搜索策略,能让联邦学习更具备落地实用性。

我目前已成功将公平联邦学习算法应用到智能电网的负荷预测和非侵入式负荷监测中,取得了令人满意的效果。后续我会继续关注这一块。希望这篇文章在帮助自己记录学习点滴之余,也能帮助大家!

相关文章:

公平联邦学习——多目标优化

前言 前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不…...

奇怪的Python:为何字符串要设置成不可变的?

你好!我是老邓。今天我们来聊聊 Python 中字符串不可变这个话题。 1、问题简介: Python 中,字符串属于不可变对象。这意味着一旦字符串被创建,它的值就无法被修改。任何看似修改字符串的操作,实际上都是创建了一个新…...

Vue-Router之嵌套路由

在路由配置中,配置children import Vue from vue import VueRouter from vue-routerVue.use(VueRouter)const router new VueRouter({mode: history,base: import.meta.env.BASE_URL,routes: [{path: /,redirect: /home},{path: /home,name: home,component: () &…...

MyBatis使用的设计模式

目录 1. 工厂模式(Factory Pattern) 2. 单例模式(Singleton Pattern) 3. 代理模式(Proxy Pattern) 4. 装饰器模式(Decorator Pattern) 5. 观察者模式(Observer Patt…...

arm rk3588 升级glibc2.31到2.33

一、查看glibc版本 rootztl:~# ldd --version ldd (Ubuntu GLIBC 2.31-0ubuntu9.2) 2.31 Copyright (C) 2020 Free Software Foundation, Inc. This is free software; see the source for copying conditions. There is NO warranty; not even for MERCHANTABILITY or FITNE…...

【Linux系列】sed命令的深入解析:如何使用sed删除文件内容

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

C++ 设计模式:桥接模式(Bridge Pattern)

链接:C 设计模式 链接:C 设计模式 - 装饰模式 桥接模式(Bridge Pattern)是一种结构型设计模式,它通过将抽象部分(业务功能)与实现部分(平台实现)分离,使它们…...

MATLAB中whitespacePattern函数用法

目录 语法 说明 示例 匹配空白字符 替换非标准空白 更正错误的间距 whitespacePattern函数的功能是匹配空白字符。 语法 pat whitespacePattern pat whitespacePattern(N) pat whitespacePattern(minCharacters,maxCharacters) 说明 pat whitespacePattern 创建一…...

Django多字段认证的实现

Django多字段认证 需求: django认证的检查用户是username,如果使用 username和 手机号验证登录。 重写: ModelBackend 类下的 authenticate 方法 # 在对应应用下创建 utils.py""" 修改Django认证类,为了实现 …...

【AndroidAPP】权限被拒绝:[android.permission.READ_EXTERNAL_STORAGE],USB设备访问权限系统报错

一、问题原因 1.安卓安全性变更 Android 12 的安全性变更,Google 引入了更严格的 PendingIntent 安全管理,强制要求开发者明确指定 PendingIntent 的可变性(Mutable)或不可变性(Immutable)。 但是&#xf…...

SQL进阶技巧:如何分析连续签到领金币数问题?

目录 0 题目需求 1 数据准备 2 问题分析 2.1 代码实现 2.2 代码功能分析 第一段 SQL...

1、ELK的架构和安装

ELK简介 elk:elasticsearch logstash kibana,统一日志收集系统。 elasticsearch:分布式的全文索引引擎的非关系数据库,json格式,在elk中存储所有的日志信息,架构有主和从,最少需要2台。 …...

Vue2/Vue3使用DataV

Vue2 注意vue2与3安装DataV命令命令是不同的Vue3 DataV - Vue3 官网地址 注意vue2与3安装DataV命令命令是不同的 vue3vite 与 Vue3webpack 对应安装也不同vue3vite npm install kjgl77/datav-vue3全局引入 // main.ts中全局引入 import { createApp } from vue import Da…...

汇编环境搭建

学习视频 将MASM所在目录 指定为C盘...

Android 系统 `android.app.Fragment` 类的深度定制与常见问题解析

Android 系统 android.app.Fragment 类的深度定制与常见问题解析 目录 引言Fragment 概述Fragment 的生命周期Fragment 的系统层深度定制 4.1 Fragment 的创建与初始化4.2 Fragment 的布局与视图4.3 Fragment 的通信机制4.4 Fragment 的动画与过渡4.5 Fragment 的状态保存与恢…...

linux ueditor nginx https 后台配置项返回格式出错,上传功能将不能正常使用

jsp的版本 如果出现了这个错误,上传的图标都亮起的情况,还是提示这个, 可以试试修改 uedtior.all.js 8082行 isJsonp utils.isCrossDomainUrl(configUrl); 改为 // isJsonp utils.isCrossDomainUrl(configUrl); isJsonp true; 如果还不…...

【机器学习 | 数据挖掘】时间序列算法

时间序列是按时间顺序排列的、随时间变化且相互关联的数据序列。分析时间序列的方法构成数据分析的一个重要领域,即时间序列分析。以下是对时间序列算法的详细介绍: 一、时间序列的分类 时间序列根据所研究的依据不同,可有不同的分类&#…...

uniapp H5 对接 声网,截图

文章目录 安装依赖创建容器容器样式 javascript代码ImageDataToBlob 方法 控制控制台LOG输出 安装依赖 版本"agora-rtc-sdk-ng": "^4.22.0", 创建容器 <template><view class"videoValue " id"videoValue"><u-toast…...

家谱管理系统|Java|SSM|VUE| 前后端分离

【技术栈】 1⃣️&#xff1a;架构: B/S、MVC 2⃣️&#xff1a;系统环境&#xff1a;Windowsh/Mac 3⃣️&#xff1a;开发环境&#xff1a;IDEA、JDK1.8、Maven、Mysql5.7 4⃣️&#xff1a;技术栈&#xff1a;Java、Mysql、SSM、Mybatis-Plus、VUE、jquery,html 5⃣️数据库…...

【LeetCode】200、岛屿数量

【LeetCode】200、岛屿数量 文章目录 一、并查集1.1 并查集1.2 多语言解法 二、洪水填充 DFS2.1 洪水填充 DFS 一、并查集 1.1 并查集 // go var sets int var father [90000]intfunc numIslands(grid [][]byte) int {n, m : len(grid), len(grid[0])build(grid, n, m)for i …...

PHP和Node.js哪个更爽?

先说结论&#xff0c;rust完胜。 php&#xff1a;laravel&#xff0c;swoole&#xff0c;webman&#xff0c;最开始在苏宁的时候写了几年php&#xff0c;当时觉得php真的是世界上最好的语言&#xff0c;因为当初活在舒适圈里&#xff0c;不愿意跳出来&#xff0c;就好比当初活在…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例&#xff0c;其中使用的是 Module Federation 和 npx-build-plus 实现了主应用&#xff08;Shell&#xff09;与子应用&#xff08;Remote&#xff09;的集成。 &#x1f6e0;️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...

(一)单例模式

一、前言 单例模式属于六大创建型模式,即在软件设计过程中,主要关注创建对象的结果,并不关心创建对象的过程及细节。创建型设计模式将类对象的实例化过程进行抽象化接口设计,从而隐藏了类对象的实例是如何被创建的,封装了软件系统使用的具体对象类型。 六大创建型模式包括…...

pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)

目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 &#xff08;1&#xff09;输入单引号 &#xff08;2&#xff09;万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...

【SpringBoot自动化部署】

SpringBoot自动化部署方法 使用Jenkins进行持续集成与部署 Jenkins是最常用的自动化部署工具之一&#xff0c;能够实现代码拉取、构建、测试和部署的全流程自动化。 配置Jenkins任务时&#xff0c;需要添加Git仓库地址和凭证&#xff0c;设置构建触发器&#xff08;如GitHub…...

Linux部署私有文件管理系统MinIO

最近需要用到一个文件管理服务&#xff0c;但是又不想花钱&#xff0c;所以就想着自己搭建一个&#xff0c;刚好我们用的一个开源框架已经集成了MinIO&#xff0c;所以就选了这个 我这边对文件服务性能要求不是太高&#xff0c;单机版就可以 安装非常简单&#xff0c;几个命令就…...

rknn toolkit2搭建和推理

安装Miniconda Miniconda - Anaconda Miniconda 选择一个 新的 版本 &#xff0c;不用和RKNN的python版本保持一致 使用 ./xxx.sh进行安装 下面配置一下载源 # 清华大学源&#xff08;最常用&#xff09; conda config --add channels https://mirrors.tuna.tsinghua.edu.cn…...