当前位置: 首页 > news >正文

Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • MI-FGSM介绍
      • 背景
      • 算法原理
  • MI-FGSM代码实现
    • MI-FGSM算法实现
    • 攻击效果
  • 代码汇总
    • mifgsm.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现MI-FGSM对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:

在这里插入图片描述

MI-FGSM介绍

MI-FGSM(Momentum Iterative Fast Gradient Sign Method)是一种基于动量的迭代快速梯度符号法,是在FGSM(Fast Gradient Sign Method)基础上的改进,旨在生成更具攻击性和隐蔽性的对抗样本,以下是对其的详细介绍:

背景

  • 在对抗攻击领域,FGSM是一种简单有效的攻击方法,但它仅进行一次梯度计算和更新,生成的对抗样本可能不够强大。为了进一步提高攻击效果,研究人员提出了迭代攻击的方法,如I-FGSM(Iterative FGSM),通过多次迭代来逐步调整对抗样本。MI-FGSM在I-FGSM的基础上引入动量项,使得攻击能够更好地利用历史梯度信息,加速收敛并提高攻击成功率。

算法原理

  • 初始化:与FGSM类似,首先需要一个预训练的模型、损失函数、原始图像和对应的真实标签,以及攻击步长 ϵ \epsilon ϵ 、迭代次数 T T T和动量因子 μ \mu μ等参数。
  • 迭代更新:在每次迭代中,计算当前对抗样本相对于模型输出的损失梯度,并将其与上一次迭代的动量项相加,得到更新后的梯度方向。然后,根据更新后的梯度方向和攻击步长,对对抗样本进行更新。具体计算公式如下:
    g t + 1 = μ ⋅ g t + ∇ x J ( x t a d v , y ) ∥ ∇ x J ( x t a d v , y ) ∥ 1 g_{t+1}=\mu \cdot g_{t}+\frac{\nabla_{x} J\left(x_{t}^{adv}, y\right)}{\left\|\nabla_{x} J\left(x_{t}^{adv}, y\right)\right\|_{1}} gt+1=μgt+xJ(xtadv,y)1xJ(xtadv,y)
    x t + 1 a d v = x t a d v + ϵ ⋅ sign ( g t + 1 ) x_{t+1}^{adv}=x_{t}^{adv}+\epsilon \cdot \text{sign}\left(g_{t+1}\right) xt+1adv=xtadv+ϵsign(gt+1)
    其中, g t g_{t} gt 是第 t t t次迭代的动量项, x t a d v x_{t}^{adv} xtadv是第 t t t次迭代得到的对抗样本, J J J是损失函数, ∇ x J ( x t a d v , y ) \nabla_{x} J\left(x_{t}^{adv}, y\right) xJ(xtadv,y) 是损失函数关于对抗样本的梯度, sign \text{sign} sign 表示符号函数。
  • 投影操作:为了确保对抗样本在合理的范围内,通常还需要进行投影操作,将其像素值限制在有效区间内,如 [ 0 , 1 ] [0, 1] [0,1] [ − 1 , 1 ] [-1, 1] [1,1]

MI-FGSM代码实现

MI-FGSM算法实现

import torch
import torch.nn as nndef MI_FGSM(model, criterion, original_images, labels, epsilon, alpha=0.001, num_iterations=10, decay=1):"""MI-FGSM (Momentum Iterative Fast Gradient Sign Method) 参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始图像- labels: 原始图像的标签- epsilon: 最大扰动幅度- alpha: 每次迭代的步长- num_iterations: 迭代次数- decay: 动量衰减因子"""# 复制原始图像作为初始的对抗样本perturbed_image = original_images.clone().detach().requires_grad_(True)momentum = torch.zeros_like(original_images).detach().to(original_images.device)for _ in range(num_iterations):outputs = model(perturbed_image)loss = criterion(outputs, labels)model.zero_grad()loss.backward()data_grad = perturbed_image.grad.data# 归一化梯度,避免梯度爆炸等问题data_grad = data_grad / torch.mean(torch.abs(data_grad), dim=(1, 2, 3), keepdim=True)# 更新动量momentum = decay * momentum + data_grad / torch.mean(torch.abs(data_grad), dim=(1, 2, 3), keepdim=True)# 计算带动量的符号梯度sign_data_grad = momentum.sign()# 更新对抗样本perturbed_image = perturbed_image + alpha * sign_data_grad# 投影操作,确保扰动后的图像仍在合理范围内(这里假设图像范围是[0, 1])perturbed_image = torch.where(perturbed_image > original_images + epsilon,original_images + epsilon, perturbed_image)perturbed_image = torch.where(perturbed_image < original_images - epsilon,original_images - epsilon, perturbed_image)perturbed_image = torch.clamp(perturbed_image, 0, 1).detach().requires_grad_(True)return perturbed_image

攻击效果

在这里插入图片描述

代码汇总

mifgsm.py

import torch
import torch.nn as nndef MI_FGSM(model, criterion, original_images, labels, epsilon, alpha=0.001, num_iterations=10, decay=1):"""MI-FGSM (Momentum Iterative Fast Gradient Sign Method) 参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始图像- labels: 原始图像的标签- epsilon: 最大扰动幅度- alpha: 每次迭代的步长- num_iterations: 迭代次数- decay: 动量衰减因子"""# 复制原始图像作为初始的对抗样本perturbed_image = original_images.clone().detach().requires_grad_(True)momentum = torch.zeros_like(original_images).detach().to(original_images.device)for _ in range(num_iterations):outputs = model(perturbed_image)loss = criterion(outputs, labels)model.zero_grad()loss.backward()data_grad = perturbed_image.grad.data# 归一化梯度,避免梯度爆炸等问题data_grad = data_grad / torch.mean(torch.abs(data_grad), dim=(1, 2, 3), keepdim=True)# 更新动量momentum = decay * momentum + data_grad / torch.mean(torch.abs(data_grad), dim=(1, 2, 3), keepdim=True)# 计算带动量的符号梯度sign_data_grad = momentum.sign()# 更新对抗样本perturbed_image = perturbed_image + alpha * sign_data_grad# 投影操作,确保扰动后的图像仍在合理范围内(这里假设图像范围是[0, 1])perturbed_image = torch.where(perturbed_image > original_images + epsilon,original_images + epsilon, perturbed_image)perturbed_image = torch.where(perturbed_image < original_images - epsilon,original_images - epsilon, perturbed_image)perturbed_image = torch.clamp(perturbed_image, 0, 1).detach().requires_grad_(True)return perturbed_image

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18# 数据预处理
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)if __name__ == "__main__":# 训练模型for epoch in range(10):  # 可以根据实际情况调整训练轮数running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')running_loss = 0.0torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as pltssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ResNet18(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))if __name__ == "__main__":# 在测试集上进行FGSM攻击并评估准确率model.eval()  # 设置为评估模式correct = 0total = 0epsilon = 0.01  # 可以调整扰动强度for data in testloader:original_images, labels = data[0].to(device), data[1].to(device)original_images.requires_grad = Trueattack_name = 'MI-FGSM'if attack_name == 'FGSM':perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)elif attack_name == 'BIM':perturbed_images = BIM(model, criterion, original_images, labels, epsilon)elif attack_name == 'MI-FGSM':perturbed_images = MI_FGSM(model, criterion, original_images, labels, epsilon)perturbed_outputs = model(perturbed_images)_, predicted = torch.max(perturbed_outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total# Attack Success RateASR = 100 - accuracyprint(f'Load ResNet Model Weight from {weights_path}')print(f'epsilon: {epsilon}')print(f'ASR of {attack_name} : {ASR}%')

相关文章:

Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用MI-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集MI-FGSM介绍背景算法原理 MI-FGSM代码实现MI-FGSM算法实现攻击效果 代码汇总mifgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器&#xff1a; Pytorch | 从零构建AlexNet对CIFAR10进行…...

linux 无网络安装mysql

下载地址 通过网盘分享的文件&#xff1a;mysql-5.7.33-linux-glibc2.12-x86_64.tar.gz 链接: https://pan.baidu.com/s/1qm48pNfGYMqBGfoqT3hxPw?pwd0012 提取码: 0012 安装 解压 tar -zxvf mysql-5.7.33-linux-glibc2.12-x86_64.tar.gz mv /usr/mysql-5.7.33-linux-glibc2.1…...

自毁程序密码—阿里聚安全(IDA动态调试)

App信息 包名&#xff1a;com.yaotong.crackme Java层分析 MainActivity 很容易就能看出来是在securityCheck函数里进行安全校验。securityCheck是一个native函数&#xff0c;到so中进行分析。 SO层分析 定位函数位置 在导出函数里搜索 securityCheck 数据类型修复和…...

【华为OD-E卷-寻找关键钥匙 100分(python、java、c++、js、c)】

【华为OD-E卷-寻找关键钥匙 100分&#xff08;python、java、c、js、c&#xff09;】 题目 小强正在参加《密室逃生》游戏&#xff0c;当前关卡要求找到符合给定 密码K&#xff08;升序的不重复小写字母组成&#xff09; 的箱子&#xff0c;并给出箱子编号&#xff0c;箱子编…...

vscode 使用说明

文章目录 1、文档2、技巧显示与搜索宏定义和包含头文件 3、插件4、智能编写5、VSCode 与 C&#xff08;1&#xff09;安装&#xff08;2&#xff09;调试&#xff08;a&#xff09;使用 CMake 进行跨平台编译与调试&#xff08;b&#xff09;launch.json&#xff08;c&#xff…...

【Linux系统编程】:信号(2)——信号的产生

1.前言 我们会讲解五种信号产生的方式: 通过终端按键产生信号&#xff0c;比如键盘上的CtrlC。kill命令。本质上是调用kill()调用函数接口产生信号硬件异常产生信号软件条件产生信号 前两种在前一篇文章中做了介绍&#xff0c;本文介绍下面三种. 2. 调用函数产生信号 2.1 k…...

Android Studio AI助手---Gemini

从金丝雀频道下载最新版 Android Studio&#xff0c;以利用所有这些新功能&#xff0c;并继续阅读以了解新增内容。 Gemini 现在可以编写、重构和记录 Android 代码 Gemini 不仅仅是提供指导。它可以编辑您的代码&#xff0c;帮助您快速从原型转向实现&#xff0c;实现常见的…...

【day09】面向对象——静态成员和可变参数

【day08】面向对象——封装重点:1.封装:a.将细节隐藏起来,不让外界直接调用,再提供公共接口,供外界通过公共接口间接使用隐藏起来的细节b.代表性的:将一段代码放到一个方法中(隐藏细节),通过方法名(提供的公共接口)去调用private关键字 -> 私有的,被private修饰之后别的类不…...

Android学习(七)-Kotlin编程语言-Lambda 编程

Lambda 编程 而 Kotlin 从第一个版本开始就支持了 Lambda 编程&#xff0c;并且 Kotlin 中的 Lambda 功能极为强大。Lambda 表达式使得代码更加简洁和易读。 2.6.1 集合的创建与遍历 集合的函数式 API 是入门 Lambda 编程的绝佳示例&#xff0c;但在开始之前&#xff0c;我们…...

彻底认识和理解探索分布式网络编程中的SSL安全通信机制

探索分布式网络编程中的SSL安全通信机制 SSL的前提介绍SSL/TLS协议概述SSL和TLS建立在TCP/IP协议的基础上分析一个日常购物的安全问题 基于SSL的加密通信SSL的安全证书SSL的证书的实现安全认证获取对应的SSL证书方式权威机构获得证书创建自我签名证书 SSL握手通信机制公私钥传输…...

【libuv】Fargo信令2:【深入】client为什么收不到服务端响应的ack消息

客户端处理server的ack回复,判断链接连接建立 【Fargo】28:字节序列【libuv】Fargo信令1:client发connect消息给到server客户端启动后理解监听read消息 但是,这个代码似乎没有触发ack消息的接收: // 客户端初始化 void start_client(uv_loop_t...

Vue3自定义事件

自定义事件是一种组件间通信的方式&#xff0c;它允许子组件向父组件发送信息。子组件可以通过自定义事件向父组件传递数据以及事件&#xff0c;当自定义事件触发时&#xff0c;子组件可以借此将子组件的数据传递给父组件并使父组件对此做出相应的操作。 1.声明自定义事件 使…...

BeautifulSoup 与 XPath 用法详解与对比

BeautifulSoup&#xff08;bs4&#xff09; 和 XPath 是学习python爬虫过程中常常用到的库&#xff0c;本文将详细介绍它们的功能、使用方法、优缺点以及实际应用中的区别和选择建议。 1. BeautifulSoup 用法详解 1.1 什么是 BeautifulSoup&#xff1f; BeautifulSoup 是 Pyt…...

Emacs 折腾日记(五)——elisp 数字类型

本文是参考 emacs lisp 简明教程 写的&#xff0c;很多东西都是照搬里面的内容&#xff0c;如果各位读者觉得本文没有这篇教程优秀或者有抄袭嫌疑、又或者觉得我更新比较慢、再或者其他什么原因&#xff0c;请直接阅读上述链接中的教程。 上一篇我们讲了elisp中的流程控制结构相…...

重拾设计模式--外观模式

文章目录 外观模式&#xff08;Facade Pattern&#xff09;概述定义 外观模式UML图作用 外观模式的结构C 代码示例1C代码示例2总结 外观模式&#xff08;Facade Pattern&#xff09;概述 定义 外观模式是一种结构型设计模式&#xff0c;它为子系统中的一组接口提供了一个统一…...

源码编译llama.cpp for android

源码编译llama.cpp for android 我这有已经编译好的版本&#xff0c;直接下载使用&#xff1a; https://github.com/turingevo/llama.cpp-build/releases/tag/b4331 准备 android-ndk 已下载&#xff1a; /media/wmx/ws1/software/qtAndroid/Sdk/ndk/23.1.7779620版本 &am…...

StarRocks 排查单副本表

文章目录 StarRocks 排查单副本表方式1 查询元数据&#xff0c;检查分区级的副本数方式2 SHOW PARTITIONS命令查看 ReplicationNum修改副本数命令 StarRocks 排查单副本表 方式1 查询元数据&#xff0c;检查分区级的副本数 # 方式一 查询元数据&#xff0c;检查分区级的副本数…...

Windows11 家庭版安装配置 Docker

1. 安装WSL WSL 是什么&#xff1a; WSL 是一个在 Windows 上运行 Linux 环境的轻量级工具&#xff0c;它可以让用户在 Windows 系统中运行 Linux 工具和应用程序。Docker 为什么需要 WSL&#xff1a; Docker 依赖 Linux 内核功能&#xff0c;WSL 2 提供了一个高性能、轻量级的…...

线程知识总结(二)

本篇文章以线程同步的相关内容为主。线程的同步机制主要用来解决线程安全问题&#xff0c;主要方式有同步代码块、同步方法等。首先来了解何为线程安全问题。 1、线程安全问题 卖票示例&#xff0c;4 个窗口卖 100 张票&#xff1a; class Ticket implements Runnable {priv…...

解决vscode ssh远程连接服务器一直卡在下载 vscode server问题

目录 方法1&#xff1a;使用科学上网 方法2&#xff1a;手动下载 方法3 在使用vscode使用ssh远程连接服务器时&#xff0c;一直卡在下载"vscode 服务器"阶段&#xff0c;但MobaXterm可以正常连接服务器&#xff0c;大概率是网络问题&#xff0c;解决方法如下: 方…...

linux 下常用变更-8

1、删除普通用户 查询用户初始UID和GIDls -l /home/ ###家目录中查看UID cat /etc/group ###此文件查看GID删除用户1.编辑文件 /etc/passwd 找到对应的行&#xff0c;YW343:x:0:0::/home/YW343:/bin/bash 2.将标红的位置修改为用户对应初始UID和GID&#xff1a; YW3…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

JVM暂停(Stop-The-World,STW)的原因分类及对应排查方案

JVM暂停(Stop-The-World,STW)的完整原因分类及对应排查方案,结合JVM运行机制和常见故障场景整理而成: 一、GC相关暂停​​ 1. ​​安全点(Safepoint)阻塞​​ ​​现象​​:JVM暂停但无GC日志,日志显示No GCs detected。​​原因​​:JVM等待所有线程进入安全点(如…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

return this;返回的是谁

一个审批系统的示例来演示责任链模式的实现。假设公司需要处理不同金额的采购申请&#xff0c;不同级别的经理有不同的审批权限&#xff1a; // 抽象处理者&#xff1a;审批者 abstract class Approver {protected Approver successor; // 下一个处理者// 设置下一个处理者pub…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析

Linux 内存管理实战精讲&#xff1a;核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用&#xff0c;还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

[ACTF2020 新生赛]Include 1(php://filter伪协议)

题目 做法 启动靶机&#xff0c;点进去 点进去 查看URL&#xff0c;有 ?fileflag.php说明存在文件包含&#xff0c;原理是php://filter 协议 当它与包含函数结合时&#xff0c;php://filter流会被当作php文件执行。 用php://filter加编码&#xff0c;能让PHP把文件内容…...

go 里面的指针

指针 在 Go 中&#xff0c;指针&#xff08;pointer&#xff09;是一个变量的内存地址&#xff0c;就像 C 语言那样&#xff1a; a : 10 p : &a // p 是一个指向 a 的指针 fmt.Println(*p) // 输出 10&#xff0c;通过指针解引用• &a 表示获取变量 a 的地址 p 表示…...

Modbus RTU与Modbus TCP详解指南

目录 1. Modbus协议基础 1.1 什么是Modbus? 1.2 Modbus协议历史 1.3 Modbus协议族 1.4 Modbus通信模型 🎭 主从架构 🔄 请求响应模式 2. Modbus RTU详解 2.1 RTU是什么? 2.2 RTU物理层 🔌 连接方式 ⚡ 通信参数 2.3 RTU数据帧格式 📦 帧结构详解 🔍…...

【Linux】Linux安装并配置RabbitMQ

目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的&#xff0c;需要先安…...