当前位置: 首页 > news >正文

【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算

# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包

pip install torchsummary

import torch
import torch.nn as nn
from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类
class Model(nn.Module):# 初始化模型的构造函数def __init__(self):super().__init__()  # 调用父类 nn.Module 的初始化方法# 定义第一个全连接层(线性层),3个输入特征,3个输出特征self.linear1 = nn.Linear(3, 3)  # 使用 Xavier 正态分布初始化第一个全连接层的权重nn.init.xavier_normal_(self.linear1.weight)# 定义第二个全连接层,输入 3 个特征,输出 2 个特征self.linear2 = nn.Linear(3, 2)# 使用 Kaiming 正态分布初始化第二个全连接层的权重,适合 ReLU 激活函数nn.init.kaiming_normal_(self.linear2.weight)# 定义输出层,输入 2 个特征,输出 2 个特征self.out = nn.Linear(2, 2)# 定义前向传播过程 (forward 函数会自动执行,类似于模型的"推理"过程)def forward(self, x):# 第一个全连接层运算x = self.linear1(x)# 使用 Sigmoid 激活函数x = torch.sigmoid(x)# 第二个全连接层运算x = self.linear2(x)# 使用 ReLU 激活函数x = torch.relu(x)# 输出层运算x = self.out(x)# 使用 Softmax 激活函数,将输出转化为概率分布# dim=-1 表示在最后一个维度(通常是输出的类别维度)上做 softmax 归一化x = torch.softmax(x, dim=-1)return xif __name__ == '__main__':# 实例化神经网络模型my_model = Model()# 随机生成一个形状为 (5, 3) 的输入数据,表示 5 个样本,每个样本有 3 个特征my_data = torch.randn(5, 3)print("mydata shape", my_data.shape)# 通过模型进行前向传播,输出模型的预测结果output = my_model(my_data)print("output shape", output.shape)# 计算并显示模型的参数总量以及模型结构summary(my_model, input_size=(3,), batch_size=5)# 查看模型中所有的参数,包括权重和偏置项(bias)print("-----查看模型参数w 和 b  -----")for name, parameter in my_model.named_parameters():print(name, parameter)

mydata shape torch.Size([5, 3])
output shape torch.Size([5, 2])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                     [5, 3]              12
            Linear-2                     [5, 2]               8
            Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
-----查看模型参数w 和 b  -----
linear1.weight Parameter containing:
tensor([[ 0.4777, -0.2076,  0.4900],
        [-0.1776,  0.4441,  0.6924],
        [-0.5449,  1.6153,  0.0243]], requires_grad=True)
linear1.bias Parameter containing:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)
linear2.weight Parameter containing:
tensor([[-0.0510, -1.2731, -0.7253],
        [-0.6112,  0.1189, -0.4903]], requires_grad=True)
linear2.bias Parameter containing:
tensor([0.5391, 0.2552], requires_grad=True)
out.weight Parameter containing:
tensor([[-0.3271, -0.3483],
        [-0.0619, -0.0680]], requires_grad=True)
out.bias Parameter containing:
tensor([-0.5508,  0.5895], requires_grad=True)
 

 代码输出结果解读

​​​​​​​

这个代码的输出展示了两部分内容:

  1. 数据维度和模型输出维度

    • mydata shape torch.Size([5, 3])

    • output shape torch.Size([5, 2])

  2. 模型的结构、参数数量和每一层的权重与偏置

    • 模型的层结构、每一层的输出形状,以及每一层的参数数量。

    • 每层的权重(weight)和偏置(bias)的具体数值。

让我们详细分析每一部分的输出。

1. 输入数据和输出数据的形状

mydata shape torch.Size([5, 3])

这部分的输出说明:

  • 输入数据的形状(5, 3),表示有 5 个样本,每个样本有 3 个特征。这与模型定义时的输入层 nn.Linear(3, 3) 是一致的,输入层期望接收 3 个特征。

output shape torch.Size([5, 2])

这部分的输出说明:

  • 模型输出的形状 (5, 2),表示 5 个样本的输出,每个样本的输出有 2 个值。由于模型的输出层定义为 nn.Linear(2, 2),它接收 2 个输入特征并输出 2 个值,符合预期。

2. 模型结构和参数

模型结构和参数信息是通过 summary() 函数生成的,它列出了每一层的名称、输出形状和参数数量。

详细输出解释:
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Linear-1                     [5, 3]             12Linear-2                     [5, 2]               8Linear-3                     [5, 2]               6
================================================================
Total params: 26
Trainable params: 26
Non-trainable params: 0
----------------------------------------------------------------
线性层 1(Linear-1
  • 层的类型Linear,这是一个全连接层,定义为 nn.Linear(3, 3)

  • 输出形状[5, 3],表示输入了 5 个样本,每个样本有 3 个特征,经过该层的输出仍然是 5 个样本,每个样本有 3 个特征。

  • 参数数量:12,其中 9 个是权重参数(3 x 3 的权重矩阵),另外 3 个是偏置项。

线性层 2(Linear-2
  • 层的类型Linear,定义为 nn.Linear(3, 2),将 3 个输入特征映射到 2 个输出特征。

  • 输出形状[5, 2],表示输入了 5 个样本,每个样本有 2 个输出特征。

  • 参数数量:8,其中 6 个是权重参数(3 x 2 的权重矩阵),另外 2 个是偏置项。

输出层(Linear-3
  • 层的类型Linear,定义为 nn.Linear(2, 2),接收 2 个输入特征,输出 2 个特征。

  • 输出形状[5, 2],表示 5 个样本,每个样本的输出为 2 个特征。

  • 参数数量:6,其中 4 个是权重参数(2 x 2 的权重矩阵),另外 2 个是偏置项。

参数统计
  • 总参数数量:26,模型中所有可训练参数(包括权重和偏置)的总数量。

  • 可训练参数:26,模型中所有参与训练的参数。这里所有的参数都是可训练的(requires_grad=True),没有非可训练的参数。

  • 非可训练参数:0,说明模型中没有被设置为不可训练的参数。

3. 查看每一层的权重和偏置

这一部分输出列出了每一层的具体参数(权重和偏置)的值。

linear1.weight:
tensor([[ 0.4777, -0.2076, 0.4900],[-0.1776, 0.4441, 0.6924],[-0.5449, 1.6153, 0.0243]], requires_grad=True)

这是 linear1 层的权重矩阵,形状是 (3, 3)。由于 linear1nn.Linear(3, 3),它的权重矩阵也是 3 行 3 列。权重参数是使用 Xavier 初始化(nn.init.xavier_normal_)初始化的。

linear1.bias:
tensor([0.4524, 0.2902, 0.4897], requires_grad=True)

这是 linear1 层的偏置项,形状是 (3,),因为每个输出特征对应一个偏置值。

linear2.weight:
tensor([[-0.0510, -1.2731, -0.7253],[-0.6112, 0.1189, -0.4903]], requires_grad=True)

这是 linear2 层的权重矩阵,形状是 (2, 3),因为 linear2nn.Linear(3, 2),需要 3 个输入特征映射到 2 个输出特征。权重是使用 Kaiming 初始化nn.init.kaiming_normal_初始化的。

linear2.bias:
tensor([0.5391, 0.2552], requires_grad=True)

这是 linear2 层的偏置项,形状是 (2,),因为每个输出特征对应一个偏置值。

out.weight:
tensor([[-0.3271, -0.3483],[-0.0619, -0.0680]], requires_grad=True)

这是输出层 out 的权重矩阵,形状是 (2, 2),因为 outnn.Linear(2, 2),接收 2 个输入特征并输出 2 个特征。

out.bias:
tensor([-0.5508, 0.5895], requires_grad=True)

这是输出层 out 的偏置项,形状是 (2,)

总结

  • 这段代码展示了一个简单的神经网络模型,包含 3 个全连接层(线性层),每层的输入输出特征数量逐步缩小。

  • 我们通过 summary() 查看了模型的整体结构,展示了每一层的输出形状和参数数量,总共有 26 个参数。

  • 每一层的权重和偏置参数值被输出,展示了它们是如何被初始化的(通过 Xavier 和 Kaiming 初始化)。

  • 该模型的前向传播通过激活函数(sigmoidReLU)以及 softmax 将输出转化为概率分布。

​​​​​​​​​​​​​​

相关文章:

【深度学习】03-神经网络01-4 神经网络的pytorch搭建和参数计算

# 计算模型参数,查看模型结构,我们要查看有多少参数,需要先安装包 pip install torchsummary import torch import torch.nn as nn from torchsummary import summary # 导入 summary 函数,用于计算模型参数和查看模型结构# 创建神经网络模型类 class Mo…...

我与Linux的爱恋:命令行参数|环境变量

​ ​ 🔥个人主页:guoguoqiang. 🔥专栏:Linux的学习 文章目录 一.命令行参数二.环境变量1.环境变量的基本概念2.查看环境变量的方法3.环境变量相关命令4.环境变量的组织方式以及获取环境变量的三种方法 环境变量具有全局属性 一…...

django drf 统一Response格式

场景 需要将响应体按照格式规范返回给前端。 例如: 响应体中包含以下字段: {"result": true,"data": {},"code": 200,"message": "ok","request_id": "20cadfe4-51cd-42f6-af81-0…...

SM2协同签名算法中随机数K的随机性对算法安全的影响

前面介绍过若持有私钥d的用户两次SM2签名过程中随机数k相同,在对手获得两次签名结果Sig1和Sig2的情况下,可破解私钥d。 具体见SM2签名算法中随机数K的随机性对算法安全的影响_sm2关闭随机数-CSDN博客 另关于SM2协同签名过程,具体见SM2协同签…...

解决setMouseTracking(true)后还是无法触发mouseMoveEvent的问题

如图,在给整体界面设置鼠标追踪且给ui界面的子控件也设置了鼠标追踪后,运行后的界面仍然有些地方移动鼠标无法触发 mouseMoveEvent函数,这就令人头痛。。。 我的解决方法是:重载event函数: 完美解决。。。...

基于深度学习的花卉智能分类识别系统

温馨提示:文末有 CSDN 平台官方提供的学长 QQ 名片 :) 1. 项目简介 传统的花卉分类方法通常依赖于专家的知识和经验,这种方法不仅耗时耗力,而且容易受到主观因素的影响。本系统利用 TensorFlow、Keras 等深度学习框架构建卷积神经网络&#…...

Springboot集成MongoDb快速入门

1. 什么是MongoDB 1.1. 基本概念 MongoDB是一个基于分布式文件存储 [1] 的数据库。由C语言编写。旨在为WEB应用提供可扩展的高性能数据存储解决方案。 MongoDB是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数…...

DERT目标检测—End-to-End Object Detection with Transformers

DERT:使用Transformer的端到端目标检测 论文题目:End-to-End Object Detection with Transformers 官方代码:https://github.com/facebookresearch/detr 论文题目中包括的一个创新点End to End(端到端的方法)简单的理解就是没有使…...

软件后端开发速度慢的科技公司老板有没有思考如何破局

最近接到两个科技公司咨询,说是他们公司后端开发速度太慢,前端程序员老等着,后端程序员拖了项目进度。 这种问题不只他们公司,在软件外包公司中,有一部分项目甲方客户要得急,以至于要求软件开发要快&#…...

开放原子超级链内核XuperCore可搭建区块链

区块链是一种分布式数据库技术,它以块的形式存储数据,并使用密码学方法保证数据的安全性和完整性。 每个块包含一定数量的交易信息,并通过加密链接到前一个块,形成一个不断增长的链条。 这种设计使得数据在网络中无法被篡改,因为任何尝试修改一个块的数据都会破坏整个链的…...

【Qualcomm】高通SNPE框架的使用 | 原始模型转换为量化的DLC文件 | 在Android的CPU端运行模型

目录 ① 激活snpe环境 ② 设置环境变量 ③ 模型转换 ④ run on Android 首先,默认SNPE工具已经下载并且Setup相关工作均已完成。同时,拥有原始模型文件,本文使用的模型文件为SNPE 框架示例的inception_v3_2016_08_28_frozen.pb文件。imag…...

C++map与set

文章目录 前言一、map和set基础知识二、set与map使用示例1.set去重操作2.map字典统计 总结 前言 本章主要介绍map和set的基本知识与用法。 一、map和set基础知识 map与set属于STL的一部分,他们底层都是是同红黑树来实现的。 ①set常见用途是去重 ,set不…...

随手记:前端一些定位bug的方法

有时候接到bug很烦躁,不管是任何环境的bug,看到都影响心情,随后记总结一下查看bug的思路,在摸不着头脑的时候或者焦虑的时候,可以静下心来顺着思路思考和排查bug可能产生的原因 1.接到bug,最重要的是&am…...

【深度学习】03-神经网络2-1损失函数

在神经网络中,不同任务类型(如多分类、二分类、回归)需要使用不同的损失函数来衡量模型预测和真实值之间的差异。选择合适的损失函数对于模型的性能至关重要。 这里的是API 的注意⚠️,但是在真实的公式中,目标值一定是…...

Python爬虫APP程序:构建智能化数据抓取工具

在信息爆炸的时代,数据的价值日益凸显。Python作为一种强大的编程语言,与其丰富的库一起,为爬虫程序的开发提供了得天独厚的优势。本文将探讨如何使用Python构建一个爬虫APP程序,以及其背后的思维逻辑。 什么是Python爬虫APP程序&…...

第五部分:2---中断与信号

目录 操作系统如何得知哪个外部资源就绪? 什么是中断机制? CPU引脚和中断号的关系: 中断向量表: 信号和中断的关系: 操作系统如何得知哪个外部资源就绪? 操作系统并不会主动轮询所有外设来查看哪些资源…...

梧桐数据库(WuTongDB):SQL Server Query Optimizer 简介

SQL Server Query Optimizer 是 SQL Server 数据库引擎的核心组件之一,负责生成查询执行计划,以优化 SQL 查询的执行性能。它的目标是根据查询的逻辑结构和底层数据的统计信息,选择出最优的查询执行方案。SQL Server Query Optimizer 采用基于…...

Scrapy框架介绍

一、什么是Scrapy 是一款快速而强大的web爬虫框架,基于Twusted的异步处理框架 Twisted是事件驱动的 Scrapy是由Python实现的爬虫框架 ① 架构清晰 ②可扩展性强 ③可以灵活完成需求 二、核心组件 Scrapy Engine(引擎):Scrapy框架…...

Facebook对现代社交互动的影响

自2004年成立以来,Facebook已经成为全球最大的社交媒体平台之一,改变了人们的交流方式和社交互动模式。作为一个数字平台,Facebook不仅为用户提供了分享生活点滴的空间,也深刻影响了现代社交互动的各个方面。本文将探讨Facebook如…...

Java项目运维有哪些内容?

Java项目运维的内容主要包括环境准备、部署Java应用、配置和优化、安全配置、以及数据安全保护措施,服务的运行和资源动态监控管理。‌ ‌1,环境准备‌:这包括选择适合运行Java和Tomcat的操作系统,如Ubuntu、CentOS等Linux发行版…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)

HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...

STM32+rt-thread判断是否联网

一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量,这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

C#中的CLR属性、依赖属性与附加属性

CLR属性的主要特征 封装性: 隐藏字段的实现细节 提供对字段的受控访问 访问控制: 可单独设置get/set访问器的可见性 可创建只读或只写属性 计算属性: 可以在getter中执行计算逻辑 不需要直接对应一个字段 验证逻辑: 可以…...

无人机侦测与反制技术的进展与应用

国家电网无人机侦测与反制技术的进展与应用 引言 随着无人机(无人驾驶飞行器,UAV)技术的快速发展,其在商业、娱乐和军事领域的广泛应用带来了新的安全挑战。特别是对于关键基础设施如电力系统,无人机的“黑飞”&…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具,可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件,也不需要在线上传文件,保护您的隐私。 工具截图 主要特点 🚀 快速转换:本地转换,无需等待上…...

MySQL JOIN 表过多的优化思路

当 MySQL 查询涉及大量表 JOIN 时,性能会显著下降。以下是优化思路和简易实现方法: 一、核心优化思路 减少 JOIN 数量 数据冗余:添加必要的冗余字段(如订单表直接存储用户名)合并表:将频繁关联的小表合并成…...

Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storms…...