当前位置: 首页 > news >正文

02- pytorch 实现 RNN

一 导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

1.1 导入训练数据

batch_size, num_steps = 32, 35
# 更改了默认的文件下载方式,需要将 article 文件放入该文件夹
train_iter, vocab = dltools.load_data_time_machine(batch_size, num_steps)

1.2 构造神经网络

num_hiddens = 256
# 构造了一个具有256个隐藏神经单元的单隐藏层的循环神经网络
rnn_layer = nn.RNN(len(vocab), num_hiddens)

构造了一个 循环神经网络 (RNN)  层,该 RNN 层具有以下特性:

  • num_hiddens = 256: 这行代码定义了 RNN 层中的隐藏单元数量,即 RNN 层内部神经元的数量。在这个例子中,设置为 256,意味着 RNN 层将有 256 个隐藏神经单元。

  • nn.RNN(len(vocab), num_hiddens): 这行代码 创建了一个 RNN 层 的实例。它的参数如下:

    • len(vocab): 这是 输入数据的特征维度。在循环神经网络中,输入数据通常是一个序列,每个时间步的输入是一个向量。len(vocab) 表示词汇表的大小,它代表了序列中的每个时间步可能的输入的数量。在自然语言处理任务中,词汇表的大小通常对应于词汇表中不同词汇的数量。

    • num_hiddens: 这是 RNN 层内部的 隐藏单元数量,根据之前定义的值为 256。

综上所述,这段代码创建了一个 具有 256 个隐藏神经单元的单隐藏层的循环神经网络层。这个 RNN 层可以用来处理序列数据,例如文本数据,在文本数据中,每个时间步可以对应一个词汇表中的一个词或一个词的嵌入表示。

1.3 初始化隐藏状态

# 初始化隐藏状态
state = torch.zeros((1, batch_size, num_hiddens))

创建了一个 全零的张量作为隐藏状态。张量的形状是 (1, batch_size, num_hiddens),其中:

  • 1 表示时间步的数量,这里初始化的是一个初始时间步的隐藏状态。
  • batch_size 表示批量大小,即一次处理的样本数量。
  • num_hiddens 表示每个时间步的隐藏单元数量,即隐藏状态的维度。

二 构建一个完整的循环神经网络

# 构建一个完整的循环神经网络
class RNNModel(nn.Module):def __init__(self, rnn_layer, vocab_size, **kwargs):super().__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_sizeif not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)# 前向传播def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)output = self.linear(Y.reshape(-1, Y.shape[-1]))return output, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn.num_layers, batch_size, self.num_hiddens), device=device)

该部分定义了一个名为 RNNModel 的 PyTorch 模型类,该模型是一个循环神经网络 (RNN) 模型,用于处理序列数据。

  1. __init__ 方法:这是类的构造函数,用于初始化模型的各个组件。在这里,做了以下工作:

    • super().__init__(**kwargs) 调用了父类的构造函数,确保正确初始化模型。
    • self.rnn = rnn_layer 存储了 传入的 RNN 层
    • self.vocab_size = vocab_size 存储了 词汇表的大小
    • self.num_hiddens = self.rnn.hidden_size 获取了 RNN 层的隐藏状态大小
    • 根据 RNN 是否是双向的,选择性地创建一个线性层,用于将 RNN 输出映射到词汇表大小的空间。如果是双向 RNN,则输入的维度是隐藏状态大小的两倍。
  2. forward 方法:这个方法定义了 前向传播 过程。它接受输入 inputs 和当前的隐藏状态 state。在前向传播中,它执行以下操作:

    • 使用 F.one_hot 将输入 inputs 转化为 独热编码,以便与词汇表大小匹配。然后将其转换为浮点数张量。
    • 将输入数据和隐藏状态传递给 RNN 层,以获得输出 Y 和新的 隐藏状态 state
    • 将 RNN 输出 Y 重塑成 二维张量,然后通过线性层 self.linear 将其映射到词汇表大小的空间,并返回输出结果。
  3. begin_state 方法:这个方法用于 初始化隐藏状态,返回一个全零的张量,其形状取决于 RNN 的层数、方向数、隐藏单元数以及批量大小。

2.1 实例化模型

# 在训练前,跑下模型
device = dltools.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)

创建了一个 RNNModel对象,该对象接受一个rnn_layer和一个词汇表大小作为参数。最后,它将模型移动到之前确定的设备上

三 执行训练

# 训练
num_epochs, lr = 200, 0.1
dltools.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

3.1 执行预测

dltools.predict_ch8('time traveller', 10, net, vocab, device)

相关文章:

02- pytorch 实现 RNN

一 导包 import torch from torch import nn from torch.nn import functional as F import dltools 1.1 导入训练数据 batch_size, num_steps 32, 35 # 更改了默认的文件下载方式,需要将 article 文件放入该文件夹 train_iter, vocab dltools.load_data_time_…...

算法课作业1

https://vjudge.net/contest/581138 A - Humidex 模拟题 题目大意 给三个类型数字通过公式来回转化 思路 求e的对数有log函数&#xff0c;不懂为什么不会出精度错误&#xff0c;很迷&#xff0c;给的三个数字也没有顺序&#xff0c;需要多判断。 #include<cstdio>…...

linux文本处理 两行变一行

linux简单文本处理 [rootkvm ~]# cat test 1.1.1.1 test1 2.2.2.2 test2 3.3.3.3 test3 192.168.1.2 test4 10.23.9.19 test5 cat test | awk /^[0-9]/{T$1;next;}{print T,$1}1.1.1.1 test1 2.2.2.2 test2 3.3.3.3 test3 192.168.1.2 test4 10.23.9.19 test5 cat test | …...

第二次面试 9.15

首先就是自我介绍 项目拷打 总体介绍一下项目 谈一下对socket的理解 在数据接收阶段&#xff0c;如何实现一个异步的数据处理 谈一谈对qt信号槽的理解 有想过如何去编写一个信号槽吗 你是如何使用CMAKE编译文件的 C11特性了解些啥 shared_ptr 和 unique_ptr 的运用场景 …...

基于matlab实现的平面波展开法二维声子晶体能带计算程序

Matlab 平面波展开法计算二维声子晶体二维声子晶体带结构计算&#xff0c;材料是铅柱在橡胶基体中周期排列&#xff0c;格子为正方形。采用PWE方法计算 完整程序: %%%%%%%%%%%%%%%%%%%%%%%%% clear;clc;tic;epssys1.0e-6; %设定一个最小量&#xff0c;避免系统截断误差或除零错…...

Minio入门系列【2】纠删码

1 纠删码 Minio使用纠删码erasure code和校验和checksum来保护数据免受硬件故障和无声数据损坏。 即便丢失一半数量&#xff08;N/2&#xff09;的硬盘&#xff0c;仍然可以恢复数据 1.1 什么叫纠删码 纠删码是一种用于重建丢失或损坏数据的数学算法。 纠删码&#xff08;e…...

基于永磁同步发电机的风力发电系统研究(Simulink实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

5.后端·新建子模块与开发(自动模式)

文章目录 学习资料自动生成模式创建后端三层 学习资料 https://www.bilibili.com/video/BV13g411Y7GS?p11&spm_id_frompageDriver&vd_sourceed09a620bf87401694f763818a31c91e 自动生成模式创建后端三层 首先&#xff0c;运行起来若依的前后端整个项目&#xff0c;…...

vue的data为什么要写成data(return{})这样而不是data:{}这样?

在Vue.js中&#xff0c;为什么要将data写成一个返回对象的函数data()而不是一个普通的对象data: {} 为什么&#xff1f; 因为Vue.js的组件实例是可复用的&#xff0c;而且它们可以在应用中多次实例化。通过将data定义为一个返回对象的函数&#xff0c;可以确保每个组件实例都…...

MySQL基础运维知识点大全

一. MySQL基本知识 1. 目录的功能 通用 Unix/Linux 二进制包的 MySQL 安装下目录的相关功能 目录目录目录binMySQLd服务器&#xff0c;客户端和实用程序docs信息格式的 MySQL 手册manUnix 手册页include包括&#xff08;头&#xff09;文件lib图书馆share用于数据库安装的错…...

javascript获取样式表的规则及读取与写入

CSSStyleSheet是继承了StyleSheet的接口属性,它是用于找当前文档中的<link rel“” href“”…>这样文件的&#xff0c;有以下属性&#xff1a;lenght,cssRules,title,href,type,deleteRule,insertRule等 CSSStyleRule是继承于CSSRule&#xff0c;它是用于找<link re…...

什么是promise?

是JavaScript中用于处理异步操作的一种机制。 异步操作&#xff0c;例如从服务器获取数据、读取文件、执行数据库查询等等。 经典使用&#xff1a;Axios 是一个基于Promise的HTTP客户端 Promise具有三个状态&#xff1a; Pending&#xff08;待定&#xff09;&#xff1a;Pr…...

从零开始学习软件测试-第45天笔记

monkey事件 事件&#xff1a;对app进行的操作&#xff0c;比如触摸事件&#xff0c;滑动事件...动作&#xff1a;构成一个事件所需要的步骤。 调整事件的百分比 adb shell monkey -p 包名 -v -v --pct-xxx 百分比 次数>输出文件的路径 分析日志有没有报错 到日志中去找…...

visual studio常用快捷键

CtrlM、CtrlO 折叠到定义 CtrlM、CtrlM 折叠当前定义 CtrlM、CtrlA 折叠全部 CtrlK、CtrlD 自动编排代码格式 F12 转到定义 ShiftF12 查看所有定义 ctrl] 转到定义首部或尾部 ctrlX 未选中文本时&#xff0c;剪切/删除光标所在行。ctrlV 未选中文本时&#xff0c;粘贴到…...

数据变换:数据挖掘的准备工作之一

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ &#x1f434;作者&#xff1a;秋无之地 &#x1f434;简介&#xff1a;CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作&#xff0c;主要擅长领域有&#xff1a;爬虫、后端、大数据…...

Go语言实践案例之简单字典

一、程序要实现效果&#xff1a; 在命令行调用程序的时候&#xff0c;可以在命令行的后面查询一个单词&#xff0c;然后会输出单词的音标和注释。 二、思路分析&#xff1a; 定义一个结构体 DictRequest&#xff0c;用于表示翻译请求的数据结构。其中包含了 TransType&#…...

笔试面试相关记录(3)

&#xff08;1&#xff09;String String和String.append()的底层实现 C中string append函数的使用与字符串拼接「建议收藏」-腾讯云开发者社区-腾讯云 (tencent.com) String String 在 第二个String中遇到\0就截止&#xff0c;append()的方法则是所有字符都会加在后面。 &…...

第6章_瑞萨MCU零基础入门系列教程之串行通信接口(SCI)

本教程基于韦东山百问网出的 DShanMCU-RA6M5开发板 进行编写&#xff0c;需要的同学可以在这里获取&#xff1a; https://item.taobao.com/item.htm?id728461040949 配套资料获取&#xff1a;https://renesas-docs.100ask.net 瑞萨MCU零基础入门系列教程汇总&#xff1a; ht…...

开源免费的流程图软件draw.io

2023年9月16日&#xff0c;周六上午 想买微软的visio&#xff0c;但发现不是很值得&#xff0c;因为我平时也不是经常需要画图。 所以我最后还是决定使用开源免费的draw.io来画图 draw.io网页版的网址&#xff1a; Flowchart Maker & Online Diagram Software draw.io的…...

Python绘图系统19:添加时间轴以实现动态绘图

文章目录 时间轴单帧跳转源代码 Python绘图系统&#xff1a; &#x1f4c8;从0开始的3D绘图系统&#x1f4c9;一套3D坐标&#xff0c;多个函数&#x1f4ca;散点图、极坐标和子图自定义控件&#xff1a;绘图风格&#x1f4c9;风格控件&#x1f4ca;定制绘图风格坐标设置进阶&a…...

OpenClaw+nanobot镜像:3步配置QQ聊天机器人触发AI任务

OpenClawnanobot镜像&#xff1a;3步配置QQ聊天机器人触发AI任务 1. 为什么选择OpenClawnanobot组合&#xff1f; 去年冬天&#xff0c;当我第一次尝试用QQ机器人自动处理群消息时&#xff0c;经历了漫长的环境配置地狱。直到发现星图平台的nanobot镜像&#xff0c;这个开箱即…...

Thermal Control Center:Dell G15散热管理的开源替代方案实战指南

Thermal Control Center&#xff1a;Dell G15散热管理的开源替代方案实战指南 【免费下载链接】tcc-g15 Thermal Control Center for Dell G15 - open source alternative to AWCC 项目地址: https://gitcode.com/gh_mirrors/tc/tcc-g15 在追求极致性能的游戏本领域&…...

Ubuntu 20.04安装MATLAB R2023B保姆级避坑指南:从卸载旧版到选对产品,一步一截图

Ubuntu 20.04安装MATLAB R2023B全流程实战&#xff1a;从彻底卸载到精准选配 在科研与工程计算领域&#xff0c;MATLAB始终保持着不可替代的地位。当最新版的R2023B遇上Ubuntu 20.04这个长期支持版本&#xff0c;如何实现完美部署却让不少用户望而却步。不同于Windows下的图形化…...

150万规模!深势开源科学图像界ImageNet,AI终于能看懂论文图表了

150 万图文对、500 万子图&#xff0c;全面覆盖 300 科学子学科。深势开源 OmniScience&#xff0c;让 AI 真正读懂科研文献图表。跨越“盲区”&#xff1a;让AI真正读懂科学影像在科学研究日益数字化的今天&#xff0c;大模型已经能够高效处理书籍与文献中的文本信息。不过&am…...

VMware虚拟机部署Mirage Flow:多环境测试方案

VMware虚拟机部署Mirage Flow&#xff1a;多环境测试方案 为开发测试构建安全可靠的隔离环境 1. 环境准备与虚拟机配置 在开始部署Mirage Flow之前&#xff0c;我们需要先准备好合适的测试环境。使用VMware虚拟机是个不错的选择&#xff0c;它能为我们提供一个完全隔离的测试空…...

Android 11 自动亮度算法优化与曲线配置解析

1. Android 11自动亮度技术演进 记得第一次用上Android 11的手机时&#xff0c;最让我惊喜的就是屏幕亮度调节变得特别"聪明"。以前在电影院掏出手机总被刺得睁不开眼&#xff0c;现在却能像人眼一样自然地适应环境。这背后其实是Google对自动亮度算法做了重大升级&a…...

LaTeX参考文献报错全解析:从\citation到\bibdata的避坑指南

LaTeX参考文献报错全解析&#xff1a;从\citation到\bibdata的避坑指南 当你熬夜赶论文时&#xff0c;突然在编译LaTeX文档时看到一串红色报错&#xff1a;"I found no \bibstyle command"、"I found no \bibdata command"、"I found no \citation co…...

三维智能分割技术:从行业痛点到落地实践的全面解析

三维智能分割技术&#xff1a;从行业痛点到落地实践的全面解析 【免费下载链接】SAMPart3D SAMPart3D: Segment Any Part in 3D Objects 项目地址: https://gitcode.com/gh_mirrors/sa/SAMPart3D 问题场景&#xff1a;三维模型处理的现实困境 建筑设计行业&#xff1a;…...

Python实战:5分钟用高德API搞定全国区县边界坐标采集(附完整代码)

Python实战&#xff1a;高德API高效获取全国区县边界坐标的工程化解决方案 1. 需求背景与方案设计 地理信息系统开发中经常需要精确的行政区划边界数据。传统手动采集方式效率低下&#xff0c;而高德地图API提供了完善的行政区划查询接口。本方案将实现&#xff1a; 全国省/…...

AI元人文构想:从自感养护到伦理中间件——一种智能时代的人文回应

AI元人文构想&#xff1a;从自感养护到伦理中间件——一种智能时代的人文回应---引言&#xff1a;技术时代的人文焦虑智能算法的深度嵌入&#xff0c;正在重塑人类感知、判断与意义生成的方式。推荐系统预判我们的欲望&#xff0c;社交平台定义我们的关系&#xff0c;大语言模型…...