BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例
加载数据集
# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()

在训练集中植入5000个中毒样本
# 在训练集中植入5000个中毒样本
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()

训练模型
data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue
if __name__=='__main__':main()
测试攻击成功率
# 攻击成功率 99.66% 对测试集中所有图像都注入后门for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()
可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。


完整代码
from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()# 训练集样本数据
print(len(train_loader))# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue# 选择一个训练集中植入后门的数据,测试后门是否有效'''sample, label = next(iter(data_loader_train))print(sample.size()) # [64, 1, 28, 28]print(label[0])# 可视化plt.imshow(sample[0][0])plt.show()model.load_state_dict(torch.load('badnets.pth'))model.eval()sample = sample.to(device)output = model(sample)print(output[0])pred = output.argmax(dim=1)print(pred[0])'''# 攻击成功率 99.66%for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()if __name__=='__main__':main()
相关文章:
BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例
加载数据集 # 载入MNIST训练集和测试集 transform transforms.Compose([transforms.ToTensor(),]) train_loader datasets.MNIST(rootdata,transformtransform,trainTrue,downloadTrue) test_loader datasets.MNIST(rootdata,transformtransform,trainFalse) # 可视化样本 …...
freeRTOS内部机制——栈的作用
上图中*pa 和*pb分别为R0,R1,调用C函数时,第一个参数保存在R0中第二个参数保存在R1中。这是约定。 指令保存在哪里? 指令保存在flash上面 LR等于什么? LR是返回地址,函数执行完了过后LR等于下一条指令的地址 运行…...
python 桌面软件开发-matplotlib画图鼠标缩放拖动
继上一篇在 Java 中缩放拖动图片后,在python matplotlib中也来实现一个自由缩放拖动的例子: python matplotlib 中缩放,较为简单,只需要通过设置要显示的 x y坐标的显示范围即可。基于此,实现一个鼠标监听回调…...
【JavaScript基础】JavaScript头等函数的理解
彻底理解JavaScript头等函数 一、函数的理解 🔥 什么是函数? 一般来说,一个函数是可以通过外部代码 调用 的一个“子程序”(或在递归的情况下由内部函数调用)。像程序本身一样,一个函数由称为函数体的一…...
如何把项目上传到Gitee(详细教程)
找到项目根目录右键打开Git Bash Here 输入命令:git init 回车 输入命令:git status 输入命令:git add . 输入命令:git status git commit -m 项目描述 在Gitee官网注册好账号后,git 新建项目 填写补充git项目信息及…...
Ubuntu挂载windows下的共享文件夹
Ubuntu挂载windows下的共享文件夹 更新apt源 如果出现安装失败,需要更新apt源为阿里云 # 备份原始文件 sudo cp /etc/apt/sources.list.d/* /etc/apt/sources.list.d.bak/# 修改文件内容 sudo vim /etc/apt/sources.list# 替换内容为如下 deb https://mirrors.al…...
什么是WMS系统条码化管理
WMS系统是一种用于仓库管理的信息化系统,旨在提高仓库操作的效率和准确性。而在WMS系统中,条码化管理是一项关键的技术和方法,它通过将商品和物料打上条码,并利用扫描设备进行数据采集和处理,实现了仓库管理的全面自动…...
【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统
【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统 一、moredoc介绍1.1 moredoc简介1.2 moredoc技术栈二、本次实践介绍2.1 本次实践简介2.2 本次环境规划三、检查k8s环境3.1 检查工作节点状态3.2 检查系统pod状态四、创建mysql的secret资源4.1 创建部署目录4.2 创建密…...
[Database] MySQL 8.x Window / Partition Function (窗口/分区函数)
🧲相关文章 [1] MySQL 系统表解析以及各项指标查询 [2] MySQL 5.7 JSON 字段的使用的处理 [3] MySQL经典练习50题 简介 MySQL 8.0版本开始支持窗口函数 官方文档 在之前的版本中已存在的大部分聚合函数,在MySQL 8 中也可以作为窗口函数来使用 方法 / …...
openGauss Meetup(天津站)精彩回顾 | openGauss天津用户组正式成立
由openGauss社区、天开发展集团、天津市软件行业协会、天大智图(天津)科技有限公司联合主办的“openGauss Meetup • 天津站”已于10月13日落下帷幕,此次活动邀请到众多业内技术专家,从技术创新、学术创新、发展创新、以及生态共建…...
linux vim 删除多行
使用linux服务器,免不了和vi编辑打交道,命令行下删除数量少还好,如果删除很多,光靠删除键一点点删除真的是头痛,还好Vi有快捷的命令可以删除多行、范围。 删除行 在Vim中删除一行的命令是dd。 以下是删除行的分步说明…...
低概率Bug,研发敷衍说复现不到
测试工作中,经常会遇到一些低概率出现的问题,如果再是个严重问题,那测试人员的压力无疑是很大的,一方面是因为低概率难以复现,另一面则是来自项目组的压力。 如何在测试时减少此类问题的重复投入,我的思考如…...
Web前端免费接入Microsoft Azure AI文本翻译,享每月2百万个字符的翻译
Azure 文本翻译是 Azure AI 翻译服务的一项基于云的 REST API 功能。 文本翻译 API 支持实时快速准确地进行源到目标文本翻译。 文本翻译软件开发工具包 (SDK) 是一组库和工具,可用于轻松地将文本翻译 REST API 功能集成到应用程序中。 文本翻译 SDK 可跨 C#/.NET、…...
1024 CSDN 程序员节-知存科技-基于存内计算芯片开发板验证语音识别
前言 在今年的 CSDN 程序员节上,我参与了这次知存科技举办的一个 AI Workshop 小活动——“基于存内计算芯片开发板验证语音识别”,并且有幸成为完成任务的学习者之一XD。上一次参与类似的活动是算能公司举办的“千校万里行”AIGC 大模型编译部署活动&a…...
【备考网络工程师】如何备考2023年网络工程师之错题集篇(3)
一、写在前面 其实做模拟或真题时候,总是会在关键的地方丢分,因此我也冷静下来思考一下,首先我们对做过的题涉及的知识进行一个梳理,其次就是再针对知识去做一些题目,这次只考了38分,表示很伤心࿰…...
密码学-SHA-1算法
实验七 SHA-1 一、实验目的 熟悉SHA-1算法的运行过程,能够使用C语言编写实现SHA-1算法程序,增 加对摘要函数的理解。 二、实验要求 (1)理解SHA-1轮函数的定义和工作过程。 (2)利用VC语言实现SHA- 1算法。 (3)分析SHA- 1算法运行的性能。 三、实验…...
Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2)
Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2) import android.graphics.Canvas import android.graphics.Point import android.graphics.drawable.ColorDrawable import android.os.Bundle import android.util…...
翻页视图ViewPager
ViewPager控件允许页面在水平方向左右滑动,就像翻书、翻报纸,Android提供了已经分装好的控件。对于ViewPager来说,一个页面就是一个项(相当于ListView的一个列表项),许多页面组成ViewPager的页面项。 List…...
【可视化Java GUI程序设计教程】第4章 布局设计
4.1 布局管理器概述 右击窗体,单击快捷菜单中的Set Layout 4.1.2 绝对布局(Absolute Layout) 缩小窗口发现超出窗口范围的按钮看不见 Absolute Layout 4.1.2 空值布局(Null Layout) 4.1.3 布局管理器的属性和组件布…...
Elasticsearch配置文件
一 前言 在elasticsearch\config目录下,有三个核心的配置文件: elasticsearch.yml,es相关的配置。jvm.options,Java jvm相关参数的配置。log4j2.properties,日志相关的配置,因为es采用了log4j的日志框架。这里以elasticsearch6.5.4版本为例,并且由于版本不同,配置也不…...
Chapter03-Authentication vulnerabilities
文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...
【OSG学习笔记】Day 18: 碰撞检测与物理交互
物理引擎(Physics Engine) 物理引擎 是一种通过计算机模拟物理规律(如力学、碰撞、重力、流体动力学等)的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互,广泛应用于 游戏开发、动画制作、虚…...
基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...
CMake基础:构建流程详解
目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...
【C语言练习】080. 使用C语言实现简单的数据库操作
080. 使用C语言实现简单的数据库操作 080. 使用C语言实现简单的数据库操作使用原生APIODBC接口第三方库ORM框架文件模拟1. 安装SQLite2. 示例代码:使用SQLite创建数据库、表和插入数据3. 编译和运行4. 示例运行输出:5. 注意事项6. 总结080. 使用C语言实现简单的数据库操作 在…...
如何在最短时间内提升打ctf(web)的水平?
刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...
Angular微前端架构:Module Federation + ngx-build-plus (Webpack)
以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...
浪潮交换机配置track检测实现高速公路收费网络主备切换NQA
浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...
mac 安装homebrew (nvm 及git)
mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用: 方法一:使用 Homebrew 安装 Git(推荐) 步骤如下:打开终端(Terminal.app) 1.安装 Homebrew…...
MyBatis中关于缓存的理解
MyBatis缓存 MyBatis系统当中默认定义两级缓存:一级缓存、二级缓存 默认情况下,只有一级缓存开启(sqlSession级别的缓存)二级缓存需要手动开启配置,需要局域namespace级别的缓存 一级缓存(本地缓存&#…...
