当前位置: 首页 > news >正文

十三、RNN循环神经网络实战

因为我本人主要课题方向是处理图像的,RNN是基本的序列处理模型,主要应用于自然语言处理,故这里就简单的学习一下,了解为主

一、问题引入

已知以前的天气数据信息,进行预测当天(4-9)是否下雨

日期温度气压是否下雨
4-11020
4-23040
4-34025
4-41030
4-5510
4-61020
4-71260
4-82580
4-92015

这里的数据集都是随别胡乱写的哈,就说在阐述一下待解决的问题,随别做的数据集
思路:可以四天一组,每组中有4天的天气信息,包括温度、气压、是否下雨
前三天作为输入,第四天最为输出

在卷积神经网络中,全连接层是权重最多的,也是整个网络中计算量最多的地方

卷积中
输入:128通道
输出:64通道
卷积核:5×5
总共的权重参数:128×64×5×5 = 204800

全连接中
一般都不会直接将一个高维通道直接变为1,而是多几个中间层进行过度
输入:4096
输出:1024
权重参数个数:4096×1024 = 4194304

权重参数个数压根都不在一个数量级上,所以说,正因为卷积的权重共享,导致卷积操作所需参数远小于全连接

RNN循环神经网络主要用在具有序列关系的数据中进行处理,例如:天气的预测,因为前后的天气会相互影响,并不会断崖式的变化、股市预测等,典型的就是自然语言处理
我喜欢beyond乐队这句话的词语之间具有序列关系,随便调换顺序产生的结果肯定很难理解

二、RNN循环神经网络

Ⅰ,RNN Cell

RNN Cell是RNN中的核心单元
在这里插入图片描述
xt:序列当中,时刻t时的数据,这个数据具有一定的维度,例如天气数据就是3D向量的,即,温度、气压、是否下雨
xt通过RNN Cell之后就会得到一个ht,这个数据也是具有一定的维度,假如是5D向量
从xt这个3D向量数据通过RNN Cell得到一个ht这个5D向量数据,很明显,这个RNN Cell本质就是一个线性层
区别:RNN Cell这个线性层是共享的

在这里插入图片描述
在这里插入图片描述
RNN Cell基本流程
在这里插入图片描述

现学现卖

在这里插入图片描述

import torch#根据需求设定参数
batch_size = 1
seq_len = 3
input_size = 4
hidden_size = 2yy_cell = torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)dataset = torch.randn(seq_len,batch_size,input_size)
hidden = torch.zeros(batch_size,hidden_size) #h0设置为全0for idx,inputs in enumerate(dataset):print('-----------------)print("Input size:",inputs.shape)hidden = yy_cell(inputs,hidden)print("outputs size:",hidden.shape)print(hidden)
"""
==================== 0 ====================
Input size: torch.Size([1, 4])
outputs size: torch.Size([1, 2])
tensor([[ 0.6377, -0.4208]], grad_fn=<TanhBackward0>)
==================== 1 ====================
Input size: torch.Size([1, 4])
outputs size: torch.Size([1, 2])
tensor([[-0.2049,  0.6174]], grad_fn=<TanhBackward0>)
==================== 2 ====================
Input size: torch.Size([1, 4])
outputs size: torch.Size([1, 2])
tensor([[-0.1482, -0.2232]], grad_fn=<TanhBackward0>)
"""

Ⅱ,RNN

在这里插入图片描述

现学现卖

import torch#根据需求设定参数
batch_size = 1
seq_len = 3
input_size = 4
hidden_size = 2
num_layers = 2 #两层RNN Cellcell = torch.nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers)inputs = torch.randn(seq_len,batch_size,input_size)
hidden = torch.zeros(num_layers,batch_size,hidden_size) #h0设置为全0out,hidden = cell(inputs,hidden)print('output size:',out.shape)
print('output:',out)
print('hidden size:',hidden.shape)
print('hidden',hidden)"""
output size: torch.Size([3, 1, 2])
output: tensor([[[ 0.8465, -0.1636]],[[ 0.3185, -0.1733]],[[ 0.0269, -0.1330]]], grad_fn=<StackBackward0>)
hidden size: torch.Size([2, 1, 2])
hidden tensor([[[ 0.5514,  0.8349]],[[ 0.0269, -0.1330]]], grad_fn=<StackBackward0>)
"""

三、RNN实战

需求:实现将输入beyond转换为ynbode

①文本转向量one-hot

因为RNN Cell单元输入的数据必须是由单词构成的向量 ,根据字符来构建一个词典,并为其分配索引,索引变One-Hot向量,词典中有几项,最终构建的向量也有几列,只能出现一个1,其余都为0

characterindex
b0
d1
e2
n3
o4
y5

在这里插入图片描述

在这里插入图片描述

②模型训练

Ⅰ RNN Cell

import torchinput_size = 6
hidden_size = 6
batch_size = 1dictionary = ['b','e','y','o','n','d'] #字典
x_data = [0,1,2,3,4,5] #beyond
y_data = [2,4,0,3,5,1] #ynbodeone_hot = [[1,0,0,0,0,0],[0,1,0,0,0,0],[0,0,1,0,0,0],[0,0,0,1,0,0],[0,0,0,0,1,0],[0,0,0,0,0,1]]x_one_hot = [one_hot[x] for x in x_data] #将x_data的每个元素从one_hot得到相对于的向量形式inputs = torch.Tensor(x_one_hot).view(-1,batch_size,input_size) #inputs形式为(seqlen,batch_size,input_size)
labels = torch.LongTensor(y_data).view(-1,1) #lables形式为(seqlen,1)class y_rnncell_model(torch.nn.Module):def __init__(self,input_size,hidden_size,batch_size):super(y_rnncell_model,self).__init__()self.batch_size = batch_sizeself.input_size = input_sizeself.hidden_size = hidden_sizeself.rnncell = torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)def forward(self,inputs,labels):hidden = self.rnncell(inputs,labels)return hiddendef init_hidden(self): #定义h0初始化return torch.zeros(self.batch_size,self.hidden_size)y_net = y_rnncell_model(input_size,hidden_size,batch_size)#定义损失函数和优化器
lossf = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(y_net.parameters(),lr=0.001)# RNN Cell
for epoch in range(800):loss = 0optim.zero_grad() #优化器梯度归零hidden = y_net.init_hidden() #h0print('Predicted string:',end='')for x,y in zip(inputs,labels):hidden = y_net(x,hidden)loss += lossf(hidden,y) #计算损失之和,需要构造计算图_,idx = hidden.max(dim=1)print(dictionary[idx.item()],end='')loss.backward()optim.step()print(',Epoch [%d/20] loss=%.4f'%(epoch+1,loss.item()))

Ⅱ RNN

#引入torch
import torchinput_size = 6 #beyond
hidden_size = 6 #
num_layers = 1
batch_size = 1
seq_len = 6idx2char = ['b','d','e','n','o','y'] #字典
x_data = [0,2,5,4,3,1] #beyond
y_data = [5,3,0,4,1,2] #ynbodeone_hot = [[1,0,0,0,0,0],[0,1,0,0,0,0],[0,0,1,0,0,0],[0,0,0,1,0,0],[0,0,0,0,1,0],[0,0,0,0,0,1]]x_one_hot = [one_hot[x] for x in x_data] #将x_data的每个元素从one_hot得到相对于的向量形式inputs = torch.Tensor(x_one_hot).view(seq_len,batch_size,input_size)labels = torch.LongTensor(y_data)class y_rnn_model(torch.nn.Module):def __init__(self,input_size,hidden_size,batch_size,num_layers):super(y_rnn_model,self).__init__()self.num_layers = num_layersself.batch_size = batch_sizeself.input_size = input_sizeself.hidden_size = hidden_sizeself.rnn = torch.nn.RNN(input_size=self.input_size,hidden_size=self.hidden_size,num_layers=self.num_layers)def forward(self,inputs):hidden = torch.zeros(self.num_layers,self.batch_size,self.hidden_size)#构造h0out,_ = self.rnn(inputs,hidden)     return out.view(-1,self.hidden_size) #(seqlen×batchsize,hiddensize)net = y_rnn_model(input_size,hidden_size,batch_size,num_layers)lessf = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.05)for epoch in range(30):optimizer.zero_grad()outputs = net(inputs)loss = lessf(outputs,labels)   loss.backward()optimizer.step()_, idx = outputs.max(dim=1)idx = idx.data.numpy()print('Predicted:',''.join([idx2char[x] for x in idx]),end='')print(',Epoch[%d/15] loss=%.3f' % (epoch+1,loss.item()))

③one-hot的不足

1,维度过高;一个单词得占用一个维度
2,one-hot向量过于稀疏;就一个1,其余全是0
3,硬编码;一对一

解决方法:EMBEDDING
思路:将高维的向量映射到一个稠密的低维的向量空间里面
即:数据的降维

在这里插入图片描述

优化RNN
在这里插入图片描述

官网torch.nn.Embedding函数详细参数解释
在这里插入图片描述

参数含义
num_embeddingsone-hot的维度
embedding_dimembedding的维度
Input: (*)(∗), IntTensor or LongTensor of arbitrary shape containing the indices to extract输入需要是一个整型或者长整型IntTensor or LongTensor
Output: (*, H), where * is the input shape and H=embedding_dim(input shape,embedding_dim )

网络架构

相关文章:

十三、RNN循环神经网络实战

因为我本人主要课题方向是处理图像的&#xff0c;RNN是基本的序列处理模型&#xff0c;主要应用于自然语言处理&#xff0c;故这里就简单的学习一下&#xff0c;了解为主 一、问题引入 已知以前的天气数据信息&#xff0c;进行预测当天(4-9)是否下雨 日期温度气压是否下雨4-…...

五子棋透明棋盘界面设计(C语言)

五子棋透明棋盘设计&#xff0c;漂亮的界面制作。程序设置双人对奕&#xff0c;人机模式&#xff0c;对战演示三种模式。设置悔棋&#xff0c;记录功能&#xff0c;有禁手设置。另有复盘功能设置。 本文主要介绍透明的玻璃板那样的五子棋棋盘的制作。作为界面设计&#xff0c;…...

Redis第六讲 Redis之List底层数据结构实现

List数据结构 List是一个有序(按加入的时序排序)的数据结构,Redis采用quicklist(双端链表) 和 ziplist 作为List的底层实现。可以通过设置每个ziplist的最大容量,quicklist的数据压缩范围,提升数据存取效率 list-max-ziplist-size -2 // 单个ziplist节点最大能存储 8kb ,…...

电子学会2023年3月青少年软件编程python等级考试试卷(四级)真题,含答案解析

目录 一、单选题(共25题,共50分) 二、判断题(共10题,共20分) 三、编程题(共3题,共30分)...

【MATLAB】一篇文章带你了解beatxbx工具箱使用

目录 一篇文章带你了解beatxbx工具箱使用 一篇文章带你了解beatxbx工具箱使用 clc;clear; tic; % step1 初始化 % 个体数量 NIND = 35; % 最大遗传代数 MAXGEN = 180; % 变量的维数 NVAR = 2; % 变量的二进制位数 % 上下界 bounds=[-10 10-10 10]; precision=0.0001; %运算精度…...

【LinuxC Sqlite数据库小项目】基于Sqlite的打卡系统------适合初学者练手的小项目

最近小哥老是想浪&#xff0c;不想好好学习&#xff0c;这不行啊&#xff0c;得想点办法&#xff0c;多少做点努力&#xff0c;于是就自己给自己写了个打卡程序&#xff1b; 该程序基于Sqlite数据库&#xff0c;实现一个简单的打卡功能&#xff0c;该函数具有自动初始化的功能…...

在掌握C#基础上再学习C语言

C#和C语言虽然名字相似&#xff0c;但它们在很多方面都有很大的区别。 首先&#xff0c;C#是一种面向对象的语言&#xff0c;而C语言是过程化的语言。这意味着C#具有更丰富的语言特性&#xff0c;如类、接口、继承和多态性等&#xff0c;而C语言则更侧重于直接对计算机硬件进行…...

HTML5 <body> 标签

HTML <body> 标签 实例 一个简单的 HTML 文档&#xff0c;包含尽可能少的必需的标签&#xff1a; <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>文档标题</title> </head><body> 文档内容…...

(链表)反转链表

文章目录前言&#xff1a;问题描述&#xff1a;解题思路&#xff1a;代码实现&#xff1a;总结&#xff1a;前言&#xff1a; 此篇是针对链表的经典练习。 问题描述&#xff1a; 给定一个单链表的头结点pHead(该头节点是有值的&#xff0c;比如在下图&#xff0c;它的val是1…...

deb文件如何安装到iphone方法分享

Cydia或同类APT管理软件在线安装 Cydia或同类APT管理软件在线安装,这个是最佳的安装方式,因为通常无需考虑依赖关系,但缺点是对网络的要求比较高;命令行中以dpkg-iXXX.deb的形式安装,好处是可以以通配符一次性安装多个deb,而且也可以直接看到脚本的运行状况和安装成功/失…...

mongodb和mysql双写数据一致性问题

文章目录 我们是如何用MongoDB的如何保证双写一致性?先写数据库,再写MongoDB先写MongoDB,再写数据库用户修改操作如何保存数据如何清理新增的垃圾数据定时删除随机删除我们是如何用MongoDB的 MongoDB是一个高可用、分布式的文档数据库,用于大容量数据存储。文档存储一般用…...

Databend 开源周报第 88 期

Databend 是一款现代云数仓。专为弹性和高效设计&#xff0c;为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务&#xff1a;https://app.databend.com 。 Whats On In Databend 探索 Databend 本周新进展&#xff0c;遇到更贴近你心意的 Databend 。 Support Eager…...

Vue3学习笔记(9.4)

Vue3自定义指令 除了默认设置的核心指令&#xff08;v-model和v-show&#xff09;&#xff0c;Vue也允许注册自定义指令。 下面我们注册一个全局指令v-focus&#xff0c;该指令的功能是在页面加载时&#xff0c;元素获得焦点&#xff1a; <!--* Author: RealRoad10834252…...

导入 Excel 文件时,抛出 413 (Request Entity Too Large) 错误

Excel文件大小&#xff1a;8MB 异常信息&#xff1a;413 (Request Entity Too Large) 环境&#xff1a;IIS10PHP7.2.33 依次检查如下几项&#xff1a; 一、php.ini Maximum amount of memory a script may consume (128MB) 限制代码消耗的最大内存&#xff0c;默认128…...

Verilog学习笔记1——关键词、运算符、数据类型、function/task、initial/always、generate

文章目录前言一、关键词二、运算符三、数据类型1、基本类型&#xff1a;reg、wire、integer、parameter四、条件语句五、循环语句1、for2、generate六、function和task七、initial和always1、initial和always相同点和区别2、always和assign语句区别前言 2023.4.4 2023.4.7 补充…...

探索LeetCode【0005】最长回文子串(未搞懂,未练习)

目录0、题目1、第一个官方答案1.1 动态规划&#xff08;未懂&#xff09;1.2 中心扩展&#xff08;已懂&#xff09;1.3 Manacher&#xff08;未懂&#xff09;2、第二个参考答案2.1 暴力求法&#xff08;已懂&#xff09;2.2 反转法&#xff08;未懂&#xff09;2.3 动态规划&…...

使用 Docker run 命令简化容器化

使用 Docker run 命令简化容器化 Docker run 是在 Docker 容器中运行应用程序的基本命令。在开始使用 Docker 之前&#xff0c;了解一些重要的命令非常重要。 在本博客中&#xff0c;我们将解释 Docker run 命令的基本语法&#xff0c;并探索其一些最常见的选项&#xff0c;以…...

腾讯TNN神经网络推理框架手动实现多设备单算子卷积推理

文章目录前言1. 简介2. 快速开始2.1 onnx转tnn2.2 编译目标平台的 TNN 引擎2.3 使用编译好的 TNN 引擎进行推理3. 手动实现单算子卷积推理(浮点)4. 代码解析4.1 构建模型(单卷积层)4.2 构建解释器4.3 初始化tnn5. 模型量化5.1 编译量化工具5.2 量化scale的计算5.3 量化流程6. i…...

基础解惑:Linux 下文件描述符标志和文件状态标志区别

简述 文件描述符标志&#xff0c;是体现进程的文件描述符的状态&#xff0c;fork进程时&#xff0c;文件描述符被复制&#xff1b;目前只有一种文件描述符&#xff1a;FD_CLOEXEC文件状态标志&#xff0c;是体现进程打开文件的一些标志&#xff0c;fork时不会复制file 结构&am…...

学弟:如何在3个月内学会自动化测试?

有小学弟问&#xff1a;如何在3个月内学会自动化测试&#xff1f; 老实说如果你现在上班&#xff0c;之前主要在做功能测试&#xff0c;或者编程基础比较弱的话&#xff0c;三个月够呛。 如果你是脱产学习&#xff0c;每天能保持6&#xff5e;8小时学习时间的话&#xff0c;可…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始&#xff0c;我们会探讨数据链路层的差错控制功能&#xff0c;差错控制功能的主要目标是要发现并且解决一个帧内部的位错误&#xff0c;我们需要使用特殊的编码技术去发现帧内部的位错误&#xff0c;当我们发现位错误之后&#xff0c;通常来说有两种解决方案。第一…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

Module Federation 和 Native Federation 的比较

前言 Module Federation 是 Webpack 5 引入的微前端架构方案&#xff0c;允许不同独立构建的应用在运行时动态共享模块。 Native Federation 是 Angular 官方基于 Module Federation 理念实现的专为 Angular 优化的微前端方案。 概念解析 Module Federation (模块联邦) Modul…...

LLM基础1_语言模型如何处理文本

基于GitHub项目&#xff1a;https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken&#xff1a;OpenAI开发的专业"分词器" torch&#xff1a;Facebook开发的强力计算引擎&#xff0c;相当于超级计算器 理解词嵌入&#xff1a;给词语画"…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”&#xff0c;物流的终极形态正在诞生 想象这样的场景&#xff1a; 凌晨3点&#xff0c;某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径&#xff1b;AI视觉系统在0.1秒内扫描包裹信息&#xff1b;数字孪生平台正模拟次日峰值流量压力…...

今日学习:Spring线程池|并发修改异常|链路丢失|登录续期|VIP过期策略|数值类缓存

文章目录 优雅版线程池ThreadPoolTaskExecutor和ThreadPoolTaskExecutor的装饰器并发修改异常并发修改异常简介实现机制设计原因及意义 使用线程池造成的链路丢失问题线程池导致的链路丢失问题发生原因 常见解决方法更好的解决方法设计精妙之处 登录续期登录续期常见实现方式特…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

【MATLAB代码】基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),附源代码|订阅专栏后可直接查看

文章所述的代码实现了基于最大相关熵准则(MCC)的三维鲁棒卡尔曼滤波算法(MCC-KF),针对传感器观测数据中存在的脉冲型异常噪声问题,通过非线性加权机制提升滤波器的抗干扰能力。代码通过对比传统KF与MCC-KF在含异常值场景下的表现,验证了后者在状态估计鲁棒性方面的显著优…...

适应性Java用于现代 API:REST、GraphQL 和事件驱动

在快速发展的软件开发领域&#xff0c;REST、GraphQL 和事件驱动架构等新的 API 标准对于构建可扩展、高效的系统至关重要。Java 在现代 API 方面以其在企业应用中的稳定性而闻名&#xff0c;不断适应这些现代范式的需求。随着不断发展的生态系统&#xff0c;Java 在现代 API 方…...