当前位置: 首页 > news >正文

经典卷积神经网络 - NIN

网络中的网络,NIN。

AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。

AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。

网络中的网络NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。

NIN块

image-20231023194711049

一个卷积层后跟两个全连接层。

  • 步幅为1,无填充,输出形状跟卷积层输出一样。
  • 起到全连接层的作用。

NIN网络结构
在这里插入图片描述
image-20231024090401226

  • 无全连接层

  • 交替使用NIN块和步幅为2的最大池化层

    逐步减小高宽和增大通道数

  • 最后使用全局平均池化层得到输出

    其输入通道数是类别数

此网络结构总计4层: 3mlpconv + 1global_average_pooling

优点:

  1. 提供了网络层间映射的一种新可能;
  2. 增加了网络卷积层的非线性能力。

总结:

  • NIN块使用卷积层加上个 1 × 1 1\times 1 1×1卷积,后者对每个像素增加了非线性性
  • NIN使用全局平均池化层来替代VGG和AlexNet中的全连接层,不容易过拟合,更少的参数个数

代码实现

使用CIFAR-10数据集。

maxpooling不改变通道数,只改变长和宽

model.py

import torch
from torch import nn# nin块
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),)# 构建网络
class NIN(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nin_block(3,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':nin = NIN()x = torch.ones((64,3,244,244))output = nin(x)print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import NIN# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = NIN()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

相关文章:

经典卷积神经网络 - NIN

网络中的网络,NIN。 AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成…...

leetcode_2558 从数量最多的堆取走礼物

1. 题意 给定一个数组,每次从中取走最大的数,返回开根号向下取整送入堆中,最后计算总和。 从数量最多的堆取走礼物 2. 题解 直接用堆模拟即可 2.1 我的代码 用了额外的空间O( n ) priority_queue会自动调用make_heap() 、pop_heap() c…...

01. 嵌入式与人工智能是如何结合的?

CPU是Arm A57的 GPU是128cuda核 一.小车跟踪的需求和设计方法 比如有一个小车跟踪的项目。 需求是:小车识别出罪犯,然后去跟踪他。方法:摄像头采集到人之后传入到开发板,内部做一下识别,然后控制小车去跟随。在人工智…...

vue3.0运行npm run dev 报错Cannot find module node:url

vue3.0运行npm run dev 报错Cannot find module 问题背景 近期用vue3.0写项目,npm init vuelatest —> npm install 都正常,npm run dev的时候报错如下: failed to load config from F:\code\testVue\vue-demo\vite.config.js error when starting…...

26. 删除排序数组中的重复项、Leetcode的Python实现

博客主页:🏆看看是李XX还是李歘歘 🏆 🌺每天分享一些包括但不限于计算机基础、算法等相关的知识点🌺 💗点关注不迷路,总有一些📖知识点📖是你想要的💗 ⛽️今…...

荣耀推送服务消息分类标准

前言 为了提升终端用户的推送体验、营造良好可持续的通知生态,荣耀推送服务将对推送消息进行分类管理。 消息分类 定义 荣耀推送服务将根据应用类型、消息内容和消息发送场景,将推送消息分成服务通讯和资讯营销两大类别。 服务通讯类,包…...

[数据结构]-二叉搜索树

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 目录 一、二叉搜…...

力扣每日一题79:单词搜索

题目描述: 给定一个 m x n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中,返回 true ;否则,返回 false 。 单词必须按照字母顺序,通过相邻的单元格内的字母构成,其中“相邻”单元格…...

ChatGPT如何应对用户提出的道德伦理困境?

ChatGPT在应对用户提出的道德伦理困境时,需要考虑众多复杂的因素。道德伦理问题涉及到价值观、原则、社会和文化背景,以及众多伦理理论。ChatGPT的设计和应用需要权衡各种考虑因素,以确保它不仅提供有用的信息,而且遵循伦理标准。…...

SpringBoot运行流程源码分析------阶段三(Spring Boot外化配置源码解析)

Spring Boot外化配置源码解析 外化配置简介 Spring Boot设计了非常特殊的加载指定属性文件(PropertySouce)的顺序,允许属性值合理的覆盖,属性值会以下面的优先级进行配置。home目录下的Devtool全局设置属性(~/.sprin…...

环形链表-力扣

一、题目描述 题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 二、题解 解题思路: 快慢指针,即慢指针一次走一步,快指针一次走两步,两个指针从链表起始位置开始运行,…...

人生岁月年华

人生很长吗?不知道。只知道高中坐在教室里,闹哄哄的很难受。也记得上班时无聊敲着代码也很难受。 可是人生也不长。你没有太多时间去试错,你没有无限的时间精力去追寻你认为的高大上。 人生是何体验呢?人生的感觉很多吧。大多数…...

电脑QQ如何录制视频文件?

听说QQ可以录制视频,还很方便,请问该如何录制呢?是需要先打开QQ才可以录制吗?还是可以直接使用快捷键进行录制呢?录制的质量又如何呢? 不要着急,既然都打开这篇文章看了,那小编今天…...

python:多波段遥感影像分离成单波段影像

作者:CSDN @ _养乐多_ 在遥感图像处理中,我们经常需要将多波段遥感影像拆分成多个单波段图像,以便进行各种分析和后续处理。本篇博客将介绍一个用Python编写的程序,该程序可以读取多波段遥感影像,将其拆分为单波段图像,并保存为单独的文件。本程序使用GDAL库来处理遥感影…...

天堂2游戏出错如何解决

运行游戏时出现以下提示:“the game may not be consistant because AGP is deactivated please activate AGP for consistancy” 这个问题的原因可能是由于您的显示卡的驱动或者主板的显示芯片组的驱动不是新开。或您虽然已经更新了您的显示卡的驱动程序&#xff0…...

『力扣刷题本』:合并两个有序链表(递归解法)

一、题目 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4]示例 2: 输入:l1 [], l2 [] 输出&#x…...

C++设计模式_12_Singleton 单件模式

在之前的博文C57个入门知识点_44:单例的实现与理解中,已经详细介绍了单例模式,并且根据其中内容,单例模式已经可以在日常编码中被使用,本文将会再做梳理。 Singleton 单件模式可以说是最简单的设计模式,但由…...

67 内网安全-域横向smbwmi明文或hash传递

#知识点1: windows2012以上版本默认关闭wdigest,攻击者无法从内存中获取明文密码windows2012以下版本如安装KB2871997补丁,同样也会导致无法获取明文密码针对以上情况,我们提供了4种方式解决此类问题 1.利用哈希hash传递(pth,ptk等…...

面向对象(类/继承/封装/多态)详解

简介: 面向对象编程(Object-Oriented Programming,OOP)是一种广泛应用于软件开发的编程范式。它基于一系列核心概念,包括类、继承、封装和多态。在这篇详细的解释中,我们将探讨这些概念,并说明它们如何在P…...

【Python机器学习】零基础掌握GradientBoostingRegressor集成学习

如何精准预测房价? 当人们提到房价预测时,很多人可能会想到房地产经纪人或专业的评估师。但是,有没有一种更科学、更精确的方法来预测房价呢?答案是有的,这就要用到机器学习中的一种算法——梯度提升回归(Gradient Boosting Regressor)。 假设现在有一组房屋数据,包括…...

【tio-websocket】12、应用层包—Packet

Packet 介绍 Packet 是用于表述业务数据结构的,我们通过继承 Packet 来实现自己的业务数据结构,对于各位而言,把 Packet 看作是一个普通的 VO 对象即可。 public class Packet implements java.io.Serializable, Cloneable {private static Logger log = LoggerFac…...

OpenCV官方教程中文版 —— 模板匹配

OpenCV官方教程中文版 —— 模板匹配 前言一、原理二、OpenCV 中的模板匹配三、多对象的模板匹配 前言 在本节我们要学习: 使用模板匹配在一幅图像中查找目标 函数:cv2.matchTemplate(),cv2.minMaxLoc() 一、原理 模板匹配是用来在一副大…...

如何为3D模型设置自发光材质?

1、自发光贴图的原理 自发光贴图是一种纹理贴图,用于模拟物体自发光的效果。其原理基于光的发射和反射过程。 在真实世界中,物体自发光通常是由于其本身具有能够产生光的属性,如荧光物质、发光材料或光源本身。为了在计算机图形中模拟这种效…...

UI组件库基础

UI组件库 全局组件* 全局注册组件 & 并且使用了require.context 模块化编程 & webpack打包 const install(Vue)>{const contextrequire.context(.,true,/\.vue$/)context.keys().forEach(fileName>{const modulecontext(fileName)Vue.component(module.default.n…...

数据结构与算法之矩阵: Leetcode 48. 旋转矩阵 (Typescript版)

旋转图像 https://leetcode.cn/problems/rotate-image/ 描述 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1 输入&…...

大厂面试题-JVM中的三色标记法是什么?

目录 问题分析 问题答案 问题分析 三色标记法是Java虚拟机(JVM)中垃圾回收算法的一种,主要用来标记内存中存活和需要回收的对象。 它的好处是,可以让JVM不发生或仅短时间发生STW(Stop The World),从而达到清除JVM内存垃圾的目的&#xff…...

Leetcode—121.买卖股票的最佳时机【简单】

2023每日刷题&#xff08;十一&#xff09; Leetcode—17.电话号码的字母组合 枚举法题解 参考自灵茶山艾府 枚举法实现代码 int maxProfit(int* prices, int pricesSize){int i;int max 0;int minPrice prices[0];for(i 1; i < pricesSize; i) {int tmp prices[i] -…...

【云原生】portainer管理多个独立docker服务器

目录 一、portainer简介 二、安装Portainer 1.1 内网环境下&#xff1a; 1.1.1 方式1&#xff1a;命令行运行 1.1.2 方式2&#xff1a;通过compose-file来启动 2.1 配置本地主机&#xff08;node-1&#xff09; 3.1 配置其他主机&#xff08;被node-1管理的节点服务器&…...

Command集合

Command集合 mysql相关命令 查看mysql的状态 sudo netstat -tap | grep mysql 启动mysql sudo service mysql start 停止mysql sudo service mysql stop 重启mysql sudo service mysql restart 指定端口号&#xff0c;客户端连接mysql sudo mysql -h127.0.0.1 -uroot -p red…...

【QT开发(17)】2023-QT 5.14.2实现Android开发

1、简介 搭建Qt For Android开发环境需要安装的软件有&#xff1a; JAVA SDK &#xff08;jdk 有apt install 安装&#xff09; Android SDK Android NDKQT官网的介绍&#xff1a; Different Qt versions depend on different NDK versions, as listed below: Qt versionNDK…...