当前位置: 首页 > article >正文

《Transformer如何进行图像分类:从新手到入门》

引言

如果你对人工智能(AI)或深度学习(Deep Learning)感兴趣,可能听说过“Transformer”这个词。它最初在自然语言处理(NLP)领域大放异彩,比如在翻译、聊天机器人和文本生成中表现出色。但你知道吗?Transformer不仅能处理文字,还能用来分类图像!这听起来是不是有点神奇?别担心,这篇博客将带你从零开始,了解Transformer的基本概念、它如何被应用到图像分类,以及通过一个简单的例子让你直观理解它的运作原理。无论你是AI新手还是好奇的技术爱好者,这篇文章都会尽量用通俗的语言为你解锁Transformer的奥秘。

第一部分:Transformer是什么?

Transformer是一种深度学习模型,最早由Vaswani等人在2017年的论文《Attention is All You Need》中提出。它的核心思想是“注意力机制”(Attention Mechanism),这是一种让模型学会“关注”输入中重要部分的能力。传统的模型,比如卷积神经网络(CNN)和循环神经网络(RNN),在处理图像或序列数据时有局限性,而Transformer通过注意力机制突破了这些限制。

1.1 为什么叫“Transformer”?

“Transformer”这个名字听起来很酷,但它其实反映了模型的功能:它能将输入数据“转换”(Transform)成更有意义的表示形式。比如,把一句话翻译成另一种语言,或者把一张图片“翻译”成一个分类标签(比如“猫”或“狗”)。它的核心在于通过计算输入数据之间的关系,生成更有用的输出。

1.2 Transformer的基本结构

Transformer由两个主要部分组成:编码器(Encoder)和解码器(Decoder)。不过,在图像分类任务中,我们通常只用到编码器部分。让我们简单看看它的组成:

  • 输入嵌入(Input Embedding):把输入数据(比如单词或图像块)转换成数字向量。
  • 注意力机制(Attention):让模型关注输入中最重要的部分。
  • 前馈神经网络(Feed-Forward Network):对数据进一步处理。
  • 层归一化和残差连接(Layer Normalization & Residual Connection):帮助模型稳定训练,避免“梯度消失”等问题。

这些组件堆叠在一起,形成多层结构,每一层都让模型对数据的理解更深一层。

1.3 注意力机制:Transformer的“超能力”

注意力机制是Transformer的核心。想象你在读一本书,当你看到“猫”这个词时,你会自动想到整句话的上下文,比如“猫在睡觉”还是“猫在跑”。注意力机制让模型也能做到这一点:它会计算输入中每个部分对其他部分的“重要性”,然后根据这些关系调整输出。

具体来说,Transformer使用的是“自注意力”(Self-Attention)。它会为输入的每个部分(比如图像的一个小块)生成三个向量:

  • 查询(Query):我想知道什么?
  • 键(Key):我有哪些信息?
  • 值(Value):这些信息有多重要?

通过计算查询和键之间的相似度,模型决定每个值的权重,然后把它们加权组合起来。这种方式让Transformer能捕捉全局关系,而不是像CNN那样只关注局部区域。

第二部分:从NLP到图像分类:Vision Transformer (ViT)

Transformer最初是为NLP设计的,那它是怎么“跨界”到图像分类的呢?这要归功于2020年提出的Vision Transformer(简称ViT)。让我们看看它是如何工作的。

2.1 图像怎么变成Transformer的输入?

图像和文字完全不同,对吧?图像是一堆像素,而文字是一串单词。要让Transformer处理图像,第一步就是把图像“翻译”成它能理解的形式。ViT的做法是:

  1. 切分图像:把一张图片(比如224x224像素)切成固定大小的小块(比如16x16像素),就像把一张大拼图拆成小碎片。
  2. 展平并嵌入:把每个小块展平成一个向量(就像把拼图碎片摊平),然后通过一个线性层把它们变成嵌入向量(Embedding)。
  3. 加上位置信息:因为Transformer不像CNN有固定的空间感知能力,我们需要手动告诉它每个小块在图像中的位置。这通过“位置编码”(Positional Encoding)实现。

经过这些步骤,一张图像就变成了一个序列(Sequence),就像NLP中的一句话,只不过这里的“单词”是图像块。

2.2 Transformer处理图像的过程

一旦图像被转换成序列,Transformer的编码器就开始工作:

  • 自注意力:计算每个图像块和其他图像块之间的关系。比如,在一张猫的图片中,耳朵和眼睛的图像块可能会被关联起来。
  • 多层堆叠:通过多层编码器,模型逐渐提取更高层次的特征。
    分类头:在最后一层,添加一个简单的分类层(比如全连接层),输出图像的类别(比如“猫”或“狗”)。

2.3 ViT的优势和挑战

相比传统的CNN,ViT有几个优点:

  • 全局视野:它能一次性看到整张图像的关系,而不像CNN只关注局部。
  • 灵活性:同一个模型可以轻松处理不同大小的输入。

但它也有挑战:

  • 计算量大:自注意力机制需要大量计算,尤其当图像块很多时。
  • 数据需求高:ViT需要大量标注数据才能训练得好。

第三部分:一个简单的例子:用ViT分类猫和狗

为了让新手更容易理解,我们通过一个具体的例子来说明Transformer如何进行图像分类。假设我们要训练一个模型,区分CIFAR-10数据集中的“猫”和“狗”图片(CIFAR-10是PyTorch内置的一个小型图像数据集,包含10类32x32像素的图像)。下面我们逐步拆解过程,并新增代码实现。

3.1 数据准备

CIFAR-10中的每张图片是32x32像素,RGB格式。我们将它切成4x4的小块(为了简化示例),总共有64个块(32 ÷ 4 = 8,8x8 = 64)。每个小块有48个数值(4x4x3,因为RGB有3个通道)。

3.2 嵌入过程

  • 把每个小块展平成一个48维向量。
  • 通过一个线性层,把48维映射到一个固定维度(比如32维),得到嵌入向量。
  • 加上位置编码,告诉模型每个块的位置。

现在,这张图片变成了一个64x32的矩阵,就像一个有64个“单词”的序列。

3.3 自注意力计算

假设猫咪的耳朵在第10个块,眼睛在第20个块。Transformer会:

  1. 为每个块生成查询、键和值向量。
  2. 计算第10个块的查询和第20个块的键之间的相似度,发现它们关系密切。
  3. 根据相似度加权组合值向量,生成一个新的表示。

经过多层自注意力,模型学会关联猫的特征。

3.4 分类输出

在最后一层,ViT取一个特殊的“分类标记”(CLS Token),通过全连接层输出10个类别的概率(CIFAR-10有10类),比如“猫”的概率是0.8,“狗”是0.1。

3.5 代码实现

下面我们提供两种代码实现方式,帮助你直观感受ViT的运作。代码基于PyTorch,使用CIFAR-10数据集。

实现方式1:从头实现一个简化的ViT

这个实现简化了ViT的核心组件,适合理解原理。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 超参数
patch_size = 4  # 切分图像为4x4的小块
embed_dim = 32  # 每个小块的嵌入维度
num_heads = 4   # 注意力头的数量
num_classes = 10  # CIFAR-10有10个类别
num_patches = (32 // patch_size) ** 2  # 64个小块 (32x32图像)# 数据加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)# 简化的ViT模型
class SimpleViT(nn.Module):def __init__(self):super(SimpleViT, self).__init__()# 将图像块映射到嵌入空间self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, embed_dim)# 位置编码self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))# CLS Tokenself.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))# Transformer编码器self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads), num_layers=2)# 分类头self.fc = nn.Linear(embed_dim, num_classes)def forward(self, x):b, c, h, w = x.shape  # [batch_size, 3, 32, 32]# 切分成小块并展平x = x.view(b, c, h // patch_size, patch_size, w // patch_size, patch_size)x = x.permute(0, 2, 4, 1, 3, 5).contiguous()  # [b, 8, 8, 3, 4, 4]x = x.view(b, num_patches, -1)  # [b, 64, 48]# 映射到嵌入空间x = self.patch_to_embedding(x)  # [b, 64, 32]# 添加CLS Tokencls_tokens = self.cls_token.expand(b, -1, -1)  # [b, 1, 32]x = torch.cat((cls_tokens, x), dim=1)  # [b, 65, 32]# 加上位置编码x = x + self.pos_embedding# 通过Transformerx = self.transformer(x)  # [b, 65, 32]# 取CLS Token的输出进行分类x = self.fc(x[:, 0])  # [b, 10]return x# 训练模型
model = SimpleViT()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(5):  # 训练5个epochfor images, labels in trainloader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

代码解释:

  • 数据加载:从CIFAR-10加载32x32的图像,归一化处理。
  • 图像切分:将32x32图像切成64个4x4的小块,展平后映射到32维嵌入。
  • CLS Token:添加一个特殊标记,用于最终分类。
  • Transformer:使用PyTorch内置的Transformer编码器,包含2层,每层有4个注意力头。
  • 训练:简单训练5个epoch,优化分类损失。
实现方式2:使用预训练ViT模型(Hugging Face)

这个实现利用Hugging Face的预训练ViT模型,适合快速上手。

import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据加载
transform = transforms.Compose([transforms.Resize((224, 224)),  # ViT需要224x224输入transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)# 加载预训练ViT模型和特征提取器
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
model.classifier = torch.nn.Linear(model.classifier.in_features, 10)  # 修改分类头为10类# 训练设置
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# 训练模型
model.train()
for epoch in range(3):  # 训练3个epochfor images, labels in trainloader:inputs = feature_extractor(images=[img.permute(1, 2, 0).numpy() for img in images], return_tensors="pt")inputs = {k: v for k, v in inputs.items()}  # 转换为模型输入格式optimizer.zero_grad()outputs = model(**inputs).logits  # 获取分类输出loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

代码解释:

  • 数据预处理:将CIFAR-10图像调整到224x224(ViT预训练模型的要求)。
  • 预训练模型:加载Google的vit-base-patch16-224,替换分类头为10类。
  • 特征提取器:自动处理图像输入,切分并嵌入。
  • 训练:微调模型,适应CIFAR-10任务。

注意:运行第二种方式需要安装transformers库(pip install transformers)。

第四部分:新手常见问题解答

4.1 Transformer和CNN有什么不同?

CNN像一个放大镜,逐步扫描图像的局部特征;而Transformer像一个全景相机,一次性捕捉全局关系。两者各有千秋,ViT证明了Transformer也能在图像任务中大放异彩。

4.2 我需要多强的编程基础才能用Transformer?

好消息是,你不需要从头写Transformer!开源工具(如PyTorch和Hugging Face)提供了预训练模型。你只需要学会加载模型、准备数据和微调,就能上手。

4.3 ViT适合所有图像任务吗?

不完全是。ViT在大数据集(如ImageNet)上表现很好,但在小数据集或需要精细局部特征的任务上,CNN可能更合适。

第五部分

Transformer通过注意力机制和全局视野,为图像分类带来了新思路。Vision Transformer(ViT)展示了它如何将图像切分成块,像处理句子一样处理图片,最终实现分类。对于新手来说,理解它的关键在于:

  1. 图像如何变成序列。
  2. 自注意力如何捕捉关系。
  3. 分类如何通过简单输出实现。

通过上面的代码示例,你可以看到:

  • 从头实现ViT帮助理解原理。
  • 使用预训练模型能快速应用到实际任务。

相关文章:

《Transformer如何进行图像分类:从新手到入门》

引言 如果你对人工智能(AI)或深度学习(Deep Learning)感兴趣,可能听说过“Transformer”这个词。它最初在自然语言处理(NLP)领域大放异彩,比如在翻译、聊天机器人和文本生成中表现出…...

coding ability 展开第三幕(滑动指针——基础篇)超详细!!!!

文章目录 前言滑动窗口长度最小的子数组思路 无重复字符的最长子串思路 最大连续1的个数思路 将x减到0的最小操作数思路 总结 前言 前面我们已经把双指针的一些习题练习的差不多啦 今天我们来学习新的算法知识——滑动窗口 让我们一起来探索滑动窗口的魅力吧 滑动窗口 滑动窗口…...

RAGFlow版本升级-Win10系统Docker

下载源码压缩包 https://github.com/infiniflow/ragflow.git 删除旧版本代码文件夹,把下载的代码解压到原先目录 更新一下env文件:ragflow/docker/.env 把值改为最新版本即可 RAGFLOW_IMAGEinfiniflow/ragflow:v0.17.1 更新一下docker docker compose -…...

通过mybatis的拦截器对SQL进行打标

1、背景 在我们开发的过程中,一般需要编写各种SQL语句,万一生产环境出现了慢查询,那么我们如何快速定位到底是程序中的那个SQL出现的问题呢? 2、解决方案 如果我们的数据访问层使用的是mybatis的话,那么我们可以通过…...

如何自己做奶茶,从此告别奶茶店

自制大白兔奶茶,奶香与茶香激情碰撞,每一口都是香浓与甜蜜的双重诱惑,好喝到跺脚!丝滑口感在舌尖舞动,仿佛味蕾在开派对。 简单几步就能复刻,成本超低,轻松在家享受奶茶自由。 材料:大白兔奶糖&…...

JavaScript性能优化实战指南

JavaScript性能优化实战指南 1. 性能分析工具与指标 核心工具链 Chrome DevTools: Performance面板:记录运行时性能,分析长任务(Long Tasks)、强制回流(Layout Shifts)、函数调用堆栈。Memory面…...

宇树人形机器人开源模型

1. 下载源码 https://github.com/unitreerobotics/unitree_ros.git2. 启动Gazebo roslaunch h1_description gazebo.launch3. 仿真效果 H1 GO2 B2 Laikago Z1 4. VMware: vmw_ioctl_command error Invalid argument 这个错误通常出现在虚拟机环境中运行需要OpenGL支持的应用…...

【Linux】浅谈冯诺依曼和进程

一、冯诺依曼体系结构 冯诺依曼由 输入设备、输出设备、运算器、控制器、存储器 五部分组成。 冯诺依曼的设计特点 二进制表示 所有数据(包括程序指令)均以二进制形式存储和运算,简化了硬件逻辑设计,提高了可靠性。 存储程序原理…...

env.development.local 和 env.development 的区别

env.development.local 和 env.development 的区别 区别1、场景2、git管理3、加载策略 思考原因如下 区别 1、场景 env.development: 用于开发环境的环境变量配置env.development.local: 用于存储特定于开发者的本地配置信息 2、git管理 env.development.local 会通过*.loca…...

linux操作系统实战

第一题 创建根目录结构中的所有的普通文件 [rootlocalhost ~]# cd /[rootlocalhost /]# mkdir /text[rootlocalhost /]# cd /text[rootlocalhost text]# mkdir /text/boot /text/root /text/home /text/bin /text/sbin /text/lib /text/lib64 /text/usr /text/opt /text/etc /…...

Python Cookbook-4.1 对象拷贝

任务 想拷贝某对象。不过,当你对一个对象赋值,将其作为参数传递,或者作为结果返回时,Python 通常会使用指向原对象的引用,并不是真正的拷贝。 解决方案 Python 标准库的 copy 模块提供了两个函数来创建拷贝。第一个…...

浅谈时钟启动和Systemlnit函数

时钟是STM32的关键,是整个系统的心脏,时钟如何启动,时钟源如何选择,各个参数如何设置,我们从源码来简单分析一下时钟的启动函数Systemlnit()。 Systemlnit函数简介 我们先来看一下源程序的注释…...

事业单位ABCDE类

1 我刚刚查阅了一下安徽省市直单位报名的表,我这个专业报的岗位大多数是自然科学专技岗。 2 安徽省的岗位大多都限制计算机科学与技术,我这个0854计算机技术能报的岗位十分有限。 而且我没有看到一个岗位只招应届生,显然安徽不保护应届生的…...

Python:函数(一)

python函数相关的知识点 1. 函数定义与调用 定义:使用 def 关键字,后接函数名和参数列表。 def greet(name):"""打印问候语(文档字符串)"""print(f"Hello, {name}!") 调用&#xff1a…...

MySql学习_基础Sql语句

目录 1.数据库相关概念 2.SQL 2.1 SQL通用语法 2.2 SQL分类 2.3 DDL(数据库定义语言) 2.4 DML(数据操作语言) 2.5 DQL(数据查询语言) 2.6 DCL(数据控制语言) 3. 函数 3.1 字…...

Nginx 生产环境安全配置加固

以下是Nginx生产环境安全配置加固的综合方案,结合多个技术实践和行业标准整理: 一、基础安全防护 1‌. 隐藏版本信息‌ 在http或server块添加server_tokens off;,避免暴露Nginx版本号‌。使用headers-more-nginx-module模块彻底移除响应头…...

C#中继承的核心定义‌

1. 继承的核心定义‌ ‌继承‌ 是面向对象编程(OOP)的核心特性之一,允许一个类(称为‌子类/派生类‌)基于另一个类(称为‌父类/基类‌)构建,自动获得父类的成员(字段、属…...

小白学Agent技术[5](Agent框架)

文章目录 Agent框架Single Agent框架BabyAGIAutoGPTHuggingGPTHuggingGPT工作原理说明GPT-EngineerAppAgentOS-Copilot Multi-Agent框架斯坦福虚拟小镇TaskWeaverMetaGPT微软UFOAgentScope现状 常见Agent项目比较概述技术规格和能力实际应用案例开发体验比较ChatChain模式 Agen…...

21.dirsearch:Web 路径扫描工具

一、项目介绍 dirsearch 是一款高效、多线程的 Web 路径扫描工具,专为渗透测试人员和网络安全研究人员设计,用于发现目标网站的隐藏目录、敏感文件及未授权接口。其支持自定义字典、代理配置、请求头伪装等功能,适用于红队渗透、漏洞挖掘及资…...

VSTO(C#)Excel开发4:打印设置

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 源码指引:github源…...

设计模式Python版 模板方法模式(上)

文章目录 前言一、模板方法模式二、模板方法模式示例 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式:关注类和对象之间…...

源IP泄露后如何涅槃重生?高可用架构与自动化防御体系设计

一、架构层解决方案 1. 高防代理架构设计 推荐架构: 用户 → CDN(缓存静态资源) → 高防IP(流量清洗) → 源站集群(真实IP隐藏) ↑ Web应用防火墙(WAF) 实施要点&a…...

transformer bert 多头自注意力

输入的(a1,a2,a3,a4)是最终嵌入,是一个(512,768)的矩阵;而a1是一个token,尺寸是768 a1通过wq权重矩阵,经过全连接变换得到查询向量q1;a2通过Wk权重矩阵得到键向量k2;q和k点乘就是值…...

python-leetcode-定长子串中元音的最大数目

1456. 定长子串中元音的最大数目 - 力扣(LeetCode) 可以使用 滑动窗口 方法来解决这个问题。步骤如下: 初始化:计算前 k 个字符中元音字母的个数,作为初始窗口的值。滑动窗口:遍历字符串,每次右…...

Spring Boot + MyBatis-Plus 项目目录结构

以下是一个标准的 Spring Boot MyBatis-Plus 项目目录结构及文件命名规范,包含每个目录和文件的作用说明,适用于中大型项目开发: 项目根目录结构 src/ ├── main/ │ ├── java/ # Java 源代码 │ │ └── com/…...

Python之变量及简单的数据类型

本文来源于《Python从入门到实践》,自己整理以供工作参考 基本内容 print("Hello Python World!")message "Hello Python world!" print(message)message "Helllo Python Crash Course world!" print(message)name "ada lov…...

力扣 Hot 100 刷题记录 - 翻转二叉树

力扣 Hot 100 刷题记录 - 翻转二叉树 题目描述 翻转二叉树 是力扣 Hot 100 中的一道经典题目,题目要求如下: 给你一棵二叉树的根节点 root,翻转这棵二叉树,并返回其根节点。 示例 1: 输入:root [4,2,7…...

力扣215.数组中的第K个最大元素--堆排序法(java)

为了找到数组中第K个最大的元素,我们可以使用堆排序的方法。堆排序的核心是构建一个最大堆,并通过多次交换堆顶元素来找到前K个最大的元素。具体步骤如下: 方法思路 构建最大堆:将输入数组转换为最大堆,使得每个父节…...

MySQL增删改查操作 -- CRUD

个人主页:顾漂亮 目录 1.CRUD简介 2.Create新增 使用示例: 注意点: 3.Retrieve检索 使用示例: 注意点: 4.where条件查询 前置知识:-- 运算符 比较运算符 使用示例: 注意点&#xf…...

【算法day9】回文数-给你一个整数 x ,如果 x 是一个回文整数,返回 true ;否则,返回 false 。

回文数 给你一个整数 x ,如果 x 是一个回文整数,返回 true ;否则,返回 false 。 回文数是指正序(从左向右)和倒序(从右向左)读都是一样的整数。 例如,121 是回文&#…...