当前位置: 首页 > news >正文

现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

专栏:神经网络复现目录


本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet)。
文章部分文字和代码来自《动手学深度学习》

文章目录

  • 使用块的网络(VGG)
  • VGG块
    • 定义
    • 实现
  • VGG16
    • 模型设计
    • 实现
  • 利用VGG16进行CIFAR10分类
    • 数据集
    • 超参数,优化器,损失函数
    • 训练


使用块的网络(VGG)

VGG是一种深度卷积神经网络,由牛津大学视觉几何组(Visual Geometry Group)在2014年提出。它是由多个卷积层和池化层组成的深度神经网络,具有很强的图像分类能力,特别是在图像识别领域,取得了很好的成果。

VGG的特点在于,它使用相对较小的卷积核(3x3),但是通过叠加多个卷积层和池化层,增加了网络的深度,从而达到更好的图像分类性能。VGG网络包含了多个版本,以卷积层数目为标志,如VGG16和VGG19等,其中VGG16和VGG19是最著名的两个版本。

VGG网络的设计非常简单和规整,容易理解和实现,因此也成为了很多深度学习新手的入门模型。

下图为VGG的六个版本,比较实用的是VGG16和VGG19,本文以VGG16为例子进行讲解
在这里插入图片描述

VGG块

定义

VGG块是VGG网络中的一个基本组成单元,由若干个卷积层和池化层组成,通常用于提取输入图像的特征。每个VGG块都由连续的1或2个卷积层,和一个最大池化层组成。其中,卷积层的卷积核大小都是3x3,而池化层的窗口大小通常是2x2。在每个VGG块中,卷积层的输出通道数都相同,可以通过超参数进行控制。

具体来说,假设一个VGG块由k个卷积层和一个池化层组成,输入为xxx,则该块的输出可以表示为:

VGG(x)=Pool(convk(convk−1(⋯conv1(x)))).\text{VGG}(x) = \text{Pool}(\text{conv}k(\text{conv}{k-1}(\cdots\text{conv}_1(x)))).VGG(x)=Pool(convk(convk1(conv1(x)))).

其中,convi(⋅)\text{conv}_i(\cdot)convi()表示第iii个卷积层,Pool(⋅)\text{Pool}(\cdot)Pool()表示池化层。在VGG块中,每个卷积层都会使用ReLU激活函数进行非线性变换,而最大池化层则用于下采样和特征压缩。

在VGG网络中,通常通过叠加多个VGG块来构建网络结构。通过增加VGG块的数量,可以增加网络的深度和宽度,从而提高网络的表达能力和泛化性能。

实现

self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)

inplace=True 表示对于输入的张量进行原地操作,即直接对原始的输入张量进行修改,而不是创建一个新的张量。这样做可以节省内存,但会覆盖原始的输入张量,可能会对后续的计算产生影响。因此,当我们需要保留原始的输入张量时,可以将 inplace 参数设置为 False。

VGG16

模型设计

VGG16是一个卷积神经网络模型,包含13个卷积层、5个池化层和3个全连接层,是由牛津大学计算机视觉组(Visual Geometry Group)在2014年提出的模型,具有较好的图像识别表现。

VGG16模型的架构如下:

输入层:输入图像的大小为224x224x3。

VGG块1

卷积层1:使用64个3x3大小的卷积核进行卷积操作,得到64张大小为224x224的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层2:使用64个3x3大小的卷积核进行卷积操作,得到64张大小为224x224的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层1:使用2x2的最大池化操作,将64张大小为224x224的特征图缩小为64张大小为112x112的特征图。采用SAME填充,步长为2。

VGG块2

卷积层3:使用128个3x3大小的卷积核进行卷积操作,得到128张大小为112x112的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层4:使用128个3x3大小的卷积核进行卷积操作,得到128张大小为112x112的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层2:使用2x2的最大池化操作,将128张大小为112x112的特征图缩小为128张大小为56x56的特征图。采用SAME填充,步长为2。

VGG块3

卷积层5:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层6:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层7:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层3:使用2x2的最大池化操作,将256张大小为56x56的特征图缩小为256张大小为28x28的特征图。采用SAME填充,步长为2。

VGG块4

卷积层8-10:使用512个3x3大小的卷积核进行卷积操作,得到512张大小为28x28的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层4:使用2x2的最大池化操作,将512张大小为28x28的特征图缩小为512张大小为14x14的特征图。采用SAME填充,步长为2。

VGG块5

卷积层11-13:使用512个3x3大小的卷积核进行卷积操作,得到512张大小为14x14的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层5:使用2x2的最大池化操作,将512张大小为14x14的特征图缩小为512张大小为7x7的特征图。采用SAME填充,步长为2。

全连接层

3个全连接层,第1、2个都有4096个输出通道,第3个全连接层则有1000个输出通道。

实现

class VGG16(nn.Module):def __init__(self):super(VGG16,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv2=nn.Sequential(nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv3=nn.Sequential(nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv4=nn.Sequential(nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv5=nn.Sequential(nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.feature=nn.Sequential(self.conv1,self.conv2,self.conv3,self.conv4,self.conv5,)self.flatten=nn.Flatten()self.fc=nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(4096,4096),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(4096,1000),#nn.Softmax(10))def forward(self,x):x=self.feature(x)# x=self.flatten(x)x = x.view(x.size(0), -1)x=self.fc(x)return x

查看结构

vgg = VGG16()
print(vgg)
x=torch.rand(1,3,224,224)
y=vgg(x)
print(y.shape)

利用VGG16进行CIFAR10分类

import torch.nn as nn
import torch
import torchvisionif(torch.cuda.is_available()):device = torch.device("cuda")print("使用GPU训练中:{}".format(torch.cuda.get_device_name()))
else:device = torch.device("cpu")print("使用CPU训练")

数据集

# transform的创建(compose方法)
from torchvision import transforms
def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4def load_data_cifar10(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
batch_size=4
train_iter, test_iter = load_data_cifar10(batch_size,resize=224)

超参数,优化器,损失函数

from torch import optim
net=VGG16()
lr=0.001
optimizer=optim.SGD(net.parameters(),lr=lr,momentum=0.9)
loss=nn.CrossEntropyLoss()
epochs=10

训练

def train(net,train_iter,test_iter,num_epochs, lr, device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)for epoch in range(num_epochs):net.train()train_step = 0for i, (X, y) in enumerate(train_iter):optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l=loss(y_hat,y)l.backward()optimizer.step()train_step+=1if(train_step%50==0):#每训练一百组输出一次损失print("第{}轮的第{}次训练的loss:{}".format((epoch+1),train_step,l.item()))

相关文章:

现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络…...

Java代码弱点与修复之——Dereference null return value(间接引用空返回值)

弱点描述 Dereference null return value,间接引用空返回值。是Coverity Scan静态代码分析工具中的一个警告,表示代码中有对可能为空(null)的方法或函数返回值进行间接引用(Dereference)操作。 该类型的漏洞可能会导致 NullPointerException 异常,并且会导致程序崩溃或…...

【冲刺蓝桥杯的最后30天】day3

大家好😃,我是想要慢慢变得优秀的向阳🌞同学👨‍💻,断更了整整一年,又开始恢复CSDN更新,从今天开始更新备战蓝桥30天系列,一共30天,如果对你有帮助或者正在备…...

光伏发电嵌入式ARM工控机

随着智慧电力技术的不断发展和普及,越来越多的电力设备和系统需要采用先进的控制和监测技术来实现自动化管理和优化运行。其中,嵌入式 ARM 控制器技术在智慧电力领域中得到了广泛应用。同时,导轨安装也是该技术的重要应用场景之一。 导轨安装…...

推荐 7 个 Vue.js 插件,也许你的项目用的上(五)

当我们可以通过使用库轻松实现相同的结果时,为什么还要编写自定义功能?开发人员最好的朋友和救星就是这些第三方库。我相信一个好的项目会利用一些可用的最佳库。Vue.js 是创建用户界面的最佳 JavaScript 框架之一。这篇文章是关于 Vue.js 的优秀库系列的…...

1.1基于知识图谱的项目实战:优酷搜索泛查询意图优化

NLU的技术实现主要分为在线识别和离线数据挖掘两块。 1.在线识别 NLU的在线识别技术栈如下图所示,共由下述2个部分组成: 第一个部分是Slot Filling(成分分析),负责对query进行实体识别和槽位抽取;第二部分Inention Detection(意图识别),根据提取的槽位进行意图的判定(目…...

[java Spring JdbcTemplate配合mysql实现数据批量删除

之前的文章 java Spring JdbcTemplate配合mysql实现数据批量添加和文章java Spring JdbcTemplate配合mysql实现数据批量修改 先后讲解了 mysql数据库的批量添加和批量删除操作 会了这两个操作之后 批量删除就不要太简单 我们看到数据库 这里 我们用的是mysql工具 这里 我们有…...

uos 20 统信 fprintd 记录

uos 20 统信 fprintd 记录 sudo busctl deepin-authenticate.service /usr/lib/systemd/system/deepin-authenticate.service [Unit] DescriptionDeepin Authentication[Service] Typedbus BusNamecom.deepin.daemon.Authenticate ExecStart/usr/lib/deepin-authenticate/d…...

vue移动端h5,文本溢出显示省略号,且展示‘更多’按钮

问题: 元素宽度100%,宽度会随着浏览器缩放而变化。元素内文本超过4行时显示省略号,同时展示‘更多’按钮,点击更多按钮展示全部文本。如下图所示 超出四行显示省略号(…)的代码 .content{overflow:hidden;text-overflow: elli…...

php宝塔搭建部署实战兰空图床程序网站PHP源码

大家好啊,我是测评君,欢迎来到web测评。 本期给大家带来一套Lsky Pro兰空图床程序网站PHP的源码。感兴趣的朋友可以自行下载学习。 技术架构 PHP8.0 nginx mysql5.7 JS CSS HTMLcnetos7以上 宝塔面板 文字搭建教程 下载源码,宝塔添加…...

软件测试面试:拿到一个产品(版本)如何开展测试?

产品提测后,如何开展测试? 我们都了解软件测试的执行流程,......提测-冒烟测试-详细测试-提交缺陷报告-回归测试,但软件测试并不总是线性过程,它甚至可能是螺旋结构,不断地试错,不断地迭代&…...

【Opencv项目实战】图像的像素值反转

文章目录一、项目思路二、算法详解2.1、获取图像信息2.2、新建模板2.3、图像通道顺序三、项目实战:彩图的像素值反转(方法一)四、项目实战:彩图的像素值反转(方法二)五、项目实战:彩图转换为灰图…...

Swagger生成接口在线文档

OpenAPI规范(OpenAPI Specification 简称OAS)是Linux基金会的一个项目,试图通过定义一种用来描述API格式或API定义的语言,来规范RESTful服务开发过程,目前版本是V3.0,并且已经发布并开源在github上。&#…...

104.第十九章 MySQL数据库 -- MySQL主从复制、 级联复制和双主复制(十四)

6.1.2 实现主从复制配置 参考官网 https://dev.mysql.com/doc/refman/8.0/en/replication-configuration.html https://dev.mysql.com/doc/refman/5.7/en/replication-configuration.html https://dev.mysql.com/doc/refman/5.5/en/replication-configuration.html https://m…...

第一次使用Python for Qt中的问题

在创建带有form的python for qt的时候,使用的库是pySide6,而不是pyqt。 因此,需要安装pyside6。 Running "/usr/bin/python3 -m pip install PySide6 --user" to install PySide6. ERROR: Could not find a version that satisfi…...

.Net Core WebApi 在Linux系统Deepin上部署Nginx并使用(一)

前言: Deepin最初是基于Ubuntu的发行版 2015年脱离Ubuntu开发,开始基于Ubuntu上游Debian操作系统 2019年脱离Debian,直接基于Linux开发,真正属于自己的上游Linux系统发行版 2022年8月,新版《Deepin V23》我下载开始了我…...

Java——打开轮盘锁

题目链接 leetcode在线oj题——打开轮盘锁 题目描述 你有一个带有四个圆形拨轮的转盘锁。每个拨轮都有10个数字: ‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’ 。每个拨轮可以自由旋转:例如把 ‘9’ 变为 ‘0’&#xff0…...

JavaScript(2)

一、事件 HTML事件是发生在hTML元素上的“事情”。比如&#xff1a;按钮被点击、鼠标移动到元素上等… 事件绑定 方式一&#xff1a;通过HTML标签中的事件属性进行绑定 <input type"button" value"点我" onclick"on()"><script>fun…...

FFMPEG 安装教程windowslinux(CentOS版)

ps: 从笔记中迁移至blog 版本概述 Windows 基于win10 Linux 基于CentOS 7.6 一.Windows安装笔记 1.下载安装 https://ffmpeg.org/download.html 2 解压缩&#xff0c;拷贝到需要目录&#xff0c;重命名 3 追加环境变量 echo %PATH%setx /m PATH "%PATH%;F:\dev_tools\…...

【虹科案例】虹科任意波形发生器在量子计算中的应用

虹科AWG在量子计算中的应用精度在研究中始终很重要&#xff0c;很少有研究领域需要比量子研究更高的精度。奥地利因斯布鲁克大学的量子光学和量子信息研究所需要一个任意波形发生器&#xff08;AWG&#xff09;来为他们的研究生成各种各样的信号。01无线电频率第一个应用是在射…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 &#xff08;一&#xff09;实时滤波与参数调整 基础滤波操作 60Hz 工频滤波&#xff1a;勾选界面右侧 “60Hz” 复选框&#xff0c;可有效抑制电网干扰&#xff08;适用于北美地区&#xff0c;欧洲用户可调整为 50Hz&#xff09;。 平滑处理&…...

VTK如何让部分单位不可见

最近遇到一个需求&#xff0c;需要让一个vtkDataSet中的部分单元不可见&#xff0c;查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行&#xff0c;是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示&#xff0c;主要是最后一个参数&#xff0c;透明度…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)

前言&#xff1a; 在Java编程中&#xff0c;类的生命周期是指类从被加载到内存中开始&#xff0c;到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期&#xff0c;让读者对此有深刻印象。 目录 ​…...

GitFlow 工作模式(详解)

今天再学项目的过程中遇到使用gitflow模式管理代码&#xff0c;因此进行学习并且发布关于gitflow的一些思考 Git与GitFlow模式 我们在写代码的时候通常会进行网上保存&#xff0c;无论是github还是gittee&#xff0c;都是一种基于git去保存代码的形式&#xff0c;这样保存代码…...

Linux部署私有文件管理系统MinIO

最近需要用到一个文件管理服务&#xff0c;但是又不想花钱&#xff0c;所以就想着自己搭建一个&#xff0c;刚好我们用的一个开源框架已经集成了MinIO&#xff0c;所以就选了这个 我这边对文件服务性能要求不是太高&#xff0c;单机版就可以 安装非常简单&#xff0c;几个命令就…...

人工智能 - 在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型

在Dify、Coze、n8n、FastGPT和RAGFlow之间做出技术选型。这些平台各有侧重&#xff0c;适用场景差异显著。下面我将从核心功能定位、典型应用场景、真实体验痛点、选型决策关键点进行拆解&#xff0c;并提供具体场景下的推荐方案。 一、核心功能定位速览 平台核心定位技术栈亮…...

二维FDTD算法仿真

二维FDTD算法仿真&#xff0c;并带完全匹配层&#xff0c;输入波形为高斯波、平面波 FDTD_二维/FDTD.zip , 6075 FDTD_二维/FDTD_31.m , 1029 FDTD_二维/FDTD_32.m , 2806 FDTD_二维/FDTD_33.m , 3782 FDTD_二维/FDTD_34.m , 4182 FDTD_二维/FDTD_35.m , 4793...

React从基础入门到高级实战:React 实战项目 - 项目五:微前端与模块化架构

React 实战项目&#xff1a;微前端与模块化架构 欢迎来到 React 开发教程专栏 的第 30 篇&#xff01;在前 29 篇文章中&#xff0c;我们从 React 的基础概念逐步深入到高级技巧&#xff0c;涵盖了组件设计、状态管理、路由配置、性能优化和企业级应用等核心内容。这一次&…...