基于CNN的FashionMNIST数据集识别5——GoogleNet模型
源码
import torch
from torch import nn
from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu = nn.ReLU()#路径1self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)#路径2self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)#路径3self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)#路径4self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLu(self.p1_1(x))p2 =self.ReLu(self.p2_2(self.ReLu(self.p2_1(x))))p3 =self.ReLu(self.p3_2(self.ReLu(self.p3_1(x))))p4 =self.ReLu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)class GoogleNet(nn.Module):def __init__(self, Inception):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=192, kernel_size=3, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (128, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.block5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(),nn.Linear(1024, 10))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0 ,0.01)if m.bias is not None:nn.init.constant_(m.bias, 0)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)return xif __name__ == "__main__":device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogleNet(Inception).to(device)print(summary(model, (1, 224, 224)))

从整个链路上看,googlenet的复杂度相比于之前我们提到的cnn网络更复杂。仔细分析可以看到,googlenet的网络结构里面有多个核心模块inception。搞懂inception就基本搞清楚了googlenet。
Inception

Inception 模块的设计动机
-
传统串联卷积的局限性
- 传统网络通过堆叠卷积层逐步提取特征,但不同尺度的特征(如边缘、纹理、物体部件)需不同大小的卷积核。
- 堆叠大卷积核(如 5x5)会导致计算量暴增(参数会增加很多)。
-
关键优化目标
- 多尺度特征融合:同时提取不同尺度的特征。
- 减少计算量:通过 1x1 卷积降维,控制参数规模。
Inception模块设计思路
- 并行多分支设计:Inception模块包含多个并行分支,典型结构包括1x1卷积、3x3卷积、5x5卷积和3x3最大池化层。不同尺寸的卷积核可同时捕捉局部细节和全局特征。
- 特征图拼接:各分支输出的特征图在通道维度进行拼接,形成综合特征表达,增强模型对不同尺度的适应性。从图片可以看到,每个inception块有四条路径,之前的cnn大多是单一路径。
class Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu = nn.ReLU()#路径1self.p1_1 = nn.Conv2d(in_channels=in_channels, out_channels=c1, kernel_size=1)#路径2self.p2_1 = nn.Conv2d(in_channels=in_channels, out_channels=c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(in_channels=c2[0], out_channels=c2[1], kernel_size=3, padding=1)#路径3self.p3_1 = nn.Conv2d(in_channels=in_channels, out_channels=c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(in_channels=c3[0], out_channels=c3[1], kernel_size=5, padding=2)#路径4self.p4_1 = nn.MaxPool2d(kernel_size=3, padding=1, stride=1)self.p4_2 = nn.Conv2d(in_channels=in_channels, out_channels=c4, kernel_size=1)def forward(self, x):p1 = self.ReLu(self.p1_1(x))p2 =self.ReLu(self.p2_2(self.ReLu(self.p2_1(x))))p3 =self.ReLu(self.p3_2(self.ReLu(self.p3_1(x))))p4 =self.ReLu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)
从代码可以看出,每个inception块都分成了四个路径。1,2,3路径都是纯卷积,第四条路径是池化层+卷积。另外,卷积核的大小是固定的,卷积核的通道数是可以通过传参设置的。
传参如下表所示:
| 参数 | 含义 | 示例值 |
|---|---|---|
in_channels | 输入特征图的通道数 | 192 |
c1 | 路径1的输出通道数 | 64 |
c2 | 路径2的通道数元组 (降维, 输出) | (96, 128) |
c3 | 路径3的通道数元组 (降维, 输出) | (16, 32) |
c4 | 路径4的输出通道数 | 32 |
总输出通道数 = c1 + c2 + c3 + c4。示例:64 + 128 + 32 + 32 = 256。
前向传播
当时写代码,我有一个疑问,inception里的前向传播是什么时候触发的,是googlenet在处理block代码流程的时候自动触发的吗?
这个问题涉及到forward方法的隐式调用。在PyTorch中,当通过 模块实例直接调用输入数据 时,forward 方法会被自动触发。例如:
inception = Inception(...) # 实例化模块
output = inception(x) # 隐式调用forward(x)

所以在googlenet前向传播的时候,完成了inception的前向传播。
另外在学习这块还学到个小知识,就是forward方法不能显式调用。会绕过一些关键步骤(如梯度计算),就导致无法反向传播了!
张量拼接
在PyTorch中,torch.cat((p1, p2, p3, p4), dim=1) 这句话的作用是沿着通道维度(channel dimension)将四个张量(p1, p2, p3, p4)拼接成一个更大的张量。以下是详细解释:
假设输入张量 x 的形状为 (batch_size, in_channels, height, width),经过Inception模块的四条路径处理后,每个路径的输出形状如下:
-
p1:(batch_size, c1, height, width)
(1x1卷积直接输出c1个通道) -
p2:(batch_size, c2, height, width)
(1x1卷积降维到c2,再通过3x3卷积输出c2个通道) -
p3:(batch_size, c3, height, width)
(1x1卷积降维到c3,再通过5x5卷积输出c3个通道) -
p4:(batch_size, c4, height, width)
(最大池化后通过1x1卷积输出c4个通道)
所有路径输出的高度(height)和宽度(width)必须一致,否则拼接会失败。批数量和通道数可以不相同。
可以在别的维度拼接吗?不太行,原因是:
dim=0:沿批量维度拼接,会合并不同样本的数据,破坏批量独立性。dim=2/3:沿空间维度拼接,会破坏特征图的空间结构,导致后续卷积无法正常操作。
参数初始化
# 遍历模型的所有子模块(包括嵌套模块)
for m in self.modules():# 对二维卷积层进行初始化if isinstance(m, nn.Conv2d):# 使用Kaiming正态分布初始化权重(针对ReLU激活函数优化)nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')# 如果存在偏置项,将其初始化为0if m.bias is not None:nn.init.constant_(m.bias, 0)# 对全连接层进行初始化 elif isinstance(m, nn.Linear):# 使用正态分布初始化权重(均值0,标准差0.01)nn.init.normal_(m.weight, 0, 0.01)# 如果存在偏置项,将其初始化为0if m.bias is not None:nn.init.constant_(m.bias, 0)
在构建方法里我们增加了参数初始化,参数初始化主要作用是提高收敛速度,减少训练模型时压根不收敛的风险。
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity='relu')
卷积层使用的是kaiming初始化,和relu激活函数搭配使用效果较好。两个参数的含义是:
mode="fan_out":根据输出通道数计算缩放系数nonlinearity='relu':针对ReLU的负半轴修正
nn.init.constant_(m.bias, 0)
卷积层如果存在偏置就统一初始化为0,避免初始阶段引入偏置。
全连接层使用的是小标准差正态分布,作用是限制初始权重范围,防止激活值过大。适用于浅层网络。
一些初始化方法的特点和适用场景:
| 方法 | 适用场景 | 核心思想 | PyTorch实现函数 |
|---|---|---|---|
| Kaiming初始化 | ReLU激活的CNN | 保持前向传播的方差一致性 | kaiming_normal_/uniform_ |
| Xavier初始化 | Tanh/Sigmoid激活 | 平衡输入输出的方差 | xavier_normal_ |
| 零初始化 | 偏置项 | 避免初始偏好 | constant_(0) |
| 正交初始化 | RNN/Transformer | 保持矩阵正交性,防止梯度爆炸 | orthogonal_ |
相关文章:
基于CNN的FashionMNIST数据集识别5——GoogleNet模型
源码 import torch from torch import nn from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu nn.ReLU()#路径1self.p1_1 nn.Conv2d(in_channelsin_channels, out_channelsc1, kern…...
JVM垃圾回收笔记01-垃圾回收算法
文章目录 前言1. 如何判断对象可以回收1.1 引用计数法1.2 可达性分析算法查看根对象哪些对象可以作为 GC Root ?对象可以被回收,就代表一定会被回收吗? 1.3 引用类型1.强引用(StrongReference)2.软引用(SoftReference…...
【初探数据结构】树与二叉树
💬 欢迎讨论:在阅读过程中有任何疑问,欢迎在评论区留言,我们一起交流学习! 👍 点赞、收藏与分享:如果你觉得这篇文章对你有帮助,记得点赞、收藏,并分享给更多对数据结构感…...
numpy学习笔记10:arr *= 2向量化操作性能优化
numpy学习笔记10:arr * 2向量化操作性能优化 在 NumPy 中,直接对整个数组进行向量化操作(如 arr * 2)的效率远高于显式循环(如 for i in range(len(arr)): arr[i] * 2)。以下是详细的解释: 1. …...
蓝桥杯备考:二分答案之路标设置
最大距离,找最小空旷指数值,我们是很容易想到用二分的,我们再看看这个答案有没有二段性 是有这么个二段性的,我们只要二分就行了,但是二分的check函数是有点不好想的,我们枚举空旷值的时候,为了…...
回调方法传参汇总
文章目录 0. 引入问题1. 父子组件传值1.1 父传子:props1.2 子传父:$emit1.3 双向绑定:v-model 2. 多个参数传递3. 父组件监听方法传递其他值3.1 $event3.2 箭头方法 4. 子组件传递多个参数,父组件传递本地参数4.1 箭头函数 … 扩…...
在 Linux下使用 Python 3.11 和 FastAPI 搭建带免费证书的 HTTPS 服务器
在当今数字化时代,保障网站数据传输的安全性至关重要。HTTPS 协议通过使用 SSL/TLS 加密技术,能够有效防止数据在传输过程中被窃取或篡改。本教程将详细介绍如何在 Ubuntu 22.04 系统上,使用 Python 3.11 和 FastAPI 框架搭建一个带有免费 SS…...
XSS基础靶场练习
目录 1. 准备靶场 2. PASS 1. Level 1:无过滤 源码: 2. level2:转HTML实体 htmlspecialchars简介: 源码 PASS 3. level3:转HTML深入 源码: PASS 4. level4:过滤<> 源码: PASS: 5. level5:过滤on 源码…...
Redis核心机制(一)
目录 Redis的特性 1.速度快 2.以键值对方式进行存储 3.丰富的功能 4.客户端语言多 5.持久化 6.主从复制 7.高可用和分布式 Redis使用场景 Redis核心机制——持久化 RDB bgsave执行流程 编辑 AOF AOF重写流程 3.混合持久化(RDBAOF) Red…...
QGroupBox取消勾选时不禁用子控件
默认情况下,QGroupBox取消勾选会自动禁用子控件,如下图所示 那么如何实现取消勾选时不禁用子控件呢? 实现很简单,直接上代码了 connect(ui->groupBox, &QGroupBox::toggled, this, [](bool checked){if (checked false){…...
Go语言中package的使用规则《二》
在 Go 语言中,包(Package) 是代码组织和复用的核心单元。以下是其定义、引用规则及使用习惯的详细说明: 一、包的定义规则 目录与包名 一个包对应一个目录(文件夹),目录名通常与包名一致。 包名…...
MyBatis-Plus 自动填充:优雅实现创建/更新时间自动更新!
目录 一、什么是 MyBatis-Plus 自动填充? 🤔二、自动填充的原理 ⚙️三、实际例子:创建时间和更新时间字段自动填充 ⏰四、注意事项 ⚠️五、总结 🎉 🌟我的其他文章也讲解的比较有趣😁,如果喜欢…...
canvas数据标注功能简单实现:矩形、圆形
背景说明 基于UI同学的设计,在市面上找不到刚刚好的数据标注工具,遂决定自行开发。目前需求是实现图片的矩形、圆形标注,并获取标注的坐标信息,使用canvas可以比较方便的实现该功能。 主要功能 选中图形,进行拖动 使…...
Python 魔术方法深度解析:__getattr__ 与 __getattribute__
一、核心概念与差异解析 1. __getattr__ 的定位与特性 触发时机: 当访问对象中 **不存在的属性** 时自动触发,是 Python 属性访问链中的最后一道防线。 核心能力: 动态生成缺失属性实现优雅的错误处理构建链式调用接口(如 R…...
【机器学习】机器学习工程实战-第2章 项目开始前
上一章:第1章 概述 文章目录 2.1 机器学习项目的优先级排序2.1.1 机器学习的影响2.1.2 机器学习的成本 2.2 估计机器学习项目的复杂度2.2.1 未知因素2.2.2 简化问题2.2.3 非线性进展 2.3 确定机器学习项目的目标2.3.1 模型能做什么2.3.2 成功模型的属性 2.4 构建机…...
【UI设计】一些好用的免费图标素材网站
阿里巴巴矢量图标库https://www.iconfont.cn/国内最大的矢量图标库之一,拥有 800 万 图标资源。特色功能包括团队协作、多端适配、定制化编辑等,适合企业级项目、电商设计、中文产品开发等场景。IconParkhttps://iconpark.oceanengine.com/home字节跳动…...
Visual Studio(VS)的 Release 配置中生成程序数据库(PDB)文件
最近工作中的一个测试工具在测试多台设备上使用过程中闪退,存了dump,但因为是release版本,没有pdb,无法根据dump定位代码哪块出了问题,很苦恼,查了下怎么加pdb生成,记录一下。以下是具体的设置步…...
ubuntu 解挂载时提示 “umount: /home/xx/Applications/yy: target is busy.”
问题如题所示,我挂载一个squanfs文件系统到指定目录,当我使用完后,准备解挂载时,提示umount: /home/xx/Applications/yy: target is busy.,具体的如图所示, 这种提示通常是表明这个路径的内容正在被某些进…...
一条不太简单的TEX学习之路
目录 rule raisebox \includegraphics newenviro 、\vspace \stretch \setlength 解释: 总结: 、\linespread newcommand \par 小四 \small simple 、mutiput画网格 解释: 图案解释: xetex pdelatex etc index 报…...
Matplotlib完全指南:数据可视化从入门到实战
目录 引言 一、环境配置与基础概念 1.1 安装Matplotlib 1.2 导入惯例 1.3 两种绘图模式 二、基础图形绘制 2.1 折线图(Line Plot) 2.2 柱状图(Bar Chart) 三、高级图表类型 3.1 散点图(Scatter Plotÿ…...
在大数据开发中ETL是指什么?
hello宝子们...我们是艾斯视觉擅长ui设计和前端数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! 在数字经济时代,数据已成为企业最核心的资产。然而,分散在业务系统、日志文件…...
OAuth 2.0认证
文章目录 1. 引言1.1 系列文章说明1.2 OAuth 2.0 的起源与演变1.3 应用场景概览 2. OAuth 2.0 核心概念2.1 角色划分2.2 核心术语解析 3. 四种授权模式详解3.1 授权码模式(Authorization Code Grant)3.1.1 完整流程解析3.1.2 PKCE 扩展(防止授…...
【Linux 下的 bash 无法正常解析, Windows 的 CRLF 换行符问题导致的】
文章目录 报错原因:解决办法:方法一:用 dos2unix 修复方法二:手动转换换行符方法三:VSCode 或其他编辑器手动改 总结 这个错误很常见,原因是你的 wait_for_gpu.sh 脚本 文件格式不对,具体来说…...
Kubernetes的Replica Set和ReplicaController有什么区别
ReplicaSet 和 ReplicationController 是 Kubernetes 中用于管理应用程序副本的两种资源,它们有类似的功能,但 ReplicaSet 是 ReplicationController 的增强版本。 以下是它们的主要区别: 1. 功能的演进 ReplicationController 是 Kubernete…...
WSL 导入完整系统包教程
作者: DWDROME 配置环境: OS: Ubuntu 20.04.6 LTS on Windows 11 x86_64Kernel: 5.15.167.4-microsoft-standard-WSL2ros-noetic 🧭WSL 导入完整系统包教程 ✅ 一、准备导出文件 假设你已有一个 .tar 的完整系统包(如从 WSL 或 L…...
[Lc_2 二叉树dfs] 布尔二叉树的值 | 根节点到叶节点数字之和 | 二叉树剪枝
目录 1.计算布尔二叉树的值 题解 2.求根节点到叶节点数字之和 3. 二叉树剪枝 题解 1.计算布尔二叉树的值 链接:2331. 计算布尔二叉树的值 给你一棵 完整二叉树 的根,这棵树有以下特征: 叶子节点 要么值为 0 要么值为 1 ,其…...
SOFABoot-07-版本查看
前言 大家好,我是老马。 sofastack 其实出来很久了,第一次应该是在 2022 年左右开始关注,但是一直没有深入研究。 最近想学习一下 SOFA 对于生态的设计和思考。 sofaboot 系列 SOFABoot-00-sofaboot 概览 SOFABoot-01-蚂蚁金服开源的 s…...
蓝桥杯 之 第27场月赛总结
文章目录 习题1.抓猪拿国一2.蓝桥字符3.蓝桥大使4.拳头对决 习题 比赛地址 1.抓猪拿国一 十分简单的签到题 print(sum(list(range(17))))2.蓝桥字符 常见的字符匹配的问题,是一个二维dp的问题,转化为对应的动态规划求解 力扣的相似题目 可以关注灵神…...
第十六章:Specialization and Overloading_《C++ Templates》notes
Specialization and Overloading 一、模板特化与重载的核心概念二、代码实战与测试用例三、关键知识点总结四、进阶技巧五、实践建议多选题设计题代码测试说明 一、模板特化与重载的核心概念 函数模板重载 (Function Template Overloading) // 基础模板 template<typename…...
可视化动态表单动态表单界的天花板--Formily(阿里开源)
文章目录 1、Formily表单介绍2、安装依赖2.1、安装内核库2.2、 安装 UI 桥接库2.3、Formily 支持多种 UI 组件生态: 3、表单设计器3.1、核心理念3.2、安装3.3、示例源码 4、场景案例-登录注册4.1、Markup Schema 案例4.2、JSON Schema 案例4.3、纯 JSX 案例 1、Form…...
