当前位置: 首页 > news >正文

pytorch实现循环神经网络

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

PyTorch 提供三种主要的 RNN 变体:

  • nn.RNN:最基本的循环神经网络,适用于短时依赖任务。
  • nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。
  • nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
网络类型优势适用场景
RNN计算简单,适用于短时序列语音、文本处理(短序列)
LSTM适用于长序列,能记忆长期信息机器翻译、语音识别、股票预测
GRU比 LSTM 计算更高效,效果相似语音处理、文本生成

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples)  # 生成 1000 个等间距数据点y = torch.sin(x)  # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1))  # 过去 seq_length 作为输入Y_data.append(y[i + seq_length])  # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10  # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()

代码解析

数据生成

  • torch.linspace(0, 100, num_samples) 生成 1000 个均匀分布的数据点。
  • torch.sin(x) 计算正弦值,形成时间序列数据。
  • X过去 10 个时间步的数据,Y下一个时间步的预测目标

构建 RNN

  • nn.RNN(input_size, hidden_size, num_layers, batch_first=True) 定义循环神经网络
    • input_size=1:每个时间步只有一个输入值(正弦波)。
    • hidden_size=32:隐藏层神经元数目。
    • num_layers=1:单层 RNN。
  • self.fc = nn.Linear(hidden_size, output_size) 负责最终输出。

训练

  • 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
  • 使用 Adam 优化器 更新模型参数。
  • 每 10 个 epoch 输出一次损失 loss

测试 & 绘图

  • 关闭梯度计算 (torch.no_grad()),执行前向传播预测测试数据。
  • Matplotlib 绘制预测曲线与真实曲线。

运行效果

如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近

相关文章:

pytorch实现循环神经网络

人工智能例子汇总:AI常见的算法和例子-CSDN博客 PyTorch 提供三种主要的 RNN 变体: nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决…...

Java 16进制 10进制 2进制数 相互的转换

在 Java 中,进行进制之间的转换时,除了功能的正确性外,效率和安全性也很重要。为了确保高效和相对安全的转换,我们通常需要考虑: 性能:使用内置的转换方法,如 Integer.toHexString()、Integer.…...

力扣动态规划-14【算法学习day.108】

前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向(例如想要掌握基础用法,该刷哪些题?建议灵神的题单和代码随想录)和记录自己的学习过程,我的解析也不会做的非常详细,只会提供思路和一些关…...

数据结构day02

1 线性表的定义和基本操作 1.1 线性表的定义 分析: 1.1.1 问题一:我们为什么探讨线性表的定义和基本操作 在研究数据结构时,需要重点关注三个方面:逻辑结构、物理结构以及数据的运算。在本节内容里,我们首先来介绍线…...

随笔 | 写在一月的最后一天

. 前言 这个月比预想中过的要快更多。突然回看这一个月,还有点不知从何提笔。 整个一月可以总结为以下几个关键词: 期许,保持期许出现休息 . 期许 关于期许,没有什么时候比一年伊始更适合设立目标和计划的了。但令人惭愧的…...

JVM方法区

一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分,但是一些简单的实现可能不会去进行垃圾收集或者进行压缩,方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样,是各个…...

一文读懂fgc之cms

一文读懂 fgc之cms-实战篇 1. 前言 线上应用运行过程中可能会出现内存使用率较高,甚至达到95仍然不触发fgc的情况,存在内存打满风险,持续触发fgc回收;或者内存占用率较低时触发了fgc,导致某些接口tp99,tp…...

MYSQL 商城系统设计 商品数据表的设计 商品 商品类别 商品选项卡 多表查询

介绍 在开发商品模块时,通常使用分表的方式进行查询以及关联。在通过表连接的方式进行查询。每个商品都有不同的分类,每个不同分类下面都有商品规格可以选择,每个商品分类对应商品规格都有自己的价格和库存。在实际的开发中应该给这些表进行…...

HTB:Administrator[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…...

开源项目Umami网站统计MySQL8.0版本Docker+Linux安装部署教程

Umami是什么? Umami是一个开源项目,简单、快速、专注用户隐私的网站统计项目。 下面来介绍如何本地安装部署Umami项目,进行你的网站统计接入。特别对于首次使用docker的萌新有非常好的指导、参考和帮助作用。 Umami的github和docker镜像地…...

FBX SDK的使用:基础知识

Windows环境配置 FBX SDK安装后,目录下有三个文件夹: include 头文件lib 编译的二进制库,根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库,需要在配置属性->C/C->预…...

VisionMamba安装

1.安装python环境 conda create -n mamba python3.10.13 -y conda activate mamba2.安装torch环境 conda install cudatoolkit11.8 -c nvidia pip install torch2.1.1 torchvision0.16.1 torchaudio2.1.1 --index-url https://download.pytorch.org/whl/cu1183.安装其他包 c…...

h2oGPT

文章目录 一、关于 h2oGPT二、现场演示特点 三、开始行动安装h2oGPT拼贴画演示资源文档指南开发致谢为什么选择 H2O.ai?免责声明 一、关于 h2oGPT 使用文档、图像、视频等与本地GPT进行私人聊天。100%私人,Apache 2.0。支持oLLaMa、Mixtral、llama. cpp…...

Win10安装MySQL、Pycharm连接MySQL,Pycharm中运行Django

一、Windows系统mysql相关操作 1、 检查系统是否安装mysql 按住win r (调出运行窗口) 输入service.msc,点击【确定】 image.png 打开服务列表-检查是否有mysql服务 (compmgmt.msc) image.png 2、 Windows安装MySQL …...

使用Pygame制作“俄罗斯方块”游戏

1. 前言 俄罗斯方块(Tetris) 是一款由方块下落、行消除等核心规则构成的经典益智游戏: 每次从屏幕顶部出现一个随机的方块(由若干小方格组成),玩家可以左右移动或旋转该方块,让它合适地堆叠在…...

【Block总结】ODConv动态卷积,适用于CV任务|即插即用

一、论文信息 论文标题:Omni-Dimensional Dynamic Convolution作者:Chao Li, Aojun Zhou, Anbang Yao发表会议:ICLR 2022论文链接:https://arxiv.org/pdf/2209.07947GitHub链接:https://github.com/OSVAI/ODConv 二…...

RK3568 opencv播放视频

文章目录 一、opencv相关视频播放类1. `cv::VideoCapture` 类主要构造方法:主要方法:2. 视频播放基本流程代码示例:3. 获取和设置视频属性4. 结合 FFmpeg 使用5. OpenCV 视频播放的局限性6. 结合 Qt 实现更高级的视频播放总结二、QT中的代码实现一、opencv相关视频播放类 在…...

《LLM大语言模型+RAG实战+Langchain+ChatGLM-4+Transformer》

文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...

【搜索回溯算法篇】:拓宽算法视野--BFS如何解决拓扑排序问题

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:搜索回溯算法篇–CSDN博客 文章目录 一.广度优先搜索(BFS)解决拓扑排…...

计算机网络 (61)移动IP

前言 移动IP(Mobile IP)是由Internet工程任务小组(Internet Engineering Task Force,IETF)提出的一个协议,旨在解决移动设备在不同网络间切换时的通信问题,确保移动设备可以在离开原有网络或子网…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法

树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...

蓝桥杯 2024 15届国赛 A组 儿童节快乐

P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...

vue3 定时器-定义全局方法 vue+ts

1.创建ts文件 路径&#xff1a;src/utils/timer.ts 完整代码&#xff1a; import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式&#xff1a; 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战

“&#x1f916;手搓TuyaAI语音指令 &#x1f60d;秒变表情包大师&#xff0c;让萌系Otto机器人&#x1f525;玩出智能新花样&#xff01;开整&#xff01;” &#x1f916; Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制&#xff08;TuyaAI…...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar&#xff1a;依赖注入与仓储模式实践 在 C# 的应用开发中&#xff0c;数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护&#xff0c;许多开发者会选择成熟的 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;SqlSugar 就是其中备受…...

10-Oracle 23 ai Vector Search 概述和参数

一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI&#xff0c;使用客户端或是内部自己搭建集成大模型的终端&#xff0c;加速与大型语言模型&#xff08;LLM&#xff09;的结合&#xff0c;同时使用检索增强生成&#xff08;Retrieval Augmented Generation &#…...

回溯算法学习

一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下&#xff0c;推客小程序系统凭借其裂变传播、精准营销等特性&#xff0c;成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径&#xff0c;助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...