当前位置: 首页 > news >正文

从0开始深度学习(32)——循环神经网络的从零开始实现

本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级)

首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据:

train_iter, vocab = load_corpus_time_machine()
train_iter

运行结果:
在这里插入图片描述

train_iter中的每个数字都表示在vocab中的索引,将这些索引直接输入神经网络可能会使学习变得困难,我们通常将每个词元表示为更具表现力的特征向量,即one-hot编码

1 one-hot编码

假设词表中不同词元的数目为 N N N,也就是len(vocab),所以词元索引的范围为 0 N − 1 0~N-1 0 N1。如果词元的索引是整数 i i i, 那么我们将创建一个长度为 N N N的全 0 0 0向量, 并将第 i i i处的元素设置为 1 1 1。例如索引为 0 0 0 2 2 2的独热向量如下所示:

from torch.nn import functional as F
import torchF.one_hot(torch.tensor([0, 2]), len(vocab))

运行结果:

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0],# 第0处设置为1,表示索引1[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0]])# 第2处设置为1,表示索引2

由于我们每次采样的小批量数据形状是二维张量:(批量大小,时间步数),one_hot函数会将这样一个小批量数据转换成三维张量, 张量的最后一个维度等于词表大小len(vocab)。我们经常转换输入的维度,以便获得形状为 (时间步数,批量大小,词表大小)的输出。 这将使我们能够更方便地通过最外层的维度, 一步一步地更新小批量数据的隐状态。

转化这一步将在后面进行

2 初始化模型参数

初始化循环神经网络模型的模型参数, 隐藏单元数num_hiddens是一个可调的超参数。 当训练语言模型时,输入和输出来自相同的词表,因此,它们具有相同的维度,即词表的大小:

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

3 循环神经网络模型

为了定义循环神经网络模型, 我们首先需要一个init_rnn_state函数在初始化时返回隐状态,这个函数的返回是一个张量,形状为(批量大小,隐藏单元数)

def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

下面的rnn函数定义了如何在一个时间步内计算隐状态输出

def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state # 提取隐状态outputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) # 使用激活函数Y = torch.mm(H, W_hq) + b_q # 做矩阵乘法,即 (隐变量*权重+偏置)outputs.append(Y)return torch.cat(outputs, dim=0), (H,)# cat()函数将所有时间步的输出拼接成一个张量,形状为 (时间步数量 * 批量大小, 输出大小)# (H,)为最后一个时间步的隐藏状态

定义了所有需要的函数之后,接下来我们创建一个类来包装这些函数, 并存储从零开始实现的循环神经网络模型的参数:

class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 进行one-hot编码return self.forward_fn(X, state, self.params)# 调用前向传播函数,传入编码后的输入数据、初始隐藏状态和模型参数def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)# 初始化的隐藏状态

我们可以做一个测试:

num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
state = net.begin_state(X.shape[0], d2l.try_gpu())
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape
# 输出形状,隐状态形状

运行结果:

(torch.Size([10, 28]), 1, torch.Size([2, 512]))

输出形状是(时间步数*批量大小,词表大小), 隐状态形状是(批量大小,隐藏单元数),符合要求

4 梯度裁剪

在编写训练函数之前,要引入一个方法——梯度裁剪,用于防止梯度爆炸问题。

对于长度为 T T T的序列,在迭代时要计算 T T T个时间步上的梯度,于是会在反向传播中产生长度为 O ( T ) O(T) O(T)的矩阵乘法链,当 T T T过大时,有可能导致梯度爆炸问题,所以循环神经网络需要额外的方式来支持稳定的训练,下面不讲解原理,直接给出一种流行的方法。

不过注意:梯度裁剪提供了一个快速修复梯度爆炸的方法, 虽然它并不能完全解决问题,但它是众多有效的技术之一。

def grad_clipping(net, theta):  #@save"""裁剪梯度"""if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm

5 训练

与线性神经网络的训练有三个不同之处:

  1. 序列数据的不同采样方法(随机采样和顺序分区)将导致隐状态初始化的差异。
  2. 在更新模型参数之前裁剪梯度,这样即使训练过程中某个点上发生了梯度爆炸,也能保证模型不会发散。
  3. 使用困惑度来评价模型

对于第一点做一些解释

  • 随机采样: 由于每次抽取的数据点是独立的,模型在处理每个样本时通常需要重新初始化隐状态。通常情况下,模型会在每个新的随机样本开始时,使用初始的隐状态。
  • 顺序分区: 由于数据点是按时间顺序排列的,模型在处理每个子序列时可以利用前一个子序列的隐状态作为当前子序列的初始隐状态
def predict_ch8(prefix, num_preds, net, vocab, device):  #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]:  # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds):  # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练一个迭代周期"""state, timer = None, 0metric = [0, 0]  # 累积损失,总词元数量for X, Y in train_iter:if X.shape[0] == 0 or Y.shape[0] == 0:  # 跳过空批次continueif state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(state, tuple):state = tuple(s.detach() for s in state)else:state = state.detach()X, Y = X.to(device), Y.T.reshape(-1).to(device)y_hat, state = net(X, state)l = loss(y_hat, Y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)updater(batch_size=1)metric[0] += l.item() * Y.numel()  # 使用 l.item() 累积标量损失metric[1] += Y.numel()  # 累计总词元数量return metric[0] / max(1, metric[1])  # 避免除以零def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型"""loss = nn.CrossEntropyLoss()updater = torch.optim.SGD(net.params, lr)for epoch in range(num_epochs):avg_loss = train_epoch_ch8(net, train_iter, loss, updater,device, use_random_iter)ppl = torch.exp(torch.tensor(avg_loss))  # 转换为 Tensor 以使用 torch.expprint(f'epoch {epoch + 1}, perplexity {ppl:.1f}')print(predict_ch8('time traveller', 50, net, vocab, device))# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 模型初始化
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_rnn_state, rnn)# 训练模型
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, device)

相关文章:

从0开始深度学习(32)——循环神经网络的从零开始实现

本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级) 首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据: train_iter…...

GitLab使用操作v1.0

1.前置条件 Gitlab 项目地址:http://******/req Gitlab账户信息:例如 001/******自己的分支名称:例如 001-master(注:master只有项目创建者有权限更新,我们只能更新自己分支,然后创建合并请求&…...

cuda conda yolov11 环境搭建

优雅的 yolo v11 标注工具 AutoLabel Conda环境直接识别训练 nvidia-smi 检查CUDA版本 下载nvidia cudnn对应的版本 将cuDNN压缩包内对应的文件复制到本地bin、include、lib的文件夹中 C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6 miniConda快速开始-安装 执行…...

解决SpringBoot连接Websocket报:请求路径 404 No static resource websocket.

问题发现 最近在工作中用到了WebSocket进行前后端的消息通信,后端代码编写完后,测试一下是否连接成功,发现报No static resource websocket.,看这个错貌似将接口变成了静态资源来访问了,第一时间觉得是端点没有注册成…...

element-plus的组件数据配置化封装 - table

目录 一、封装的table、table-column组件以及相关ts类型的定义 1、ATable组件的封装 - index.ts 2、ATableColumn组件的封装 - ATableColumn.ts 3、ATable、ATableColumn类型 - interface.ts 二、ATable、ATableColumn组件的使用 三、相关属性、方法的使用以及相关说明 1. C…...

【二维动态规划:交错字符串】

介绍 编程语言:Java 本篇介绍一道比较经典的二维动态规划题。 交错字符串 主要说明几点: 为什么双指针解不了?为什么是二维动态规划?根据题意分析处转移方程。严格位置依赖和空间压缩优化。 题目介绍 题意有点抽象&#xff0c…...

goframe开发一个企业网站 MongoDB 完整工具包18

1. MongoDB 工具包完整实现 (mongodb.go) package mongodbimport ("context""fmt""time""github.com/gogf/gf/v2/frame/g""go.mongodb.org/mongo-driver/mongo""go.mongodb.org/mongo-driver/mongo/options" )va…...

在vue中,根据后端接口返回的文件流实现word文件弹窗预览

需求 弹窗预览word文件,因浏览器无法直接根据blob路径直接预览word文件,所以需要利用插件实现。 解决方案 利用docx-preview实现word文件弹窗预览,以node版本16.21.3和docx-preview版本0.1.8为例 具体实现步骤 1、安装docx-preview插件 …...

动态规划之背包问题

0/1背包问题 1.二维数组解法 题目描述:有一个容量为m的背包,还有n个物品,他们的重量分别为w1、w2、w3.....wn,他们的价值分别为v1、v2、v3......vn。每个物品只能使用一次,求可以放进背包物品的最大价值。 输入样例…...

【Python】 深入理解Python的单元测试:用unittest和pytest进行测试驱动开发

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 单元测试是现代软件开发中的重要组成部分,通过验证代码的功能性、准确性和稳定性,提升代码质量和开发效率。本文章深入介绍Python中两种主流单元测试框架:unittest和pytest,并结合测试驱动开发(TDD)…...

Java集合1.0

1.什么是集合? 集合就是一个存放数据的容器,准确的说是放数据对象引用的容器。 集合和数组的区别 数组是固定长度,集合是可变长度。数组可以存储基本数据类型,也可以存储引用数据类型,集合只能存储引用数据类型&…...

Leetcode 336 回文对

示例 1: 输入:words ["abcd","dcba","lls","s","sssll"] 输出:[[0,1],[1,0],[3,2],[2,4]] 解释:可拼接成的回文串为 ["dcbaabcd","abcddcba","sl…...

实现一个可配置的TCP设备模拟器,支持交互和解析配置

前言 诸位在做IOT开发的时候是否有遇到一个问题,那就是模拟一个设备来联调测试,虽然说现在的物联网通信主要是用mqtt通信,但还是有很多设备使用TCP这种协议交互,例如充电桩,还有一些工业设备,TCP这类报文交…...

算法的空间复杂度

空间复杂度 空间复杂度主要是衡量一个算法运行所需要的额外空间,在计算机发展早期,计算机的储存容量很小,所以空间复杂度是很重要的。但是经过计算机行业的迅速发展,计算机的容量已经不再是问题了,所以如今已经不再需…...

自定义协议

1. 问题引入 问题:TCP是面向字节流的(TCP不关心发送的数据是消息、文件还是其他任何类型的数据。它简单地将所有数据视为一个字节序列,即字节流。这意味着TCP不会对发送的数据进行任何特定的边界划分,它只是确保数据的顺序和完整…...

在 Taro 中实现系统主题适配:亮/暗模式

目录 背景实现方案方案一:CSS 变量 prefers-color-scheme 媒体查询什么是 prefers-color-scheme?代码示例 方案二:通过 JavaScript 监听系统主题切换 背景 用Taro开发的微信小程序,需求是页面的UI主题想要跟随手机系统的主题适配…...

autogen框架中使用chatglm4模型实现react

本文将介绍如何使用使用chatglm4实现react,利用环境变量、Tavily API和ReAct代理模式来回答用户提出的问题。 环境变量 首先,我们需要加载环境变量。这可以通过使用dotenv库来实现。 from dotenv import load_dotenv_ load_dotenv()注意.env文件处于…...

读《Effective Java》笔记 - 条目9

条目9:与try-finally 相比,首选 try -with -resource 什么是 try-finally? try-finally 是 Java 中传统的资源管理方式,通常用于确保资源(如文件流、数据库连接等)被正确关闭。 BufferedReader reader n…...

【软件入门】Git快速入门

Git快速入门 文章目录 Git快速入门0.前言1.安装和配置2.新建版本库2.1.本地创建2.2.云端下载 3.版本管理3.1.添加和提交文件3.2.回退版本3.2.1.soft模式3.2.2.mixed模式3.2.3.hard模式3.2.4.使用场景 3.3.查看版本差异3.4.忽略文件 4.云端配置4.1.Github4.1.1.SSH配置4.1.2.关联…...

nextjs window is not defined

问题产生的原因 在 Next.js 中,“window is not defined” 错误通常出现在服务器端渲染(Server - Side Rendering,SSR)的代码中。这是因为window对象是浏览器环境中的全局对象,在服务器端没有window这个概念。例如&am…...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题: 下面创建一个简单的Flask RESTful API示例。首先,我们需要创建环境,安装必要的依赖,然后…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

iPhone密码忘记了办?iPhoneUnlocker,iPhone解锁工具Aiseesoft iPhone Unlocker 高级注册版​分享

平时用 iPhone 的时候,难免会碰到解锁的麻烦事。比如密码忘了、人脸识别 / 指纹识别突然不灵,或者买了二手 iPhone 却被原来的 iCloud 账号锁住,这时候就需要靠谱的解锁工具来帮忙了。Aiseesoft iPhone Unlocker 就是专门解决这些问题的软件&…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中,新增了一个本地验证码接口 /code,使用函数式路由(RouterFunction)和 Hutool 的 Circle…...

云原生玩法三问:构建自定义开发环境

云原生玩法三问:构建自定义开发环境 引言 临时运维一个古董项目,无文档,无环境,无交接人,俗称三无。 运行设备的环境老,本地环境版本高,ssh不过去。正好最近对 腾讯出品的云原生 cnb 感兴趣&…...

JVM 内存结构 详解

内存结构 运行时数据区: Java虚拟机在运行Java程序过程中管理的内存区域。 程序计数器: ​ 线程私有,程序控制流的指示器,分支、循环、跳转、异常处理、线程恢复等基础功能都依赖这个计数器完成。 ​ 每个线程都有一个程序计数…...

深度学习水论文:mamba+图像增强

🧀当前视觉领域对高效长序列建模需求激增,对Mamba图像增强这方向的研究自然也逐渐火热。原因在于其高效长程建模,以及动态计算优势,在图像质量提升和细节恢复方面有难以替代的作用。 🧀因此短时间内,就有不…...

R 语言科研绘图第 55 期 --- 网络图-聚类

在发表科研论文的过程中,科研绘图是必不可少的,一张好看的图形会是文章很大的加分项。 为了便于使用,本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中,获取方式: R 语言科研绘图模板 --- sciRplothttps://mp.…...

【从零开始学习JVM | 第四篇】类加载器和双亲委派机制(高频面试题)

前言: 双亲委派机制对于面试这块来说非常重要,在实际开发中也是经常遇见需要打破双亲委派的需求,今天我们一起来探索一下什么是双亲委派机制,在此之前我们先介绍一下类的加载器。 目录 ​编辑 前言: 类加载器 1. …...