当前位置: 首页 > news >正文

经典卷积神经网络 - VGG

使用块的网络 - VGG。

使用多个 3 × 3 3\times 3 3×3的要比使用少个 5 × 5 5\times 5 5×5的效果要好。
在这里插入图片描述
在这里插入图片描述
VGG全称是Visual Geometry Group,因为是由Oxford的Visual Geometry Group提出的。AlexNet问世之后,很多学者通过改进AlexNet的网络结构来提高自己的准确率,主要有两个方向:小卷积核和多尺度。而VGG的作者们则选择了另外一个方向,即加深网络深度。

网络架构

卷积网络的输入是224 * 224RGB图像,整个网络的组成是非常格式化的,基本上都用的是3 * 3的卷积核以及 2 * 2max pooling,少部分网络加入了1 * 1的卷积核。因为想要体现出“上下左右中”的概念,3*3的卷积核已经是最小的尺寸了。

VGG16相比之前网络的改进是3个33卷积核来代替7x7卷积核,2个33卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,减少参数,提升了网络的深度。

多个VGG块后接全连接层。

不同次数的重复块得到不同的架构,如VGG-16,VGG-19等。

VGG:更大更深的AlexNet。

总结:

  • VGG使用可重复使用的卷积块来构建深度卷积神经网络
  • 不同的卷积块个数和超参数可以得到不同复杂度的变种

代码实现

使用数据集CIFAR

model.py

import torch
from torch import nnclass Vgg16(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(64,64,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(64,128,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(128,128,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(128,256,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(256,256,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(256,256,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(256,512,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(512,512,kernel_size=3,padding=1),nn.ReLU(),nn.Conv2d(512,512,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.Conv2d(512, 512, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(2,2),nn.Flatten(),nn.Linear(7*7*512,4096),nn.Dropout(0.5),nn.Linear(4096,4096),nn.Dropout(0.5),nn.Linear(4096,10))def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':net = Vgg16()x = torch.ones((64,3,244,244))output = net(x)print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import Vgg16# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = Vgg16()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

相关文章:

经典卷积神经网络 - VGG

使用块的网络 - VGG。 使用多个 3 3 3\times 3 33的要比使用少个 5 5 5\times 5 55的效果要好。 VGG全称是Visual Geometry Group,因为是由Oxford的Visual Geometry Group提出的。AlexNet问世之后,很多学者通过改进AlexNet的网络结构来提高自己的准确…...

系统集成测试(SIT)/系统测试(ST)/用户验收测试(UAT)

文章目录 单元测试集成测试系统测试用户验收测试黑盒测试白盒测试压力测试性能测试容量测试安全测试SIT和UAT的区别 单元测试 英文 unit testing,缩写 UT。测试粒度最小,一般由开发小组采用白盒方式来测试,主要测试单元是否符合“设计”。 …...

Android Gradle8.0以上多渠道写法以及针对不同渠道导入包的方式,填坑!

目录 多渠道的写法 针对多渠道引用不同的包 There was a failure while populating the build operation queue: Could not stat file E:\xxxx\xxxx\xxxx\app\src\UAT\libsUAT\xxx-provider(?)-xx.aar 最近升级了Gradle8.3之后,从Groovy 迁移到 Kotlin&#xff…...

hdlbits系列verilog解答(向量门操作)-14

文章目录 一、问题描述二、verilog源码三、仿真结果 一、问题描述 构建一个具有两个 3 位输入的电路,用于计算两个向量的按位 OR、两个向量的逻辑 OR 以及两个向量的逆 (NOT)。将b反相输出到out_not上半部分,将a 的反相输出到out…...

工厂模式(初学)

工厂模式 1、简单工厂模式 是一种创建型设计模式,旨在通过一个工厂类(简单工厂)来封装对象的实例化过程 运算类 public class Operation { //这个是父类private double num1; //运算器中的两个值private double num2;public double getNu…...

python试题实例

背景: 在外地出差,突然接到单位电话,让自己出一些python考题供新人教育训练使用,以下是10道Python编程试题及其答案: 1.试题:请写一个Python程序,计算并输出1到100之间所有偶数的和。 答案&am…...

Java Heap Space问题解析与解决方案(InsCode AI 创作助手)

Heap Space问题是Java开发中常见的内存溢出问题之一,我们需要理解其原因和表现形式,然后通过优化代码、增加JVM内存和使用垃圾回收机制等方法来解决。 一、常见报错 java.lang.OutOfMemoryError: Java heap space二、Heap Space问题的原因 对象创建过…...

基于遥感影像的分类技术(监督/非监督和面向对象的分类技术)

遥感图像分类技术 “图像分类是将土地覆盖类别分配给像素的过程。例如,类别包括水、城市、森林、农业和草原。”前言 – 人工智能教程 什么是遥感图像分类? 遥感图像分类技术的三种主要类型是: 无监督图像分类监督图像分类基于对象的图像分析…...

插入兄弟元素 insertAfter() 方法

insertAfter() 方法在被选元素后插入 HTML 元素。 提示&#xff1a;如需在被选元素前插入 HTML 元素&#xff0c;请使用 insertBefore() 方法。 语法 $(content).insertAfter(selector)例子&#xff1a; $("<span>Hello world!</span>").insertAfter(…...

【C++项目】高并发内存池第二讲中心缓存CentralCache框架+核心实现

CentralCache 1.框架介绍2.核心功能3.核心函数实现介绍3.1SpanSpanList介绍3.2CentralCache.h3.3CentralCache.cpp3.4TreadCache申请内存函数介绍3.5慢反馈算法 1.框架介绍 回顾一下ThreadCache的设计&#xff1a; 如图所示&#xff0c;ThreadCache设计是一个哈希桶结构&…...

Git基础教程

一、Git简介 1、什么是Git&#xff1f; Git是一个开源的分布式版本控制系统&#xff0c;用于敏捷高效地处理任何或大或小的项目。 Git是Linus Torvalds为了帮助管理Linux内核开发而开发的一个开放源代码的版本控制软件。 Git与常用的版本控制工具CVS、Subversion等不同&#…...

stm32外部时钟为12MHZ,修改代码适配

代码默认是8MHZ的&#xff0c;修改2个地方&#xff1a; 第一个地方是这个文件的这里&#xff1a; 第二个地方是找到这个函数&#xff1a; 修改第二个地方的这里&#xff1a;...

【数据结构】八大排序

目录 1. 排序的概念及其作用 1.1 排序的概念 1.2 排序运用 1.3 常见的排序算法 2. 常见排序算法的实现 2.1 插入排序 2.1.1 基本思想 2.1.2 直接插入排序 2.1.3 希尔排序&#xff08;缩小增量排序&#xff09; 2.2 选择排序 2.2.1 基本思想 2.2.2 直接选择排序 2.2…...

MYSQL(事务+锁+MVCC+SQL执行流程)理解

一)事务的特性: 一致性:主要是在数据层面来说&#xff0c;不能说执行扣减库存的操作的时候用户订单数据却没有生成 原子性:主要是在操作层面来说&#xff0c;要么操作完成&#xff0c;要么操作全部回滚&#xff1b; 隔离性:是自己的事务操作自己的数据&#xff0c;不会受到到其…...

解密一致性哈希算法:实现高可用和负载均衡的秘诀

解密一致性哈希算法&#xff1a;实现高可用和负载均衡的秘诀 前言第一&#xff1a;分布式系统中的数据分布问题&#xff0c;为什么需要一致性哈希算法第二&#xff1a;一致性hash算法的原理第三&#xff1a;一致性哈希算法的优点和局限性第四&#xff1a;一致性哈希算法的安全性…...

Python脚本:让工作自动化起来

Python是一种流行的编程语言&#xff0c;以其简洁和易读性而闻名。它提供了大量的库和模块&#xff0c;使其成为自动化各种任务的绝佳选择。 本文将探讨Python脚本及其代码&#xff0c;可以帮助您自动化各种任务并提高工作效率。无论您是开发人员、数据分析师还是只是想简化工…...

香港科技大学广州|可持续能源与环境学域博士招生宣讲会—广州大学城专场!!!(暨全额奖学金政策)

香港科技大学广州&#xff5c;可持续能源与环境学域博士招生宣讲会—广州大学城专场&#xff01;&#xff01;&#xff01;&#xff08;暨全额奖学金政策&#xff09; “面向未来改变游戏规则的——可持续能源与环境学域” &#xfffd;&#xfffd;&#xfffd;专注于能源环…...

uni-app:多种方法写入图片路径

一、文件在前端文件夹中 1、相对路径引用 从当前文件所在位置开始寻找图片文件的路径。../../ 表示返回两级目录&#xff0c;即从当前文件所在的 wind.vue 所在的位置开始向上回退两级。接着&#xff0c;进入 static 目录&#xff0c;再进入 look 目录&#xff0c;最后定位到 …...

共谋工业3D视觉发展,深眸科技以自研解决方案拓宽场景应用边界

随着中国工业领域自动化程度逐渐攀升&#xff0c;“机器换人”这一需求进一步提升。在传统2D工业视觉易受环境光干扰、无法进一步获取物体深度信息的限制条件下&#xff0c;工业3D视觉凭借着更强的空间和深度感知能力&#xff0c;以及通过点云数据获取物体距离和三维坐标信息的…...

前端面试基础面试题——11

1.什么是 vue 的计算属性&#xff1f; 2.vue怎么实现页面的权限控制 3.watch的作用是什么 4.响应式系统的基本原理 5.vue-loader 是什么&#xff1f;使用它的用途有哪些&#xff1f; 6.vuex 工作原理详解 7.vuex 有哪几种属性&#xff1f; 8.什么是 MVVM&#xff1f; 9…...

wordpress后台更新后 前端没变化的解决方法

使用siteground主机的wordpress网站&#xff0c;会出现更新了网站内容和修改了php模板文件、js文件、css文件、图片文件后&#xff0c;网站没有变化的情况。 不熟悉siteground主机的新手&#xff0c;遇到这个问题&#xff0c;就很抓狂&#xff0c;明明是哪都没操作错误&#x…...

SkyWalking 10.2.0 SWCK 配置过程

SkyWalking 10.2.0 & SWCK 配置过程 skywalking oap-server & ui 使用Docker安装在K8S集群以外&#xff0c;K8S集群中的微服务使用initContainer按命名空间将skywalking-java-agent注入到业务容器中。 SWCK有整套的解决方案&#xff0c;全安装在K8S群集中。 具体可参…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

tree 树组件大数据卡顿问题优化

问题背景 项目中有用到树组件用来做文件目录&#xff0c;但是由于这个树组件的节点越来越多&#xff0c;导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多&#xff0c;导致的浏览器卡顿&#xff0c;这里很明显就需要用到虚拟列表的技术&…...

mac 安装homebrew (nvm 及git)

mac 安装nvm 及git 万恶之源 mac 安装这些东西离不开Xcode。及homebrew 一、先说安装git步骤 通用&#xff1a; 方法一&#xff1a;使用 Homebrew 安装 Git&#xff08;推荐&#xff09; 步骤如下&#xff1a;打开终端&#xff08;Terminal.app&#xff09; 1.安装 Homebrew…...

C# 表达式和运算符(求值顺序)

求值顺序 表达式可以由许多嵌套的子表达式构成。子表达式的求值顺序可以使表达式的最终值发生 变化。 例如&#xff0c;已知表达式3*52&#xff0c;依照子表达式的求值顺序&#xff0c;有两种可能的结果&#xff0c;如图9-3所示。 如果乘法先执行&#xff0c;结果是17。如果5…...

【Linux】自动化构建-Make/Makefile

前言 上文我们讲到了Linux中的编译器gcc/g 【Linux】编译器gcc/g及其库的详细介绍-CSDN博客 本来我们将一个对于编译来说很重要的工具&#xff1a;make/makfile 1.背景 在一个工程中源文件不计其数&#xff0c;其按类型、功能、模块分别放在若干个目录中&#xff0c;mak…...

前端中slice和splic的区别

1. slice slice 用于从数组中提取一部分元素&#xff0c;返回一个新的数组。 特点&#xff1a; 不修改原数组&#xff1a;slice 不会改变原数组&#xff0c;而是返回一个新的数组。提取数组的部分&#xff1a;slice 会根据指定的开始索引和结束索引提取数组的一部分。不包含…...