当前位置: 首页 > article >正文

LSTM 学习笔记 之pytorch调包每个参数的解释

0、 LSTM 原理

整理优秀的文章
LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现)
[干货]深入浅出LSTM及其Python代码实现
整理视频
李毅宏手撕LSTM
[双语字幕]吴恩达深度学习deeplearning.ai

1 Pytorch 代码

这里直接调用了nn.lstm

 self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nn

下面作为初学者解释一下里面的3个参数
input_size: 这个就是输入的向量的长度or 维度,如一个单词可能占用20个维度。
hidden_size: 这个是隐藏层,其实我感觉有点全连接的意思,这个层的维度影响LSTM 网络输入的维度,换句话说,LSTM接收的数据维度不是输入什么维度就是什么维度,而是经过了隐藏层,做了一个维度的转化。
num_layers: 这里就是说堆叠了几个LSMT 结构。

2 网络定义

class LstmRNN(nn.Module):"""Parameters:- input_size: feature size- hidden_size: number of hidden units- output_size: number of output- num_layers: layers of LSTM to stack"""def __init__(self, input_size, hidden_size=1, output_size=1, num_layers=1):super().__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers)  # utilize the LSTM model in torch.nnself.forwardCalculation = nn.Linear(hidden_size, output_size)def forward(self, _x):x, _ = self.lstm(_x)  # _x is input, size (seq_len, batch, input_size)s, b, h = x.shape  # x is output, size (seq_len, batch, hidden_size)x = x.view(s * b, h)x = self.forwardCalculation(x)x = x.view(s, b, -1)return x

3 网络初始化

我们定义一个网络导出onnx ,观察 网络的具体结构

INPUT_FEATURES_NUM = 100
OUTPUT_FEATURES_NUM = 13
lstm_model = LstmRNN(INPUT_FEATURES_NUM, 16, output_size=OUTPUT_FEATURES_NUM, num_layers=2)  # 16 hidden units
print(lstm_model)
save_onnx_path= "weights/lstm_16.onnx"
input_data = torch.randn(1,150,100)input_names = ["images"] + ["called_%d" % i for i in range(2)]
output_names = ["prob"]
torch.onnx.export(lstm_model,input_data,save_onnx_path,verbose=True,input_names=input_names,output_names=output_names,opset_version=12)

在这里插入图片描述
可以看到 LSTM W 是1x64x100;这个序列150没有了 是不是说150序列是一次一次的送的呢,所以在网络中没有体现;16是hidden,LSTM里面的W是64,这里存在一个4倍的关系。
我想这个关系和LSTM的3个门(输入+输出+遗忘+C^)有联系。
在这里插入图片描述
在这里插入图片描述
这里输出我们设置的13,如图 onnx 网络结构可视化显示也是13,至于这个150,或许就是输入有150个词,输出也是150个词吧。

在这里插入图片描述
至于LSTM的层数设置为2,则表示有2个LSTM堆叠。
在这里插入图片描述

4 网络提取

另外提取 网络方便看 每一层的维度,代码如下。

import onnx
from onnx import helper, checker
from onnx import TensorProto
import re
import argparse
model = "./weights/lstm_16.onnx"
output_model_path = "./weights/lstm_16_e.onnx"onnx_model = onnx.load(model)
#Flatten
onnx.utils.extract_model(model, output_model_path, ['images'],['prob'])

相关文章:

LSTM 学习笔记 之pytorch调包每个参数的解释

0、 LSTM 原理 整理优秀的文章 LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现) [干货]深入浅出LSTM及其Python代码实现 整理视频 李毅宏手撕LSTM [双语字幕]吴恩达深度学习deeplearning.ai 1 Pytorch 代码 这里直接调用了nn.l…...

计算机网络,大白话

好嘞,咱就从头到尾,给你好好说道说道计算机网络里这些“门门道道”的概念: 1. 网络(Network) 啥是网络? 你可以把网络想象成一个“大Party”,大家(设备)聚在一起&#…...

自定义sort排序

数组中&#xff0c;根据出现次数以大到小排序&#xff0c;当频率相同时按元素值降序排序 #include <iostream> #include <vector> #include <algorithm> #include <unordered_map>// 全局的 unordered_map 用于存储元素频率 std::unordered_map<in…...

【EXCEL】【VBA】处理GI Log获得Surf格式的CONTOUR DATA

【EXCEL】【VBA】处理GI Log获得Surf格式的CONTOUR DATA data source1: BH coordination tabledata source2:BH layer tableprocess 1:Collect BH List To Layer Tableprocess 2:match Reduced Level from "Layer"+"BH"data source1: BH coordination…...

kafka动态监听主题

简单版本 import org.springframework.beans.factory.annotation.Autowired; import org.springframework.kafka.core.ConsumerFactory; import org.springframework.kafka.listener.ConcurrentMessageListenerContainer; import org.springframework.kafka.listener.Containe…...

【PHP的static】

关于静态属性 最简单直接&#xff1a;静态方法也是一样 看了很多关于静态和动态的说法&#xff0c;无非是从 调用方式&#xff0c; 类访问实例变量&#xff0c; 访问静态变量&#xff0c; 需不要实例化这几个方向&#xff0c;太空了。问使用场景&#xff0c;好一点的 能说个…...

国产编辑器EverEdit - 光标位置跳转

1 光标位置跳转 1.1 应用场景 某些场景下&#xff0c;用户从当前编辑位置跳转到别的位置查阅信息&#xff0c;如果要快速跳转回之前编辑位置&#xff0c;则可以使用光标跳转相关功能。 1.2 使用方法 1.2.1 上一个编辑位置 跳转到上一个编辑位置&#xff0c;即文本修改过的位…...

cv2.Sobel

1. Sobel 算子简介 Sobel 算子是一种 边缘检测算子&#xff0c;通过对图像做梯度计算&#xff0c;可以突出边缘。 Sobel X 方向卷积核&#xff1a; 用于计算 水平方向&#xff08;x 方向&#xff09; 的梯度。 2. 输入图像示例 假设我们有一个 55 的灰度图像&#xff0c;像素…...

51单片机俄罗斯方块整行消除函数

/************************************************************************************************************** * 名称&#xff1a;flash * 功能&#xff1a;行清除动画 * 参数&#xff1a;NULL * 返回&#xff1a;NULL * 备注&#xff1a; * 采用非阻塞延时&#xff0…...

鸿蒙HarmonyOS NEXT开发:优化用户界面性能——组件复用(@Reusable装饰器)

文章目录 一、概述二、原理介绍三、使用规则四、复用类型详解1、标准型2、有限变化型2.1、类型1和类型2布局不同&#xff0c;业务逻辑不同2.2、类型1和类型2布局不同&#xff0c;但是很多业务逻辑公用 3、组合型4、全局型5、嵌套型 一、概述 组件复用是优化用户界面性能&#…...

langchain系列(二)- 提示词以及模板

导读 环境&#xff1a;OpenEuler、Windows 11、WSL 2、Python 3.12.3 langchain 0.3 背景&#xff1a;前期忙碌的开发阶段结束&#xff0c;需要沉淀自己的应用知识&#xff0c;过一遍LangChain 时间&#xff1a;20250212 说明&#xff1a;技术梳理 提示词模板理论说明 提…...

Openssl的使用,CA证书,中间证书,服务器证书的生成与使用

证书教程 1、Openssl相关文档2、生成证书命令初步解释3、准备openssl的配置文件 openssl.cnf4、证书生成4.1、生成根证书、CA根证书、自签名证书4.2、生成服务器证书4.3、生成中间证书4.3、使用中间证书生成服务器证书5、使用openssl操作证书5.1 查看证书内容5.2 进行证书测试5…...

深入浅出:Python 中的异步编程与协程

引言 大家好&#xff0c;今天我们来聊聊 异步编程 和 协程&#xff0c;这是近年来编程语言领域中的热点话题之一&#xff0c;尤其在 Python 中&#xff0c;它作为一种全新的编程模型&#xff0c;已经成为处理 IO密集型 任务的强力工具。尽管很多人对异步编程望而却步&#xff0…...

Windows中使用Docker安装Anythingllm,基于deepseek构建自己的本地知识库问答大模型,可局域网内多用户访问、离线运行

文章目录 Windows中使用Docker安装Anythingllm&#xff0c;基于deepseek构建自己的知识库问答大模型1. 安装 Docker Desktop2. 使用Docker拉取Anythingllm镜像2. 设置 STORAGE_LOCATION 路径3. 创建存储目录和 .env 文件.env 文件的作用关键配置项 4. 运行 Docker 命令docker r…...

Unity使用iTextSharp导出PDF-04图形

坐标系 pdf文档页面的原点&#xff08;0&#xff0c;0&#xff09;在左下角&#xff0c;向上为y,向右为x。 文档的PageSize可获取页面的宽高数值 单位&#xff1a;像素 绘制矢量图形 使用PdfContentByte类进行绘制&#xff0c;注意文档打开后才有此对象的实例。 绘制方法 …...

[SAP ABAP] OO ALV报表练习1

销售订单明细查询报表 业务目的&#xff1a;根据选择屏幕的筛选条件&#xff0c;使用 ALV 报表&#xff0c;显示销售订单详情 效果展示 用户的输入条件界面 用户的查询结果界面 涉及的主要功能点&#xff1a; 1.当在销售订单明细查询页面取不到任何数据时&#xff0c;在选择…...

安卓基础(第一集)

SharedPreferences&#xff08;本地存储简单数据&#xff09; 在 Android 中&#xff0c;SharedPreferences 用于存储小型数据。 &#xff08;1&#xff09;存储数据 // 获取 SharedPreferences 对象 SharedPreferences sharedPreferences getSharedPreferences("MyPre…...

数据库高安全—数据保护:数据动态脱敏

书接上文数据库高安全—审计追踪&#xff1a;传统审计&统一审计&#xff0c;从传统审计和统一审计两方面对高斯数据库的审计追踪技术进行解读&#xff0c;本篇将从数据动态脱敏方面对高斯数据库的数据保护技术进行解读。 5.1 数据动态脱敏 数据脱敏&#xff0c;顾名思义就…...

Datawhale 数学建模导论二 2025年2月

第6章 数据处理与拟合模型 本章主要涉及到的知识点有&#xff1a; 数据与大数据Python数据预处理常见的统计分析模型随机过程与随机模拟数据可视化 本章内容涉及到基础的概率论与数理统计理论&#xff0c;如果对这部分内容不熟悉&#xff0c;可以参考相关概率论与数理统计的…...

ArcGIS Enterprise 与 ArcGIS Online 的关系

ArcGIS Enterprise 和 ArcGIS Online 是 Esri 提供的两款核心产品,它们在功能、部署方式和使用场景上存在显著差异,但同时也有一定的联系和互补性。以下是关于这两款产品的详细关系说明: 1. 产品定位与功能 ArcGIS Enterprise 是一款企业级解决方案,支持在组织的基础设施上…...

ASP.NET Core SignalR实践指南

Hub类的生命周期是瞬态的&#xff0c;每次调用集线器的时候都会创建一个新的Hub类实例&#xff0c;因此不要在Hub类中通过属性、成员变量等方式保存状态。如果服务器的压力比较大&#xff0c;建议把ASP.NET Core程序和SignalR服务器端部署到不同服务器上&#xff0c;以免它们互…...

【力扣 - 简单题】88. 合并两个有序数组

题目&#xff1a;88. 合并两个有序数组 - 力扣&#xff08;LeetCode&#xff09; 解题&#xff1a; class Solution { public:void merge(vector<int>& nums1, int m, vector<int>& nums2, int n) {for (int i m; i < n m; i ){nums1[i] nums2[i -…...

【密评】 | 商用密码应用安全性评估从业人员考核题库(23)

在GM/T0048《智能密码钥匙密码检测规范》中,产品的对称算法性能应满足哪个标准中的要求()。 A.GM/T 0016《智能密码钥匙密码应用接口规范》 B.GM/T 0017《智能密码钥匙密码应用接口数据格式规范》 C.GM/T 0027《智能密码钥匙技术规范》 D.GM/T 0028《密码模块安全技术要求》…...

记录 | WPF基础学习MVVM例子讲解1

目录 前言一、NotificationObject与数据属性创建个类&#xff0c;声明NotificationObject 二、DelegateCommand与命令属性三、View与ViewModel的交互&#xff08;难点&#xff09;在ViewModel文件下创建MainWindowViewModel数据和方法绑定资源指定 代码下载四、优势体现代码下载…...

PyTorch 中 `torch.cuda.amp` 相关警告的解决方法

在最近的写代码过程中&#xff0c;遇到了两个与 PyTorch 的混合精度训练相关的警告信息。这里随手记录一下。 警告内容 警告 1: torch.cuda.amp.autocast FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast(cuda, args...) i…...

实验7 路由器之间IPsec VPN配置

实验7 路由器之间IPsec VPN配置 1.实验目的 通过在两台路由器之间配置IPsec VPN连接&#xff0c;掌握IPsec VPN配置方法&#xff0c;加深对IPsec协议的理解。 2.实验内容 &#xff08;1&#xff09;按照实验拓扑搭建实验环境。 &#xff08;2&#xff09;在路由器R1和R4配置IP…...

Unity中快速制作2D沙雕动画:流程编

Unity中快速制作2D沙雕动画&#xff08;搞笑/无厘头风格&#xff09;&#xff0c;通过以下方案实现低成本、高成效的开发流程&#xff0c;结合夸张的动作、滑稽的物理效果和魔性音效&#xff1a; 1. 角色与素材设计 核心原则&#xff1a;丑萌即正义&#xff0c;越怪越好&#…...

小白零基础如何搭建CNN

1.卷积层 在PyTorch中针对卷积操作的对象和使用的场景不同&#xff0c;如有1维卷积、2维卷积、 3维卷积与转置卷积&#xff08;可以简单理解为卷积操作的逆操作&#xff09;&#xff0c;但它们的使用方法比较相似&#xff0c;都可以从torch.nn模块中调用&#xff0c;需要调用的…...

【Java八股文】01-Java基础面试篇

【Java八股文】01-Java基础面试篇 概念Java特点Java为什么跨平台JVM、JDK、JRE关系 面向对象什么是面向对象&#xff0c;什么是封装继承多态&#xff1f;多态体现的方面面向对象设计原则重载重写的区别抽象类和实体类区别Java抽象类和接口的区别抽象类可以被实例化吗 深拷贝浅拷…...

k8s部署logstash

1. 编写logstash.yaml配置文件 --- apiVersion: v1 kind: Service metadata:name: logstash spec:type: ClusterIPclusterIP: Noneports:- name: logstash-tcpport: 5000targetPort: 5000- name: logstash-beatsport: 5044targetPort: 5044- name: logstash-apiport: 9600targ…...