Desnet模型详解
模型介绍
DenseNet的主要思想是密集连接,它在卷积神经网络(CNN)中引入了密集块(Dense Block),在这些块中,每个层都与前面所有层直接连接。这种设计可以让信息更快速地传播,有助于解决梯度消失的问题,同时也能够增加网络的参数共享,减少参数量,提高模型的效率和性能。
Desnet原理
DenseNet 的原理可以总结为以下几个关键点:
-
密集连接的块: DenseNet 将网络分成多个密集块(Dense Block)。在每个密集块内,每一层都连接到前面所有的层,不仅仅是前一层。这种连接方式使得信息能够更加快速地传播,允许网络在更早的阶段融合不同层的特征。
-
跳跃连接: 每一层都从前面所有的层接收特征作为输入。这些输入通过堆叠而来,从而构成了一个密集的特征图。这种跳跃连接有助于解决梯度消失问题,因为每一层都可以直接访问之前层的梯度信息,使得训练更加稳定。
-
特征重用性: 由于每一层都与前面所有层连接,网络可以自动地学习到更加丰富和复杂的特征表示。这样的特征重用性有助于提高网络的性能,同时减少了需要训练的参数数量。
-
过渡层: 在密集块之间,通常会使用过渡层(Transition Layer)来控制特征图的大小。过渡层包括一个卷积层和一个池化层,用于减小特征图的尺寸,从而减少计算量。
Desnet的结构
关于 DenseNet 的结构时,我们主要关注网络中的三个主要组成部分:密集块(Dense Block)、过渡层(Transition Layer)以及全局平均池化层。
密集块
密集块是 DenseNet 最核心的部分,由若干层组成。在密集块中,每一层都与前面所有层直接连接。这种密集连接的方式使得信息可以更充分地传递和重用。每一层的输出都是前面所有层输出的连结,这也意味着每一层的输入包括了前面所有层的特征。这种连接方式通过堆叠层的方式,构建了一个密集的特征图。
过渡层
在密集块之间,可以使用过渡层来控制特征图的大小,从而减少计算成本。过渡层由一个卷积层和一个池化层组成。卷积层用于减小通道数,从而降低特征图的维度。池化层(通常是平均池化)用于减小特征图的尺寸。这些操作有助于在保持网络性能的同时降低计算需求。
全局平均池化层
在整个 DenseNet 结构的末尾,通常会添加一个全局平均池化层。这一层的作用是将最终的特征图转换为全局汇总的特征,这对于分类任务是非常有用的。全局平均池化层计算每个通道上的平均值,将每个通道转换为一个标量,从而形成最终的预测。
DenseNet 结构的特点不仅在每个密集块内进行特征的密集连接,还在不同密集块之间使用过渡层来控制网络的尺寸和复杂度。这使得 DenseNet 能够在高度复杂的任务中表现出色,同时保持相对较少的参数。
这些在论文当中也有体现:
Desnet的优缺点比较
优点
-
密集连接促进信息传递和特征重用,提升了网络性能。
-
跳跃连接减少了梯度消失,有助于训练深层网络。
-
密集连接减少参数数量,提高了模型效率。
-
早期融合多尺度特征,增强了表征能力。
-
在小样本情况下表现更佳,充分利用有限数据。
缺点
-
密集连接可能导致内存需求增大。
-
连接多导致计算量增加,训练和推理时间较长。
-
可能因复杂性导致过拟合,需考虑正则化。
其实综合考虑,Desnet在图像识别和计算机视觉任务中仍然是一个好的选择。
Pytorch实现Desnet
import torch
import torchvision
import torch.nn as nn
import torchsummary
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
from collections import OrderedDict
from torchvision.utils import _log_api_usage_once
import torch.utils.checkpoint as cpmodel_urls = {"densenet121":"https://download.pytorch.org/models/densenet121-a639ec97.pth","densenet161":"https://download.pytorch.org/models/densenet161-8d451a50.pth","densenet169":"https://download.pytorch.org/models/densenet169-b2777c0a.pth","densenet201":"https://download.pytorch.org/models/densenet201-c1103571.pth",
}
cfgs = {"densenet121":(6, 12, 24, 16),"densenet161":(6, 12, 36, 24),"densenet169":(6, 12, 32, 32),"densenet201":(6, 12, 48, 32),
}class DenseLayer(nn.Module):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient = False):super(DenseLayer,self).__init__()self.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)self.drop_rate = float(drop_rate)self.memory_efficient = memory_efficientdef bn_function(self, inputs):concated_features = torch.cat(inputs, 1)bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))return bottleneck_outputdef any_requires_grad(self, input):for tensor in input:if tensor.requires_grad:return Truereturn False@torch.jit.unuseddef call_checkpoint_bottleneck(self, input):def closure(*inputs):return self.bn_function(inputs)return cp.checkpoint(closure, *input)def forward(self, input):if isinstance(input, torch.Tensor):prev_features = [input]else:prev_features = inputif self.memory_efficient and self.any_requires_grad(prev_features):if torch.jit.is_scripting():raise Exception("Memory Efficient not supported in JIT")bottleneck_output = self.call_checkpoint_bottleneck(prev_features)else:bottleneck_output = self.bn_function(prev_features)new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return new_featuresclass DenseBlock(nn.ModuleDict):def __init__(self,num_layers,num_input_features,bn_size,growth_rate,drop_rate,memory_efficient = False,):super(DenseBlock,self).__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.add_module("denselayer%d" % (i + 1), layer)def forward(self, init_features):features = [init_features]for name, layer in self.items():new_features = layer(features)features.append(new_features)return torch.cat(features, 1)class Transition(nn.Sequential):"""Densenet Transition Layer:1 × 1 conv2 × 2 average pool, stride 2"""def __init__(self, num_input_features, num_output_features):super(Transition,self).__init__()self.norm = nn.BatchNorm2d(num_input_features)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)class DenseNet(nn.Module):def __init__(self,growth_rate = 32,num_init_features = 64,block_config = None,num_classes = 1000,bn_size = 4,drop_rate = 0.,memory_efficient = False,):super(DenseNet,self).__init__()_log_api_usage_once(self)# First convolutionself.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),]))# Each denseblocknum_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.features.add_module("denseblock%d" % (i + 1), block)num_features = num_features + num_layers * growth_rateif i != len(block_config) - 1:trans = Transition(num_input_features=num_features, num_output_features=num_features // 2)self.features.add_module("transition%d" % (i + 1), trans)num_features = num_features // 2# Final batch normself.features.add_module("norm5", nn.BatchNorm2d(num_features))# Linear layerself.classifier = nn.Linear(num_features, num_classes)# Official init from torch repo.for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.relu(features, inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))out = torch.flatten(out, 1)out = self.classifier(out)return outdef densenet(growth_rate=32,num_init_features=64,num_classes=1000,mode="densenet121",pretrained=False,**kwargs):import repattern = re.compile(r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$")if mode == "densenet161":growth_rate=48num_init_features=96model = DenseNet(growth_rate, num_init_features, cfgs[mode],num_classes=num_classes, **kwargs)if pretrained:state_dict = load_state_dict_from_url(model_urls[mode], model_dir='./model', progress=True) # 预训练模型地址for key in list(state_dict.keys()):res = pattern.match(key)if res:new_key = res.group(1) + res.group(2)state_dict[new_key] = state_dict[key]del state_dict[key]if num_classes != 1000:num_new_classes = num_classesweight = state_dict['classifier.weight']bias = state_dict['classifier.bias']weight_new = weight[:num_new_classes, :]bias_new = bias[:num_new_classes]state_dict['classifier.weight'] = weight_newstate_dict['classifier.bias'] = bias_newmodel.load_state_dict(state_dict)return modelfrom torchsummaryX import summaryif __name__ == "__main__":in_channels = 3num_classes = 10device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = densenet(growth_rate=32, num_init_features=64, num_classes=num_classes, mode="densenet121", pretrained=True)model = model.to(device)print(model)summary(model, torch.zeros((1, in_channels, 224, 224)).to(device))
相关文章:

Desnet模型详解
模型介绍 DenseNet的主要思想是密集连接,它在卷积神经网络(CNN)中引入了密集块(Dense Block),在这些块中,每个层都与前面所有层直接连接。这种设计可以让信息更快速地传播,有助于解…...

clickhouse-压测
一、数据集准备 数据集可以使用官网数据集,也可以用ssb-dbgen来准备 1.准备数据 这里最后生成表的数据行数为60亿行,数据量为300G左右 git clone https://github.com/vadimtk/ssb-dbgen.git cd ssb-dbgen/ make1.1 生成数据 # -s 指生成多少G的数据…...

AI夏令营第三期用户新增挑战赛学习笔记
1、数据可视化 1.数据探索和理解:数据可视化可以帮助我们更好地理解数据集的特征、分布和关系。通过可视化数据,我们可以发现数据中的模式、异常值、缺失值等信息,从而更好地了解数据的特点和结构。2.特征工程:数据可视化可以帮助…...

pdf转ppt软件哪个好用?推荐一个好用的pdf转ppt软件
在日常工作和学习中,我们经常会遇到需要将PDF文件转换为PPT格式的情况。PDF格式的文件通常用于展示和保留文档的原始格式,而PPT格式则更适合用于演示和展示。为了满足这一需求,许多软件提供了PDF转PPT的功能,使我们能够方便地将PD…...
Linux 内核函数kallsyms_lookup_name
文章目录 一、API使用二、源码解析2.1 kallsyms_lookup_name2.2 kallsyms_expand_symbol2.3 kallsyms_sym_address2.3.1 x86_642.3.2 arm642.3.3 CONFIG_KALLSYMS_ABSOLUTE_PERCPU 参考资料 一、API使用 kallsyms_lookup_name 是一个内核函数,用于通过符号名称查找…...

强化学习在游戏AI中的应用与挑战
文章目录 1. 强化学习简介2. 强化学习在游戏AI中的应用2.1 游戏智能体训练2.2 游戏AI决策2.3 游戏测试和优化 3. 强化学习在游戏AI中的挑战3.1 探索与利用的平衡3.2 多样性的应对 4. 解决方法与展望4.1 深度强化学习4.2 奖励设计和函数逼近 5. 总结 🎉欢迎来到AIGC人…...
6 Python的异常处理
概述 在上一节,我们介绍了Python的面向对象编程,包括:类的定义、类的使用、类变量、实例变量、实例方法、类方法、静态方法、类的运算符重载、继承等内容。在这一节中,我们将介绍Python的异常处理。异常是指程序在运行过程中出现的…...
【跨语言通讯】
传统的跨语言通讯方案: 基于SOAP消息格式的WebService 基于JSON消息格式的RESTful 服务 主要弊端: XML体积太大,解析性能极差 JSON体积相对较小,解析相对较快,但表达能力较弱 如今比较流行的跨语言通讯方案&…...

Android 基础知识
一、Activity 1、onSaveInstanceState(),onRestoreInstanceState的调用时机 onSaveInstanceState 调用时机 从最近应用中选择运行其他程序时 但用户按下Home键时 屏幕方向切换时 按下电源案件时 从当前activity启动一个新的activity时 onRestorInstanceState调用时机 只…...

Linux常用命令_帮助命令、用户管理命令、压缩解压命令
文章目录 1. 帮助命令1.1 帮助命令:man1.2 帮助命令:help1.3 其他帮助命令 2. 用户管理命令2.1 用户管理命令: useradd2.2 用户管理命令: passwd2.3 用户管理命令: who2.4 用户管理命令: w 3. 压缩解压命令3.1 压缩解压命令: gzip3.2 压缩解压命令: gunzip3.3 压缩解压命令: ta…...
解决 KylinOS “Could not get lock /var/lib/dpkg/lock”错误
最近,我遇到了 “Could not get lock /var/lib/dpkg/lock”的错误,我既不能安装任何软件包,也不能更新系统。此错误也与“Could not get lock /var/lib/apt/lists/lock”错误密切相关。以下是 Ubuntu 20.04 上的一些样本输出。 Reading package lists… Done E: Could not…...
PHP pdf 自动填写表单
一、下载github上的项目,地址 二、下载pdftk 地址 // 转化PDF模板 pdftk modele.pdf output modele2.pdf# 填充pdf文件中的表单 require(fpdm.php); $fields array(name > My name,address > My address,city > My city,phone > My phone nu…...
Win2016Server绑定多网卡实现负载均衡
一、服务器端: 1、输入ncpa.cpl打开网络连接,对要绑定的网卡勾掉IPV4,IPV4地址选择自动 2、输入servermanager.exe,打开服务器管理器 3、在 [本地服务器] 中,点后边的 “已禁用” ,在 [适配器和接口] 小窗口…...

微软宣布在 Excel 中使用 Python:结合了 Python 的强大功能和 Excel 的灵活性。
文章目录 Excel 中的 Python 有何独特之处?1. Excel 中的 Python 是为分析师构建的。高级可视化机器学习、预测分析和预测数据清理 2. Excel 中的 Python 通过 Anaconda 展示了最好的 Python 分析功能。3. Excel 中的 Python 在 Microsoft 云上安全运行,…...
学习心得03:OpenCV
数学真是不可思议,不管什么东西,都能用数学来处理。OpenCV以前也接触过,这次是系统学习一下。 颜色模型 RGB,YUV,HSV,Lab,GRAY 颜色转换cvtColor()/convertTo(),通道分离split()&…...

ubuntu学习(五)----读取文件以及光标的移动
1、读取文件函数原型介绍 ssize_t read(int fd,void*buf,size_t count) 参数说明: fd: 是文件描述符 buf:为读出数据的缓冲区; count: 为每次读取的字节数(是请求读取的字节数,读上来的数据保存在缓冲区buf中,同时文…...

Python 数据分析——matplotlib 快速绘图
matplotlib采用面向对象的技术来实现,因此组成图表的各个元素都是对象,在编写较大的应用程序时通过面向对象的方式使用matplotlib将更加有效。但是使用这种面向对象的调用接口进行绘图比较烦琐,因此matplotlib还提供了快速绘图的pyplot模块。…...

uniapp小程序位置信息配置
uniapp 小程序获取当前位置信息报错 报错信息: getLocation:fail the api need to be declared in the requiredPrivateInfos field in app.json/ext.json 需要在manifest.json配置文件中进行配置:...
《基于 Vue 组件库 的 Webpack5 配置》1.模式 Mode 和 vue-loader
一定要配置 模式 Mode,这里有个小知识点,环境变量 process.env.NODE_ENV module.exports {mode: production,// process.env.NODE_ENV 或 development, }一定要配置 vue-loader Vue Loader v15 现在需要配合一个 webpack 插件才能正确使用; …...

01.sqlite3学习——数据库概述
目录 重点概述总结 数据库标准介绍 什么是数据库? 数据库是如何存储数据的? 数据库是如何管理数据的? 数据库系统结构 常见关系型数据库管理系统 关系型数据库相关知识点 数据库与文件存储数据对比 重点概述总结 数据库可以理解为操…...

XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
模型参数、模型存储精度、参数与显存
模型参数量衡量单位 M:百万(Million) B:十亿(Billion) 1 B 1000 M 1B 1000M 1B1000M 参数存储精度 模型参数是固定的,但是一个参数所表示多少字节不一定,需要看这个参数以什么…...
在rocky linux 9.5上在线安装 docker
前面是指南,后面是日志 sudo dnf config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo sudo dnf install docker-ce docker-ce-cli containerd.io -y docker version sudo systemctl start docker sudo systemctl status docker …...

【第二十一章 SDIO接口(SDIO)】
第二十一章 SDIO接口 目录 第二十一章 SDIO接口(SDIO) 1 SDIO 主要功能 2 SDIO 总线拓扑 3 SDIO 功能描述 3.1 SDIO 适配器 3.2 SDIOAHB 接口 4 卡功能描述 4.1 卡识别模式 4.2 卡复位 4.3 操作电压范围确认 4.4 卡识别过程 4.5 写数据块 4.6 读数据块 4.7 数据流…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例
文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

对WWDC 2025 Keynote 内容的预测
借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...

1.3 VSCode安装与环境配置
进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
Rust 异步编程
Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

ios苹果系统,js 滑动屏幕、锚定无效
现象:window.addEventListener监听touch无效,划不动屏幕,但是代码逻辑都有执行到。 scrollIntoView也无效。 原因:这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作,从而会影响…...