【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算
# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包
pip install torchsummary
import torch
import torch.nn as nn
from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类
class Model(nn.Module):# 初始化模型的构造函数def __init__(self):super().__init__() # 调用父类 nn.Module 的初始化方法# 定义第一个全连接层(线性层),3个输入特征,3个输出特征self.linear1 = nn.Linear(3, 3) # 使用 Xavier 正态分布初始化第一个全连接层的权重nn.init.xavier_normal_(self.linear1.weight)# 定义第二个全连接层,输入 3 个特征,输出 2 个特征self.linear2 = nn.Linear(3, 2)# 使用 Kaiming 正态分布初始化第二个全连接层的权重,适合 ReLU 激活函数nn.init.kaiming_normal_(self.linear2.weight)# 定义输出层,输入 2 个特征,输出 2 个特征self.out = nn.Linear(2, 2)# 定义前向传播过程 (forward 函数会自动执行,类似于模型的"推理"过程)def forward(self, x):# 第一个全连接层运算x = self.linear1(x)# 使用 Sigmoid 激活函数x = torch.sigmoid(x)# 第二个全连接层运算x = self.linear2(x)# 使用 ReLU 激活函数x = torch.relu(x)# 输出层运算x = self.out(x)# 使用 Softmax 激活函数,将输出转化为概率分布# dim=-1 表示在最后一个维度(通常是输出的类别维度)上做 softmax 归一化x = torch.softmax(x, dim=-1)return xif __name__ == '__main__':# 实例化神经网络模型my_model = Model()# 随机生成一个形状为 (5, 3) 的输入数据,表示 5 个样本,每个样本有 3 个特征my_data = torch.randn(5, 3)print("mydata shape", my_data.shape)# 通过模型进行前向传播,输出模型的预测结果output = my_model(my_data)print("output shape", output.shape)# 计算并显示模型的参数总量以及模型结构summary(my_model, input_size=(3,), batch_size=5)# 查看模型中所有的参数,包括权重和偏置项(bias)print("-----查看模型参数w 和 b -----")for name, parameter in my_model.named_parameters():print(name, parameter)
mydata shape torch.Size([5, 3])
output shape torch.Size([5, 2])
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [5, 3] 12
Linear-2 [5, 2] 8
Linear-3 [5, 2] 6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
-----查看模型参数w 和 b -----
linear1.weight Parameter containing:
tensor([[ 0.4777, -0.2076, 0.4900],
[-0.1776, 0.4441, 0.6924],
[-0.5449, 1.6153, 0.0243]], requires_grad=True)
linear1.bias Parameter containing:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)
linear2.weight Parameter containing:
tensor([[-0.0510, -1.2731, -0.7253],
[-0.6112, 0.1189, -0.4903]], requires_grad=True)
linear2.bias Parameter containing:
tensor([0.5391, 0.2552], requires_grad=True)
out.weight Parameter containing:
tensor([[-0.3271, -0.3483],
[-0.0619, -0.0680]], requires_grad=True)
out.bias Parameter containing:
tensor([-0.5508, 0.5895], requires_grad=True)
代码输出结果解读
这个代码的输出展示了两部分内容:
-
数据维度和模型输出维度:
-
mydata shape torch.Size([5, 3])
-
output shape torch.Size([5, 2])
-
-
模型的结构、参数数量和每一层的权重与偏置:
-
模型的层结构、每一层的输出形状,以及每一层的参数数量。
-
每层的权重(
weight
)和偏置(bias
)的具体数值。
-
让我们详细分析每一部分的输出。
1. 输入数据和输出数据的形状
mydata shape torch.Size([5, 3])
这部分的输出说明:
-
输入数据的形状为
(5, 3)
,表示有 5 个样本,每个样本有 3 个特征。这与模型定义时的输入层nn.Linear(3, 3)
是一致的,输入层期望接收 3 个特征。
output shape torch.Size([5, 2])
这部分的输出说明:
-
模型输出的形状为
(5, 2)
,表示 5 个样本的输出,每个样本的输出有 2 个值。由于模型的输出层定义为nn.Linear(2, 2)
,它接收 2 个输入特征并输出 2 个值,符合预期。
2. 模型结构和参数
模型结构和参数信息是通过 summary()
函数生成的,它列出了每一层的名称、输出形状和参数数量。
详细输出解释:
----------------------------------------------------------------Layer (type) Output Shape Param # ================================================================Linear-1 [5, 3] 12Linear-2 [5, 2] 8Linear-3 [5, 2] 6 ================================================================ Total params: 26 Trainable params: 26 Non-trainable params: 0 ----------------------------------------------------------------
线性层 1(Linear-1
):
-
层的类型:
Linear
,这是一个全连接层,定义为nn.Linear(3, 3)
。 -
输出形状:
[5, 3]
,表示输入了 5 个样本,每个样本有 3 个特征,经过该层的输出仍然是 5 个样本,每个样本有 3 个特征。 -
参数数量:12,其中 9 个是权重参数(
3 x 3
的权重矩阵),另外 3 个是偏置项。
线性层 2(Linear-2
):
-
层的类型:
Linear
,定义为nn.Linear(3, 2)
,将 3 个输入特征映射到 2 个输出特征。 -
输出形状:
[5, 2]
,表示输入了 5 个样本,每个样本有 2 个输出特征。 -
参数数量:8,其中 6 个是权重参数(
3 x 2
的权重矩阵),另外 2 个是偏置项。
输出层(Linear-3
):
-
层的类型:
Linear
,定义为nn.Linear(2, 2)
,接收 2 个输入特征,输出 2 个特征。 -
输出形状:
[5, 2]
,表示 5 个样本,每个样本的输出为 2 个特征。 -
参数数量:6,其中 4 个是权重参数(
2 x 2
的权重矩阵),另外 2 个是偏置项。
参数统计:
-
总参数数量:26,模型中所有可训练参数(包括权重和偏置)的总数量。
-
可训练参数:26,模型中所有参与训练的参数。这里所有的参数都是可训练的(
requires_grad=True
),没有非可训练的参数。 -
非可训练参数:0,说明模型中没有被设置为不可训练的参数。
3. 查看每一层的权重和偏置
这一部分输出列出了每一层的具体参数(权重和偏置)的值。
linear1.weight
:
tensor([[ 0.4777, -0.2076, 0.4900],[-0.1776, 0.4441, 0.6924],[-0.5449, 1.6153, 0.0243]], requires_grad=True)
这是 linear1
层的权重矩阵,形状是 (3, 3)
。由于 linear1
是 nn.Linear(3, 3)
,它的权重矩阵也是 3 行 3 列。权重参数是使用 Xavier 初始化(nn.init.xavier_normal_
)初始化的。
linear1.bias
:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)
这是 linear1
层的偏置项,形状是 (3,)
,因为每个输出特征对应一个偏置值。
linear2.weight
:
tensor([[-0.0510, -1.2731, -0.7253],[-0.6112, 0.1189, -0.4903]], requires_grad=True)
这是 linear2
层的权重矩阵,形状是 (2, 3)
,因为 linear2
是 nn.Linear(3, 2)
,需要 3 个输入特征映射到 2 个输出特征。权重是使用 Kaiming 初始化(nn.init.kaiming_normal_
)初始化的。
linear2.bias
:
tensor([0.5391, 0.2552], requires_grad=True)
这是 linear2
层的偏置项,形状是 (2,)
,因为每个输出特征对应一个偏置值。
out.weight
:
tensor([[-0.3271, -0.3483],[-0.0619, -0.0680]], requires_grad=True)
这是输出层 out
的权重矩阵,形状是 (2, 2)
,因为 out
是 nn.Linear(2, 2)
,接收 2 个输入特征并输出 2 个特征。
out.bias
:
tensor([-0.5508, 0.5895], requires_grad=True)
这是输出层 out
的偏置项,形状是 (2,)
。
总结
-
这段代码展示了一个简单的神经网络模型,包含 3 个全连接层(线性层),每层的输入输出特征数量逐步缩小。
-
我们通过
summary()
查看了模型的整体结构,展示了每一层的输出形状和参数数量,总共有 26 个参数。 -
每一层的权重和偏置参数值被输出,展示了它们是如何被初始化的(通过 Xavier 和 Kaiming 初始化)。
-
该模型的前向传播通过激活函数(
sigmoid
和ReLU
)以及 softmax 将输出转化为概率分布。
相关文章:

【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算
# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包 pip install torchsummary import torch import torch.nn as nn from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类 class Mo…...

我与Linux的爱恋:命令行参数|环境变量
🔥个人主页:guoguoqiang. 🔥专栏:Linux的学习 文章目录 一.命令行参数二.环境变量1.环境变量的基本概念2.查看环境变量的方法3.环境变量相关命令4.环境变量的组织方式以及获取环境变量的三种方法 环境变量具有全局属性 一…...
django drf 统一Response格式
场景 需要将响应体按照格式规范返回给前端。 例如: 响应体中包含以下字段: {"result": true,"data": {},"code": 200,"message": "ok","request_id": "20cadfe4-51cd-42f6-af81-0…...
SM2协同签名算法中随机数K的随机性对算法安全的影响
前面介绍过若持有私钥d的用户两次SM2签名过程中随机数k相同,在对手获得两次签名结果Sig1和Sig2的情况下,可破解私钥d。 具体见SM2签名算法中随机数K的随机性对算法安全的影响_sm2关闭随机数-CSDN博客 另关于SM2协同签名过程,具体见SM2协同签…...

解决setMouseTracking(true)后还是无法触发mouseMoveEvent的问题
如图,在给整体界面设置鼠标追踪且给ui界面的子控件也设置了鼠标追踪后,运行后的界面仍然有些地方移动鼠标无法触发 mouseMoveEvent函数,这就令人头痛。。。 我的解决方法是:重载event函数: 完美解决。。。...

基于深度学习的花卉智能分类识别系统
温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 传统的花卉分类方法通常依赖于专家的知识和经验,这种方法不仅耗时耗力,而且容易受到主观因素的影响。本系统利用 TensorFlow、Keras 等深度学习框架构建卷积神经网络&#…...
Springboot集成MongoDb快速入门
1. 什么是MongoDB 1.1. 基本概念 MongoDB是一个基于分布式文件存储 [1] 的数据库。由C语言编写。旨在为WEB应用提供可扩展的高性能数据存储解决方案。 MongoDB是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数…...

DERT目标检测—End-to-End Object Detection with Transformers
DERT:使用Transformer的端到端目标检测 论文题目:End-to-End Object Detection with Transformers 官方代码:https://github.com/facebookresearch/detr 论文题目中包括的一个创新点End to End(端到端的方法)简单的理解就是没有使…...
软件后端开发速度慢的科技公司老板有没有思考如何破局
最近接到两个科技公司咨询,说是他们公司后端开发速度太慢,前端程序员老等着,后端程序员拖了项目进度。 这种问题不只他们公司,在软件外包公司中,有一部分项目甲方客户要得急,以至于要求软件开发要快&#…...

开放原子超级链内核XuperCore可搭建区块链
区块链是一种分布式数据库技术,它以块的形式存储数据,并使用密码学方法保证数据的安全性和完整性。 每个块包含一定数量的交易信息,并通过加密链接到前一个块,形成一个不断增长的链条。 这种设计使得数据在网络中无法被篡改,因为任何尝试修改一个块的数据都会破坏整个链的…...

【Qualcomm】高通SNPE框架的使用 | 原始模型转换为量化的DLC文件 | 在Android的CPU端运行模型
目录 ① 激活snpe环境 ② 设置环境变量 ③ 模型转换 ④ run on Android 首先,默认SNPE工具已经下载并且Setup相关工作均已完成。同时,拥有原始模型文件,本文使用的模型文件为SNPE 框架示例的inception_v3_2016_08_28_frozen.pb文件。imag…...

C++map与set
文章目录 前言一、map和set基础知识二、set与map使用示例1.set去重操作2.map字典统计 总结 前言 本章主要介绍map和set的基本知识与用法。 一、map和set基础知识 map与set属于STL的一部分,他们底层都是是同红黑树来实现的。 ①set常见用途是去重 ,set不…...
随手记:前端一些定位bug的方法
有时候接到bug很烦躁,不管是任何环境的bug,看到都影响心情,随后记总结一下查看bug的思路,在摸不着头脑的时候或者焦虑的时候,可以静下心来顺着思路思考和排查bug可能产生的原因 1.接到bug,最重要的是&am…...

【深度学习】03-神经网络2-1损失函数
在神经网络中,不同任务类型(如多分类、二分类、回归)需要使用不同的损失函数来衡量模型预测和真实值之间的差异。选择合适的损失函数对于模型的性能至关重要。 这里的是API 的注意⚠️,但是在真实的公式中,目标值一定是…...

Python爬虫APP程序:构建智能化数据抓取工具
在信息爆炸的时代,数据的价值日益凸显。Python作为一种强大的编程语言,与其丰富的库一起,为爬虫程序的开发提供了得天独厚的优势。本文将探讨如何使用Python构建一个爬虫APP程序,以及其背后的思维逻辑。 什么是Python爬虫APP程序&…...
第五部分:2---中断与信号
目录 操作系统如何得知哪个外部资源就绪? 什么是中断机制? CPU引脚和中断号的关系: 中断向量表: 信号和中断的关系: 操作系统如何得知哪个外部资源就绪? 操作系统并不会主动轮询所有外设来查看哪些资源…...
梧桐数据库(WuTongDB):SQL Server Query Optimizer 简介
SQL Server Query Optimizer 是 SQL Server 数据库引擎的核心组件之一,负责生成查询执行计划,以优化 SQL 查询的执行性能。它的目标是根据查询的逻辑结构和底层数据的统计信息,选择出最优的查询执行方案。SQL Server Query Optimizer 采用基于…...
Scrapy框架介绍
一、什么是Scrapy 是一款快速而强大的web爬虫框架,基于Twusted的异步处理框架 Twisted是事件驱动的 Scrapy是由Python实现的爬虫框架 ① 架构清晰 ②可扩展性强 ③可以灵活完成需求 二、核心组件 Scrapy Engine(引擎):Scrapy框架…...

Facebook对现代社交互动的影响
自2004年成立以来,Facebook已经成为全球最大的社交媒体平台之一,改变了人们的交流方式和社交互动模式。作为一个数字平台,Facebook不仅为用户提供了分享生活点滴的空间,也深刻影响了现代社交互动的各个方面。本文将探讨Facebook如…...
Java项目运维有哪些内容?
Java项目运维的内容主要包括环境准备、部署Java应用、配置和优化、安全配置、以及数据安全保护措施,服务的运行和资源动态监控管理。 1,环境准备:这包括选择适合运行Java和Tomcat的操作系统,如Ubuntu、CentOS等Linux发行版…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
C++:std::is_convertible
C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...
Swagger和OpenApi的前世今生
Swagger与OpenAPI的关系演进是API标准化进程中的重要篇章,二者共同塑造了现代RESTful API的开发范式。 本期就扒一扒其技术演进的关键节点与核心逻辑: 🔄 一、起源与初创期:Swagger的诞生(2010-2014) 核心…...

浪潮交换机配置track检测实现高速公路收费网络主备切换NQA
浪潮交换机track配置 项目背景高速网络拓扑网络情况分析通信线路收费网络路由 收费汇聚交换机相应配置收费汇聚track配置 项目背景 在实施省内一条高速公路时遇到的需求,本次涉及的主要是收费汇聚交换机的配置,浪潮网络设备在高速项目很少,通…...

基于IDIG-GAN的小样本电机轴承故障诊断
目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) 梯度归一化(Gradient Normalization) (2) 判别器梯度间隙正则化(Discriminator Gradient Gap Regularization) (3) 自注意力机制(Self-Attention) 3. 完整损失函数 二…...
关于uniapp展示PDF的解决方案
在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项: 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库: npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...
用鸿蒙HarmonyOS5实现中国象棋小游戏的过程
下面是一个基于鸿蒙OS (HarmonyOS) 的中国象棋小游戏的实现代码。这个实现使用Java语言和鸿蒙的Ability框架。 1. 项目结构 /src/main/java/com/example/chinesechess/├── MainAbilitySlice.java // 主界面逻辑├── ChessView.java // 游戏视图和逻辑├──…...
云原生周刊:k0s 成为 CNCF 沙箱项目
开源项目推荐 HAMi HAMi(原名 k8s‑vGPU‑scheduler)是一款 CNCF Sandbox 级别的开源 K8s 中间件,通过虚拟化 GPU/NPU 等异构设备并支持内存、计算核心时间片隔离及共享调度,为容器提供统一接口,实现细粒度资源配额…...