经典卷积神经网络 - NIN
网络中的网络,NIN。
AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。
AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。
网络中的网络(NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。
NIN块
一个卷积层后跟两个全连接层。
- 步幅为1,无填充,输出形状跟卷积层输出一样。
- 起到全连接层的作用。
NIN网络结构


-
无全连接层
-
交替使用NIN块和步幅为2的最大池化层
逐步减小高宽和增大通道数
-
最后使用全局平均池化层得到输出
其输入通道数是类别数
此网络结构总计4层: 3mlpconv + 1global_average_pooling
优点:
- 提供了网络层间映射的一种新可能;
- 增加了网络卷积层的非线性能力。
总结:
- NIN块使用卷积层加上个 1 × 1 1\times 1 1×1卷积,后者对每个像素增加了非线性性
- NIN使用全局平均池化层来替代VGG和AlexNet中的全连接层,不容易过拟合,更少的参数个数
代码实现
使用CIFAR-10数据集。
maxpooling不改变通道数,只改变长和宽
model.py
import torch
from torch import nn# nin块
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),)# 构建网络
class NIN(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nin_block(3,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':nin = NIN()x = torch.ones((64,3,244,244))output = nin(x)print(output)
train.py
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import NIN# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = NIN()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))
相关文章:
经典卷积神经网络 - NIN
网络中的网络,NIN。 AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成…...
leetcode_2558 从数量最多的堆取走礼物
1. 题意 给定一个数组,每次从中取走最大的数,返回开根号向下取整送入堆中,最后计算总和。 从数量最多的堆取走礼物 2. 题解 直接用堆模拟即可 2.1 我的代码 用了额外的空间O( n ) priority_queue会自动调用make_heap() 、pop_heap() c…...
01. 嵌入式与人工智能是如何结合的?
CPU是Arm A57的 GPU是128cuda核 一.小车跟踪的需求和设计方法 比如有一个小车跟踪的项目。 需求是:小车识别出罪犯,然后去跟踪他。方法:摄像头采集到人之后传入到开发板,内部做一下识别,然后控制小车去跟随。在人工智…...
vue3.0运行npm run dev 报错Cannot find module node:url
vue3.0运行npm run dev 报错Cannot find module 问题背景 近期用vue3.0写项目,npm init vuelatest —> npm install 都正常,npm run dev的时候报错如下: failed to load config from F:\code\testVue\vue-demo\vite.config.js error when starting…...
26. 删除排序数组中的重复项、Leetcode的Python实现
博客主页:🏆看看是李XX还是李歘歘 🏆 🌺每天分享一些包括但不限于计算机基础、算法等相关的知识点🌺 💗点关注不迷路,总有一些📖知识点📖是你想要的💗 ⛽️今…...
荣耀推送服务消息分类标准
前言 为了提升终端用户的推送体验、营造良好可持续的通知生态,荣耀推送服务将对推送消息进行分类管理。 消息分类 定义 荣耀推送服务将根据应用类型、消息内容和消息发送场景,将推送消息分成服务通讯和资讯营销两大类别。 服务通讯类,包…...
[数据结构]-二叉搜索树
前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 目录 一、二叉搜…...
力扣每日一题79:单词搜索
题目描述: 给定一个 m x n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中,返回 true ;否则,返回 false 。 单词必须按照字母顺序,通过相邻的单元格内的字母构成,其中“相邻”单元格…...
ChatGPT如何应对用户提出的道德伦理困境?
ChatGPT在应对用户提出的道德伦理困境时,需要考虑众多复杂的因素。道德伦理问题涉及到价值观、原则、社会和文化背景,以及众多伦理理论。ChatGPT的设计和应用需要权衡各种考虑因素,以确保它不仅提供有用的信息,而且遵循伦理标准。…...
SpringBoot运行流程源码分析------阶段三(Spring Boot外化配置源码解析)
Spring Boot外化配置源码解析 外化配置简介 Spring Boot设计了非常特殊的加载指定属性文件(PropertySouce)的顺序,允许属性值合理的覆盖,属性值会以下面的优先级进行配置。home目录下的Devtool全局设置属性(~/.sprin…...
环形链表-力扣
一、题目描述 题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 二、题解 解题思路: 快慢指针,即慢指针一次走一步,快指针一次走两步,两个指针从链表起始位置开始运行,…...
人生岁月年华
人生很长吗?不知道。只知道高中坐在教室里,闹哄哄的很难受。也记得上班时无聊敲着代码也很难受。 可是人生也不长。你没有太多时间去试错,你没有无限的时间精力去追寻你认为的高大上。 人生是何体验呢?人生的感觉很多吧。大多数…...
电脑QQ如何录制视频文件?
听说QQ可以录制视频,还很方便,请问该如何录制呢?是需要先打开QQ才可以录制吗?还是可以直接使用快捷键进行录制呢?录制的质量又如何呢? 不要着急,既然都打开这篇文章看了,那小编今天…...
python:多波段遥感影像分离成单波段影像
作者:CSDN @ _养乐多_ 在遥感图像处理中,我们经常需要将多波段遥感影像拆分成多个单波段图像,以便进行各种分析和后续处理。本篇博客将介绍一个用Python编写的程序,该程序可以读取多波段遥感影像,将其拆分为单波段图像,并保存为单独的文件。本程序使用GDAL库来处理遥感影…...
天堂2游戏出错如何解决
运行游戏时出现以下提示:“the game may not be consistant because AGP is deactivated please activate AGP for consistancy” 这个问题的原因可能是由于您的显示卡的驱动或者主板的显示芯片组的驱动不是新开。或您虽然已经更新了您的显示卡的驱动程序࿰…...
『力扣刷题本』:合并两个有序链表(递归解法)
一、题目 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4]示例 2: 输入:l1 [], l2 [] 输出&#x…...
C++设计模式_12_Singleton 单件模式
在之前的博文C57个入门知识点_44:单例的实现与理解中,已经详细介绍了单例模式,并且根据其中内容,单例模式已经可以在日常编码中被使用,本文将会再做梳理。 Singleton 单件模式可以说是最简单的设计模式,但由…...
67 内网安全-域横向smbwmi明文或hash传递
#知识点1: windows2012以上版本默认关闭wdigest,攻击者无法从内存中获取明文密码windows2012以下版本如安装KB2871997补丁,同样也会导致无法获取明文密码针对以上情况,我们提供了4种方式解决此类问题 1.利用哈希hash传递(pth,ptk等…...
面向对象(类/继承/封装/多态)详解
简介: 面向对象编程(Object-Oriented Programming,OOP)是一种广泛应用于软件开发的编程范式。它基于一系列核心概念,包括类、继承、封装和多态。在这篇详细的解释中,我们将探讨这些概念,并说明它们如何在P…...
【Python机器学习】零基础掌握GradientBoostingRegressor集成学习
如何精准预测房价? 当人们提到房价预测时,很多人可能会想到房地产经纪人或专业的评估师。但是,有没有一种更科学、更精确的方法来预测房价呢?答案是有的,这就要用到机器学习中的一种算法——梯度提升回归(Gradient Boosting Regressor)。 假设现在有一组房屋数据,包括…...
SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...
idea大量爆红问题解决
问题描述 在学习和工作中,idea是程序员不可缺少的一个工具,但是突然在有些时候就会出现大量爆红的问题,发现无法跳转,无论是关机重启或者是替换root都无法解决 就是如上所展示的问题,但是程序依然可以启动。 问题解决…...
React 第五十五节 Router 中 useAsyncError的使用详解
前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...
大话软工笔记—需求分析概述
需求分析,就是要对需求调研收集到的资料信息逐个地进行拆分、研究,从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要,后续设计的依据主要来自于需求分析的成果,包括: 项目的目的…...
可靠性+灵活性:电力载波技术在楼宇自控中的核心价值
可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
Python爬虫(二):爬虫完整流程
爬虫完整流程详解(7大核心步骤实战技巧) 一、爬虫完整工作流程 以下是爬虫开发的完整流程,我将结合具体技术点和实战经验展开说明: 1. 目标分析与前期准备 网站技术分析: 使用浏览器开发者工具(F12&…...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别
OpenPrompt 和直接对提示词的嵌入向量进行训练有什么区别 直接训练提示词嵌入向量的核心区别 您提到的代码: prompt_embedding = initial_embedding.clone().requires_grad_(True) optimizer = torch.optim.Adam([prompt_embedding...
