BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例
加载数据集
# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()
在训练集中植入5000个中毒样本
# 在训练集中植入5000个中毒样本
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()
训练模型
data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue
if __name__=='__main__':main()
测试攻击成功率
# 攻击成功率 99.66% 对测试集中所有图像都注入后门for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()
可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。
完整代码
from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()# 训练集样本数据
print(len(train_loader))# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue# 选择一个训练集中植入后门的数据,测试后门是否有效'''sample, label = next(iter(data_loader_train))print(sample.size()) # [64, 1, 28, 28]print(label[0])# 可视化plt.imshow(sample[0][0])plt.show()model.load_state_dict(torch.load('badnets.pth'))model.eval()sample = sample.to(device)output = model(sample)print(output[0])pred = output.argmax(dim=1)print(pred[0])'''# 攻击成功率 99.66%for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()if __name__=='__main__':main()
相关文章:

BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例
加载数据集 # 载入MNIST训练集和测试集 transform transforms.Compose([transforms.ToTensor(),]) train_loader datasets.MNIST(rootdata,transformtransform,trainTrue,downloadTrue) test_loader datasets.MNIST(rootdata,transformtransform,trainFalse) # 可视化样本 …...

freeRTOS内部机制——栈的作用
上图中*pa 和*pb分别为R0,R1,调用C函数时,第一个参数保存在R0中第二个参数保存在R1中。这是约定。 指令保存在哪里? 指令保存在flash上面 LR等于什么? LR是返回地址,函数执行完了过后LR等于下一条指令的地址 运行…...

python 桌面软件开发-matplotlib画图鼠标缩放拖动
继上一篇在 Java 中缩放拖动图片后,在python matplotlib中也来实现一个自由缩放拖动的例子: python matplotlib 中缩放,较为简单,只需要通过设置要显示的 x y坐标的显示范围即可。基于此,实现一个鼠标监听回调…...
【JavaScript基础】JavaScript头等函数的理解
彻底理解JavaScript头等函数 一、函数的理解 🔥 什么是函数? 一般来说,一个函数是可以通过外部代码 调用 的一个“子程序”(或在递归的情况下由内部函数调用)。像程序本身一样,一个函数由称为函数体的一…...

如何把项目上传到Gitee(详细教程)
找到项目根目录右键打开Git Bash Here 输入命令:git init 回车 输入命令:git status 输入命令:git add . 输入命令:git status git commit -m 项目描述 在Gitee官网注册好账号后,git 新建项目 填写补充git项目信息及…...
Ubuntu挂载windows下的共享文件夹
Ubuntu挂载windows下的共享文件夹 更新apt源 如果出现安装失败,需要更新apt源为阿里云 # 备份原始文件 sudo cp /etc/apt/sources.list.d/* /etc/apt/sources.list.d.bak/# 修改文件内容 sudo vim /etc/apt/sources.list# 替换内容为如下 deb https://mirrors.al…...

什么是WMS系统条码化管理
WMS系统是一种用于仓库管理的信息化系统,旨在提高仓库操作的效率和准确性。而在WMS系统中,条码化管理是一项关键的技术和方法,它通过将商品和物料打上条码,并利用扫描设备进行数据采集和处理,实现了仓库管理的全面自动…...
【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统
【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统 一、moredoc介绍1.1 moredoc简介1.2 moredoc技术栈二、本次实践介绍2.1 本次实践简介2.2 本次环境规划三、检查k8s环境3.1 检查工作节点状态3.2 检查系统pod状态四、创建mysql的secret资源4.1 创建部署目录4.2 创建密…...

[Database] MySQL 8.x Window / Partition Function (窗口/分区函数)
🧲相关文章 [1] MySQL 系统表解析以及各项指标查询 [2] MySQL 5.7 JSON 字段的使用的处理 [3] MySQL经典练习50题 简介 MySQL 8.0版本开始支持窗口函数 官方文档 在之前的版本中已存在的大部分聚合函数,在MySQL 8 中也可以作为窗口函数来使用 方法 / …...

openGauss Meetup(天津站)精彩回顾 | openGauss天津用户组正式成立
由openGauss社区、天开发展集团、天津市软件行业协会、天大智图(天津)科技有限公司联合主办的“openGauss Meetup • 天津站”已于10月13日落下帷幕,此次活动邀请到众多业内技术专家,从技术创新、学术创新、发展创新、以及生态共建…...
linux vim 删除多行
使用linux服务器,免不了和vi编辑打交道,命令行下删除数量少还好,如果删除很多,光靠删除键一点点删除真的是头痛,还好Vi有快捷的命令可以删除多行、范围。 删除行 在Vim中删除一行的命令是dd。 以下是删除行的分步说明…...

低概率Bug,研发敷衍说复现不到
测试工作中,经常会遇到一些低概率出现的问题,如果再是个严重问题,那测试人员的压力无疑是很大的,一方面是因为低概率难以复现,另一面则是来自项目组的压力。 如何在测试时减少此类问题的重复投入,我的思考如…...

Web前端免费接入Microsoft Azure AI文本翻译,享每月2百万个字符的翻译
Azure 文本翻译是 Azure AI 翻译服务的一项基于云的 REST API 功能。 文本翻译 API 支持实时快速准确地进行源到目标文本翻译。 文本翻译软件开发工具包 (SDK) 是一组库和工具,可用于轻松地将文本翻译 REST API 功能集成到应用程序中。 文本翻译 SDK 可跨 C#/.NET、…...

1024 CSDN 程序员节-知存科技-基于存内计算芯片开发板验证语音识别
前言 在今年的 CSDN 程序员节上,我参与了这次知存科技举办的一个 AI Workshop 小活动——“基于存内计算芯片开发板验证语音识别”,并且有幸成为完成任务的学习者之一XD。上一次参与类似的活动是算能公司举办的“千校万里行”AIGC 大模型编译部署活动&a…...
【备考网络工程师】如何备考2023年网络工程师之错题集篇(3)
一、写在前面 其实做模拟或真题时候,总是会在关键的地方丢分,因此我也冷静下来思考一下,首先我们对做过的题涉及的知识进行一个梳理,其次就是再针对知识去做一些题目,这次只考了38分,表示很伤心࿰…...

密码学-SHA-1算法
实验七 SHA-1 一、实验目的 熟悉SHA-1算法的运行过程,能够使用C语言编写实现SHA-1算法程序,增 加对摘要函数的理解。 二、实验要求 (1)理解SHA-1轮函数的定义和工作过程。 (2)利用VC语言实现SHA- 1算法。 (3)分析SHA- 1算法运行的性能。 三、实验…...
Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2)
Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2) import android.graphics.Canvas import android.graphics.Point import android.graphics.drawable.ColorDrawable import android.os.Bundle import android.util…...
翻页视图ViewPager
ViewPager控件允许页面在水平方向左右滑动,就像翻书、翻报纸,Android提供了已经分装好的控件。对于ViewPager来说,一个页面就是一个项(相当于ListView的一个列表项),许多页面组成ViewPager的页面项。 List…...

【可视化Java GUI程序设计教程】第4章 布局设计
4.1 布局管理器概述 右击窗体,单击快捷菜单中的Set Layout 4.1.2 绝对布局(Absolute Layout) 缩小窗口发现超出窗口范围的按钮看不见 Absolute Layout 4.1.2 空值布局(Null Layout) 4.1.3 布局管理器的属性和组件布…...
Elasticsearch配置文件
一 前言 在elasticsearch\config目录下,有三个核心的配置文件: elasticsearch.yml,es相关的配置。jvm.options,Java jvm相关参数的配置。log4j2.properties,日志相关的配置,因为es采用了log4j的日志框架。这里以elasticsearch6.5.4版本为例,并且由于版本不同,配置也不…...
C++:std::is_convertible
C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

黑马Mybatis
Mybatis 表现层:页面展示 业务层:逻辑处理 持久层:持久数据化保存 在这里插入图片描述 Mybatis快速入门 允许出现允许…...

Linux相关概念和易错知识点(42)(TCP的连接管理、可靠性、面临复杂网络的处理)
目录 1.TCP的连接管理机制(1)三次握手①握手过程②对握手过程的理解 (2)四次挥手(3)握手和挥手的触发(4)状态切换①挥手过程中状态的切换②握手过程中状态的切换 2.TCP的可靠性&…...
Neo4j 集群管理:原理、技术与最佳实践深度解析
Neo4j 的集群技术是其企业级高可用性、可扩展性和容错能力的核心。通过深入分析官方文档,本文将系统阐述其集群管理的核心原理、关键技术、实用技巧和行业最佳实践。 Neo4j 的 Causal Clustering 架构提供了一个强大而灵活的基石,用于构建高可用、可扩展且一致的图数据库服务…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...
css的定位(position)详解:相对定位 绝对定位 固定定位
在 CSS 中,元素的定位通过 position 属性控制,共有 5 种定位模式:static(静态定位)、relative(相对定位)、absolute(绝对定位)、fixed(固定定位)和…...
今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存
文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...
Fabric V2.5 通用溯源系统——增加图片上传与下载功能
fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...
AGain DB和倍数增益的关系
我在设置一款索尼CMOS芯片时,Again增益0db变化为6DB,画面的变化只有2倍DN的增益,比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析: 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...