当前位置: 首页 > news >正文

BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例

加载数据集

# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

在训练集中植入5000个中毒样本

# 在训练集中植入5000个中毒样本
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()

在这里插入图片描述

训练模型

data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue
if __name__=='__main__':main()

测试攻击成功率

# 攻击成功率 99.66%  对测试集中所有图像都注入后门for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()

可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。
在这里插入图片描述
在这里插入图片描述

完整代码

from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()# 训练集样本数据
print(len(train_loader))# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9  # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue# 选择一个训练集中植入后门的数据,测试后门是否有效'''sample, label = next(iter(data_loader_train))print(sample.size())  # [64, 1, 28, 28]print(label[0])# 可视化plt.imshow(sample[0][0])plt.show()model.load_state_dict(torch.load('badnets.pth'))model.eval()sample = sample.to(device)output = model(sample)print(output[0])pred = output.argmax(dim=1)print(pred[0])'''# 攻击成功率 99.66%for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()if __name__=='__main__':main()

相关文章:

BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例

加载数据集 # 载入MNIST训练集和测试集 transform transforms.Compose([transforms.ToTensor(),]) train_loader datasets.MNIST(rootdata,transformtransform,trainTrue,downloadTrue) test_loader datasets.MNIST(rootdata,transformtransform,trainFalse) # 可视化样本 …...

freeRTOS内部机制——栈的作用

上图中*pa 和*pb分别为R0,R1,调用C函数时,第一个参数保存在R0中第二个参数保存在R1中。这是约定。 指令保存在哪里? 指令保存在flash上面 LR等于什么? LR是返回地址,函数执行完了过后LR等于下一条指令的地址 运行…...

python 桌面软件开发-matplotlib画图鼠标缩放拖动

继上一篇在 Java 中缩放拖动图片后,在python matplotlib中也来实现一个自由缩放拖动的例子: python matplotlib 中缩放,较为简单,只需要通过设置要显示的 x y坐标的显示范围即可。基于此,实现一个鼠标监听回调&#xf…...

【JavaScript基础】JavaScript头等函数的理解

彻底理解JavaScript头等函数 一、函数的理解 🔥 什么是函数? 一般来说,一个函数是可以通过外部代码 调用 的一个“子程序”(或在递归的情况下由内部函数调用)。像程序本身一样,一个函数由称为函数体的一…...

如何把项目上传到Gitee(详细教程)

找到项目根目录右键打开Git Bash Here 输入命令:git init 回车 输入命令:git status 输入命令:git add . 输入命令:git status git commit -m 项目描述 在Gitee官网注册好账号后,git 新建项目 填写补充git项目信息及…...

Ubuntu挂载windows下的共享文件夹

Ubuntu挂载windows下的共享文件夹 更新apt源 如果出现安装失败,需要更新apt源为阿里云 # 备份原始文件 sudo cp /etc/apt/sources.list.d/* /etc/apt/sources.list.d.bak/# 修改文件内容 sudo vim /etc/apt/sources.list# 替换内容为如下 deb https://mirrors.al…...

什么是WMS系统条码化管理

WMS系统是一种用于仓库管理的信息化系统,旨在提高仓库操作的效率和准确性。而在WMS系统中,条码化管理是一项关键的技术和方法,它通过将商品和物料打上条码,并利用扫描设备进行数据采集和处理,实现了仓库管理的全面自动…...

【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统

【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统 一、moredoc介绍1.1 moredoc简介1.2 moredoc技术栈二、本次实践介绍2.1 本次实践简介2.2 本次环境规划三、检查k8s环境3.1 检查工作节点状态3.2 检查系统pod状态四、创建mysql的secret资源4.1 创建部署目录4.2 创建密…...

[Database] MySQL 8.x Window / Partition Function (窗口/分区函数)

🧲相关文章 [1] MySQL 系统表解析以及各项指标查询 [2] MySQL 5.7 JSON 字段的使用的处理 [3] MySQL经典练习50题 简介 MySQL 8.0版本开始支持窗口函数 官方文档 在之前的版本中已存在的大部分聚合函数,在MySQL 8 中也可以作为窗口函数来使用 方法 / …...

openGauss Meetup(天津站)精彩回顾 | openGauss天津用户组正式成立

由openGauss社区、天开发展集团、天津市软件行业协会、天大智图(天津)科技有限公司联合主办的“openGauss Meetup • 天津站”已于10月13日落下帷幕,此次活动邀请到众多业内技术专家,从技术创新、学术创新、发展创新、以及生态共建…...

linux vim 删除多行

使用linux服务器,免不了和vi编辑打交道,命令行下删除数量少还好,如果删除很多,光靠删除键一点点删除真的是头痛,还好Vi有快捷的命令可以删除多行、范围。 删除行 在Vim中删除一行的命令是dd。 以下是删除行的分步说明…...

低概率Bug,研发敷衍说复现不到

测试工作中,经常会遇到一些低概率出现的问题,如果再是个严重问题,那测试人员的压力无疑是很大的,一方面是因为低概率难以复现,另一面则是来自项目组的压力。 如何在测试时减少此类问题的重复投入,我的思考如…...

Web前端免费接入Microsoft Azure AI文本翻译,享每月2百万个字符的翻译

Azure 文本翻译是 Azure AI 翻译服务的一项基于云的 REST API 功能。 文本翻译 API 支持实时快速准确地进行源到目标文本翻译。 文本翻译软件开发工具包 (SDK) 是一组库和工具,可用于轻松地将文本翻译 REST API 功能集成到应用程序中。 文本翻译 SDK 可跨 C#/.NET、…...

1024 CSDN 程序员节-知存科技-基于存内计算芯片开发板验证语音识别

前言 在今年的 CSDN 程序员节上,我参与了这次知存科技举办的一个 AI Workshop 小活动——“基于存内计算芯片开发板验证语音识别”,并且有幸成为完成任务的学习者之一XD。上一次参与类似的活动是算能公司举办的“千校万里行”AIGC 大模型编译部署活动&a…...

【备考网络工程师】如何备考2023年网络工程师之错题集篇(3)

一、写在前面 其实做模拟或真题时候,总是会在关键的地方丢分,因此我也冷静下来思考一下,首先我们对做过的题涉及的知识进行一个梳理,其次就是再针对知识去做一些题目,这次只考了38分,表示很伤心&#xff0…...

密码学-SHA-1算法

实验七 SHA-1 一、实验目的 熟悉SHA-1算法的运行过程,能够使用C语言编写实现SHA-1算法程序,增 加对摘要函数的理解。 二、实验要求 (1)理解SHA-1轮函数的定义和工作过程。 (2)利用VC语言实现SHA- 1算法。 (3)分析SHA- 1算法运行的性能。 三、实验…...

Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2)

Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2) import android.graphics.Canvas import android.graphics.Point import android.graphics.drawable.ColorDrawable import android.os.Bundle import android.util…...

翻页视图ViewPager

ViewPager控件允许页面在水平方向左右滑动,就像翻书、翻报纸,Android提供了已经分装好的控件。对于ViewPager来说,一个页面就是一个项(相当于ListView的一个列表项),许多页面组成ViewPager的页面项。 List…...

【可视化Java GUI程序设计教程】第4章 布局设计

4.1 布局管理器概述 右击窗体,单击快捷菜单中的Set Layout 4.1.2 绝对布局(Absolute Layout) 缩小窗口发现超出窗口范围的按钮看不见 Absolute Layout 4.1.2 空值布局(Null Layout) 4.1.3 布局管理器的属性和组件布…...

Elasticsearch配置文件

一 前言 在elasticsearch\config目录下,有三个核心的配置文件: elasticsearch.yml,es相关的配置。jvm.options,Java jvm相关参数的配置。log4j2.properties,日志相关的配置,因为es采用了log4j的日志框架。这里以elasticsearch6.5.4版本为例,并且由于版本不同,配置也不…...

运维:mysql常用的服务器状态命令

目录 1、查询当前服务器运行的进程 2、查询最大链接数 3、查询当前链接数 4、展示当前正在执行的sql语句 5、查询当前MySQL当中记录的慢查询条数 6、展示Mysql服务器从启动到现在持续运行的时间 7、查询数据库存储占用情况 8、查询服务器启动以来的执行查询的总次数 9…...

k8s中kubectl陈述式资源管理

1、 理论 1.1、 管理k8s核心资源的三种基本方法 : 1.1.1陈述式的资源管理方法: 主要依赖命令行工具kubectl进行管理 1.1.1.1、优点: 可以满足90%以上的使用场景 对资源的增、删、查操作比较容易 1.1.1.2、缺点: 命令冗长&…...

11 个最值得推荐的 Windows 数据恢复软件

您可能已经尝试过许多免费的恢复程序,但它们都不起作用,对吧?这就是您正在寻找最好的数据恢复软件的原因。 个人去过那里。根据个人的经验,大多数免费软件并不能解决这个问题。有时,当个人在 PC 上运行恢复程序时&…...

Docker从入门到实战

Docker基本概念 1、解决的问题 1、统一标准 应用构建 ○ Java、C、JavaScript ○ 打成软件包 ○ .exe ○ docker build … 镜像应用分享 ○ 所有软件的镜像放到一个指定地方 docker hub ○ 安卓,应用市场应用运行 ○ 统一标准的 镜像 ○ docker run 容器化技术 …...

UE4 材质实操记录

TexCoord的R通道是从左到右的递增量,G通道是从上到下的递增量,R通道减去0.5,那么左边就是【-0.5~0】区间,所以左边为全黑,Abs取绝对值,就达到一个两边向中间的一个递减的效果,G通道同理&#xf…...

http协议和Fiddler

文章目录 一、http协议的报文结构1.1http请求和http响应之间的区别1.2http请求1.2.1URL1.2.2方法1.2.3请求头1.2.3.1Host1.2.3.2Content-Length、Content-Type1.2.3.3User-Agent(简称UA)1.2.3.4Referer1.2.3.5Cookie 1.3http响应1.3.1响应状态码1.3.2响应头1.3.2.1Content-Leng…...

李宇航

该篇文章仅用作能直接在百度搜索到我的csdn,进入我的主页,没有实际意义. 进入李宇航博客方法 通过百度搜索"李宇航" 链接: https://blog.csdn.net/llllyh812 1.电脑端进入方法 输入网址链接: https://blog.csdn.net/llllyh812 或者 进入csdn主页,搜索"李宇…...

【JAVA学习笔记】38 - 单例设计模式-静态方法和属性的经典使用

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter10/src/com/yinhai/final_ 一、什么是设计模式 1.静态方法和属性的经典使用 2.设计模式是在大量的实践中总结和理论化之后优选的代码结构、编程风格以及解决问题的思考方式。设计模式就像是…...

m1 安装 cocoapods

其实最终解决问题很简单,麻烦的是如果找到解决问题的答案。 前提条件: 命令行工具,不能以rosetta方法打开。 首先安装homebrew 这里不多说了,自行解决。 使用brew安装cocoapods brew install cocoapods, pod --ver…...

【大数据】Kafka 实战教程(一)

Kafka 实战教程(一) 1.Kafka 介绍1.1. 主要功能1.2. 使用场景1.3 详细介绍1.3.1 消息传输流程1.3.2 Kafka 服务器消息存储策略1.3.3 与生产者的交互1.3.4 与消费者的交互 2.Kafka 生产者3.Kafka 消费者3.1 Kafka 消费模式3.1.1 At-most-once(…...