当前位置: 首页 > news >正文

pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 

相关文章:

pytorch实现RNN网络

目录 1.导包 2. 加载本地文本数据 3.构建循环神经网络层 4.初始化隐藏状态state 5.创建随机的数据,检测一下代码是否能正常运行 6. 构建一个完整的循环神经网络 7.模型训练 8.个人知识点理解 1.导包 import torch from torch import nn from torch.nn imp…...

智能工厂的软件设计 “程序program”表达式,即 接口模型的代理模式表达式

Q1、前面将“智能工厂的软件设计”中绝无仅有的“程序”视为 专注于 给定的某个单一面(语言面/逻辑面/数学面)中的 问题,专注于分析问题和解决问题的程序活动的组织,每一面都是一个“组织者”就像一个“独角兽”,并提出…...

leetcode 难度【简单模式】标签【数据库】题型整理大全

文章目录 175. 组合两个表181. 超过经理收入的员工182. 查找重复的电子邮箱COUNT(*)COUNT(*) 与 COUNT(column) 的区别 where和vaing之间的区别用法 183.从不订购的客户196.删除重复的电子邮箱197.上升的温度511.游戏玩法分析I512.游戏玩法分析II577.员工奖金584.寻找用户推荐人…...

利士策分享,自我和解:通往赚钱与内心富足的和谐之道

利士策分享,自我和解:通往赚钱与内心富足的和谐之道 在这个快节奏、高压力的时代,我们往往在追求物质财富的同时,忽略了内心世界的和谐与平衡。 赚钱,作为现代生活中不可或缺的一部分,它不仅仅是生存的手段…...

【物联网】深入解析时序数据库TDengine及其Java应用实践

文章目录 一、什么是时序数据库?二、TDengine简介三、TDengine的Java应用实践(1)环境准备(2)数据插入(3)数据查询 一、什么是时序数据库? 时序数据库(Time-Series Datab…...

2023北华大学程序设计新生赛部分题解

时光如流水般逝去,我已在校园中奋战大二!(≧▽≦) 今天,静静回顾去年的新生赛,心中涌起无尽感慨,仿佛那段青春岁月如烟花般绚烂。✧。(≧▽≦)。✧ 青春就像一场燃烧的盛宴,激情澎湃&…...

PPP的配置

概述:PPP模式,即公私合作模式(Public-Private Partnership),是一种公共部门与私营部门合作的模式。 一、实验拓扑 实验一:PPP基本功能 实验步骤: (1)配置AR1的接口IP地…...

回溯算法总结篇

组合问题:N个数里面按一定规则找出k个数的集合 如果题目要求的是组合的具体信息,则只能使用回溯算法,如果题目只是要求组合的某些最值,个数等信息,则使用动态规划(比如求组合中元素最少的组合,…...

机器学习-点击率预估-论文速读-20240916

1. [经典文章] 特征交叉: Factorization Machines, ICDM, 2010 分解机(Factorization Machines) 摘要 本文介绍了一种新的模型类——分解机(FM),它结合了支持向量机(SVM)和分解模型的优点。与…...

【leetcode】堆习题

215.数组中的第K个最大元素 给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。 请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。 你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1: 输…...

前端大模型入门:编码(Tokenizer)和嵌入(Embedding)解析 - llm的输入

LLM的核心是通过对语言进行建模来生成自然语言输出或理解输入,两个重要的概念在其中发挥关键作用:Tokenizer 和 Embedding。本篇文章将对这两个概念进行入门级介绍,并提供了针对前端的js示例代码,帮助读者理解它们的基本原理/作用和如何使用。 1. 什么是…...

一文读懂 JS 中的 Map 结构

你好,我是沐爸,欢迎点赞、收藏、评论和关注。 上次聊了 Set 数据结构,今天我们聊下 Map,看看它与 Set、与普通对象有什么区别?下面直接进入正题。 一、Set 和 Map 有什么区别? Set 是一个集合&#xff0…...

C++校招面经(二)

欢迎关注 0voice GitHub 6、 C 和 Java 区别(语⾔特性,垃圾回收,应⽤场景等) 指针: Java 语⾔让程序员没法找到指针来直接访问内存,没有指针的概念,并有内存的⾃动管理功能,从⽽…...

Python Web 面试题

1 Web 相关 get 和 post 区别 get: 请求数据在 URL 末尾,URL 长度有限制 请求幂等,即无论请求多少次,服务器响应始终相同,这是因为 get 至少获取资源,而不修改资源 可以被浏览器缓存,以便以后…...

java日志框架之JUL(Logging)

文章目录 一、JUL简介1、JUL组件介绍 二、Logger快速入门三、Logger日志级别1、日志级别2、默认级别info3、原理分析4、自定义日志级别5、日志持久化(保存到磁盘) 三、Logger父子关系四、Logger配置文件 一、JUL简介 JUL全程Java Util Logging&#xff…...

ARM驱动学习之PWM

ARM驱动学习之PWM 1.分析原理图: GPD0_0 XpwmTOUT0定时器0 2.定时器上的资源: 1.5组32位定时器 2.定时器产生内部中断 3.定时器0,1,2可编程实现pwm 4.定时器各自分频 5.TCN--,TCN TCMPBN 6.分频器 24-2 7.24.3.4 例子&#xff1…...

我的AI工具箱Tauri版-VideoClipMixingCut视频批量混剪

本教程基于自研的AI工具箱Tauri版进行VideoClipMixingCut视频批量混剪。 VideoClipMixingCut视频批量混剪 是自研AI工具箱Tauri版中的一款强大工具,专为自动化视频批量混剪设计。该模块通过将预设的解说文稿与视频素材进行自动拼接生成混剪视频,适合需要…...

postgres_fdw访问存储在外部 PostgreSQL 服务器中的数据

文章目录 一、postgres_fdw 介绍二、安装使用示例三、成本估算四、 远程执行选项执行计划无法递推解决 参考文件: 一、postgres_fdw 介绍 postgres_fdw 模块提供外部数据包装器 postgres_fdw,可用于访问存储在外部 PostgreSQL 服务器中的数据。 此模块…...

什么是3D展厅?有何优势?怎么制作3D展厅?

一、什么是3D展厅? 3D展厅是一种利用三维技术构建的虚拟展示空间。它借助虚拟现实(VR)、增强现实(AR)等现代科技手段,将真实的展示空间数字化,呈现出逼真、立体、沉浸的展示效果。通过3D展厅&a…...

Linux下的CAN通讯

CAN总线 CAN总线简介 CAN&#xff08;Controller Area Network&#xff09;总线是一种多主从式 <font color red>异步半双工串行 </font> 通信总线&#xff0c;它最早由Bosch公司开发&#xff0c;用于汽车电子系统。CAN总线具有以下特点&#xff1a; 多主从式&a…...

循环冗余码校验CRC码 算法步骤+详细实例计算

通信过程&#xff1a;&#xff08;白话解释&#xff09; 我们将原始待发送的消息称为 M M M&#xff0c;依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)&#xff08;意思就是 G &#xff08; x ) G&#xff08;x) G&#xff08;x) 是已知的&#xff09;&#xff0…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 &#xff08;FL&#xff09; 支持跨分布式客户端进行协作模型训练&#xff0c;而无需共享原始数据&#xff0c;这使其成为在互联和自动驾驶汽车 &#xff08;CAV&#xff09; 等领域保护隐私的机器学习的一种很有前途的方法。然而&#xff0c;最近的研究表明&…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址&#xff1a;pdf 英文是纯手打的&#xff01;论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误&#xff0c;若有发现欢迎评论指正&#xff01;文章偏向于笔记&#xff0c;谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

【数据分析】R版IntelliGenes用于生物标志物发现的可解释机器学习

禁止商业或二改转载&#xff0c;仅供自学使用&#xff0c;侵权必究&#xff0c;如需截取部分内容请后台联系作者! 文章目录 介绍流程步骤1. 输入数据2. 特征选择3. 模型训练4. I-Genes 评分计算5. 输出结果 IntelliGenesR 安装包1. 特征选择2. 模型训练和评估3. I-Genes 评分计…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中&#xff0c;性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期&#xff0c;开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发&#xff0c;但背后往往隐藏着系统资源调度不当…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事&#xff0c;必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后&#xff0c;我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集&#xff0c;就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...

Axure 下拉框联动

实现选省、选完省之后选对应省份下的市区...

数据结构:泰勒展开式:霍纳法则(Horner‘s Rule)

目录 &#x1f50d; 若用递归计算每一项&#xff0c;会发生什么&#xff1f; Horners Rule&#xff08;霍纳法则&#xff09; 第一步&#xff1a;我们从最原始的泰勒公式出发 第二步&#xff1a;从形式上重新观察展开式 &#x1f31f; 第三步&#xff1a;引出霍纳法则&…...

医疗AI模型可解释性编程研究:基于SHAP、LIME与Anchor

1 医疗树模型与可解释人工智能基础 医疗领域的人工智能应用正迅速从理论研究转向临床实践,在这一过程中,模型可解释性已成为确保AI系统被医疗专业人员接受和信任的关键因素。基于树模型的集成算法(如RandomForest、XGBoost、LightGBM)因其卓越的预测性能和相对良好的解释性…...

算法250609 高精度

加法 #include<stdio.h> #include<iostream> #include<string.h> #include<math.h> #include<algorithm> using namespace std; char input1[205]; char input2[205]; int main(){while(scanf("%s%s",input1,input2)!EOF){int a[205]…...