当前位置: 首页 > news >正文

公平联邦学习——多目标优化

前言

前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不专业的地方,欢迎大家来指正!

参考文章:FedMDFG: Federated Learning with Multi-Gradient Descent and Fair Guidance (AAAI-2023)

联邦学习背景

横向联邦学习它的基本原理很通俗易懂,就是想象一下这样的场景:

  • 有m个用户,原本每个用户都只用自己的数据来各顾各地进行本地的模型训练,这就叫做Individual Learning;

  • 而如果把所有用户的数据都上传到一处地方,然后在那儿用收集起来的所有数据进行模型训练,就叫做Centralized Learning;

  • 而现在用户不愿意分享数据了,但又怕只做本地模型训练得到的模型泛化性不好,想共同合作训练一个模型,这个时候,它们想到,把本地训练的结果/梯度上传到一个地方(称之为server),然后server收集到本地模型/梯度后,进行聚合运算,得到一个新的模型(称之为全局模型),并将其发给用户;用户基于收到这个新的全局模型来再进行本地训练,完后再回传给server进行聚合,如此往复。这样一来,用户就不需要share数据了,同时又能从这样的联合训练过程中获得一个泛化性更好的模型。

上述三种模型训练方式的对比如下图:

 

总结来说,generic FL的流程如下:

  1. 在第 t 轮通信时,Server把global model分发给各个clients。

  2. Clients收到模型后,进行本地训练,然后上传模型更新的gradients(或local updates)给server。

  3. Server基于收到的gradients(或local updates)更新global model。

  4. t=t+1 ,然后回到第1步,并循环若干轮直到达到终止条件。

公平联邦学习

因为联邦学习是一个涉及多个用户共同参与的过程,因此合作的公平性尤为重要。FL的公平性有很多种,其中最直观的就是performance fairness。简单来说,当一个模型在某些用户的数据上精度很高,但在其他用户上则表现很差,则这个模型是不公平的了,例如下图:

 

为了提升模型的performance fairness,我们的目标是去让模型在各个clients上的精度更平均一些。那怎么衡量公平性呢?我们很容易想到用模型在各个clients上的accuracy的标准差。但标准差不是scale-free的。为此,这里我们引入余弦相似度来衡量公平性:

 

导致模型不公平的直接原因

为了提升FL模型的公平性,前人提出了很多复杂的方法,例如降低模型的本地更新与全局更新的冲突;在模型聚合的时候提高效果较弱的用户的权重,等等。但究其根本,很容易想到,导致模型不公平有两个直接原因:

  • 使用了一个会加剧不公平的更新方向来更新模型;

  • 使用了一个不恰当的学习率来更新模型。

具体来说,如果用一个错误的更新方向,它对于某些用户而言不是梯度下降的,因此就会直接导致模型不公平。此外,哪怕更新方向能促进公平,但由于步长(学习率)太大了,也会破坏模型的公平性。

为此,FedMDFG这篇文章从多目标优化的角度来给出了提升模型公平性的方法。在我之前的多目标优化研究里面,通常是有若干个目标函数,然后并不是用加权聚合的方法把它转成单目标,而是用多目标优化算法,使得解逐渐趋于帕累托最优。帕累托最优的概念可以参考此链接:多目标优化。

而对于联邦学习而言,如果我们把模型参数ω看作是决策变量,每个用户的training loss(记为L)看作优化目标,那么就可以构建这样一个多目标优化问题:

 

由于这是一个连续变量的优化问题,我们可以用Multiple Gradient Descent Algorithm (MGDA)来求解出一个common descent direction来更新模型,因为该direction是common descent的,所以它能同时让模型在各个用户上的效果变好。

但“不让模型在某些用户上的效果变差”只是FL公平性的一个必要条件,只做到这一点,并不能使模型变得公平。举个简单的例子:假设有两个用户,它们的loss值分别是5和6,用common descent direction来更新模型后,两者的loss值变为4和1,显然,虽然两者的loss都下降了,但模型变得更加不公平了。

文章画了一个直观的图来展现什么样的更新方向才能促进模型公平性:

 

从图中可看出,三个用户的梯度g1, g2, g3相互冲突,中间黄色+灰色锥的表示所有满足common descent direction的向量组成的区域,而黄色锥中的更新方向既满足common descent,又能够避免上面所说的导致模型不公平的问题。可以直观地看到,前面的几种公平联邦学习算法,并不能确保更新方向是common descent的。因此它们不但会破坏模型的公平性,甚至可能还会出现收敛问题,尤其是在non-IID的场景下。

为了计算出一个公平的更新方向(即落在黄色区域内),文章在FL的多目标优化的formulation基础上,进一步引入了一个新的优化目标:

其中h是这样一个向量:

 

它是在clients的loss组成的向量L(ω)上作一个垂直的向量,并进行归一化所得。如下图所示:

 

同时优化这个目标,就能让L(ω)在偏离公平引导向量p=(1,1,...,1)的时候,将它拉回去,促进了模型的公平性。

因此,文章的优化目标变成:

 

至此,我们就可以用MGDA来求解上述多目标优化问题:

  • 首先创建一个矩阵Q,它由上述的m+1个目标函数的梯度拼接而成。

  • 然后求解下述dual problem得到λ:

 

  • 最后模型更新方向 就等于:

 

根据对偶理论,上述优化问题是下面这个优化问题的dual problem。

 

因此,上面求解得到的 满足:

  1. 如果 已经是Pareto critical,则 ;

  2. 如果 还不是是Pareto critical,则:

 

因此, 不单对于各用户的loss而言是common descent的,它对于新增的目标而言也是common descent的,因此它能够促进模型公平性。

步长线性搜索

文章还给出了一套适用于FL的线性搜索步长的策略,很容易实施。简单来说,就是server在计算出公平更新方向 并发给clients后,clients用这个 以及一系列从大到小的步长,分别进行local training,并评估结果的loss,发给server。server在收集所有的loss反馈后,得到一个最大的、并且相较于旧模型在精度和公平性上变得更好的步长,来作为最终正式更新模型的参数。值得注意的是,由于这个额外的通信过程中,只需要来回传标量值,因此它并不会增加很多额外的通信开销。

我觉得这是很有意义的一件事情,步长线性搜索策略能让FL不再依赖调节learning rate这个超参数。并且它是目前而言绝无仅有的可行的方法。我分析原因大概是这样:想要确保模型能收敛的同时,用一个更好、收敛性更快的步长,则最好要保证模型的更新方向是common descent的。但此前的FL算法并不能保障梯度下降,因此,如果贸然地加大更新步长(learning rate),则无法保障模型收敛。而FedMDFG因为能计算出一个common descent且公平的梯度下降方向,所以可以用line search搜索出恰当的步长来更新模型。反之,如果用的是传统的FL算法,比如FedAvg,它们不能确保更新方向是common descent的,如果强行用line search,则需要结合更新方向,设计一套更加复杂的步长搜索策略,才能保障收敛性。

实验

文章还给出了算法的收敛性严格证明,并且在多个场景下进行了实验。具体的实验设置这里就不赘述了,这里贴其中一个实验图,可以直观看到FedMDFG显著地提升了算法的收敛速度,以及收敛效果。并且在公平性上也明显好很多。文章还给出了复现的代码。

 

后记

将多目标优化与联邦学习结合,确实是一个很令人信服的方法。它能够具有理论保障地改善联邦学习的公平性,使得联邦学习在non-IID的场景下表现更佳。并且引入步长的线性搜索策略,能让联邦学习更具备落地实用性。

我目前已成功将公平联邦学习算法应用到智能电网的负荷预测和非侵入式负荷监测中,取得了令人满意的效果。后续我会继续关注这一块。希望这篇文章在帮助自己记录学习点滴之余,也能帮助大家!

相关文章:

公平联邦学习——多目标优化

前言 前段时间接触到了联邦学习(Federated Learning, FL)。涉猎了几年多目标优化的我,惊奇地发现横向联邦学习里面也有用多目标优化来做的。于是有感而发,特此写一篇博客记录记录,如有机会可以和大家多多交流。遇到不…...

奇怪的Python:为何字符串要设置成不可变的?

你好!我是老邓。今天我们来聊聊 Python 中字符串不可变这个话题。 1、问题简介: Python 中,字符串属于不可变对象。这意味着一旦字符串被创建,它的值就无法被修改。任何看似修改字符串的操作,实际上都是创建了一个新…...

Vue-Router之嵌套路由

在路由配置中,配置children import Vue from vue import VueRouter from vue-routerVue.use(VueRouter)const router new VueRouter({mode: history,base: import.meta.env.BASE_URL,routes: [{path: /,redirect: /home},{path: /home,name: home,component: () &…...

MyBatis使用的设计模式

目录 1. 工厂模式(Factory Pattern) 2. 单例模式(Singleton Pattern) 3. 代理模式(Proxy Pattern) 4. 装饰器模式(Decorator Pattern) 5. 观察者模式(Observer Patt…...

arm rk3588 升级glibc2.31到2.33

一、查看glibc版本 rootztl:~# ldd --version ldd (Ubuntu GLIBC 2.31-0ubuntu9.2) 2.31 Copyright (C) 2020 Free Software Foundation, Inc. This is free software; see the source for copying conditions. There is NO warranty; not even for MERCHANTABILITY or FITNE…...

【Linux系列】sed命令的深入解析:如何使用sed删除文件内容

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

C++ 设计模式:桥接模式(Bridge Pattern)

链接:C 设计模式 链接:C 设计模式 - 装饰模式 桥接模式(Bridge Pattern)是一种结构型设计模式,它通过将抽象部分(业务功能)与实现部分(平台实现)分离,使它们…...

MATLAB中whitespacePattern函数用法

目录 语法 说明 示例 匹配空白字符 替换非标准空白 更正错误的间距 whitespacePattern函数的功能是匹配空白字符。 语法 pat whitespacePattern pat whitespacePattern(N) pat whitespacePattern(minCharacters,maxCharacters) 说明 pat whitespacePattern 创建一…...

Django多字段认证的实现

Django多字段认证 需求: django认证的检查用户是username,如果使用 username和 手机号验证登录。 重写: ModelBackend 类下的 authenticate 方法 # 在对应应用下创建 utils.py""" 修改Django认证类,为了实现 …...

【AndroidAPP】权限被拒绝:[android.permission.READ_EXTERNAL_STORAGE],USB设备访问权限系统报错

一、问题原因 1.安卓安全性变更 Android 12 的安全性变更,Google 引入了更严格的 PendingIntent 安全管理,强制要求开发者明确指定 PendingIntent 的可变性(Mutable)或不可变性(Immutable)。 但是&#xf…...

SQL进阶技巧:如何分析连续签到领金币数问题?

目录 0 题目需求 1 数据准备 2 问题分析 2.1 代码实现 2.2 代码功能分析 第一段 SQL...

1、ELK的架构和安装

ELK简介 elk:elasticsearch logstash kibana,统一日志收集系统。 elasticsearch:分布式的全文索引引擎的非关系数据库,json格式,在elk中存储所有的日志信息,架构有主和从,最少需要2台。 …...

Vue2/Vue3使用DataV

Vue2 注意vue2与3安装DataV命令命令是不同的Vue3 DataV - Vue3 官网地址 注意vue2与3安装DataV命令命令是不同的 vue3vite 与 Vue3webpack 对应安装也不同vue3vite npm install kjgl77/datav-vue3全局引入 // main.ts中全局引入 import { createApp } from vue import Da…...

汇编环境搭建

学习视频 将MASM所在目录 指定为C盘...

Android 系统 `android.app.Fragment` 类的深度定制与常见问题解析

Android 系统 android.app.Fragment 类的深度定制与常见问题解析 目录 引言Fragment 概述Fragment 的生命周期Fragment 的系统层深度定制 4.1 Fragment 的创建与初始化4.2 Fragment 的布局与视图4.3 Fragment 的通信机制4.4 Fragment 的动画与过渡4.5 Fragment 的状态保存与恢…...

linux ueditor nginx https 后台配置项返回格式出错,上传功能将不能正常使用

jsp的版本 如果出现了这个错误,上传的图标都亮起的情况,还是提示这个, 可以试试修改 uedtior.all.js 8082行 isJsonp utils.isCrossDomainUrl(configUrl); 改为 // isJsonp utils.isCrossDomainUrl(configUrl); isJsonp true; 如果还不…...

【机器学习 | 数据挖掘】时间序列算法

时间序列是按时间顺序排列的、随时间变化且相互关联的数据序列。分析时间序列的方法构成数据分析的一个重要领域,即时间序列分析。以下是对时间序列算法的详细介绍: 一、时间序列的分类 时间序列根据所研究的依据不同,可有不同的分类&#…...

uniapp H5 对接 声网,截图

文章目录 安装依赖创建容器容器样式 javascript代码ImageDataToBlob 方法 控制控制台LOG输出 安装依赖 版本"agora-rtc-sdk-ng": "^4.22.0", 创建容器 <template><view class"videoValue " id"videoValue"><u-toast…...

家谱管理系统|Java|SSM|VUE| 前后端分离

【技术栈】 1⃣️&#xff1a;架构: B/S、MVC 2⃣️&#xff1a;系统环境&#xff1a;Windowsh/Mac 3⃣️&#xff1a;开发环境&#xff1a;IDEA、JDK1.8、Maven、Mysql5.7 4⃣️&#xff1a;技术栈&#xff1a;Java、Mysql、SSM、Mybatis-Plus、VUE、jquery,html 5⃣️数据库…...

【LeetCode】200、岛屿数量

【LeetCode】200、岛屿数量 文章目录 一、并查集1.1 并查集1.2 多语言解法 二、洪水填充 DFS2.1 洪水填充 DFS 一、并查集 1.1 并查集 // go var sets int var father [90000]intfunc numIslands(grid [][]byte) int {n, m : len(grid), len(grid[0])build(grid, n, m)for i …...

idea报错:There is not enough memory to perform the requested operation.

文章目录 一、问题描述二、先解决三、后原因&#xff08;了解&#xff09; 一、问题描述 就是在使用 IDEA 写代码时&#xff0c;IDEA 可能会弹一个窗&#xff0c;大概提示你目前使用的 IDEA 内存不足&#xff0c;其实就是提醒你 JVM 的内存不够了&#xff0c;需要重新分配。弹…...

python ai ReAct 代理(ReAct Agent)

ReAct 代理&#xff08;ReAct Agent&#xff09;是一种结合了推理&#xff08;Reasoning&#xff09;和行动&#xff08;Action&#xff09;的智能代理框架&#xff0c;旨在通过交互式的方式解决复杂任务。ReAct 的核心思想是让代理在完成任务时&#xff0c;能够动态地推理下一…...

HTML入门教程|| HTML 基本标签(2)

HTML 列表 HTML列表 HTML 无序列表 ul 元素表示无序列表。 ul 元素中的项目使用 li 元素表示。 元素没有在HTML5中定义任何属性&#xff0c;并且您使用CSS控制列表的显示。 HTML5中的 type 和 compact 属性已过时。 您可以在以下代码中查看正在使用的 ul 元素。 <!D…...

MySQL root用户密码忘记怎么办(Reset root account password)

在使用MySQL数据库的的过程中&#xff0c;不可避免的会出现忘记密码的现象。普通用户的密码如果忘记&#xff0c;可以用更高权限的用户&#xff08;例如root&#xff09;进行重置。但是如果root用户的密码忘记了&#xff0c;由于root用户本身就是最高权限&#xff0c;那这个方法…...

groovy:多线程 简单示例

在Groovy中&#xff0c;多线程编程与Java非常相似&#xff0c;因为Groovy运行在Java虚拟机&#xff08;JVM&#xff09;上&#xff0c;并且可以利用Java的所有并发工具。以下是一些在Groovy中实现多线程编程的方法&#xff1a; class MyThread extends Thread {Overridevoid…...

SOME/IP 协议详解——序列化

文章目录 0. 概述1.基本数据序列化2.字符串序列化2.1 字符串通用规则2.2 固定长度字符串规则2.3 动态长度字符串规则 3.结构体序列化4. 带有标识符和可选成员的结构化数据类型5. 数组5.1 固定长度数组5.2 动态长度数组5.3 Enumeration&#xff08;枚举&#xff09;5.4 Bitfield…...

三、GIT与Github推送(上传)和克隆(下载)

GIT与Github推送&#xff08;上传&#xff09;和克隆&#xff08;下载&#xff09; 一、配置好SSH二、在Github创建仓库三、git克隆&#xff08;下载&#xff09;文件四、git推送&#xff08;上传&#xff09;文件到远程仓库 一、配置好SSH Git与Github上传和下载时需要使用到…...

18.2、网络安全评测技术与攻击

目录 网络安全测评技术与工具网络安全测评质量管理和标准 网络安全测评技术与工具 漏洞扫描技术可以用于测评&#xff0c;测评你安不安全&#xff0c;也可以用来风险评估安不安全&#xff0c;风险大不大 漏洞扫描包含网络安全漏洞扫描、主机安全漏洞扫描&#xff0c;还有数据…...

在 ArcGIS Pro/GeoScene Pro 中设计专题地图的符号系统

原始 按颜色对面进行符号化 打开符号系统 选择主符号系统 选择字段及其计算方式 更改临界值</...

CSS2笔记

一、CSS基础 1.CSS简介 2.CSS的编写位置 2.1 行内样式 2.2 内部样式 2.3 外部样式 3.样式表的优先级 4.CSS语法规范 5.CSS代码风格 二、CSS选择器 1.CSS基本选择器 通配选择器元素选择器类选择器id选择器 1.1 通配选择器 1.2 元素选择器 1.3 类选择器 1.4 ID选择器 1.5 基…...