当前位置: 首页 > article >正文

保姆级教程:用PyTorch从零搭建联邦学习MNIST实验环境(附完整代码)

联邦学习实战PyTorch搭建MNIST实验环境全流程解析1. 联邦学习与MNIST实验概述联邦学习作为一种分布式机器学习范式正在重塑传统模型训练的方式。不同于集中式训练联邦学习允许多个客户端在保持数据本地化的前提下协作训练模型特别适合手写数字识别这类需要隐私保护的场景。MNIST作为经典的入门级数据集包含60,000张28x28像素的灰度手写数字图像是验证联邦学习算法的理想选择。在典型的联邦学习框架中我们需要处理几个核心组件客户端数据划分将MNIST训练集按IID独立同分布方式分配给多个客户端本地模型训练每个客户端基于分配到的数据独立训练模型参数聚合服务器收集客户端模型参数并进行加权平均全局模型更新将聚合后的参数分发给客户端进行下一轮训练# 联邦学习基本流程伪代码 for round in range(total_rounds): # 选择参与本轮训练的客户端 selected_clients select_clients(clients, selection_ratio) # 客户端本地训练 client_updates [] for client in selected_clients: local_model train_locally(client, global_model) client_updates.append(local_model.state_dict()) # 服务器聚合更新 global_model aggregate_updates(global_model, client_updates)2. 实验环境搭建2.1 基础环境配置推荐使用Python 3.8和PyTorch 1.10环境。以下是使用conda创建环境的命令conda create -n fl_env python3.8 conda activate fl_env pip install torch torchvision numpy matplotlib2.2 项目目录结构合理的项目结构能显著提高代码可维护性federated-mnist/ ├── data/ │ ├── raw/ # 原始MNIST数据 │ ├── processed/ # 处理后的数据 │ └── clients/ # 客户端数据划分 ├── models/ │ └── cnn.py # 模型定义 ├── utils/ │ ├── data_utils.py # 数据预处理工具 │ └── fl_utils.py # 联邦学习辅助函数 ├── config.py # 参数配置 ├── server.py # 服务器逻辑 └── client.py # 客户端逻辑2.3 数据准备与IID划分MNIST数据集的IID划分是联邦学习实验的基础步骤。我们需要将60,000个训练样本均匀分配到100个客户端每个客户端获得600个样本def split_iid(dataset, num_clients): num_items len(dataset) // num_clients client_dict {} indices np.random.permutation(len(dataset)) for i in range(num_clients): client_dict[i] indices[i*num_items : (i1)*num_items] return client_dict注意确保每个客户端获得均衡的类别分布可通过检查每个客户端的标签分布验证IID属性。3. 核心代码实现3.1 模型架构设计我们采用经典的CNN结构处理MNIST图像class MNIST_CNN(nn.Module): def __init__(self): super(MNIST_CNN, self).__init__() self.conv1 nn.Conv2d(1, 32, 5, padding2) self.pool nn.MaxPool2d(2, 2) self.conv2 nn.Conv2d(32, 64, 5, padding2) self.fc1 nn.Linear(64*7*7, 512) self.fc2 nn.Linear(512, 10) def forward(self, x): x self.pool(F.relu(self.conv1(x))) x self.pool(F.relu(self.conv2(x))) x x.view(-1, 64*7*7) x F.relu(self.fc1(x)) x self.fc2(x) return x3.2 客户端本地训练客户端本地训练的关键实现def client_train(model, trainloader, epochs, lr0.01): model.train() optimizer torch.optim.SGD(model.parameters(), lrlr) criterion nn.CrossEntropyLoss() for epoch in range(epochs): for data, labels in trainloader: optimizer.zero_grad() outputs model(data) loss criterion(outputs, labels) loss.backward() optimizer.step() return model.state_dict()3.3 服务器聚合算法实现FedAvg聚合算法def aggregate_weights(client_weights): 执行FedAvg参数聚合 global_weights {} # 初始化全局参数 for key in client_weights[0].keys(): global_weights[key] torch.zeros_like(client_weights[0][key]) # 加权平均 total_samples sum([weights[num_samples] for weights in client_weights]) for weights in client_weights: for key in global_weights: global_weights[key] weights[key] * (weights[num_samples] / total_samples) return global_weights4. 实验执行与调优4.1 关键参数配置联邦学习中有三个核心超参数需要特别关注参数描述典型值影响C客户端选择比例0.1影响通信成本和模型多样性E本地训练epoch数1-5影响计算开销和本地拟合程度B本地batch大小10-600影响训练稳定性和效率4.2 训练循环实现完整的训练流程实现def run_federated(num_rounds100, num_clients100, C0.1, E5, B64): # 初始化全局模型 global_model MNIST_CNN() # 准备数据 train_dataset MNIST(root./data, trainTrue, transformtransforms.ToTensor()) test_dataset MNIST(root./data, trainFalse, transformtransforms.ToTensor()) client_loaders create_iid_loaders(train_dataset, num_clients, B) # 训练循环 for round in range(num_rounds): # 选择客户端 selected np.random.choice(num_clients, int(num_clients*C), replaceFalse) # 客户端更新 client_weights [] for client_id in selected: local_model copy.deepcopy(global_model) weights client_train(local_model, client_loaders[client_id], E) client_weights.append({weights: weights, num_samples: len(client_loaders[client_id].dataset)}) # 聚合更新 global_weights aggregate_weights(client_weights) global_model.load_state_dict(global_weights) # 评估 test_acc evaluate(global_model, test_dataset) print(fRound {round1}, Test Acc: {test_acc:.2f}%)4.3 常见问题排查在实际运行中可能会遇到以下典型问题内存不足减小batch size或客户端选择比例收敛缓慢调整学习率或增加本地epoch数客户端漂移使用客户端动量或学习率衰减通信瓶颈考虑模型压缩或异步更新提示使用固定随机种子如torch.manual_seed(42)确保实验可复现性5. 实验结果分析与可视化5.1 性能指标跟踪记录每轮测试准确率并可视化def plot_results(acc_history): plt.figure(figsize(10, 6)) plt.plot(acc_history, labelTest Accuracy) plt.xlabel(Communication Rounds) plt.ylabel(Accuracy (%)) plt.title(Federated Learning Performance) plt.grid(True) plt.legend() plt.show()5.2 参数对比实验比较不同超参数配置下的表现配置最终准确率收敛速度计算开销C0.1, E192.3%中等低C0.2, E594.7%快高C1.0, E393.8%最快最高5.3 扩展实验建议为进一步提升实验价值可以考虑非IID数据划分的影响不同聚合算法的比较如FedProx客户端差分隐私保护模型压缩对通信效率的影响6. 工程实践建议在实际项目中应用联邦学习时有几个实用技巧值得注意数据预处理标准化确保所有客户端使用相同的预处理流程模型版本控制跟踪每轮迭代的模型变化容错机制处理客户端离线或延迟的情况资源监控跟踪CPU/GPU利用率和网络开销# 简单的资源监控装饰器 def monitor_resources(func): def wrapper(*args, **kwargs): start_time time.time() start_mem psutil.Process().memory_info().rss / 1024 / 1024 result func(*args, **kwargs) end_time time.time() end_mem psutil.Process().memory_info().rss / 1024 / 1024 print(fExecution time: {end_time-start_time:.2f}s) print(fMemory usage: {end_mem-start_mem:.2f}MB) return result return wrapper联邦学习的魅力在于其分布式特性与实际应用的契合度。在MNIST上的实践只是起点相同的框架可以扩展到更复杂的模型和更具挑战性的数据集。经过多次实验发现客户端选择策略对最终模型性能的影响往往比预期更大这在实际业务场景中需要特别关注。

相关文章:

保姆级教程:用PyTorch从零搭建联邦学习MNIST实验环境(附完整代码)

联邦学习实战:PyTorch搭建MNIST实验环境全流程解析 1. 联邦学习与MNIST实验概述 联邦学习作为一种分布式机器学习范式,正在重塑传统模型训练的方式。不同于集中式训练,联邦学习允许多个客户端在保持数据本地化的前提下协作训练模型&#xff0…...

从零解析ATK1218-BD:Arduino实战中的北斗/GPS数据获取与NMEA协议解读

1. 从零认识ATK1218-BD模块 第一次拿到这个火柴盒大小的北斗/GPS双模定位模块时,我完全没想到它能输出这么多信息。ATK1218-BD是正点原子推出的一款工业级定位模块,特别适合用在无人机、车载导航这些需要高精度定位的场景。和普通GPS模块最大的区别是它能…...

绿联NAS上利用Docker部署SearXNG与Open-WebUI的YAML配置实战

1. 绿联NAS与Docker的完美组合 如果你手头有一台绿联NAS,那你就拥有了一个强大的家庭数据中心。作为国产NAS中的佼佼者,绿联NAS不仅提供了友好的操作界面,还内置了Docker支持,这让它成为了技术爱好者折腾的理想平台。我用了大半年…...

SEO_内容与SEO如何结合?高效优化步骤详解

SEO与内容结合:高效优化步骤详解 在当今数字化时代,搜索引擎优化(SEO)和内容营销无疑是提升网站流量和品牌影响力的关键。SEO和内容的结合并不是一件简单的事情。很多人可能在这两者之间产生困惑,不知道如何在保持内容…...

GPS定位误差从几十米到厘米级:RTK技术如何实现高精度定位(附手机实测对比)

GPS定位误差从几十米到厘米级:RTK技术如何实现高精度定位(附手机实测对比) 你是否曾在城市峡谷中看着导航地图上飘忽不定的定位箭头哭笑不得?或是户外徒步时发现轨迹记录偏离实际路线数十米?这些困扰背后,是…...

幻兽帕鲁存档修复终极指南:3步解决服务器迁移数据丢失问题

幻兽帕鲁存档修复终极指南:3步解决服务器迁移数据丢失问题 【免费下载链接】palworld-host-save-fix Fixes the bug which forces a player to create a new character when they already have a save. Useful for migrating maps from co-op to dedicated servers …...

差动保护:电力系统的核心安全保障技术

差动保护电流差动保护是电力系统的"铁闸门",核心思想简单粗暴:比较设备两端的电流是否对得上账。就像两个会计同时记账,如果两边数据差太多,肯定有人搞鬼——要么线路漏电,要么设备内部短路。举个接地气的例…...

3大突破!NormalMap-Online让3D材质制作效率提升10倍的终极解决方案

3大突破!NormalMap-Online让3D材质制作效率提升10倍的终极解决方案 【免费下载链接】NormalMap-Online NormalMap Generator Online 项目地址: https://gitcode.com/gh_mirrors/no/NormalMap-Online 在3D建模领域,如何快速将普通图片转化为具有真…...

YimMenu安全指南与效率提升:GTA5辅助工具全面应用手册

YimMenu安全指南与效率提升:GTA5辅助工具全面应用手册 【免费下载链接】YimMenu YimMenu, a GTA V menu protecting against a wide ranges of the public crashes and improving the overall experience. 项目地址: https://gitcode.com/GitHub_Trending/yi/YimM…...

跨游戏模组协同:XXMI启动器智能管理解决方案

跨游戏模组协同:XXMI启动器智能管理解决方案 【免费下载链接】XXMI-Launcher Modding platform for GI, HSR, WW and ZZZ 项目地址: https://gitcode.com/gh_mirrors/xx/XXMI-Launcher 当你同时游玩《原神》《崩坏:星穹铁道》《鸣潮》等多款二次元…...

文本输入组件核心讲解与实战

一、文本输入类组件核心认知(一)组件整体定位TextInput、TextArea、Search是鸿蒙ArkTS核心文本输入类组件,基于统一输入底层能力封装,支持通用样式与高频事件;针对单行短文本、多行长文本、搜索专属三大场景做差异化优…...

NeuroKit2深度解析:Python神经生理信号处理的进阶实战指南

NeuroKit2深度解析:Python神经生理信号处理的进阶实战指南 【免费下载链接】NeuroKit NeuroKit2: The Python Toolbox for Neurophysiological Signal Processing 项目地址: https://gitcode.com/gh_mirrors/ne/NeuroKit 在当今神经科学和生物医学工程领域&a…...

5分钟Mac本地跑通32B Qwen!免费GPT-4o替代,还能5分钟造个会开浏览器+执行Shell的AI Agent

1. 硬件与模型选择 配置:Apple M2 Pro(19 核 GPU)、32GB 统一内存。 推荐模型:mlx-community/Qwen2.5-Coder-32B-Instruct-4bit 4bit 量化后只占 18-22GB 内存专为代码和 Agent 优化,Tool Calling 能力强MLX 原生支持…...

Vim-signify 异步更新技巧:让你的 Vim 编辑器更智能

Vim-signify 异步更新技巧:让你的 Vim 编辑器更智能 【免费下载链接】vim-signify :heavy_plus_sign: Show a diff using Vim its sign column. 项目地址: https://gitcode.com/gh_mirrors/vi/vim-signify Vim-signify 是一个强大的 Vim/Neovim 插件&#xf…...

关于reverse的tea题目回顾

ea的短暂性小总结说实话今天做的内容不算太多,但是感觉很超出自己的承受范围。 话不多说进行短暂总结tea模式tea的题目做起来的话公式比较固定。就比如用下面这个简单的题目进行示范这个就是图片,有en和de两种模式。de是我自己写出来的。查看en代码时能够…...

告别残差加法,Kimi 给神经网络换了个 “智能引擎”

来源:算法进阶 本文约2800字,建议阅读6分钟本文介绍了 Kimi 团队用 Attention Residuals 替代传统残差机制的成果。只要接触深度学习神经网络的读者们对「」一定不会陌生。自从 2015 年 ResNet 诞生以来,这种「将输入直接加到输出上」的简单逻…...

OpCore-Simplify:如何用四步自动化配置解决黑苹果安装难题?

OpCore-Simplify:如何用四步自动化配置解决黑苹果安装难题? 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-Simplify OpCore-Simplify是…...

革新性量化交易平台:基于Backtrader的高效策略回测工具实现方法

革新性量化交易平台:基于Backtrader的高效策略回测工具实现方法 【免费下载链接】backtrader-pyqt-ui 项目地址: https://gitcode.com/gh_mirrors/bac/backtrader-pyqt-ui Backtrader可视化平台是一款融合PyQt界面框架与finplot图表库的革新性量化交易回测工…...

从作业到考试:中科大数字图像分析(DIA)课程避坑与自学指南

中科大数字图像分析(DIA)课程高效学习与实战避坑指南 数字图像分析(DIA)作为中科大电子工程与信息科学系的专业基础课,以其知识面广、难度高著称。每年都有不少同学因低估课程强度而陷入"上课听不懂、作业不会做、考前突击难"的困境。本文将系统梳理从日常…...

Microsoft团队提出“弯曲雅各布天梯”新思路,了解量子数据如何教会AI做更好的化学

来源:ScienceAI 本文约3500字,建议阅读5分钟量子计算机生成精确数据,AI模型学习并实现百万倍加速预测。有时,一个视觉上引人注目的隐喻,足以让你传达一个复杂的观点。2001 年夏天,杜兰大学物理教授 John P.…...

前端开发中的加载指示器(Loading Spinners)一种动态旋转的图形元素(如圆圈、齿轮状动画)

在 Android 中,Spinner 是一个下拉选择控件,用于从预定义列表中选择一项。以下是标准、稳定、兼容性好的实现方式(基于 ViewBinding ArrayAdapter,适配 AndroidX 和 API 21):✅ 一、绑定数据(以…...

C 里面如何使用链表 list

1. 学生时代, 那会学习 C 数据结构, 比较简单 struct person {int id;char name[641];struct person * next; }; 类似上面这样, 需要什么依赖 next 指针来回调整, 然后手工 print F5 去 debug 熬. 2. 刚工作青年时代, 主要花活, 随大流类似 #pragma once#include "stru…...

TensorFlow开发中用到的一些第三方库

本节介绍下后面开发要用到的辅助库,并做一些简单的代码实例和效果演示,当然我们都是为了最终目标TensorFlow开发做准备的,用到的也是这些库的简单的api,这里做简单的介绍为后面TensorFlow开发做准备,对于这些库的深入研…...

GHelper:华硕笔记本性能优化与硬件控制的开源解决方案

GHelper:华硕笔记本性能优化与硬件控制的开源解决方案 【免费下载链接】g-helper Lightweight, open-source control tool for ASUS laptops and ROG Ally. Manage performance modes, fans, GPU, battery, and RGB lighting across Zephyrus, Flow, TUF, Strix, Sc…...

TensorFlow的一些基本概念

分类问题和回归问题 在实际生活中,人们面临的问题无非就是离散的和连续的。 比方区分出某个人属于男性还是女性,比方衣服是什么颜色的,什么种类的,这些都是在有限数量的结果中寻找答案,也就是最终结果只能是N个里面的某…...

NI USB-6210 DAQ采集卡开箱照

1、包装非常简单,有点对不起它6000~7000元的价格:2、 内部也没有什么特别的:3、一张用户须知,一本使用说明:4、一张光盘,感觉有点Low,现在电脑很少有光驱了:5、这条USB线据说要200大…...

SmolVLA企业应用:轻量级VLA模型赋能AGV分拣与桌面机械臂

SmolVLA企业应用:轻量级VLA模型赋能AGV分拣与桌面机械臂 1. 引言:当机器人开始“看懂”世界 想象一下,你对着一个机械臂说:“把那个红色的方块拿起来,放到蓝色的盒子里。”然后它真的照做了。这不是科幻电影&#xf…...

7大核心优势!D3KeyHelper暗黑3智能宏工具全面解析:从手动操作到自动化体验的升级之路

7大核心优势!D3KeyHelper暗黑3智能宏工具全面解析:从手动操作到自动化体验的升级之路 【免费下载链接】D3keyHelper D3KeyHelper是一个有图形界面,可自定义配置的暗黑3鼠标宏工具。 项目地址: https://gitcode.com/gh_mirrors/d3/D3keyHelp…...

ai辅助开发:向快马描述需求,直接生成jdk1.8实现的控制台通讯录项目

最近在尝试用Java开发一个简单的命令行通讯录程序,正好借这个机会体验了一把AI辅助开发的便利。整个过程让我深刻感受到,合理利用工具真的能大幅提升开发效率。下面记录下这个项目的实现思路和关键点,或许对同样想用JDK1.8练手的朋友有帮助。…...

突破8大平台限制:开源工具实现高速下载的3种创新方案

突破8大平台限制:开源工具实现高速下载的3种创新方案 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿里云盘 / 中国移动云盘 / 天翼云…...