当前位置: 首页 > news >正文

现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

专栏:神经网络复现目录


本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet)。
文章部分文字和代码来自《动手学深度学习》

文章目录

  • 使用块的网络(VGG)
  • VGG块
    • 定义
    • 实现
  • VGG16
    • 模型设计
    • 实现
  • 利用VGG16进行CIFAR10分类
    • 数据集
    • 超参数,优化器,损失函数
    • 训练


使用块的网络(VGG)

VGG是一种深度卷积神经网络,由牛津大学视觉几何组(Visual Geometry Group)在2014年提出。它是由多个卷积层和池化层组成的深度神经网络,具有很强的图像分类能力,特别是在图像识别领域,取得了很好的成果。

VGG的特点在于,它使用相对较小的卷积核(3x3),但是通过叠加多个卷积层和池化层,增加了网络的深度,从而达到更好的图像分类性能。VGG网络包含了多个版本,以卷积层数目为标志,如VGG16和VGG19等,其中VGG16和VGG19是最著名的两个版本。

VGG网络的设计非常简单和规整,容易理解和实现,因此也成为了很多深度学习新手的入门模型。

下图为VGG的六个版本,比较实用的是VGG16和VGG19,本文以VGG16为例子进行讲解
在这里插入图片描述

VGG块

定义

VGG块是VGG网络中的一个基本组成单元,由若干个卷积层和池化层组成,通常用于提取输入图像的特征。每个VGG块都由连续的1或2个卷积层,和一个最大池化层组成。其中,卷积层的卷积核大小都是3x3,而池化层的窗口大小通常是2x2。在每个VGG块中,卷积层的输出通道数都相同,可以通过超参数进行控制。

具体来说,假设一个VGG块由k个卷积层和一个池化层组成,输入为xxx,则该块的输出可以表示为:

VGG(x)=Pool(convk(convk−1(⋯conv1(x)))).\text{VGG}(x) = \text{Pool}(\text{conv}k(\text{conv}{k-1}(\cdots\text{conv}_1(x)))).VGG(x)=Pool(convk(convk1(conv1(x)))).

其中,convi(⋅)\text{conv}_i(\cdot)convi()表示第iii个卷积层,Pool(⋅)\text{Pool}(\cdot)Pool()表示池化层。在VGG块中,每个卷积层都会使用ReLU激活函数进行非线性变换,而最大池化层则用于下采样和特征压缩。

在VGG网络中,通常通过叠加多个VGG块来构建网络结构。通过增加VGG块的数量,可以增加网络的深度和宽度,从而提高网络的表达能力和泛化性能。

实现

self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)

inplace=True 表示对于输入的张量进行原地操作,即直接对原始的输入张量进行修改,而不是创建一个新的张量。这样做可以节省内存,但会覆盖原始的输入张量,可能会对后续的计算产生影响。因此,当我们需要保留原始的输入张量时,可以将 inplace 参数设置为 False。

VGG16

模型设计

VGG16是一个卷积神经网络模型,包含13个卷积层、5个池化层和3个全连接层,是由牛津大学计算机视觉组(Visual Geometry Group)在2014年提出的模型,具有较好的图像识别表现。

VGG16模型的架构如下:

输入层:输入图像的大小为224x224x3。

VGG块1

卷积层1:使用64个3x3大小的卷积核进行卷积操作,得到64张大小为224x224的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层2:使用64个3x3大小的卷积核进行卷积操作,得到64张大小为224x224的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层1:使用2x2的最大池化操作,将64张大小为224x224的特征图缩小为64张大小为112x112的特征图。采用SAME填充,步长为2。

VGG块2

卷积层3:使用128个3x3大小的卷积核进行卷积操作,得到128张大小为112x112的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层4:使用128个3x3大小的卷积核进行卷积操作,得到128张大小为112x112的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层2:使用2x2的最大池化操作,将128张大小为112x112的特征图缩小为128张大小为56x56的特征图。采用SAME填充,步长为2。

VGG块3

卷积层5:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层6:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层7:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层3:使用2x2的最大池化操作,将256张大小为56x56的特征图缩小为256张大小为28x28的特征图。采用SAME填充,步长为2。

VGG块4

卷积层8-10:使用512个3x3大小的卷积核进行卷积操作,得到512张大小为28x28的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层4:使用2x2的最大池化操作,将512张大小为28x28的特征图缩小为512张大小为14x14的特征图。采用SAME填充,步长为2。

VGG块5

卷积层11-13:使用512个3x3大小的卷积核进行卷积操作,得到512张大小为14x14的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层5:使用2x2的最大池化操作,将512张大小为14x14的特征图缩小为512张大小为7x7的特征图。采用SAME填充,步长为2。

全连接层

3个全连接层,第1、2个都有4096个输出通道,第3个全连接层则有1000个输出通道。

实现

class VGG16(nn.Module):def __init__(self):super(VGG16,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv2=nn.Sequential(nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv3=nn.Sequential(nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv4=nn.Sequential(nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv5=nn.Sequential(nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.feature=nn.Sequential(self.conv1,self.conv2,self.conv3,self.conv4,self.conv5,)self.flatten=nn.Flatten()self.fc=nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(4096,4096),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(4096,1000),#nn.Softmax(10))def forward(self,x):x=self.feature(x)# x=self.flatten(x)x = x.view(x.size(0), -1)x=self.fc(x)return x

查看结构

vgg = VGG16()
print(vgg)
x=torch.rand(1,3,224,224)
y=vgg(x)
print(y.shape)

利用VGG16进行CIFAR10分类

import torch.nn as nn
import torch
import torchvisionif(torch.cuda.is_available()):device = torch.device("cuda")print("使用GPU训练中:{}".format(torch.cuda.get_device_name()))
else:device = torch.device("cpu")print("使用CPU训练")

数据集

# transform的创建(compose方法)
from torchvision import transforms
def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4def load_data_cifar10(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
batch_size=4
train_iter, test_iter = load_data_cifar10(batch_size,resize=224)

超参数,优化器,损失函数

from torch import optim
net=VGG16()
lr=0.001
optimizer=optim.SGD(net.parameters(),lr=lr,momentum=0.9)
loss=nn.CrossEntropyLoss()
epochs=10

训练

def train(net,train_iter,test_iter,num_epochs, lr, device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)for epoch in range(num_epochs):net.train()train_step = 0for i, (X, y) in enumerate(train_iter):optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l=loss(y_hat,y)l.backward()optimizer.step()train_step+=1if(train_step%50==0):#每训练一百组输出一次损失print("第{}轮的第{}次训练的loss:{}".format((epoch+1),train_step,l.item()))

相关文章:

现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络…...

Java代码弱点与修复之——Dereference null return value(间接引用空返回值)

弱点描述 Dereference null return value,间接引用空返回值。是Coverity Scan静态代码分析工具中的一个警告,表示代码中有对可能为空(null)的方法或函数返回值进行间接引用(Dereference)操作。 该类型的漏洞可能会导致 NullPointerException 异常,并且会导致程序崩溃或…...

【冲刺蓝桥杯的最后30天】day3

大家好😃,我是想要慢慢变得优秀的向阳🌞同学👨‍💻,断更了整整一年,又开始恢复CSDN更新,从今天开始更新备战蓝桥30天系列,一共30天,如果对你有帮助或者正在备…...

光伏发电嵌入式ARM工控机

随着智慧电力技术的不断发展和普及,越来越多的电力设备和系统需要采用先进的控制和监测技术来实现自动化管理和优化运行。其中,嵌入式 ARM 控制器技术在智慧电力领域中得到了广泛应用。同时,导轨安装也是该技术的重要应用场景之一。 导轨安装…...

推荐 7 个 Vue.js 插件,也许你的项目用的上(五)

当我们可以通过使用库轻松实现相同的结果时,为什么还要编写自定义功能?开发人员最好的朋友和救星就是这些第三方库。我相信一个好的项目会利用一些可用的最佳库。Vue.js 是创建用户界面的最佳 JavaScript 框架之一。这篇文章是关于 Vue.js 的优秀库系列的…...

1.1基于知识图谱的项目实战:优酷搜索泛查询意图优化

NLU的技术实现主要分为在线识别和离线数据挖掘两块。 1.在线识别 NLU的在线识别技术栈如下图所示,共由下述2个部分组成: 第一个部分是Slot Filling(成分分析),负责对query进行实体识别和槽位抽取;第二部分Inention Detection(意图识别),根据提取的槽位进行意图的判定(目…...

[java Spring JdbcTemplate配合mysql实现数据批量删除

之前的文章 java Spring JdbcTemplate配合mysql实现数据批量添加和文章java Spring JdbcTemplate配合mysql实现数据批量修改 先后讲解了 mysql数据库的批量添加和批量删除操作 会了这两个操作之后 批量删除就不要太简单 我们看到数据库 这里 我们用的是mysql工具 这里 我们有…...

uos 20 统信 fprintd 记录

uos 20 统信 fprintd 记录 sudo busctl deepin-authenticate.service /usr/lib/systemd/system/deepin-authenticate.service [Unit] DescriptionDeepin Authentication[Service] Typedbus BusNamecom.deepin.daemon.Authenticate ExecStart/usr/lib/deepin-authenticate/d…...

vue移动端h5,文本溢出显示省略号,且展示‘更多’按钮

问题: 元素宽度100%,宽度会随着浏览器缩放而变化。元素内文本超过4行时显示省略号,同时展示‘更多’按钮,点击更多按钮展示全部文本。如下图所示 超出四行显示省略号(…)的代码 .content{overflow:hidden;text-overflow: elli…...

php宝塔搭建部署实战兰空图床程序网站PHP源码

大家好啊,我是测评君,欢迎来到web测评。 本期给大家带来一套Lsky Pro兰空图床程序网站PHP的源码。感兴趣的朋友可以自行下载学习。 技术架构 PHP8.0 nginx mysql5.7 JS CSS HTMLcnetos7以上 宝塔面板 文字搭建教程 下载源码,宝塔添加…...

软件测试面试:拿到一个产品(版本)如何开展测试?

产品提测后,如何开展测试? 我们都了解软件测试的执行流程,......提测-冒烟测试-详细测试-提交缺陷报告-回归测试,但软件测试并不总是线性过程,它甚至可能是螺旋结构,不断地试错,不断地迭代&…...

【Opencv项目实战】图像的像素值反转

文章目录一、项目思路二、算法详解2.1、获取图像信息2.2、新建模板2.3、图像通道顺序三、项目实战:彩图的像素值反转(方法一)四、项目实战:彩图的像素值反转(方法二)五、项目实战:彩图转换为灰图…...

Swagger生成接口在线文档

OpenAPI规范(OpenAPI Specification 简称OAS)是Linux基金会的一个项目,试图通过定义一种用来描述API格式或API定义的语言,来规范RESTful服务开发过程,目前版本是V3.0,并且已经发布并开源在github上。&#…...

104.第十九章 MySQL数据库 -- MySQL主从复制、 级联复制和双主复制(十四)

6.1.2 实现主从复制配置 参考官网 https://dev.mysql.com/doc/refman/8.0/en/replication-configuration.html https://dev.mysql.com/doc/refman/5.7/en/replication-configuration.html https://dev.mysql.com/doc/refman/5.5/en/replication-configuration.html https://m…...

第一次使用Python for Qt中的问题

在创建带有form的python for qt的时候,使用的库是pySide6,而不是pyqt。 因此,需要安装pyside6。 Running "/usr/bin/python3 -m pip install PySide6 --user" to install PySide6. ERROR: Could not find a version that satisfi…...

.Net Core WebApi 在Linux系统Deepin上部署Nginx并使用(一)

前言: Deepin最初是基于Ubuntu的发行版 2015年脱离Ubuntu开发,开始基于Ubuntu上游Debian操作系统 2019年脱离Debian,直接基于Linux开发,真正属于自己的上游Linux系统发行版 2022年8月,新版《Deepin V23》我下载开始了我…...

Java——打开轮盘锁

题目链接 leetcode在线oj题——打开轮盘锁 题目描述 你有一个带有四个圆形拨轮的转盘锁。每个拨轮都有10个数字: ‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’ 。每个拨轮可以自由旋转:例如把 ‘9’ 变为 ‘0’&#xff0…...

JavaScript(2)

一、事件 HTML事件是发生在hTML元素上的“事情”。比如&#xff1a;按钮被点击、鼠标移动到元素上等… 事件绑定 方式一&#xff1a;通过HTML标签中的事件属性进行绑定 <input type"button" value"点我" onclick"on()"><script>fun…...

FFMPEG 安装教程windowslinux(CentOS版)

ps: 从笔记中迁移至blog 版本概述 Windows 基于win10 Linux 基于CentOS 7.6 一.Windows安装笔记 1.下载安装 https://ffmpeg.org/download.html 2 解压缩&#xff0c;拷贝到需要目录&#xff0c;重命名 3 追加环境变量 echo %PATH%setx /m PATH "%PATH%;F:\dev_tools\…...

【虹科案例】虹科任意波形发生器在量子计算中的应用

虹科AWG在量子计算中的应用精度在研究中始终很重要&#xff0c;很少有研究领域需要比量子研究更高的精度。奥地利因斯布鲁克大学的量子光学和量子信息研究所需要一个任意波形发生器&#xff08;AWG&#xff09;来为他们的研究生成各种各样的信号。01无线电频率第一个应用是在射…...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

逻辑回归:给不确定性划界的分类大师

想象你是一名医生。面对患者的检查报告&#xff08;肿瘤大小、血液指标&#xff09;&#xff0c;你需要做出一个**决定性判断**&#xff1a;恶性还是良性&#xff1f;这种“非黑即白”的抉择&#xff0c;正是**逻辑回归&#xff08;Logistic Regression&#xff09;** 的战场&a…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作

一、上下文切换 即使单核CPU也可以进行多线程执行代码&#xff0c;CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短&#xff0c;所以CPU会不断地切换线程执行&#xff0c;从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

腾讯云V3签名

想要接入腾讯云的Api&#xff0c;必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口&#xff0c;但总是卡在签名这一步&#xff0c;最后放弃选择SDK&#xff0c;这次终于自己代码实现。 可能腾讯云翻新了接口文档&#xff0c;现在阅读起来&#xff0c;清晰了很多&…...

掌握 HTTP 请求:理解 cURL GET 语法

cURL 是一个强大的命令行工具&#xff0c;用于发送 HTTP 请求和与 Web 服务器交互。在 Web 开发和测试中&#xff0c;cURL 经常用于发送 GET 请求来获取服务器资源。本文将详细介绍 cURL GET 请求的语法和使用方法。 一、cURL 基本概念 cURL 是 "Client URL" 的缩写…...

水泥厂自动化升级利器:Devicenet转Modbus rtu协议转换网关

在水泥厂的生产流程中&#xff0c;工业自动化网关起着至关重要的作用&#xff0c;尤其是JH-DVN-RTU疆鸿智能Devicenet转Modbus rtu协议转换网关&#xff0c;为水泥厂实现高效生产与精准控制提供了有力支持。 水泥厂设备众多&#xff0c;其中不少设备采用Devicenet协议。Devicen…...

[USACO23FEB] Bakery S

题目描述 Bessie 开了一家面包店! 在她的面包店里&#xff0c;Bessie 有一个烤箱&#xff0c;可以在 t C t_C tC​ 的时间内生产一块饼干或在 t M t_M tM​ 单位时间内生产一块松糕。 ( 1 ≤ t C , t M ≤ 10 9 ) (1 \le t_C,t_M \le 10^9) (1≤tC​,tM​≤109)。由于空间…...

DAY 45 超大力王爱学Python

来自超大力王的友情提示&#xff1a;在用tensordoard的时候一定一定要用绝对位置&#xff0c;例如&#xff1a;tensorboard --logdir"D:\代码\archive (1)\runs\cifar10_mlp_experiment_2" 不然读取不了数据 知识点回顾&#xff1a; tensorboard的发展历史和原理tens…...