当前位置: 首页 > article >正文

从LeNet到ResNet:用PyTorch官方Demo理解卷积神经网络(CNN)的演进与核心模块

从LeNet到ResNetPyTorch实战中的CNN架构演进与模块化设计卷积神经网络CNN的发展史就是一部深度学习技术的进化简史。1998年诞生的LeNet-5在MNIST手写数字识别任务上一战成名却因算力限制沉寂多年2012年AlexNet凭借GPU算力和ReLU激活函数在ImageNet竞赛中掀起革命2014年VGG用整齐的3x3卷积堆叠证明深度决定性能2015年ResNet更以残差连接突破千层网络训练瓶颈。这些里程碑背后是卷积、池化、全连接等基础模块的持续创新与组合进化。本文将带您用PyTorch亲手实现这些经典网络通过CIFAR-10分类任务对比不同架构的设计哲学。不同于简单调用现成模型我们会从LeNet的每一行代码出发逐步拆解现代CNN的模块化设计精髓——如何用nn.Module构建可复用的网络组件如何通过继承机制实现架构快速迭代以及为什么说ResNet的残差块设计改变了深度学习的游戏规则。1. LeNet-5CNN的启蒙设计在Jupyter Notebook中新建一个PyTorch环境让我们从最基础的LeNet实现开始import torch.nn as nn import torch.nn.functional as F class LeNet(nn.Module): def __init__(self): super().__init__() self.conv1 nn.Conv2d(3, 16, 5) # 输入通道3(RGB), 输出16通道, 5x5卷积核 self.pool1 nn.MaxPool2d(2, 2) # 2x2最大池化, 步长2 self.conv2 nn.Conv2d(16, 32, 5) self.pool2 nn.MaxPool2d(2, 2) self.fc1 nn.Linear(32*5*5, 120) # 展平后全连接 self.fc2 nn.Linear(120, 84) self.fc3 nn.Linear(84, 10) # CIFAR-10共10类 def forward(self, x): x F.relu(self.conv1(x)) # [3,32,32] - [16,28,28] x self.pool1(x) # - [16,14,14] x F.relu(self.conv2(x)) # - [32,10,10] x self.pool2(x) # - [32,5,5] x x.view(-1, 32*5*5) # 展平处理 x F.relu(self.fc1(x)) # - 120维 x F.relu(self.fc2(x)) # - 84维 x self.fc3(x) # - 10维输出 return x这个不足30行的类包含了CNN最原始的三个设计智慧局部感受野5x5卷积核模拟生物视觉的局部感知特性参数共享同一卷积核滑动扫描整张图像大幅减少参数量空间降采样池化层逐步压缩特征图尺寸增强平移不变性在CIFAR-10上训练5个epoch后测试准确率约65%。这个成绩在今天看来平平无奇但请注意LeNet的几个历史局限仅2个卷积层感受野有限全连接层参数量占比超过90%容易过拟合使用Sigmoid激活函数原始版本存在梯度消失问题提示现代实现已将原始Sigmoid替换为ReLU这是提升经典模型性能的常用技巧2. VGG深度革命的标准化范式2014年牛津大学Visual Geometry Group提出的VGG网络确立了CNN架构的若干标准实践设计选择VGG贡献现代影响小卷积核堆叠用连续3x3卷积替代大卷积核成为行业标准设计统一模块设计每阶段固定2-3个卷积1个池化启发了后续ResNet等模块化设计通道数翻倍规则每次池化后通道数×2仍广泛使用的经验法则以下是VGG-16的PyTorch实现关键片段class VGGBlock(nn.Module): 可复用的VGG基础块 def __init__(self, in_channels, out_channels, num_convs): super().__init__() layers [] for _ in range(num_convs): layers [ nn.Conv2d(in_channels, out_channels, kernel_size3, padding1), nn.ReLU(inplaceTrue) ] in_channels out_channels layers.append(nn.MaxPool2d(kernel_size2, stride2)) self.block nn.Sequential(*layers) def forward(self, x): return self.block(x) class VGG16(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( VGGBlock(3, 64, 2), # Stage1: 2个卷积, 输出64通道 VGGBlock(64, 128, 2), # Stage2: 2个卷积, 输出128通道 VGGBlock(128, 256, 3), # Stage3: 3个卷积 VGGBlock(256, 512, 3), # Stage4: 3个卷积 VGGBlock(512, 512, 3) # Stage5: 3个卷积 ) self.classifier nn.Sequential( nn.Linear(512*1*1, 4096), # 原输入224x224CIFAR-10经5次池化后为7x7 nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 10) ) def forward(self, x): x self.features(x) x x.view(x.size(0), -1) x self.classifier(x) return xVGG的模块化设计带来了几个显著优势参数效率两个3x3卷积(9918参数)比一个5x5卷积(25参数)感受野更大深度可扩展通过堆叠相同模块轻松增加网络深度训练稳定性小卷积核的梯度传播更平稳在相同训练条件下VGG-16在CIFAR-10上的准确率可达约75%比LeNet提升10个百分点。但它的全连接层仍占用大量参数约1.2亿参数中1亿在全连接层这催生了后续架构的进一步革新。3. ResNet残差连接破解深度难题当网络深度超过20层后准确率不升反降——这是2015年之前困扰研究者的梯度消失难题。ResNet的残差块Residual Block通过跨层连接skip connection创造了一条梯度高速公路class ResidualBlock(nn.Module): 基本的残差块单元 def __init__(self, in_channels, out_channels, stride1): super().__init__() self.conv1 nn.Conv2d(in_channels, out_channels, kernel_size3, stridestride, padding1, biasFalse) self.bn1 nn.BatchNorm2d(out_channels) self.conv2 nn.Conv2d(out_channels, out_channels, kernel_size3, stride1, padding1, biasFalse) self.bn2 nn.BatchNorm2d(out_channels) # 当输入输出维度不一致时使用1x1卷积调整维度 self.shortcut nn.Sequential() if stride ! 1 or in_channels ! out_channels: self.shortcut nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size1, stridestride, biasFalse), nn.BatchNorm2d(out_channels) ) def forward(self, x): residual self.shortcut(x) out F.relu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out residual # 关键残差连接 return F.relu(out)残差块的核心创新在于将传统的H(x)学习目标改为H(x)F(x)x即让网络学习残差函数F(x)H(x)-x。这一改变带来了三个深远影响梯度直通通过加法操作梯度可以绕过卷积层直接反向传播恒等映射当残差为0时网络自动退化为浅层模型深度鲁棒实验证明残差网络可轻松训练1000层以上的模型完整的ResNet-18实现如下class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes10): super().__init__() self.in_channels 64 self.conv1 nn.Conv2d(3, 64, kernel_size3, stride1, padding1, biasFalse) self.bn1 nn.BatchNorm2d(64) self.layer1 self._make_layer(block, 64, num_blocks[0], stride1) self.layer2 self._make_layer(block, 128, num_blocks[1], stride2) self.layer3 self._make_layer(block, 256, num_blocks[2], stride2) self.layer4 self._make_layer(block, 512, num_blocks[3], stride2) self.linear nn.Linear(512, num_classes) def _make_layer(self, block, out_channels, num_blocks, stride): strides [stride] [1]*(num_blocks-1) layers [] for stride in strides: layers.append(block(self.in_channels, out_channels, stride)) self.in_channels out_channels return nn.Sequential(*layers) def forward(self, x): out F.relu(self.bn1(self.conv1(x))) out self.layer1(out) out self.layer2(out) out self.layer3(out) out self.layer4(out) out F.avg_pool2d(out, 4) out out.view(out.size(0), -1) out self.linear(out) return out在CIFAR-10上ResNet-18仅用5个epoch就能达到80%以上的准确率训练曲线也显示出更快的收敛速度。下表对比了三种架构的关键指标指标LeNet-5VGG-16ResNet-18参数量(M)0.0615.211.2训练准确率(%)65.275.882.4训练时间/epoch42s3.2m2.8m最大有效深度2层卷积13层卷积18层带残差4. PyTorch模块化设计进阶技巧现代CNN实现已形成一套成熟的模块化设计范式以下是三个提升代码质量的实用技巧1. 可配置化网络构建def build_model(archresnet18, num_classes10): if arch lenet: return LeNet() elif arch vgg16: return VGG16() elif arch resnet18: return ResNet(ResidualBlock, [2,2,2,2], num_classes) else: raise ValueError(fUnknown architecture: {arch})2. 动态计算全连接层输入尺寸避免手动计算展平后的维度class SmartFlatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ImprovedNet(nn.Module): def __init__(self): super().__init__() self.features nn.Sequential( # 卷积层定义... ) self.flatten SmartFlatten() # 先创建空的全连接层 self.classifier nn.Linear(0, 10) # 0为占位符 def forward(self, x): x self.features(x) x self.flatten(x) # 动态调整全连接层 if self.classifier.in_features 0: self.classifier nn.Linear(x.size(1), 10).to(x.device) return self.classifier(x)3. 混合精度训练加速利用PyTorch的AMP模块实现自动混合精度训练from torch.cuda.amp import autocast, GradScaler scaler GradScaler() for epoch in range(epochs): for inputs, labels in train_loader: inputs, labels inputs.to(device), labels.to(device) optimizer.zero_grad() with autocast(): # 自动选择运算精度 outputs model(inputs) loss criterion(outputs, labels) scaler.scale(loss).backward() # 缩放梯度 scaler.step(optimizer) # 更新参数 scaler.update() # 调整缩放系数这些技巧在实际工程中能显著提升开发效率和训练速度。例如在NVIDIA V100上混合精度训练可使ResNet-18的每个epoch时间从2.8分钟缩短到1.5分钟而准确率基本保持不变。

相关文章:

从LeNet到ResNet:用PyTorch官方Demo理解卷积神经网络(CNN)的演进与核心模块

从LeNet到ResNet:PyTorch实战中的CNN架构演进与模块化设计 卷积神经网络(CNN)的发展史就是一部深度学习技术的进化简史。1998年诞生的LeNet-5在MNIST手写数字识别任务上一战成名,却因算力限制沉寂多年;2012年AlexNet凭…...

从S-Function到系统级验证:构建可复用的16QAM Simulink自定义模块库

1. 为什么需要自定义Simulink模块库 在通信系统仿真中,我们经常遇到标准模块库无法满足特定需求的情况。就拿16QAM调制解调来说,虽然Simulink自带通信工具箱,但实际项目中往往需要更灵活的配置和更直观的参数调整界面。我刚开始做通信仿真时…...

别再让扰动拖后腿!手把手教你用MATLAB/Simulink实现非线性系统的干扰观测器(附完整代码)

非线性系统扰动观测器实战:从理论到MATLAB代码的完整实现指南 在控制工程实践中,非线性系统的干扰抑制一直是工程师面临的棘手挑战。想象一下,你正在调试一台工业机械臂,理论模型完美无缺,但实际运行时总是出现无法解…...

魔兽争霸3终极优化工具:5分钟搞定所有兼容性问题

魔兽争霸3终极优化工具:5分钟搞定所有兼容性问题 【免费下载链接】WarcraftHelper Warcraft III Helper , support 1.20e, 1.24e, 1.26a, 1.27a, 1.27b 项目地址: https://gitcode.com/gh_mirrors/wa/WarcraftHelper 还在为《魔兽争霸3》在现代电脑上的各种问…...

如何构建高效完整的抖音直播实时数据采集系统:深度解析WebSocket与Protobuf技术方案

如何构建高效完整的抖音直播实时数据采集系统:深度解析WebSocket与Protobuf技术方案 【免费下载链接】DouyinLiveWebFetcher 抖音直播间网页版的弹幕数据抓取(2025最新版本) 项目地址: https://gitcode.com/gh_mirrors/do/DouyinLiveWebFet…...

高速接口EMI抑制:共模扼流圈选型与设计实战

1. 高速数据接口中的EMI挑战与共模扼流圈原理在USB3.1 Gen2、HDMI2.1等高速数据接口设计中,信号完整性工程师最头疼的问题莫过于电磁干扰(EMI)。当数据传输速率突破10Gbps时,电缆会变成高效的天线,将共模噪声辐射到周围…...

Arm服务器架构设计:虚拟化与安全增强解析

1. Arm服务器基础架构设计哲学 现代Arm服务器架构的设计核心在于"硬件虚拟化优先"理念。与传统x86架构渐进式添加虚拟化功能不同,Armv8/v9架构从设计之初就将虚拟化支持作为基础能力。这种设计哲学在SBSA(Server Base System Architecture&…...

Twitter 用户信息 API 集成指南

在这篇文章中,我们将介绍如何集成 Twitter 用户信息 API。利用这个 API,您可以获取 Twitter 用户的详细信息。只需输入 Twitter 用户的用户名,就能够输出该用户的 Twitter 主页信息。 环境准备 要使用此 API,您需要在 Twitter 用…...

MySQL 临时表详解

MySQL 临时表详解 引言 在MySQL数据库中,临时表是一种非常有用的工具,它可以帮助我们在查询过程中临时存储数据。本文将详细探讨MySQL临时表的概念、使用方法、优缺点以及在实际开发中的应用。 一、什么是MySQL临时表? MySQL临时表是一种在服务器会话期间创建的表,它仅…...

5分钟免费备份QQ空间:GetQzonehistory终极数据拯救指南

5分钟免费备份QQ空间:GetQzonehistory终极数据拯救指南 【免费下载链接】GetQzonehistory 获取QQ空间发布的历史说说 项目地址: https://gitcode.com/GitHub_Trending/ge/GetQzonehistory 还在担心QQ空间里那些承载青春回忆的说说会随着时间流逝而消失吗&…...

为OpenClaw智能体工作流配置Taotoken作为统一的模型服务后端

🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 为OpenClaw智能体工作流配置Taotoken作为统一的模型服务后端 对于使用OpenClaw框架构建AI智能体的开发者而言,一个稳定…...

3个步骤让Windows任务栏焕然一新:TranslucentTB如何改变你的桌面体验?

3个步骤让Windows任务栏焕然一新:TranslucentTB如何改变你的桌面体验? 【免费下载链接】TranslucentTB A lightweight utility that makes the Windows taskbar translucent/transparent. 项目地址: https://gitcode.com/gh_mirrors/tr/TranslucentTB …...

ThinkPad风扇控制终极指南:TPFanCtrl2实现128级精准调速与双风扇独立管理

ThinkPad风扇控制终极指南:TPFanCtrl2实现128级精准调速与双风扇独立管理 【免费下载链接】TPFanCtrl2 ThinkPad Fan Control 2 (Dual Fan) for Windows 10 and 11 项目地址: https://gitcode.com/gh_mirrors/tp/TPFanCtrl2 TPFanCtrl2是一款专为ThinkPad笔记…...

酷安UWP桌面版:在Windows上体验酷安社区的最佳指南

酷安UWP桌面版:在Windows上体验酷安社区的最佳指南 【免费下载链接】Coolapk-UWP 一个基于 UWP 平台的第三方酷安客户端 项目地址: https://gitcode.com/gh_mirrors/co/Coolapk-UWP 还在为手机屏幕太小而烦恼吗?想要在大屏幕上舒适浏览酷安社区内…...

深入AMD Ryzen硬件调试:SMUDebugTool技术原理与高级应用指南

深入AMD Ryzen硬件调试:SMUDebugTool技术原理与高级应用指南 【免费下载链接】SMUDebugTool A dedicated tool to help write/read various parameters of Ryzen-based systems, such as manual overclock, SMU, PCI, CPUID, MSR and Power Table. 项目地址: http…...

云函数window hook分析

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包 内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!侵权通过头像私信或名字简介叫我删除博…...

山姆小程序云网关数据hook主动调用分析

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包 内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!侵权通过头像私信或名字简介叫我删除博…...

BetterGI原神自动化助手:告别重复操作,解放双手的终极指南

BetterGI原神自动化助手:告别重复操作,解放双手的终极指南 【免费下载链接】better-genshin-impact 📦BetterGI 更好的原神 - 自动拾取 | 自动剧情 | 全自动钓鱼(AI) | 全自动七圣召唤 | 自动伐木 | 自动刷本 | 自动采集/挖矿/锄地 | 一条龙…...

QQ音乐加密音频解密:qmcdump实用指南与完整教程

QQ音乐加密音频解密:qmcdump实用指南与完整教程 【免费下载链接】qmcdump 一个简单的QQ音乐解码(qmcflac/qmc0/qmc3 转 flac/mp3),仅为个人学习参考用。 项目地址: https://gitcode.com/gh_mirrors/qm/qmcdump 你是否遇到过…...

GitHubCopilot与Gemini3.1Pro协同开发实战

在 2026 年,AI 编程工具的差异已经从“谁能写代码”转向“谁能把代码写对、写稳、写得可维护”。很多团队开始采用“双引擎协作”:GitHub Copilot 负责快速生成与代码补全,而 Gemini 3.1 Pro 负责更强的推理、架构级建议、测试策略与长上下文…...

如何快速上手Python财经数据分析:AKShare完整新手指南

如何快速上手Python财经数据分析:AKShare完整新手指南 【免费下载链接】akshare AKShare is an elegant and simple financial data interface library for Python, built for human beings! 开源财经数据接口库 项目地址: https://gitcode.com/gh_mirrors/aks/ak…...

如何彻底告别系统配置烦恼:KMS智能脚本完整使用指南

如何彻底告别系统配置烦恼:KMS智能脚本完整使用指南 【免费下载链接】KMS_VL_ALL_AIO Smart Activation Script 项目地址: https://gitcode.com/gh_mirrors/km/KMS_VL_ALL_AIO 你是否厌倦了Windows系统频繁出现的功能限制提示?是否因为Office突然…...

D3KeyHelper终极指南:暗黑3鼠标宏工具高效配置与实战应用

D3KeyHelper终极指南:暗黑3鼠标宏工具高效配置与实战应用 【免费下载链接】D3keyHelper D3KeyHelper是一个有图形界面,可自定义配置的暗黑3鼠标宏工具。 项目地址: https://gitcode.com/gh_mirrors/d3/D3keyHelper D3KeyHelper是一款专为暗黑破坏…...

ImageGlass终极指南:5分钟掌握这款轻量级图片查看器的完整使用技巧

ImageGlass终极指南:5分钟掌握这款轻量级图片查看器的完整使用技巧 【免费下载链接】ImageGlass 🏞 A lightweight, versatile image viewer 项目地址: https://gitcode.com/gh_mirrors/im/ImageGlass ImageGlass是一款专为Windows系统设计的轻量…...

SITS 2026正式版将于2024Q3封版,这7类测试团队必须在GA前掌握的AI原生适配策略(限内部技术预览通道)

更多请点击: https://intelliparadigm.com 第一章:AI原生测试方法革新:SITS 2026自动化测试新思路 SITS 2026(Semantic Intelligence Testing Suite)标志着测试范式从脚本驱动向语义感知与上下文自适应的跃迁。它不再…...

AG Grid实战:用‘列组伸缩’和‘行组展开’构建一个清晰的学生成绩分析表

AG Grid实战:用‘列组伸缩’和‘行组展开’构建清晰的学生成绩分析表 在数据密集型的教育管理系统中,如何高效呈现学生成绩数据一直是开发者面临的挑战。传统的表格往往因为信息过载而显得杂乱无章,而简单的折叠功能又难以满足多层级分析需求…...

Linux df 命令深度解析:从磁盘空间监控到 inode 耗尽排查

服务器磁盘满了,SSH 登录都报错 No space left on device。第一反应就是敲 df -h,但有时候明明显示还有空间,却还是报错——这是 inode 耗尽了。深入了解 df 命令后,发现这个看似简单的工具其实藏着不少门道。 df 的底层实现&…...

Vivado 2018.3联合Modelsim SE 10.6d仿真全流程:从库编译到成功调用IP核的实战记录

Vivado与Modelsim联合仿真全流程:从环境配置到IP核验证的深度实践 在FPGA开发领域,仿真验证环节往往决定着项目成败。作为Xilinx官方工具链的核心组合,Vivado与Modelsim的联合使用既能发挥Vivado在综合与实现阶段的优势,又能利用M…...

【权威预警】SITS 2026注册系统将于3月15日关闭早鸟通道——附2025参会者未公开的6条避坑清单

更多请点击: https://intelliparadigm.com 第一章:SITS 2026上海站定档4月:2026奇点智能技术大会报名通道开启 大会核心信息速览 SITS(Singularity Intelligence Technology Summit)2026上海站正式官宣:将…...

Java——继承实现的基本原理

继承实现的基本原理1、示例2、类加载过程3、对象创建的过程4、方法调用的过程5、变量访问的过程6、继承是把双刃剑6.1、继承破坏封装6.2、封装是如何被破坏的6.3、继承没有反映is-a关系6.4、如何应对继承的双面性1、示例 Base类: public class Base {public stati…...