BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例
加载数据集
# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
plt.imshow(train_loader.data[0].numpy())
plt.show()

在训练集中植入5000个中毒样本
# 在训练集中植入5000个中毒样本
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()

训练模型
data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)
# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))
def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue
if __name__=='__main__':main()
测试攻击成功率
# 攻击成功率 99.66% 对测试集中所有图像都注入后门for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()
可视化中毒样本,成功被预测为特定目标类别“9”,证明攻击成功。


完整代码
from packaging import packaging
from torchvision.models import resnet50
from utils import Flatten
from tqdm import tqdm
import numpy as np
import torch
from torch import optim, nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")# 载入MNIST训练集和测试集
transform = transforms.Compose([transforms.ToTensor(),])
train_loader = datasets.MNIST(root='data',transform=transform,train=True,download=True)
test_loader = datasets.MNIST(root='data',transform=transform,train=False)
# 可视化样本 大小28×28
# plt.imshow(train_loader.data[0].numpy())
# plt.show()# 训练集样本数据
print(len(train_loader))# 在训练集中植入5000个中毒样本
''' '''
for i in range(5000):train_loader.data[i][26][26] = 255train_loader.data[i][25][25] = 255train_loader.data[i][24][26] = 255train_loader.data[i][26][24] = 255train_loader.targets[i] = 9 # 设置中毒样本的目标标签为9
# 可视化中毒样本
plt.imshow(train_loader.data[0].numpy())
plt.show()data_loader_train = torch.utils.data.DataLoader(dataset=train_loader,batch_size=64,shuffle=True,num_workers=0)
data_loader_test = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)# LeNet-5 模型
class LeNet_5(nn.Module):def __init__(self):super(LeNet_5, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5, 1)self.conv2 = nn.Conv2d(6, 16, 5, 1)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(self.conv1(x), 2, 2)x = F.max_pool2d(self.conv2(x), 2, 2)x = x.view(-1, 16 * 4 * 4)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 训练过程
def train(model, device, train_loader, optimizer, epoch):model.train()for idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)pred = model(data)loss = F.cross_entropy(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if idx % 100 == 0:print("Train Epoch: {}, iterantion: {}, Loss: {}".format(epoch, idx, loss.item()))torch.save(model.state_dict(), 'badnets.pth')# 测试过程
def test(model, device, test_loader):model.load_state_dict(torch.load('badnets.pth'))model.eval()total_loss = 0correct = 0with torch.no_grad():for idx, (data, target) in enumerate(test_loader):data, target = data.to(device), target.to(device)output = model(data)total_loss += F.cross_entropy(output, target, reduction="sum").item()pred = output.argmax(dim=1)correct += pred.eq(target.view_as(pred)).sum().item()total_loss /= len(test_loader.dataset)acc = correct / len(test_loader.dataset) * 100print("Test Loss: {}, Accuracy: {}".format(total_loss, acc))def main():# 超参数num_epochs = 10lr = 0.01momentum = 0.5model = LeNet_5().to(device)optimizer = torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)# 在干净训练集上训练,在干净测试集上测试# acc=98.29%# 在带后门数据训练集上训练,在干净测试集上测试# acc=98.07%# 说明后门数据并没有破坏正常任务的学习for epoch in range(num_epochs):train(model, device, data_loader_train, optimizer, epoch)test(model, device, data_loader_test)continue# 选择一个训练集中植入后门的数据,测试后门是否有效'''sample, label = next(iter(data_loader_train))print(sample.size()) # [64, 1, 28, 28]print(label[0])# 可视化plt.imshow(sample[0][0])plt.show()model.load_state_dict(torch.load('badnets.pth'))model.eval()sample = sample.to(device)output = model(sample)print(output[0])pred = output.argmax(dim=1)print(pred[0])'''# 攻击成功率 99.66%for i in range(len(test_loader)):test_loader.data[i][26][26] = 255test_loader.data[i][25][25] = 255test_loader.data[i][24][26] = 255test_loader.data[i][26][24] = 255test_loader.targets[i] = 9data_loader_test2 = torch.utils.data.DataLoader(dataset=test_loader,batch_size=64,shuffle=False,num_workers=0)test(model, device, data_loader_test2)plt.imshow(test_loader.data[0].numpy())plt.show()if __name__=='__main__':main()
相关文章:
BadNets:基于数据投毒的模型后门攻击代码(Pytorch)以MNIST为例
加载数据集 # 载入MNIST训练集和测试集 transform transforms.Compose([transforms.ToTensor(),]) train_loader datasets.MNIST(rootdata,transformtransform,trainTrue,downloadTrue) test_loader datasets.MNIST(rootdata,transformtransform,trainFalse) # 可视化样本 …...
freeRTOS内部机制——栈的作用
上图中*pa 和*pb分别为R0,R1,调用C函数时,第一个参数保存在R0中第二个参数保存在R1中。这是约定。 指令保存在哪里? 指令保存在flash上面 LR等于什么? LR是返回地址,函数执行完了过后LR等于下一条指令的地址 运行…...
python 桌面软件开发-matplotlib画图鼠标缩放拖动
继上一篇在 Java 中缩放拖动图片后,在python matplotlib中也来实现一个自由缩放拖动的例子: python matplotlib 中缩放,较为简单,只需要通过设置要显示的 x y坐标的显示范围即可。基于此,实现一个鼠标监听回调…...
【JavaScript基础】JavaScript头等函数的理解
彻底理解JavaScript头等函数 一、函数的理解 🔥 什么是函数? 一般来说,一个函数是可以通过外部代码 调用 的一个“子程序”(或在递归的情况下由内部函数调用)。像程序本身一样,一个函数由称为函数体的一…...
如何把项目上传到Gitee(详细教程)
找到项目根目录右键打开Git Bash Here 输入命令:git init 回车 输入命令:git status 输入命令:git add . 输入命令:git status git commit -m 项目描述 在Gitee官网注册好账号后,git 新建项目 填写补充git项目信息及…...
Ubuntu挂载windows下的共享文件夹
Ubuntu挂载windows下的共享文件夹 更新apt源 如果出现安装失败,需要更新apt源为阿里云 # 备份原始文件 sudo cp /etc/apt/sources.list.d/* /etc/apt/sources.list.d.bak/# 修改文件内容 sudo vim /etc/apt/sources.list# 替换内容为如下 deb https://mirrors.al…...
什么是WMS系统条码化管理
WMS系统是一种用于仓库管理的信息化系统,旨在提高仓库操作的效率和准确性。而在WMS系统中,条码化管理是一项关键的技术和方法,它通过将商品和物料打上条码,并利用扫描设备进行数据采集和处理,实现了仓库管理的全面自动…...
【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统
【云原生之kubernetes实战】在k8s环境下部署moredoc文库系统 一、moredoc介绍1.1 moredoc简介1.2 moredoc技术栈二、本次实践介绍2.1 本次实践简介2.2 本次环境规划三、检查k8s环境3.1 检查工作节点状态3.2 检查系统pod状态四、创建mysql的secret资源4.1 创建部署目录4.2 创建密…...
[Database] MySQL 8.x Window / Partition Function (窗口/分区函数)
🧲相关文章 [1] MySQL 系统表解析以及各项指标查询 [2] MySQL 5.7 JSON 字段的使用的处理 [3] MySQL经典练习50题 简介 MySQL 8.0版本开始支持窗口函数 官方文档 在之前的版本中已存在的大部分聚合函数,在MySQL 8 中也可以作为窗口函数来使用 方法 / …...
openGauss Meetup(天津站)精彩回顾 | openGauss天津用户组正式成立
由openGauss社区、天开发展集团、天津市软件行业协会、天大智图(天津)科技有限公司联合主办的“openGauss Meetup • 天津站”已于10月13日落下帷幕,此次活动邀请到众多业内技术专家,从技术创新、学术创新、发展创新、以及生态共建…...
linux vim 删除多行
使用linux服务器,免不了和vi编辑打交道,命令行下删除数量少还好,如果删除很多,光靠删除键一点点删除真的是头痛,还好Vi有快捷的命令可以删除多行、范围。 删除行 在Vim中删除一行的命令是dd。 以下是删除行的分步说明…...
低概率Bug,研发敷衍说复现不到
测试工作中,经常会遇到一些低概率出现的问题,如果再是个严重问题,那测试人员的压力无疑是很大的,一方面是因为低概率难以复现,另一面则是来自项目组的压力。 如何在测试时减少此类问题的重复投入,我的思考如…...
Web前端免费接入Microsoft Azure AI文本翻译,享每月2百万个字符的翻译
Azure 文本翻译是 Azure AI 翻译服务的一项基于云的 REST API 功能。 文本翻译 API 支持实时快速准确地进行源到目标文本翻译。 文本翻译软件开发工具包 (SDK) 是一组库和工具,可用于轻松地将文本翻译 REST API 功能集成到应用程序中。 文本翻译 SDK 可跨 C#/.NET、…...
1024 CSDN 程序员节-知存科技-基于存内计算芯片开发板验证语音识别
前言 在今年的 CSDN 程序员节上,我参与了这次知存科技举办的一个 AI Workshop 小活动——“基于存内计算芯片开发板验证语音识别”,并且有幸成为完成任务的学习者之一XD。上一次参与类似的活动是算能公司举办的“千校万里行”AIGC 大模型编译部署活动&a…...
【备考网络工程师】如何备考2023年网络工程师之错题集篇(3)
一、写在前面 其实做模拟或真题时候,总是会在关键的地方丢分,因此我也冷静下来思考一下,首先我们对做过的题涉及的知识进行一个梳理,其次就是再针对知识去做一些题目,这次只考了38分,表示很伤心࿰…...
密码学-SHA-1算法
实验七 SHA-1 一、实验目的 熟悉SHA-1算法的运行过程,能够使用C语言编写实现SHA-1算法程序,增 加对摘要函数的理解。 二、实验要求 (1)理解SHA-1轮函数的定义和工作过程。 (2)利用VC语言实现SHA- 1算法。 (3)分析SHA- 1算法运行的性能。 三、实验…...
Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2)
Android View拖拽/拖放DragAndDrop自定义View.DragShadowBuilder,Kotlin(2) import android.graphics.Canvas import android.graphics.Point import android.graphics.drawable.ColorDrawable import android.os.Bundle import android.util…...
翻页视图ViewPager
ViewPager控件允许页面在水平方向左右滑动,就像翻书、翻报纸,Android提供了已经分装好的控件。对于ViewPager来说,一个页面就是一个项(相当于ListView的一个列表项),许多页面组成ViewPager的页面项。 List…...
【可视化Java GUI程序设计教程】第4章 布局设计
4.1 布局管理器概述 右击窗体,单击快捷菜单中的Set Layout 4.1.2 绝对布局(Absolute Layout) 缩小窗口发现超出窗口范围的按钮看不见 Absolute Layout 4.1.2 空值布局(Null Layout) 4.1.3 布局管理器的属性和组件布…...
Elasticsearch配置文件
一 前言 在elasticsearch\config目录下,有三个核心的配置文件: elasticsearch.yml,es相关的配置。jvm.options,Java jvm相关参数的配置。log4j2.properties,日志相关的配置,因为es采用了log4j的日志框架。这里以elasticsearch6.5.4版本为例,并且由于版本不同,配置也不…...
Wan2.2-I2V-A14B绿色AI实践:显存优化降低35%功耗的碳足迹测算
Wan2.2-I2V-A14B绿色AI实践:显存优化降低35%功耗的碳足迹测算 1. 引言:绿色AI的迫切需求 在AI技术快速发展的今天,大模型训练和推理带来的能源消耗问题日益突出。Wan2.2-I2V-A14B作为一款先进的文生视频模型,通过显存优化技术实…...
从ChatGPT插件到MCP:一个AI开发者亲历的工具集成进化史
从ChatGPT插件到MCP:一个AI开发者亲历的工具集成进化史 三年前,当我第一次尝试让ChatGPT调用外部API时,需要手动拼接JSON参数、处理OAuth认证、设计错误重试机制——光是让模型能查询天气就耗费了两天时间。如今,通过MCP协议&…...
终极指南:ImagePicker资源解析机制如何高效处理图像资源
终极指南:ImagePicker资源解析机制如何高效处理图像资源 【免费下载链接】ImagePicker :camera: Reinventing the way ImagePicker works. 项目地址: https://gitcode.com/gh_mirrors/im/ImagePicker ImagePicker作为一款重新定义图片选择体验的工具…...
COMSOL相场模拟:枝晶生长与雪花形成的模型与教程
comsol相场模拟枝晶生长(雪花的形成) 有模型和教程 凌晨三点盯着显微镜下的冰晶生长,突然意识到这玩意儿和编程调试一样——参数调不好分分钟给你长歪。相场法模拟枝晶生长这事儿,本质上就是在用数学方程式和物理定律"种&qu…...
如何使用USearch构建自动驾驶传感器数据的实时向量搜索系统
如何使用USearch构建自动驾驶传感器数据的实时向量搜索系统 【免费下载链接】usearch Fastest Open-Source Search & Clustering engine for Vectors & 🔜 Strings in C, C, Python, JavaScript, Rust, Java, Objective-C, Swift, C#, GoLang, and Wolfra…...
Obsidian模板库实战指南:从零构建高效知识管理系统
Obsidian模板库实战指南:从零构建高效知识管理系统 【免费下载链接】OB_Template OB_Templates is a Obsidian reference for note templates focused on new users of the application using only core plugins. 项目地址: https://gitcode.com/gh_mirrors/ob/OB…...
保姆级教程:用YOLOv8+PyQt5打造你的番茄成熟度检测桌面应用(附完整源码与数据集)
从零构建番茄成熟度检测桌面应用:YOLOv8与PyQt5深度整合实战 在农业智能化浪潮中,计算机视觉技术正逐步改变传统农业生产方式。以番茄种植为例,成熟度判断直接影响采摘效率和经济效益。本文将带您完整实现一个结合YOLOv8目标检测与PyQt5图形界…...
BilibiliDown:三分钟掌握跨平台B站视频批量下载终极方案
BilibiliDown:三分钟掌握跨平台B站视频批量下载终极方案 【免费下载链接】BilibiliDown (GUI-多平台支持) B站 哔哩哔哩 视频下载器。支持稍后再看、收藏夹、UP主视频批量下载|Bilibili Video Downloader 😳 项目地址: https://gitcode.com/gh_mirrors…...
如何快速掌握NoteGen AI笔记:新手入门完整指南
如何快速掌握NoteGen AI笔记:新手入门完整指南 【免费下载链接】note-gen 一款专注于记录和写作的跨端 AI 笔记应用。 项目地址: https://gitcode.com/GitHub_Trending/no/note-gen 在信息爆炸的时代,高效记录和管理知识已成为现代人的刚需。Note…...
Midscene.js视觉驱动自动化:从认知到实践的AI跨平台控制指南
Midscene.js视觉驱动自动化:从认知到实践的AI跨平台控制指南 【免费下载链接】midscene Let AI be your browser operator. 项目地址: https://gitcode.com/GitHub_Trending/mid/midscene 一、认知篇:理解Midscene.js的技术革新 1.1 破解传统自动…...
