当前位置: 首页 > news >正文

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • FGSM介绍
  • FGSM代码实现
    • FGSM算法实现
    • 攻击效果
  • 代码汇总
    • fgsm.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现快速梯度符号攻击Fast Gradient Sign Method, FGSM)对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

FGSM介绍

FGSM(Fast Gradient Sign Method)算法是一种基于梯度的快速攻击算法,由Goodfellow等人在2015年提出,主要用于评估神经网络模型的鲁棒性。以下是对FGSM算法原理的详细介绍:

算法原理

  • FGSM算法的核心思想是利用神经网络的梯度信息来生成对抗样本。对于给定的输入样本,通过计算模型对该样本的损失函数关于输入的梯度,然后根据梯度的符号来确定扰动的方向,最后在该方向上添加一个小的扰动,得到对抗样本。
  • 具体而言,给定一个输入样本 x x x,其对应的真实标签为 y y y,模型的参数为 θ \theta θ,损失函数为 J ( θ , x , y ) J(\theta, x, y) J(θ,x,y)。首先计算损失函数 J J J 关于输入 x x x 的梯度 ∇ x J ( θ , x , y ) \nabla_x J(\theta, x, y) xJ(θ,x,y),然后根据梯度的符号确定扰动的方向 sign ( ∇ x J ( θ , x , y ) ) \text{sign}(\nabla_x J(\theta, x, y)) sign(xJ(θ,x,y)),最后生成对抗样本 x ′ = x + ϵ ⋅ sign ( ∇ x J ( θ , x , y ) ) x' = x + \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)) x=x+ϵsign(xJ(θ,x,y)),其中 ϵ \epsilon ϵ 是一个很小的正数,用于控制扰动的大小。

FGSM代码实现

FGSM算法实现

import torch
import torch.nn as nndef FGSM(model, criterion, original_images, labels, epsilon):"""FGSM (Fast Gradient Sign Method)参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始图像- labels: 原始图像的标签- epsilon: 扰动幅度"""perturbed_images = original_images.clone().detach().requires_grad_(True)outputs = model(perturbed_images)loss = criterion(outputs, labels)model.zero_grad()loss.backward()data_grad = perturbed_images.grad.datasign_data_grad = data_grad.sign()perturbed_images = perturbed_images + epsilon * sign_data_gradperturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)return perturbed_images

攻击效果

在这里插入图片描述

代码汇总

fgsm.py

import torch
import torch.nn as nndef FGSM(model, criterion, original_images, labels, epsilon):"""FGSM (Fast Gradient Sign Method)参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始图像- labels: 原始图像的标签- epsilon: 扰动幅度"""perturbed_images = original_images.clone().detach().requires_grad_(True)outputs = model(perturbed_images)loss = criterion(outputs, labels)model.zero_grad()loss.backward()data_grad = perturbed_images.grad.datasign_data_grad = data_grad.sign()perturbed_images = perturbed_images + epsilon * sign_data_gradperturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)return perturbed_images

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18# 数据预处理
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)if __name__ == "__main__":# 训练模型for epoch in range(10):  # 可以根据实际情况调整训练轮数running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')running_loss = 0.0torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as pltssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ResNet18(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))if __name__ == "__main__":# 在测试集上进行FGSM攻击并评估准确率model.eval()  # 设置为评估模式correct = 0total = 0epsilon = 16 / 255  # 可以调整扰动强度for data in testloader:original_images, labels = data[0].to(device), data[1].to(device)original_images.requires_grad = Trueattack_name = 'FGSM'if attack_name == 'FGSM':perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)perturbed_outputs = model(perturbed_images)_, predicted = torch.max(perturbed_outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total# Attack Success RateASR = 100 - accuracyprint(f'Load ResNet Model Weight from {weights_path}')print(f'epsilon: {epsilon}')print(f'ASR of {attack_name} : {ASR}%')

相关文章:

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集FGSM介绍FGSM代码实现FGSM算法实现攻击效果 代码汇总fgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器: Pytorch | 从零构建AlexNet对CIFAR10进行分类 Pytorch | 从零构建Vgg…...

消息队列(二)消息队列的高可用原理

高可用的定义 上一篇文章提过,引入了消息队列的优势在于解耦、并发和缓冲,代价就是让程序的复杂性上升,引入了消息队列后,就需要考虑消息队列对于系统整体的影响,此时消息队列的稳定和健壮就是重中之重。也就是消息队…...

大模型-使用Ollama+Dify在本地搭建一个专属于自己的聊天助手与知识库

大模型-使用OllamaDify在本地搭建一个专属于自己的知识库 1、本地安装Dify2、本地安装Ollama并解决跨越问题3、使用Dify搭建聊天助手4、使用Dify搭建本地知识库 1、本地安装Dify 参考往期博客:https://guoqingru.blog.csdn.net/article/details/144683767 2、本地…...

深入理解索引的最左匹配原则:底层逻辑解析

1. 什么是最左匹配原则? 最左匹配原则是指在使用复合索引时,查询条件从左到右依次匹配索引列的顺序,一旦中间有列未匹配,索引将停止工作或部分失效。 1.1 举例说明 假设我们有一张用户表(users)&#xf…...

微服务——数据管理与一致性

1、在微服务架构中,每个微服务都有自己的数据库,这种设计有什么优点和挑战? 优点挑战服务自治:每个微服务可独立选择适合自己的数据库类型。数据一致性:跨微服务的事务难以保证强一致性。故障隔离:一个微服…...

Docker之技术架构【八大架构演进之路】

Docker之技术架构 1. 八大架构演进之路1.1 单机架构1.2 应用数据分离架构1.3 应用服务集群架构1.4 读写分离架构1.5 冷热分离架构1.6 垂直分库架构1.7 微服务架构1.8 容器编排架构(docker出现) 2. 一个互联网实战架构 本章意在让大家了解Docker出现的历史…...

CSP-X2024山东小学组T4:刷题

题目链接 CSP-X2024山东小学组T4:刷题 题目描述 比赛之路多艰,做题方得提升。努力刷题的人在比赛中往往能取得很好的成绩,小红就是这样的人。 为了继续提升自己的编程实力,小红整理了一份刷题题单,并选中了题单中的…...

【Windows指令】Windows常用快捷指令

一.查找系统上所有可用的 .cpl 文件 要查找系统上所有可用的 .cpl 文件,你可以浏览到以下目录: C:\Windows\System32在“System32”文件夹中搜索扩展名为 .cpl 的文件,将列出所有可用的控制面板小程序。 ❗某些 .cpl 文件可能仅存在于特定的…...

NLP中的神经网络基础

一:多层感知器模型 1:感知器 解释一下,为什么写成 wxb>0 ,其实原本是 wx > t ,t就是阈值,超过这个阈值fx就为1,现在把t放在左边。 在感知器里面涉及到两个问题: 第一个,特征提…...

安全筑堤,效率破浪 | 统一运维管理平台下的免密登录应用解析

在信息技术迅猛发展的今天,企业运维管理领域正面临着前所未有的复杂挑战。统一运维管理平台作为集中管理和监控IT基础设施的核心工具,其安全性和效率至关重要。免密登录作为一种新兴的身份验证技术,正逐渐成为提升运维管理效率和安全性的重要…...

初学elasticsearch

ES 文章目录 ES一、初识elasticsearch1、什么是elasticsearch,elastic static,Lucene2、倒排索引2.1、正向索引和倒排序索引 3、es与mysql的概念对比3.1、文档3.2、索引3.3、es与数据库中的关系 二、索引库操作1、mapping属性2、创建索引库和映射基本语法…...

HTMLCSS:惊!3D 折叠按钮

这段代码创建了一个具有 3D 效果和动画的按钮,按钮上有 SVG 图标和文本。按钮在鼠标悬停时会显示一个漂浮点动画,图标会消失并显示一个线条动画。这种效果适用于吸引用户注意并提供视觉反馈。按钮的折叠效果和背景渐变增加了页面的美观性。 演示效果 HT…...

SDK 指南

在前端开发中,SDK(Software Development Kit,软件开发工具包)是一个用于帮助开发者在特定平台、框架或技术栈中实现某些功能的工具集。 1. SDK 是什么? SDK 是一种开发工具包,它提供了开发人员实现某些功…...

Web 应用项目开发全流程解析与实战经验分享

目录 一、引言 二、需求分析 三、技术选型 四、架构设计 五、开发实现 六、测试优化 七、部署上线 八、实战经验分享 九、总结 一、引言 在当今数字化时代,Web 应用已经深入到我们生活和工作的各个角落。从社交网络到电子商务,从在线办公到娱乐…...

WPS中插入矩阵的方法

WPS中插入矩阵的方法: 1、先选择插入公式中的矩阵中的第二个括号矩阵 选中矩阵右键,点击插入 点击在此后插入列和在此后插入行,会得到3x3矩阵,如图 分别点击两次会得到4x4矩阵,如图,可以画出4x4矩阵...

Python调用R语言中的程序包来执行回归树、随机森林、条件推断树和条件推断森林算法

要使用Python调用R语言中的程序包来执行回归树、随机森林、条件推断树和条件推断森林算法,重新计算中国居民收入不平等,并进行分类汇总,我们可以使用rpy2库。rpy2允许在Python中嵌入R代码并调用R函数。以下是一个详细的步骤和示例代码&#x…...

uniapp input苹果中文键盘输入拼音直接切换输入焦点监听失效

问题: uniapp微信小程序,苹果手机中文键盘状态下,输入字母时,不点击确定也不点击空白处,直接切换到下一个input输入框,UI界面会保留上个输入框输入的内容,但input、blur事件监听到的值都是空&a…...

多智能体/多机器人网络中的图论法

一、引言 1、网络科学至今受到广泛关注的原因: (1)大量的学科(尤其生物及材料科学)需要对元素间相互作用在多层级系统中所扮演的角色有更深层次的理解; (2)科技的发展促进了综合网…...

华为:数字化转型只有“起点”,没有“终点”

上个月,我收到了一位朋友的私信,他询问我是否有关于华为数字化转型的资料。幸运的是,我手头正好收藏了一些,于是我便分享给他。 然后在昨天,他又再次联系我,并感慨:“如果当初我在进行企业数字…...

centos server系统新装后的网络配置

当前状态: ping www.baidu.com报错 1、检查IP ip addr show记录要编辑的网卡 link/ether 后的XX:XX:XX:XX:XX:XX号 2、以em1为例: vi /etc/sysconfig/network-scripts/ifcfg-em1,新增如下行: HWADDRXX:XX:XX:XX:XX:XX(具体值…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

React Native 导航系统实战(React Navigation)

导航系统实战(React Navigation) React Navigation 是 React Native 应用中最常用的导航库之一,它提供了多种导航模式,如堆栈导航(Stack Navigator)、标签导航(Tab Navigator)和抽屉…...

AI Agent与Agentic AI:原理、应用、挑战与未来展望

文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...

三体问题详解

从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…...

实现弹窗随键盘上移居中

实现弹窗随键盘上移的核心思路 在Android中&#xff0c;可以通过监听键盘的显示和隐藏事件&#xff0c;动态调整弹窗的位置。关键点在于获取键盘高度&#xff0c;并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...

今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存

文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...

html-<abbr> 缩写或首字母缩略词

定义与作用 <abbr> 标签用于表示缩写或首字母缩略词&#xff0c;它可以帮助用户更好地理解缩写的含义&#xff0c;尤其是对于那些不熟悉该缩写的用户。 title 属性的内容提供了缩写的详细说明。当用户将鼠标悬停在缩写上时&#xff0c;会显示一个提示框。 示例&#x…...

Java编程之桥接模式

定义 桥接模式&#xff08;Bridge Pattern&#xff09;属于结构型设计模式&#xff0c;它的核心意图是将抽象部分与实现部分分离&#xff0c;使它们可以独立地变化。这种模式通过组合关系来替代继承关系&#xff0c;从而降低了抽象和实现这两个可变维度之间的耦合度。 用例子…...

Qt 事件处理中 return 的深入解析

Qt 事件处理中 return 的深入解析 在 Qt 事件处理中&#xff0c;return 语句的使用是另一个关键概念&#xff0c;它与 event->accept()/event->ignore() 密切相关但作用不同。让我们详细分析一下它们之间的关系和工作原理。 核心区别&#xff1a;不同层级的事件处理 方…...