当前位置: 首页 > news >正文

[PyTorch][chapter 45][RNN_2]

目录:

  1.     RNN 问题
  2.     RNN 时序链问题
  3.     RNN 词组预测的例子
  4.      RNN简洁实现

 


一  RNN 问题

     RNN 主要有两个问题,梯度弥散和梯度爆炸

    1.1  损失函数

               E=\sum_{t=1}^TE_t

               h_t= tanh(W_{xh}x_t+W_{hh}h_t)

               y_t=W_{ho}h_t

           梯度

              \frac{\partial E_t}{\partial W_{hh}}=\sum_{i=1}^{t}\frac{\partial E_t}{\partial y_{t}}\frac{\partial y_t}{\partial h_{t}}\frac{\partial h_t}{\partial h_{i}}\frac{\partial h_i}{\partial W_{hh}}

              其中:

              \frac{\partial h_t}{\partial h_i}=\prod_{k=i}^{t-1} \frac{\partial h_{k+1}}{\partial h_k}

              \frac{\partial h_{k+1}}{\partial h_{k}}=diag(1-h_k^2)W_{hh}

             则

                \frac{\partial h_k}{\partial h_1}=\prod_{i}^{k} diag(1-h_i^2)W_{hh}

     1.1  梯度爆炸(Gradient Exploding)

                  上面矩阵进行连乘后k,可能会出现里面参数会变得极大

               

                解决方案:

                  梯度剪裁:对W.grad进行约束

              

           

def print_current_grad(model):for  w in model.parameters():print(w.grad.norm())loss.criterion(output, y)
model.zero_grad()
loss.backward()
print_current_grad(model)
torch.nn.utils.clip_grad_norm_(p,10)
print_current_grad(model)
optimizer.step()

             

     1.2  梯度弥散(Gradient vanishing)

    是由于时序链过程,导致梯度为0,前面的层参数无法更新。

  解决方案 :

          LSTM.


二  RNN 时序链问题

    

# -*- coding: utf-8 -*-
"""
Created on Mon Jul 24 15:12:49 2023@author: chengxf2
"""import torch
import torch.nn as  nn
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt  # 导入作图相关的包'''生成训练的数据集
return x: 当前时刻的输入值[batch_size=1, time_step=num_time_steps-1, feature=1]y: 当前时刻的标签值[batch_size=1, time_step=num_time_steps-1, feature=1]
'''
def sampleData():#生成一个[0-3]之间的数据start = np.random.randint(3,size=1)[0]num_time_steps =20#时序链长度为num_time_stepstime_steps= np.linspace(start, start+10,num_time_steps)data = np.sin(time_steps)data = data.reshape(num_time_steps,1)#[batch, seq, dimension]x= torch.tensor(data[:-1]).float().view(1,num_time_steps-1,1)y= torch.tensor(data[1:]).float().view(1, num_time_steps-1,1)return x,y,time_steps'''网络模型args:input_size – 输入x的特征数量。hidden_size – 隐藏层的特征数量。num_layers – RNN的层数。nonlinearity – 指定非线性函数使用tanh还是relu。默认是tanh。bias – 默认是Truebatch_first – 如果True的话,那么输入Tensor的shape应该是[batch_size, time_step, feature],输出也是这样。默认是Falsedropout – 如果值非零,那么除了最后一层外,其它层的输出都会套上一个dropout层。bidirectional – 如果True,将会变成一个双向RNN,默认为False。
'''
class Net(nn.Module):def __init__(self,input_dim = 1, hidden_dim =10,  out_dim = 1):super(Net, self).__init__()self.rnn= nn.RNN(input_size = input_dim, hidden_size = hidden_dim,num_layers = 1,batch_first = True)self.linear= nn.Linear(in_features= hidden_dim, out_features=out_dim)#前向传播函数def forward(self,x,hidden_prev):# 给定一个h_state初始状态,(batch_size=1,layer=1,hidden_dim=10)# 给定一个序列x.shape:[batch_size, time_step, feature]hidden_dim =10#print("\n x.shape",x.shape)out,hidden_prev= self.rnn(x,hidden_prev)out = out.view(-1,hidden_dim) #[1,seq,h]=>[1*seq,h]out = self.linear(out)#[seq,h]=>[seq,1]out = out.unsqueeze(dim=0) #[seq,1] 指定的维度上面添加一个维度[batch=1,seq,1]return out, hidden_prevdef main():model = Net()criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(),lr=1e-3)hidden_dim =10#初始值hidden_prv = torch.zeros(1,1,hidden_dim)for iter in range(5000):x,y,time_steps =sampleData() #[batch=1,seq=99,dim=1]output, hidden_prev =model(x,hidden_prv)hidden_prev = hidden_prev.detach()loss = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter %100 ==0:print("Iter:{} loss{}".format(iter, loss.item()))# 对最后一次的结果作图查看网络的预测效果plt.plot(time_steps[0:-1], y.flatten(), 'r-')plt.plot(time_steps[0:-1], output.data.numpy().flatten(), 'b-')
main()

三  RNN 词组预测的例子

     这是参考李沐写得一个实现nn.RNN功能的例子

,一般很少用,都是直接用nn.RNN.

# -*- coding: utf-8 -*-
"""
Created on Wed Jul 26 14:17:49 2023@author: chengxf2
"""import math
import torch
from torch import nn
from torch.nn import functional as F
import numpy
import d2lzh_pytorch as d2l#生成随机变量
def normal(shape,device):return torch.randn(size=shape, device=device)*0.01#模型需要更新的权重系数
def get_params(vocab_size=27, num_hiddens=10, device='cuda:0'):num_inputs = num_outputs = vocab_sizeW_xh = normal((num_inputs,num_hiddens),device)W_hh = normal((num_hiddens,num_hiddens),device)b_xh = torch.zeros(num_hiddens,device=device)b_hh = torch.zeros(num_hiddens,device=device)W_hq = normal((num_hiddens,num_outputs),device)b_q = torch.zeros(num_outputs, device= device)params = [W_xh,W_hh, b_xh,b_hh, W_hq,b_q]for param in params:param.requires_grad_(True)return params#初始的隐藏值 hidden ,tuple
def init_rnn_state(batch_size, hidden_size, device):h_init= torch.zeros((batch_size,hidden_size),device=device)return  (h_init,)#RNN 函数定义了如何在时间序列上更新隐藏状态和输出
def rnn(X, h_init, params):W_xh,W_hh, b_xh,b_hh, W_hq,b_q = paramshidden, = h_initoutputs =[]for x_t in X:z_t = torch.mm(x_t, W_xh)+b_xh+ torch.mm(x_t,W_hh)+b_hhhidden =torch.tanh(z_t)out = torch.mm(hidden,W_hq)+b_qoutputs.append(out)#[batch_size*T, dimension]return torch.cat(outputs, dim=0),(hidden,)#根据给定的词,预测后面num_preds 个词
def predict_ch8(prefix, num_preds, net, vocab, device):#生成初始状态state = net.begin_state(batch_size=1, device=device)#把第一个词拿出来outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]],device=device,(1,1))for y in prefix[1:]:_,state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds):y, state = net(get_input(), state)outputs, (int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idex_to_toke[i] for i in output])#梯度剪裁def grad_clipping(net, theta):if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad_]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params)if norm > theta:for param in params:param.grad[:]*=theta/normclass RNNModel:#从零开始实现RNN 网络模型#def __init__(self, vocab_size, hidden_size, device, get_params, init_rnn_state,forward_fn):forward_fnself.vocab_size = vocab_size, self.num_hiddens = hidden_sizeself.params = get_params(vocab_size, hidden_size, device)self.init_state = init_rnn_state(batch_size, hidden_size, device)self.forwad_fn = forward_fn#X.shape [batch_size,num_steps] def __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)#[num_steps, batch_size]return self.forwad_fn(X, state, self.params)def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)# 训练模型def train_epoch_ch8(net, train_iter, loss, updater, device,)state, timer = None, d21.Timer()metric = d21.Accumulator(2)for X,Y in train_iter:if state is None or use_random_iter:state = net.beign_state(bacth_size=X.shape[0])elseif isinstance(net, nn.Module) and not isinstance(o, t)state.detach_()elsefor s in state:s.detach_()y = Y.T.reshape(-1)X,y = X.to(device),y.to(device)y_hat,state=net(X,state)l = loss(y_hat,y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()elsel.backward()grad_clipping(net, 1)updater(batch_size=1)metric.add(1&y.numel(),y.numel())return math.exp(metric[0]/metric[1]))def train(net, train_iter,vocab, lr,num_epochs, device, use_random_iter=False):loss = nn.CrossEntropyLoss()animator = d21.animator(xlabel='epoch',ylabel='preplexity',legend=['train'],xlim=[10,num_epochs])if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(),lr)else:updater = lambda batch_size: d21.sgd(net.parameters,batch_size,lr)predict = lambda prefix: predict_ch8(prefix, num_preds=50, net, vocab, device)for epoch in range(num_epochs):ppl, spped = train_epoch_ch8(net, train_iter, updater(),use_random_iter())if (epoch+1)%10 ==0:print(predict('time traverller'))animator.add(epoch+1, [ppl])print(f'困惑度{ppl:lf},{speed:1f} 标记/秒')print(predict('time traveller'))print(predict('traveller'))def main():num_hiddens =512num_epochs, ,lr = 500,1vocab_size = len(vocab)#[批量大小,时间步数]batch_size, num_steps = 32, 10train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)F.one_hot(torch.tensor([0,2]), len(vocab))X= torch.arange(10).reshape((2,5))Y = F.one_hot(X.T,28).shape #[step, batch_num]model = RNNModel(vocab_size, num_hiddens, dl2.try_gpu(), get_params, init_rnn_state, rnn)            train_ch8(model, train_iter, vocab,lr,num_epochs,dl2.try_gpu())if __name__ == "__main__":main()

四  RNN简洁实现

     

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 10:11:33 2023@author: chengxf2
"""import torch
from torch import nn
from torch.nn import functional as Fclass SimpleRNN(nn.Module):def __init__(self,batch_size, input_size, hidden_size,out_size):super(SimpleRNN,self).__init__()self.batch_size,self.num_hiddens = batch_size,hidden_sizeself.rnn_layer = nn.RNN(input_size,hidden_size)self.linear = nn.Linear(hidden_size, out_size)def  forward(self, X,state):'''Parameters----------X : [seq,batch, feature]state : [layer, batch, feature]-------#output:(layer, batch_size, hidden_size)state_new : []'''hidden, hidden_new = self.rnn_layer(X, state)hidden = hidden.view(-1, hidden.shape[-1])output = self.linear(hidden)return output ,hiddendef init_hidden_state(self):'''初始化隐藏状态'''state = torch.zeros((1,self.batch_size, self.num_hiddens))return statedef main():seq_len = 3 #时序链长度batch_size =5 #批量大小input_size = 27hidden_size = 10out_size = 9X = torch.rand(size=(seq_len,batch_size,input_size))model = SimpleRNN(batch_size,input_size, hidden_size,out_size)init_state = model.init_hidden_state()output, hidden = model.forward(X, init_state)print("\n 输出值:",output.shape)print("\n 时刻的隐藏状态")print(hidden.shape)if __name__ == "__main__":main()

pytorch入门10--循环神经网络(RNN)_rnn代码pytorch_微扬嘴角的博客-CSDN博客

【PyTorch】深度学习实践之 RNN基础篇——实现RNN_pytorch实现rnn_zoetu的博客-CSDN博客

RNN 的基本原理+pytorch代码_rnn代码_黄某某很聪明的博客-CSDN博客

55 循环神经网络 RNN 的实现【动手学深度学习v2】_哔哩哔哩_bilibili

《动手学深度学习》环境搭建全程详细教程 window用户_https://zh.d21.ai/d21-zh-1.1.zip_溶~月的博客-CSDN博客

ModuleNotFoundError: No module named ‘d2l’_卡拉比丘流形的博客-CSDN博客

相关文章:

[PyTorch][chapter 45][RNN_2]

目录: RNN 问题 RNN 时序链问题 RNN 词组预测的例子 RNN简洁实现 一 RNN 问题 RNN 主要有两个问题,梯度弥散和梯度爆炸 1.1 损失函数 梯度 其中: 则 1.1 梯度爆炸(Gradient Exploding) 上面矩阵进行连乘后…...

基于canvas画布的实用类Fabric.js的使用

目录 前言 一、Fabric.js简介 二、开始 1、引入Fabric.js 2、在main.js中使用 3、初始化画布 三、方法 四、事件 1、常用事件 2、事件绑定 3、事件解绑 五、canvas常用属性 六、对象属性 1、基本属性 2、扩展属性 七、图层层级操作 八、复制和粘贴 1、复制 2…...

基于SpringBoot+Vue驾校理论课模拟考试系统源码(自动化部署)

DrivingTestSimulation Unity3D Project, subject two, simulated driving test 【更新信息】 更新时间-2021-1-17 解决了方向盘不同机型转动轴心偏离 更新时间-2021-2-18 加入了手刹系统 待更新-2021-6-19(工作太忙少有时间更新,先指出问题&#xf…...

SpringBoot使用Redis对用户IP进行接口限流

使用接口限流的主要目的在于提高系统的稳定性&#xff0c;防止接口被恶意打击&#xff08;短时间内大量请求&#xff09;。 一、创建限流注解 引入redis依赖 <!--redis--><dependency><groupId>org.springframework.boot</groupId><artifactId&g…...

MeterSphere学习篇

从开发环境部署开始 metersphere-1.20.4 源码下载地址&#xff1a; https://gitee.com/fit2cloud-feizhiyun/MeterSphere/tree/v1.20/ MeterSphere GitHub 相关插件程序下载 相关准备 安装mysql 配置IDEA...

大数据技术之Clickhouse---入门篇---数据类型、表引擎

星光下的赶路人star的个人主页 今天没有开始的事&#xff0c;明天绝对不会完成 文章目录 1、数据类型1.1 整型1.2 浮点型1.3 布尔型1.4 Decimal型1.5 字符串1.6 枚举类型1.7 时间类型1.8 数组 2、表引擎2.1 表引擎的使用2.2 TinyLog2.3 Memory2.4 MergeTree2.4.1 Partition by分…...

【微服务架构设计】微服务不是魔术:处理超时

微服务很重要。它们可以为我们的架构和团队带来一些相当大的胜利&#xff0c;但微服务也有很多成本。随着微服务、无服务器和其他分布式系统架构在行业中变得更加普遍&#xff0c;我们将它们的问题和解决它们的策略内化是至关重要的。在本文中&#xff0c;我们将研究网络边界可…...

天下风云出我辈,AI准独角兽实在智能获评“十大数字经济风云企业

时值盛夏&#xff0c;各地全力拼经济的氛围同样热火朝天。在浙江省经济强区余杭区这片创业热土上&#xff0c;人工智能助力数字经济建设正焕发出蓬勃生机。 7月28日&#xff0c;经专家评审、公开投票&#xff0c;由中共杭州市余杭区委组织部&#xff08;区委两新工委&#xff…...

SpringBoot2学习笔记

信息来源&#xff1a;https://www.bilibili.com/video/BV19K4y1L7MT?p5&vd_source3969f30b089463e19db0cc5e8fe4583a 作者提供的文档&#xff1a;https://www.yuque.com/atguigu/springboot 作者提供的代码&#xff1a;https://gitee.com/leifengyang/springboot2 ----…...

安达发|APS生产派单系统对数字化工厂有哪些影响和作用

数字化工厂是当今制造业的热门话题&#xff0c;而APS软件则是这一领域的颠覆者。它以其独特的影响和作用&#xff0c;给制造业带来了巨大的改变。让我们一起来看看APS软件对数字化工厂有哪些影响和作用吧&#xff01; 提高生产效率的神器 1.APS软件作为数字化工厂的核心系统&a…...

状态机的介绍和使用 | 京东物流技术团队

1 状态机简介 1.1 定义 我们先来给出状态机的基本定义。一句话&#xff1a; 状态机是有限状态自动机的简称&#xff0c;是现实事物运行规则抽象而成的一个数学模型。 先来解释什么是“状态”&#xff08; State &#xff09;。现实事物是有不同状态的&#xff0c;例如一个自…...

tinkerCAD案例:32. 使用对齐工具构建喷泉

tinkerCAD案例&#xff1a;32. 使用对齐工具构建喷泉 In this lesson, you will practice the basics in Tinkercad, such as move, rotate, and scale. You will also learn how to use the Align Tool. 在本课中&#xff0c;您将练习 Tinkercad 中的基础知识&#xff0c;例如…...

一起学数据结构(2)——线性表及线性表顺序实现

目录 1. 什么是数据结构&#xff1a; 1.1 数据结构的研究内容&#xff1a; 1.2 数据结构的基本概念&#xff1a; 1.2.1 逻辑结构&#xff1a; 1.2.2 存储结构&#xff1a; 2. 线性表&#xff1a; 2.1 线性表的基本定义&#xff1a; 2.2 线性表的运用&#xff1a; 3 .线性…...

mqtt协议流程图

转载于...

7、单元测试--测试RestFul 接口

单元测试–测试RestFul 接口 – 测试用例类使用SpringBootTest(webEnvironment WebEnvironment.RANDOM_PORT)修饰。 – 测试用例类会接收容器依赖注入TestRestTemplate这个实例变量。 – 测试方法可通过TestRestTemplate来调用RESTful接口的方法。 测试用例应该定义在和被测…...

国家留学基金委(CSC)|发布2024年创新型人才国际合作培养项目实施办法

2023年7月28日&#xff0c;国家留学基金委&#xff08;CSC&#xff09;发布了《2024年创新型人才国际合作培养项目实施办法》&#xff0c;在此知识人网小编做全文转载。详细信息请参见https://www.csc.edu.cn/chuguo/s/2648。 2024年创新型人才国际合作培养项目实施办法 第一章…...

找好听的配乐、BGM就上这6个网站,免费商用。

推荐几个音乐素材网站给你&#xff0c;各种类似、风格的都有&#xff0c;而且免费下载&#xff0c;还可以商用&#xff0c;建议收藏起来~ 菜鸟图库 https://www.sucai999.com/audio.html?vNTYxMjky 站内有上千首音效素材&#xff0c;网络流行的音效素材这里都能找到&#xf…...

【前端知识】React 基础巩固(三十五)——ReduxToolKit (RTK)

React 基础巩固(三十五)——ReduxToolKit (RTK) 一、RTK介绍 Redux Tool Kit &#xff08;RTK&#xff09;是官方推荐的编写Redux逻辑的方法&#xff0c;旨在成为编写Redux逻辑的标准方式&#xff0c;从而解决上面提到的问题。 RTK的核心API主要有如下几个&#xff1a; confi…...

android Android Studio Giraffe | 2022.3.1 版本Lombok不兼容 解决方案

android Android Studio Giraffe | 2022.3.1 版本Lombok不兼容 解决方案 1.查看当前的android studio 版本 Android Studio Giraffe | 2022.3.1 Build #AI-223.8836.35.2231.10406996, built on June 29, 2023 2.打开 idea 官网下载页面 idea下载历史版本 找到对应的版本编号…...

前端框架学习-基础前后端分离

前端知识栈 前端三要素&#xff1a;HTML、CSS、JS HTML 是前端的一个结构层&#xff0c;HTML相当于一个房子的框架&#xff0c;可类比于毛坯房只有一个结构。CSS 是前端的一个样式层&#xff0c;有了CSS的装饰&#xff0c;相当于房子有了装修。JS 是前端的一个行为层&#xff…...

Swift 协议扩展精进之路:解决 CoreData 托管实体子类的类型不匹配问题(下)

概述 在 Swift 开发语言中&#xff0c;各位秃头小码农们可以充分利用语法本身所带来的便利去劈荆斩棘。我们还可以恣意利用泛型、协议关联类型和协议扩展来进一步简化和优化我们复杂的代码需求。 不过&#xff0c;在涉及到多个子类派生于基类进行多态模拟的场景下&#xff0c;…...

工程地质软件市场:发展现状、趋势与策略建议

一、引言 在工程建设领域&#xff0c;准确把握地质条件是确保项目顺利推进和安全运营的关键。工程地质软件作为处理、分析、模拟和展示工程地质数据的重要工具&#xff0c;正发挥着日益重要的作用。它凭借强大的数据处理能力、三维建模功能、空间分析工具和可视化展示手段&…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

大模型多显卡多服务器并行计算方法与实践指南

一、分布式训练概述 大规模语言模型的训练通常需要分布式计算技术,以解决单机资源不足的问题。分布式训练主要分为两种模式: 数据并行:将数据分片到不同设备,每个设备拥有完整的模型副本 模型并行:将模型分割到不同设备,每个设备处理部分模型计算 现代大模型训练通常结合…...

搭建DNS域名解析服务器(正向解析资源文件)

正向解析资源文件 1&#xff09;准备工作 服务端及客户端都关闭安全软件 [rootlocalhost ~]# systemctl stop firewalld [rootlocalhost ~]# setenforce 0 2&#xff09;服务端安装软件&#xff1a;bind 1.配置yum源 [rootlocalhost ~]# cat /etc/yum.repos.d/base.repo [Base…...

Python 实现 Web 静态服务器(HTTP 协议)

目录 一、在本地启动 HTTP 服务器1. Windows 下安装 node.js1&#xff09;下载安装包2&#xff09;配置环境变量3&#xff09;安装镜像4&#xff09;node.js 的常用命令 2. 安装 http-server 服务3. 使用 http-server 开启服务1&#xff09;使用 http-server2&#xff09;详解 …...

如何在Windows本机安装Python并确保与Python.NET兼容

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

WEB3全栈开发——面试专业技能点P8DevOps / 区块链部署

一、Hardhat / Foundry 进行合约部署 概念介绍 Hardhat 和 Foundry 都是以太坊智能合约开发的工具套件&#xff0c;支持合约的编译、测试和部署。 它们允许开发者在本地或测试网络快速开发智能合约&#xff0c;并部署到链上&#xff08;测试网或主网&#xff09;。 部署过程…...

第2篇:BLE 广播与扫描机制详解

本文是《BLE 协议从入门到专家》专栏第二篇,专注于解析 BLE 广播(Advertising)与扫描(Scanning)机制。我们将从协议层结构、广播包格式、设备发现流程、控制器行为、开发者 API、广播冲突与多设备调度等方面,全面拆解这一 BLE 最基础也是最关键的通信机制。 一、什么是 B…...

C#学习12——预处理

一、预处理指令&#xff1a; 解释&#xff1a;是在编译前由预处理器执行的命令&#xff0c;用于控制编译过程。这些命令以 # 开头&#xff0c;每行只能有一个预处理指令&#xff0c;且不能包含在方法或类中。 个人理解&#xff1a;就是游戏里面的备战阶段&#xff08;不同对局…...