当前位置: 首页 > news >正文

pytorch04:网络模型创建

目录

  • 一、模型创建过程
    • 1.1 以LeNet网络为例
    • 1.2 LeNet结构
    • 1.3 nn.Module
  • 二、网络层容器(Containers)
    • 2.1 nn.Sequential
      • 2.1.1 常规方法实现
      • 2.1.2 OrderedDict方法实现
    • 2.2 nn.ModuleList
    • 2.3 nn.ModuleDict
    • 2.4 三种容器构建总结
  • 三、AlexNet网络构建

一、模型创建过程

在这里插入图片描述

1.1 以LeNet网络为例

在这里插入图片描述

网络代码如下:

class LeNet(nn.Module):def __init__(self, classes):super(LeNet, self).__init__()  # 调用父类方法,作用是调用nn.Module类的构造函数,# 确保LeNet类被正确地初始化,并继承了nn.Module 的所有属性和方法self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out

1.2 LeNet结构

在这里插入图片描述

LeNet:conv1–>pool1–>conv2–>pool2–>fc1–>fc2–>fc3
在这里插入图片描述

1.3 nn.Module

Module是nn模块中的功能,nn模块还有Parameter、functional等模块。
在这里插入图片描述
nn.Module主要有以下参数:
• parameters : 存储管理nn.Parameter类
• modules : 存储管理nn.Module类
• buffers:存储管理缓冲属性,如BN层中的running_mean

二、网络层容器(Containers)

在这里插入图片描述

2.1 nn.Sequential

nn.Sequential 是 nn.module的容器,也是最常用的容器,用于按顺序包装一组网络层
• 顺序性:各网络层之间严格按照顺序构建
• 自带forward():自带的forward里,通过for循环依次执行前向传播运算

2.1.1 常规方法实现

LeNet网络由两部分构成,中间的卷积池化特征提取部分(features),以及最后的分类部分(classifier)。
在这里插入图片描述
具体代码如下:

class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential(  #特征提取部分nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(  #分类部分nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

打印网络层:
在这里插入图片描述

2.1.2 OrderedDict方法实现

使用有序字典的方法构建Sequential
代码如下:

class LeNetSequentialOrderDict(nn.Module):def __init__(self, classes):super(LeNetSequentialOrderDict, self).__init__()self.features = nn.Sequential(OrderedDict({'conv1': nn.Conv2d(3, 6, 5),'relu1': nn.ReLU(inplace=True),'pool1': nn.MaxPool2d(kernel_size=2, stride=2),'conv2': nn.Conv2d(6, 16, 5),'relu2': nn.ReLU(inplace=True),'pool2': nn.MaxPool2d(kernel_size=2, stride=2),}))self.classifier = nn.Sequential(OrderedDict({'fc1': nn.Linear(16 * 5 * 5, 120),'relu3': nn.ReLU(),'fc2': nn.Linear(120, 84),'relu4': nn.ReLU(inplace=True),'fc3': nn.Linear(84, classes),}))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

先看一下Sequential函数中init初始化的两种方法,当我们使用OrderedDict方法时,会进行判断,使用self.add_module(key, module)方法将字典中的key和value取出来添加到Sequential中。

class Sequential(Module):def __init__(self, *args):super().__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)

通过这种方法构建可以给每一网络层添加一个名称,网络输出结果如下:
在这里插入图片描述

2.2 nn.ModuleList

nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层
主要方法:
• append():在ModuleList后面添加网络层
• extend():拼接两个ModuleList
• insert():指定在ModuleList中位置插入网络层

使用列表生成式,通过一行代码就能构建20个网络层。
代码演示:

class ModuleList(nn.Module):def __init__(self):super(ModuleList, self).__init__()# 使用列表生成式构建20个全连接层,每个全连接层10个神经元的网络self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])def forward(self, x):for i, linear in enumerate(self.linears):x = linear(x)return xnet = ModuleList()

2.3 nn.ModuleDict

nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层,可以用过参数的形式选取想要调用的网络层。
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除

代码展示,只选取conv和relu两个网络层:

class ModuleDict(nn.Module):def __init__(self):super(ModuleDict, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})# 激活函数self.activations = nn.ModuleDict({'relu': nn.ReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act):  # 传入两个参数 用来选择网络层x = self.choices[choice](x)x = self.activations[act](x)return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')  #只选取conv和relu两个网络层。
print(output)

2.4 三种容器构建总结

• nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
• nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
• nn.ModuleDict:索引性,常用于可选择的网络层

三、AlexNet网络构建

AlexNet:2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代
AlexNet特点如下:

  1. 采用ReLU:替换饱和激活函数,减轻梯度消失
  2. 采用LRN(Local Response Normalization):对数据归一化,减轻梯度消失
  3. Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
  4. Data Augmentation:TenCrop,色彩修改

网络结构图如下:
在这里插入图片描述
构建代码:

import torch.nn as nn
import torch
from torchsummary import summary
# 定义一个名为AlexNet的神经网络模型,继承自nn.Module基类
class AlexNet(nn.Module):# 构造函数,初始化网络的参数def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:# 调用父类的构造函数super().__init__()# 定义神经网络的特征提取部分,包含多个卷积层和池化层self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),  # 输入通道3,输出通道64,卷积核大小11x11,步长4,填充2nn.ReLU(inplace=True),  # 使用ReLU激活函数,inplace=True表示原地操作,节省内存nn.MaxPool2d(kernel_size=3, stride=2),  # 最大池化层,核大小3x3,步长2nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 输入通道64,输出通道192,卷积核大小5x5,填充2nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)# 定义自适应平均池化层,将输入的任意大小的特征图池化为固定大小6x6self.avgpool = nn.AdaptiveAvgPool2d((6, 6))# 定义分类器部分,包含全连接层和Dropout层self.classifier = nn.Sequential(nn.Dropout(p=dropout),  # 使用Dropout进行正则化,随机丢弃一部分神经元以防止过拟合nn.Linear(256 * 6 * 6, 4096),  # 输入大小为256*6*6,输出大小为4096nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),  # 最后的全连接层输出类别数)# 前向传播函数,定义数据在网络中的传播过程def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x)  # 特征提取x = self.avgpool(x)  # 平均池化x = torch.flatten(x, 1)  # 将特征图展平成一维向量x = self.classifier(x)  # 分类器return xif __name__ == '__main__':net = AlexNet().cuda()summary(net, (3, 256, 256))

打印出的网络结构图如下:
在这里插入图片描述

相关文章:

pytorch04:网络模型创建

目录 一、模型创建过程1.1 以LeNet网络为例1.2 LeNet结构1.3 nn.Module 二、网络层容器(Containers)2.1 nn.Sequential2.1.1 常规方法实现2.1.2 OrderedDict方法实现 2.2 nn.ModuleList2.3 nn.ModuleDict2.4 三种容器构建总结 三、AlexNet网络构建 一、模型创建过程 1.1 以LeNe…...

用js让用户输入一个数累加和

需求&#xff1a;用户输入一个数&#xff0c; 计算 1 到这个数的和。 比如 用户输入的是 5&#xff0c; 则计算 1~5 之间的累加和 并且输出到控制台 <body><script>let numprompt(请输入一个数)let sum0for(let i1;i<num;i){sumi}console.log(sum)</script…...

踩坑记录-安装nuxt3报错:Error: Failed to download template from registry: fetch failed;

报错复现 安装nuxt3报错&#xff1a;Error: Failed to download template from registry: fetch failednpx nuxi init nuxt-demo 初始化nuxt 项目 报错 Error: Failed to download template from registry: fetch faile 解决方法 配置hosts Mac电脑&#xff1a;/etc/hostswin电…...

大数据学习(31)-Spark非常用及重要特性

&&大数据学习&& &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 承认自己的无知&#xff0c;乃是开启智慧的大门 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4dd;支持一下博主哦&#x1f91…...

【教学类-43-14】 20240103 (4宫格数独:正确版:576套) 不重复的基础模板数量:576套

作品展示&#xff1a;&#xff1a;——4宫格 576套不重复模板&#xff08;48页*12套题&#xff09; 背景需求&#xff1a; 生成4宫格基础模板768套&#xff0c;观看64页内容时&#xff0c;明显看到有错误 【教学类-43-13】 20240103 &#xff08;4宫格数独&#xff1a;错误版…...

AIGC开发:调用openai的API接口实现简单机器人

简介 开始进行最简单的使用&#xff1a;通过API调用openai的模型能力 OpenAI的能力如下图&#xff1a; 文本生成模型 OpenAI 的文本生成模型&#xff08;通常称为生成式预训练 Transformer 或大型语言模型&#xff09;经过训练可以理解自然语言、代码和图像。这些模型提供文…...

c基础(二)

指针&#xff1a; 含义&#xff1a;是一个值&#xff0c;一个值代表着一个内存地址&#xff0c;类似于存放路径 * 运算符 &#xff1a; 1 字符*表示指针 作用&#xff1a;通常跟在类型关键字的后面&#xff0c;表示指针指向的是什么类型的值 int * foo, * bar;声明指针后会…...

人工智能趋势报告解读:ai野蛮式生长的背后是机遇还是危机?

近期&#xff0c;Enterprise WordPress发布了生成式人工智能在营销中的应用程度的报告&#xff0c;这是一个人工智能迅猛发展的时代&#xff0c;目前人工智能已经广泛运用到内容创作等领域&#xff0c;可以预见的是人工智能及其扩展应用还将延伸到我们工作与生活中的方方面面。…...

三、C语言中的分支与循环—goto语句 (10) (完)

在C语言中&#xff0c;goto语句允许程序无条件地跳转到同一函数内的标记位置。这个标记位置通过一个标签和冒号(:)来标示。goto语句可以用于从深层嵌套的循环或条件语句中直接跳出&#xff0c;或者跳过某些代码执行。尽管goto语句在某些情况下可以使程序逻辑变得清晰&#xff0…...

RabbitMQ 常见问题

1. 如何保证消息顺序消费 在RabbitMQ中&#xff0c;消息最终会保存在队列中&#xff0c;在同一个队列中&#xff0c;消息是顺序的&#xff0c;保持先进先出的原则&#xff0c;这个由Rabbitmq保证。而不同队列中的消息&#xff0c;RabbitMQ 是无法保证其顺序性。顺序消费主要是…...

阶段二-Day10-日期类

日期类结构: 1.java.util.Date是日期类 2.DateFormat是日期格式类、SimpleDateFormat是日期格式类的子类 Timezone代表时区 3.Calendar是日历类&#xff0c;GregorianCalendar是日历的子类 一. 常用类-Date 1.1 Date构造方法 Date(long date) 使用给定的毫秒时间价值构建…...

多任务并行处理相关面试题

我自己面试时被问过两次多任务并行相关的问题&#xff1a; 假设现在有10个任务&#xff0c;要求同时处理&#xff0c;并且必须所有任务全部完成才返回结果 这个面试题的难点是&#xff1a; 既然要同时处理&#xff0c;那么肯定要用多线程。怎么设计多线程同时处理任务呢&…...

Shell脚本学习笔记

1. 写在前面 工作中&#xff0c;需要用到写一些shell脚本去完成一些简单的重复性工作&#xff0c; 于是就想系统的学习下shell脚本的相关知识&#xff0c; 本篇文章是学习shell脚本整理的学习笔记&#xff0c;内容参考主要来自C语言中文网&#xff0c; 学习过程中&#xff0c;…...

ROS-安装xacro

安装 运行下列命令进行安装&#xff0c;xxxxxx处更改为自己的版本 sudo apt-get install ros-xxxxxx-xacro运行 输入下列命令 roscd xacro如果没有报错&#xff0c;并且进入了xacro软件包的目录&#xff0c;则表示安装成功。 参考&#xff1a; [1]https://wenku.csdn.net/ans…...

为什么说 $mash 是 Solana 上最正统的铭文通证?

早在 2023 年的 11 月&#xff0c;包括 Solana、Avalanche、Polygon、Arbitrum、zkSync 等生态正在承接比特币铭文生态外溢的价值。当然&#xff0c;因铭文赛道过于火爆&#xff0c;当 Avalanche、BNB Chain 以及 Polygon 等链上 Gas 飙升至极值&#xff0c;Arbitrum、zkSync 等…...

安装elasticsearch、kibana、IK分词器、扩展IK词典

安装elasticsearch、kibana、IK分词器、扩展IK词典 后面还会安装kibana&#xff0c;这个会提供可视化界面方面学习。 需要注意的是elasticsearch和kibana版本一定要一样&#xff01;&#xff01;&#xff01; 否则就像这样 elasticsearch 1、创建网络 因为我们还需要部署k…...

Spring中常见的BeanFactory后处理器

常见的BeanFacatory后处理器 先给出没有添加任何BeanFactory后处理器的测试代码 public class TestBeanFactoryPostProcessor {public static void main(String[] args) {GenericApplicationContext context new GenericApplicationContext();context.registerBean("co…...

FPGA LCD1602驱动代码 (已验证)

一.需求解读 1.需求 在液晶屏第一行显示“HELLO FPGA 1234!” 2. 知识背景 1602 液晶也叫 1602 字符型液晶,它是一种专门用来显示字母、数字、符号等的点阵 型液晶模块。它由若干个 5X7 或者 5X11 等点阵字符位组成,每个点阵字符位都可以显示一 个字符,每位之间有一个点距的…...

c++编程要养成的好习惯

1、缩进 你说有缩进看的清楚还是没缩进看的清楚 2、i和i i运行起来和i更快 3、 n%20和n&1 不要再用n%20来判断n是不是偶数了&#xff0c;又慢又土&#xff0c;用n&10&#xff0c;如果n&10就说明n是偶数 同理&#xff0c;n&11说明n是奇数 4、*2和<<…...

后台管理项目的多数据源方案

引言 在互联网开发公司中&#xff0c;往往伴随着业务的快速迭代&#xff0c;程序员可能没有过多的时间去思考技术扩展的相关问题&#xff0c;长久下来导致技术过于单一。为此最近在学习互联网思维&#xff0c;从相对简单的功能开始做总结&#xff0c;比如非常常见的基础数据的…...

IGP(Interior Gateway Protocol,内部网关协议)

IGP&#xff08;Interior Gateway Protocol&#xff0c;内部网关协议&#xff09; 是一种用于在一个自治系统&#xff08;AS&#xff09;内部传递路由信息的路由协议&#xff0c;主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...

centos 7 部署awstats 网站访问检测

一、基础环境准备&#xff08;两种安装方式都要做&#xff09; bash # 安装必要依赖 yum install -y httpd perl mod_perl perl-Time-HiRes perl-DateTime systemctl enable httpd # 设置 Apache 开机自启 systemctl start httpd # 启动 Apache二、安装 AWStats&#xff0…...

vue3 字体颜色设置的多种方式

在Vue 3中设置字体颜色可以通过多种方式实现&#xff0c;这取决于你是想在组件内部直接设置&#xff0c;还是在CSS/SCSS/LESS等样式文件中定义。以下是几种常见的方法&#xff1a; 1. 内联样式 你可以直接在模板中使用style绑定来设置字体颜色。 <template><div :s…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候&#xff0c;写过一篇简单实现&#xff0c;后期随着对该模型的深入研究&#xff0c;本次记录涉及到prophet 的公式以及参数调优&#xff0c;从公式可以更直观…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

HashMap中的put方法执行流程(流程图)

1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中&#xff0c;其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下&#xff1a; 初始判断与哈希计算&#xff1a; 首先&#xff0c;putVal 方法会检查当前的 table&#xff08;也就…...

【7色560页】职场可视化逻辑图高级数据分析PPT模版

7种色调职场工作汇报PPT&#xff0c;橙蓝、黑红、红蓝、蓝橙灰、浅蓝、浅绿、深蓝七种色调模版 【7色560页】职场可视化逻辑图高级数据分析PPT模版&#xff1a;职场可视化逻辑图分析PPT模版https://pan.quark.cn/s/78aeabbd92d1...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现企业微信功能

1. 开发环境准备 ​​安装DevEco Studio 3.1​​&#xff1a; 从华为开发者官网下载最新版DevEco Studio安装HarmonyOS 5.0 SDK ​​项目配置​​&#xff1a; // module.json5 {"module": {"requestPermissions": [{"name": "ohos.permis…...

自然语言处理——文本分类

文本分类 传统机器学习方法文本表示向量空间模型 特征选择文档频率互信息信息增益&#xff08;IG&#xff09; 分类器设计贝叶斯理论&#xff1a;线性判别函数 文本分类性能评估P-R曲线ROC曲线 将文本文档或句子分类为预定义的类或类别&#xff0c; 有单标签多类别文本分类和多…...

Python爬虫实战:研究Restkit库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的有价值数据。如何高效地采集这些数据并将其应用于实际业务中,成为了许多企业和开发者关注的焦点。网络爬虫技术作为一种自动化的数据采集工具,可以帮助我们从网页中提取所需的信息。而 RESTful API …...