当前位置: 首页 > news >正文

动手学深度学习——循环神经网络的从零开始实现(原理解释+代码详解)

文章目录

    • 循环神经网络的从零开始实现
      • 1. 独热编码
      • 2. 初始化模型参数
      • 3. 循环神经网络模型
      • 4. 预测
      • 5. 梯度裁剪
      • 6. 训练

循环神经网络的从零开始实现

从头开始基于循环神经网络实现字符级语言模型。

# 读取数据集
%matplotlib inline
import math
import torchfrom torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

1. 独热编码

每个词元都有一个对应的索引,表示为特征向量,即每个索引映射为相互不同的单位向量。

词元表不同词元个数为N,词元索引范围为0到N-1。词元的索引为整数,那么将创建一个长度为N的全0向量,并将第i处元素设置为1。则此向量是原始词元的一个独热编码。

假如有2个词元"cat"和"dog"

  • "cat"对应:[1, 0]
  • "dog"对应:[0, 1]

索引为0和2的独热向量

# 索引为0和2的独热向量
F.one_hot(torch.tensor([0, 2]), len(vocab))

在这里插入图片描述
采样的小批量数据形状为二维张量:(批量大小,时间步数),one_hot函数将其转换为三维张量:(时间步数,批量大小,词表大小)

# 采样的小批量数据形状为二维张量:(批量大小,时间步数)
# one_hot函数将其转换为三维张量:(时间步数,批量大小,词表大小)
# 方便我们通过最外层维度,一步一步更新小批量数据的隐状态
X = torch.arange(10).reshape((2, 5))
print(F.one_hot(X.T, 28).shape)
# 显示第一行
F.one_hot(X.T, 28)[0,:,:]

在这里插入图片描述

2. 初始化模型参数

隐藏单元数num_hiddens是一个可调的超参数

训练语言模型时,输入和输出来自相同的词表,具有相同的维度即词表大小

"""
初始化模型参数:1、隐藏层参数2、输出层参数3、附加梯度
"""
# (词表大小,隐藏层数,设备)
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_size# 定义函数normal(),初始化模型的参数def normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

3. 循环神经网络模型

定义init_rnn_state函数在初始化时返回隐状态,该函数的返回是一个张量,张量全用0填充,形状为(批量大小,隐藏单元数)。

# 定义init_rnn_state函数在初始化时返回隐状态
# 该函数的返回是一个张量,张量全用0填充,形状为(批量大小,隐藏单元数)
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

在这里插入图片描述
循环神经网络通过最外层的维度实现循环,以便时间步更新小批量数据的隐状态H

# 循环神经网络通过最外层的维度实现循环,以便时间步更新小批量数据的隐状态H
def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:# 激活函数tanh,更新隐状态HH = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

创建一个类来包装这些函数, 并存储从零开始实现的循环神经网络模型的参数

"""
从零开始实现的循环神经网络模型:
1、定义网络模型的参数
2、对词表进行独热编码
3、初始化模型参数并返回隐状态
"""
class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""# 定义类的初始化,将传入的参数赋值给对象的属性,以便后续使用def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):# 对输入进行独热编码,返回状态及参数X = F.one_hot(X.T, self.vocab_size).type(torch.float32)return self.forward_fn(X, state, self.params)def begin_state(self, batch_size, device):# 初始化参数return self.init_state(batch_size, self.num_hiddens, device)

检查输出是否具有正确的形状。 例如,隐状态的维数是否保持不变。

num_hiddens = 512
# 网络模型
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
# 获得网络初始状态
state = net.begin_state(X.shape[0], d2l.try_gpu())
# 将X移到GPU上,并且返回输出Y和状态
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape

在这里插入图片描述
可以看到输出形状是(时间步数x批量大小,词表大小), 而隐状态形状保持不变,即(批量大小,隐藏单元数)。

4. 预测

定义预测函数

"""
定义预测函数:
1、prefix是用户提供的字符串;
2、循环遍历prefix的开始字符时不输出,不断将隐状态传递给下一个时间步;
3、在此期间模型进行自我更新(隐状态),不进行预测;
4、2和3步骤称为预热期,预热期过后隐状态的值更适合预测,从而预测字符并输出。
"""
# prefix:前缀字符串
def predict_ch8(prefix, num_preds, net, vocab, device):  #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]# 匿名函数:改变输出的形状get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))# 预热期:不进行输出for y in prefix[1:]:  # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])# 预热期过了之后,进行预测for _ in range(num_preds):  # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])

测试predict_ch8函数。 我们将前缀指定为time traveller, 并基于这个前缀生成10个后续字符

# 测试predict_ch8函数。 我们将前缀指定为time traveller, 并基于这个前缀生成10个后续字符。
# 未训练模型,输出预测结果没有联系
predict_ch8('time traveller ', 10, net, vocab, d2l.try_gpu())

在这里插入图片描述

5. 梯度裁剪

为什么要梯度裁剪:
1、对于长度为T的序列,我们在迭代中计算T个时间步上的梯度,在反向传播过程中产生长度为T的矩阵乘法链;
2、T较大时,会导致数值不稳定,例如梯度消失或者梯度爆炸。

一个流行的替代方案是通过将梯度g投影回给定半径 (例如θ)的球来裁剪梯度g。
在这里插入图片描述

def grad_clipping(net, theta):  #@save"""裁剪梯度"""if isinstance(net, nn.Module):# 附加梯度的参数params = [p for p in net.parameters() if p.requires_grad]else:# 梯度的范数:对应图里作为分母的"||g||"params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))# 如果梯度过大,将其限制到θif norm > theta:for param in params:param.grad[:] *= theta / norm

6. 训练

在一个迭代周期内训练模型:
1、序列数据的不同采样方法(随机采样和顺序分区)将导致状态初始化的差异;
2、在更新模型参数之前裁剪梯度,这样可以保证训练过程中如果某点发生梯度爆炸,模型也不会发散;
3、用困惑度评价模型,使得不同长度的序列也有了可比性。

  • 顺序分区:只在每个迭代周期的开始位置初始化隐状态。
  • 随机抽样:每个样本都是在一个随机位置抽样的,因此需要在每个迭代周期重新初始化隐状态。
#@save
"""
训练网络一个迭代周期:
1、初始化状态,将数据传到GPU上
2、计算损失,进行梯度裁剪并更新模型参数
"""
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练网络一个迭代周期(定义见第8章)"""# 状态,时间state, timer = None, d2l.Timer()metric = d2l.Accumulator(2)  # 训练损失之和,词元数量for X, Y in train_iter:if state is None or use_random_iter:# 在第一次迭代或使用随机抽样时初始化statestate = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(net, nn.Module) and not isinstance(state, tuple):# state对于nn.GRU是个张量# detach_()将张量从计算图中分离出来,不会影响到原始张量state.detach_()else:# state对于nn.LSTM或对于我们从零开始实现的模型是个张量for s in state:s.detach_()# 将Y 进行转置并展平成一维向量y = Y.T.reshape(-1)# 将X,y移动到设备上,并且输入到模型中X, y = X.to(device), y.to(device)y_hat, state = net(X, state)l = loss(y_hat, y.long()).mean()# 如果更新器 updater 是 torch.optim.Optimizer 类型,则调用 updater.step() 方法进行参数更新;# 否则调用 updater(batch_size=1) 进行参数更新。if isinstance(updater, torch.optim.Optimizer):updater.zero_grad() # 梯度置零l.backward() # 反向传播,知道如何调整参数以最小化损失函数grad_clipping(net, 1) # 梯度裁剪updater.step() # 使用优化器来更新参数else:l.backward()grad_clipping(net, 1)# 因为已经调用了mean函数updater(batch_size=1)# y.numel()计算y中元素数量metric.add(l * y.numel(), y.numel())# 使用指数损失函数计算累积平均困惑度 math.exp(metric[0] / metric[1]) 和训练速度 metric[1] / timer.stop()。return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()
  • updater.zero_grad(): 这一行代码将模型参数的梯度置零,以便在每次迭代中计算新的梯度。
  • l.backward(): 这一行代码使用反向传播算法计算损失函数对模型参数的梯度。通过计算梯度,我们可以知道如何调整模型参数以最小化损失函数。
  • grad_clipping(net, 1): 这一行代码对模型的梯度进行裁剪,以防止梯度爆炸的问题。梯度爆炸可能会导致训练不稳定,裁剪梯度可以限制梯度的范围。
  • updater.step(): 这一行代码使用优化器(如SGD、Adam等)来更新模型的参数。优化器根据计算得到的梯度和预定义的学习率来更新模型参数,以使模型更好地拟合训练数据。

循环神经网络的训练函数也支持高级API实现

# 循环神经网络的训练函数也支持高级API实现
#@save
def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型(定义见第8章)"""loss = nn.CrossEntropyLoss()# 动画窗口:窗口显示一个图例,图例名称为 "train",x 轴的范围从 10 到 num_epochsanimator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)# 训练和预测for epoch in range(num_epochs):ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter)# 每10个epoch,对输入字符串进行预测,并将预测结果添加到动画中if (epoch + 1) % 10 == 0:print(predict('time traveller'))animator.add(epoch + 1, [ppl])print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')print(predict('time traveller'))print(predict('traveller'))

在数据集中只使用了10000个词元, 所以模型需要更多的迭代周期来更好地收敛

# 在数据集中只使用了10000个词元, 所以模型需要更多的迭代周期来更好地收敛
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())

在这里插入图片描述
检查一下随机抽样方法的结果

# 检查一下随机抽样方法的结果
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(),use_random_iter=True)

在这里插入图片描述

相关文章:

动手学深度学习——循环神经网络的从零开始实现(原理解释+代码详解)

文章目录 循环神经网络的从零开始实现1. 独热编码2. 初始化模型参数3. 循环神经网络模型4. 预测5. 梯度裁剪6. 训练 循环神经网络的从零开始实现 从头开始基于循环神经网络实现字符级语言模型。 # 读取数据集 %matplotlib inline import math import torchfrom torch import …...

【操作系统】文件系统的逻辑结构与目录结构

文章目录 文件的概念定义属性基本操作 文件的结构文件的逻辑结构文件的目录结构文件控制块(FCB)索引节点目录结构 文件的概念 定义 在操作系统中,文件被定义为:以计算机硬盘为载体的存储在计算机上的信息集合。 属性 描述文件…...

局域网内Ubuntu上搭建Git服务器

1.在局域网内选定一台Ubuntu电脑作为Git服务端: (1).新建用户如为fbc,执行如下命令:需设置密码,此为fbc sudo adduser fbc (2).切换到fbc用户:需密码,此前设置为fbc su fbc (3).建一个空目录作为仓…...

基础课10——自然语言生成

自然语言生成是让计算机自动或半自动地生成自然语言的文本。这个领域涉及到自然语言处理、语言学、计算机科学等多个领域的知识。 1.简介 自然语言生成系统可以分为基于规则的方法和基于统计的方法两大类。基于规则的方法主要依靠专家知识库和语言学规则来生成文本&#xff0…...

xpath

xpath 使用 使用 from lxml import etree或者 from lxml import htmlet etree.XML(xml) et etree.HTML(html) res et.xpath("/book") # 返回列表项目Valueet.xpath(“/book”)/表示根节点/div/a子节点用/依次表示/name/text()text()取文本/book//nick//表示标签…...

Java拼图小游戏

Java拼图小游戏 import javax.swing.*; import java.awt.*; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import java.awt.image.BufferedImage; import java.util.ArrayList; import java.util.Collections; import java.util.List;public cla…...

终于有人把数据资产入表知识地图总结出来了,轻松看懂

在当前数字化的浪潮下,数据已经成为劳动、土地、知识、技术以后的第五大生产要素,“数据就是资源”已成为共识。如今数据资产“入表”已成定局,数据资产化迫在眉睫。 2023年8月21日,财政部正式印发《企业数据资源相关会计处理暂行…...

白鳝:聊聊IvorySQL的Oracle兼容技术细节与实现原理

两年前听瀚高的一个朋友说他们要做一个开源数据库项目,基于PostgreSQL,主打与Oracle的兼容性,并且与PG社区版内核同步发布。当时我听了有点不太相信,瀚高的Highgo是在PG内核上增加了一定的Oracle兼容性的特性,一般也会…...

vue和uni-app的递归组件排坑

有这样一个数组数据,实际可能有很多级。 tree: [{id: 1,name: 1,children: [{ id: 2, name: 1-1, children: [{id: 7, name: 1-1-1,children: []}]},{ id: 3, name: 1-2 }]},{id: 4,name: 2,children: [{ id: 5, name: 2-1 },{ id: 6, name: 2-2 }]} ]要渲染为下面…...

【考研】数据结构(更新到顺序表)

声明&#xff1a;所有代码都可以运行&#xff0c;可以直接粘贴运行&#xff08;只有库函数没有声明&#xff09; 线性表的定义和基本操作 基本操作 定义 静态&#xff1a; #include<stdio.h> #include<stdlib.h>#define MaxSize 10//静态 typedef struct{int d…...

汇编-指针

一个变量如果包含的是另一个变量的地址&#xff0c; 则该变量就称为指针(pointer) 。指针是操作数组和数据结构的极好工具&#xff0c;因为它包含的地址在运行时是可以修改的。 .data arrayB byte 10h, 20h, 30h, 40h ptrB dword arrayB ptrB1 dword OFFSET arrayBarray…...

常见Web安全

一.Web安全概述 以下是百度百科对于web安全的解释&#xff1a; Web安全&#xff0c;计算机术语&#xff0c;随着Web2.0、社交网络、微博等等一系列新型的互联网产品的诞生&#xff0c;基于Web环境的互联网应用越来越广泛&#xff0c;企业信息化的过程中各种应用都架设在Web平台…...

milvus数据库搜索

一、向量相似度搜索 在Milvus中进行向量相似度搜索时&#xff0c;会计算查询向量和集合中具有指定相似性度量的向量之间的距离&#xff0c;并返回最相似的结果。通过指定一个布尔表达式来过滤标量字段或主键字段&#xff0c;您可以执行混合搜索。 1.加载集合 执行操作的前提是…...

HEVC参考帧技术

为了增强参考帧管理的抗差错能力&#xff0c;HEVC采用了参考帧集技术&#xff0c;通过直接在每一帧的片头码流中传输DPB中各个帧的状态变化&#xff0c;将当前帧以及后续帧可能用到的参考帧在DPB中都进行描述&#xff0c;描述以POC作为一帧的身份标识。因此&#xff0c;不需要依…...

QT小记:The QColor ctor taking ints is cheaper than the one taking string literals

这个警告意味着在使用 Qt 的 C 代码中&#xff0c;使用接受整数参数的 QColor 构造函数比使用接受字符串字面值的构造函数更有效率。 要解决这个警告&#xff0c;你可以修改你的代码&#xff0c;尽可能使用接受整数参数的 QColor 构造函数&#xff0c;而不是字符串字面值。例如…...

机器人走迷宫问题

题目 1.房间有XY的方格组成&#xff0c;例如下图为64的大小。每一个方格以坐标(x,y) 描述。 2.机器人固定从方格(0, 0)出发&#xff0c;只能向东或者向北前进&#xff0c;出口固定为房间的最东北角&#xff0c;如下图的 方格(5,3)。用例保证机器人可以从入口走到出口。 3.房间…...

轻量封装WebGPU渲染系统示例<36>- 广告板(Billboard)(WGSL源码)

原理不再赘述&#xff0c;请见wgsl shader实现。 当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/BillboardEntityTest.ts 当前示例运行效果: WGSL顶点shader: group(0) binding(0) var<uniform> objMat :…...

Java 多线程进阶

1 方法执行与进程执行 GetMapping("/demo1")public void demo1(){//方法调用new ThreadTest1("run1").run();//线程调用new ThreadTest1("run2").start();} 下断点调试信息&#xff0c;可以看到run()方法当前线程是“main1” 继续运行到run里面&…...

CentOS上搭建SVN并自动同步至web目录

一、搭建svn环境并创建仓库&#xff1a; 1、安装Subversion&#xff1a; yum install svn2、创建版本库&#xff1a; //先建目录 cd /www mkdir wwwsvn cd wwwsvn //创建版本库 svnadmin create xiangmumingcheng二、创建用户组及用户&#xff1a; 1、 进入版本库中的配…...

.Net中Redis的基本使用

前言 Redis可以用来存储、缓存和消息传递。它具有高性能、持久化、高可用性、扩展性和灵活性等特点&#xff0c;尤其适用于处理高并发业务和大量数据量的系统&#xff0c;它支持多种数据结构&#xff0c;如字符串、哈希表、列表、集合、有序集合等。 Redis的使用 安装包Ser…...

uniapp 对接腾讯云IM群组成员管理(增删改查)

UniApp 实战&#xff1a;腾讯云IM群组成员管理&#xff08;增删改查&#xff09; 一、前言 在社交类App开发中&#xff0c;群组成员管理是核心功能之一。本文将基于UniApp框架&#xff0c;结合腾讯云IM SDK&#xff0c;详细讲解如何实现群组成员的增删改查全流程。 权限校验…...

【OSG学习笔记】Day 18: 碰撞检测与物理交互

物理引擎&#xff08;Physics Engine&#xff09; 物理引擎 是一种通过计算机模拟物理规律&#xff08;如力学、碰撞、重力、流体动力学等&#xff09;的软件工具或库。 它的核心目标是在虚拟环境中逼真地模拟物体的运动和交互&#xff0c;广泛应用于 游戏开发、动画制作、虚…...

【AI学习】三、AI算法中的向量

在人工智能&#xff08;AI&#xff09;算法中&#xff0c;向量&#xff08;Vector&#xff09;是一种将现实世界中的数据&#xff08;如图像、文本、音频等&#xff09;转化为计算机可处理的数值型特征表示的工具。它是连接人类认知&#xff08;如语义、视觉特征&#xff09;与…...

Spring AI 入门:Java 开发者的生成式 AI 实践之路

一、Spring AI 简介 在人工智能技术快速迭代的今天&#xff0c;Spring AI 作为 Spring 生态系统的新生力量&#xff0c;正在成为 Java 开发者拥抱生成式 AI 的最佳选择。该框架通过模块化设计实现了与主流 AI 服务&#xff08;如 OpenAI、Anthropic&#xff09;的无缝对接&…...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

C++ 设计模式 《小明的奶茶加料风波》

&#x1f468;‍&#x1f393; 模式名称&#xff1a;装饰器模式&#xff08;Decorator Pattern&#xff09; &#x1f466; 小明最近上线了校园奶茶配送功能&#xff0c;业务火爆&#xff0c;大家都在加料&#xff1a; 有的同学要加波霸 &#x1f7e4;&#xff0c;有的要加椰果…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

STM32---外部32.768K晶振(LSE)无法起振问题

晶振是否起振主要就检查两个1、晶振与MCU是否兼容&#xff1b;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容&#xff08;CL&#xff09;与匹配电容&#xff08;CL1、CL2&#xff09;的关系 2. 如何选择 CL1 和 CL…...

提升移动端网页调试效率:WebDebugX 与常见工具组合实践

在日常移动端开发中&#xff0c;网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时&#xff0c;开发者迫切需要一套高效、可靠且跨平台的调试方案。过去&#xff0c;我们或多或少使用过 Chrome DevTools、Remote Debug…...

API网关Kong的鉴权与限流:高并发场景下的核心实践

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 引言 在微服务架构中&#xff0c;API网关承担着流量调度、安全防护和协议转换的核心职责。作为云原生时代的代表性网关&#xff0c;Kong凭借其插件化架构…...