当前位置: 首页 > news >正文

PyTorch使用教程(10)-torchinfo.summary网络结构可视化详细说明

1、基本介绍

torchinfo是一个为PyTorch用户量身定做的开源工具,其核心功能之一是summary函数。这个函数旨在简化模型的开发与调试流程,让模型架构一目了然。通过torchinfosummary函数,用户可以快速获取模型的详细结构和统计信息,如模型的层次结构、输入/输出维度、参数数量、多加操作(Mult-Adds)等关键信息。

2、安装

首先,你需要安装torchinfo库。可以通过pip进行安装:

pip install torchinfo

3、导入

安装完成后,需要在你的Python脚本中导入torchinfo模块:

from torchinfo import summary

4、函数原型定义

torchinfo的summary函数原型定义如下:

def summary(model: nn.Module, input_data: torch.Tensor | tuple[torch.Tensor, ...] | tuple[int, ...] | None = None, batch_dim: int = 0, col_widths: tuple[int, ...] | None = None, col_names: tuple[str, ...] | None = None, device: str | torch.device | None = None, dtypes: tuple[torch.dtype, ...] | None = None, verbose: int = 1, **kwargs)

参数说明

  • model: 要分析的PyTorch模型,必须是torch.nn.Module的实例。
  • input_data: 用于模型前向传播的输入数据。它可以是一个torch.Tensor对象,也可以是一个包含多个输入张量的元组。此外,还可以提供一个表示输入尺寸的元组,例如(batch_size, channels, height, width)。
  • batch_dim: 指定输入张量中哪个维度是批量大小(batch size)。默认为0。
  • col_widths: 指定输出列宽的元组。如果未指定,则自动计算列宽以适应输出。
  • col_names: 指定输出列名的元组。如果未指定,则使用默认列名。
  • device: 指定模型运行的设备(如’cpu’或’cuda’)。如果未指定,则自动选择。
  • dtypes: 指定输入张量的数据类型。如果未指定,则自动推断。
  • verbose: 控制输出信息的详细程度。默认为1,表示输出基本信息。设置为2或更高可以获得更详细的输出。
  • kwargs: 其他关键字参数,可以传递给模型的前向传播函数。

5、使用方法

下面通过几个示例来展示如何使用torchinfo的summary函数。
5.1 使用预定义模型
首先,我们使用PyTorch预定义的模型(如torchvision.models.resnet50)来展示如何使用summary函数。

import torch
import torchvision.models as models
from torchinfo import summary
# 定义模型
model = models.resnet18(pretrained=False)# 使用summary函数打印模型概况
summary(model, input_size=(1, 3, 224, 224))

在这个示例中,我们加载了一个未预训练的ResNet50模型,并使用summary函数打印了模型的概况。input_size参数指定了输入数据的大小,即(batch_size, channels, height, width)。
在这里插入图片描述

5.2 使用自定义模型
接下来,我们定义一个简单的自定义模型,并使用summary函数打印其概况。

import torch
import torch.nn as nn
from torchinfo import summary# 定义一个简单的两层全连接神经网络模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc1 = nn.Linear(100, 50)self.fc2 = nn.Linear(50, 10)self.relu = nn.ReLU()def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 创建模型实例
model = SimpleModel()# 使用summary函数打印模型概况
summary(model, input_size=(100,))

在这个示例中,我们定义了一个简单的两层全连接神经网络模型,并使用summary函数打印了模型的概况。input_size参数指定了输入数据的大小,即(batch_size, features)。由于我们的模型是一个全连接层,所以我们只指定了特征数量。
在这里插入图片描述

5.3 使用自定义输入数据

有时候,可能想要使用实际的输入数据来查看模型的概况。下面是一个示例,展示了如何使用自定义输入数据来调用summary函数。

import torch
import torchvision.models as models
from torchinfo import summary# 定义模型
model = models.resnet50(pretrained=False)# 创建自定义输入数据
input_data = torch.randn(1, 3, 224, 224)  # batch_size=1, channels=3, height=224, width=224# 使用summary函数打印模型概况
summary(model, input_data=input_data)

在这个示例中,我们创建了一个形状为(1, 3, 224, 224)的随机张量作为输入数据,并使用summary函数打印了模型的概况。注意,这里我们使用input_data参数而不是input_size参数来指定输入数据。

5.4 调整输出格式
torchinfo允许通过col_widths和col_names参数来调整输出的格式。下面是一个示例,展示了如何自定义输出列宽和列名。

import torch
import torchvision.models as models
from torchinfo import summary# 定义模型
model = models.resnet50(pretrained=False)# 使用summary函数打印模型概况,并自定义输出列宽和列名
summary(model, input_size=(3, 224, 224), col_widths=(30, 30, 20, 20),col_names=('input_size', 'output_size', 'kernel_size', 'num_params'))

在这个示例中,我们自定义了输出列宽和列名。col_widths参数指定了每列的宽度(以字符为单位),而col_names参数指定了每列的列名。这样,就可以根据需要来调整输出的格式了。

6、小结

torchinfo的summary函数是一个强大的工具,可以方便地查看PyTorch模型的结构和参数数量。通过本文的介绍,应该已经掌握了如何使用summary函数来打印模型的概况。无论使用预定义模型还是自定义模型,无论是使用输入尺寸还是自定义输入数据,torchinfo都能提供详细而清晰的输出信息。希望这篇文章能对你有所帮助!

相关文章:

PyTorch使用教程(10)-torchinfo.summary网络结构可视化详细说明

1、基本介绍 torchinfo是一个为PyTorch用户量身定做的开源工具,其核心功能之一是summary函数。这个函数旨在简化模型的开发与调试流程,让模型架构一目了然。通过torchinfo的summary函数,用户可以快速获取模型的详细结构和统计信息&#xff0…...

亚博microros小车-原生ubuntu支持系列:5-姿态检测

MediaPipe 介绍参见:亚博microros小车-原生ubuntu支持系列:4-手部检测-CSDN博客 本篇继续迁移姿态检测。 一 背景知识 以下来自亚博官网 MediaPipe Pose是⼀个⽤于⾼保真⾝体姿势跟踪的ML解决⽅案,利⽤BlazePose研究,从RGB视频…...

C语言之高校学生信息快速查询系统的实现

🌟 嗨,我是LucianaiB! 🌍 总有人间一两风,填我十万八千梦。 🚀 路漫漫其修远兮,吾将上下而求索。 C语言之高校学生信息快速查询系统的实现 目录 任务陈述与分析 问题陈述问题分析 数据结构设…...

WPF基础 | WPF 基础概念全解析:布局、控件与事件

WPF基础 | WPF 基础概念全解析:布局、控件与事件 一、前言二、WPF 布局系统2.1 布局的重要性与基本原理2.2 常见布局面板2.3 布局的测量与排列过程 三、WPF 控件3.1 控件概述与分类3.2 常见控件的属性、方法与事件3.3 自定义控件 四、WPF 事件4.1 路由事件概述4.2 事…...

迷宫1.2

先发一下上次的代码 #include<bits/stdc.h> #include<windows.h> #include <conio.h> using namespace std; char a[1005][1005]{ " ", "################", "# # *#", "# # # #&qu…...

RabbitMQ---应用问题

&#xff08;一&#xff09;幂等性介绍 幂等性是本身是数学中的运算性质&#xff0c;他们可以被多次应用&#xff0c;但是不会改变初始应用的结果 1.应用程序的幂等性介绍 包括很多&#xff0c;有数据库幂等性&#xff0c;接口幂等性以及网络通信幂等性等 就比如数据库的sel…...

Unity自学之旅03

Unity自学之旅03 Unity自学之旅03&#x1f4dd; 碰撞体 Collider 基础定义与作用常见类型OnCollisionEnter 事件碰撞触发器 &#x1f917; 总结归纳 Unity自学之旅03 &#x1f4dd; 碰撞体 Collider 基础 定义与作用 定义&#xff1a;碰撞体是游戏中用于检测物体之间碰撞的组…...

pip 相关

一劳永逸法&#xff08;pip怎么样都用不了也更新不了&#xff09;&#xff1a; 重下python(卸载旧版本&#xff09;&#xff1a;请输入访问密码 密码&#xff1a;7598 各版本python都有&#xff0c;下3.10.10 python路径建立&#xff0c;pip无法访问方式&#xff1a; 访问pip要…...

vue request 发送formdata

在Vue中&#xff0c;你可以使用axios库来发送包含FormData的请求。以下是一个简单的例子&#xff1a; 首先&#xff0c;确保你已经安装了axios&#xff1a; npm install axios然后&#xff0c;你可以使用axios发送FormData&#xff0c;例如&#xff1a; import axios from a…...

Android RTMP直播练习实践

前言&#xff1a;本文只是练习&#xff0c;本文只是练习&#xff0c;本文只是练习&#xff01; 直播的核心就是推流和拉流&#xff0c;我们就以RTMP的协议来实现下推流和拉流&#xff0c;其他的协议等我学习后再来补充 1.推流 1.1搭建流媒体服务器&#xff0c;具体搭建方法请参…...

ITIL认证工具商-ManageEngine Servicedesk Plus

ServiceDesk Plus是Zoho Corporation旗下企业IT管理部门ManageEngine提供的统一服务管理解决方案。凭借其无限的可扩展性、情境化的IT和业务集成以及一键式工作流程自动化功能&#xff0c;IT领导者可以使用ServiceDesk Plus有效执行和控制跨不同业务部门和IT功能的复杂工作流程…...

https 的 CA证书和电子签名

https 的攻击者可能使用伪造的一对公私钥与客户端交互, 那么如何确保确实是该服务器的公钥呢? 这就诞生了CA颁发机构 CA颁发机构 服务器和客户端都信任指定的CA颁发机构 服务器上传服务器公钥, CA颁发机构做了什么 服务器公钥哈希, 记为 Hash使用 CA 私钥为 Hash 进行 CA 签…...

频繁刷新网页会对服务器造成哪些影响?

当用户在进行浏览网页的过程中频繁刷新页面时&#xff0c;浏览器会向服务器发送请求&#xff0c;服务器会对该请求进行处理并返回到相应的页面内容中&#xff0c;所以频繁刷新网页会对服务器造成影响&#xff0c;有可能会出现以下问题&#xff1a; 用户每次刷新网页都会向服务器…...

贪心算法(题1)区间选点

输出 2 #include <iostream> #include<algorithm>using namespace std;const int N 100010 ;int n; struct Range {int l,r;bool operator <(const Range &W)const{return r<W.r;} }range[N];int main() {scanf("%d",&n);for(int i0;i&l…...

JavaWeb开发学习笔记--MySQL

MySQL-DQL 基本语法&#xff1a; select 字段列表 from 表名列表 where 条件列表 group by 分组字段列表 having 分组后条件列表 order by 排序字段列表 limit 分页参数 基本查询 关键字&#xff1a;SELECT 查询多个字段&#xff1a;select 字…...

抖音小程序一键获取手机号

前端代码组件 <button v-if"!isFromOrderList"class"get-phone-btn" open-type"getPhoneNumber"getphonenumber"onGetPhoneNumber">一键获取</button>// 获取手机号回调onGetPhoneNumber(e) {var that this tt.login({f…...

iconfont等图标托管网站上传svg显示未轮廓化解决办法

打开即时设计 即时设计 - 可实时协作的专业 UI 设计工具 导入图标后拖入画板里面&#xff0c;右键选择轮廓化 将图标导出...

2008-2020年各省城镇登记失业率数据

2008-2020年各省城镇登记失业率数据 1、时间&#xff1a;2008-2020年 2、来源&#xff1a;国家统计局、统计年鉴 3、指标&#xff1a;行政区划代码、地区名称、年份、城镇登记失业率 4、范围&#xff1a;31省 5、指标说明&#xff1a;城镇登记失业率是指在一定时期内&…...

Linux——信号量和(环形队列消费者模型)

Linux——线程条件变量&#xff08;同步&#xff09;-CSDN博客 文章目录 目录 文章目录 前言 一、信号量是什么&#xff1f; 二、信号量 1、主要类型 2、操作 3、应用场景 三、信号量函数 1、sem_init 函数 2、sem_wait 函数 3、sem_post 函数 4、sem_destroy 函数 ​​​​​​…...

【JOIN】关键字在MySql中的详细使用

目录 INNER JOIN&#xff08;内连接&#xff09; LEFT JOIN&#xff08;左连接&#xff09; RIGHT JOIN&#xff08;右连接&#xff09; FULL JOIN&#xff08;全连接&#xff09; 示例图形化解释JOIN的不同类型 INNER JOIN&#xff1a; LEFT JOIN&#xff1a; RIGHT J…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank&#xff1f;由于时间太久&#xff0c;我真忘记了。搜搜发现&#xff0c;还真有人和我一样。见下面的链接&#xff1a;https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

Java - Mysql数据类型对应

Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么&#xff1f; WebAssembly&#xff08;WASM&#xff09; 是一种能在现代浏览器中高效运行的二进制指令格式&#xff0c;它不是传统的编程语言&#xff0c;而是一种 低级字节码格式&#xff0c;可由高级语言&#xff08;如 C、C、Rust&am…...

#Uniapp篇:chrome调试unapp适配

chrome调试设备----使用Android模拟机开发调试移动端页面 Chrome://inspect/#devices MuMu模拟器Edge浏览器&#xff1a;Android原生APP嵌入的H5页面元素定位 chrome://inspect/#devices uniapp单位适配 根路径下 postcss.config.js 需要装这些插件 “postcss”: “^8.5.…...

逻辑回归暴力训练预测金融欺诈

简述 「使用逻辑回归暴力预测金融欺诈&#xff0c;并不断增加特征维度持续测试」的做法&#xff0c;体现了一种逐步建模与迭代验证的实验思路&#xff0c;在金融欺诈检测中非常有价值&#xff0c;本文作为一篇回顾性记录了早年间公司给某行做反欺诈预测用到的技术和思路。百度…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...

Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成

一个面向 Java 开发者的 Sring-Ai 示例工程项目&#xff0c;该项目是一个 Spring AI 快速入门的样例工程项目&#xff0c;旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计&#xff0c;每个模块都专注于特定的功能领域&#xff0c;便于学习和…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...

算术操作符与类型转换:从基础到精通

目录 前言&#xff1a;从基础到实践——探索运算符与类型转换的奥秘 算术操作符超级详解 算术操作符&#xff1a;、-、*、/、% 赋值操作符&#xff1a;和复合赋值 单⽬操作符&#xff1a;、--、、- 前言&#xff1a;从基础到实践——探索运算符与类型转换的奥秘 在先前的文…...

React父子组件通信:Props怎么用?如何从父组件向子组件传递数据?

系列回顾&#xff1a; 在上一篇《React核心概念&#xff1a;State是什么&#xff1f;》中&#xff0c;我们学习了如何使用useState让一个组件拥有自己的内部数据&#xff08;State&#xff09;&#xff0c;并通过一个计数器案例&#xff0c;实现了组件的自我更新。这很棒&#…...