当前位置: 首页 > news >正文

pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 

相关文章:

pytorch实现RNN网络

目录 1.导包 2. 加载本地文本数据 3.构建循环神经网络层 4.初始化隐藏状态state 5.创建随机的数据,检测一下代码是否能正常运行 6. 构建一个完整的循环神经网络 7.模型训练 8.个人知识点理解 1.导包 import torch from torch import nn from torch.nn imp…...

智能工厂的软件设计 “程序program”表达式,即 接口模型的代理模式表达式

Q1、前面将“智能工厂的软件设计”中绝无仅有的“程序”视为 专注于 给定的某个单一面(语言面/逻辑面/数学面)中的 问题,专注于分析问题和解决问题的程序活动的组织,每一面都是一个“组织者”就像一个“独角兽”,并提出…...

leetcode 难度【简单模式】标签【数据库】题型整理大全

文章目录 175. 组合两个表181. 超过经理收入的员工182. 查找重复的电子邮箱COUNT(*)COUNT(*) 与 COUNT(column) 的区别 where和vaing之间的区别用法 183.从不订购的客户196.删除重复的电子邮箱197.上升的温度511.游戏玩法分析I512.游戏玩法分析II577.员工奖金584.寻找用户推荐人…...

利士策分享,自我和解:通往赚钱与内心富足的和谐之道

利士策分享,自我和解:通往赚钱与内心富足的和谐之道 在这个快节奏、高压力的时代,我们往往在追求物质财富的同时,忽略了内心世界的和谐与平衡。 赚钱,作为现代生活中不可或缺的一部分,它不仅仅是生存的手段…...

【物联网】深入解析时序数据库TDengine及其Java应用实践

文章目录 一、什么是时序数据库?二、TDengine简介三、TDengine的Java应用实践(1)环境准备(2)数据插入(3)数据查询 一、什么是时序数据库? 时序数据库(Time-Series Datab…...

2023北华大学程序设计新生赛部分题解

时光如流水般逝去,我已在校园中奋战大二!(≧▽≦) 今天,静静回顾去年的新生赛,心中涌起无尽感慨,仿佛那段青春岁月如烟花般绚烂。✧。(≧▽≦)。✧ 青春就像一场燃烧的盛宴,激情澎湃&…...

PPP的配置

概述:PPP模式,即公私合作模式(Public-Private Partnership),是一种公共部门与私营部门合作的模式。 一、实验拓扑 实验一:PPP基本功能 实验步骤: (1)配置AR1的接口IP地…...

回溯算法总结篇

组合问题:N个数里面按一定规则找出k个数的集合 如果题目要求的是组合的具体信息,则只能使用回溯算法,如果题目只是要求组合的某些最值,个数等信息,则使用动态规划(比如求组合中元素最少的组合,…...

机器学习-点击率预估-论文速读-20240916

1. [经典文章] 特征交叉: Factorization Machines, ICDM, 2010 分解机(Factorization Machines) 摘要 本文介绍了一种新的模型类——分解机(FM),它结合了支持向量机(SVM)和分解模型的优点。与…...

【leetcode】堆习题

215.数组中的第K个最大元素 给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。 请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。 你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1: 输…...

前端大模型入门:编码(Tokenizer)和嵌入(Embedding)解析 - llm的输入

LLM的核心是通过对语言进行建模来生成自然语言输出或理解输入,两个重要的概念在其中发挥关键作用:Tokenizer 和 Embedding。本篇文章将对这两个概念进行入门级介绍,并提供了针对前端的js示例代码,帮助读者理解它们的基本原理/作用和如何使用。 1. 什么是…...

一文读懂 JS 中的 Map 结构

你好,我是沐爸,欢迎点赞、收藏、评论和关注。 上次聊了 Set 数据结构,今天我们聊下 Map,看看它与 Set、与普通对象有什么区别?下面直接进入正题。 一、Set 和 Map 有什么区别? Set 是一个集合&#xff0…...

C++校招面经(二)

欢迎关注 0voice GitHub 6、 C 和 Java 区别(语⾔特性,垃圾回收,应⽤场景等) 指针: Java 语⾔让程序员没法找到指针来直接访问内存,没有指针的概念,并有内存的⾃动管理功能,从⽽…...

Python Web 面试题

1 Web 相关 get 和 post 区别 get: 请求数据在 URL 末尾,URL 长度有限制 请求幂等,即无论请求多少次,服务器响应始终相同,这是因为 get 至少获取资源,而不修改资源 可以被浏览器缓存,以便以后…...

java日志框架之JUL(Logging)

文章目录 一、JUL简介1、JUL组件介绍 二、Logger快速入门三、Logger日志级别1、日志级别2、默认级别info3、原理分析4、自定义日志级别5、日志持久化(保存到磁盘) 三、Logger父子关系四、Logger配置文件 一、JUL简介 JUL全程Java Util Logging&#xff…...

ARM驱动学习之PWM

ARM驱动学习之PWM 1.分析原理图: GPD0_0 XpwmTOUT0定时器0 2.定时器上的资源: 1.5组32位定时器 2.定时器产生内部中断 3.定时器0,1,2可编程实现pwm 4.定时器各自分频 5.TCN--,TCN TCMPBN 6.分频器 24-2 7.24.3.4 例子&#xff1…...

我的AI工具箱Tauri版-VideoClipMixingCut视频批量混剪

本教程基于自研的AI工具箱Tauri版进行VideoClipMixingCut视频批量混剪。 VideoClipMixingCut视频批量混剪 是自研AI工具箱Tauri版中的一款强大工具,专为自动化视频批量混剪设计。该模块通过将预设的解说文稿与视频素材进行自动拼接生成混剪视频,适合需要…...

postgres_fdw访问存储在外部 PostgreSQL 服务器中的数据

文章目录 一、postgres_fdw 介绍二、安装使用示例三、成本估算四、 远程执行选项执行计划无法递推解决 参考文件: 一、postgres_fdw 介绍 postgres_fdw 模块提供外部数据包装器 postgres_fdw,可用于访问存储在外部 PostgreSQL 服务器中的数据。 此模块…...

什么是3D展厅?有何优势?怎么制作3D展厅?

一、什么是3D展厅? 3D展厅是一种利用三维技术构建的虚拟展示空间。它借助虚拟现实(VR)、增强现实(AR)等现代科技手段,将真实的展示空间数字化,呈现出逼真、立体、沉浸的展示效果。通过3D展厅&a…...

Linux下的CAN通讯

CAN总线 CAN总线简介 CAN&#xff08;Controller Area Network&#xff09;总线是一种多主从式 <font color red>异步半双工串行 </font> 通信总线&#xff0c;它最早由Bosch公司开发&#xff0c;用于汽车电子系统。CAN总线具有以下特点&#xff1a; 多主从式&a…...

强化学习在并行机构人形机器人控制中的应用

1. 项目概述在机器人控制领域&#xff0c;强化学习(RL)正逐渐成为解决复杂动力学系统问题的有力工具。然而&#xff0c;当面对具有并行驱动机构的人形机器人时&#xff0c;传统RL训练方法往往面临一个关键挑战&#xff1a;大多数仿真环境无法准确模拟闭环运动链(Closed Kinemat…...

T型翼/尾板导向的穿浪双体船姿态控制【附代码】

✨ 长期致力于穿浪双体船、T型翼、尾板、多自由度姿态控制、舒适性评估研究工作&#xff0c;擅长数据搜集与处理、建模仿真、程序编写、仿真设计。 ✅ 专业定制毕设、代码 ✅ 如需沟通交流&#xff0c;点击《获取方式》 &#xff08;1&#xff09;动态水翼升力模型与耦合运动方…...

Python基础语法:常用内置函数

round()&#xff1a;四舍五入 # 省略 ndigits print(round(3.14)) # 输出 3&#xff08;int&#xff09; print(round(3.66)) # 输出 4# 指定 ndigits print(round(3.14159, 2)) # 输出 3.14&#xff08;float&#xff09; print(round(3.666, 2)) # 输出 3.67# …...

内存占用3KB!极致瘦身释放MCU无限可能

极致小体积&#xff0c;给工业领域带来了无限的可能&#xff1a;更低硬件成本&#xff0c;更小芯片体积&#xff0c;更低功耗&#xff0c;更高可靠性&#xff0c;让每一颗小MCU都拥有大系统的完整能力。 https://www.bilibili.com/video/BV1eZLi6PEjc/?spm_id_from333.1387.ho…...

【深度解析】AI Coding 模型竞速:从 Claude Mythos 安全编码到 GPT-5.6 传闻,如何落地代码审查智能体

摘要 AI 编码模型正在从“代码补全”进入“复杂代码库理解、漏洞发现与自动修复”阶段。本文结合 Claude Mythos、Claude Opus 4.8 与 GPT-5.6 相关信息&#xff0c;解析新一代 Coding Agent 的技术趋势&#xff0c;并给出基于大模型 API 的代码安全审查实战方案。背景介绍&…...

基于随机森林的低成本传感器机器学习校准实践指南

1. 项目概述&#xff1a;当低成本传感器遇上机器学习校准在物联网和智能感知系统铺天盖地的今天&#xff0c;低成本传感器几乎无处不在。从监测办公室的空气质量&#xff0c;到追踪城市街道的噪音污染&#xff0c;再到农业大棚里的温湿度控制&#xff0c;这些价格亲民的“小眼睛…...

交流电机驱动器的三种控制模式:前沿切相、后沿切相与同步模式详解

1. 项目概述&#xff1a;一个能玩出花的交流电机驱动器在汽车改装、工业控制或者一些创客项目里&#xff0c;驱动一个交流电机听起来简单&#xff0c;但想让它听话地变速、正反转&#xff0c;甚至实现软启动和精确同步&#xff0c;往往就得搬出笨重又昂贵的工业变频器。今天分享…...

OpenIPC开源固件:5分钟解锁网络摄像头的终极控制权

OpenIPC开源固件&#xff1a;5分钟解锁网络摄像头的终极控制权 【免费下载链接】firmware Alternative IP Camera firmware from an open community 项目地址: https://gitcode.com/gh_mirrors/fir/firmware 还在为网络摄像头的封闭系统而烦恼吗&#xff1f;想要完全掌控…...

天文时序数据分析:机器学习评估、半监督学习与无监督方法实战

1. 项目概述&#xff1a;当机器学习遇见星空 处理海量的天文时序数据&#xff0c;比如来自Kepler、TESS这些“巡天巨眼”的光变曲线&#xff0c;早已不是靠人眼一张张图去翻的时代了。数据量太大&#xff0c;噪声复杂&#xff0c;信号微弱&#xff0c;传统方法常常力不从心。这…...

解决claude code频繁封号与token不足的taotoken接入方案

&#x1f680; 告别海外账号与网络限制&#xff01;稳定直连全球优质大模型&#xff0c;限时半价接入中。 &#x1f449; 点击领取海量免费额度 解决Claude Code频繁封号与Token不足的Taotoken接入方案 1. 问题背景&#xff1a;Claude Code用户面临的挑战 对于依赖Claude Cod…...