当前位置: 首页 > news >正文

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • NI-FGSM介绍
      • 背景
      • 算法原理
  • NI-FGSM代码实现
    • NI-FGSM算法实现
    • 攻击效果
  • 代码汇总
    • nifgsm.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现NI-FGSM对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:

在这里插入图片描述

NI-FGSM介绍

NI-FGSM(Nesterov Iterative Fast Gradient Sign Method)即涅斯捷罗夫迭代快速梯度符号法,是一种在对抗攻击领域中对FGSM进行改进的迭代攻击算法,以下是其详细介绍:

背景

  • 传统的FGSM及其一些迭代改进版本如I-FGSM等,在生成对抗样本时存在一些局限性,例如可能会在迭代过程中陷入局部最优,导致攻击效果不够理想或生成的对抗样本转移性较差。NI-FGSM借鉴了优化算法中的Nesterov加速梯度法的思想,旨在更有效地利用梯度信息,提高攻击的效率和效果。

算法原理

  • 初始化:与其他对抗攻击方法类似,需要一个待攻击的目标模型 f f f、损失函数 J J J、原始图像 x x x及其对应的真实标签 y y y,同时还需要设定攻击步长 ϵ \epsilon ϵ、迭代次数 T T T等参数。
  • 迭代更新:在每次迭代 t t t中,首先计算一个“前瞻”点 x t l o o k a h e a d x_{t}^{lookahead} xtlookahead,它是基于当前迭代点 x t x_{t} xt和上一次迭代的梯度信息进行的一个预估更新点,公式为 x t l o o k a h e a d = x t + α ⋅ sign ( ∇ x J ( x t , y ) ) x_{t}^{lookahead}=x_{t}+\alpha \cdot \text{sign}\left(\nabla_{x} J\left(x_{t}, y\right)\right) xtlookahead=xt+αsign(xJ(xt,y)),其中 α \alpha α是一个类似于步长的参数。然后,计算在这个“前瞻”点处的损失梯度 ∇ x J ( x t l o o k a h e a d , y ) \nabla_{x} J\left(x_{t}^{lookahead}, y\right) xJ(xtlookahead,y),并根据该梯度来更新当前迭代点 x t x_{t} xt,更新公式为 x t + 1 = x t + ϵ ⋅ sign ( ∇ x J ( x t l o o k a h e a d , y ) ) x_{t + 1}=x_{t}+\epsilon \cdot \text{sign}\left(\nabla_{x} J\left(x_{t}^{lookahead}, y\right)\right) xt+1=xt+ϵsign(xJ(xtlookahead,y))
  • 投影操作:与其他对抗攻击方法一样,为了确保生成的对抗样本在合理的范围内,如像素值在 [ 0 , 1 ] [0, 1] [0,1] [ − 1 , 1 ] [-1, 1] [1,1]之间,需要对每次迭代更新后的样本进行投影操作。

NI-FGSM代码实现

NI-FGSM算法实现

import torch
import torch.nn as nndef NI_FGSM(model, criterion, original_images, labels, epsilon, num_iterations=10):"""NI-FGSM (Nesterov Iterative Fast Gradient Sign Method) 参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始输入图像数据- labels: 对应的真实标签- epsilon: 最大扰动幅度- num_iterations: 迭代次数"""# alpha: 每次迭代的步长alpha = epsilon / num_iterations# 复制原始图像作为初始的对抗样本,并设置其需要计算梯度perturbed_images = original_images.clone().detach().requires_grad_(True)for _ in range(num_iterations):# 计算 "前瞻" 点(基于当前对抗样本和当前梯度方向预估的下一步位置)lookahead_images = perturbed_images + alpha * torch.sign(perturbed_images.grad.data) if perturbed_images.grad is not None else perturbed_images# 前向传播得到模型输出outputs = model(lookahead_images)# 计算损失loss = criterion(outputs, labels)# 清空模型之前的梯度信息model.zero_grad()# 反向传播计算梯度loss.backward()# 获取当前梯度数据data_grad = lookahead_images.grad.data if lookahead_images.grad is not None else torch.zeros_like(original_images)# 计算符号梯度sign_data_grad = torch.sign(data_grad)# 更新对抗样本perturbed_images = perturbed_images + epsilon * sign_data_grad# 投影操作,确保扰动后的图像仍在合理范围内(这里假设图像范围是[0, 1])perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)perturbed_images = perturbed_images.detach().requires_grad_(True)return perturbed_images

攻击效果

在这里插入图片描述

代码汇总

nifgsm.py

import torch
import torch.nn as nndef NI_FGSM(model, criterion, original_images, labels, epsilon, num_iterations=10):"""NI-FGSM (Nesterov Iterative Fast Gradient Sign Method) 参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始输入图像数据- labels: 对应的真实标签- epsilon: 最大扰动幅度- num_iterations: 迭代次数"""# alpha: 每次迭代的步长alpha = epsilon / num_iterations# 复制原始图像作为初始的对抗样本,并设置其需要计算梯度perturbed_images = original_images.clone().detach().requires_grad_(True)for _ in range(num_iterations):# 计算 "前瞻" 点(基于当前对抗样本和当前梯度方向预估的下一步位置)lookahead_images = perturbed_images + alpha * torch.sign(perturbed_images.grad.data) if perturbed_images.grad is not None else perturbed_images# 前向传播得到模型输出outputs = model(lookahead_images)# 计算损失loss = criterion(outputs, labels)# 清空模型之前的梯度信息model.zero_grad()# 反向传播计算梯度loss.backward()# 获取当前梯度数据data_grad = lookahead_images.grad.data if lookahead_images.grad is not None else torch.zeros_like(original_images)# 计算符号梯度sign_data_grad = torch.sign(data_grad)# 更新对抗样本perturbed_images = perturbed_images + epsilon * sign_data_grad# 投影操作,确保扰动后的图像仍在合理范围内(这里假设图像范围是[0, 1])perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)perturbed_images = perturbed_images.detach().requires_grad_(True)return perturbed_images

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18# 数据预处理
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)if __name__ == "__main__":# 训练模型for epoch in range(10):  # 可以根据实际情况调整训练轮数running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')running_loss = 0.0torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as pltssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ResNet18(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))if __name__ == "__main__":# 在测试集上进行FGSM攻击并评估准确率model.eval()  # 设置为评估模式correct = 0total = 0epsilon = 16 / 255  # 可以调整扰动强度for data in testloader:original_images, labels = data[0].to(device), data[1].to(device)original_images.requires_grad = Trueattack_name = 'MI-FGSM'if attack_name == 'FGSM':perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)elif attack_name == 'BIM':perturbed_images = BIM(model, criterion, original_images, labels, epsilon)elif attack_name == 'MI-FGSM':perturbed_images = MI_FGSM(model, criterion, original_images, labels, epsilon)elif attack_name == 'NI-FGSM':perturbed_images = NI_FGSM(model, criterion, original_images, labels, epsilon)perturbed_outputs = model(perturbed_images)_, predicted = torch.max(perturbed_outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total# Attack Success RateASR = 100 - accuracyprint(f'Load ResNet Model Weight from {weights_path}')print(f'epsilon: {epsilon}')print(f'ASR of {attack_name} : {ASR}%')

相关文章:

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集NI-FGSM介绍背景算法原理 NI-FGSM代码实现NI-FGSM算法实现攻击效果 代码汇总nifgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器: Pytorch | 从零构建AlexNet对CIFAR10进行…...

深度学习实战车辆目标跟踪【bytetrack/deepsort】

本文采用YOLOv8作为核心算法框架,结合PyQt5构建用户界面,使用Python3进行开发。YOLOv8以其高效的实时检测能力,在多个目标检测任务中展现出卓越性能。本研究针对车辆目标数据集进行训练和优化,该数据集包含丰富的车辆目标图像样本…...

【C复习】模拟题题库*3总结

1.c语言中要求对变量作强制定义的主要理由是便于确定类型和分配空间 2.结构化程序由三中基本结构组成,三中基本结构组成的算法可以完成任何复杂的任务 3.数组名是一个不可变的常量 4.下列选项中,合法的C语言关键字是()。 …...

【数据分析】层次贝叶斯

文章目录 一、 贝叶斯推理二、 层次贝叶斯模型三、 层次贝叶斯的特点四、 数学表述五、推断方法六、应用领域 层次贝叶斯(Hierarchical Bayesian)方法是一种基于贝叶斯推理的统计模型,用于处理具有多个层次结构的数据模型。 它允许我们在同一…...

Layui table不使用url属性结合laypage组件实现动态分页

从后台一次性获取所有数据赋值给 Layui table 组件的 data 属性,若数据量大时,很可能会超出浏览器字符串最大长度,导致渲染数据失败。Layui table 结合 laypage 组件实现动态分页可解决此问题。 HTML增加分页组件标签 在table后增加一个用于…...

【蓝桥杯】43688-《Excel地址问题》

Excel地址问题 题目描述 Excel 单元格的地址表示很有趣,它可以使用字母来表示列号。比如, A 表示第 1 列, B 表示第 2 列, … Z 表示第 26 列, AA 表示第 27 列, AB 表示第 28 列, … BA 表示…...

【bodgeito】攻防实战记录

也许有一天我们再相逢&#xff0c;睁开眼睛看清楚&#xff0c;我才是英雄。 进入网站整体浏览网页 点击页面评分进入关卡 一般搭建之后这里都是红色的&#xff0c;黄色是代表接近&#xff0c;绿色代表过关 首先来到搜索处本着见框就插的原则 构造payload输入 <script>…...

Soul Preserver

Soul Preserver 护魂者 Soul Preserver - Item - 魔兽世界怀旧服WLK3.35数据库_巫妖王之怒80级魔兽数据库_wlk数据库 原来的1274法力值 圣光闪现不需要法力 圣光术原来的474法力值 但是测试数据3-5分钟有时候就触发了3次&#xff0c;节约2400蓝...

Android 折叠屏问题解决 - 展开或收起页面重建

一、问题说明 Android 折叠屏展开或收起后页面会重建&#xff0c;并重新走 onCreate onStart onResume ... 重新创建后页面的状态也会丢失&#xff0c;比如页面中是一个 RecyclerView&#xff0c;我们滑动到了第 5 个卡片的位置&#xff0c;展开后又自动滑动到了第 1 个卡片的…...

深入理解 Linux wc 命令

文章目录 深入理解 Linux wc 命令1. 基本功能2. 常用选项3. 示例3.1 统计文件的行、单词和字符数3.2 仅统计行数3.3 统计多个文件的总和3.4 使用管道统计命令输出的行数 4. 实用案例4.1 日志分析4.2 快速统计代码行数4.3 统计单词频率 5. 注意事项6. 总结 深入理解 Linux wc 命…...

半连接转内连接规则的原理与代码解析 |OceanBase查询优化

背景 在查询语句中&#xff0c;若涉及半连接&#xff08;semi join&#xff09;操作&#xff0c;由于半连接不满足交换律的规则&#xff0c;连接操作必须遵循语句中定义的顺序执行&#xff0c;从而限制了优化器根据参与连接的表的实际数据量来灵活选择优化策略的能力。为此&am…...

多进程、多线程、分布式测试支持-pytest-xdis插件

pytest-xdist是pytest测试框架的一个插件&#xff0c;它提供了多进程、多线程和分布式测试的支持&#xff0c;可以显著提高测试效率。以下是对pytest-xdist的详细介绍&#xff1a; 一、安装 要使用pytest-xdist&#xff0c;首先需要安装pytest和pytest-xdist。可以通过pip进行…...

Oracle virTualBox安装window10

一、下载windows10镜像 我下载的windows10镜像如下&#xff1a; 内部文件如下&#xff1a; 二、错误的安装方法 直接新建虚拟机&#xff0c;选择镜像文件&#xff1a; 启动虚拟机&#xff08;会一直提示没有启动设备&#xff0c;选择镜像后一直弹窗提示&#xff09; 三、正确…...

Python7-数据结构

记录python学习&#xff0c;直到学会基本的爬虫&#xff0c;使用python搭建接口自动化测试就算学会了&#xff0c;在进阶webui自动化&#xff0c;app自动化 python基础7-数据结构的那些事儿 常见的数据结构有哪些&#xff1f;线性数据结构有哪些&#xff1f;非线性数据结构有哪…...

springboot指定ssl版本连接

在application.yml配置指定 server.ssl.protocolTLSv1.2结果应用依然接受低版本如TLSv1.0的连接 可以在ie浏览器&#xff1a;设置-Internet选项-高级&#xff0c;将当前连接改为TLSv1.0进行测试 这种情况可以通过增加配置仅由TLSv1.2支持的密码处理&#xff1a; server.ssl.…...

VTK编程指南<十二>:VTK图像数据结构及图像创建与显示

数字图像是一种重要的多媒体数据&#xff0c;广泛应用于工业生产、生物医学、地质、气象等重要领域。数字图像处理技术具有重要的应用价值。图像是VTK里非常重要的一种数据结构。本章重点讲解VTK在数字图像处理应用方面的相关技术。 1、VTK图像数据结构 数字图像文件内容由两个…...

EasyGBS国标GB28181平台P2P远程访问故障排查指南:客户端角度的排查思路

在现代视频监控系统中&#xff0c;P2P&#xff08;点对点&#xff09;技术因其便捷性和高效性而被广泛应用。然而&#xff0c;当用户在使用P2P远程访问时遇到设备不在线或无法访问的问题时&#xff0c;有效的排查方法显得尤为重要。本文将从客户端的角度出发&#xff0c;详细探…...

打造智慧医院挂号枢纽:SSM 与 Vue 融合的系统设计与实施

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…...

网络编程 02:IP 地址,IP 地址的作用、分类,通过 Java 实现 IP 地址的信息获取

一、概述 记录时间 [2024-12-18] 前置文章&#xff1a;网络编程 01&#xff1a;计算机网络概述&#xff0c;网络的作用&#xff0c;网络通信的要素&#xff0c;以及网络通信协议与分层模型 本文讲述网络编程相关知识——IP 地址&#xff0c;包括 IP 地址的作用、分类&#xff…...

如何使用Python WebDriver爬取ChatGPT内容(完整教程)

大背景 虽然我们能用网页版chatGPT来聊天、写文章&#xff0c;但是我们采集大量的内容&#xff0c;就得不断地手动输入提问来获取答案&#xff0c;并且将结果复制到数据库来保存。如果整个过程能使用程序来做自然要节省很多的人力&#xff0c;精力和时间。 Python webdirver …...

WSL切换默认发行版

查看适用于wsl的子系统有哪些: wslconfig /list 设置wsl的默认发行版 wslconfig /setdefault Ubuntu-20.04...

全志H618 Android12修改doucmentsui功能菜单项

背景: 由于当前的文件管理器在我们的产品定义当中,某些界面有改动的需求,所以需要在Android12 rom中进行定制以符合当前产品定义。 需求: 在进入File文件管理器后,查看...功能菜单时,有不需要的功能菜单,需要隐藏,如:新建窗口、不显示的文件夹、故代码分析以及客制…...

移动网络(2,3,4,5G)设备TCP通讯调试方法

背景&#xff1a; 当设备是移动网络设备连接云平台的时候&#xff0c;如果服务器没有收到网络数据&#xff0c;移动物联设备发送不知道有没有有丢失数据的时候&#xff0c;需要一个抓取设备出来的数据和服务器下发的数据的方法。 1.服务器系统是很成熟的&#xff0c;一般是linu…...

网络安全概论——入侵检测系统IDS

一、入侵检测的概念 1、入侵检测的概念 检测对计算机系统的非授权访问对系统的运行状态进行监视&#xff0c;发现各种攻击企图、攻击行为或攻击结果&#xff0c;以保证系统资源的保密性、完整性和可用性识别针对计算机系统和网络系统或广义上的信息系统的非法攻击&#xff0c…...

Linux通信System V:消息队列 信号量

Linux通信System V&#xff1a;消息队列 & 信号量 一、信号量概念二、信号量意义三、操作系统如何管理ipc资源&#xff08;2.36版本&#xff09;四、如何对信号量资源进行管理 一、信号量概念 信号量本质上就是计数器&#xff0c;用来保护共享资源。多个进程在进行通信时&a…...

计算机网络基础图解

注&#xff1a;本文为来自 猿小许 的 “计算机网络” 相关系列文章合辑。 一、计算机网络概述 猿小许于 2021-06-03 18:39:47 发布 一、计算机网络的概念 1.1 计算机网络 概念 计算机网络&#xff1a; 是一个将分散的、具有独立功能的计算机系统&#xff0c;通过通信设备与…...

TDesign:NavBar 导航栏

NavBar 导航栏 左图&#xff0c;右标 appBar: TDNavBar(padding: EdgeInsets.only(left: 0,right: 30.w), // 重写左右内边距centerTitle:false, // 不显示标题height: 45, // 高度titleWidget: TDImage( // 左图assetUrl: assets/img/logo.png,width: 147.w,height: 41.w,),ba…...

hive注释comment中文乱码解决

问题描述 当使用以下命令查看表的元数据信息时出现中文乱码&#xff08;使用的是idea连接hive&#xff09; desc formatted test.t_archer; 解决 连接保存hive元数据的MySQL数据库&#xff0c;执行以下命令&#xff1a; use hive3; show tables;alter table hive3.COLUMNS_…...

电脑提示ntdll.d缺失是什么原因?不处理的话会怎么样?ntdll.dll文件缺失快速解决方案来啦!

电脑提示ntdll.dll缺失&#xff1a;原因、影响与解决方案 在日常的电脑使用中&#xff0c;我们偶尔会遇到一些令人困惑的系统错误&#xff0c;其中“ntdll.dll缺失”便是较为常见的一种。作为软件开发从业者&#xff0c;我深知这一错误给用户带来的不便&#xff0c;因此&#…...

MFC/C++学习系列之简单记录——序列化机制

MFC/C学习系列之简单记录——序列化机制 前言简述六大机制序列化机制使用反序列化总结 前言 MFC有六大机制&#xff0c;分别是程序启动机制、窗口创建机制、动态创建机制、运行时类信息机制、消息映射机制、序列化机制。 简述六大机制 程序启动机制&#xff1a;全局的应用程序…...