VGG卷积神经网络实现Cifar10图片分类-Pytorch实战
前言
当涉足深度学习,选择合适的框架是至关重要的一步。PyTorch作为三大主流框架之一,以其简单易用的特点,成为初学者们的首选。相比其他框架,PyTorch更像是一门易学的编程语言,让我们专注于实现项目的功能,而无需深陷于底层原理的细节。
就像我们使用汽车时,更重要的是了解如何驾驭,而不是花费过多时间研究轮子是如何制造的。我将以一系列专门针对深度学习框架的文章,逐步深入理论知识和实践操作。但这需要在对深度学习有一定了解后才能进行,现阶段我们的重点是学会如何灵活使用PyTorch工具。深度学习涉及大量数学理论和计算原理,对于初学者来说可能会有些繁琐。然而,只有通过实际操作,我们才能真正理解所写代码在神经网络中的作用。我将努力将知识简化,转化为我们熟悉的内容,让大家能够理解和熟练使用神经网络框架。
如果你发现深度学习看似难以掌握,我将尽力简化知识,将其转化为我们更容易理解的内容。我会确保你能够理解知识并顺利运用到实践中。在后期,我将发布一系列专门解析深度学习框架的文章,但在开始学习之前,我们需要对深度学习的理论知识和实践操作有一定的熟悉度。
作为一个从事数据建模五年的专业人士,我参与了许多数学建模项目,了解各种模型的原理、建模流程和题目分析方法。我希望通过这个专栏让你能够快速掌握各类数学模型、机器学习和深度学习知识,并掌握相应的代码实现。每篇文章都包含实际项目和可运行的代码。我会紧跟各类数模比赛,将最新的思路和代码分享给你,保证你能够高效地学习这些知识。
博主非常期待与你一同探索这个精心打造的专栏,里面充满了丰富的实战项目和可运行的代码,希望你不要错过:专栏链接
一、VGGNet概述
VGGNet(Visual Geometry Group Network)是由牛津大学视觉几何组(Visual Geometry Group)提出的深度卷积神经网络架构,它在2014年的ImageNet图像分类挑战中取得了优异的成绩。VGGNet之所以著名,一方面是因为其简洁而高效的网络结构,另一方面是因为它通过深度堆叠的方式展示了深度卷积神经网络的强大能力。
VGGNet探索了卷积神经网络的深度与其性能之间的关系,成功地构筑了16~19层深的卷积神经网络,证明了增加网络的深度能够在一定程度上影响网络最终的性能,使错误率大幅下降,同时拓展性又很强,迁移到其它图片数据上的泛化性也非常好。到目前为止,VGG仍然被用来提取图像特征。
VGGNet包含两种结构,分别为16层和19层。VGGNet结构中,所有卷积层的kernel都只有3*3。VGGNet中连续使用3组3*3kernel的原因是它与使用1个7*7kernel产生的效果相同,然而更深的网络结构还会学习到更复杂的非线性关系,从而使得模型的效果更好。该操作带来的另一个好处是参数数量的减少,因为对于一个包含了C个kernel的卷积层来说,原来的参数个数为7*7*C,而新的参数个数为3*(3*3*C)。
下图给出了VGG16的具体结构示意图:


根据VGG16进行具体分析,包含:
- 13个卷积层(Convolutional Layer)
- 3个全连接层(Fully connected Layer)
- 5个池化层(Pool layer)
其中,卷积层和全连接层具有权重系数,因此也被称为权重层,总数目为13+3=16,这即是VGG16中16的来源。

内存消耗主要来自早期的卷积,而参数量的激增则发生在后期的全连接层。由于采用了大量的卷积层,导致VGGNet的参数数量较大,训练和推理过程需要更多的计算资源。而且参数量较大,需要更多的数据来避免过拟合问题。
二、PyTorch网络搭建
我们参考上述网络结构,利用pytorch进行网络搭建,首先我们可以先搭建输出层,根据我上述提供的每一层具体的parameters搭建即可:
def __init__(self, num_classes=1000):super(VGG,self).__init()__self.features = self._make_layers()self.classifier = nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(True),nn.Dropout(),nn,Linear(4096,4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096,num_classes))
接下来我们来搭建卷积和全连接层,可以利用循环帮助我们省去每个步骤繁琐的写层:
def _make_layers(self):layers = []in_clannels = 3cfg =[64,64,'M',128,128,'M',256,256,256,'M',512,512,512,'M']for v in cfg:if v =='M':layers +=[nn.MaxPool2d(kernel_size=2,stride=2)]else:conv2d = nn.Conv2d(in_channels,v,kernel_size)layers +=[conv2d,nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)
然后写入每个神经网络必备的传播:
def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x
总体网络结构为:
VGGNet((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)
定义损失函数和优化方法:
#定义损失函数和优化方式
criterion = nn.CrossEntropyLoss() #定义损失函数:交叉熵
optimizer = torch.optim.SGD(net.parameters(),lr=0.001,momentum=0.9)#定义优化方法,随机梯度下降
进行卷积网络训练,这里需要微调一下原来vgg的模型,Cifar10的数据集有10个类别而且图片转换的矩阵需要加入自适应池化层,要一些改进:
import torch.nn as nn# 设置随机种子以保证实验的可复现性
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falseclass VGGNet(nn.Module):def __init__(self, num_classes=10):super(VGGNet, self).__init__()self.features = self._make_layers()self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096,4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096,num_classes))def _make_layers(self):layers = []in_channels = 3cfg =[64,64,'M',128,128,'M',256,256,256,'M',512,512,512,'M',512, 512, 512, 'M']for v in cfg:if v =='M':layers +=[nn.MaxPool2d(kernel_size=2,stride=2)]else:conv2d = nn.Conv2d(in_channels,v,kernel_size=3, padding=1)layers +=[conv2d,nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x
需要注意到是我们需要初始化网络的权重,不更新权重的话10000张图片和实际不借助算法猜测图片的概率是一致的,我们先不初始化网络的权重进行训练:
for epoch in range(1):train_loss=0.0for batch_idx,data in enumerate(train_loader,0):#初始化inputs,labels = data #获取数据inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() #梯度置0#优化过程outputs = net(inputs) #将数据输入到网络,得到第一轮网络前向传播的预测结果outputsloss = criterion(outputs,labels) #预测结果outputs和labels通过之前定义的交叉熵计算损失loss.backward() #误差反向传播optimizer.step() #随机梯度下降优化权重#查看网络训练状态train_loss += loss.item()if batch_idx % 2000 == 1 :print(batch_idx)print('[%d,%5d] loss: %.3f' % (epoch + 1,batch_idx + 1,train_loss / 2000))train_loss = 0.0print('Saving epoch %d model ...'%(epoch + 1))state = {'net':net.state_dict(),'epoch':epoch+1,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')#torch.save(state,'./checkpoint/cifar10_epoch_%d.ckpt'%(epoch+1))print('Finished Training')
然后我们去计算整个测试集的预测效果:
#批量计算整个测试集的预测效果
correct= 0
total = 0
with torch.no_grad():for data in test_loader:images,labels = dataimages = images.to(device)labels = labels.to(device)outputs = net(images)_,predicted = torch.max(outputs.data,1)total += labels.size(0)correct += (predicted == labels ).sum().item() #当标记的label种类和预测的种类一致时认为正确,并计数print('Accurary of the network on the 10000 test images : %d %%'%(100*correct/total))
很明显和实际猜测的概率是一模一样的,总共十个类别1/10很正常:
Accurary of the network on the 10000 test images : 10 %
我们需要先进行初始化网络权重在训练:
def initialize_weights(module):if isinstance(module, nn.Conv2d):nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')if module.bias is not None:nn.init.constant_(module.bias, 0)elif isinstance(module, nn.Linear):nn.init.normal_(module.weight, 0, 0.01)nn.init.constant_(module.bias, 0)
之后在训练预测一版:
Accurary of the network on the 10000 test images : 47 %
效果就十分明显了。
点关注,防走丢,如有纰漏之处,请留言指教,非常感谢
以上就是本期全部内容。我是fanstuck ,有问题大家随时留言讨论 ,我们下期见。
相关文章:
VGG卷积神经网络实现Cifar10图片分类-Pytorch实战
前言 当涉足深度学习,选择合适的框架是至关重要的一步。PyTorch作为三大主流框架之一,以其简单易用的特点,成为初学者们的首选。相比其他框架,PyTorch更像是一门易学的编程语言,让我们专注于实现项目的功能࿰…...
CentOS 7文件系统中的软链接和硬链接
软链接(Symbolic Link) 软链接,也称为符号链接,是一个指向另一个文件或目录的特殊类型的文件。它是一个指向目标文件的符号,就像快捷方式一样。软链接的创建和使用非常灵活,适用于各种情况。 创建软链接 …...
【AI】深度学习——前馈神经网络——全连接前馈神经网络
文章目录 1.1 全连接前馈神经网络1.1.1 符号说明超参数参数活性值 1.1.2 信息传播公式通用近似定理 1.1.3 神经网络与机器学习结合二分类问题多分类问题 1.1.4 参数学习矩阵求导链式法则更为高效的参数学习反向传播算法目标计算 ∂ z ( l ) ∂ w i j ( l ) \frac{\partial z^{…...
超简单的视频截取方法,迅速提取所需片段!
“视频可以截取吗?用相机拍摄了一段视频,但是中途相机发生了故障,录进去了很多不需要的片段,现在想截取一部分视频出来,但是不知道方法,想问问广大的网友,知不知道视频截取的方法。” 无论是工…...
ArcGIS/GeoScene脚本:基于粒子群优化的支持向量机回归模型
参数输入 1.样本数据必须包含需要回归的字段 2.回归字段是数值类型 3.影响因子是栅格数据,可添加多个 4.随机种子可以确保每次运行的训练集和测试集一致 5.训练集占比为0-1之间的小数 6.迭代次数:迭代次数越高精度越高,但是运行时间越长…...
vue3组件的通信方式
一、vue3组件通信方式 通信仓库地址:vue3_communication: 当前仓库为贾成豪老师使用组件通信案例 不管是vue2还是vue3,组件通信方式很重要,不管是项目还是面试都是经常用到的知识点。 比如:vue2组件通信方式 props:可以实现父子组件、子父组件、甚至兄弟组件通信 自定义事件:可…...
Qt QPair
QPair 文章目录 QPair 摘要QPairQPair 特点代码示例QPair 与 QMap 区别 关键字: Qt、 QPair、 QMap、 键值、 容器 摘要 今天在观摩小伙伴撸代码的时候,突然听到了QPair自己使用Qt开发这么就,竟然都不知道,所以趁没有被人发…...
K8S云计算系列-(3)
K8S Kubeadm案例实战 Kubeadm 是一个K8S部署工具,它提供了kubeadm init 以及 kubeadm join 这两个命令来快速创建kubernetes集群。 Kubeadm 通过执行必要的操作来启动和运行一个最小可用的集群。它故意被设计为只关心启动集群,而不是之前的节点准备工作…...
ardupilot罗盘数据计算航向
目录 文章目录 目录摘要1.数据特点2.数据结论1.结论2.结论摘要 本节主要记录ardupilot 根据罗盘数据计算航向的过程。 如果知道了一组罗盘数据,我们可以粗略估计航向:主要后面我们所说的X和Y都是表示的飞机里面的坐标系,也就是X前Y右边,如果按照罗盘坐标系Y实际在左边。 我…...
第六章:最新版零基础学习 PYTHON 教程—Python 正则表达式(第一节 - Python 正则表达式)
在本教程中,您将了解RegEx并了解各种正则表达式。 常用表达为什么使用正则表达式基本正则表达式更多正则表达式编译的正则表达式 目录 元字符 为什么是正则表达式?...
docker安装Jenkins完整教程
1.docker拉取 Jenkins镜像并启动容器 新版本的Jenkins依赖于JDK11 我们选择docker中jdk11版本的镜像 # 拉取镜像 docker pull jenkins/jenkins:2.346.3-2-lts-jdk11 2.宿主机上创建文件夹 # 创建Jenkins目录文件夹 mkdir -p /data/jenkins_home # 设置权限 chmod 777 -R /dat…...
[CISCN 2019初赛]Love Math - RCE(异或绕过)
[CISCN 2019初赛]Love Math 1 解题流程1.1 分析1.2 解题题目代码: <?php //听说你很喜欢数学,不知道你是否爱它胜过爱flag if(!isset($_GET[c]))...
C++ 使用getline()从文件中读取一行字符串
我们知道,getline() 方法定义在 istream 类中,而 fstream 和 ifstream 类继承自 istream 类,因此 fstream 和 ifstream 的类对象可以调用 getline() 成员方法。 当文件流对象调用 getline() 方法时,该方法的功能就变成了从指定文件中读取一行字符串。 该方法有以下 2 种语…...
JS进阶-原型
原型 原型就是一个对象,也称为原型对象 构造函数通过原型分配的函数是所有对象所共享的 JavaScript规定,每一个构造函数都有一个prototype属性,指向另一个对象,所以我们也称为原型对象 这个对象可以挂载函数,对象实…...
虹科方案 | 汽车CAN/LIN总线数据采集解决方案
全文导读:现代汽车配备了复杂的电子系统,CAN和LIN总线已成为这些系统之间实现通信的标准协议,为了开发和优化汽车的电子功能,汽车制造商和工程师需要可靠的数据采集解决方案。基于PCAN和PLIN设备,虹科提供了一种高效、…...
HTML5+CSSDAY4综合案例一--热词
样式展示图: 代码如下: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>热词…...
【源码】hamcrest 源码阅读 泛型 extends 和迭代器模式
文章目录 前言1. 泛型参数和自定义迭代器1.1 使用场景1.2 实现 2. 值得一提 前言 官方文档 Hamcrest Tutorial 上篇文章 Hamcrest 源码阅读及空对象模式、模板方法模式的应用 本篇文章 迭代器模式 1. 泛型参数和自定义迭代器 hamcrest 作为一个matcher库,把某个…...
IntelliJ IDEA 2023.1 版本可以安装了
Maven 的导入时间更加快了。 收到的有邮件提醒安装。 安装后的版本,其实就是升级下,并没有什么主要改变。 IntelliJ IDEA 2023.1 版本可以安装了 - 软件技术 - OSSEZMaven 的导入时间更加快了。 收到的有邮件提醒安装。 安装后的版本,其实就是…...
安全论坛和外包平台汇总
文章目录 一. 网络安全论坛汇总二. 外包平台汇总1. 国内:2. 国外 一. 网络安全论坛汇总 安全焦点BugTraq:http://www.fuzzysecurity.com/Exploit-DB:https://www.exploit-db.com/hackone:https://www.hackerone.com/FreeBuf&…...
9-2-Dataset创建-import调用
文章目录 utils_dataset.pymain-调用utils_dateset.pyutils_dataset.py 1默认:没有改变尺寸,数据集中的图像可以是任意形状尺寸。dataloader中必须令batch_size=1 transforms.Resize((宽,高))(image) 和 batch_size=1 必须用其一 原因:当batch_size>1时,每个batch的数…...
UE5 学习系列(二)用户操作界面及介绍
这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...
OpenLayers 可视化之热力图
注:当前使用的是 ol 5.3.0 版本,天地图使用的key请到天地图官网申请,并替换为自己的key 热力图(Heatmap)又叫热点图,是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...
(十)学生端搭建
本次旨在将之前的已完成的部分功能进行拼装到学生端,同时完善学生端的构建。本次工作主要包括: 1.学生端整体界面布局 2.模拟考场与部分个人画像流程的串联 3.整体学生端逻辑 一、学生端 在主界面可以选择自己的用户角色 选择学生则进入学生登录界面…...
AI Agent与Agentic AI:原理、应用、挑战与未来展望
文章目录 一、引言二、AI Agent与Agentic AI的兴起2.1 技术契机与生态成熟2.2 Agent的定义与特征2.3 Agent的发展历程 三、AI Agent的核心技术栈解密3.1 感知模块代码示例:使用Python和OpenCV进行图像识别 3.2 认知与决策模块代码示例:使用OpenAI GPT-3进…...
SCAU期末笔记 - 数据分析与数据挖掘题库解析
这门怎么题库答案不全啊日 来简单学一下子来 一、选择题(可多选) 将原始数据进行集成、变换、维度规约、数值规约是在以下哪个步骤的任务?(C) A. 频繁模式挖掘 B.分类和预测 C.数据预处理 D.数据流挖掘 A. 频繁模式挖掘:专注于发现数据中…...
关于nvm与node.js
1 安装nvm 安装过程中手动修改 nvm的安装路径, 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解,但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后,通常在该文件中会出现以下配置&…...
蓝桥杯 2024 15届国赛 A组 儿童节快乐
P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的“no matching...“系列算法协商失败问题
【SSH疑难排查】轻松解决新版OpenSSH连接旧服务器的"no matching..."系列算法协商失败问题 摘要: 近期,在使用较新版本的OpenSSH客户端连接老旧SSH服务器时,会遇到 "no matching key exchange method found", "n…...
基于Springboot+Vue的办公管理系统
角色: 管理员、员工 技术: 后端: SpringBoot, Vue2, MySQL, Mybatis-Plus 前端: Vue2, Element-UI, Axios, Echarts, Vue-Router 核心功能: 该办公管理系统是一个综合性的企业内部管理平台,旨在提升企业运营效率和员工管理水…...
苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会
在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...
