【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR
【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR

1 算法原理
Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023.
本文提出了一种名为 UNSIR(Unlearning with Single Pass Impair and Repair) 的机器遗忘框架,用于从深度神经网络中高效地卸载(遗忘)特定类别数据,同时保留模型对其他数据的性能。以下是算法的主要步骤:
1. 零隐私设置(Zero-Glance Privacy Setting)
- 假设:用户请求从已训练的模型中删除其数据(例如人脸图像),并且模型无法再访问这些数据,即使是为了权重调整。
- 目标:在不重新训练模型的情况下,使模型忘记特定类别的数据,同时保留对其他数据的性能。
2. 学习误差最大化噪声矩阵(Error-Maximizing Noise Matrix)
-
初始化:随机初始化噪声矩阵 N,其大小与模型输入相同。
-
优化目标:通过最大化模型对目标类别的损失函数来优化噪声矩阵 N。具体优化问题为:
a r g N m i n E ( θ ) = − L ( f , y ) + λ ∥ w n o i s e ∥ argNminE(θ)=−L(f,y)+λ∥wnoise∥ argNminE(θ)=−L(f,y)+λ∥wnoise∥其中:
- L(f,y) 是针对要卸载的类别的分类损失函数。
- λ∥wnoise∥ 是正则化项,防止噪声值过大。
- 使用交叉熵损失函数 L 和 L2 归一化。
-
噪声矩阵的作用:生成的噪声矩阵 N 与要卸载的类别标签相关联,用于在后续步骤中破坏模型对这些类别的记忆。
3. 单次损伤与修复(Single Pass Impair and Repair)
- 损伤步骤(Impair Step):
- 操作:将噪声矩阵 N 与保留数据子集Dr结合,训练模型一个周期(epoch)。
- 目的:通过高学习率(例如 0.02)快速破坏模型对要卸载类别的权重。
- 结果:模型对要卸载类别的性能显著下降,同时对保留类别的性能也会受到一定影响。
- 修复步骤(Repair Step):
- 操作:仅使用保留数据子集 Dr再次训练模型一个周期(epoch),学习率较低(例如 0.01)。
- 目的:恢复模型对保留类别的性能,同时保持对要卸载类别的遗忘效果。
- 结果:最终模型在保留数据上保持较高的准确率,而在卸载数据上准确率接近于零。
2 Python代码实现
相关函数
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falsewarnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义三层全连接网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):transform = transforms.Compose([transforms.ToTensor()])train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)forgotten_train_data,_ = generate_subset_by_ratio(train_data, forgotten_classes,ratio)retain_train_data,_ = generate_subset_by_ratio(train_data, [i for i in range(10) if i not in forgotten_classes])forgotten_train_loader= DataLoader(forgotten_train_data, batch_size=batch_size, shuffle=True)retain_train_loader= DataLoader(retain_train_data, batch_size=batch_size, shuffle=True)return train_loader, test_loader, retain_train_loader, forgotten_train_loader# worker_init_fn 用于初始化每个 worker 的随机种子
def worker_init_fn(worker_id):random.seed(2024 + worker_id)np.random.seed(2024 + worker_id)
def get_transforms():train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 标准化为[-1, 1]])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 标准化为[-1, 1]])return train_transform, test_transform
# 模型训练函数
def train_model(model, train_loader, criterion, optimizer, scheduler=None,use_fp16 = False):use_fp16 = True# 使用新的初始化方式:torch.amp.GradScaler("cuda")scaler = GradScaler("cuda") # 用于混合精度训练model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播with autocast(enabled=use_fp16, device_type="cuda"): # 更新为使用 "cuda"outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()if use_fp16:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()running_loss += loss.item()if scheduler is not None:# 更新学习率scheduler.step()print(f"Loss: {running_loss/len(train_loader):.4f}")
# 模型评估(计算保留和遗忘类别的准确率)
def test_model(model, test_loader, forgotten_classes=[0]):"""测试模型的性能,计算总准确率、遗忘类别准确率和保留类别准确率。:param model: 要测试的模型:param test_loader: 测试数据加载器:param forgotten_classes: 需要遗忘的类别列表:return: overall_accuracy, forgotten_accuracy, retained_accuracy"""model.eval()correct = 0total = 0forgotten_correct = 0forgotten_total = 0retained_correct = 0retained_total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 计算总的准确率total += labels.size(0)correct += (predicted == labels).sum().item()# 计算遗忘类别的准确率mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=device))forgotten_total += mask_forgotten.sum().item()forgotten_correct += (predicted[mask_forgotten] == labels[mask_forgotten]).sum().item()# 计算保留类别的准确率(除遗忘类别的其他类别)mask_retained = ~mask_forgottenretained_total += mask_retained.sum().item()retained_correct += (predicted[mask_retained] == labels[mask_retained]).sum().item()overall_accuracy = correct / totalforgotten_accuracy = forgotten_correct / forgotten_total if forgotten_total > 0 else 0retained_accuracy = retained_correct / retained_total if retained_total > 0 else 0# return overall_accuracy, forgotten_accuracy, retained_accuracyreturn round(overall_accuracy, 4), round(forgotten_accuracy, 4), round(retained_accuracy, 4)
主函数
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from models.Base import load_MNIST_data, test_model, load_CIFAR100_data, init_modelclass UNSIRForget:def __init__(self, model):self.model = modelself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 学习误差最大化噪声矩阵def learn_error_maximizing_noise(self, train_loader, forgotten_classes, lambda_reg=0.01, learning_rate=0.01, num_epochs=5):self.model.eval()# 初始化噪声矩阵 N,大小与输入图像相同(例如28x28图像)noise_matrix = torch.randn(1, 1, 28, 28, device=self.device, requires_grad=True) # 假设输入是28x28的图像# 优化器用于优化噪声矩阵optimizer = torch.optim.SGD([noise_matrix], lr=learning_rate)noise_data = []noise_labels = []# 生成噪声数据集for epoch in range(num_epochs):total_loss = 0.0for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)# 只对属于遗忘类别的数据进行优化mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))noisy_images = images.clone()# 对遗忘类别的图像添加噪声noisy_images[mask_forgotten] += noise_matrix# 保存噪声数据noise_data.append(noisy_images)noise_labels.append(labels)# 前向传播outputs = self.model(noisy_images.view(-1, 28 * 28)) # 假设模型的输入是28x28的图像loss = F.cross_entropy(outputs, labels)# L2 正则化项(噪声矩阵的L2范数)l2_reg = lambda_reg * torch.norm(noise_matrix)# 总损失(包含交叉熵损失和L2正则化)total_loss = loss + l2_reg# 反向传播并更新噪声矩阵optimizer.zero_grad()total_loss.backward()optimizer.step()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}")# 返回包含噪声数据和标签的噪声数据集return torch.cat(noise_data), torch.cat(noise_labels), noise_matrix.detach()# 实现机器遗忘(针对特定类别,使用噪声矩阵进行干扰)def unlearn(self, train_loader, forgotten_classes, noise_data, noise_labels, noise_matrix, alpha_impair, alpha_repair, num_epochs=1):# 损伤步骤self.model.train()print("执行损伤中...")for epoch in range(num_epochs):for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)# 仅选择保留类别的数据mask_retained = ~torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))retained_images = images[mask_retained]retained_labels = labels[mask_retained]# 生成新的数据集,将噪声数据添加到保留数据中augmented_images = torch.cat([retained_images, noise_data], dim=0)augmented_labels = torch.cat([retained_labels, noise_labels], dim=0)# 前向传播outputs = self.model(augmented_images.view(-1, 28 * 28)) # 假设模型的输入是28x28的图像loss = F.cross_entropy(outputs, augmented_labels)# 更新模型权重self.model.zero_grad()loss.backward()with torch.no_grad():for param in self.model.parameters():param.data -= alpha_impair * param.grad.data# 修复步骤print("执行修复中...")for epoch in range(num_epochs):for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)# 仅使用保留类别的数据进行修复mask_retained = ~torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))retained_images = images[mask_retained]retained_labels = labels[mask_retained]if retained_images.size(0) == 0:continue# 前向传播和损失计算outputs = self.model(retained_images.view(-1, 28 * 28))loss = F.cross_entropy(outputs, retained_labels)# 更新模型权重self.model.zero_grad()loss.backward()with torch.no_grad():for param in self.model.parameters():param.data -= alpha_repair * param.grad.datareturn self.model# UNSIR算法的主要流程
def unsir_unlearning(model_before, retrain_data, forget_data, all_data, forgotten_classes, lambda_reg=0.01, learning_rate=0.01, alpha_impair=0.5, alpha_repair=0.001, num_epochs=5):"""执行 UNSIR 算法的主要流程,包括学习误差最大化噪声矩阵、损伤、修复步骤,最终返回遗忘后的模型。"""unsir_forgetter = UNSIRForget(model_before)# 计算学习误差最大化噪声矩阵noise_data, noise_labels, noise_matrix = unsir_forgetter.learn_error_maximizing_noise(all_data, forgotten_classes, lambda_reg, learning_rate, num_epochs)# 执行 unlearn(损伤与修复步骤)unlearned_model = unsir_forgetter.unlearn(all_data, forgotten_classes, noise_data, noise_labels, noise_matrix, alpha_impair, alpha_repair, num_epochs)return unlearned_modeldef main():# 超参数设置batch_size = 256forgotten_classes = [0]ratio = 1model_name = "MLP"# 加载数据train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)print("执行 UNSIR 遗忘...")model_after = unsir_unlearning(model_before,retain_loader,forget_loader,train_loader,forgotten_classes,lambda_reg=0.01,learning_rate=0.01,alpha_impair=0.5,alpha_repair=0.001,num_epochs=5,)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()
3 总结
当前方法不支持随机样本或类别子集的卸载,这可能违反零隐私假设。
仍属于重新优化的算法,即还需要训练。
相关文章:
【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR
【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR 1 算法原理 Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023. 本文提出了一种名为 UNSIR(Un…...
基于SpringBoot的阳光幼儿园管理系统
作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…...
【逻辑学导论第15版】A. 推理
识别下列语段中的前提与结论。有些前提确实支持结论,有些并不支持。请注意,前提可能直接或间接地支持结论,而简单的语段也可能包含不止一个论证。 例题: 1.管理得当的民兵组织对于一个自由国家的安全是必需的,因而人民…...
【开源免费】基于SpringBoot+Vue.JS景区民宿预约系统(JAVA毕业设计)
本文项目编号 T 162 ,文末自助获取源码 \color{red}{T162,文末自助获取源码} T162,文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…...
c++ map/multimap容器 学习笔记
1 map的基本概念 简介: map中所有的元素都是pair pair中第一个元素是key(键),第二个元素是value(值) 所有元素都会根据元素的键值自动排序。本质: map/multimap 属于关联式容器,底…...
安卓逆向之脱壳-认识一下动态加载 双亲委派(一)
安卓逆向和脱壳是安全研究、漏洞挖掘、恶意软件分析等领域的重要环节。脱壳(unpacking)指的是去除应用程序中加固或保护措施的过程,使得可以访问应用程序的原始代码或者数据。脱壳的重要性: 分析恶意软件:很多恶意软件…...
64位的谷歌浏览器Chrome/Google Chrome
64位的谷歌浏览器Chrome/Google Chrome 在百度搜索关键字:chrome,即可下载官方的“谷歌浏览器Chrome/Google Chrome”,但它可能是32位的(切记注意网址:https://www.google.cn/...., 即:google.cnÿ…...
马尔科夫模型和隐马尔科夫模型区别
我用一个天气预报和海藻湿度观测的比喻来解释,保证你秒懂! 1. 马尔可夫模型(Markov Model, MM) 特点:状态直接可见 场景:天气预报(晴天→雨天→阴天…)核心假设: 下一个…...
Python NumPy(7):连接数组、分割数组、数组元素的添加与删除
1 连接数组 函数描述concatenate连接沿现有轴的数组序列stack沿着新的轴加入一系列数组。hstack水平堆叠序列中的数组(列方向)vstack竖直堆叠序列中的数组(行方向) 1.1 numpy.concatenate numpy.concatenate 函数用于沿指定轴连…...
【LLM】deepseek多模态之Janus-Pro和JanusFlow框架
note 文章目录 note一、Janus-Pro:解耦视觉编码,实现多模态高效统一技术亮点模型细节 二、JanusFlow:融合生成流与语言模型,重新定义多模态技术亮点模型细节 Reference 一、Janus-Pro:解耦视觉编码,实现多模…...
2000-2021年 全国各地级市专利申请与获得情况、绿色专利申请与获得情况数据
2000-2021年 全国各地级市专利申请与获得情况、绿色专利申请与获得情况数据.ziphttps://download.csdn.net/download/2401_84585615/89575931 https://download.csdn.net/download/2401_84585615/89575931 2000至2021年,全国各地级市的专利申请与获得情况呈现出显著…...
51单片机(STC89C52)开发:点亮一个小灯
软件安装: 安装开发板CH340驱动。 安装KEILC51开发软件:C51V901.exe。 下载软件:PZ-ISP.exe 创建项目: 新建main.c 将main.c加入至项目中: main.c:点亮一个小灯 #include "reg52.h"sbit LED1P2^0; //P2的…...
数仓ETL测试
提取,转换和加载有助于组织使数据在不同的数据系统中可访问,有意义且可用。ETL工具是用于提取,转换和加载数据的软件。在当今数据驱动的世界中,无论大小如何,都会从各种组织,机器和小工具中生成大量数据。 …...
240. 搜索二维矩阵||
参考题解:https://leetcode.cn/problems/search-a-2d-matrix-ii/solutions/2361487/240-sou-suo-er-wei-ju-zhen-iitan-xin-qin-7mtf 将矩阵旋转45度,可以看作一个二叉搜索树。 假设以左下角元素为根结点, 当target比root大的时候ÿ…...
反向代理模块b
1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说,反向代理就相当于…...
Kafka的内部通信协议
引言 kafka内部用到的常见协议和优缺点可以看看原文 Kafka用到的协议 本文奖详细探究kafka核心通信协议和高性能的关键 网络层通信的实现 基于 Java NIO:Kafka 的网络通信层主要基于 Java NIO 来实现,这使得它能够高效地处理大量的连接和数据传输。…...
Excel - Binary和Text两种Compare方法
Option Compare statement VBA里可以定义默认使用的compare方法: Set the string comparison method to Binary. Option Compare Binary That is, "AAA" is less than "aaa". Set the string comparison method to Text. Option Compare Tex…...
【Linux权限】—— 于虚拟殿堂,轻拨密钥启华章
欢迎来到ZyyOvO的博客✨,一个关于探索技术的角落,记录学习的点滴📖,分享实用的技巧🛠️,偶尔还有一些奇思妙想💡 本文由ZyyOvO原创✍️,感谢支持❤️!请尊重原创…...
EasyExcel使用详解
文章目录 EasyExcel使用详解一、引言二、环境准备与基础配置1、添加依赖2、定义实体类 三、Excel 读取详解1、基础读取2、自定义监听器3、多 Sheet 处理 四、Excel 写入详解1、基础写入2、动态列与复杂表头3、样式与模板填充 五、总结 EasyExcel使用详解 一、引言 EasyExcel 是…...
LeetCode 2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版
【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚不乱),附Python一行版 文章目录 【LetMeFly】2412.完成所有交易的初始最少钱数:【年度巨献】举例说明(讲明白),由难至简(手脚…...
前端-Rollup
Rollup 是一个用于 JavaScript 的模块打包工具,它将小的代码片段编译成更大、更复杂的代码,例如库或应用程序。它使用 JavaScript 的 ES6 版本中包含的新标准化代码模块格式,而不是以前的 CommonJS 和 AMD 等特殊解决方案。ES 模块允许你自由…...
ubuntu黑屏问题解决
重启Ubuntu后,系统自动进入tty1,无法进入桌面。想到前几天安装了一些主题之类的,然后今天才重启,可能是这些主题造成冲突或者问题了把。 这里直接重新安装ubuntu-desktop解决: 更新源: sudo apt-get upd…...
MV结构下设置Qt表格的代理
目录 预备知识 模型 关联 刷新 示例 代理 模型 界面 结果 完整资料见: 所谓MV结构,是“model-view”(模型-视图)的简称。也就是说,表格的数据保存在model中,而视图由view实现。在我前面的很多博客…...
vue3相关知识点
title: vue_1 date: 2025-01-28 12:00:00 tags:- 前端 categories:- 前端vue3 Webpack ~ vite vue3是基于vite创建的 vite 更快一点 一些准备工作 准备后如图所示 插件 Main.ts // 引入createApp用于创建应用 import {createApp} from vue // 引入App根组件 import App f…...
Lustre v6 语法 - 时序表达式
概述 Lustre v6 语法中,与时序表达式有关的运算,包括 ->(followed by), pre(previous), fby, current, when, merge。其中,除 merge 运算是 Lustre v6 中新引入的外,其余在 Lustre Core 语法中已有定义。 与时序表达式有关的…...
vs2013 使用 eigen 库编译时报 C2059 错的解决方法
(个人感觉)vs2013 就不能使用版本大于等于 3.4 的 eigen,使用 3.3.9 就可以了,再不行就用 3.3.8 另一个博主也遇到过用 vs2013 的时候不能编译 3.4 的 eigen 的问题,不过我用的是 win11,所以感觉跟操作系统…...
Kafka 消费端反复 Rebalance: `Attempt to heartbeat failed since group is rebalancing`
文章目录 Kafka 消费端反复 Rebalance: Attempt to heartbeat failed since group is rebalancing1. Rebalance 过程概述2. 错误原因分析2.1 消费者组频繁加入或退出2.1.1 消费者故障导致频繁重启2.1.2. 消费者加入和退出导致的 Rebalance2.1.3 消费者心跳超时导致的 Rebalance…...
【第九天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-六种常见的图论算法(持续更新)
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的图论算法2. 图论算法3.详细的图论算法1)深度优先搜索(DFS)2…...
微服务网关鉴权之sa-token
目录 前言 项目描述 使用技术 项目结构 要点 实现 前期准备 依赖准备 统一依赖版本 模块依赖 配置文件准备 登录准备 网关配置token解析拦截器 网关集成sa-token 配置sa-token接口鉴权 配置satoken权限、角色获取 通用模块配置用户拦截器 api模块配置feign…...
shell脚本批量修改文件名之方法(The Method of Batch Modifying File Names in Shell Scripts)
shell脚本批量修改文件名方法 我们可以使用Shell脚本来实现这个功能。Shell脚本是一种用于自动化任务的编程语言,它可以在Unix/Linux操作系统上运行。在这个脚本中,我们将使用一个for循环来遍历目标目录下的所有文件,并使用mv命令将每个文件…...
