当前位置: 首页 > news >正文

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • NI-FGSM介绍
      • 背景
      • 算法原理
  • NI-FGSM代码实现
    • NI-FGSM算法实现
    • 攻击效果
  • 代码汇总
    • nifgsm.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现NI-FGSM对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:

在这里插入图片描述

NI-FGSM介绍

NI-FGSM(Nesterov Iterative Fast Gradient Sign Method)即涅斯捷罗夫迭代快速梯度符号法,是一种在对抗攻击领域中对FGSM进行改进的迭代攻击算法,以下是其详细介绍:

背景

  • 传统的FGSM及其一些迭代改进版本如I-FGSM等,在生成对抗样本时存在一些局限性,例如可能会在迭代过程中陷入局部最优,导致攻击效果不够理想或生成的对抗样本转移性较差。NI-FGSM借鉴了优化算法中的Nesterov加速梯度法的思想,旨在更有效地利用梯度信息,提高攻击的效率和效果。

算法原理

  • 初始化:与其他对抗攻击方法类似,需要一个待攻击的目标模型 f f f、损失函数 J J J、原始图像 x x x及其对应的真实标签 y y y,同时还需要设定攻击步长 ϵ \epsilon ϵ、迭代次数 T T T等参数。
  • 迭代更新:在每次迭代 t t t中,首先计算一个“前瞻”点 x t l o o k a h e a d x_{t}^{lookahead} xtlookahead,它是基于当前迭代点 x t x_{t} xt和上一次迭代的梯度信息进行的一个预估更新点,公式为 x t l o o k a h e a d = x t + α ⋅ sign ( ∇ x J ( x t , y ) ) x_{t}^{lookahead}=x_{t}+\alpha \cdot \text{sign}\left(\nabla_{x} J\left(x_{t}, y\right)\right) xtlookahead=xt+αsign(xJ(xt,y)),其中 α \alpha α是一个类似于步长的参数。然后,计算在这个“前瞻”点处的损失梯度 ∇ x J ( x t l o o k a h e a d , y ) \nabla_{x} J\left(x_{t}^{lookahead}, y\right) xJ(xtlookahead,y),并根据该梯度来更新当前迭代点 x t x_{t} xt,更新公式为 x t + 1 = x t + ϵ ⋅ sign ( ∇ x J ( x t l o o k a h e a d , y ) ) x_{t + 1}=x_{t}+\epsilon \cdot \text{sign}\left(\nabla_{x} J\left(x_{t}^{lookahead}, y\right)\right) xt+1=xt+ϵsign(xJ(xtlookahead,y))
  • 投影操作:与其他对抗攻击方法一样,为了确保生成的对抗样本在合理的范围内,如像素值在 [ 0 , 1 ] [0, 1] [0,1] [ − 1 , 1 ] [-1, 1] [1,1]之间,需要对每次迭代更新后的样本进行投影操作。

NI-FGSM代码实现

NI-FGSM算法实现

import torch
import torch.nn as nndef NI_FGSM(model, criterion, original_images, labels, epsilon, num_iterations=10):"""NI-FGSM (Nesterov Iterative Fast Gradient Sign Method) 参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始输入图像数据- labels: 对应的真实标签- epsilon: 最大扰动幅度- num_iterations: 迭代次数"""# alpha: 每次迭代的步长alpha = epsilon / num_iterations# 复制原始图像作为初始的对抗样本,并设置其需要计算梯度perturbed_images = original_images.clone().detach().requires_grad_(True)for _ in range(num_iterations):# 计算 "前瞻" 点(基于当前对抗样本和当前梯度方向预估的下一步位置)lookahead_images = perturbed_images + alpha * torch.sign(perturbed_images.grad.data) if perturbed_images.grad is not None else perturbed_images# 前向传播得到模型输出outputs = model(lookahead_images)# 计算损失loss = criterion(outputs, labels)# 清空模型之前的梯度信息model.zero_grad()# 反向传播计算梯度loss.backward()# 获取当前梯度数据data_grad = lookahead_images.grad.data if lookahead_images.grad is not None else torch.zeros_like(original_images)# 计算符号梯度sign_data_grad = torch.sign(data_grad)# 更新对抗样本perturbed_images = perturbed_images + epsilon * sign_data_grad# 投影操作,确保扰动后的图像仍在合理范围内(这里假设图像范围是[0, 1])perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)perturbed_images = perturbed_images.detach().requires_grad_(True)return perturbed_images

攻击效果

在这里插入图片描述

代码汇总

nifgsm.py

import torch
import torch.nn as nndef NI_FGSM(model, criterion, original_images, labels, epsilon, num_iterations=10):"""NI-FGSM (Nesterov Iterative Fast Gradient Sign Method) 参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始输入图像数据- labels: 对应的真实标签- epsilon: 最大扰动幅度- num_iterations: 迭代次数"""# alpha: 每次迭代的步长alpha = epsilon / num_iterations# 复制原始图像作为初始的对抗样本,并设置其需要计算梯度perturbed_images = original_images.clone().detach().requires_grad_(True)for _ in range(num_iterations):# 计算 "前瞻" 点(基于当前对抗样本和当前梯度方向预估的下一步位置)lookahead_images = perturbed_images + alpha * torch.sign(perturbed_images.grad.data) if perturbed_images.grad is not None else perturbed_images# 前向传播得到模型输出outputs = model(lookahead_images)# 计算损失loss = criterion(outputs, labels)# 清空模型之前的梯度信息model.zero_grad()# 反向传播计算梯度loss.backward()# 获取当前梯度数据data_grad = lookahead_images.grad.data if lookahead_images.grad is not None else torch.zeros_like(original_images)# 计算符号梯度sign_data_grad = torch.sign(data_grad)# 更新对抗样本perturbed_images = perturbed_images + epsilon * sign_data_grad# 投影操作,确保扰动后的图像仍在合理范围内(这里假设图像范围是[0, 1])perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)perturbed_images = perturbed_images.detach().requires_grad_(True)return perturbed_images

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18# 数据预处理
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)if __name__ == "__main__":# 训练模型for epoch in range(10):  # 可以根据实际情况调整训练轮数running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')running_loss = 0.0torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as pltssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ResNet18(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))if __name__ == "__main__":# 在测试集上进行FGSM攻击并评估准确率model.eval()  # 设置为评估模式correct = 0total = 0epsilon = 16 / 255  # 可以调整扰动强度for data in testloader:original_images, labels = data[0].to(device), data[1].to(device)original_images.requires_grad = Trueattack_name = 'MI-FGSM'if attack_name == 'FGSM':perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)elif attack_name == 'BIM':perturbed_images = BIM(model, criterion, original_images, labels, epsilon)elif attack_name == 'MI-FGSM':perturbed_images = MI_FGSM(model, criterion, original_images, labels, epsilon)elif attack_name == 'NI-FGSM':perturbed_images = NI_FGSM(model, criterion, original_images, labels, epsilon)perturbed_outputs = model(perturbed_images)_, predicted = torch.max(perturbed_outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total# Attack Success RateASR = 100 - accuracyprint(f'Load ResNet Model Weight from {weights_path}')print(f'epsilon: {epsilon}')print(f'ASR of {attack_name} : {ASR}%')

相关文章:

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用NI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集NI-FGSM介绍背景算法原理 NI-FGSM代码实现NI-FGSM算法实现攻击效果 代码汇总nifgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器: Pytorch | 从零构建AlexNet对CIFAR10进行…...

深度学习实战车辆目标跟踪【bytetrack/deepsort】

本文采用YOLOv8作为核心算法框架,结合PyQt5构建用户界面,使用Python3进行开发。YOLOv8以其高效的实时检测能力,在多个目标检测任务中展现出卓越性能。本研究针对车辆目标数据集进行训练和优化,该数据集包含丰富的车辆目标图像样本…...

【C复习】模拟题题库*3总结

1.c语言中要求对变量作强制定义的主要理由是便于确定类型和分配空间 2.结构化程序由三中基本结构组成,三中基本结构组成的算法可以完成任何复杂的任务 3.数组名是一个不可变的常量 4.下列选项中,合法的C语言关键字是()。 …...

【数据分析】层次贝叶斯

文章目录 一、 贝叶斯推理二、 层次贝叶斯模型三、 层次贝叶斯的特点四、 数学表述五、推断方法六、应用领域 层次贝叶斯(Hierarchical Bayesian)方法是一种基于贝叶斯推理的统计模型,用于处理具有多个层次结构的数据模型。 它允许我们在同一…...

Layui table不使用url属性结合laypage组件实现动态分页

从后台一次性获取所有数据赋值给 Layui table 组件的 data 属性,若数据量大时,很可能会超出浏览器字符串最大长度,导致渲染数据失败。Layui table 结合 laypage 组件实现动态分页可解决此问题。 HTML增加分页组件标签 在table后增加一个用于…...

【蓝桥杯】43688-《Excel地址问题》

Excel地址问题 题目描述 Excel 单元格的地址表示很有趣,它可以使用字母来表示列号。比如, A 表示第 1 列, B 表示第 2 列, … Z 表示第 26 列, AA 表示第 27 列, AB 表示第 28 列, … BA 表示…...

【bodgeito】攻防实战记录

也许有一天我们再相逢&#xff0c;睁开眼睛看清楚&#xff0c;我才是英雄。 进入网站整体浏览网页 点击页面评分进入关卡 一般搭建之后这里都是红色的&#xff0c;黄色是代表接近&#xff0c;绿色代表过关 首先来到搜索处本着见框就插的原则 构造payload输入 <script>…...

Soul Preserver

Soul Preserver 护魂者 Soul Preserver - Item - 魔兽世界怀旧服WLK3.35数据库_巫妖王之怒80级魔兽数据库_wlk数据库 原来的1274法力值 圣光闪现不需要法力 圣光术原来的474法力值 但是测试数据3-5分钟有时候就触发了3次&#xff0c;节约2400蓝...

Android 折叠屏问题解决 - 展开或收起页面重建

一、问题说明 Android 折叠屏展开或收起后页面会重建&#xff0c;并重新走 onCreate onStart onResume ... 重新创建后页面的状态也会丢失&#xff0c;比如页面中是一个 RecyclerView&#xff0c;我们滑动到了第 5 个卡片的位置&#xff0c;展开后又自动滑动到了第 1 个卡片的…...

深入理解 Linux wc 命令

文章目录 深入理解 Linux wc 命令1. 基本功能2. 常用选项3. 示例3.1 统计文件的行、单词和字符数3.2 仅统计行数3.3 统计多个文件的总和3.4 使用管道统计命令输出的行数 4. 实用案例4.1 日志分析4.2 快速统计代码行数4.3 统计单词频率 5. 注意事项6. 总结 深入理解 Linux wc 命…...

半连接转内连接规则的原理与代码解析 |OceanBase查询优化

背景 在查询语句中&#xff0c;若涉及半连接&#xff08;semi join&#xff09;操作&#xff0c;由于半连接不满足交换律的规则&#xff0c;连接操作必须遵循语句中定义的顺序执行&#xff0c;从而限制了优化器根据参与连接的表的实际数据量来灵活选择优化策略的能力。为此&am…...

多进程、多线程、分布式测试支持-pytest-xdis插件

pytest-xdist是pytest测试框架的一个插件&#xff0c;它提供了多进程、多线程和分布式测试的支持&#xff0c;可以显著提高测试效率。以下是对pytest-xdist的详细介绍&#xff1a; 一、安装 要使用pytest-xdist&#xff0c;首先需要安装pytest和pytest-xdist。可以通过pip进行…...

Oracle virTualBox安装window10

一、下载windows10镜像 我下载的windows10镜像如下&#xff1a; 内部文件如下&#xff1a; 二、错误的安装方法 直接新建虚拟机&#xff0c;选择镜像文件&#xff1a; 启动虚拟机&#xff08;会一直提示没有启动设备&#xff0c;选择镜像后一直弹窗提示&#xff09; 三、正确…...

Python7-数据结构

记录python学习&#xff0c;直到学会基本的爬虫&#xff0c;使用python搭建接口自动化测试就算学会了&#xff0c;在进阶webui自动化&#xff0c;app自动化 python基础7-数据结构的那些事儿 常见的数据结构有哪些&#xff1f;线性数据结构有哪些&#xff1f;非线性数据结构有哪…...

springboot指定ssl版本连接

在application.yml配置指定 server.ssl.protocolTLSv1.2结果应用依然接受低版本如TLSv1.0的连接 可以在ie浏览器&#xff1a;设置-Internet选项-高级&#xff0c;将当前连接改为TLSv1.0进行测试 这种情况可以通过增加配置仅由TLSv1.2支持的密码处理&#xff1a; server.ssl.…...

VTK编程指南<十二>:VTK图像数据结构及图像创建与显示

数字图像是一种重要的多媒体数据&#xff0c;广泛应用于工业生产、生物医学、地质、气象等重要领域。数字图像处理技术具有重要的应用价值。图像是VTK里非常重要的一种数据结构。本章重点讲解VTK在数字图像处理应用方面的相关技术。 1、VTK图像数据结构 数字图像文件内容由两个…...

EasyGBS国标GB28181平台P2P远程访问故障排查指南:客户端角度的排查思路

在现代视频监控系统中&#xff0c;P2P&#xff08;点对点&#xff09;技术因其便捷性和高效性而被广泛应用。然而&#xff0c;当用户在使用P2P远程访问时遇到设备不在线或无法访问的问题时&#xff0c;有效的排查方法显得尤为重要。本文将从客户端的角度出发&#xff0c;详细探…...

打造智慧医院挂号枢纽:SSM 与 Vue 融合的系统设计与实施

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…...

网络编程 02:IP 地址,IP 地址的作用、分类,通过 Java 实现 IP 地址的信息获取

一、概述 记录时间 [2024-12-18] 前置文章&#xff1a;网络编程 01&#xff1a;计算机网络概述&#xff0c;网络的作用&#xff0c;网络通信的要素&#xff0c;以及网络通信协议与分层模型 本文讲述网络编程相关知识——IP 地址&#xff0c;包括 IP 地址的作用、分类&#xff…...

如何使用Python WebDriver爬取ChatGPT内容(完整教程)

大背景 虽然我们能用网页版chatGPT来聊天、写文章&#xff0c;但是我们采集大量的内容&#xff0c;就得不断地手动输入提问来获取答案&#xff0c;并且将结果复制到数据库来保存。如果整个过程能使用程序来做自然要节省很多的人力&#xff0c;精力和时间。 Python webdirver …...

Chapter03-Authentication vulnerabilities

文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

Ubuntu系统下交叉编译openssl

一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机&#xff1a;Ubuntu 20.04.6 LTSHost&#xff1a;ARM32位交叉编译器&#xff1a;arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性&#xff1a;电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中&#xff0c;电力载波技术&#xff08;PLC&#xff09;凭借其独特的优势&#xff0c;正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据&#xff0c;无需额外布…...

React Native在HarmonyOS 5.0阅读类应用开发中的实践

一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强&#xff0c;React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 &#xff08;1&#xff09;使用React Native…...

【项目实战】通过多模态+LangGraph实现PPT生成助手

PPT自动生成系统 基于LangGraph的PPT自动生成系统&#xff0c;可以将Markdown文档自动转换为PPT演示文稿。 功能特点 Markdown解析&#xff1a;自动解析Markdown文档结构PPT模板分析&#xff1a;分析PPT模板的布局和风格智能布局决策&#xff1a;匹配内容与合适的PPT布局自动…...

Ascend NPU上适配Step-Audio模型

1 概述 1.1 简述 Step-Audio 是业界首个集语音理解与生成控制一体化的产品级开源实时语音对话系统&#xff0c;支持多语言对话&#xff08;如 中文&#xff0c;英文&#xff0c;日语&#xff09;&#xff0c;语音情感&#xff08;如 开心&#xff0c;悲伤&#xff09;&#x…...

12.找到字符串中所有字母异位词

&#x1f9e0; 题目解析 题目描述&#xff1a; 给定两个字符串 s 和 p&#xff0c;找出 s 中所有 p 的字母异位词的起始索引。 返回的答案以数组形式表示。 字母异位词定义&#xff1a; 若两个字符串包含的字符种类和出现次数完全相同&#xff0c;顺序无所谓&#xff0c;则互为…...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问&#xff1a;说说对 IP 数据报中 TTL 的理解&#xff1f;我们都知道&#xff0c;IP 数据报由首部和数据两部分组成&#xff0c;首部又分为两部分&#xff1a;固定部分和可变部分&#xff0c;共占 20 字节&#xff0c;而即将讨论的 TTL 就位于首…...

基于matlab策略迭代和值迭代法的动态规划

经典的基于策略迭代和值迭代法的动态规划matlab代码&#xff0c;实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...

Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?

在大数据处理领域&#xff0c;Hive 作为 Hadoop 生态中重要的数据仓库工具&#xff0c;其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式&#xff0c;很多开发者常常陷入选择困境。本文将从底…...