当前位置: 首页 > news >正文

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

简介

在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性:对抗样本(Adversarial Examples)。简单来说,对抗样本是通过对原始输入数据添加微小扰动,导致神经网络预测错误的一类特殊样本。

Fast Gradient Sign Method (FGSM) 是最早提出的一种简单且有效的对抗攻击方法。它通过利用神经网络的梯度信息,快速生成对抗样本。本篇博客将深入剖析 FGSM 的数学原理,并提供一份基于 PyTorch 的代码实现,帮助你快速上手对抗攻击。

最终效果

当前图片为 truck,但经过对抗攻击后,结果看似应该同样是 truck,然而模型却错误地将其分类为了 cat

在这里插入图片描述

数学原理

1. 神经网络的基本工作原理

神经网络通过对输入数据 x x x 进行一系列的非线性变换,输出一个预测值 y ^ \hat{y} y^。以图像分类任务为例,输入 x x x 是图像,网络输出 y ^ \hat{y} y^ 是一个分类结果。网络的目标是使得预测值 y ^ \hat{y} y^ 尽可能接近真实标签 y y y,通过最小化损失函数 L ( y ^ , y ) L(\hat{y}, y) L(y^,y) 来优化网络的权重。

2. 对抗攻击的目标

对抗攻击的核心思想是在输入数据上添加一个微小的扰动 η \eta η,使得网络在预测时发生错误。这种扰动应尽量保持小到人类肉眼难以察觉,但足以让网络产生不同的输出,即:

y ^ a d v = f ( x + η ) and y ^ a d v ≠ y \hat{y}_{adv} = f(x + \eta) \quad \text{and} \quad \hat{y}_{adv} \neq y y^adv=f(x+η)andy^adv=y

3. FGSM 的核心思想

FGSM 利用了梯度上升的思想,通过损失函数相对于输入图像的梯度来找到 最容易 迷惑网络的方向,并沿着这个方向对图像进行微小的扰动。

具体来说,给定输入图像 x x x 和其真实标签 y y y,我们可以计算损失函数 L ( x , y ) L(x, y) L(x,y),并对输入图像 x x x 计算其梯度:

∇ x L ( f ( x ) , y ) \nabla_x L(f(x), y) xL(f(x),y)

这个梯度告诉我们,如何改变输入图像才能最大化损失。FGSM 的基本想法是,沿着这个梯度的符号方向对图像进行微调,以最大化损失函数。具体公式为:

x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵsign(xL(f(x),y))

其中:

  • ϵ \epsilon ϵ 是一个小的常数,控制扰动的大小。
  • sign ( ∇ x L ( f ( x ) , y ) ) \text{sign}(\nabla_x L(f(x), y)) sign(xL(f(x),y)) 是损失函数对输入图像梯度的符号,即正负号矩阵。

通过这种方式,我们可以生成一个对抗样本 x a d v x_{adv} xadv,它与原始图像 x x x 看起来几乎相同,但神经网络会将其错误分类。

4. 梯度上升与对抗样本

FGSM 是基于梯度上升的攻击方法。为了让网络犯错,我们希望最大化损失,因此通过扰动让损失函数 L L L 增加。梯度 ∇ x L \nabla_x L xL 的方向是使 L L L 增加最快的方向,因此,我们可以通过对梯度取符号(sign)来生成一个扰动 η \eta η,即:

η = ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) \eta = \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) η=ϵsign(xL(f(x),y))

这样生成的扰动被加到原图 x x x 上,即对抗样本的公式为:

x a d v = x + η = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \eta = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+η=x+ϵsign(xL(f(x),y))

代码实现与解读

接下来,我们通过实际的代码示例,展示如何使用 PyTorch 实现 FGSM 攻击。

1. 加载CIFAR-10数据集与预训练模型

首先,我们需要加载 CIFAR-10 数据集和一个预训练的模型。这里我们使用 ResNet-18 模型,并针对 CIFAR-10 数据集进行了调整。

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models# 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# 加载微调后的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)
model.eval()  # 设置为评估模式

在这一步中,我们准备了数据集,并将模型设置为评估模式,这样模型的参数不会在攻击过程中被更新。

提示:如果没有下载过 CIFAR-10 数据集,请将 download=False 改为 download=True

2. 实现FGSM攻击方法

FGSM 的关键在于利用损失函数的梯度来生成对抗样本。具体步骤如下:

# 定义 FGSM 攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image
数学公式:

在代码中,我们实际上实现了如下的公式:

x a d v = x + ϵ ⋅ sign ( ∇ x L ( f ( x ) , y ) ) x_{adv} = x + \epsilon \cdot \text{sign}(\nabla_x L(f(x), y)) xadv=x+ϵsign(xL(f(x),y))

该函数实现了 FGSM 算法,通过将计算得到的梯度方向乘以一个小的扰动系数 ϵ \epsilon ϵ,并添加到原始图像上来生成对抗样本。

3. 攻击流程

接下来,我们定义一个函数,来完成攻击过程:

def attack_example(model, device, data, target, epsilon):# 将数据和标签移动到设备上data, target = data.to(device), target.to(device)data.requires_grad = True  # 追踪输入图像的梯度# 前向传播得到初始预测output = model(data)init_pred = output.max(1, keepdim=True)[1]  # 选择概率最大的类别作为预测结果# 如果初始预测错误,则不进行攻击if init_pred.item() != target.item():return None, None, None# 计算损失并进行反向传播计算梯度loss_fn = torch.nn.CrossEntropyLoss()loss = loss_fn(output, target)model.zero_grad()  # 清除现有的梯度loss.backward()  # 计算损失对输入图像的梯度# 使用 FGSM 方法生成对抗样本data_grad = data.grad.dataperturbed_data = fgsm_attack(data, epsilon, data_grad)# 使用对抗样本进行新的预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]  # 获取对抗样本的预测标签return data, perturbed_data, final_pred

对抗攻击的基本思路是:

  1. 前向传播:将输入数据(如图片)传递给模型,得到输出,并预测标签。
  2. 计算损失:使用损失函数来评估模型的预测与真实标签之间的差距。
  3. 反向传播计算梯度:通过反向传播,计算损失函数相对于输入数据的梯度。
  4. 生成对抗样本:利用这些梯度来调整输入数据,生成扰动的对抗样本。
  5. 重新预测:使用生成的对抗样本进行重新预测,观察模型是否做出错误的分类。

4. 展示原始图像与对抗样本

最后,我们选择一个样本进行攻击,并展示原始图像和对抗样本的效果。

# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01  # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")

在运行这段代码后,你会看到一张原始图像及其对抗样本。尽管对抗样本与原图在视觉上几乎没有区别,但它却被神经网络错误分类了。这种现象突显了神经网络的脆弱性。

结论

FGSM 是一种简单且有效的对抗攻击方法,它利用了神经网络的梯度信息生成对抗样本。这些对抗样本虽然肉眼难以察觉,但能显著影响模型的预测结果。本文介绍了 FGSM 的数学原理,并提供了基于 PyTorch 的实现代码。

通过这个例子,我们可以看到即使是看似强大的深度学习模型,也会受到一些精心设计的输入的严重影响。了解这些脆弱性有助于我们开发出更加健壮和安全的模型。

附录:完整代码

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np# Step 1: 设置设备 (GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Step 2: 加载 CIFAR-10 数据集
transform = transforms.Compose([transforms.ToTensor(),
])# 加载测试数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)# model = models.resnet18(pretrained=True).to(device)# 1. 加载在 CIFAR-10 上微调的 ResNet 模型
model = models.resnet18(pretrained=False, num_classes=10).to(device)model.eval()  # 设置为评估模式# 使用的损失函数
loss_fn = nn.CrossEntropyLoss()# Step 4: FGSM 对抗攻击函数
def fgsm_attack(image, epsilon, data_grad):# 生成扰动方向sign_data_grad = data_grad.sign()# 生成对抗样本perturbed_image = image + epsilon * sign_data_grad# 对抗样本像素值范围约束在 [0,1]perturbed_image = torch.clamp(perturbed_image, 0, 1)return perturbed_image# Step 5: 攻击流程
def attack_example(model, device, data, target, epsilon):data, target = data.to(device), target.to(device)# 确保计算图对输入的梯度data.requires_grad = True# 前向传播output = model(data)init_pred = output.max(1, keepdim=True)[1]  # 预测标签# 如果预测错误,则不攻击if init_pred.item() != target.item():return None, None, None# 计算损失loss = loss_fn(output, target)# 反向传播计算梯度model.zero_grad()loss.backward()# 提取梯度data_grad = data.grad.data# 执行 FGSM 攻击perturbed_data = fgsm_attack(data, epsilon, data_grad)# 再次进行预测output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]  # 对抗样本的预测标签return data, perturbed_data, final_pred# Step 6: 选择 epsilon 并进行测试
epsilon = 0.01  # 扰动强度# 遍历数据集,找到预测正确的样本
for data, target in test_loader:orig_data, perturbed_data, final_pred = attack_example(model, device, data, target, epsilon)if orig_data is not None:break# CIFAR-10 类别
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']if orig_data is not None:# 显示原始图像和对抗样本orig_img = orig_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()perturbed_img = perturbed_data.squeeze().permute(1, 2, 0).detach().cpu().numpy()# 标签# orig_label = target.item()# perturbed_label = final_pred.item()# 获取语义化标签original_label_name = cifar10_classes[target.item()]adversarial_label_name = cifar10_classes[final_pred.item()]# Step 7: 展示结果fig, (ax1, ax2) = plt.subplots(1, 2)ax1.imshow(orig_img)ax1.set_title(f"Original Label: {original_label_name}")ax2.imshow(perturbed_img)ax2.set_title(f"Adversarial Label: {adversarial_label_name}")plt.show()print(f"Original Label: {original_label_name}, Adversarial Label: {adversarial_label_name}")
else:print("Initial prediction was incorrect. No attack performed.")

相关文章:

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击

从原理到代码:如何通过 FGSM 生成对抗样本并进行攻击 简介 在机器学习领域,深度神经网络的强大表现令人印象深刻,尤其是在图像分类等任务上。然而,随着对深度学习的深入研究,研究人员发现了神经网络的一个脆弱性&…...

从零开始学习OMNeT++系列第一弹——OMNeT++的介绍与安装

最近由于由于工作上的需求,接了一个网络仿真的任务。于是开始调研各个仿真平台,然后根据目前的需求和网络上公开资料的多少,决定使用omnet这个网络仿真平台。现在也是刚开始学习,所以决定记录一下从零开始的这个学习过程。因为虽然…...

Cluster Explanation via Polyhedral Descriptions

通过多面体描述进行聚类解释 本文关注聚类描述问题,即在给定数据集及其聚类划分的情况下,解释这些聚类的任务。我们提出了一种新的聚类解释方法,通过在每个聚类周围构建一个多面体,同时最小化最终多面体的复杂性或用于描述的特征…...

爬虫设计思考之一

爬虫设计思考之一 经常做爬虫的人对于技术比较的执着,尤其是本身从事的擅长的技术领域,从而容易忽视与之相近或者相似的技术。因此我建议大家在遇到此类问题的时候,可以采用对比分析的方式来理解。 本次的思考是基于国内最大的中文搜索引擎百…...

解决centos 删除文件后但空间没有释放

一、问题描述:磁盘空间不足,清理完垃圾日志以后磁盘空间还是没有释放 查看磁盘空间 [rootxwj-qt-65-44 ~]# df -h Filesystem Size Used Avail Use% Mounted on devtmpfs 1.9G 0 1.9G 0% /dev tmpfs 1.9G 0 1.9G …...

微软SCCM:企业级系统管理的核心工具

目录 摘要 1. 引言 2. SCCM的基本概念 2.1 什么是SCCM? 2.2 SCCM的历史 3. SCCM的架构 3.1 中心服务器 3.2 数据库 3.3 管理点(Management Point) 3.4 分发点(Distribution Point) 3.5 客户端代理 3.6 报告服务 4. SCCM的核心功能 4.1 软件部署与管理 4.2 操…...

RTSP作为客户端 推流 拉流的过程分析

之前写过一个 rtsp server 作为服务端的简单demo 这次分析下 rtsp作为客户端 推流和拉流时候的过 A.作为客户端拉流 TCP方式 1.Client发送OPTIONS方法 Server回应告诉支持的方法 2.Client发送DESCRIPE方法 这里是从海康摄像机拉流并且设置了用户名密码 Server回复未认证 3.客…...

【MySQL 07】内置函数

目录 1.日期函数 日期函数使用场景: 2.字符串函数 字符串函数使用场景: 3.数学函数 4.控制流函数 1.日期函数 函数示例: 1.在日期的基础上加日期 在该日期下,加上10天。 2.在日期的基础上减去时间 在该日期下减去2天 3.计算两…...

《深度学习》OpenCV 背景建模 原理及案例解析

目录 一、背景建模 1、什么是背景建模 2、背景建模的方法 1)帧差法(backgroundSubtractor) 2)基于K近邻的背景/前景分割算法BackgroundSubtractorKNN 3)基于高斯混合的背景/前景分割算法BackgroundSubtractorMOG2 3、步骤 1)初…...

机器学习(1):机器学习的概念

1. 机器学习的定义和相关概念 机器学习之父 Arthur Samuel 对机器学习的定义是:在没有明确设置的情况下,使计算机具有学习能力的研究领域。 国际机器学习大会的创始人之一 Tom Mitchell 对机器学习的定义是:计算机程序从经验 E 中学习&#…...

0. Pixel3 在Ubuntu22下Android12源码拉取 + 编译

0. Pixel3 在Ubuntu22下Android12源码拉取 编译 原文地址: http://www.androidcrack.com/index.php/archives/3/ 1. 前言 这是一个非常悲伤的故事, 因为一个意外, 不小心把之前镜像的源码搞坏了. 也没做版本管理,恢复不了了. 那么只能说是重新做一次. 再者以前的镜像太老旧…...

ip经过多个服务器转发会网速变慢吗

会的,IP经过多个服务器转发时,网速通常会变慢,主要原因包括: 增加的延迟: 每经过一个服务器,数据包就需要额外的时间进行处理和转发。这种处理时间和网络延迟会累积,导致整体延迟增加。 带宽限制…...

mongodb通过mongoimport导入JSON文件数据

目录 一、概念 二、mongoimport导入工具 三、导入命令 一、概念 MongoDB是一个流行的开源文档数据库,它支持JSON格式的文档,非常适合存储和处理大量的非结构化数据。在实际应用中,我们经常需要将大量的数据批量导入到MongoDB中。mongoimpo…...

【Qt】控件概述 (1)

控件概述 1. QWidget核心属性1.1核心属性概述1.2 enable1.3 geometry——窗口坐标1.4 window frame的影响1.4 windowTitle——窗口标题1.5 windowIcon——窗口图标1.6 windowOpacity——透明度设置1.7 cursor——光标设置1.8 font——字体设置1.9 toolTip——鼠标悬停提示设置1…...

ping基本使用详解

在网络中ping是一个十分强大的TCP/IP工具。它的作用主要为: 用来检测网络的连通情况和分析网络速度根据域名得到服务器 IP根据 ping 返回的 TTL 值来判断对方所使用的操作系统及数据包经过路由器数量。我们通常会用它来直接 ping ip 地址,来测试网络的连…...

Win10之解决:设置静态IP后,为什么自动获取动态IP问题(七十八)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【…...

【AI论文精读1】针对知识密集型NLP任务的检索增强生成(RAG原始论文)

目录 一、简介一句话简介作者、引用数、时间论文地址开源代码地址 二、摘要三、引言四、整体架构(用一个例子来阐明)场景例子:核心点: 五、方法 (架构各部分详解)5.1 模型1. RAG-Sequence Model2. RAG-Toke…...

踩坑spring cloud gateway /actuator/gateway/refresh不生效

版本 java version: 17 spring boot: 3.2.x spring cloud: 2023.0.3 现象 参考Spring Cloud Gateway -> Actuator API -> Refreshing the Route Cache 说明,先修改routes配置再调用/actuator/gateway/refresh,接口返回200 status,但…...

【STM32开发环境搭建】-3-STM32CubeMX Project Manager配置-自动生成一个Keil(MDK-ARM) 5的工程

目录 1 KEIL(MDK-ARM) 5 Project工程设置 2 MCU和嵌入式软件包的选择 3 Code Generator 3.1 STM32Cube Firmware Library Package 3.2 Generated files 3.3 HAL Settings 3.4 Template Settings 4 Advanced Settings 5 自动生成的KEIL(MDK-ARM) 5 Project工程目录 结…...

计算机毕业设计 Java酷听音乐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析

今天聊的内容,我认为是AI开发里面非常重要的内容。它在AI开发里无处不在,当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗",或者让翻译模型 "将这段合同翻译成商务日语" 时,输入的这句话就是 Prompt。…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...

Oracle查询表空间大小

1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面,开源代码 作为一个电子罗盘模块,我们可以通过I2C从中获取偏航角yaw,相对于六轴陀螺仪的yaw,qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

最新SpringBoot+SpringCloud+Nacos微服务框架分享

文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...

MODBUS TCP转CANopen 技术赋能高效协同作业

在现代工业自动化领域,MODBUS TCP和CANopen两种通讯协议因其稳定性和高效性被广泛应用于各种设备和系统中。而随着科技的不断进步,这两种通讯协议也正在被逐步融合,形成了一种新型的通讯方式——开疆智能MODBUS TCP转CANopen网关KJ-TCPC-CANP…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得,如果用户端访问量比较大,数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据,减少数据库查询操作。 缓存逻辑分析: ①每个分类下的菜品保持一份缓存数据…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...