当前位置: 首页 > news >正文

【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算

# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包

pip install torchsummary

import torch
import torch.nn as nn
from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类
class Model(nn.Module):# 初始化模型的构造函数def __init__(self):super().__init__()  # 调用父类 nn.Module 的初始化方法# 定义第一个全连接层(线性层),3个输入特征,3个输出特征self.linear1 = nn.Linear(3, 3)  # 使用 Xavier 正态分布初始化第一个全连接层的权重nn.init.xavier_normal_(self.linear1.weight)# 定义第二个全连接层,输入 3 个特征,输出 2 个特征self.linear2 = nn.Linear(3, 2)# 使用 Kaiming 正态分布初始化第二个全连接层的权重,适合 ReLU 激活函数nn.init.kaiming_normal_(self.linear2.weight)# 定义输出层,输入 2 个特征,输出 2 个特征self.out = nn.Linear(2, 2)# 定义前向传播过程 (forward 函数会自动执行,类似于模型的"推理"过程)def forward(self, x):# 第一个全连接层运算x = self.linear1(x)# 使用 Sigmoid 激活函数x = torch.sigmoid(x)# 第二个全连接层运算x = self.linear2(x)# 使用 ReLU 激活函数x = torch.relu(x)# 输出层运算x = self.out(x)# 使用 Softmax 激活函数,将输出转化为概率分布# dim=-1 表示在最后一个维度(通常是输出的类别维度)上做 softmax 归一化x = torch.softmax(x, dim=-1)return xif __name__ == '__main__':# 实例化神经网络模型my_model = Model()# 随机生成一个形状为 (5, 3) 的输入数据,表示 5 个样本,每个样本有 3 个特征my_data = torch.randn(5, 3)print("mydata shape", my_data.shape)# 通过模型进行前向传播,输出模型的预测结果output = my_model(my_data)print("output shape", output.shape)# 计算并显示模型的参数总量以及模型结构summary(my_model, input_size=(3,), batch_size=5)# 查看模型中所有的参数,包括权重和偏置项(bias)print("-----查看模型参数w 和 b  -----")for name, parameter in my_model.named_parameters():print(name, parameter)

mydata shape torch.Size([5, 3])
output shape torch.Size([5, 2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                     [5, 3]              12
            Linear-2                     [5, 2]               8
            Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
-----查看模型参数w 和 b  -----
linear1.weight Parameter containing:
tensor([[ 0.4777, -0.2076,  0.4900],
        [-0.1776,  0.4441,  0.6924],
        [-0.5449,  1.6153,  0.0243]], requires_grad=True)
linear1.bias Parameter containing:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)
linear2.weight Parameter containing:
tensor([[-0.0510, -1.2731, -0.7253],
        [-0.6112,  0.1189, -0.4903]], requires_grad=True)
linear2.bias Parameter containing:
tensor([0.5391, 0.2552], requires_grad=True)
out.weight Parameter containing:
tensor([[-0.3271, -0.3483],
        [-0.0619, -0.0680]], requires_grad=True)
out.bias Parameter containing:
tensor([-0.5508,  0.5895], requires_grad=True)
 

 代码输出结果解读

​​​​​​​

这个代码的输出展示了两部分内容:

  1. 数据维度和模型输出维度

    • mydata shape torch.Size([5, 3])

    • output shape torch.Size([5, 2])

  2. 模型的结构、参数数量和每一层的权重与偏置

    • 模型的层结构、每一层的输出形状,以及每一层的参数数量。

    • 每层的权重(weight)和偏置(bias)的具体数值。

让我们详细分析每一部分的输出。

1. 输入数据和输出数据的形状

mydata shape torch.Size([5, 3])

这部分的输出说明:

  • 输入数据的形状(5, 3),表示有 5 个样本,每个样本有 3 个特征。这与模型定义时的输入层 nn.Linear(3, 3) 是一致的,输入层期望接收 3 个特征。

output shape torch.Size([5, 2])

这部分的输出说明:

  • 模型输出的形状 (5, 2),表示 5 个样本的输出,每个样本的输出有 2 个值。由于模型的输出层定义为 nn.Linear(2, 2),它接收 2 个输入特征并输出 2 个值,符合预期。

2. 模型结构和参数

模型结构和参数信息是通过 summary() 函数生成的,它列出了每一层的名称、输出形状和参数数量。

详细输出解释:
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                     [5, 3]             12Linear-2                     [5, 2]               8Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
线性层 1(Linear-1
  • 层的类型Linear,这是一个全连接层,定义为 nn.Linear(3, 3)

  • 输出形状[5, 3],表示输入了 5 个样本,每个样本有 3 个特征,经过该层的输出仍然是 5 个样本,每个样本有 3 个特征。

  • 参数数量:12,其中 9 个是权重参数(3 x 3 的权重矩阵),另外 3 个是偏置项。

线性层 2(Linear-2
  • 层的类型Linear,定义为 nn.Linear(3, 2),将 3 个输入特征映射到 2 个输出特征。

  • 输出形状[5, 2],表示输入了 5 个样本,每个样本有 2 个输出特征。

  • 参数数量:8,其中 6 个是权重参数(3 x 2 的权重矩阵),另外 2 个是偏置项。

输出层(Linear-3
  • 层的类型Linear,定义为 nn.Linear(2, 2),接收 2 个输入特征,输出 2 个特征。

  • 输出形状[5, 2],表示 5 个样本,每个样本的输出为 2 个特征。

  • 参数数量:6,其中 4 个是权重参数(2 x 2 的权重矩阵),另外 2 个是偏置项。

参数统计
  • 总参数数量:26,模型中所有可训练参数(包括权重和偏置)的总数量。

  • 可训练参数:26,模型中所有参与训练的参数。这里所有的参数都是可训练的(requires_grad=True),没有非可训练的参数。

  • 非可训练参数:0,说明模型中没有被设置为不可训练的参数。

3. 查看每一层的权重和偏置

这一部分输出列出了每一层的具体参数(权重和偏置)的值。

linear1.weight:
tensor([[ 0.4777, -0.2076, 0.4900],[-0.1776, 0.4441, 0.6924],[-0.5449, 1.6153, 0.0243]], requires_grad=True)

这是 linear1 层的权重矩阵,形状是 (3, 3)。由于 linear1nn.Linear(3, 3),它的权重矩阵也是 3 行 3 列。权重参数是使用 Xavier 初始化(nn.init.xavier_normal_)初始化的。

linear1.bias:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)

这是 linear1 层的偏置项,形状是 (3,),因为每个输出特征对应一个偏置值。

linear2.weight:
tensor([[-0.0510, -1.2731, -0.7253],[-0.6112, 0.1189, -0.4903]], requires_grad=True)

这是 linear2 层的权重矩阵,形状是 (2, 3),因为 linear2nn.Linear(3, 2),需要 3 个输入特征映射到 2 个输出特征。权重是使用 Kaiming 初始化nn.init.kaiming_normal_初始化的。

linear2.bias:
tensor([0.5391, 0.2552], requires_grad=True)

这是 linear2 层的偏置项,形状是 (2,),因为每个输出特征对应一个偏置值。

out.weight:
tensor([[-0.3271, -0.3483],[-0.0619, -0.0680]], requires_grad=True)

这是输出层 out 的权重矩阵,形状是 (2, 2),因为 outnn.Linear(2, 2),接收 2 个输入特征并输出 2 个特征。

out.bias:
tensor([-0.5508, 0.5895], requires_grad=True)

这是输出层 out 的偏置项,形状是 (2,)

总结

  • 这段代码展示了一个简单的神经网络模型,包含 3 个全连接层(线性层),每层的输入输出特征数量逐步缩小。

  • 我们通过 summary() 查看了模型的整体结构,展示了每一层的输出形状和参数数量,总共有 26 个参数。

  • 每一层的权重和偏置参数值被输出,展示了它们是如何被初始化的(通过 Xavier 和 Kaiming 初始化)。

  • 该模型的前向传播通过激活函数(sigmoidReLU)以及 softmax 将输出转化为概率分布。

​​​​​​​​​​​​​​

相关文章:

【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算

# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包 pip install torchsummary import torch import torch.nn as nn from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类 class Mo…...

我与Linux的爱恋:命令行参数|环境变量

​ ​ 🔥个人主页:guoguoqiang. 🔥专栏:Linux的学习 文章目录 一.命令行参数二.环境变量1.环境变量的基本概念2.查看环境变量的方法3.环境变量相关命令4.环境变量的组织方式以及获取环境变量的三种方法 环境变量具有全局属性 一…...

django drf 统一Response格式

场景 需要将响应体按照格式规范返回给前端。 例如: 响应体中包含以下字段: {"result": true,"data": {},"code": 200,"message": "ok","request_id": "20cadfe4-51cd-42f6-af81-0…...

SM2协同签名算法中随机数K的随机性对算法安全的影响

前面介绍过若持有私钥d的用户两次SM2签名过程中随机数k相同,在对手获得两次签名结果Sig1和Sig2的情况下,可破解私钥d。 具体见SM2签名算法中随机数K的随机性对算法安全的影响_sm2关闭随机数-CSDN博客 另关于SM2协同签名过程,具体见SM2协同签…...

解决setMouseTracking(true)后还是无法触发mouseMoveEvent的问题

如图,在给整体界面设置鼠标追踪且给ui界面的子控件也设置了鼠标追踪后,运行后的界面仍然有些地方移动鼠标无法触发 mouseMoveEvent函数,这就令人头痛。。。 我的解决方法是:重载event函数: 完美解决。。。...

基于深度学习的花卉智能分类识别系统

温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 传统的花卉分类方法通常依赖于专家的知识和经验,这种方法不仅耗时耗力,而且容易受到主观因素的影响。本系统利用 TensorFlow、Keras 等深度学习框架构建卷积神经网络&#…...

Springboot集成MongoDb快速入门

1. 什么是MongoDB 1.1. 基本概念 MongoDB是一个基于分布式文件存储 [1] 的数据库。由C语言编写。旨在为WEB应用提供可扩展的高性能数据存储解决方案。 MongoDB是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数…...

DERT目标检测—End-to-End Object Detection with Transformers

DERT:使用Transformer的端到端目标检测 论文题目:End-to-End Object Detection with Transformers 官方代码:https://github.com/facebookresearch/detr 论文题目中包括的一个创新点End to End(端到端的方法)简单的理解就是没有使…...

软件后端开发速度慢的科技公司老板有没有思考如何破局

最近接到两个科技公司咨询,说是他们公司后端开发速度太慢,前端程序员老等着,后端程序员拖了项目进度。 这种问题不只他们公司,在软件外包公司中,有一部分项目甲方客户要得急,以至于要求软件开发要快&#…...

开放原子超级链内核XuperCore可搭建区块链

区块链是一种分布式数据库技术,它以块的形式存储数据,并使用密码学方法保证数据的安全性和完整性。 每个块包含一定数量的交易信息,并通过加密链接到前一个块,形成一个不断增长的链条。 这种设计使得数据在网络中无法被篡改,因为任何尝试修改一个块的数据都会破坏整个链的…...

【Qualcomm】高通SNPE框架的使用 | 原始模型转换为量化的DLC文件 | 在Android的CPU端运行模型

目录 ① 激活snpe环境 ② 设置环境变量 ③ 模型转换 ④ run on Android 首先,默认SNPE工具已经下载并且Setup相关工作均已完成。同时,拥有原始模型文件,本文使用的模型文件为SNPE 框架示例的inception_v3_2016_08_28_frozen.pb文件。imag…...

C++map与set

文章目录 前言一、map和set基础知识二、set与map使用示例1.set去重操作2.map字典统计 总结 前言 本章主要介绍map和set的基本知识与用法。 一、map和set基础知识 map与set属于STL的一部分,他们底层都是是同红黑树来实现的。 ①set常见用途是去重 ,set不…...

随手记:前端一些定位bug的方法

有时候接到bug很烦躁,不管是任何环境的bug,看到都影响心情,随后记总结一下查看bug的思路,在摸不着头脑的时候或者焦虑的时候,可以静下心来顺着思路思考和排查bug可能产生的原因 1.接到bug,最重要的是&am…...

【深度学习】03-神经网络2-1损失函数

在神经网络中,不同任务类型(如多分类、二分类、回归)需要使用不同的损失函数来衡量模型预测和真实值之间的差异。选择合适的损失函数对于模型的性能至关重要。 这里的是API 的注意⚠️,但是在真实的公式中,目标值一定是…...

Python爬虫APP程序:构建智能化数据抓取工具

在信息爆炸的时代,数据的价值日益凸显。Python作为一种强大的编程语言,与其丰富的库一起,为爬虫程序的开发提供了得天独厚的优势。本文将探讨如何使用Python构建一个爬虫APP程序,以及其背后的思维逻辑。 什么是Python爬虫APP程序&…...

第五部分:2---中断与信号

目录 操作系统如何得知哪个外部资源就绪? 什么是中断机制? CPU引脚和中断号的关系: 中断向量表: 信号和中断的关系: 操作系统如何得知哪个外部资源就绪? 操作系统并不会主动轮询所有外设来查看哪些资源…...

梧桐数据库(WuTongDB):SQL Server Query Optimizer 简介

SQL Server Query Optimizer 是 SQL Server 数据库引擎的核心组件之一,负责生成查询执行计划,以优化 SQL 查询的执行性能。它的目标是根据查询的逻辑结构和底层数据的统计信息,选择出最优的查询执行方案。SQL Server Query Optimizer 采用基于…...

Scrapy框架介绍

一、什么是Scrapy 是一款快速而强大的web爬虫框架,基于Twusted的异步处理框架 Twisted是事件驱动的 Scrapy是由Python实现的爬虫框架 ① 架构清晰 ②可扩展性强 ③可以灵活完成需求 二、核心组件 Scrapy Engine(引擎):Scrapy框架…...

Facebook对现代社交互动的影响

自2004年成立以来,Facebook已经成为全球最大的社交媒体平台之一,改变了人们的交流方式和社交互动模式。作为一个数字平台,Facebook不仅为用户提供了分享生活点滴的空间,也深刻影响了现代社交互动的各个方面。本文将探讨Facebook如…...

Java项目运维有哪些内容?

Java项目运维的内容主要包括环境准备、部署Java应用、配置和优化、安全配置、以及数据安全保护措施,服务的运行和资源动态监控管理。‌ ‌1,环境准备‌:这包括选择适合运行Java和Tomcat的操作系统,如Ubuntu、CentOS等Linux发行版…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

CentOS下的分布式内存计算Spark环境部署

一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架&#xff0c;相比 MapReduce 具有以下核心优势&#xff1a; 内存计算&#xff1a;数据可常驻内存&#xff0c;迭代计算性能提升 10-100 倍&#xff08;文档段落&#xff1a;3-79…...

MySQL中【正则表达式】用法

MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现&#xff08;两者等价&#xff09;&#xff0c;用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例&#xff1a; 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文&#xff5c;DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色&#xff0c;华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型&#xff0c;能助力我们轻松驾驭 DeepSeek-V3/R1&#xff0c;本文中将分享如何…...

css3笔记 (1) 自用

outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size&#xff1a;0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格&#xff…...

技术栈RabbitMq的介绍和使用

目录 1. 什么是消息队列&#xff1f;2. 消息队列的优点3. RabbitMQ 消息队列概述4. RabbitMQ 安装5. Exchange 四种类型5.1 direct 精准匹配5.2 fanout 广播5.3 topic 正则匹配 6. RabbitMQ 队列模式6.1 简单队列模式6.2 工作队列模式6.3 发布/订阅模式6.4 路由模式6.5 主题模式…...

算法:模拟

1.替换所有的问号 1576. 替换所有的问号 - 力扣&#xff08;LeetCode&#xff09; ​遍历字符串​&#xff1a;通过外层循环逐一检查每个字符。​遇到 ? 时处理​&#xff1a; 内层循环遍历小写字母&#xff08;a 到 z&#xff09;。对每个字母检查是否满足&#xff1a; ​与…...

比较数据迁移后MySQL数据库和OceanBase数据仓库中的表

设计一个MySQL数据库和OceanBase数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...

【FTP】ftp文件传输会丢包吗?批量几百个文件传输,有一些文件没有传输完整,如何解决?

FTP&#xff08;File Transfer Protocol&#xff09;本身是一个基于 TCP 的协议&#xff0c;理论上不会丢包。但 FTP 文件传输过程中仍可能出现文件不完整、丢失或损坏的情况&#xff0c;主要原因包括&#xff1a; ✅ 一、FTP传输可能“丢包”或文件不完整的原因 原因描述网络…...