当前位置: 首页 > news >正文

人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shapetorch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

 很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。  

相关文章:

人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现 如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。 import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2lbatch_size, num_steps 32, 35 t…...

02_MySQL体系结构及数据文件介绍

#课程目标 了解MySQL的体系结构了解MySQL常见的日志文件及作用了解事务的控制语句,提交和回滚能够查看当前数据库的版本和用户了解MySQL数据库如何存放数据能在使用SQL语句创建、删除数据库 #一、MySQL的体系结构 ##1、客户端(连接者) MySQL的客户端可以是某个客户…...

【Web安全】xsstrike工具使用方法表格

xsstrike工具使用方法表格 版本:XSStrike v3.1.5 项目地址: https://github.com/s0md3v/XSStrike使用文档: usage: xsstrike.py [-h] [-u TARGET] [--data PARAMDATA] [-e ENCODE] [--fuzzer] [--update] [--timeout TIMEOUT] [--proxy][…...

python实现鼠标实时坐标监测

python实现鼠标实时坐标监测 一、说明 使用了以下技术和库: tkinter:用于创建GUI界面。pyperclip:用于复制文本到剪贴板。pynput.mouse:用于监听鼠标事件,包括移动和点击。threading:用于创建多线程&…...

【华为OD】C卷真题 100%通过:攀登者1 C/C++源码实现

【华为OD】C卷真题 100%通过:攀登者1 C/C源码实现 目录 题目描述: 示例1 代码实现: 题目描述: 攀登者喜欢寻找各种地图,并且尝试攀登到最高的山峰。 地图表示为一维数组,数组的索引代表水平位置&…...

Flask,uWSGI,nginx的理解

文章目录 前言与背景理解 - FlaskuWSGInginx理解 - nginx理解 - FlaskuWSGI理解 - vuedjangonginx 前言与背景 此篇文章是针对小白的一篇理解Flask,uWSGI,nginx的文章,只介绍了理解,并没有介绍如何部署。 由于工作需要使用flask…...

【JAVA杂货铺】一文带你走进面向对象编程|继承|重载|重写|期末复习系列 | (中4)

🌈个人主页: Aileen_0v0🔥系列专栏:Java学习系列专栏💫个人格言:"没有罗马,那就自己创造罗马~" 目录 继承 私有成员变量在继承中的使用​编辑 当子类和父类变量不重名时: 当子类和父类重名时: 📝总结: 继承的含义: …...

单细胞seurat入门—— 从原始数据到表达矩阵

根据所使用的建库方法,单细胞的RNA序列(也称为读取(reads)或标签(tags))将从转录本的3端(或5端)(10X Genomics,CEL-seq2,Drop-seq&…...

Docker部署Nacos

此篇文章使用的nacos为2.2.1版本 拉取Nacos镜像 docker pull nacos/nacos-server:v2.2.1先将容器启动起来 docker run -d \ --name nacos \ -p 8848:8848 \ -p 9848:9848 \ -p 9849:9849 \ --privilegedtrue \ -e JVM_XMS256m \ -e JVM_XMX256m \ -e MODEstandalone \ -e NA…...

1005. K 次取反后最大化的数组和

原题链接:1005. K 次取反后最大化的数组和 思路: 先把数组排序好,然后直接从下标0(最小的负数)开始反转,那么接下来有两种情况: 1.负数反转完了,k还有剩余。此时因为nums内全部都是正数,所以我…...

【云原生】什么是 Kubernetes ?

什么是 Kubernetes ? Kubernetes 是一个开源容器编排平台,管理着一系列的 主机 或者 服务器,它们被称作是 节点(Node)。 每一个节点运行了若干个相互独立的 Pod。 Pod 是 Kubernetes 中可以部署的 最小执行单元&#x…...

自建CA实战之 《0x01 Nginx 配置 https单向认证》

自建私有化证书颁发机构(Certificate Authority,CA)实战之 《0x01 Nginx 配置 https单向认证》 上一篇文章我们介绍了如何自建私有化证书颁发机构(Certificate Authority,CA),本篇文章我们将介…...

《QT从基础到进阶·三十八》QWidget实现炫酷log日志打印界面

QWidget实现了log日志的打印功能,不仅可以在界面显示,还可以生成打印日志。先来看下效果,源码放在文章末尾: LogPlugin插件类管理log所有功能,它可以获取Log界面并能打印正常信息,警告信息和错误信息&…...

JVM的小知识总结

加载时jvm做了这三件事: 1)通过一个类的全限定名来获取该类的二进制字节流 什么是全限定类名? 就是类名全称,带包路径的用点隔开,例如: java.lang.String。 即全限定名 包名类型 非限定类名也叫短名,就…...

深入理解JVM虚拟机第二十六篇:详解JVM当中的虚方法和非虚方法,并从字节码指令的角度去分析虚方法和非虚方法

😉😉 学习交流群: ✅✅1:这是孙哥suns和树哥给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783824 📚​​​​​​​📚 微信:DashuDeveloper拉你进微信群,免费领取! 一:非虚方法和虚方法 方法…...

ElasticSearch的日志配置

ElasticSearch默认情况下使用Log4j2来记录日志,日志配置文件的路径为$ES_HOME/config/log4j2.properties,配置方法见Log4j2的官方文档。 参考path-settings,通过指定path.logs,可以指定日志文件的保存路径。 在日志配置文件$ES_…...

SQL Injection (Blind)`

SQL Injection (Blind) SQL Injection (Blind) SQL盲注,是一种特殊类型的SQL注入攻击,它的特点是无法直接从页面上看到注入语句的执行结果。在这种情况下,需要利用一些方法进行判断或者尝试,这个过程称之为盲注。 盲注的主要形式有…...

NX二次开发UF_CURVE_ask_trim 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_ask_trim Defined in: uf_curve.h int UF_CURVE_ask_trim(tag_t trim_feature, UF_CURVE_trim_p_t trim_info ) overview 概述 Retrieve the current parameters of an a…...

linux的netstat命令和ss命令

1. 网络状态 State状态LISTENING监听中,服务端需要打开一个socket进行监听,侦听来自远方TCP端口的连接请求ESTABLISHED已连接,代表一个打开的连接,双方可以进行或已经在数据交互了SYN_SENT客户端通过应用程序调用connect发送一个…...

python:傅里叶分析,傅里叶变换 FFT

使用python进行傅里叶分析,傅里叶变换 FFT 的一些关键概念的引入: 1.1.离散傅里叶变换(DFT) 离散傅里叶变换(discrete Fourier transform) 傅里叶分析方法是信号分析的最基本方法,傅里叶变换是傅里叶分析的核心&…...

vscode里如何用git

打开vs终端执行如下: 1 初始化 Git 仓库(如果尚未初始化) git init 2 添加文件到 Git 仓库 git add . 3 使用 git commit 命令来提交你的更改。确保在提交时加上一个有用的消息。 git commit -m "备注信息" 4 …...

Ubuntu系统下交叉编译openssl

一、参考资料 OpenSSL&&libcurl库的交叉编译 - hesetone - 博客园 二、准备工作 1. 编译环境 宿主机:Ubuntu 20.04.6 LTSHost:ARM32位交叉编译器:arm-linux-gnueabihf-gcc-11.1.0 2. 设置交叉编译工具链 在交叉编译之前&#x…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

P3 QT项目----记事本(3.8)

3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

AI书签管理工具开发全记录(十九):嵌入资源处理

1.前言 &#x1f4dd; 在上一篇文章中&#xff0c;我们完成了书签的导入导出功能。本篇文章我们研究如何处理嵌入资源&#xff0c;方便后续将资源打包到一个可执行文件中。 2.embed介绍 &#x1f3af; Go 1.16 引入了革命性的 embed 包&#xff0c;彻底改变了静态资源管理的…...

使用Matplotlib创建炫酷的3D散点图:数据可视化的新维度

文章目录 基础实现代码代码解析进阶技巧1. 自定义点的大小和颜色2. 添加图例和样式美化3. 真实数据应用示例实用技巧与注意事项完整示例(带样式)应用场景在数据科学和可视化领域,三维图形能为我们提供更丰富的数据洞察。本文将手把手教你如何使用Python的Matplotlib库创建引…...

STM32---外部32.768K晶振(LSE)无法起振问题

晶振是否起振主要就检查两个1、晶振与MCU是否兼容&#xff1b;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容&#xff08;CL&#xff09;与匹配电容&#xff08;CL1、CL2&#xff09;的关系 2. 如何选择 CL1 和 CL…...

关于easyexcel动态下拉选问题处理

前些日子突然碰到一个问题&#xff0c;说是客户的导入文件模版想支持部分导入内容的下拉选&#xff0c;于是我就找了easyexcel官网寻找解决方案&#xff0c;并没有找到合适的方案&#xff0c;没办法只能自己动手并分享出来&#xff0c;针对Java生成Excel下拉菜单时因选项过多导…...

区块链技术概述

区块链技术是一种去中心化、分布式账本技术&#xff0c;通过密码学、共识机制和智能合约等核心组件&#xff0c;实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点&#xff1a;数据存储在网络中的多个节点&#xff08;计算机&#xff09;&#xff0c;而非…...