当前位置: 首页 > article >正文

告别CNN,用ViT做图像分类真的更牛吗?手把手带你复现ViT核心步骤(附PyTorch代码)

视觉Transformer实战从零构建ViT模型并对比CNN性能差异当ResNet还在计算机视觉领域占据主导地位时Google Research的一篇论文《AN IMAGE IS WORTH 16X16 WORDS》彻底改变了游戏规则。视觉Transformer(ViT)的出现让传统卷积神经网络(CNN)的铁王座开始动摇。但ViT真的在所有场景下都比CNN更优秀吗本文将带你亲手拆解ViT的核心组件并用PyTorch实现一个精简版ViT最后在CIFAR-10数据集上与CNN进行实战对比。1. 为什么需要ViTTransformer在CV领域的破局之道传统CNN通过卷积核的滑动窗口捕捉局部特征这种归纳偏置(inductive bias)虽然高效但也可能限制模型捕捉长距离依赖的能力。ViT的突破性在于全局注意力机制每个patch都能直接与其他所有patch交互序列化处理将图像视为patch序列类比NLP中的token处理可扩展性模型容量随token数量增加而提升不受固定感受野限制但ViT并非完美无缺其显著缺点包括数据饥渴需要大规模预训练才能发挥优势计算开销自注意力机制的时间复杂度随token数量平方增长位置信息依赖必须显式编码位置信息不像CNN天然具有平移不变性# 计算复杂度对比公式 def complexity_compare(n, d): cnn n * d**2 # 卷积计算复杂度 vit n**2 * d # 自注意力计算复杂度 return cnn, vit表CNN与ViT在224x224图像上的计算量对比模型类型FLOPs参数量需要预训练数据量ResNet504.1G25M中等(ImageNet级别)ViT-Base17.6G86M极大(JFT-300M级别)2. ViT核心组件拆解与PyTorch实现2.1 Patch Embedding图像到序列的转换艺术ViT的第一步是将图像分割为固定大小的patch然后展平为向量。假设输入图像为224×224×3patch大小为16×16每个patch的原始维度16×16×3768patch数量(224/16)^2196通过线性投影将768维映射到模型维度(如1024)import torch import torch.nn as nn class PatchEmbedding(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim1024): super().__init__() self.img_size img_size self.patch_size patch_size self.n_patches (img_size // patch_size) ** 2 self.proj nn.Conv2d( in_chans, embed_dim, kernel_sizepatch_size, stridepatch_size ) def forward(self, x): x self.proj(x) # (B, E, H/P, W/P) x x.flatten(2) # (B, E, N) x x.transpose(1, 2) # (B, N, E) return x提示实际应用中patch大小是需要调优的超参数。较小的patch能捕捉更精细特征但会增加序列长度。2.2 Position Embedding为视觉序列注入空间信息与CNN不同ViT需要显式编码位置信息。常见方案包括可学习的位置编码随机初始化并与模型共同训练相对位置编码编码patch间的相对位置关系二维正弦编码将二维位置分解为行和列分别编码class ViT(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim1024, depth12): super().__init__() self.patch_embed PatchEmbedding(img_size, patch_size, in_chans, embed_dim) self.pos_embed nn.Parameter(torch.zeros(1, self.patch_embed.n_patches 1, embed_dim)) self.cls_token nn.Parameter(torch.zeros(1, 1, embed_dim)) def forward(self, x): B x.shape[0] x self.patch_embed(x) # (B, N, E) cls_tokens self.cls_token.expand(B, -1, -1) x torch.cat((cls_tokens, x), dim1) x x self.pos_embed return x2.3 Transformer Encoder自注意力的魔力ViT使用标准Transformer编码器包含多头自注意力(MSA)和前馈网络(FFN)class TransformerBlock(nn.Module): def __init__(self, embed_dim1024, num_heads8): super().__init__() self.norm1 nn.LayerNorm(embed_dim) self.attn nn.MultiheadAttention(embed_dim, num_heads) self.norm2 nn.LayerNorm(embed_dim) self.mlp nn.Sequential( nn.Linear(embed_dim, 4*embed_dim), nn.GELU(), nn.Linear(4*embed_dim, embed_dim) ) def forward(self, x): x x self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] x x self.mlp(self.norm2(x)) return x3. 实战对比ViT vs CNN在CIFAR-10上的表现3.1 实验设置我们在CIFAR-10数据集上对比精简版ViT6层Transformerembed_dim256num_heads8对比CNN4个卷积层2个全连接层训练配置Adam优化器lr3e-4batch_size6450个epochfrom torchvision import datasets, transforms transform transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_data datasets.CIFAR10(./data, trainTrue, downloadTrue, transformtransform) test_data datasets.CIFAR10(./data, trainFalse, downloadTrue, transformtransform)3.2 结果分析表ViT与CNN在CIFAR-10上的性能对比指标ViT模型CNN模型训练准确率78.2%85.6%测试准确率72.4%80.3%训练时间/epoch142s68s参数量3.2M1.8M关键发现数据效率在小规模数据集上CNN表现优于ViT收敛速度CNN训练更快ViT需要更多epoch达到稳定过拟合ViT表现出更强的过拟合倾向注意这个结果不能直接推广到大规模数据集。当使用ImageNet或更大数据集预训练后ViT通常能超越CNN。4. 何时选择ViT技术选型决策指南基于实验结果和理论分析我们总结出以下决策框架适合ViT的场景拥有海量训练数据(百万级图像)需要建模长距离依赖的任务(如全景分割)计算资源充足追求state-of-the-art性能适合CNN的场景中小规模数据集(万级图像)实时性要求高的应用边缘设备部署场景混合架构(Hybrid)的兴起 最近的研究如ConViT、CeIT等尝试结合CNN的局部性和Transformer的全局建模能力class HybridModel(nn.Module): def __init__(self): super().__init__() self.cnn_backbone resnet18(pretrainedTrue) self.transformer TransformerBlock(embed_dim512) def forward(self, x): cnn_features self.cnn_backbone(x) transformer_features self.transformer(cnn_features) return transformer_features在实际项目中我多次遇到团队在模型选型时的困惑。一个经验法则是当你不确定时先用CNN baseline快速验证想法等数据规模足够大再考虑切换到ViT或混合架构。这种渐进式策略能有效降低技术风险。

相关文章:

告别CNN,用ViT做图像分类真的更牛吗?手把手带你复现ViT核心步骤(附PyTorch代码)

视觉Transformer实战:从零构建ViT模型并对比CNN性能差异 当ResNet还在计算机视觉领域占据主导地位时,Google Research的一篇论文《AN IMAGE IS WORTH 16X16 WORDS》彻底改变了游戏规则。视觉Transformer(ViT)的出现,让传统卷积神经网络(CNN)的…...

AI Agent实战专栏导读:6周掌握智能代理开发(含完整代码)

🎯 8篇深度教程 5个完整项目 | 完全免费 | 代码开源可运行 📖 专栏介绍 欢迎来到 AI Agent实战专栏! 这是国内首个系统化的AI Agent实战教程系列,从基础概念到企业级应用,带你全面掌握智能代理开发技术。 ✨ 专栏特…...

MPR121电容触摸传感器避坑指南:与Arduino UNO驱动WS2812时常见的3个问题及解决

MPR121电容触摸传感器与WS2812协同开发实战:避坑与性能优化指南 当你把MPR121电容触摸传感器和WS2812彩灯模块同时连接到Arduino UNO上时,事情往往不会像教程里展示的那样一帆风顺。触摸检测突然失灵、LED闪烁导致误触发、I2C通信时断时续——这些问题在…...

手把手教你调参:MATLAB中ellipord和ellipap函数设计椭圆滤波器的完整避坑指南

手把手教你调参:MATLAB中ellipord和ellipap函数设计椭圆滤波器的完整避坑指南 在数字信号处理领域,滤波器设计一直是工程师们面临的核心挑战之一。特别是当我们需要在有限的硬件资源下实现陡峭的过渡带特性时,椭圆滤波器往往成为最优选择。不…...

群体神经网络:分布式API调用与弹性计算新范式

1. 项目概述:群体神经网络如何重构函数与API调用 在传统分布式计算中,函数调用和API执行往往受限于单一节点的处理能力与可靠性。三年前我在构建一个高并发交易系统时,就曾因单个API节点崩溃导致整个服务雪崩。而群体神经网络(Swa…...

FPGA新手避坑指南:用Verilog在Spartan-6上搞定IS62LV256 SRAM读写(附完整代码)

FPGA实战:Spartan-6与IS62LV256 SRAM的Verilog高效驱动手册 第一次接触FPGA片外SRAM时,我盯着开发板上那个小小的IS62LV256芯片发呆了半小时——数据手册上密密麻麻的时序参数、三态总线的双向控制、状态机的精确跳转条件,每一个环节都可能成…...

避坑指南:YOLOv8-pose关键点训练数据准备,Labelme标注的3个常见错误与修复脚本

YOLOv8-pose关键点标注避坑实战:Labelme常见错误排查与自动化修复方案 当你第一次尝试用Labelme为YOLOv8-pose准备关键点检测数据时,大概率会在标注环节遇到几个"经典坑"。这些错误不会立即导致程序报错,却会让模型训练效果莫名其妙…...

英国AI初创公司Ineffable Intelligence获11亿美元种子轮融资,投后估值达51亿美元

11亿美元种子轮融资,欧洲最大规模纪录诞生4月28日消息,据TechCrunch报道,英国AI初创公司Ineffable Intelligence宣布完成11亿美元种子轮融资,投后估值达51亿美元,创下欧洲史上最大规模种子轮融资纪录。本轮融资由红杉资…...

微信数据解密完整指南:如何安全备份你的聊天记录

微信数据解密完整指南:如何安全备份你的聊天记录 【免费下载链接】PyWxDump 删库 项目地址: https://gitcode.com/GitHub_Trending/py/PyWxDump 微信作为我们日常沟通的重要工具,存储着大量珍贵的聊天记录、图片和文件。然而,这些数据…...

解锁论文降重新姿势:书匠策AI,你的学术减负小能手!

在学术的浩瀚海洋中,每一位学者或学生都像是勇敢的航海者,驾驶着知识的船只,探索未知的领域。然而,在撰写论文这一航程中,有一个让人头疼的“暗礁”——重复率过高。它不仅可能让你的辛勤努力付诸东流,还可…...

【必收藏】2026年大模型应用开发工程师趋势解析,小白程序员必看!

不夸张地说,对于程序员而言,未来5年最值得深耕、最有前景的技术发展方向,毫无疑问是AI大模型!尤其是2026年,随着大模型技术从“数字感知”迈向“物理认知”,行业迎来范式变革,无论是刚入门的编程…...

WindowsCleaner终极指南:告别C盘爆红,3步实现系统加速

WindowsCleaner终极指南:告别C盘爆红,3步实现系统加速 【免费下载链接】WindowsCleaner Windows Cleaner——专治C盘爆红及各种不服! 项目地址: https://gitcode.com/gh_mirrors/wi/WindowsCleaner 你是否曾经因为C盘爆红而焦虑不已&a…...

捡垃圾神器Tesla M40风冷改造全记录:从拆机到上机,Win11双显卡就这么配

Tesla M40风冷改造实战指南:低成本打造高性能计算平台 在硬件DIY的世界里,总有一些被市场低估的"宝藏"等待发掘。Tesla M40计算卡就是这样一个典型代表——它拥有24GB GDDR5显存和3072个CUDA核心,性能接近GTX 1080 Ti,但…...

ARM架构CNTHVS_CTL_EL2寄存器详解与虚拟定时器应用

1. ARM架构中的CNTHVS_CTL_EL2寄存器解析在ARMv8-A架构中,系统寄存器扮演着处理器与操作系统间关键桥梁的角色。作为安全虚拟定时器的控制核心,CNTHVS_CTL_EL2寄存器在虚拟化环境中发挥着不可替代的作用。这个64位寄存器专为Secure EL2虚拟定时器设计&am…...

避坑指南:PS2020安装Geographic Imager 6.2插件后,如何正确配置浮动许可(localhost:5053)

PS2020安装Geographic Imager 6.2插件浮动许可配置全攻略 当你在PS2020中成功安装Geographic Imager 6.2插件后,最令人头疼的往往是浮动许可的配置环节。不少用户反映,明明按照步骤安装了插件,却在最后一步卡在许可验证上,弹出各…...

3步掌握BiliTools:如何高效下载B站视频并提取AI智能总结

3步掌握BiliTools:如何高效下载B站视频并提取AI智能总结 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …...

微信语音导出mp3全攻略:手机免电脑、在线工具、格式工厂三种方法实测对比

微信语音导出MP3全攻略:三种方法实测与避坑指南 每次听到微信里珍贵的语音消息时,你是否想过把它们永久保存下来?无论是孩子第一次叫"爸爸妈妈"的稚嫩声音,还是商务谈判中的关键承诺,这些语音都值得用更通用…...

csp基础知识——分治、查找与排序

分治分治是一种思想,具体是在解决某类问题的一种解决思路,常常在排序算法中使用。当然用一个具体的例子可以快速了解一下。假设在一堆(n个)质量相同的真硬币中混入了一枚质量较轻的假硬币,现在要找出来,常规…...

终极NCM解密指南:3分钟解锁网易云音乐加密格式,让音乐自由播放

终极NCM解密指南:3分钟解锁网易云音乐加密格式,让音乐自由播放 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 还在为下载的网易云音乐NCM格式文件无法在其他播放器播放而烦恼吗?ncmdump是一款简单…...

Java 25 外部函数接口增强:仅剩72小时!OpenJDK 25正式版冻结前必须掌握的3个@ClangBinding兼容性开关

更多请点击: https://intelliparadigm.com 第一章:Java 25 外部函数接口增强概览 Java 25 正式将外部函数与内存 API(Foreign Function & Memory API)从预览状态转为正式特性(JEP 497),标…...

内存健康守护神:如何用Memtest86+彻底检测电脑内存故障

内存健康守护神:如何用Memtest86彻底检测电脑内存故障 【免费下载链接】memtest86plus Official repo for Memtest86 项目地址: https://gitcode.com/gh_mirrors/me/memtest86plus 你的电脑是否经常出现蓝屏、死机或数据损坏?这些恼人的问题很可能…...

[FRP]Windows 安装 frpc 客户端,以及P2P方式ssh配置

一. 下载 frpc 客户端程序 客户端程序下载地址:GITHUB官方仓库 。根据您的 CPU 类型选择合适的版本。 本教程以 v0.68.1 为例:选择 frp_0.68.1_windows_amd64.zip 下载。 二、解压文件 三、配置文件 frpc.toml serverAddr "服务端IP" ser…...

【优化调度】含氢气氨气综合能源系统优化调度【含Matlab源码 15394期】

💥💥💥💥💥💥💥💥💞💞💞💞💞💞💞💞💞Matlab武动乾坤博客之家💞…...

Vue2 转 Vue3 思维转变与工程实践

一、前言Vue2 转 Vue3 思维转变与工程实践 是当前技术圈热议的话题。本文从实际场景出发,帮你快速掌握核心要点。二、核心概念2.1 什么是Vue3Vue3是现代软件开发中不可或缺的一环,下面通过一个典型场景来理解它的核心价值。2.2 基本用法// 基础示例 asyn…...

开发者职业倦怠自救手册:找回编码的快乐——写给软件测试从业者的专业指南

我们为何“倦”了?在软件测试领域深耕多年后,许多从业者会经历这样一个阶段:曾经对发现Bug、保障质量充满热情,如今却感到重复、枯燥甚至迷茫。每天面对相似的测试用例、无穷的回归测试、复杂的自动化脚本维护,以及不断…...

【仅限头部金融级用户知晓】Java 25 ZGC 2.0生产调优白皮书(含JFR采样模板与火焰图标注规范)

更多请点击: https://intelliparadigm.com 第一章:Java 25 ZGC 2.0 生产调优白皮书导论 ZGC 2.0 是 Java 25 中面向超低延迟场景的下一代垃圾收集器重大演进,其核心目标是将 GC 停顿时间稳定控制在 **1ms 以内**(P99 ≤ 0.8ms&am…...

HarmonyOS Tabs组件自定义遮罩效果全解析

引言:提升tabBar视觉体验的遮罩技术在HarmonyOS应用开发中,Tabs组件作为常见的导航控件,广泛应用于各类内容切换场景。然而,当tabBar页签内容过长且采用可滚动模式时,简单的背景色设置往往无法提供理想的视觉体验——用…...

React组件化开发全解析,前端现代必备知识

我们来深入、系统地拆解 React 前端技术。 一、核心概念:React 是什么? React 是一个用于构建用户界面的 JavaScript 库(注意,它不是框架)。它的核心思想是组件化和声明式编程。你可以把它想象成乐高积木&#xff1a…...

每日AI新闻推送:具身智能、芯片与大模型的最新突破(2026.04.26)

为您精选过去24小时内全球最具影响力的10条科技新闻,涵盖具身智能、机器人、芯片、大模型与应用四大核心领域。 🤖 具身智能与机器人:从“能动”迈向“会干”的元年 1. 智元机器人宣布2026为“部署态元年”,万台下线开启工业化落…...

终极指南:3分钟掌握FF14过场动画跳过插件的完整使用技巧

终极指南:3分钟掌握FF14过场动画跳过插件的完整使用技巧 【免费下载链接】FFXIV_ACT_CutsceneSkip 项目地址: https://gitcode.com/gh_mirrors/ff/FFXIV_ACT_CutsceneSkip 还在为《最终幻想14》中重复的副本过场动画浪费时间吗?FFXIV_ACT_Cutsce…...