当前位置: 首页 > news >正文

人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shapetorch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape(torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):"""循环神经网络模型"""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)# 它的输出形状是(时间步数*批量大小,词表大小)。output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# nn.GRU以张量作为隐状态return  torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device)else:# nn.LSTM以元组作为隐状态return (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

 很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。  

相关文章:

人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现 如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。 import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2lbatch_size, num_steps 32, 35 t…...

02_MySQL体系结构及数据文件介绍

#课程目标 了解MySQL的体系结构了解MySQL常见的日志文件及作用了解事务的控制语句,提交和回滚能够查看当前数据库的版本和用户了解MySQL数据库如何存放数据能在使用SQL语句创建、删除数据库 #一、MySQL的体系结构 ##1、客户端(连接者) MySQL的客户端可以是某个客户…...

【Web安全】xsstrike工具使用方法表格

xsstrike工具使用方法表格 版本:XSStrike v3.1.5 项目地址: https://github.com/s0md3v/XSStrike使用文档: usage: xsstrike.py [-h] [-u TARGET] [--data PARAMDATA] [-e ENCODE] [--fuzzer] [--update] [--timeout TIMEOUT] [--proxy][…...

python实现鼠标实时坐标监测

python实现鼠标实时坐标监测 一、说明 使用了以下技术和库: tkinter:用于创建GUI界面。pyperclip:用于复制文本到剪贴板。pynput.mouse:用于监听鼠标事件,包括移动和点击。threading:用于创建多线程&…...

【华为OD】C卷真题 100%通过:攀登者1 C/C++源码实现

【华为OD】C卷真题 100%通过:攀登者1 C/C源码实现 目录 题目描述: 示例1 代码实现: 题目描述: 攀登者喜欢寻找各种地图,并且尝试攀登到最高的山峰。 地图表示为一维数组,数组的索引代表水平位置&…...

Flask,uWSGI,nginx的理解

文章目录 前言与背景理解 - FlaskuWSGInginx理解 - nginx理解 - FlaskuWSGI理解 - vuedjangonginx 前言与背景 此篇文章是针对小白的一篇理解Flask,uWSGI,nginx的文章,只介绍了理解,并没有介绍如何部署。 由于工作需要使用flask…...

【JAVA杂货铺】一文带你走进面向对象编程|继承|重载|重写|期末复习系列 | (中4)

🌈个人主页: Aileen_0v0🔥系列专栏:Java学习系列专栏💫个人格言:"没有罗马,那就自己创造罗马~" 目录 继承 私有成员变量在继承中的使用​编辑 当子类和父类变量不重名时: 当子类和父类重名时: 📝总结: 继承的含义: …...

单细胞seurat入门—— 从原始数据到表达矩阵

根据所使用的建库方法,单细胞的RNA序列(也称为读取(reads)或标签(tags))将从转录本的3端(或5端)(10X Genomics,CEL-seq2,Drop-seq&…...

Docker部署Nacos

此篇文章使用的nacos为2.2.1版本 拉取Nacos镜像 docker pull nacos/nacos-server:v2.2.1先将容器启动起来 docker run -d \ --name nacos \ -p 8848:8848 \ -p 9848:9848 \ -p 9849:9849 \ --privilegedtrue \ -e JVM_XMS256m \ -e JVM_XMX256m \ -e MODEstandalone \ -e NA…...

1005. K 次取反后最大化的数组和

原题链接:1005. K 次取反后最大化的数组和 思路: 先把数组排序好,然后直接从下标0(最小的负数)开始反转,那么接下来有两种情况: 1.负数反转完了,k还有剩余。此时因为nums内全部都是正数,所以我…...

【云原生】什么是 Kubernetes ?

什么是 Kubernetes ? Kubernetes 是一个开源容器编排平台,管理着一系列的 主机 或者 服务器,它们被称作是 节点(Node)。 每一个节点运行了若干个相互独立的 Pod。 Pod 是 Kubernetes 中可以部署的 最小执行单元&#x…...

自建CA实战之 《0x01 Nginx 配置 https单向认证》

自建私有化证书颁发机构(Certificate Authority,CA)实战之 《0x01 Nginx 配置 https单向认证》 上一篇文章我们介绍了如何自建私有化证书颁发机构(Certificate Authority,CA),本篇文章我们将介…...

《QT从基础到进阶·三十八》QWidget实现炫酷log日志打印界面

QWidget实现了log日志的打印功能,不仅可以在界面显示,还可以生成打印日志。先来看下效果,源码放在文章末尾: LogPlugin插件类管理log所有功能,它可以获取Log界面并能打印正常信息,警告信息和错误信息&…...

JVM的小知识总结

加载时jvm做了这三件事: 1)通过一个类的全限定名来获取该类的二进制字节流 什么是全限定类名? 就是类名全称,带包路径的用点隔开,例如: java.lang.String。 即全限定名 包名类型 非限定类名也叫短名,就…...

深入理解JVM虚拟机第二十六篇:详解JVM当中的虚方法和非虚方法,并从字节码指令的角度去分析虚方法和非虚方法

😉😉 学习交流群: ✅✅1:这是孙哥suns和树哥给大家的福利! ✨✨2:我们免费分享Netty、Dubbo、k8s、Spring...应用和源码级别的视频资料 🥭🥭3:QQ群:583783824 📚​​​​​​​📚 微信:DashuDeveloper拉你进微信群,免费领取! 一:非虚方法和虚方法 方法…...

ElasticSearch的日志配置

ElasticSearch默认情况下使用Log4j2来记录日志,日志配置文件的路径为$ES_HOME/config/log4j2.properties,配置方法见Log4j2的官方文档。 参考path-settings,通过指定path.logs,可以指定日志文件的保存路径。 在日志配置文件$ES_…...

SQL Injection (Blind)`

SQL Injection (Blind) SQL Injection (Blind) SQL盲注,是一种特殊类型的SQL注入攻击,它的特点是无法直接从页面上看到注入语句的执行结果。在这种情况下,需要利用一些方法进行判断或者尝试,这个过程称之为盲注。 盲注的主要形式有…...

NX二次开发UF_CURVE_ask_trim 函数介绍

文章作者:里海 来源网站:https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_ask_trim Defined in: uf_curve.h int UF_CURVE_ask_trim(tag_t trim_feature, UF_CURVE_trim_p_t trim_info ) overview 概述 Retrieve the current parameters of an a…...

linux的netstat命令和ss命令

1. 网络状态 State状态LISTENING监听中,服务端需要打开一个socket进行监听,侦听来自远方TCP端口的连接请求ESTABLISHED已连接,代表一个打开的连接,双方可以进行或已经在数据交互了SYN_SENT客户端通过应用程序调用connect发送一个…...

python:傅里叶分析,傅里叶变换 FFT

使用python进行傅里叶分析,傅里叶变换 FFT 的一些关键概念的引入: 1.1.离散傅里叶变换(DFT) 离散傅里叶变换(discrete Fourier transform) 傅里叶分析方法是信号分析的最基本方法,傅里叶变换是傅里叶分析的核心&…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...

学校招生小程序源码介绍

基于ThinkPHPFastAdminUniApp开发的学校招生小程序源码,专为学校招生场景量身打造,功能实用且操作便捷。 从技术架构来看,ThinkPHP提供稳定可靠的后台服务,FastAdmin加速开发流程,UniApp则保障小程序在多端有良好的兼…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 (一)多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如,当用户上传一张“蓝色连衣裙”的图片时,接口可自动提取图像中的颜色(RGB值&…...

对WWDC 2025 Keynote 内容的预测

借助我们以往对苹果公司发展路径的深入研究经验,以及大语言模型的分析能力,我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际,我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测,聊作存档。等到明…...

从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)

设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...

css3笔记 (1) 自用

outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size&#xff1a;0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格&#xff…...

C# 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

React---day11

14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store&#xff1a; 我们在使用异步的时候理应是要使用中间件的&#xff0c;但是configureStore 已经自动集成了 redux-thunk&#xff0c;注意action里面要返回函数 import { configureS…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...