经典卷积神经网络 - NIN
网络中的网络,NIN。
AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成的小网络来构建⼀个深层网络。
AlexNet和VGG对LeNet的改进主要在于如何扩大和加深这两个模块。
或者,可以想象在这个过程的早期使用全连接层。然而,如果使用了全连接层,可能会完全放弃表征的空间结构。
网络中的网络(NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。也就是使用了多个1*1的卷积核。同时他认为全连接层占据了大量的内存,所以整个网络结构中没有使用全连接层。
NIN块

一个卷积层后跟两个全连接层。
- 步幅为1,无填充,输出形状跟卷积层输出一样。
- 起到全连接层的作用。
NIN网络结构
-
无全连接层
-
交替使用NIN块和步幅为2的最大池化层
逐步减小高宽和增大通道数
-
最后使用全局平均池化层得到输出
其输入通道数是类别数
此网络结构总计4层: 3mlpconv + 1global_average_pooling
优点:
- 提供了网络层间映射的一种新可能;
- 增加了网络卷积层的非线性能力。
总结:
- NIN块使用卷积层加上个 1 × 1 1\times 1 1×1卷积,后者对每个像素增加了非线性性
- NIN使用全局平均池化层来替代VGG和AlexNet中的全连接层,不容易过拟合,更少的参数个数
代码实现
使用CIFAR-10
数据集。
maxpooling不改变通道数,只改变长和宽
model.py
import torch
from torch import nn# nin块
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),)# 构建网络
class NIN(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.model = nn.Sequential(nin_block(3,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())def forward(self,x):return self.model(x)# 验证模型正确性
if __name__ == '__main__':nin = NIN()x = torch.ones((64,3,244,244))output = nin(x)print(output)
train.py
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets
from torchvision.transforms import transforms
from model import NIN# 扫描数据次数
epochs = 3
# 分组大小
batch = 64
# 学习率
learning_rate = 0.01
# 训练次数
train_step = 0
# 测试次数
test_step = 0# 定义图像转换
transform = transforms.Compose([transforms.Resize(224),transforms.ToTensor()
])
# 读取数据
train_dataset = datasets.CIFAR10(root="./dataset",train=True,transform=transform,download=True)
test_dataset = datasets.CIFAR10(root="./dataset",train=False,transform=transform,download=True)
# 加载数据
train_dataloader = DataLoader(train_dataset,batch_size=batch,shuffle=True,num_workers=0)
test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=True,num_workers=0)
# 数据大小
train_size = len(train_dataset)
test_size = len(test_dataset)
print("训练集大小:{}".format(train_size))
print("验证集大小:{}".format(test_size))# GPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)
# 创建网络
net = NIN()
net = net.to(device)
# 定义损失函数
loss = nn.CrossEntropyLoss()
loss = loss.to(device)
# 定义优化器
optimizer = torch.optim.SGD(net.parameters(),lr=learning_rate)writer = SummaryWriter("logs")
# 训练
for epoch in range(epochs):print("-------------------第 {} 轮训练开始-------------------".format(epoch))net.train()for data in train_dataloader:train_step = train_step + 1images,targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs,targets)optimizer.zero_grad()loss_out.backward()optimizer.step()if train_step%100==0:writer.add_scalar("Train Loss",scalar_value=loss_out.item(),global_step=train_step)print("训练次数:{},Loss:{}".format(train_step,loss_out.item()))# 测试net.eval()total_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:test_step = test_step + 1images, targets = dataimages = images.to(device)targets = targets.to(device)outputs = net(images)loss_out = loss(outputs, targets)total_loss = total_loss + loss_outaccuracy = (targets == torch.argmax(outputs,dim=1)).sum()total_accuracy = total_accuracy + accuracy# 计算精确率print(total_accuracy)accuracy_rate = total_accuracy / test_sizeprint("第 {} 轮,验证集总损失为:{}".format(epoch+1,total_loss))print("第 {} 轮,精确率为:{}".format(epoch+1,accuracy_rate))writer.add_scalar("Test Total Loss",scalar_value=total_loss,global_step=epoch+1)writer.add_scalar("Accuracy Rate",scalar_value=accuracy_rate,global_step=epoch+1)torch.save(net,"./model/net_{}.pth".format(epoch+1))print("模型net_{}.pth已保存".format(epoch+1))
相关文章:

经典卷积神经网络 - NIN
网络中的网络,NIN。 AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成…...
leetcode_2558 从数量最多的堆取走礼物
1. 题意 给定一个数组,每次从中取走最大的数,返回开根号向下取整送入堆中,最后计算总和。 从数量最多的堆取走礼物 2. 题解 直接用堆模拟即可 2.1 我的代码 用了额外的空间O( n ) priority_queue会自动调用make_heap() 、pop_heap() c…...

01. 嵌入式与人工智能是如何结合的?
CPU是Arm A57的 GPU是128cuda核 一.小车跟踪的需求和设计方法 比如有一个小车跟踪的项目。 需求是:小车识别出罪犯,然后去跟踪他。方法:摄像头采集到人之后传入到开发板,内部做一下识别,然后控制小车去跟随。在人工智…...

vue3.0运行npm run dev 报错Cannot find module node:url
vue3.0运行npm run dev 报错Cannot find module 问题背景 近期用vue3.0写项目,npm init vuelatest —> npm install 都正常,npm run dev的时候报错如下: failed to load config from F:\code\testVue\vue-demo\vite.config.js error when starting…...
26. 删除排序数组中的重复项、Leetcode的Python实现
博客主页:🏆看看是李XX还是李歘歘 🏆 🌺每天分享一些包括但不限于计算机基础、算法等相关的知识点🌺 💗点关注不迷路,总有一些📖知识点📖是你想要的💗 ⛽️今…...

荣耀推送服务消息分类标准
前言 为了提升终端用户的推送体验、营造良好可持续的通知生态,荣耀推送服务将对推送消息进行分类管理。 消息分类 定义 荣耀推送服务将根据应用类型、消息内容和消息发送场景,将推送消息分成服务通讯和资讯营销两大类别。 服务通讯类,包…...

[数据结构]-二叉搜索树
前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正。 目录 一、二叉搜…...

力扣每日一题79:单词搜索
题目描述: 给定一个 m x n 二维字符网格 board 和一个字符串单词 word 。如果 word 存在于网格中,返回 true ;否则,返回 false 。 单词必须按照字母顺序,通过相邻的单元格内的字母构成,其中“相邻”单元格…...
ChatGPT如何应对用户提出的道德伦理困境?
ChatGPT在应对用户提出的道德伦理困境时,需要考虑众多复杂的因素。道德伦理问题涉及到价值观、原则、社会和文化背景,以及众多伦理理论。ChatGPT的设计和应用需要权衡各种考虑因素,以确保它不仅提供有用的信息,而且遵循伦理标准。…...
SpringBoot运行流程源码分析------阶段三(Spring Boot外化配置源码解析)
Spring Boot外化配置源码解析 外化配置简介 Spring Boot设计了非常特殊的加载指定属性文件(PropertySouce)的顺序,允许属性值合理的覆盖,属性值会以下面的优先级进行配置。home目录下的Devtool全局设置属性(~/.sprin…...

环形链表-力扣
一、题目描述 题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 二、题解 解题思路: 快慢指针,即慢指针一次走一步,快指针一次走两步,两个指针从链表起始位置开始运行,…...
人生岁月年华
人生很长吗?不知道。只知道高中坐在教室里,闹哄哄的很难受。也记得上班时无聊敲着代码也很难受。 可是人生也不长。你没有太多时间去试错,你没有无限的时间精力去追寻你认为的高大上。 人生是何体验呢?人生的感觉很多吧。大多数…...

电脑QQ如何录制视频文件?
听说QQ可以录制视频,还很方便,请问该如何录制呢?是需要先打开QQ才可以录制吗?还是可以直接使用快捷键进行录制呢?录制的质量又如何呢? 不要着急,既然都打开这篇文章看了,那小编今天…...

python:多波段遥感影像分离成单波段影像
作者:CSDN @ _养乐多_ 在遥感图像处理中,我们经常需要将多波段遥感影像拆分成多个单波段图像,以便进行各种分析和后续处理。本篇博客将介绍一个用Python编写的程序,该程序可以读取多波段遥感影像,将其拆分为单波段图像,并保存为单独的文件。本程序使用GDAL库来处理遥感影…...
天堂2游戏出错如何解决
运行游戏时出现以下提示:“the game may not be consistant because AGP is deactivated please activate AGP for consistancy” 这个问题的原因可能是由于您的显示卡的驱动或者主板的显示芯片组的驱动不是新开。或您虽然已经更新了您的显示卡的驱动程序࿰…...

『力扣刷题本』:合并两个有序链表(递归解法)
一、题目 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4]示例 2: 输入:l1 [], l2 [] 输出&#x…...

C++设计模式_12_Singleton 单件模式
在之前的博文C57个入门知识点_44:单例的实现与理解中,已经详细介绍了单例模式,并且根据其中内容,单例模式已经可以在日常编码中被使用,本文将会再做梳理。 Singleton 单件模式可以说是最简单的设计模式,但由…...

67 内网安全-域横向smbwmi明文或hash传递
#知识点1: windows2012以上版本默认关闭wdigest,攻击者无法从内存中获取明文密码windows2012以下版本如安装KB2871997补丁,同样也会导致无法获取明文密码针对以上情况,我们提供了4种方式解决此类问题 1.利用哈希hash传递(pth,ptk等…...

面向对象(类/继承/封装/多态)详解
简介: 面向对象编程(Object-Oriented Programming,OOP)是一种广泛应用于软件开发的编程范式。它基于一系列核心概念,包括类、继承、封装和多态。在这篇详细的解释中,我们将探讨这些概念,并说明它们如何在P…...
【Python机器学习】零基础掌握GradientBoostingRegressor集成学习
如何精准预测房价? 当人们提到房价预测时,很多人可能会想到房地产经纪人或专业的评估师。但是,有没有一种更科学、更精确的方法来预测房价呢?答案是有的,这就要用到机器学习中的一种算法——梯度提升回归(Gradient Boosting Regressor)。 假设现在有一组房屋数据,包括…...
云计算——弹性云计算器(ECS)
弹性云服务器:ECS 概述 云计算重构了ICT系统,云计算平台厂商推出使得厂家能够主要关注应用管理而非平台管理的云平台,包含如下主要概念。 ECS(Elastic Cloud Server):即弹性云服务器,是云计算…...

定时器任务——若依源码分析
分析util包下面的工具类schedule utils: ScheduleUtils 是若依中用于与 Quartz 框架交互的工具类,封装了定时任务的 创建、更新、暂停、删除等核心逻辑。 createScheduleJob createScheduleJob 用于将任务注册到 Quartz,先构建任务的 JobD…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...

分布式增量爬虫实现方案
之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...
NPOI Excel用OLE对象的形式插入文件附件以及插入图片
static void Main(string[] args) {XlsWithObjData();Console.WriteLine("输出完成"); }static void XlsWithObjData() {// 创建工作簿和单元格,只有HSSFWorkbook,XSSFWorkbook不可以HSSFWorkbook workbook new HSSFWorkbook();HSSFSheet sheet (HSSFSheet)workboo…...

从 GreenPlum 到镜舟数据库:杭银消费金融湖仓一体转型实践
作者:吴岐诗,杭银消费金融大数据应用开发工程师 本文整理自杭银消费金融大数据应用开发工程师在StarRocks Summit Asia 2024的分享 引言:融合数据湖与数仓的创新之路 在数字金融时代,数据已成为金融机构的核心竞争力。杭银消费金…...
0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化
是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可,…...
用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章
用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章 摘要: 操作系统内核的安全性、稳定性至关重要。传统 Linux 内核模块开发长期依赖于 C 语言,受限于 C 语言本身的内存安全和并发安全问题,开发复杂模块极易引入难以…...