当前位置: 首页 > news >正文

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

简介

在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性:对抗样本(Adversarial Examples)。简单来说,对抗样本是通过对原始输入数据添加微小扰动,导致神经网络预测错误的一类特殊样本。

Fast Gradient Sign Method (FGSM) 是最早提出的一种简单且有效的对抗攻击方法。它通过利用神经网络的梯度信息,快速生成对抗样本。本篇博客将深入剖析 FGSM 的数学原理,并提供一份基于 PyTorch 的代码实现,帮助你快速上手对抗攻击。

最终效果

当前图片为 truck,但经过对抗攻击后,结果看似应该同样是 truck,然而模型却错误地将其分类为了 cat

在这里插入图片描述

数学原理

1. 神经网络的基本工作原理

神经网络通过对输入数据 x x x 进行一系列的非线性变换,输出一个预测值 y ^ \hat{y} y^。以图像分类任务为例,输入 x x x 是图像,网络输出 y ^ \hat{y} y^ 是一个分类结果。网络的目标是使得预测值 y ^ \hat{y} y^ 尽可能接近真实标签 y y y,通过最小化损失函数 L ( y ^ , y ) L(\hat{y}, y) L(y^,y) 来优化网络的权重。

2. 对抗攻击的目标

对抗攻击的核心思想是在输入数据上添加一个微小的扰动 η \eta η,使得网络在预测时发生错误。这种扰动应尽量保持小到人类肉眼难以察觉,但足以让网络产生不同的输出,即:

y ^ a d v = f ( x + η ) and y ^ a d v ≠ y \hat{y}_{adv} = f(x + \eta) \quad \text{and} \quad \hat{y}_{adv} \neq y y^adv=f(x+η)andy^adv=y

3. FGSM 的核心思想

FGSM 利用了梯度上升的思想,通过损失函数相对于输入图像的梯度来找到 最容易 迷惑网络的方向,并沿着这个方向对图像进行微小的扰动。

具体来说,给定输入图像 x x x 和其真实标签 y y y,我们可以计算损失函数 L ( x , y ) L(x, y) L(x,y),并对输入图像 x x x 计算其梯度:

∇ x L ( f ( x ) , y ) \nabla_x L(f(x), y) xL(f(x),y)

这个梯度告诉我们,如何改变输入图像才能最大化损失。FGSM 的基本想法是,沿着这个梯度的符号方向对图像进行微调,以最大化损失函数。具体公式为:

x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵsign(xL(f(x),y))

其中:

  • ϵ \epsilon ϵ 是一个小的常数,控制扰动的大小。
  • sign ( ∇ x L ( f ( x ) , y ) ) \text{sign}(\nabla_x L(f(x), y)) sign(xL(f(x),y)) 是损失函数对输入图像梯度的符号,即正负号矩阵。

通过这种方式,我们可以生成一个对抗样本 x a d v x_{adv} xadv,它与原始图像 x x x 看起来几乎相同,但神经网络会将其错误分类。

4. 梯度上升与对抗样本

FGSM 是基于梯度上升的攻击方法。为了让网络犯错,我们希望最大化损失,因此通过扰动让损失函数 L L L 增加。梯度 ∇ x L \nabla_x L xL 的方向是使 L L L 增加最快的方向,因此,我们可以通过对梯度取符号(sign)来生成一个扰动 η \eta η,即:

η = ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) \eta = \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) η=ϵsign(xL(f(x),y))

这样生成的扰动被加到原图 x x x 上,即对抗样本的公式为:

x a d v = x + η = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \eta = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+η=x+ϵsign(xL(f(x),y))

代码实现与解读

接下来,我们通过实际的代码示例,展示如何使用 PyTorch 实现 FGSM 攻击。

1. 加载CIFAR-10数据集与预训练模型

首先,我们需要加载 CIFAR-10 数据集和一个预训练的模型。这里我们使用 ResNet-18 模型,并针对 CIFAR-10 数据集进行了调整。

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models# 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# 加载微调后的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)
model.eval()  # 设置为评估模式

在这一步中,我们准备了数据集,并将模型设置为评估模式,这样模型的参数不会在攻击过程中被更新。

提示:如果没有下载过 CIFAR-10 数据集,请将 download=False 改为 download=True

2. 实现FGSM攻击方法

FGSM 的关键在于利用损失函数的梯度来生成对抗样本。具体步骤如下:

# 定义 FGSM 攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image
数学公式:

在代码中,我们实际上实现了如下的公式:

x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵsign(xL(f(x),y))

该函数实现了 FGSM 算法,通过将计算得到的梯度方向乘以一个小的扰动系数 ϵ \epsilon ϵ,并添加到原始图像上来生成对抗样本。

3. 攻击流程

接下来,我们定义一个函数,来完成攻击过程:

def attack_example(model, device, data, target, epsilon):# 将数据和标签移动到设备上data, target = data.to(device), target.to(device)data.requires_grad = True  # 追踪输入图像的梯度# 前向传播得到初始预测output = model(data)init_pred = output.max(1, keepdim=True)[1]  # 选择概率最大的类别作为预测结果# 如果初始预测错误,则不进行攻击if init_pred.item() != target.item():return None, None, None# 计算损失并进行反向传播计算梯度loss_fn = torch.nn.CrossEntropyLoss()loss = loss_fn(output, target)model.zero_grad()  # 清除现有的梯度loss.backward()  # 计算损失对输入图像的梯度# 使用 FGSM 方法生成对抗样本data_grad = data.grad.dataperturbed_data = fgsm_attack(data, epsilon, data_grad)# 使用对抗样本进行新的预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]  # 获取对抗样本的预测标签return data, perturbed_data, final_pred

对抗攻击的基本思路是:

  1. 前向传播:将输入数据(如图片)传递给模型,得到输出,并预测标签。
  2. 计算损失:使用损失函数来评估模型的预测与真实标签之间的差距。
  3. 反向传播计算梯度:通过反向传播,计算损失函数相对于输入数据的梯度。
  4. 生成对抗样本:利用这些梯度来调整输入数据,生成扰动的对抗样本。
  5. 重新预测:使用生成的对抗样本进行重新预测,观察模型是否做出错误的分类。

4. 展示原始图像与对抗样本

最后,我们选择一个样本进行攻击,并展示原始图像和对抗样本的效果。

# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01  # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")

在运行这段代码后,你会看到一张原始图像及其对抗样本。尽管对抗样本与原图在视觉上几乎没有区别,但它却被神经网络错误分类了。这种现象突显了神经网络的脆弱性。

结论

FGSM 是一种简单且有效的对抗攻击方法,它利用了神经网络的梯度信息生成对抗样本。这些对抗样本虽然肉眼难以察觉,但能显著影响模型的预测结果。本文介绍了 FGSM 的数学原理,并提供了基于 PyTorch 的实现代码。

通过这个例子,我们可以看到即使是看似强大的深度学习模型,也会受到一些精心设计的输入的严重影响。了解这些脆弱性有助于我们开发出更加健壮和安全的模型。

附录:完整代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np# Step 1: 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Step 2: 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])# 加载测试数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# model = models.resnet18(pretrained=True).to(device)# 1. 加载在 CIFAR-10 上微调的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)model.eval()  # 设置为评估模式# 使用的损失函数
loss_fn = nn.CrossEntropyLoss()# Step 4: FGSM 对抗攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image# Step 5: 攻击流程
def attack_example(model, device, data, target, epsilon):data, target = data.to(device), target.to(device)# 确保计算图对输入的梯度data.requires_grad = True# 前向传播output = model(data)init_pred = output.max(1, keepdim=True)[1]  # 预测标签# 如果预测错误,则不攻击if init_pred.item() != target.item():return None, None, None# 计算损失loss = loss_fn(output, target)# 反向传播计算梯度model.zero_grad()loss.backward()# 提取梯度data_grad = data.grad.data# 执行 FGSM 攻击perturbed_data = fgsm_attack(data, epsilon, data_grad)# 再次进行预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]  # 对抗样本的预测标签return data, perturbed_data, final_pred# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01  # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 标签# orig_label = target.item()# perturbed_label = final_pred.item()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")

相关文章:

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击 简介 在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性&…...

从零开始学习OMNeT++系列第一弹——OMNeT++的介绍与安装

最近由于由于工作上的需求,接了一个网络仿真的任务。于是开始调研各个仿真平台,然后根据目前的需求和网络上公开资料的多少,决定使用omnet这个网络仿真平台。现在也是刚开始学习,所以决定记录一下从零开始的这个学习过程。因为虽然…...

Cluster Explanation via Polyhedral Descriptions

通过多面体描述进行聚类解释 本文关注聚类描述问题,即在给定数据集及其聚类划分的情况下,解释这些聚类的任务。我们提出了一种新的聚类解释方法,通过在每个聚类周围构建一个多面体,同时最小化最终多面体的复杂性或用于描述的特征…...

爬虫设计思考之一

爬虫设计思考之一 经常做爬虫的人对于技术比较的执着,尤其是本身从事的擅长的技术领域,从而容易忽视与之相近或者相似的技术。因此我建议大家在遇到此类问题的时候,可以采用对比分析的方式来理解。 本次的思考是基于国内最大的中文搜索引擎百…...

解决centos 删除文件后但空间没有释放

一、问题描述:磁盘空间不足,清理完垃圾日志以后磁盘空间还是没有释放 查看磁盘空间 [rootxwj-qt-65-44 ~]# df -h Filesystem Size Used Avail Use% Mounted on devtmpfs 1.9G 0 1.9G 0% /dev tmpfs 1.9G 0 1.9G …...

微软SCCM:企业级系统管理的核心工具

目录 摘要 1. 引言 2. SCCM的基本概念 2.1 什么是SCCM? 2.2 SCCM的历史 3. SCCM的架构 3.1 中心服务器 3.2 数据库 3.3 管理点(Management Point) 3.4 分发点(Distribution Point) 3.5 客户端代理 3.6 报告服务 4. SCCM的核心功能 4.1 软件部署与管理 4.2 操…...

RTSP作为客户端 推流 拉流的过程分析

之前写过一个 rtsp server 作为服务端的简单demo 这次分析下 rtsp作为客户端 推流和拉流时候的过 A.作为客户端拉流 TCP方式 1.Client发送OPTIONS方法 Server回应告诉支持的方法 2.Client发送DESCRIPE方法 这里是从海康摄像机拉流并且设置了用户名密码 Server回复未认证 3.客…...

【MySQL 07】内置函数

目录 1.日期函数 日期函数使用场景: 2.字符串函数 字符串函数使用场景: 3.数学函数 4.控制流函数 1.日期函数 函数示例: 1.在日期的基础上加日期 在该日期下,加上10天。 2.在日期的基础上减去时间 在该日期下减去2天 3.计算两…...

《深度学习》OpenCV 背景建模 原理及案例解析

目录 一、背景建模 1、什么是背景建模 2、背景建模的方法 1)帧差法(backgroundSubtractor) 2)基于K近邻的背景/前景分割算法BackgroundSubtractorKNN 3)基于高斯混合的背景/前景分割算法BackgroundSubtractorMOG2 3、步骤 1)初…...

机器学习(1):机器学习的概念

1. 机器学习的定义和相关概念 机器学习之父 Arthur Samuel 对机器学习的定义是:在没有明确设置的情况下,使计算机具有学习能力的研究领域。 国际机器学习大会的创始人之一 Tom Mitchell 对机器学习的定义是:计算机程序从经验 E 中学习&#…...

0. Pixel3 在Ubuntu22下Android12源码拉取 + 编译

0. Pixel3 在Ubuntu22下Android12源码拉取 编译 原文地址: http://www.androidcrack.com/index.php/archives/3/ 1. 前言 这是一个非常悲伤的故事, 因为一个意外, 不小心把之前镜像的源码搞坏了. 也没做版本管理,恢复不了了. 那么只能说是重新做一次. 再者以前的镜像太老旧…...

ip经过多个服务器转发会网速变慢吗

会的,IP经过多个服务器转发时,网速通常会变慢,主要原因包括: 增加的延迟: 每经过一个服务器,数据包就需要额外的时间进行处理和转发。这种处理时间和网络延迟会累积,导致整体延迟增加。 带宽限制…...

mongodb通过mongoimport导入JSON文件数据

目录 一、概念 二、mongoimport导入工具 三、导入命令 一、概念 MongoDB是一个流行的开源文档数据库,它支持JSON格式的文档,非常适合存储和处理大量的非结构化数据。在实际应用中,我们经常需要将大量的数据批量导入到MongoDB中。mongoimpo…...

【Qt】控件概述 (1)

控件概述 1. QWidget核心属性1.1核心属性概述1.2 enable1.3 geometry——窗口坐标1.4 window frame的影响1.4 windowTitle——窗口标题1.5 windowIcon——窗口图标1.6 windowOpacity——透明度设置1.7 cursor——光标设置1.8 font——字体设置1.9 toolTip——鼠标悬停提示设置1…...

ping基本使用详解

在网络中ping是一个十分强大的TCP/IP工具。它的作用主要为: 用来检测网络的连通情况和分析网络速度根据域名得到服务器 IP根据 ping 返回的 TTL 值来判断对方所使用的操作系统及数据包经过路由器数量。我们通常会用它来直接 ping ip 地址,来测试网络的连…...

Win10之解决:设置静态IP后,为什么自动获取动态IP问题(七十八)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…...

【AI论文精读1】针对知识密集型NLP任务的检索增强生成(RAG原始论文)

目录 一、简介一句话简介作者、引用数、时间论文地址开源代码地址 二、摘要三、引言四、整体架构(用一个例子来阐明)场景例子:核心点: 五、方法 (架构各部分详解)5.1 模型1. RAG-Sequence Model2. RAG-Toke…...

踩坑spring cloud gateway /actuator/gateway/refresh不生效

版本 java version: 17 spring boot: 3.2.x spring cloud: 2023.0.3 现象 参考Spring Cloud Gateway -> Actuator API -> Refreshing the Route Cache 说明,先修改routes配置再调用/actuator/gateway/refresh,接口返回200 status,但…...

【STM32开发环境搭建】-3-STM32CubeMX Project Manager配置-自动生成一个Keil(MDK-ARM) 5的工程

目录 1 KEIL(MDK-ARM) 5 Project工程设置 2 MCU和嵌入式软件包的选择 3 Code Generator 3.1 STM32Cube Firmware Library Package 3.2 Generated files 3.3 HAL Settings 3.4 Template Settings 4 Advanced Settings 5 自动生成的KEIL(MDK-ARM) 5 Project工程目录 结…...

计算机毕业设计 Java酷听音乐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...

【Axure高保真原型】引导弹窗

今天和大家中分享引导弹窗的原型模板,载入页面后,会显示引导弹窗,适用于引导用户使用页面,点击完成后,会显示下一个引导弹窗,直至最后一个引导弹窗完成后进入首页。具体效果可以点击下方视频观看或打开下方…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

从零实现富文本编辑器#5-编辑器选区模型的状态结构表达

先前我们总结了浏览器选区模型的交互策略,并且实现了基本的选区操作,还调研了自绘选区的实现。那么相对的,我们还需要设计编辑器的选区表达,也可以称为模型选区。编辑器中应用变更时的操作范围,就是以模型选区为基准来…...

(二)原型模式

原型的功能是将一个已经存在的对象作为源目标,其余对象都是通过这个源目标创建。发挥复制的作用就是原型模式的核心思想。 一、源型模式的定义 原型模式是指第二次创建对象可以通过复制已经存在的原型对象来实现,忽略对象创建过程中的其它细节。 📌 核心特点: 避免重复初…...

Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式

点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年,截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始,将英文题库免费公布出来,并进行解析,帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

回溯算法学习

一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...

智能AI电话机器人系统的识别能力现状与发展水平

一、引言 随着人工智能技术的飞速发展,AI电话机器人系统已经从简单的自动应答工具演变为具备复杂交互能力的智能助手。这类系统结合了语音识别、自然语言处理、情感计算和机器学习等多项前沿技术,在客户服务、营销推广、信息查询等领域发挥着越来越重要…...