当前位置: 首页 > news >正文

【深度学习】RNN的简单实现

目录

1.RNNCell

2.RNN

3.RNN_Embedding


1.RNNCell

import torchinput_size = 4
hidden_size = 4
batch_size = 1idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]  # 输入:hello
y_data = [3, 1, 2, 3, 2]  # 期待:ohlol# 独热向量
one_hot_lookup = [[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]inputs = torch.Tensor(x_one_hot).view(-1, batch_size, input_size)  # (seqLen,batchSize,inputSize)
labels = torch.LongTensor(y_data).view(-1, 1)  # (seqLen,1)class Model(torch.nn.Module):def __init__(self, input_size, hidden_size, batch_size):super(Model, self).__init__()self.batch_size = batch_sizeself.input_size = input_sizeself.hidden_size = hidden_sizeself.rnncell = torch.nn.RNNCell(input_size=self.input_size,hidden_size=self.hidden_size)def forward(self, input, hidden):hidden = self.rnncell(input, hidden)  # input:(batch, input_size) hidden:(batch, hidden_size)return hiddendef init_hidden(self):return torch.zeros(self.batch_size, self.hidden_size)net = Model(input_size, hidden_size, batch_size)criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)for epoch in range(15):loss = 0optimizer.zero_grad()hidden = net.init_hidden()print('Predicted string: ', end='')for input, label in zip(inputs, labels):hidden = net(input, hidden)loss += criterion(hidden, label)idx = torch.argmax(hidden, dim=1)print(idx2char[idx.item()], end='')loss.backward()optimizer.step()print(', Epoch [%d/15] loss=%.4f' % (epoch+1, loss.item()))

2.RNN

import torchinput_size = 4  # 输入的维度,例如hello为四个字母表示,其维度为四
hidden_size = 4  # 隐藏层维度
num_layers = 1  # number of layers
batch_size = 1
seq_len = 5idx2char = ['e', 'h', 'l', 'o']
x_data = [1, 0, 2, 2, 3]  # 输入:hello
y_data = [3, 1, 2, 3, 2]  # 期待:ohlol# 独热向量
one_hot_lookup = [[1, 0, 0, 0],[0, 1, 0, 0],[0, 0, 1, 0],[0, 0, 0, 1]]
x_one_hot = [one_hot_lookup[x] for x in x_data]inputs = torch.Tensor(x_one_hot).view(seq_len, batch_size, input_size)
labels = torch.LongTensor(y_data)  # (seqSize*batchSize, 1)class Model(torch.nn.Module):def __init__(self, input_size, hidden_size, batch_size, num_layers=1):super(Model, self).__init__()self.num_layers = num_layersself.batch_size = batch_sizeself.input_size = input_sizeself.hidden_size = hidden_sizeself.rnn = torch.nn.RNN(input_size=self.input_size,hidden_size=self.hidden_size,num_layers=num_layers)def forward(self, input):hidden = torch.zeros(self.num_layers,self.batch_size,self.hidden_size)  # (numLayers, batchSize, hiddenSize)out, hidden_last = self.rnn(input, hidden)  # out:(seqLen, batchSize, hiddenSize), hidden_last:最后一个hiddenreturn out.view(-1, self.hidden_size)  # (seqLen×batchSize, hiddenSize)net = Model(input_size, hidden_size, batch_size, num_layers)criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)for epoch in range(15):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()idx = torch.argmax(outputs, dim=1)print('Predicted string: ', ''.join([idx2char[i.item()] for i in idx]), end='')print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

3.RNN_Embedding

import torch# parameters
num_class = 4  # 引入线性层,不用不必要求一个输入就有一个输出,可以多个
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]]  # (batch:1, seq_len:5)
y_data = [3, 1, 2, 3, 2]  # (batch * seq_len)
inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.emb = torch.nn.Embedding(input_size, embedding_size)self.rnn = torch.nn.RNN(input_size=embedding_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True)# batchSize在第一位: (batchSize:一共几个句子, seqLen:每个句子有几个单词, inputSize:每个单词有多少特征)self.fc = torch.nn.Linear(hidden_size, num_class)def forward(self, x):hidden = torch.zeros(num_layers, x.size(0), hidden_size)x = self.emb(x)  # (batch, seqLen, embeddingSize), 输入数据x首先经过嵌入层,将字符索引转换为向量x, _ = self.rnn(x, hidden)x = self.fc(x)return x.view(-1, num_class)net = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)for epoch in range(15):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()idx = torch.argmax(outputs, dim=1)print('Predicted string: ', ''.join([idx2char[i.item()] for i in idx]), end='')print(', Epoch [%d/15] loss = %.3f' % (epoch + 1, loss.item()))

相关文章:

【深度学习】RNN的简单实现

目录 1.RNNCell 2.RNN 3.RNN_Embedding 1.RNNCell import torchinput_size 4 hidden_size 4 batch_size 1idx2char [e, h, l, o] x_data [1, 0, 2, 2, 3] # 输入:hello y_data [3, 1, 2, 3, 2] # 期待:ohlol# 独热向量 one_hot_lookup [[1, …...

每次请求时,检查 JWT Token的有效期并决定是否需要刷新

为了在每次请求时检查 access_token 的有效期,并在过期时自动刷新,可以通过以下步骤实现: 1. 解析 JWT Token 获取过期时间 JWT token 的有效期是编码在 token 本身的,你可以通过解析 token 来获取它的到期时间。JWT token 是由…...

AI大模型开发架构设计(13)——LLM大模型的向量数据库应用实战

文章目录 LLM大模型的向量数据库应用实战1 大模型的局限性大模型的4点局限性大模型的4点局限性的改进实践方法 2 向量数据库使用场景以及改建大模型向量数据库向量数据库选型知识库文档检索增强(Retrieval Augmented Generation) 3 向量数据库应用技术架构剖析向量数据库应用技…...

WPF中Grid、StackPanel、Canvas、WrapPanel常用属性

Grid常用属性 Grid 控件在 WPF 中非常强大,它提供了多种属性来定义行和列的布局。以下是一些常用的 Grid 属性: RowDefinitions 和 ColumnDefinitions: Grid 控件使用 RowDefinitions 和 ColumnDefinitions 来定义行和列的集合。每个 RowDef…...

【芙丽芳丝净润洗面霜和雅漾舒护活泉喷雾

1. 洁面产品: - 芙丽芳丝净润洗面霜:氨基酸洗面奶的经典产品,成分温和,不含酒精、香料等刺激性成分。泡沫丰富细腻,能够有效清洁皮肤的同时,不会过度剥夺皮肤的油脂,洗后皮肤不紧绷,…...

ubuntu更新Cmake

CMake 先验知识创建软链接如何删除符号链接如何找出失效链接并将其删除PATH 优先级查看当前CMake命令的位置 高版本 CMake 安装参考 先验知识 创建软链接 ln -s <path to the file/folder to be linked> <the path of the link to be created>ln 是链接命令&…...

CMOS晶体管的串联与并联

CMOS晶体管的串联与并联 前言 对于mos管的串联和并联&#xff0c;一直没有整明白&#xff0c;特别是设计到EDA软件中&#xff0c;关于MOS的M和F参数&#xff0c;就更困惑了&#xff0c;今天看了许多资料以及在EDA软件上验证了电路结构与版图的对应关系&#xff0c;总算有点收…...

从IT高管到看门大爷:53岁我的职场华丽转身

该文讲述了一位1971年出生的男士&#xff0c;在53岁时因日企撤资而失业。他曾是IT技术员&#xff0c;后晋升为IT高管兼工会主席&#xff0c;但失业后数百份简历石沉大海&#xff0c;面试也因年龄被取消。他意识到年龄是求职的障碍&#xff0c;开始调整心态&#xff0c;降低期望…...

Redis入门到精通(三):入门Redis看这一篇就够了

文章目录 Redis分布式锁的实现原理Redis实现分布式锁如何合理的控制锁的有效时常&#xff1f;**redisson实现的分布式锁**redisson实现的如何保证主从一致性 Redis的集群方案1.主从复制主从数据的同步原理全量同步增量同步 2.哨兵模式Redis的集群脑裂是什么&#xff1f;3.分片集…...

IP基本原理

IP的定义 当前唯一的网络层协议标准定义数据网络层的封装方式、编址方法 MTU 最大传输单元接口收发数据支持的单个包的最大长度不同二层链路类型的接口的MTU不一致。以太网接口默认MTU1500Byte。PPPoE接口默认MTU1480Byte。 IP头部封装格式 IP 头部长度不固定&#xff0c;2…...

数据分析题面试题系列2

一.如何估算星巴克一天的营业额 a.需求澄清&#xff1a;区域&#xff1f;节假日&#xff1f;产品范围&#xff1f; b.收入销售杯数*单价&#xff08;营业时间*每小时产能*每小时产能利用率&#xff09;*平均单价 Hypo该星巴克门店的营业时间为12小时&#xff08;取整&#x…...

uniapp 单表、多级动态表单添加validateFunction自定义规则

uniapp 多级动态表单添加自定义规则 在uniapp制作小程序时&#xff0c;当涉及到需要设置validateFunction的校验规则时。可能遇到的问题 1、validateFunction不生效&#xff0c;没有触发 2、多层级表单怎么添加validateFunction自定义校验规则 本文将以单表单校验和多表单校…...

FPGA高端图像处理培训第一期,提供工程源码+视频教程+FPGA开发板

目录 1、FPGA图像处理培训现状分析2、本FPGA图像处理培训优势亮点架构全起点高实用性强项目应用级别细节恐怖工程源码清晰 3、本FPGA图像处理培训内容介绍图像处理基本框架图像前处理框架图像中处理框架图像前中处理框架图像后处理框架图像中后处理框架图像处理仿真框架视频教程…...

顺序表的实现(数据结构)——C语言

目录 1.结构与概念 2.分类 3 动态顺序表的实现 SeqList.h SeqList.c 创建SLInit&#xff1a; 尾插SLPushBack以及SLCheak&#xff08;检查空间是否足够&#xff09;&#xff1a; 头插SLPushFront&#xff1a; 尾删SLPopBack 头删SLPopFront 查找指定元素SLFind 指定…...

【VUE】Vue中 computed计算属性和watch侦听器的区别

核心功能不同 computed 是一个计算属性&#xff0c;其核心功能是基于已有的数据属性计算得出新的属性值。当某个依赖的数据发生变化时&#xff0c;computed 会自动重新计算并更新自己的值。因此&#xff0c;可以将 computed 看做是一种“派生状态”。 watch 是一个观察者函数&…...

linux线程 | 同步与互斥 | 深度学习与理解同步

前言&#xff1a;本节内容主要讲解linux下的同步问题。 同步问题是保证数据安全的情况下&#xff0c;让我们的线程访问具有一定的顺序性。 线程安全就规定了它必须是在加锁的场景下的&#xff01;&#xff01;那么&#xff0c; 具体什么是同步问题&#xff0c; 我们加下来看看吧…...

Tkinter Frame布局笔记--做一个简易的计算器

#encodingutf-8 import tkinter import re import tkinter.messagebox import tkinter.simpledialog import sys import os def get_resources_path(relative_path):if getattr(sys,frozen, False):base_pathsys._MEIPASS#获取临时文件else:base_pathos.path.dirname(".&q…...

算法专题八: 链表

目录 链表1. 链表的常用技巧和操作总结2. 两数相加3. 两两交换链表中的节点4. 重排链表5. 合并K个升序链表6. K个一组翻转链表 链表 1. 链表的常用技巧和操作总结 常用技巧 画图!!! 更加直观形象, 便于我们理解引入虚拟头节点, 方便我们对链表的操作, 减少我们对边界情况的考…...

MySQL中关于NULL值的六大坑!你被坑过吗?

NULL值是我们在开发过程中的老朋友了&#xff0c;但是这个老朋友在MySQL中有很多坑&#xff0c;我通过这篇文章来总结分享一下&#xff0c;欢迎大家在评论区分享你的看法和踩坑经历。 1、NULL不等于NULL 在MySQL中&#xff0c;执行以下SQL会返回NULL 假如t表有以下数据&#…...

学生学习动机测试:激发潜能,引领未来

学习动机、学习兴趣和学习目标制定是影响学生学习成效的三个关键因素。通过对学生学习动机的测试,我们可以深入了解学生的学习状态,进而采取针对性的措施,激发他们的学习潜能,引导他们走向更加光明的未来。本文将从学习动机、学习兴趣和学习目标制定三个方面,详细探讨学生…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架&#xff0c;它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用&#xff0c;和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

大话软工笔记—需求分析概述

需求分析&#xff0c;就是要对需求调研收集到的资料信息逐个地进行拆分、研究&#xff0c;从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要&#xff0c;后续设计的依据主要来自于需求分析的成果&#xff0c;包括: 项目的目的…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

五年级数学知识边界总结思考-下册

目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解&#xff1a;由来、作用与意义**一、知识点核心内容****二、知识点的由来&#xff1a;从生活实践到数学抽象****三、知识的作用&#xff1a;解决实际问题的工具****四、学习的意义&#xff1a;培养核心素养…...

linux arm系统烧录

1、打开瑞芯微程序 2、按住linux arm 的 recover按键 插入电源 3、当瑞芯微检测到有设备 4、松开recover按键 5、选择升级固件 6、点击固件选择本地刷机的linux arm 镜像 7、点击升级 &#xff08;忘了有没有这步了 估计有&#xff09; 刷机程序 和 镜像 就不提供了。要刷的时…...

Nuxt.js 中的路由配置详解

Nuxt.js 通过其内置的路由系统简化了应用的路由配置&#xff0c;使得开发者可以轻松地管理页面导航和 URL 结构。路由配置主要涉及页面组件的组织、动态路由的设置以及路由元信息的配置。 自动路由生成 Nuxt.js 会根据 pages 目录下的文件结构自动生成路由配置。每个文件都会对…...

C++ 求圆面积的程序(Program to find area of a circle)

给定半径r&#xff0c;求圆的面积。圆的面积应精确到小数点后5位。 例子&#xff1a; 输入&#xff1a;r 5 输出&#xff1a;78.53982 解释&#xff1a;由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982&#xff0c;因为我们只保留小数点后 5 位数字。 输…...

【Oracle】分区表

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...

重启Eureka集群中的节点,对已经注册的服务有什么影响

先看答案&#xff0c;如果正确地操作&#xff0c;重启Eureka集群中的节点&#xff0c;对已经注册的服务影响非常小&#xff0c;甚至可以做到无感知。 但如果操作不当&#xff0c;可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...