【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

1 算法原理
论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 11516–11524.
Amnesiac Unlearning(遗忘性遗忘) 是一种高效且精确的算法,旨在从已经训练好的神经网络模型中删除特定数据的学习信息,而不会显著影响模型在其他数据上的性能。该算法的核心思想是通过选择性撤销与敏感数据相关的参数更新来实现数据的“遗忘”。
1. 训练阶段:记录参数更新
在模型训练过程中,记录每个批次的参数更新以及哪些批次包含敏感数据。
- 步骤:
- 初始化模型参数:从随机初始化的参数 θ i n i t i a l \theta_{initial} θinitial 开始训练模型。
- 训练模型:使用标准训练方法(如随机梯度下降)对模型进行训练,训练过程分为多个 epoch,每个 epoch 包含多个批次(batches)。
- 记录参数更新:
- 对于每个批次 b b b,记录该批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b,其中 e e e 表示 epoch 编号, b b b 表示批次编号。
- 同时,记录哪些批次包含敏感数据(即需要删除的数据)。可以将这些批次标记为 S B SB SB(Sensitive Batches)。
- 存储信息:
- 存储所有批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b。
- 存储敏感数据批次的索引 S B SB SB。
2. 数据删除阶段:选择性撤销参数更新
当收到数据删除请求时,撤销与敏感数据相关的参数更新。
-
步骤:
- 识别敏感数据批次:从存储的记录中提取包含敏感数据的批次索引 S B SB SB。
- 撤销参数更新:
计算删除敏感数据后的模型参数 θ M \theta_{M} θM:
θ M ′ = θ M − ∑ s b ∈ S B Δ θ s b \theta_{M'} = \theta_{M} - \sum_{sb \in SB} \Delta_{\theta_{sb}} θM′=θM−sb∈SB∑Δθsb其中:
- θ M \theta_{M} θM 是原始训练后的模型参数。
- Δ θ s b \Delta_{\theta_{sb}} Δθsb 是敏感数据批次 s b sb sb 的参数更新。
- 生成保护模型:使用更新后的参数 θ M ′ \theta_{M'} θM′ 作为新的模型参数。
3. 微调阶段(可选)
如果删除的批次较多,可能会对模型性能产生一定影响。此时可以通过少量微调来恢复模型性能。
- 步骤:
- 微调模型:使用删除敏感数据后的数据集对模型进行少量迭代训练。
- 恢复性能:通过微调,模型可以恢复在非敏感数据上的性能。
2 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from models.Base import load_MNIST_data, test_model, device, MLP, load_CIFAR100_data, init_model# AmnesiacForget类:封装撤销与敏感数据相关的参数更新
class AmnesiacForget:def __init__(self, model, all_data, epochs, learning_rate):"""初始化 AmnesiacForget 类。:param model: 需要训练的模型。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数。:param learning_rate: 优化器的学习率。"""self.model = modelself.all_data = all_dataself.epochs = epochsself.learning_rate = learning_rateself.batch_updates = [] # 存储每个批次的参数更新值self.initial_params = {name: param.clone() for name, param in model.named_parameters()} # 存储初始模型参数self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备选择(GPU 或 CPU)def train(self, forgotten_classes):"""训练模型并记录每个批次的参数更新值。:param forgotten_classes: 需要遗忘的类别列表。:return: sensitive_batches: 包含敏感数据的批次索引。"""optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) # 使用 Adam 优化器self.model.train() # 将模型设置为训练模式sensitive_batches = {} # 记录每个 epoch 中包含敏感数据的批次索引# 训练过程for epoch in range(self.epochs):running_loss = 0.0sensitive_batches[epoch] = set() # 每个 epoch 的敏感批次集for batch_idx, (images, labels) in enumerate(self.all_data):optimizer.zero_grad() # 清空梯度images, labels = images.to(self.device), labels.to(self.device) # 将数据移动到设备上# 前向传播和损失计算outputs = self.model(images)loss = nn.CrossEntropyLoss()(outputs, labels)# 反向传播计算梯度loss.backward()running_loss += loss.item()# 记录当前参数值current_params = {name: param.clone() for name, param in self.model.named_parameters()}# 更新参数optimizer.step()# 记录参数更新值(当前参数值 - 更新前的参数值)batch_update = {}for name, param in self.model.named_parameters():if param.requires_grad:batch_update[name] = param.data - current_params[name].data # 记录参数更新值self.batch_updates.append(batch_update)# 记录包含敏感数据的批次索引if any(label.item() in forgotten_classes for label in labels):sensitive_batches[epoch].add(batch_idx)print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {running_loss/len(self.all_data):.4f}")return sensitive_batchesdef unlearn(self, sensitive_batches):"""撤销与敏感数据相关的批次更新。:param sensitive_batches: 包含敏感数据的批次索引。:return: 更新后的模型。"""# 计算非敏感批次的参数更新总和non_sensitive_updates = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}for batch_idx, batch_update in enumerate(self.batch_updates):if batch_idx not in {sb for epoch_batches in sensitive_batches.values() for sb in epoch_batches}:for name, update in batch_update.items():non_sensitive_updates[name] += update# 更新模型参数:初始参数 + 非敏感批次的更新for name, param in self.model.named_parameters():param.data = self.initial_params[name].data + non_sensitive_updates[name]return self.model# 全局函数:实现 Amnesiac Forget
def amnesiac_unlearning(model_before, test_loader, forgotten_classes, all_data, epochs=10, learning_rate=0.001):"""执行 Amnesiac Unlearning:训练模型,记录参数更新,并撤销与敏感数据相关的更新。:param model_before: 遗忘前的模型。:param test_loader: 测试数据加载器。:param forgotten_classes: 需要遗忘的类别列表。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数(默认为 10)。:param learning_rate: 优化器的学习率(默认为 0.001)。:return: 遗忘后的模型。"""# 模拟从头训练的过程,并记录批次更新的过程print("模拟重新训练过程,记录批次更新...")temp_model = MLP().to(device) # 初始化一个新模型amnesiac_forget = AmnesiacForget(temp_model, all_data, epochs, learning_rate) # 初始化 AmnesiacForget 类sensitive_batches = amnesiac_forget.train(forgotten_classes) # 训练模型并记录敏感批次# 测试遗忘前的模型性能overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(amnesiac_forget.model, test_loader)print(f"全部准确率: {overall_acc_before:.2f}%, 保留准确率: {retained_acc_before:.2f}%, 遗忘准确率: {forgotten_acc_before:.2f}%")# 应用遗忘:撤销与敏感数据相关的批次更新model_after = amnesiac_forget.unlearn(sensitive_batches)return model_afterdef main():# 超参数设置batch_size = 256forgotten_classes = [0] # 需要遗忘的类别ratio = 1model_name = "ResNet18" # 模型名称# 加载数据if model_name == "MLP":train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)elif model_name == "ResNet18":train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)# 初始化模型model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)# 实现遗忘操作print("执行遗忘 Amnesiac...")model_after = amnesiac_unlearning(model_before, test_loader, forgotten_classes, train_loader, epochs=5, learning_rate=0.001)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning 前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning 后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning 前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning 后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()
3 总结
- 高效性:只需撤销与敏感数据相关的参数更新,避免了从头训练模型的高成本。
- 精确性:能够精确删除特定数据的学习信息,特别适合删除少量数据。
- 存储成本:需要存储每个批次的参数更新,存储成本较高,但通常低于从头训练模型的成本。
- 适用场景:适合删除少量数据(如单个样本或少量样本),而不适合删除大量数据(如整个类别)。
相关文章:
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning 1 算法原理 论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 115…...
Vue 3 30天精进之旅:Day 03 - Vue实例
引言 在前两天的学习中,我们成功搭建了Vue.js的开发环境,并创建了我们的第一个Vue项目。今天,我们将深入了解Vue的核心概念之一——Vue实例。通过学习Vue实例,你将理解Vue的基础架构,掌握数据绑定、模板语法和指令的使…...
【ArcGIS微课1000例】0141:提取多波段影像中的单个波段
文章目录 一、波段提取函数二、加载单波段导出问题描述:如下图所示,img格式的时序NDVI数据有24个波段。现在需要提取某一个波段,该怎样操作? 一、波段提取函数 首先加载多波段数据。点击【窗口】→【影像分析】。 选择需要处理的多波段影像,点击下方的【添加函数】。 在多…...
【第九天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-六种常见的图论算法(持续更新)
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的图论算法2. 图论算法3.详细的图论算法1)深度优先搜索(DFS)2…...
落地 轮廓匹配
个人理解为将一幅不规则的图形,通过最轮廓发现,最大轮廓匹配来确定图像的位置,再通过pt将不规则的图像放在规定的矩形里面,在通过透视变换将不规则的图形放进规则的图像中。 1. findHomography 函数 • Mat h findHomography(s…...
【漫话机器学习系列】064.梯度下降小口诀(Gradient Descent rule of thume)
梯度下降小口诀 为了帮助记忆梯度下降的核心原理和关键注意事项,可以用以下简单口诀来总结: 1. 基本原理 损失递减,梯度为引:目标是让损失函数减少,依靠梯度指引方向。负梯度,反向最短:沿着负…...
JAVA(SpringBoot)集成Kafka实现消息发送和接收。
SpringBoot集成Kafka实现消息发送和接收。 一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者 君子之学贵一,一则明,明则有功。 一、Kafka 简介 Kafka 是由 Apache 软件基金会开发的一个开源流处理平台,最初由 Link…...
AI刷题-蛋糕工厂产能规划、优质章节的连续选择
挑两个简单的写写 目录 一、蛋糕工厂产能规划 问题描述 输入格式 输出格式 解题思路: 问题理解 数据结构选择 算法步骤 关键点 最终代码: 运行结果:编辑 二、优质章节的连续选择 问题描述 输入格式 输出格式 解题思路&a…...
在线可编辑Excel
1. Handsontable 特点: 提供了类似 Excel 的表格编辑体验,包括单元格样式、公式计算、数据验证等功能。 支持多种插件,如筛选、排序、合并单元格等。 轻量级且易于集成到现有项目中。 具备强大的自定义能力,可以调整外观和行为…...
什么是词嵌入?Word2Vec、GloVe 与 FastText 的区别
自然语言处理(NLP)领域的核心问题之一,是如何将人类的语言转换成计算机可以理解的数值形式,而词嵌入(Word Embedding)正是为了解决这个问题的重要技术。本文将详细讲解词嵌入的概念及其经典模型(Word2Vec、GloVe 和 FastText)的原理与区别。 1. 什么是词嵌入(Word Em…...
WPS数据分析000010
基于数据透视表的内容 一、排序 手动调动 二、筛选 三、值显示方式 四、值汇总依据 五、布局和选项 不显示分类汇总 合并居中带标签的单元格 空单元格显示 六、显示报表筛选页...
Qt中QVariant的使用
1.使用QVariant实现不同类型数据的相加 方法:通过type函数返回数值的类型,然后通过setValue来构造一个QVariant类型的返回值。 函数: QVariant mainPage::dataPlus(QVariant a, QVariant b) {QVariant ret;if ((a.type() QVariant::Int) &a…...
Avalonia UI MVVM DataTemplate里绑定Command
Avalonia 模板里面绑定ViewModel跟WPF写法有些不同。需要单独绑定Command. WPF里面可以直接按照下面的方法绑定DataContext. <Button Content"Button" Command"{Binding DataContext.ClickCommand, RelativeSource{RelativeSource AncestorType{x:Type User…...
动态规划DP 数字三角型模型 最低通行费用(题目详解+C++代码完整实现)
最低通行费用 原题链接 AcWing 1018. 最低同行费用 题目描述 一个商人穿过一个 NN的正方形的网格,去参加一个非常重要的商务活动。 他要从网格的左上角进,右下角出。每穿越中间 1个小方格,都要花费 1个单位时间。商人必须在 (2N−1)个单位…...
deepseek R1的确不错,特别是深度思考模式
deepseek R1的确不错,特别是深度思考模式,每次都能自我反省改进。比如我让 它写文案: 【赛博朋克版程序员新春密码——2025我们来破局】 亲爱的代码骑士们: 当CtrlS的肌肉记忆遇上抢票插件,当Spring Boot的…...
Linux 常用命令 - sort 【对文件内容进行排序】
简介 sort 命令源于英文单词 “sort”,表示排序。其主要功能是对文本文件中的行进行排序。它可以根据字母、数字、特定字段等不同的标准进行排序。sort 通过逐行读取文件(没有指定文件或指定文件为 - 时读取标准输入)内容,并按照…...
MyBatis最佳实践:提升数据库交互效率的秘密武器
第一章:框架的概述: MyBatis 框架的概述: MyBatis 是一个优秀的基于 Java 的持久框架,内部对 JDBC 做了封装,使开发者只需要关注 SQL 语句,而不关注 JDBC 的代码,使开发变得更加的简单MyBatis 通…...
选择困难?直接生成pynput快捷键字符串
from pynput import keyboard# 文档:https://pynput.readthedocs.io/en/latest/keyboard.html#monitoring-the-keyboard # 博客(pynput相关源码):https://blog.csdn.net/qq_39124701/article/details/145230331 # 虚拟键码(十六进制):https:/…...
DeepSeek-R1:强化学习驱动的推理模型
1月20日晚,DeepSeek正式发布了全新的推理模型DeepSeek-R1,引起了人工智能领域的广泛关注。该模型在数学、代码生成等高复杂度任务上表现出色,性能对标OpenAI的o1正式版。同时,DeepSeek宣布将DeepSeek-R1以及相关技术报告全面开源。…...
国内优秀的FPGA设计公司主要分布在哪些城市?
近年来,国内FPGA行业发展迅速,随着5G通信、人工智能、大数据等新兴技术的崛起,FPGA设计企业的需求也迎来了爆发式增长。很多技术人才在求职时都会考虑城市的行业分布和发展潜力。因此,国内优秀的FPGA设计公司主要分布在哪些城市&a…...
颠覆性创新:为什么Upkie开源轮式双足机器人正在重新定义机器人开发范式
颠覆性创新:为什么Upkie开源轮式双足机器人正在重新定义机器人开发范式 【免费下载链接】upkie Open-source wheeled biped robots 项目地址: https://gitcode.com/gh_mirrors/up/upkie 在传统机器人设计面临轮式与足式两难选择的今天,一个革命性…...
从日志到环境变量:根治 Android Studio AVD 启动报错“The emulator process has terminated”
1. 从错误弹窗到日志分析:定位问题的第一步 当你兴冲冲地打开Android Studio准备启动AVD(Android Virtual Device)时,突然弹出一个冰冷的提示框:"The emulator process has terminated",这感觉就…...
用PyTorch和ECANet18搞定RAF-DB表情分类:从数据集下载到模型部署的保姆级教程
基于ECANet18的RAF-DB表情识别实战:从零构建高精度分类模型 人脸表情识别(FER)作为计算机视觉领域的重要分支,在情感计算、智能交互等领域展现出巨大潜力。本文将带您完整实现一个基于PyTorch和ECANet18的端到端表情识别系统&…...
如何用PCL2启动器打造完美的Minecraft模组体验:从零到精通的完整指南
如何用PCL2启动器打造完美的Minecraft模组体验:从零到精通的完整指南 【免费下载链接】PCL Minecraft 启动器 Plain Craft Launcher(PCL)。 项目地址: https://gitcode.com/gh_mirrors/pc/PCL 你是否厌倦了每次启动Minecraft都要手动配…...
LangGraph 并发执行不是开 Goroutine 那么简单:状态竞争与事务处理
LangGraph 并发执行不是开 Goroutine 那么简单:状态竞争与事务处理深度解析 元数据 关键词:LangGraph, 大语言模型工作流, 有状态并发, 状态一致性, 事务处理, 多Agent系统, 分布式状态管理 摘要:很多开发者初次接触LangGraph的并发特性时,会下意识将其等同于传统协程/线程…...
迪拜塔幕墙设计
迪拜塔幕墙设计 【作 者】:罗永增 【关键词】:迪拜塔,幕墙,设计,系统。 前言:...
基于MCP协议构建AI金融数据可视化服务器:从原理到实战部署
1. 项目概述:一个为AI智能体提供实时金融数据可视化的MCP服务器最近在折腾AI智能体(Agent)的生态,发现一个挺有意思的痛点:当你想让AI帮你分析股票、基金或者加密货币时,它往往只能给你干巴巴的数字和文字描…...
药物发现自动化:FEP计算工作流引擎faah的设计原理与实战
1. 项目概述:一个面向药物发现的自动化工作流引擎 最近在药物研发的自动化工具领域,一个名为 kiron0/faah 的项目引起了我的注意。这并非一个简单的脚本集合,而是一个设计精巧、旨在为药物发现中的自由能微扰计算提供端到端自动化解决方案的…...
Claw框架数据库迁移工具claw-migrate:原理、实践与团队协作指南
1. 项目概述:一个专为Claw设计的迁移工具最近在折腾一个叫Claw的开源项目,它本身是一个轻量级的Web框架,用起来挺顺手。但项目迭代过程中,难免会遇到数据库结构变更、数据迁移这类“脏活累活”。手动写SQL脚本?太原始&…...
【仿真学习框架】HoloMotion 从入门到精通:全身人形控制 Foundation Model 完全指南
HoloMotion 从入门到精通:全身人形控制 Foundation Model 完全指南 目标读者:具身智能研究者、人形机器人开发者、RL/机器人学习工程师 目录 第1章 HoloMotion 全景概览 1.1 什么是 HoloMotion 1.2 技术定位:"小脑"基座模型 1.3 4-Any 愿景与路线图 1.4 核心能力矩…...
