当前位置: 首页 > news >正文

现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

专栏:神经网络复现目录


本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet)。
文章部分文字和代码来自《动手学深度学习》

文章目录

  • 使用块的网络(VGG)
  • VGG块
    • 定义
    • 实现
  • VGG16
    • 模型设计
    • 实现
  • 利用VGG16进行CIFAR10分类
    • 数据集
    • 超参数,优化器,损失函数
    • 训练


使用块的网络(VGG)

VGG是一种深度卷积神经网络,由牛津大学视觉几何组(Visual Geometry Group)在2014年提出。它是由多个卷积层和池化层组成的深度神经网络,具有很强的图像分类能力,特别是在图像识别领域,取得了很好的成果。

VGG的特点在于,它使用相对较小的卷积核(3x3),但是通过叠加多个卷积层和池化层,增加了网络的深度,从而达到更好的图像分类性能。VGG网络包含了多个版本,以卷积层数目为标志,如VGG16和VGG19等,其中VGG16和VGG19是最著名的两个版本。

VGG网络的设计非常简单和规整,容易理解和实现,因此也成为了很多深度学习新手的入门模型。

下图为VGG的六个版本,比较实用的是VGG16和VGG19,本文以VGG16为例子进行讲解
在这里插入图片描述

VGG块

定义

VGG块是VGG网络中的一个基本组成单元,由若干个卷积层和池化层组成,通常用于提取输入图像的特征。每个VGG块都由连续的1或2个卷积层,和一个最大池化层组成。其中,卷积层的卷积核大小都是3x3,而池化层的窗口大小通常是2x2。在每个VGG块中,卷积层的输出通道数都相同,可以通过超参数进行控制。

具体来说,假设一个VGG块由k个卷积层和一个池化层组成,输入为xxx,则该块的输出可以表示为:

VGG(x)=Pool(convk(convk−1(⋯conv1(x)))).\text{VGG}(x) = \text{Pool}(\text{conv}k(\text{conv}{k-1}(\cdots\text{conv}_1(x)))).VGG(x)=Pool(convk(convk1(conv1(x)))).

其中,convi(⋅)\text{conv}_i(\cdot)convi()表示第iii个卷积层,Pool(⋅)\text{Pool}(\cdot)Pool()表示池化层。在VGG块中,每个卷积层都会使用ReLU激活函数进行非线性变换,而最大池化层则用于下采样和特征压缩。

在VGG网络中,通常通过叠加多个VGG块来构建网络结构。通过增加VGG块的数量,可以增加网络的深度和宽度,从而提高网络的表达能力和泛化性能。

实现

self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)

inplace=True 表示对于输入的张量进行原地操作,即直接对原始的输入张量进行修改,而不是创建一个新的张量。这样做可以节省内存,但会覆盖原始的输入张量,可能会对后续的计算产生影响。因此,当我们需要保留原始的输入张量时,可以将 inplace 参数设置为 False。

VGG16

模型设计

VGG16是一个卷积神经网络模型,包含13个卷积层、5个池化层和3个全连接层,是由牛津大学计算机视觉组(Visual Geometry Group)在2014年提出的模型,具有较好的图像识别表现。

VGG16模型的架构如下:

输入层:输入图像的大小为224x224x3。

VGG块1

卷积层1:使用64个3x3大小的卷积核进行卷积操作,得到64张大小为224x224的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层2:使用64个3x3大小的卷积核进行卷积操作,得到64张大小为224x224的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层1:使用2x2的最大池化操作,将64张大小为224x224的特征图缩小为64张大小为112x112的特征图。采用SAME填充,步长为2。

VGG块2

卷积层3:使用128个3x3大小的卷积核进行卷积操作,得到128张大小为112x112的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层4:使用128个3x3大小的卷积核进行卷积操作,得到128张大小为112x112的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层2:使用2x2的最大池化操作,将128张大小为112x112的特征图缩小为128张大小为56x56的特征图。采用SAME填充,步长为2。

VGG块3

卷积层5:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层6:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

卷积层7:使用256个3x3大小的卷积核进行卷积操作,得到256张大小为56x56的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层3:使用2x2的最大池化操作,将256张大小为56x56的特征图缩小为256张大小为28x28的特征图。采用SAME填充,步长为2。

VGG块4

卷积层8-10:使用512个3x3大小的卷积核进行卷积操作,得到512张大小为28x28的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层4:使用2x2的最大池化操作,将512张大小为28x28的特征图缩小为512张大小为14x14的特征图。采用SAME填充,步长为2。

VGG块5

卷积层11-13:使用512个3x3大小的卷积核进行卷积操作,得到512张大小为14x14的特征图。采用SAME填充,步长为1。然后再通过ReLU非线性激活函数进行激活。

池化层5:使用2x2的最大池化操作,将512张大小为14x14的特征图缩小为512张大小为7x7的特征图。采用SAME填充,步长为2。

全连接层

3个全连接层,第1、2个都有4096个输出通道,第3个全连接层则有1000个输出通道。

实现

class VGG16(nn.Module):def __init__(self):super(VGG16,self).__init__()self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv2=nn.Sequential(nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv3=nn.Sequential(nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv4=nn.Sequential(nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.conv5=nn.Sequential(nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1,stride=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2),)self.feature=nn.Sequential(self.conv1,self.conv2,self.conv3,self.conv4,self.conv5,)self.flatten=nn.Flatten()self.fc=nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(4096,4096),nn.ReLU(inplace=True),nn.Dropout(0.4),nn.Linear(4096,1000),#nn.Softmax(10))def forward(self,x):x=self.feature(x)# x=self.flatten(x)x = x.view(x.size(0), -1)x=self.fc(x)return x

查看结构

vgg = VGG16()
print(vgg)
x=torch.rand(1,3,224,224)
y=vgg(x)
print(y.shape)

利用VGG16进行CIFAR10分类

import torch.nn as nn
import torch
import torchvisionif(torch.cuda.is_available()):device = torch.device("cuda")print("使用GPU训练中:{}".format(torch.cuda.get_device_name()))
else:device = torch.device("cpu")print("使用CPU训练")

数据集

# transform的创建(compose方法)
from torchvision import transforms
def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4def load_data_cifar10(batch_size, resize=None):  #@savetrans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True)return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
batch_size=4
train_iter, test_iter = load_data_cifar10(batch_size,resize=224)

超参数,优化器,损失函数

from torch import optim
net=VGG16()
lr=0.001
optimizer=optim.SGD(net.parameters(),lr=lr,momentum=0.9)
loss=nn.CrossEntropyLoss()
epochs=10

训练

def train(net,train_iter,test_iter,num_epochs, lr, device):def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)for epoch in range(num_epochs):net.train()train_step = 0for i, (X, y) in enumerate(train_iter):optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l=loss(y_hat,y)l.backward()optimizer.step()train_step+=1if(train_step%50==0):#每训练一百组输出一次损失print("第{}轮的第{}次训练的loss:{}".format((epoch+1),train_step,l.item()))

相关文章:

现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络…...

Java代码弱点与修复之——Dereference null return value(间接引用空返回值)

弱点描述 Dereference null return value,间接引用空返回值。是Coverity Scan静态代码分析工具中的一个警告,表示代码中有对可能为空(null)的方法或函数返回值进行间接引用(Dereference)操作。 该类型的漏洞可能会导致 NullPointerException 异常,并且会导致程序崩溃或…...

【冲刺蓝桥杯的最后30天】day3

大家好😃,我是想要慢慢变得优秀的向阳🌞同学👨‍💻,断更了整整一年,又开始恢复CSDN更新,从今天开始更新备战蓝桥30天系列,一共30天,如果对你有帮助或者正在备…...

光伏发电嵌入式ARM工控机

随着智慧电力技术的不断发展和普及,越来越多的电力设备和系统需要采用先进的控制和监测技术来实现自动化管理和优化运行。其中,嵌入式 ARM 控制器技术在智慧电力领域中得到了广泛应用。同时,导轨安装也是该技术的重要应用场景之一。 导轨安装…...

推荐 7 个 Vue.js 插件,也许你的项目用的上(五)

当我们可以通过使用库轻松实现相同的结果时,为什么还要编写自定义功能?开发人员最好的朋友和救星就是这些第三方库。我相信一个好的项目会利用一些可用的最佳库。Vue.js 是创建用户界面的最佳 JavaScript 框架之一。这篇文章是关于 Vue.js 的优秀库系列的…...

1.1基于知识图谱的项目实战:优酷搜索泛查询意图优化

NLU的技术实现主要分为在线识别和离线数据挖掘两块。 1.在线识别 NLU的在线识别技术栈如下图所示,共由下述2个部分组成: 第一个部分是Slot Filling(成分分析),负责对query进行实体识别和槽位抽取;第二部分Inention Detection(意图识别),根据提取的槽位进行意图的判定(目…...

[java Spring JdbcTemplate配合mysql实现数据批量删除

之前的文章 java Spring JdbcTemplate配合mysql实现数据批量添加和文章java Spring JdbcTemplate配合mysql实现数据批量修改 先后讲解了 mysql数据库的批量添加和批量删除操作 会了这两个操作之后 批量删除就不要太简单 我们看到数据库 这里 我们用的是mysql工具 这里 我们有…...

uos 20 统信 fprintd 记录

uos 20 统信 fprintd 记录 sudo busctl deepin-authenticate.service /usr/lib/systemd/system/deepin-authenticate.service [Unit] DescriptionDeepin Authentication[Service] Typedbus BusNamecom.deepin.daemon.Authenticate ExecStart/usr/lib/deepin-authenticate/d…...

vue移动端h5,文本溢出显示省略号,且展示‘更多’按钮

问题: 元素宽度100%,宽度会随着浏览器缩放而变化。元素内文本超过4行时显示省略号,同时展示‘更多’按钮,点击更多按钮展示全部文本。如下图所示 超出四行显示省略号(…)的代码 .content{overflow:hidden;text-overflow: elli…...

php宝塔搭建部署实战兰空图床程序网站PHP源码

大家好啊,我是测评君,欢迎来到web测评。 本期给大家带来一套Lsky Pro兰空图床程序网站PHP的源码。感兴趣的朋友可以自行下载学习。 技术架构 PHP8.0 nginx mysql5.7 JS CSS HTMLcnetos7以上 宝塔面板 文字搭建教程 下载源码,宝塔添加…...

软件测试面试:拿到一个产品(版本)如何开展测试?

产品提测后,如何开展测试? 我们都了解软件测试的执行流程,......提测-冒烟测试-详细测试-提交缺陷报告-回归测试,但软件测试并不总是线性过程,它甚至可能是螺旋结构,不断地试错,不断地迭代&…...

【Opencv项目实战】图像的像素值反转

文章目录一、项目思路二、算法详解2.1、获取图像信息2.2、新建模板2.3、图像通道顺序三、项目实战:彩图的像素值反转(方法一)四、项目实战:彩图的像素值反转(方法二)五、项目实战:彩图转换为灰图…...

Swagger生成接口在线文档

OpenAPI规范(OpenAPI Specification 简称OAS)是Linux基金会的一个项目,试图通过定义一种用来描述API格式或API定义的语言,来规范RESTful服务开发过程,目前版本是V3.0,并且已经发布并开源在github上。&#…...

104.第十九章 MySQL数据库 -- MySQL主从复制、 级联复制和双主复制(十四)

6.1.2 实现主从复制配置 参考官网 https://dev.mysql.com/doc/refman/8.0/en/replication-configuration.html https://dev.mysql.com/doc/refman/5.7/en/replication-configuration.html https://dev.mysql.com/doc/refman/5.5/en/replication-configuration.html https://m…...

第一次使用Python for Qt中的问题

在创建带有form的python for qt的时候,使用的库是pySide6,而不是pyqt。 因此,需要安装pyside6。 Running "/usr/bin/python3 -m pip install PySide6 --user" to install PySide6. ERROR: Could not find a version that satisfi…...

.Net Core WebApi 在Linux系统Deepin上部署Nginx并使用(一)

前言: Deepin最初是基于Ubuntu的发行版 2015年脱离Ubuntu开发,开始基于Ubuntu上游Debian操作系统 2019年脱离Debian,直接基于Linux开发,真正属于自己的上游Linux系统发行版 2022年8月,新版《Deepin V23》我下载开始了我…...

Java——打开轮盘锁

题目链接 leetcode在线oj题——打开轮盘锁 题目描述 你有一个带有四个圆形拨轮的转盘锁。每个拨轮都有10个数字: ‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’ 。每个拨轮可以自由旋转:例如把 ‘9’ 变为 ‘0’&#xff0…...

JavaScript(2)

一、事件 HTML事件是发生在hTML元素上的“事情”。比如&#xff1a;按钮被点击、鼠标移动到元素上等… 事件绑定 方式一&#xff1a;通过HTML标签中的事件属性进行绑定 <input type"button" value"点我" onclick"on()"><script>fun…...

FFMPEG 安装教程windowslinux(CentOS版)

ps: 从笔记中迁移至blog 版本概述 Windows 基于win10 Linux 基于CentOS 7.6 一.Windows安装笔记 1.下载安装 https://ffmpeg.org/download.html 2 解压缩&#xff0c;拷贝到需要目录&#xff0c;重命名 3 追加环境变量 echo %PATH%setx /m PATH "%PATH%;F:\dev_tools\…...

【虹科案例】虹科任意波形发生器在量子计算中的应用

虹科AWG在量子计算中的应用精度在研究中始终很重要&#xff0c;很少有研究领域需要比量子研究更高的精度。奥地利因斯布鲁克大学的量子光学和量子信息研究所需要一个任意波形发生器&#xff08;AWG&#xff09;来为他们的研究生成各种各样的信号。01无线电频率第一个应用是在射…...

docker详细操作--未完待续

docker介绍 docker官网: Docker&#xff1a;加速容器应用程序开发 harbor官网&#xff1a;Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台&#xff0c;用于将应用程序及其依赖项&#xff08;如库、运行时环…...

Golang 面试经典题:map 的 key 可以是什么类型?哪些不可以?

Golang 面试经典题&#xff1a;map 的 key 可以是什么类型&#xff1f;哪些不可以&#xff1f; 在 Golang 的面试中&#xff0c;map 类型的使用是一个常见的考点&#xff0c;其中对 key 类型的合法性 是一道常被提及的基础却很容易被忽视的问题。本文将带你深入理解 Golang 中…...

AspectJ 在 Android 中的完整使用指南

一、环境配置&#xff08;Gradle 7.0 适配&#xff09; 1. 项目级 build.gradle // 注意&#xff1a;沪江插件已停更&#xff0c;推荐官方兼容方案 buildscript {dependencies {classpath org.aspectj:aspectjtools:1.9.9.1 // AspectJ 工具} } 2. 模块级 build.gradle plu…...

Web 架构之 CDN 加速原理与落地实践

文章目录 一、思维导图二、正文内容&#xff08;一&#xff09;CDN 基础概念1. 定义2. 组成部分 &#xff08;二&#xff09;CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 &#xff08;三&#xff09;CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 &#xf…...

VM虚拟机网络配置(ubuntu24桥接模式):配置静态IP

编辑-虚拟网络编辑器-更改设置 选择桥接模式&#xff0c;然后找到相应的网卡&#xff08;可以查看自己本机的网络连接&#xff09; windows连接的网络点击查看属性 编辑虚拟机设置更改网络配置&#xff0c;选择刚才配置的桥接模式 静态ip设置&#xff1a; 我用的ubuntu24桌…...

音视频——I2S 协议详解

I2S 协议详解 I2S (Inter-IC Sound) 协议是一种串行总线协议&#xff0c;专门用于在数字音频设备之间传输数字音频数据。它由飞利浦&#xff08;Philips&#xff09;公司开发&#xff0c;以其简单、高效和广泛的兼容性而闻名。 1. 信号线 I2S 协议通常使用三根或四根信号线&a…...

代码规范和架构【立芯理论一】(2025.06.08)

1、代码规范的目标 代码简洁精炼、美观&#xff0c;可持续性好高效率高复用&#xff0c;可移植性好高内聚&#xff0c;低耦合没有冗余规范性&#xff0c;代码有规可循&#xff0c;可以看出自己当时的思考过程特殊排版&#xff0c;特殊语法&#xff0c;特殊指令&#xff0c;必须…...

CSS3相关知识点

CSS3相关知识点 CSS3私有前缀私有前缀私有前缀存在的意义常见浏览器的私有前缀 CSS3基本语法CSS3 新增长度单位CSS3 新增颜色设置方式CSS3 新增选择器CSS3 新增盒模型相关属性box-sizing 怪异盒模型resize调整盒子大小box-shadow 盒子阴影opacity 不透明度 CSS3 新增背景属性ba…...

用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章

用 Rust 重写 Linux 内核模块实战&#xff1a;迈向安全内核的新篇章 ​​摘要&#xff1a;​​ 操作系统内核的安全性、稳定性至关重要。传统 Linux 内核模块开发长期依赖于 C 语言&#xff0c;受限于 C 语言本身的内存安全和并发安全问题&#xff0c;开发复杂模块极易引入难以…...

Vue 实例的数据对象详解

Vue 实例的数据对象详解 在 Vue 中,数据对象是响应式系统的核心,也是组件状态的载体。理解数据对象的原理和使用方式是成为 Vue 专家的关键一步。我将从多个维度深入剖析 Vue 实例的数据对象。 一、数据对象的定义方式 1. Options API 中的定义 在 Options API 中,使用 …...