pytorch实现RNN网络
目录
1.导包
2. 加载本地文本数据
3.构建循环神经网络层
4.初始化隐藏状态state
5.创建随机的数据,检测一下代码是否能正常运行
6. 构建一个完整的循环神经网络¶
7.模型训练
8.个人知识点理解
1.导包
import torch
from torch import nn
from torch.nn import functional as F
import dltools
2. 加载本地文本数据
#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)
3.构建循环神经网络层
#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)
4.初始化隐藏状态state
# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])
5.创建随机的数据,检测一下代码是否能正常运行
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!
(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))
6. 构建一个完整的循环神经网络¶
.long() 方法:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。
class RNNModel(nn.Module): #继承nn.Module#初始化(需要用到的)参数, **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'
7.模型训练
#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

8.个人知识点理解

相关文章:
pytorch实现RNN网络
目录 1.导包 2. 加载本地文本数据 3.构建循环神经网络层 4.初始化隐藏状态state 5.创建随机的数据,检测一下代码是否能正常运行 6. 构建一个完整的循环神经网络 7.模型训练 8.个人知识点理解 1.导包 import torch from torch import nn from torch.nn imp…...
智能工厂的软件设计 “程序program”表达式,即 接口模型的代理模式表达式
Q1、前面将“智能工厂的软件设计”中绝无仅有的“程序”视为 专注于 给定的某个单一面(语言面/逻辑面/数学面)中的 问题,专注于分析问题和解决问题的程序活动的组织,每一面都是一个“组织者”就像一个“独角兽”,并提出…...
leetcode 难度【简单模式】标签【数据库】题型整理大全
文章目录 175. 组合两个表181. 超过经理收入的员工182. 查找重复的电子邮箱COUNT(*)COUNT(*) 与 COUNT(column) 的区别 where和vaing之间的区别用法 183.从不订购的客户196.删除重复的电子邮箱197.上升的温度511.游戏玩法分析I512.游戏玩法分析II577.员工奖金584.寻找用户推荐人…...
利士策分享,自我和解:通往赚钱与内心富足的和谐之道
利士策分享,自我和解:通往赚钱与内心富足的和谐之道 在这个快节奏、高压力的时代,我们往往在追求物质财富的同时,忽略了内心世界的和谐与平衡。 赚钱,作为现代生活中不可或缺的一部分,它不仅仅是生存的手段…...
【物联网】深入解析时序数据库TDengine及其Java应用实践
文章目录 一、什么是时序数据库?二、TDengine简介三、TDengine的Java应用实践(1)环境准备(2)数据插入(3)数据查询 一、什么是时序数据库? 时序数据库(Time-Series Datab…...
2023北华大学程序设计新生赛部分题解
时光如流水般逝去,我已在校园中奋战大二!(≧▽≦) 今天,静静回顾去年的新生赛,心中涌起无尽感慨,仿佛那段青春岁月如烟花般绚烂。✧。(≧▽≦)。✧ 青春就像一场燃烧的盛宴,激情澎湃&…...
PPP的配置
概述:PPP模式,即公私合作模式(Public-Private Partnership),是一种公共部门与私营部门合作的模式。 一、实验拓扑 实验一:PPP基本功能 实验步骤: (1)配置AR1的接口IP地…...
回溯算法总结篇
组合问题:N个数里面按一定规则找出k个数的集合 如果题目要求的是组合的具体信息,则只能使用回溯算法,如果题目只是要求组合的某些最值,个数等信息,则使用动态规划(比如求组合中元素最少的组合,…...
机器学习-点击率预估-论文速读-20240916
1. [经典文章] 特征交叉: Factorization Machines, ICDM, 2010 分解机(Factorization Machines) 摘要 本文介绍了一种新的模型类——分解机(FM),它结合了支持向量机(SVM)和分解模型的优点。与…...
【leetcode】堆习题
215.数组中的第K个最大元素 给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。 请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。 你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1: 输…...
前端大模型入门:编码(Tokenizer)和嵌入(Embedding)解析 - llm的输入
LLM的核心是通过对语言进行建模来生成自然语言输出或理解输入,两个重要的概念在其中发挥关键作用:Tokenizer 和 Embedding。本篇文章将对这两个概念进行入门级介绍,并提供了针对前端的js示例代码,帮助读者理解它们的基本原理/作用和如何使用。 1. 什么是…...
一文读懂 JS 中的 Map 结构
你好,我是沐爸,欢迎点赞、收藏、评论和关注。 上次聊了 Set 数据结构,今天我们聊下 Map,看看它与 Set、与普通对象有什么区别?下面直接进入正题。 一、Set 和 Map 有什么区别? Set 是一个集合࿰…...
C++校招面经(二)
欢迎关注 0voice GitHub 6、 C 和 Java 区别(语⾔特性,垃圾回收,应⽤场景等) 指针: Java 语⾔让程序员没法找到指针来直接访问内存,没有指针的概念,并有内存的⾃动管理功能,从⽽…...
Python Web 面试题
1 Web 相关 get 和 post 区别 get: 请求数据在 URL 末尾,URL 长度有限制 请求幂等,即无论请求多少次,服务器响应始终相同,这是因为 get 至少获取资源,而不修改资源 可以被浏览器缓存,以便以后…...
java日志框架之JUL(Logging)
文章目录 一、JUL简介1、JUL组件介绍 二、Logger快速入门三、Logger日志级别1、日志级别2、默认级别info3、原理分析4、自定义日志级别5、日志持久化(保存到磁盘) 三、Logger父子关系四、Logger配置文件 一、JUL简介 JUL全程Java Util Loggingÿ…...
ARM驱动学习之PWM
ARM驱动学习之PWM 1.分析原理图: GPD0_0 XpwmTOUT0定时器0 2.定时器上的资源: 1.5组32位定时器 2.定时器产生内部中断 3.定时器0,1,2可编程实现pwm 4.定时器各自分频 5.TCN--,TCN TCMPBN 6.分频器 24-2 7.24.3.4 例子࿱…...
我的AI工具箱Tauri版-VideoClipMixingCut视频批量混剪
本教程基于自研的AI工具箱Tauri版进行VideoClipMixingCut视频批量混剪。 VideoClipMixingCut视频批量混剪 是自研AI工具箱Tauri版中的一款强大工具,专为自动化视频批量混剪设计。该模块通过将预设的解说文稿与视频素材进行自动拼接生成混剪视频,适合需要…...
postgres_fdw访问存储在外部 PostgreSQL 服务器中的数据
文章目录 一、postgres_fdw 介绍二、安装使用示例三、成本估算四、 远程执行选项执行计划无法递推解决 参考文件: 一、postgres_fdw 介绍 postgres_fdw 模块提供外部数据包装器 postgres_fdw,可用于访问存储在外部 PostgreSQL 服务器中的数据。 此模块…...
什么是3D展厅?有何优势?怎么制作3D展厅?
一、什么是3D展厅? 3D展厅是一种利用三维技术构建的虚拟展示空间。它借助虚拟现实(VR)、增强现实(AR)等现代科技手段,将真实的展示空间数字化,呈现出逼真、立体、沉浸的展示效果。通过3D展厅&a…...
Linux下的CAN通讯
CAN总线 CAN总线简介 CAN(Controller Area Network)总线是一种多主从式 <font color red>异步半双工串行 </font> 通信总线,它最早由Bosch公司开发,用于汽车电子系统。CAN总线具有以下特点: 多主从式&a…...
使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
XML Group端口详解
在XML数据映射过程中,经常需要对数据进行分组聚合操作。例如,当处理包含多个物料明细的XML文件时,可能需要将相同物料号的明细归为一组,或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码,增加了开…...
接口测试中缓存处理策略
在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...
Ubuntu系统下交叉编译openssl
一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...
1688商品列表API与其他数据源的对接思路
将1688商品列表API与其他数据源对接时,需结合业务场景设计数据流转链路,重点关注数据格式兼容性、接口调用频率控制及数据一致性维护。以下是具体对接思路及关键技术点: 一、核心对接场景与目标 商品数据同步 场景:将1688商品信息…...
dedecms 织梦自定义表单留言增加ajax验证码功能
增加ajax功能模块,用户不点击提交按钮,只要输入框失去焦点,就会提前提示验证码是否正确。 一,模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...
ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放
简介 前面两期文章我们介绍了I2S的读取和写入,一个是通过INMP441麦克风模块采集音频,一个是通过PCM5102A模块播放音频,那如果我们将两者结合起来,将麦克风采集到的音频通过PCM5102A播放,是不是就可以做一个扩音器了呢…...
DBAPI如何优雅的获取单条数据
API如何优雅的获取单条数据 案例一 对于查询类API,查询的是单条数据,比如根据主键ID查询用户信息,sql如下: select id, name, age from user where id #{id}API默认返回的数据格式是多条的,如下: {&qu…...
智能仓储的未来:自动化、AI与数据分析如何重塑物流中心
当仓库学会“思考”,物流的终极形态正在诞生 想象这样的场景: 凌晨3点,某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径;AI视觉系统在0.1秒内扫描包裹信息;数字孪生平台正模拟次日峰值流量压力…...
Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理
引言 Bitmap(位图)是Android应用内存占用的“头号杀手”。一张1080P(1920x1080)的图片以ARGB_8888格式加载时,内存占用高达8MB(192010804字节)。据统计,超过60%的应用OOM崩溃与Bitm…...
