【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

1 算法原理
论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 11516–11524.
Amnesiac Unlearning(遗忘性遗忘) 是一种高效且精确的算法,旨在从已经训练好的神经网络模型中删除特定数据的学习信息,而不会显著影响模型在其他数据上的性能。该算法的核心思想是通过选择性撤销与敏感数据相关的参数更新来实现数据的“遗忘”。
1. 训练阶段:记录参数更新
在模型训练过程中,记录每个批次的参数更新以及哪些批次包含敏感数据。
- 步骤:
- 初始化模型参数:从随机初始化的参数 θ i n i t i a l \theta_{initial} θinitial 开始训练模型。
- 训练模型:使用标准训练方法(如随机梯度下降)对模型进行训练,训练过程分为多个 epoch,每个 epoch 包含多个批次(batches)。
- 记录参数更新:
- 对于每个批次 b b b,记录该批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b,其中 e e e 表示 epoch 编号, b b b 表示批次编号。
- 同时,记录哪些批次包含敏感数据(即需要删除的数据)。可以将这些批次标记为 S B SB SB(Sensitive Batches)。
- 存储信息:
- 存储所有批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b。
- 存储敏感数据批次的索引 S B SB SB。
2. 数据删除阶段:选择性撤销参数更新
当收到数据删除请求时,撤销与敏感数据相关的参数更新。
-
步骤:
- 识别敏感数据批次:从存储的记录中提取包含敏感数据的批次索引 S B SB SB。
- 撤销参数更新:
计算删除敏感数据后的模型参数 θ M \theta_{M} θM:
θ M ′ = θ M − ∑ s b ∈ S B Δ θ s b \theta_{M'} = \theta_{M} - \sum_{sb \in SB} \Delta_{\theta_{sb}} θM′=θM−sb∈SB∑Δθsb其中:
- θ M \theta_{M} θM 是原始训练后的模型参数。
- Δ θ s b \Delta_{\theta_{sb}} Δθsb 是敏感数据批次 s b sb sb 的参数更新。
- 生成保护模型:使用更新后的参数 θ M ′ \theta_{M'} θM′ 作为新的模型参数。
3. 微调阶段(可选)
如果删除的批次较多,可能会对模型性能产生一定影响。此时可以通过少量微调来恢复模型性能。
- 步骤:
- 微调模型:使用删除敏感数据后的数据集对模型进行少量迭代训练。
- 恢复性能:通过微调,模型可以恢复在非敏感数据上的性能。
2 代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from models.Base import load_MNIST_data, test_model, device, MLP, load_CIFAR100_data, init_model# AmnesiacForget类:封装撤销与敏感数据相关的参数更新
class AmnesiacForget:def __init__(self, model, all_data, epochs, learning_rate):"""初始化 AmnesiacForget 类。:param model: 需要训练的模型。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数。:param learning_rate: 优化器的学习率。"""self.model = modelself.all_data = all_dataself.epochs = epochsself.learning_rate = learning_rateself.batch_updates = [] # 存储每个批次的参数更新值self.initial_params = {name: param.clone() for name, param in model.named_parameters()} # 存储初始模型参数self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设备选择(GPU 或 CPU)def train(self, forgotten_classes):"""训练模型并记录每个批次的参数更新值。:param forgotten_classes: 需要遗忘的类别列表。:return: sensitive_batches: 包含敏感数据的批次索引。"""optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) # 使用 Adam 优化器self.model.train() # 将模型设置为训练模式sensitive_batches = {} # 记录每个 epoch 中包含敏感数据的批次索引# 训练过程for epoch in range(self.epochs):running_loss = 0.0sensitive_batches[epoch] = set() # 每个 epoch 的敏感批次集for batch_idx, (images, labels) in enumerate(self.all_data):optimizer.zero_grad() # 清空梯度images, labels = images.to(self.device), labels.to(self.device) # 将数据移动到设备上# 前向传播和损失计算outputs = self.model(images)loss = nn.CrossEntropyLoss()(outputs, labels)# 反向传播计算梯度loss.backward()running_loss += loss.item()# 记录当前参数值current_params = {name: param.clone() for name, param in self.model.named_parameters()}# 更新参数optimizer.step()# 记录参数更新值(当前参数值 - 更新前的参数值)batch_update = {}for name, param in self.model.named_parameters():if param.requires_grad:batch_update[name] = param.data - current_params[name].data # 记录参数更新值self.batch_updates.append(batch_update)# 记录包含敏感数据的批次索引if any(label.item() in forgotten_classes for label in labels):sensitive_batches[epoch].add(batch_idx)print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {running_loss/len(self.all_data):.4f}")return sensitive_batchesdef unlearn(self, sensitive_batches):"""撤销与敏感数据相关的批次更新。:param sensitive_batches: 包含敏感数据的批次索引。:return: 更新后的模型。"""# 计算非敏感批次的参数更新总和non_sensitive_updates = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}for batch_idx, batch_update in enumerate(self.batch_updates):if batch_idx not in {sb for epoch_batches in sensitive_batches.values() for sb in epoch_batches}:for name, update in batch_update.items():non_sensitive_updates[name] += update# 更新模型参数:初始参数 + 非敏感批次的更新for name, param in self.model.named_parameters():param.data = self.initial_params[name].data + non_sensitive_updates[name]return self.model# 全局函数:实现 Amnesiac Forget
def amnesiac_unlearning(model_before, test_loader, forgotten_classes, all_data, epochs=10, learning_rate=0.001):"""执行 Amnesiac Unlearning:训练模型,记录参数更新,并撤销与敏感数据相关的更新。:param model_before: 遗忘前的模型。:param test_loader: 测试数据加载器。:param forgotten_classes: 需要遗忘的类别列表。:param all_data: 训练数据集。:param epochs: 训练的总 epoch 数(默认为 10)。:param learning_rate: 优化器的学习率(默认为 0.001)。:return: 遗忘后的模型。"""# 模拟从头训练的过程,并记录批次更新的过程print("模拟重新训练过程,记录批次更新...")temp_model = MLP().to(device) # 初始化一个新模型amnesiac_forget = AmnesiacForget(temp_model, all_data, epochs, learning_rate) # 初始化 AmnesiacForget 类sensitive_batches = amnesiac_forget.train(forgotten_classes) # 训练模型并记录敏感批次# 测试遗忘前的模型性能overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(amnesiac_forget.model, test_loader)print(f"全部准确率: {overall_acc_before:.2f}%, 保留准确率: {retained_acc_before:.2f}%, 遗忘准确率: {forgotten_acc_before:.2f}%")# 应用遗忘:撤销与敏感数据相关的批次更新model_after = amnesiac_forget.unlearn(sensitive_batches)return model_afterdef main():# 超参数设置batch_size = 256forgotten_classes = [0] # 需要遗忘的类别ratio = 1model_name = "ResNet18" # 模型名称# 加载数据if model_name == "MLP":train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)elif model_name == "ResNet18":train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)# 初始化模型model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)# 实现遗忘操作print("执行遗忘 Amnesiac...")model_after = amnesiac_unlearning(model_before, test_loader, forgotten_classes, train_loader, epochs=5, learning_rate=0.001)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning 前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning 后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning 前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning 后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()
3 总结
- 高效性:只需撤销与敏感数据相关的参数更新,避免了从头训练模型的高成本。
- 精确性:能够精确删除特定数据的学习信息,特别适合删除少量数据。
- 存储成本:需要存储每个批次的参数更新,存储成本较高,但通常低于从头训练模型的成本。
- 适用场景:适合删除少量数据(如单个样本或少量样本),而不适合删除大量数据(如整个类别)。
相关文章:
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning
【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning 1 算法原理 论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 115…...
Vue 3 30天精进之旅:Day 03 - Vue实例
引言 在前两天的学习中,我们成功搭建了Vue.js的开发环境,并创建了我们的第一个Vue项目。今天,我们将深入了解Vue的核心概念之一——Vue实例。通过学习Vue实例,你将理解Vue的基础架构,掌握数据绑定、模板语法和指令的使…...
【ArcGIS微课1000例】0141:提取多波段影像中的单个波段
文章目录 一、波段提取函数二、加载单波段导出问题描述:如下图所示,img格式的时序NDVI数据有24个波段。现在需要提取某一个波段,该怎样操作? 一、波段提取函数 首先加载多波段数据。点击【窗口】→【影像分析】。 选择需要处理的多波段影像,点击下方的【添加函数】。 在多…...
【第九天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-六种常见的图论算法(持续更新)
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的图论算法2. 图论算法3.详细的图论算法1)深度优先搜索(DFS)2…...
落地 轮廓匹配
个人理解为将一幅不规则的图形,通过最轮廓发现,最大轮廓匹配来确定图像的位置,再通过pt将不规则的图像放在规定的矩形里面,在通过透视变换将不规则的图形放进规则的图像中。 1. findHomography 函数 • Mat h findHomography(s…...
【漫话机器学习系列】064.梯度下降小口诀(Gradient Descent rule of thume)
梯度下降小口诀 为了帮助记忆梯度下降的核心原理和关键注意事项,可以用以下简单口诀来总结: 1. 基本原理 损失递减,梯度为引:目标是让损失函数减少,依靠梯度指引方向。负梯度,反向最短:沿着负…...
JAVA(SpringBoot)集成Kafka实现消息发送和接收。
SpringBoot集成Kafka实现消息发送和接收。 一、Kafka 简介二、Kafka 功能三、POM依赖四、配置文件五、生产者六、消费者 君子之学贵一,一则明,明则有功。 一、Kafka 简介 Kafka 是由 Apache 软件基金会开发的一个开源流处理平台,最初由 Link…...
AI刷题-蛋糕工厂产能规划、优质章节的连续选择
挑两个简单的写写 目录 一、蛋糕工厂产能规划 问题描述 输入格式 输出格式 解题思路: 问题理解 数据结构选择 算法步骤 关键点 最终代码: 运行结果:编辑 二、优质章节的连续选择 问题描述 输入格式 输出格式 解题思路&a…...
在线可编辑Excel
1. Handsontable 特点: 提供了类似 Excel 的表格编辑体验,包括单元格样式、公式计算、数据验证等功能。 支持多种插件,如筛选、排序、合并单元格等。 轻量级且易于集成到现有项目中。 具备强大的自定义能力,可以调整外观和行为…...
什么是词嵌入?Word2Vec、GloVe 与 FastText 的区别
自然语言处理(NLP)领域的核心问题之一,是如何将人类的语言转换成计算机可以理解的数值形式,而词嵌入(Word Embedding)正是为了解决这个问题的重要技术。本文将详细讲解词嵌入的概念及其经典模型(Word2Vec、GloVe 和 FastText)的原理与区别。 1. 什么是词嵌入(Word Em…...
WPS数据分析000010
基于数据透视表的内容 一、排序 手动调动 二、筛选 三、值显示方式 四、值汇总依据 五、布局和选项 不显示分类汇总 合并居中带标签的单元格 空单元格显示 六、显示报表筛选页...
Qt中QVariant的使用
1.使用QVariant实现不同类型数据的相加 方法:通过type函数返回数值的类型,然后通过setValue来构造一个QVariant类型的返回值。 函数: QVariant mainPage::dataPlus(QVariant a, QVariant b) {QVariant ret;if ((a.type() QVariant::Int) &a…...
Avalonia UI MVVM DataTemplate里绑定Command
Avalonia 模板里面绑定ViewModel跟WPF写法有些不同。需要单独绑定Command. WPF里面可以直接按照下面的方法绑定DataContext. <Button Content"Button" Command"{Binding DataContext.ClickCommand, RelativeSource{RelativeSource AncestorType{x:Type User…...
动态规划DP 数字三角型模型 最低通行费用(题目详解+C++代码完整实现)
最低通行费用 原题链接 AcWing 1018. 最低同行费用 题目描述 一个商人穿过一个 NN的正方形的网格,去参加一个非常重要的商务活动。 他要从网格的左上角进,右下角出。每穿越中间 1个小方格,都要花费 1个单位时间。商人必须在 (2N−1)个单位…...
deepseek R1的确不错,特别是深度思考模式
deepseek R1的确不错,特别是深度思考模式,每次都能自我反省改进。比如我让 它写文案: 【赛博朋克版程序员新春密码——2025我们来破局】 亲爱的代码骑士们: 当CtrlS的肌肉记忆遇上抢票插件,当Spring Boot的…...
Linux 常用命令 - sort 【对文件内容进行排序】
简介 sort 命令源于英文单词 “sort”,表示排序。其主要功能是对文本文件中的行进行排序。它可以根据字母、数字、特定字段等不同的标准进行排序。sort 通过逐行读取文件(没有指定文件或指定文件为 - 时读取标准输入)内容,并按照…...
MyBatis最佳实践:提升数据库交互效率的秘密武器
第一章:框架的概述: MyBatis 框架的概述: MyBatis 是一个优秀的基于 Java 的持久框架,内部对 JDBC 做了封装,使开发者只需要关注 SQL 语句,而不关注 JDBC 的代码,使开发变得更加的简单MyBatis 通…...
选择困难?直接生成pynput快捷键字符串
from pynput import keyboard# 文档:https://pynput.readthedocs.io/en/latest/keyboard.html#monitoring-the-keyboard # 博客(pynput相关源码):https://blog.csdn.net/qq_39124701/article/details/145230331 # 虚拟键码(十六进制):https:/…...
DeepSeek-R1:强化学习驱动的推理模型
1月20日晚,DeepSeek正式发布了全新的推理模型DeepSeek-R1,引起了人工智能领域的广泛关注。该模型在数学、代码生成等高复杂度任务上表现出色,性能对标OpenAI的o1正式版。同时,DeepSeek宣布将DeepSeek-R1以及相关技术报告全面开源。…...
国内优秀的FPGA设计公司主要分布在哪些城市?
近年来,国内FPGA行业发展迅速,随着5G通信、人工智能、大数据等新兴技术的崛起,FPGA设计企业的需求也迎来了爆发式增长。很多技术人才在求职时都会考虑城市的行业分布和发展潜力。因此,国内优秀的FPGA设计公司主要分布在哪些城市&a…...
python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...
Linux简单的操作
ls ls 查看当前目录 ll 查看详细内容 ls -a 查看所有的内容 ls --help 查看方法文档 pwd pwd 查看当前路径 cd cd 转路径 cd .. 转上一级路径 cd 名 转换路径 …...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...
【Java_EE】Spring MVC
目录 Spring Web MVC 编辑注解 RestController RequestMapping RequestParam RequestParam RequestBody PathVariable RequestPart 参数传递 注意事项 编辑参数重命名 RequestParam 编辑编辑传递集合 RequestParam 传递JSON数据 编辑RequestBody …...
项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)
Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...
#Uniapp篇:chrome调试unapp适配
chrome调试设备----使用Android模拟机开发调试移动端页面 Chrome://inspect/#devices MuMu模拟器Edge浏览器:Android原生APP嵌入的H5页面元素定位 chrome://inspect/#devices uniapp单位适配 根路径下 postcss.config.js 需要装这些插件 “postcss”: “^8.5.…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
HTML前端开发:JavaScript 获取元素方法详解
作为前端开发者,高效获取 DOM 元素是必备技能。以下是 JS 中核心的获取元素方法,分为两大系列: 一、getElementBy... 系列 传统方法,直接通过 DOM 接口访问,返回动态集合(元素变化会实时更新)。…...
ubuntu22.04 安装docker 和docker-compose
首先你要确保没有docker环境或者使用命令删掉docker sudo apt-get remove docker docker-engine docker.io containerd runc安装docker 更新软件环境 sudo apt update sudo apt upgrade下载docker依赖和GPG 密钥 # 依赖 apt-get install ca-certificates curl gnupg lsb-rel…...
