当前位置: 首页 > news >正文

Desnet模型详解

模型介绍

DenseNet的主要思想是密集连接,它在卷积神经网络(CNN)中引入了密集块(Dense Block),在这些块中,每个层都与前面所有层直接连接。这种设计可以让信息更快速地传播,有助于解决梯度消失的问题,同时也能够增加网络的参数共享,减少参数量,提高模型的效率和性能。

Desnet原理

DenseNet 的原理可以总结为以下几个关键点:

  1. 密集连接的块: DenseNet 将网络分成多个密集块(Dense Block)。在每个密集块内,每一层都连接到前面所有的层,不仅仅是前一层。这种连接方式使得信息能够更加快速地传播,允许网络在更早的阶段融合不同层的特征。

  2. 跳跃连接: 每一层都从前面所有的层接收特征作为输入。这些输入通过堆叠而来,从而构成了一个密集的特征图。这种跳跃连接有助于解决梯度消失问题,因为每一层都可以直接访问之前层的梯度信息,使得训练更加稳定。

  3. 特征重用性: 由于每一层都与前面所有层连接,网络可以自动地学习到更加丰富和复杂的特征表示。这样的特征重用性有助于提高网络的性能,同时减少了需要训练的参数数量。

  4. 过渡层: 在密集块之间,通常会使用过渡层(Transition Layer)来控制特征图的大小。过渡层包括一个卷积层和一个池化层,用于减小特征图的尺寸,从而减少计算量。

Desnet的结构

关于 DenseNet 的结构时,我们主要关注网络中的三个主要组成部分:密集块(Dense Block)、过渡层(Transition Layer)以及全局平均池化层。

密集块

密集块是 DenseNet 最核心的部分,由若干层组成。在密集块中,每一层都与前面所有层直接连接。这种密集连接的方式使得信息可以更充分地传递和重用。每一层的输出都是前面所有层输出的连结,这也意味着每一层的输入包括了前面所有层的特征。这种连接方式通过堆叠层的方式,构建了一个密集的特征图。

过渡层

在密集块之间,可以使用过渡层来控制特征图的大小,从而减少计算成本。过渡层由一个卷积层和一个池化层组成。卷积层用于减小通道数,从而降低特征图的维度。池化层(通常是平均池化)用于减小特征图的尺寸。这些操作有助于在保持网络性能的同时降低计算需求。

全局平均池化层

在整个 DenseNet 结构的末尾,通常会添加一个全局平均池化层。这一层的作用是将最终的特征图转换为全局汇总的特征,这对于分类任务是非常有用的。全局平均池化层计算每个通道上的平均值,将每个通道转换为一个标量,从而形成最终的预测。

DenseNet 结构的特点不仅在每个密集块内进行特征的密集连接,还在不同密集块之间使用过渡层来控制网络的尺寸和复杂度。这使得 DenseNet 能够在高度复杂的任务中表现出色,同时保持相对较少的参数。

这些在论文当中也有体现:

Desnet的优缺点比较

优点

  • 密集连接促进信息传递和特征重用,提升了网络性能。

  • 跳跃连接减少了梯度消失,有助于训练深层网络。

  • 密集连接减少参数数量,提高了模型效率。

  • 早期融合多尺度特征,增强了表征能力。

  • 在小样本情况下表现更佳,充分利用有限数据。

缺点

  • 密集连接可能导致内存需求增大。

  • 连接多导致计算量增加,训练和推理时间较长。

  • 可能因复杂性导致过拟合,需考虑正则化。

其实综合考虑,Desnet在图像识别和计算机视觉任务中仍然是一个好的选择。

Pytorch实现Desnet

import torch
import torchvision
import torch.nn as nn
import torchsummary
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from collections import OrderedDict
from torchvision.utils import _log_api_usage_once
import torch.utils.checkpoint as cpmodel_urls = {"densenet121":"https://download.pytorch.org/models/densenet121-a639ec97.pth","densenet161":"https://download.pytorch.org/models/densenet161-8d451a50.pth","densenet169":"https://download.pytorch.org/models/densenet169-b2777c0a.pth","densenet201":"https://download.pytorch.org/models/densenet201-c1103571.pth",
}
cfgs = {"densenet121":(6, 12, 24, 16),"densenet161":(6, 12, 36, 24),"densenet169":(6, 12, 32, 32),"densenet201":(6, 12, 48, 32),
}class DenseLayer(nn.Module):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient = False):super(DenseLayer,self).__init__()self.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)self.drop_rate = float(drop_rate)self.memory_efficient = memory_efficientdef bn_function(self, inputs):concated_features = torch.cat(inputs, 1)bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))return bottleneck_outputdef any_requires_grad(self, input):for tensor in input:if tensor.requires_grad:return Truereturn False@torch.jit.unuseddef call_checkpoint_bottleneck(self, input):def closure(*inputs):return self.bn_function(inputs)return cp.checkpoint(closure, *input)def forward(self, input):if isinstance(input, torch.Tensor):prev_features = [input]else:prev_features = inputif self.memory_efficient and self.any_requires_grad(prev_features):if torch.jit.is_scripting():raise Exception("Memory Efficient not supported in JIT")bottleneck_output = self.call_checkpoint_bottleneck(prev_features)else:bottleneck_output = self.bn_function(prev_features)new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return new_featuresclass DenseBlock(nn.ModuleDict):def __init__(self,num_layers,num_input_features,bn_size,growth_rate,drop_rate,memory_efficient = False,):super(DenseBlock,self).__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.add_module("denselayer%d" % (i + 1), layer)def forward(self, init_features):features = [init_features]for name, layer in self.items():new_features = layer(features)features.append(new_features)return torch.cat(features, 1)class Transition(nn.Sequential):"""Densenet Transition Layer:1 × 1 conv2 × 2 average pool, stride 2"""def __init__(self, num_input_features, num_output_features):super(Transition,self).__init__()self.norm = nn.BatchNorm2d(num_input_features)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)class DenseNet(nn.Module):def __init__(self,growth_rate = 32,num_init_features = 64,block_config = None,num_classes = 1000,bn_size = 4,drop_rate = 0.,memory_efficient = False,):super(DenseNet,self).__init__()_log_api_usage_once(self)# First convolutionself.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),]))# Each denseblocknum_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.features.add_module("denseblock%d" % (i + 1), block)num_features = num_features + num_layers * growth_rateif i != len(block_config) - 1:trans = Transition(num_input_features=num_features, num_output_features=num_features // 2)self.features.add_module("transition%d" % (i + 1), trans)num_features = num_features // 2# Final batch normself.features.add_module("norm5", nn.BatchNorm2d(num_features))# Linear layerself.classifier = nn.Linear(num_features, num_classes)# Official init from torch repo.for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.relu(features, inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))out = torch.flatten(out, 1)out = self.classifier(out)return outdef densenet(growth_rate=32,num_init_features=64,num_classes=1000,mode="densenet121",pretrained=False,**kwargs):import repattern = re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$")if mode == "densenet161":growth_rate=48num_init_features=96model = DenseNet(growth_rate, num_init_features, cfgs[mode],num_classes=num_classes, **kwargs)if pretrained:state_dict = load_state_dict_from_url(model_urls[mode], model_dir='./model', progress=True)  # 预训练模型地址for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]if num_classes != 1000:num_new_classes = num_classesweight = state_dict['classifier.weight']bias = state_dict['classifier.bias']weight_new = weight[:num_new_classes, :]bias_new = bias[:num_new_classes]state_dict['classifier.weight'] = weight_newstate_dict['classifier.bias'] = bias_newmodel.load_state_dict(state_dict)return modelfrom torchsummaryX import summaryif __name__ == "__main__":in_channels = 3num_classes = 10device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = densenet(growth_rate=32, num_init_features=64, num_classes=num_classes, mode="densenet121", pretrained=True)model = model.to(device)print(model)summary(model, torch.zeros((1, in_channels, 224, 224)).to(device))

相关文章:

Desnet模型详解

模型介绍 DenseNet的主要思想是密集连接,它在卷积神经网络(CNN)中引入了密集块(Dense Block),在这些块中,每个层都与前面所有层直接连接。这种设计可以让信息更快速地传播,有助于解…...

clickhouse-压测

一、数据集准备 数据集可以使用官网数据集,也可以用ssb-dbgen来准备 1.准备数据 这里最后生成表的数据行数为60亿行,数据量为300G左右 git clone https://github.com/vadimtk/ssb-dbgen.git cd ssb-dbgen/ make1.1 生成数据 # -s 指生成多少G的数据…...

AI夏令营第三期用户新增挑战赛学习笔记

1、数据可视化 1.数据探索和理解:数据可视化可以帮助我们更好地理解数据集的特征、分布和关系。通过可视化数据,我们可以发现数据中的模式、异常值、缺失值等信息,从而更好地了解数据的特点和结构。2.特征工程:数据可视化可以帮助…...

pdf转ppt软件哪个好用?推荐一个好用的pdf转ppt软件

在日常工作和学习中,我们经常会遇到需要将PDF文件转换为PPT格式的情况。PDF格式的文件通常用于展示和保留文档的原始格式,而PPT格式则更适合用于演示和展示。为了满足这一需求,许多软件提供了PDF转PPT的功能,使我们能够方便地将PD…...

Linux 内核函数kallsyms_lookup_name

文章目录 一、API使用二、源码解析2.1 kallsyms_lookup_name2.2 kallsyms_expand_symbol2.3 kallsyms_sym_address2.3.1 x86_642.3.2 arm642.3.3 CONFIG_KALLSYMS_ABSOLUTE_PERCPU 参考资料 一、API使用 kallsyms_lookup_name 是一个内核函数,用于通过符号名称查找…...

强化学习在游戏AI中的应用与挑战

文章目录 1. 强化学习简介2. 强化学习在游戏AI中的应用2.1 游戏智能体训练2.2 游戏AI决策2.3 游戏测试和优化 3. 强化学习在游戏AI中的挑战3.1 探索与利用的平衡3.2 多样性的应对 4. 解决方法与展望4.1 深度强化学习4.2 奖励设计和函数逼近 5. 总结 🎉欢迎来到AIGC人…...

6 Python的异常处理

概述 在上一节,我们介绍了Python的面向对象编程,包括:类的定义、类的使用、类变量、实例变量、实例方法、类方法、静态方法、类的运算符重载、继承等内容。在这一节中,我们将介绍Python的异常处理。异常是指程序在运行过程中出现的…...

【跨语言通讯】

传统的跨语言通讯方案: 基于SOAP消息格式的WebService 基于JSON消息格式的RESTful 服务 主要弊端: XML体积太大,解析性能极差 JSON体积相对较小,解析相对较快,但表达能力较弱 如今比较流行的跨语言通讯方案&…...

Android 基础知识

一、Activity 1、onSaveInstanceState(),onRestoreInstanceState的调用时机 onSaveInstanceState 调用时机 从最近应用中选择运行其他程序时 但用户按下Home键时 屏幕方向切换时 按下电源案件时 从当前activity启动一个新的activity时 onRestorInstanceState调用时机 只…...

Linux常用命令_帮助命令、用户管理命令、压缩解压命令

文章目录 1. 帮助命令1.1 帮助命令:man1.2 帮助命令:help1.3 其他帮助命令 2. 用户管理命令2.1 用户管理命令: useradd2.2 用户管理命令: passwd2.3 用户管理命令: who2.4 用户管理命令: w 3. 压缩解压命令3.1 压缩解压命令: gzip3.2 压缩解压命令: gunzip3.3 压缩解压命令: ta…...

解决 KylinOS “Could not get lock /var/lib/dpkg/lock”错误

最近,我遇到了 “Could not get lock /var/lib/dpkg/lock”的错误,我既不能安装任何软件包,也不能更新系统。此错误也与“Could not get lock /var/lib/apt/lists/lock”错误密切相关。以下是 Ubuntu 20.04 上的一些样本输出。 Reading package lists… Done E: Could not…...

PHP pdf 自动填写表单

一、下载github上的项目,地址 二、下载pdftk 地址 // 转化PDF模板 pdftk modele.pdf output modele2.pdf# 填充pdf文件中的表单 require(fpdm.php); $fields array(name > My name,address > My address,city > My city,phone > My phone nu…...

Win2016Server绑定多网卡实现负载均衡

一、服务器端: 1、输入ncpa.cpl打开网络连接,对要绑定的网卡勾掉IPV4,IPV4地址选择自动 2、输入servermanager.exe,打开服务器管理器 3、在 [本地服务器] 中,点后边的 “已禁用” ,在 [适配器和接口] 小窗口…...

微软宣布在 Excel 中使用 Python:结合了 Python 的强大功能和 Excel 的灵活性。

文章目录 Excel 中的 Python 有何独特之处?1. Excel 中的 Python 是为分析师构建的。高级可视化机器学习、预测分析和预测数据清理 2. Excel 中的 Python 通过 Anaconda 展示了最好的 Python 分析功能。3. Excel 中的 Python 在 Microsoft 云上安全运行,…...

学习心得03:OpenCV

数学真是不可思议,不管什么东西,都能用数学来处理。OpenCV以前也接触过,这次是系统学习一下。 颜色模型 RGB,YUV,HSV,Lab,GRAY 颜色转换cvtColor()/convertTo(),通道分离split()&…...

ubuntu学习(五)----读取文件以及光标的移动

1、读取文件函数原型介绍 ssize_t read(int fd,void*buf,size_t count) 参数说明: fd: 是文件描述符 buf:为读出数据的缓冲区; count: 为每次读取的字节数(是请求读取的字节数,读上来的数据保存在缓冲区buf中,同时文…...

Python 数据分析——matplotlib 快速绘图

matplotlib采用面向对象的技术来实现,因此组成图表的各个元素都是对象,在编写较大的应用程序时通过面向对象的方式使用matplotlib将更加有效。但是使用这种面向对象的调用接口进行绘图比较烦琐,因此matplotlib还提供了快速绘图的pyplot模块。…...

uniapp小程序位置信息配置

uniapp 小程序获取当前位置信息报错 报错信息: getLocation:fail the api need to be declared in the requiredPrivateInfos field in app.json/ext.json 需要在manifest.json配置文件中进行配置:...

《基于 Vue 组件库 的 Webpack5 配置》1.模式 Mode 和 vue-loader

一定要配置 模式 Mode,这里有个小知识点,环境变量 process.env.NODE_ENV module.exports {mode: production,// process.env.NODE_ENV 或 development, }一定要配置 vue-loader Vue Loader v15 现在需要配合一个 webpack 插件才能正确使用; …...

01.sqlite3学习——数据库概述

目录 重点概述总结 数据库标准介绍 什么是数据库? 数据库是如何存储数据的? 数据库是如何管理数据的? 数据库系统结构 常见关系型数据库管理系统 关系型数据库相关知识点 数据库与文件存储数据对比 重点概述总结 数据库可以理解为操…...

测试微信模版消息推送

进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...

Linux 文件类型,目录与路径,文件与目录管理

文件类型 后面的字符表示文件类型标志 普通文件:-(纯文本文件,二进制文件,数据格式文件) 如文本文件、图片、程序文件等。 目录文件:d(directory) 用来存放其他文件或子目录。 设备…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括:采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中,设置任务排序规则尤其重要,因为它让看板视觉上直观地体…...

1688商品列表API与其他数据源的对接思路

将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...

【磁盘】每天掌握一个Linux命令 - iostat

目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码,专为学校招生场景量身打造,功能实用且操作便捷。 从技术架构来看,ThinkPHP提供稳定可靠的后台服务,FastAdmin加速开发流程,UniApp则保障小程序在多端有良好的兼…...

生成 Git SSH 证书

🔑 1. ​​生成 SSH 密钥对​​ 在终端(Windows 使用 Git Bash,Mac/Linux 使用 Terminal)执行命令: ssh-keygen -t rsa -b 4096 -C "your_emailexample.com" ​​参数说明​​: -t rsa&#x…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言:为什么 Eureka 依然是存量系统的核心? 尽管 Nacos 等新注册中心崛起,但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制,是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

今日科技热点速览

🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...