当前位置: 首页 > news >正文

经典卷积神经网络 - NIN

网络中的网络,NIN。

AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。

AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。

网络中的网络NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。

NIN块

image-20231023194711049

一个卷积层后跟两个全连接层。

  • 步幅为1,无填充,输出形状跟卷积层输出一样。
  • 起到全连接层的作用。

NIN网络结构
在这里插入图片描述
image-20231024090401226

  • 无全连接层

  • 交替使用NIN块和步幅为2的最大池化层

    逐步减小高宽和增大通道数

  • 最后使用全局平均池化层得到输出

    其输入通道数是类别数

此网络结构总计4层: 3mlpconv + 1global_average_pooling

优点:

  1. 提供了网络层间映射的一种新可能;
  2. 增加了网络卷积层的非线性能力。

总结:

  • NIN块使用卷积层加上个 1 × 1 1\times 1 1×1卷积,后者对每个像素增加了非线性性
  • NIN使用全局平均池化层来替代VGG和AlexNet中的全连接层,不容易过拟合,更少的参数个数

代码实现

使用CIFAR-10数据集。

maxpooling不改变通道数,只改变长和宽

model.py

import torch
from torch import nn# nin块
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),)# 构建网络
class NIN(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nin_block(3,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':nin = NIN()x = torch.ones((64,3,244,244))output = nin(x)print(output)

train.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import NIN# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = NIN()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))

相关文章:

经典卷积神经网络 - NIN

网络中的网络,NIN。 AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成…...

leetcode_2558 从数量最多的堆取走礼物

1. 题意 给定一个数组,每次从中取走最大的数,返回开根号向下取整送入堆中,最后计算总和。 从数量最多的堆取走礼物 2. 题解 直接用堆模拟即可 2.1 我的代码 用了额外的空间O( n ) priority_queue会自动调用make_heap() 、pop_heap() c…...

01. 嵌入式与人工智能是如何结合的?

CPU是Arm A57的 GPU是128cuda核 一.小车跟踪的需求和设计方法 比如有一个小车跟踪的项目。 需求是:小车识别出罪犯,然后去跟踪他。方法:摄像头采集到人之后传入到开发板,内部做一下识别,然后控制小车去跟随。在人工智…...

vue3.0运行npm run dev 报错Cannot find module node:url

vue3.0运行npm run dev 报错Cannot find module 问题背景 近期用vue3.0写项目,npm init vuelatest —> npm install 都正常,npm run dev的时候报错如下: failed to load config from F:\code\testVue\vue-demo\vite.config.js error when starting…...

26. 删除排序数组中的重复项、Leetcode的Python实现

博客主页:🏆看看是李XX还是李歘歘 🏆 🌺每天分享一些包括但不限于计算机基础、算法等相关的知识点🌺 💗点关注不迷路,总有一些📖知识点📖是你想要的💗 ⛽️今…...

荣耀推送服务消息分类标准

前言 为了提升终端用户的推送体验、营造良好可持续的通知生态,荣耀推送服务将对推送消息进行分类管理。 消息分类 定义 荣耀推送服务将根据应用类型、消息内容和消息发送场景,将推送消息分成服务通讯和资讯营销两大类别。 服务通讯类,包…...

[数据结构]-二叉搜索树

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 目录 一、二叉搜…...

力扣每日一题79:单词搜索

题目描述: 给定一个 m x n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中,返回 true ;否则,返回 false 。 单词必须按照字母顺序,通过相邻的单元格内的字母构成,其中“相邻”单元格…...

ChatGPT如何应对用户提出的道德伦理困境?

ChatGPT在应对用户提出的道德伦理困境时,需要考虑众多复杂的因素。道德伦理问题涉及到价值观、原则、社会和文化背景,以及众多伦理理论。ChatGPT的设计和应用需要权衡各种考虑因素,以确保它不仅提供有用的信息,而且遵循伦理标准。…...

SpringBoot运行流程源码分析------阶段三(Spring Boot外化配置源码解析)

Spring Boot外化配置源码解析 外化配置简介 Spring Boot设计了非常特殊的加载指定属性文件(PropertySouce)的顺序,允许属性值合理的覆盖,属性值会以下面的优先级进行配置。home目录下的Devtool全局设置属性(~/.sprin…...

环形链表-力扣

一、题目描述 题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 二、题解 解题思路: 快慢指针,即慢指针一次走一步,快指针一次走两步,两个指针从链表起始位置开始运行,…...

人生岁月年华

人生很长吗?不知道。只知道高中坐在教室里,闹哄哄的很难受。也记得上班时无聊敲着代码也很难受。 可是人生也不长。你没有太多时间去试错,你没有无限的时间精力去追寻你认为的高大上。 人生是何体验呢?人生的感觉很多吧。大多数…...

电脑QQ如何录制视频文件?

听说QQ可以录制视频,还很方便,请问该如何录制呢?是需要先打开QQ才可以录制吗?还是可以直接使用快捷键进行录制呢?录制的质量又如何呢? 不要着急,既然都打开这篇文章看了,那小编今天…...

python:多波段遥感影像分离成单波段影像

作者:CSDN @ _养乐多_ 在遥感图像处理中,我们经常需要将多波段遥感影像拆分成多个单波段图像,以便进行各种分析和后续处理。本篇博客将介绍一个用Python编写的程序,该程序可以读取多波段遥感影像,将其拆分为单波段图像,并保存为单独的文件。本程序使用GDAL库来处理遥感影…...

天堂2游戏出错如何解决

运行游戏时出现以下提示:“the game may not be consistant because AGP is deactivated please activate AGP for consistancy” 这个问题的原因可能是由于您的显示卡的驱动或者主板的显示芯片组的驱动不是新开。或您虽然已经更新了您的显示卡的驱动程序&#xff0…...

『力扣刷题本』:合并两个有序链表(递归解法)

一、题目 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4]示例 2: 输入:l1 [], l2 [] 输出&#x…...

C++设计模式_12_Singleton 单件模式

在之前的博文C57个入门知识点_44:单例的实现与理解中,已经详细介绍了单例模式,并且根据其中内容,单例模式已经可以在日常编码中被使用,本文将会再做梳理。 Singleton 单件模式可以说是最简单的设计模式,但由…...

67 内网安全-域横向smbwmi明文或hash传递

#知识点1: windows2012以上版本默认关闭wdigest,攻击者无法从内存中获取明文密码windows2012以下版本如安装KB2871997补丁,同样也会导致无法获取明文密码针对以上情况,我们提供了4种方式解决此类问题 1.利用哈希hash传递(pth,ptk等…...

面向对象(类/继承/封装/多态)详解

简介: 面向对象编程(Object-Oriented Programming,OOP)是一种广泛应用于软件开发的编程范式。它基于一系列核心概念,包括类、继承、封装和多态。在这篇详细的解释中,我们将探讨这些概念,并说明它们如何在P…...

【Python机器学习】零基础掌握GradientBoostingRegressor集成学习

如何精准预测房价? 当人们提到房价预测时,很多人可能会想到房地产经纪人或专业的评估师。但是,有没有一种更科学、更精确的方法来预测房价呢?答案是有的,这就要用到机器学习中的一种算法——梯度提升回归(Gradient Boosting Regressor)。 假设现在有一组房屋数据,包括…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

【HTTP三个基础问题】

面试官您好!HTTP是超文本传输协议,是互联网上客户端和服务器之间传输超文本数据(比如文字、图片、音频、视频等)的核心协议,当前互联网应用最广泛的版本是HTTP1.1,它基于经典的C/S模型,也就是客…...

网站指纹识别

网站指纹识别 网站的最基本组成:服务器(操作系统)、中间件(web容器)、脚本语言、数据厍 为什么要了解这些?举个例子:发现了一个文件读取漏洞,我们需要读/etc/passwd,如…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析

Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...

LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》

这段 Python 代码是一个完整的 知识库数据库操作模块,用于对本地知识库系统中的知识库进行增删改查(CRUD)操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 📘 一、整体功能概述 该模块…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具,可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件,也不需要在线上传文件,保护您的隐私。 工具截图 主要特点 🚀 快速转换:本地转换,无需等待上…...

Python 实现 Web 静态服务器(HTTP 协议)

目录 一、在本地启动 HTTP 服务器1. Windows 下安装 node.js1)下载安装包2)配置环境变量3)安装镜像4)node.js 的常用命令 2. 安装 http-server 服务3. 使用 http-server 开启服务1)使用 http-server2)详解 …...

c++第七天 继承与派生2

这一篇文章主要内容是 派生类构造函数与析构函数 在派生类中重写基类成员 以及多继承 第一部分:派生类构造函数与析构函数 当创建一个派生类对象时,基类成员是如何初始化的? 1.当派生类对象创建的时候,基类成员的初始化顺序 …...