从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击
从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击
简介
在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性:对抗样本(Adversarial Examples)。简单来说,对抗样本是通过对原始输入数据添加微小扰动,导致神经网络预测错误的一类特殊样本。
Fast Gradient Sign Method (FGSM) 是最早提出的一种简单且有效的对抗攻击方法。它通过利用神经网络的梯度信息,快速生成对抗样本。本篇博客将深入剖析 FGSM 的数学原理,并提供一份基于 PyTorch 的代码实现,帮助你快速上手对抗攻击。
最终效果
当前图片为 truck,但经过对抗攻击后,结果看似应该同样是 truck,然而模型却错误地将其分类为了 cat。

数学原理
1. 神经网络的基本工作原理
神经网络通过对输入数据 x x x 进行一系列的非线性变换,输出一个预测值 y ^ \hat{y} y^。以图像分类任务为例,输入 x x x 是图像,网络输出 y ^ \hat{y} y^ 是一个分类结果。网络的目标是使得预测值 y ^ \hat{y} y^ 尽可能接近真实标签 y y y,通过最小化损失函数 L ( y ^ , y ) L(\hat{y}, y) L(y^,y) 来优化网络的权重。
2. 对抗攻击的目标
对抗攻击的核心思想是在输入数据上添加一个微小的扰动 η \eta η,使得网络在预测时发生错误。这种扰动应尽量保持小到人类肉眼难以察觉,但足以让网络产生不同的输出,即:
y ^ a d v = f ( x + η ) and y ^ a d v ≠ y \hat{y}_{adv} = f(x + \eta) \quad \text{and} \quad \hat{y}_{adv} \neq y y^adv=f(x+η)andy^adv=y
3. FGSM 的核心思想
FGSM 利用了梯度上升的思想,通过损失函数相对于输入图像的梯度来找到 最容易 迷惑网络的方向,并沿着这个方向对图像进行微小的扰动。
具体来说,给定输入图像 x x x 和其真实标签 y y y,我们可以计算损失函数 L ( x , y ) L(x, y) L(x,y),并对输入图像 x x x 计算其梯度:
∇ x L ( f ( x ) , y ) \nabla_x L(f(x), y) ∇xL(f(x),y)
这个梯度告诉我们,如何改变输入图像才能最大化损失。FGSM 的基本想法是,沿着这个梯度的符号方向对图像进行微调,以最大化损失函数。具体公式为:
x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵ⋅sign(∇xL(f(x),y))
其中:
- ϵ \epsilon ϵ 是一个小的常数,控制扰动的大小。
- sign ( ∇ x L ( f ( x ) , y ) ) \text{sign}(\nabla_x L(f(x), y)) sign(∇xL(f(x),y)) 是损失函数对输入图像梯度的符号,即正负号矩阵。
通过这种方式,我们可以生成一个对抗样本 x a d v x_{adv} xadv,它与原始图像 x x x 看起来几乎相同,但神经网络会将其错误分类。
4. 梯度上升与对抗样本
FGSM 是基于梯度上升的攻击方法。为了让网络犯错,我们希望最大化损失,因此通过扰动让损失函数 L L L 增加。梯度 ∇ x L \nabla_x L ∇xL 的方向是使 L L L 增加最快的方向,因此,我们可以通过对梯度取符号(sign)来生成一个扰动 η \eta η,即:
η = ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) \eta = \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) η=ϵ⋅sign(∇xL(f(x),y))
这样生成的扰动被加到原图 x x x 上,即对抗样本的公式为:
x a d v = x + η = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \eta = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+η=x+ϵ⋅sign(∇xL(f(x),y))
代码实现与解读
接下来,我们通过实际的代码示例,展示如何使用 PyTorch 实现 FGSM 攻击。
1. 加载CIFAR-10数据集与预训练模型
首先,我们需要加载 CIFAR-10 数据集和一个预训练的模型。这里我们使用 ResNet-18 模型,并针对 CIFAR-10 数据集进行了调整。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models# 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# 加载微调后的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)
model.eval() # 设置为评估模式
在这一步中,我们准备了数据集,并将模型设置为评估模式,这样模型的参数不会在攻击过程中被更新。
提示:如果没有下载过 CIFAR-10 数据集,请将
download=False改为download=True。
2. 实现FGSM攻击方法
FGSM 的关键在于利用损失函数的梯度来生成对抗样本。具体步骤如下:
# 定义 FGSM 攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image
数学公式:
在代码中,我们实际上实现了如下的公式:
x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵ⋅sign(∇xL(f(x),y))
该函数实现了 FGSM 算法,通过将计算得到的梯度方向乘以一个小的扰动系数 ϵ \epsilon ϵ,并添加到原始图像上来生成对抗样本。
3. 攻击流程
接下来,我们定义一个函数,来完成攻击过程:
def attack_example(model, device, data, target, epsilon):# 将数据和标签移动到设备上data, target = data.to(device), target.to(device)data.requires_grad = True # 追踪输入图像的梯度# 前向传播得到初始预测output = model(data)init_pred = output.max(1, keepdim=True)[1] # 选择概率最大的类别作为预测结果# 如果初始预测错误,则不进行攻击if init_pred.item() != target.item():return None, None, None# 计算损失并进行反向传播计算梯度loss_fn = torch.nn.CrossEntropyLoss()loss = loss_fn(output, target)model.zero_grad() # 清除现有的梯度loss.backward() # 计算损失对输入图像的梯度# 使用 FGSM 方法生成对抗样本data_grad = data.grad.dataperturbed_data = fgsm_attack(data, epsilon, data_grad)# 使用对抗样本进行新的预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1] # 获取对抗样本的预测标签return data, perturbed_data, final_pred
对抗攻击的基本思路是:
- 前向传播:将输入数据(如图片)传递给模型,得到输出,并预测标签。
- 计算损失:使用损失函数来评估模型的预测与真实标签之间的差距。
- 反向传播计算梯度:通过反向传播,计算损失函数相对于输入数据的梯度。
- 生成对抗样本:利用这些梯度来调整输入数据,生成扰动的对抗样本。
- 重新预测:使用生成的对抗样本进行重新预测,观察模型是否做出错误的分类。
4. 展示原始图像与对抗样本
最后,我们选择一个样本进行攻击,并展示原始图像和对抗样本的效果。
# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01 # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")
在运行这段代码后,你会看到一张原始图像及其对抗样本。尽管对抗样本与原图在视觉上几乎没有区别,但它却被神经网络错误分类了。这种现象突显了神经网络的脆弱性。
结论
FGSM 是一种简单且有效的对抗攻击方法,它利用了神经网络的梯度信息生成对抗样本。这些对抗样本虽然肉眼难以察觉,但能显著影响模型的预测结果。本文介绍了 FGSM 的数学原理,并提供了基于 PyTorch 的实现代码。
通过这个例子,我们可以看到即使是看似强大的深度学习模型,也会受到一些精心设计的输入的严重影响。了解这些脆弱性有助于我们开发出更加健壮和安全的模型。
附录:完整代码
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np# Step 1: 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Step 2: 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])# 加载测试数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# model = models.resnet18(pretrained=True).to(device)# 1. 加载在 CIFAR-10 上微调的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)model.eval() # 设置为评估模式# 使用的损失函数
loss_fn = nn.CrossEntropyLoss()# Step 4: FGSM 对抗攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image# Step 5: 攻击流程
def attack_example(model, device, data, target, epsilon):data, target = data.to(device), target.to(device)# 确保计算图对输入的梯度data.requires_grad = True# 前向传播output = model(data)init_pred = output.max(1, keepdim=True)[1] # 预测标签# 如果预测错误,则不攻击if init_pred.item() != target.item():return None, None, None# 计算损失loss = loss_fn(output, target)# 反向传播计算梯度model.zero_grad()loss.backward()# 提取梯度data_grad = data.grad.data# 执行 FGSM 攻击perturbed_data = fgsm_attack(data, epsilon, data_grad)# 再次进行预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1] # 对抗样本的预测标签return data, perturbed_data, final_pred# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01 # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 标签# orig_label = target.item()# perturbed_label = final_pred.item()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")相关文章:
从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击
从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击 简介 在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性&…...
从零开始学习OMNeT++系列第一弹——OMNeT++的介绍与安装
最近由于由于工作上的需求,接了一个网络仿真的任务。于是开始调研各个仿真平台,然后根据目前的需求和网络上公开资料的多少,决定使用omnet这个网络仿真平台。现在也是刚开始学习,所以决定记录一下从零开始的这个学习过程。因为虽然…...
Cluster Explanation via Polyhedral Descriptions
通过多面体描述进行聚类解释 本文关注聚类描述问题,即在给定数据集及其聚类划分的情况下,解释这些聚类的任务。我们提出了一种新的聚类解释方法,通过在每个聚类周围构建一个多面体,同时最小化最终多面体的复杂性或用于描述的特征…...
爬虫设计思考之一
爬虫设计思考之一 经常做爬虫的人对于技术比较的执着,尤其是本身从事的擅长的技术领域,从而容易忽视与之相近或者相似的技术。因此我建议大家在遇到此类问题的时候,可以采用对比分析的方式来理解。 本次的思考是基于国内最大的中文搜索引擎百…...
解决centos 删除文件后但空间没有释放
一、问题描述:磁盘空间不足,清理完垃圾日志以后磁盘空间还是没有释放 查看磁盘空间 [rootxwj-qt-65-44 ~]# df -h Filesystem Size Used Avail Use% Mounted on devtmpfs 1.9G 0 1.9G 0% /dev tmpfs 1.9G 0 1.9G …...
微软SCCM:企业级系统管理的核心工具
目录 摘要 1. 引言 2. SCCM的基本概念 2.1 什么是SCCM? 2.2 SCCM的历史 3. SCCM的架构 3.1 中心服务器 3.2 数据库 3.3 管理点(Management Point) 3.4 分发点(Distribution Point) 3.5 客户端代理 3.6 报告服务 4. SCCM的核心功能 4.1 软件部署与管理 4.2 操…...
RTSP作为客户端 推流 拉流的过程分析
之前写过一个 rtsp server 作为服务端的简单demo 这次分析下 rtsp作为客户端 推流和拉流时候的过 A.作为客户端拉流 TCP方式 1.Client发送OPTIONS方法 Server回应告诉支持的方法 2.Client发送DESCRIPE方法 这里是从海康摄像机拉流并且设置了用户名密码 Server回复未认证 3.客…...
【MySQL 07】内置函数
目录 1.日期函数 日期函数使用场景: 2.字符串函数 字符串函数使用场景: 3.数学函数 4.控制流函数 1.日期函数 函数示例: 1.在日期的基础上加日期 在该日期下,加上10天。 2.在日期的基础上减去时间 在该日期下减去2天 3.计算两…...
《深度学习》OpenCV 背景建模 原理及案例解析
目录 一、背景建模 1、什么是背景建模 2、背景建模的方法 1)帧差法(backgroundSubtractor) 2)基于K近邻的背景/前景分割算法BackgroundSubtractorKNN 3)基于高斯混合的背景/前景分割算法BackgroundSubtractorMOG2 3、步骤 1)初…...
机器学习(1):机器学习的概念
1. 机器学习的定义和相关概念 机器学习之父 Arthur Samuel 对机器学习的定义是:在没有明确设置的情况下,使计算机具有学习能力的研究领域。 国际机器学习大会的创始人之一 Tom Mitchell 对机器学习的定义是:计算机程序从经验 E 中学习&#…...
0. Pixel3 在Ubuntu22下Android12源码拉取 + 编译
0. Pixel3 在Ubuntu22下Android12源码拉取 编译 原文地址: http://www.androidcrack.com/index.php/archives/3/ 1. 前言 这是一个非常悲伤的故事, 因为一个意外, 不小心把之前镜像的源码搞坏了. 也没做版本管理,恢复不了了. 那么只能说是重新做一次. 再者以前的镜像太老旧…...
ip经过多个服务器转发会网速变慢吗
会的,IP经过多个服务器转发时,网速通常会变慢,主要原因包括: 增加的延迟: 每经过一个服务器,数据包就需要额外的时间进行处理和转发。这种处理时间和网络延迟会累积,导致整体延迟增加。 带宽限制…...
mongodb通过mongoimport导入JSON文件数据
目录 一、概念 二、mongoimport导入工具 三、导入命令 一、概念 MongoDB是一个流行的开源文档数据库,它支持JSON格式的文档,非常适合存储和处理大量的非结构化数据。在实际应用中,我们经常需要将大量的数据批量导入到MongoDB中。mongoimpo…...
【Qt】控件概述 (1)
控件概述 1. QWidget核心属性1.1核心属性概述1.2 enable1.3 geometry——窗口坐标1.4 window frame的影响1.4 windowTitle——窗口标题1.5 windowIcon——窗口图标1.6 windowOpacity——透明度设置1.7 cursor——光标设置1.8 font——字体设置1.9 toolTip——鼠标悬停提示设置1…...
ping基本使用详解
在网络中ping是一个十分强大的TCP/IP工具。它的作用主要为: 用来检测网络的连通情况和分析网络速度根据域名得到服务器 IP根据 ping 返回的 TTL 值来判断对方所使用的操作系统及数据包经过路由器数量。我们通常会用它来直接 ping ip 地址,来测试网络的连…...
Win10之解决:设置静态IP后,为什么自动获取动态IP问题(七十八)
简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…...
【AI论文精读1】针对知识密集型NLP任务的检索增强生成(RAG原始论文)
目录 一、简介一句话简介作者、引用数、时间论文地址开源代码地址 二、摘要三、引言四、整体架构(用一个例子来阐明)场景例子:核心点: 五、方法 (架构各部分详解)5.1 模型1. RAG-Sequence Model2. RAG-Toke…...
踩坑spring cloud gateway /actuator/gateway/refresh不生效
版本 java version: 17 spring boot: 3.2.x spring cloud: 2023.0.3 现象 参考Spring Cloud Gateway -> Actuator API -> Refreshing the Route Cache 说明,先修改routes配置再调用/actuator/gateway/refresh,接口返回200 status,但…...
【STM32开发环境搭建】-3-STM32CubeMX Project Manager配置-自动生成一个Keil(MDK-ARM) 5的工程
目录 1 KEIL(MDK-ARM) 5 Project工程设置 2 MCU和嵌入式软件包的选择 3 Code Generator 3.1 STM32Cube Firmware Library Package 3.2 Generated files 3.3 HAL Settings 3.4 Template Settings 4 Advanced Settings 5 自动生成的KEIL(MDK-ARM) 5 Project工程目录 结…...
计算机毕业设计 Java酷听音乐系统的设计与实现 Java实战项目 附源码+文档+视频讲解
博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...
Ubuntu服务器中文乱码终极解决方案:从locale配置到阿里云重启避坑指南
Ubuntu服务器中文乱码终极解决方案:从locale配置到阿里云重启避坑指南 当你第一次在Ubuntu服务器上看到中文字符变成一堆问号或方框时,那种困惑和挫败感我深有体会。特别是在云服务器环境下,问题往往比本地环境更复杂——即使按照常规教程操作…...
ESP32轻量级18650电池电量估算库设计与实现
1. 项目概述Battery_18650_Stats是一款专为 ESP32 平台设计的轻量级嵌入式电池状态计算库,核心目标是在 Arduino IDE 环境下,以最小资源开销、最高工程鲁棒性,实现对单节 18650 锂离子电池(Li-ion)荷电状态(…...
CAN总线波特率计算器工具开发指南(Python+PyQt5)
CAN总线波特率计算器工具开发指南(PythonPyQt5) 在汽车电子工程领域,CAN总线作为车载网络的骨干,其通信质量直接影响整车系统的稳定性。而波特率作为CAN通信的基础参数,其配置精度直接决定了总线能否正常工作。传统的手…...
leOS2:基于看门狗定时器的轻量级嵌入式调度器
1. leOS2:基于看门狗定时器的轻量级嵌入式调度器 leOS2(little embedded Operating System 2)是一个专为资源受限的8位AVR微控制器设计的极简实时调度器。它不依赖于通用定时器(如Timer0/Timer1),而是创造…...
别再傻傻分不清了!IM和RTC到底差在哪?从微信聊天到腾讯会议的技术选择
IM与RTC技术选型指南:从协议栈到商业场景的深度解析 当你的产品经理在白板上画出一个"消息气泡"和一个"视频通话图标"时,技术团队首先需要面对的灵魂拷问是:这到底该用IM架构还是RTC架构?2019年某在线教育初创…...
开源电子书工具:如何用鸿蒙系统打造专属个性化阅读空间
开源电子书工具:如何用鸿蒙系统打造专属个性化阅读空间 【免费下载链接】legado-Harmony 开源阅读鸿蒙版仓库 项目地址: https://gitcode.com/gh_mirrors/le/legado-Harmony 你是否曾因阅读应用充斥广告而烦躁?是否渴望完全掌控自己的阅读体验&am…...
4大核心优势解决人脸处理难题:设计师与创作者的AI增强工具
4大核心优势解决人脸处理难题:设计师与创作者的AI增强工具 【免费下载链接】DZ-FaceDetailer a node for comfyui for restore/edit/enchance faces utilizing face recognition 项目地址: https://gitcode.com/gh_mirrors/dz/DZ-FaceDetailer 【问题诊断】为…...
从‘偏差-方差’到一行代码:用NumPy/PyTorch五步实现GAE,附PPO实战避坑点
从‘偏差-方差’到一行代码:用NumPy/PyTorch五步实现GAE,附PPO实战避坑点 强化学习中的策略优化常常面临一个核心挑战:如何准确评估动作的价值?广义优势估计(GAE)通过巧妙平衡偏差与方差,成为PP…...
STM32F103C8T6驱动无FIFO的OV7670:从时序理解到图像显示的完整避坑指南
STM32F103C8T6驱动无FIFO的OV7670:从时序理解到图像显示的完整避坑指南 当你第一次将OV7670摄像头模块连接到STM32F103C8T6开发板时,可能会被那些看似简单的时序信号搞得晕头转向。VSYNC、HREF、PCLK——这些信号线背后隐藏着图像数据采集的全部秘密。本…...
5分钟搞定三网话费余额查询:手把手教你用PHP+HTML搭建查询系统(含API调用避坑指南)
三网话费查询系统开发实战:从API调用到前端优化的全流程指南 最近在帮朋友开发一个小型话费查询工具时,发现市面上关于三网运营商API调用的完整教程并不多见。大多数开发者遇到问题时只能靠反复试错,特别是当需要同时对接移动、联通、电信三家…...
