当前位置: 首页 > news >正文

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

  • CIFAR数据集
  • FGSM介绍
  • FGSM代码实现
    • FGSM算法实现
    • 攻击效果
  • 代码汇总
    • fgsm.py
    • train.py
    • advtest.py

之前已经针对CIFAR10训练了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
Pytorch | 从零构建ResNet对CIFAR10进行分类
Pytorch | 从零构建MobileNet对CIFAR10进行分类
Pytorch | 从零构建EfficientNet对CIFAR10进行分类
Pytorch | 从零构建ParNet对CIFAR10进行分类

本篇文章我们使用Pytorch实现快速梯度符号攻击Fast Gradient Sign Method, FGSM)对CIFAR10上的ResNet分类器进行攻击.

CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

FGSM介绍

FGSM(Fast Gradient Sign Method)算法是一种基于梯度的快速攻击算法,由Goodfellow等人在2015年提出,主要用于评估神经网络模型的鲁棒性。以下是对FGSM算法原理的详细介绍:

算法原理

  • FGSM算法的核心思想是利用神经网络的梯度信息来生成对抗样本。对于给定的输入样本,通过计算模型对该样本的损失函数关于输入的梯度,然后根据梯度的符号来确定扰动的方向,最后在该方向上添加一个小的扰动,得到对抗样本。
  • 具体而言,给定一个输入样本 x x x,其对应的真实标签为 y y y,模型的参数为 θ \theta θ,损失函数为 J ( θ , x , y ) J(\theta, x, y) J(θ,x,y)。首先计算损失函数 J J J 关于输入 x x x 的梯度 ∇ x J ( θ , x , y ) \nabla_x J(\theta, x, y) xJ(θ,x,y),然后根据梯度的符号确定扰动的方向 sign ( ∇ x J ( θ , x , y ) ) \text{sign}(\nabla_x J(\theta, x, y)) sign(xJ(θ,x,y)),最后生成对抗样本 x ′ = x + ϵ ⋅ sign ( ∇ x J ( θ , x , y ) ) x' = x + \epsilon \cdot \text{sign}(\nabla_x J(\theta, x, y)) x=x+ϵsign(xJ(θ,x,y)),其中 ϵ \epsilon ϵ 是一个很小的正数,用于控制扰动的大小。

FGSM代码实现

FGSM算法实现

import torch
import torch.nn as nndef FGSM(model, criterion, original_images, labels, epsilon):"""FGSM (Fast Gradient Sign Method)参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始图像- labels: 原始图像的标签- epsilon: 扰动幅度"""perturbed_images = original_images.clone().detach().requires_grad_(True)outputs = model(perturbed_images)loss = criterion(outputs, labels)model.zero_grad()loss.backward()data_grad = perturbed_images.grad.datasign_data_grad = data_grad.sign()perturbed_images = perturbed_images + epsilon * sign_data_gradperturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)return perturbed_images

攻击效果

在这里插入图片描述

代码汇总

fgsm.py

import torch
import torch.nn as nndef FGSM(model, criterion, original_images, labels, epsilon):"""FGSM (Fast Gradient Sign Method)参数:- model: 要攻击的模型- criterion: 损失函数- original_images: 原始图像- labels: 原始图像的标签- epsilon: 扰动幅度"""perturbed_images = original_images.clone().detach().requires_grad_(True)outputs = model(perturbed_images)loss = criterion(outputs, labels)model.zero_grad()loss.backward()data_grad = perturbed_images.grad.datasign_data_grad = data_grad.sign()perturbed_images = perturbed_images + epsilon * sign_data_gradperturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)return perturbed_images

train.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import ResNet18# 数据预处理
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 加载Cifar10训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# 定义设备(GPU或CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 初始化模型
model = ResNet18(num_classes=10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)if __name__ == "__main__":# 训练模型for epoch in range(10):  # 可以根据实际情况调整训练轮数running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')running_loss = 0.0torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')print('Finished Training')

advtest.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *
from attacks import *
import ssl
import os
from PIL import Image
import matplotlib.pyplot as pltssl._create_default_https_context = ssl._create_unverified_context# 定义数据预处理操作
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False, num_workers=2)# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = ResNet18(num_classes=10).to(device)criterion = nn.CrossEntropyLoss()# 加载模型权重
weights_path = "weights/epoch_10.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))if __name__ == "__main__":# 在测试集上进行FGSM攻击并评估准确率model.eval()  # 设置为评估模式correct = 0total = 0epsilon = 16 / 255  # 可以调整扰动强度for data in testloader:original_images, labels = data[0].to(device), data[1].to(device)original_images.requires_grad = Trueattack_name = 'FGSM'if attack_name == 'FGSM':perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)perturbed_outputs = model(perturbed_images)_, predicted = torch.max(perturbed_outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / total# Attack Success RateASR = 100 - accuracyprint(f'Load ResNet Model Weight from {weights_path}')print(f'epsilon: {epsilon}')print(f'ASR of {attack_name} : {ASR}%')

相关文章:

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击

Pytorch | 利用FGSM针对CIFAR10上的ResNet分类器进行对抗攻击 CIFAR数据集FGSM介绍FGSM代码实现FGSM算法实现攻击效果 代码汇总fgsm.pytrain.pyadvtest.py 之前已经针对CIFAR10训练了多种分类器: Pytorch | 从零构建AlexNet对CIFAR10进行分类 Pytorch | 从零构建Vgg…...

消息队列(二)消息队列的高可用原理

高可用的定义 上一篇文章提过,引入了消息队列的优势在于解耦、并发和缓冲,代价就是让程序的复杂性上升,引入了消息队列后,就需要考虑消息队列对于系统整体的影响,此时消息队列的稳定和健壮就是重中之重。也就是消息队…...

大模型-使用Ollama+Dify在本地搭建一个专属于自己的聊天助手与知识库

大模型-使用OllamaDify在本地搭建一个专属于自己的知识库 1、本地安装Dify2、本地安装Ollama并解决跨越问题3、使用Dify搭建聊天助手4、使用Dify搭建本地知识库 1、本地安装Dify 参考往期博客:https://guoqingru.blog.csdn.net/article/details/144683767 2、本地…...

深入理解索引的最左匹配原则:底层逻辑解析

1. 什么是最左匹配原则? 最左匹配原则是指在使用复合索引时,查询条件从左到右依次匹配索引列的顺序,一旦中间有列未匹配,索引将停止工作或部分失效。 1.1 举例说明 假设我们有一张用户表(users)&#xf…...

微服务——数据管理与一致性

1、在微服务架构中,每个微服务都有自己的数据库,这种设计有什么优点和挑战? 优点挑战服务自治:每个微服务可独立选择适合自己的数据库类型。数据一致性:跨微服务的事务难以保证强一致性。故障隔离:一个微服…...

Docker之技术架构【八大架构演进之路】

Docker之技术架构 1. 八大架构演进之路1.1 单机架构1.2 应用数据分离架构1.3 应用服务集群架构1.4 读写分离架构1.5 冷热分离架构1.6 垂直分库架构1.7 微服务架构1.8 容器编排架构(docker出现) 2. 一个互联网实战架构 本章意在让大家了解Docker出现的历史…...

CSP-X2024山东小学组T4:刷题

题目链接 CSP-X2024山东小学组T4:刷题 题目描述 比赛之路多艰,做题方得提升。努力刷题的人在比赛中往往能取得很好的成绩,小红就是这样的人。 为了继续提升自己的编程实力,小红整理了一份刷题题单,并选中了题单中的…...

【Windows指令】Windows常用快捷指令

一.查找系统上所有可用的 .cpl 文件 要查找系统上所有可用的 .cpl 文件,你可以浏览到以下目录: C:\Windows\System32在“System32”文件夹中搜索扩展名为 .cpl 的文件,将列出所有可用的控制面板小程序。 ❗某些 .cpl 文件可能仅存在于特定的…...

NLP中的神经网络基础

一:多层感知器模型 1:感知器 解释一下,为什么写成 wxb>0 ,其实原本是 wx > t ,t就是阈值,超过这个阈值fx就为1,现在把t放在左边。 在感知器里面涉及到两个问题: 第一个,特征提…...

安全筑堤,效率破浪 | 统一运维管理平台下的免密登录应用解析

在信息技术迅猛发展的今天,企业运维管理领域正面临着前所未有的复杂挑战。统一运维管理平台作为集中管理和监控IT基础设施的核心工具,其安全性和效率至关重要。免密登录作为一种新兴的身份验证技术,正逐渐成为提升运维管理效率和安全性的重要…...

初学elasticsearch

ES 文章目录 ES一、初识elasticsearch1、什么是elasticsearch,elastic static,Lucene2、倒排索引2.1、正向索引和倒排序索引 3、es与mysql的概念对比3.1、文档3.2、索引3.3、es与数据库中的关系 二、索引库操作1、mapping属性2、创建索引库和映射基本语法…...

HTMLCSS:惊!3D 折叠按钮

这段代码创建了一个具有 3D 效果和动画的按钮,按钮上有 SVG 图标和文本。按钮在鼠标悬停时会显示一个漂浮点动画,图标会消失并显示一个线条动画。这种效果适用于吸引用户注意并提供视觉反馈。按钮的折叠效果和背景渐变增加了页面的美观性。 演示效果 HT…...

SDK 指南

在前端开发中,SDK(Software Development Kit,软件开发工具包)是一个用于帮助开发者在特定平台、框架或技术栈中实现某些功能的工具集。 1. SDK 是什么? SDK 是一种开发工具包,它提供了开发人员实现某些功…...

Web 应用项目开发全流程解析与实战经验分享

目录 一、引言 二、需求分析 三、技术选型 四、架构设计 五、开发实现 六、测试优化 七、部署上线 八、实战经验分享 九、总结 一、引言 在当今数字化时代,Web 应用已经深入到我们生活和工作的各个角落。从社交网络到电子商务,从在线办公到娱乐…...

WPS中插入矩阵的方法

WPS中插入矩阵的方法: 1、先选择插入公式中的矩阵中的第二个括号矩阵 选中矩阵右键,点击插入 点击在此后插入列和在此后插入行,会得到3x3矩阵,如图 分别点击两次会得到4x4矩阵,如图,可以画出4x4矩阵...

Python调用R语言中的程序包来执行回归树、随机森林、条件推断树和条件推断森林算法

要使用Python调用R语言中的程序包来执行回归树、随机森林、条件推断树和条件推断森林算法,重新计算中国居民收入不平等,并进行分类汇总,我们可以使用rpy2库。rpy2允许在Python中嵌入R代码并调用R函数。以下是一个详细的步骤和示例代码&#x…...

uniapp input苹果中文键盘输入拼音直接切换输入焦点监听失效

问题: uniapp微信小程序,苹果手机中文键盘状态下,输入字母时,不点击确定也不点击空白处,直接切换到下一个input输入框,UI界面会保留上个输入框输入的内容,但input、blur事件监听到的值都是空&a…...

多智能体/多机器人网络中的图论法

一、引言 1、网络科学至今受到广泛关注的原因: (1)大量的学科(尤其生物及材料科学)需要对元素间相互作用在多层级系统中所扮演的角色有更深层次的理解; (2)科技的发展促进了综合网…...

华为:数字化转型只有“起点”,没有“终点”

上个月,我收到了一位朋友的私信,他询问我是否有关于华为数字化转型的资料。幸运的是,我手头正好收藏了一些,于是我便分享给他。 然后在昨天,他又再次联系我,并感慨:“如果当初我在进行企业数字…...

centos server系统新装后的网络配置

当前状态: ping www.baidu.com报错 1、检查IP ip addr show记录要编辑的网卡 link/ether 后的XX:XX:XX:XX:XX:XX号 2、以em1为例: vi /etc/sysconfig/network-scripts/ifcfg-em1,新增如下行: HWADDRXX:XX:XX:XX:XX:XX(具体值…...

消息队列处理模式:流式与批处理的艺术

🌊 消息队列处理模式:流式与批处理的艺术 📌 深入解析现代分布式系统中的数据处理范式 一、流式处理:实时数据的"活水" 在大数据时代,流式处理已成为实时分析的核心技术。它将数据视为无限的流,…...

VsCode 安装 Cline 插件并使用免费模型(例如 DeepSeek)

当前时间为 25/6/3,Cline 版本为 3.17.8 点击侧边栏的“扩展”图标 在搜索框中输入“Cline” 找到 Cline 插件,然后点击“安装” 安装完成后,Cline 图标会出现在 VS Code 的侧边栏中 点击 Use your own API key API Provider 选择 OpenRouter…...

C++课设:学生成绩管理系统

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、项目功能概览1. 核心功能模块2. 系统特色亮点3. 完整代码4. 运行演示二、核心结构设计1. 系统架构设计2. Stud…...

arc3.2语言sort的时候报错:(sort < `(2 9 3 7 5 1)) 需要写成这种:(sort > (pair (list 3 2)))

arc语言sort的时候报错&#xff1a;(sort < (2 9 3 7 5 1)) arc> (sort < (2 9 3 7 5 1)) Error: "set-car!: expected argument of type <pair>; given: 9609216" arc> (sort < (2 9 3 )) Error: "Function call on inappropriate object…...

埃文科技智能数据引擎产品入选《中国网络安全细分领域产品名录》

嘶吼安全产业研究院发布《中国网络安全细分领域产品名录》&#xff0c;埃文科技智能数据引擎产品成功入选数据分级分类产品名录。 在数字化转型加速的今天&#xff0c;网络安全已成为企业生存与发展的核心基石&#xff0c;为了解这一蓬勃发展的产业格局&#xff0c;嘶吼安全产业…...

防爆型断链保护器的应用场景有哪些?

​ ​防爆型断链保护器是一种用于防止链条断裂导致设备损坏或安全事故的装置&#xff0c;尤其适用于存在爆炸风险的工业环境。以下是其主要应用场景&#xff1a; ​ ​1.石油化工行业 在石油化工厂、炼油厂等场所&#xff0c;防爆型断链保护器可用于保护输送设备&#xf…...

MySQL的并发事务问题及事务隔离级别

一、并发事务问题 1). 赃读&#xff1a;一个事务读到另外一个事务还没有提交的数据。 比如 B 读取到了 A 未提交的数据。 2). 不可重复读&#xff1a;一个事务先后读取同一条记录&#xff0c;但两次读取的数据不同&#xff0c;称之为不可重复读。 事务 A 两次读取同一条记录&…...

gateway 网关 路由新增 (已亲测)

问题&#xff1a; 前端通过gateway调用后端接口&#xff0c;路由转发失败&#xff0c;提示404 not found 排查&#xff1a; 使用 { "href":"/actuator/gateway/routes", "methods":[ "POST", "GET" ] } 命令查看路由列表&a…...

hot100 -- 6.矩阵系列

1.矩阵置零 问题&#xff1a;给定一个 m x n 的矩阵&#xff0c;如果一个元素为 0 &#xff0c;则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 方法&#xff1a;记录行列 置0 # 记录行列&#xff0c;分别置0 def set_zero(matrix):row, col [], []# 统计0元素…...

每日算法刷题Day23 6.5:leetcode二分答案3道题,用时1h40min(有点慢)

8. 3007.价值和小于等于K的最大数字(中等&#xff0c;学习,太难&#xff0c;先过) 3007. 价值和小于等于 K 的最大数字 - 力扣&#xff08;LeetCode&#xff09; 思想 1.给你一个整数 k 和一个整数 x 。整数 num 的价值是它的二进制表示中在 x&#xff0c;2x&#xff0c;3x …...