pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客
PyTorch 提供三种主要的 RNN 变体:
nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
| 网络类型 | 优势 | 适用场景 |
|---|---|---|
| RNN | 计算简单,适用于短时序列 | 语音、文本处理(短序列) |
| LSTM | 适用于长序列,能记忆长期信息 | 机器翻译、语音识别、股票预测 |
| GRU | 比 LSTM 计算更高效,效果相似 | 语音处理、文本生成 |
例子:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples) # 生成 1000 个等间距数据点y = torch.sin(x) # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1)) # 过去 seq_length 作为输入Y_data.append(y[i + seq_length]) # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10 # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()
代码解析
数据生成
torch.linspace(0, 100, num_samples)生成 1000 个均匀分布的数据点。torch.sin(x)计算正弦值,形成时间序列数据。X为过去 10 个时间步的数据,Y为下一个时间步的预测目标。
构建 RNN
nn.RNN(input_size, hidden_size, num_layers, batch_first=True)定义循环神经网络:input_size=1:每个时间步只有一个输入值(正弦波)。hidden_size=32:隐藏层神经元数目。num_layers=1:单层 RNN。
self.fc = nn.Linear(hidden_size, output_size)负责最终输出。
训练
- 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
- 使用 Adam 优化器 更新模型参数。
- 每 10 个
epoch输出一次损失loss。
测试 & 绘图
- 关闭梯度计算 (
torch.no_grad()),执行前向传播预测测试数据。 - Matplotlib 绘制预测曲线与真实曲线。
运行效果
如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近:
相关文章:
pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客 PyTorch 提供三种主要的 RNN 变体: nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决…...
Java 16进制 10进制 2进制数 相互的转换
在 Java 中,进行进制之间的转换时,除了功能的正确性外,效率和安全性也很重要。为了确保高效和相对安全的转换,我们通常需要考虑: 性能:使用内置的转换方法,如 Integer.toHexString()、Integer.…...
力扣动态规划-14【算法学习day.108】
前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向(例如想要掌握基础用法,该刷哪些题?建议灵神的题单和代码随想录)和记录自己的学习过程,我的解析也不会做的非常详细,只会提供思路和一些关…...
数据结构day02
1 线性表的定义和基本操作 1.1 线性表的定义 分析: 1.1.1 问题一:我们为什么探讨线性表的定义和基本操作 在研究数据结构时,需要重点关注三个方面:逻辑结构、物理结构以及数据的运算。在本节内容里,我们首先来介绍线…...
随笔 | 写在一月的最后一天
. 前言 这个月比预想中过的要快更多。突然回看这一个月,还有点不知从何提笔。 整个一月可以总结为以下几个关键词: 期许,保持期许出现休息 . 期许 关于期许,没有什么时候比一年伊始更适合设立目标和计划的了。但令人惭愧的…...
JVM方法区
一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分,但是一些简单的实现可能不会去进行垃圾收集或者进行压缩,方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样,是各个…...
一文读懂fgc之cms
一文读懂 fgc之cms-实战篇 1. 前言 线上应用运行过程中可能会出现内存使用率较高,甚至达到95仍然不触发fgc的情况,存在内存打满风险,持续触发fgc回收;或者内存占用率较低时触发了fgc,导致某些接口tp99,tp…...
MYSQL 商城系统设计 商品数据表的设计 商品 商品类别 商品选项卡 多表查询
介绍 在开发商品模块时,通常使用分表的方式进行查询以及关联。在通过表连接的方式进行查询。每个商品都有不同的分类,每个不同分类下面都有商品规格可以选择,每个商品分类对应商品规格都有自己的价格和库存。在实际的开发中应该给这些表进行…...
HTB:Administrator[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…...
开源项目Umami网站统计MySQL8.0版本Docker+Linux安装部署教程
Umami是什么? Umami是一个开源项目,简单、快速、专注用户隐私的网站统计项目。 下面来介绍如何本地安装部署Umami项目,进行你的网站统计接入。特别对于首次使用docker的萌新有非常好的指导、参考和帮助作用。 Umami的github和docker镜像地…...
FBX SDK的使用:基础知识
Windows环境配置 FBX SDK安装后,目录下有三个文件夹: include 头文件lib 编译的二进制库,根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库,需要在配置属性->C/C->预…...
VisionMamba安装
1.安装python环境 conda create -n mamba python3.10.13 -y conda activate mamba2.安装torch环境 conda install cudatoolkit11.8 -c nvidia pip install torch2.1.1 torchvision0.16.1 torchaudio2.1.1 --index-url https://download.pytorch.org/whl/cu1183.安装其他包 c…...
h2oGPT
文章目录 一、关于 h2oGPT二、现场演示特点 三、开始行动安装h2oGPT拼贴画演示资源文档指南开发致谢为什么选择 H2O.ai?免责声明 一、关于 h2oGPT 使用文档、图像、视频等与本地GPT进行私人聊天。100%私人,Apache 2.0。支持oLLaMa、Mixtral、llama. cpp…...
Win10安装MySQL、Pycharm连接MySQL,Pycharm中运行Django
一、Windows系统mysql相关操作 1、 检查系统是否安装mysql 按住win r (调出运行窗口) 输入service.msc,点击【确定】 image.png 打开服务列表-检查是否有mysql服务 (compmgmt.msc) image.png 2、 Windows安装MySQL …...
使用Pygame制作“俄罗斯方块”游戏
1. 前言 俄罗斯方块(Tetris) 是一款由方块下落、行消除等核心规则构成的经典益智游戏: 每次从屏幕顶部出现一个随机的方块(由若干小方格组成),玩家可以左右移动或旋转该方块,让它合适地堆叠在…...
【Block总结】ODConv动态卷积,适用于CV任务|即插即用
一、论文信息 论文标题:Omni-Dimensional Dynamic Convolution作者:Chao Li, Aojun Zhou, Anbang Yao发表会议:ICLR 2022论文链接:https://arxiv.org/pdf/2209.07947GitHub链接:https://github.com/OSVAI/ODConv 二…...
RK3568 opencv播放视频
文章目录 一、opencv相关视频播放类1. `cv::VideoCapture` 类主要构造方法:主要方法:2. 视频播放基本流程代码示例:3. 获取和设置视频属性4. 结合 FFmpeg 使用5. OpenCV 视频播放的局限性6. 结合 Qt 实现更高级的视频播放总结二、QT中的代码实现一、opencv相关视频播放类 在…...
《LLM大语言模型+RAG实战+Langchain+ChatGLM-4+Transformer》
文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...
【搜索回溯算法篇】:拓宽算法视野--BFS如何解决拓扑排序问题
✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:搜索回溯算法篇–CSDN博客 文章目录 一.广度优先搜索(BFS)解决拓扑排…...
计算机网络 (61)移动IP
前言 移动IP(Mobile IP)是由Internet工程任务小组(Internet Engineering Task Force,IETF)提出的一个协议,旨在解决移动设备在不同网络间切换时的通信问题,确保移动设备可以在离开原有网络或子网…...
Harmonyos应用实例215: 条件概率模拟器
7. 条件概率模拟器 功能简介:通过模拟抽卡片、掷骰子等实验,展示条件概率的计算方法,验证贝叶斯定理。支持调整实验参数,实时显示概率结果和理论值对比,帮助学生理解条件概率的概念。 ArkTS代码: @Entry @Component struct ConditionalProbability {@State private...
webMAN-MOD实战指南:构建PS3主机扩展服务系统
webMAN-MOD实战指南:构建PS3主机扩展服务系统 【免费下载链接】webMAN-MOD Extended services for PS3 console (web server, ftp server, netiso, ntfs, ps3mapi, etc.) 项目地址: https://gitcode.com/gh_mirrors/we/webMAN-MOD 当你在PS3主机上尝试加载网…...
GD32F30x串口DMA+空闲中断接收不定长数据,一个LED控制项目带你搞懂
GD32F30x串口DMA空闲中断实战:从零构建LED智能控制系统 在嵌入式开发中,串口通信就像设备的"嘴巴"和"耳朵",而DMA技术则是解放CPU的"隐形助手"。想象一下这样的场景:你需要通过手机APP远程控制实验…...
Suricata在CentOS7上的性能优化:如何配置网卡混杂模式与端口聚合
Suricata在CentOS7上的性能优化:网卡混杂模式与端口聚合实战指南 当企业网络流量突破千兆级别时,传统单网卡监控方案往往力不从心。我曾为某金融客户部署Suricata时,单台服务器每天要处理超过2TB的流量数据,正是通过下文介绍的网卡…...
Pixel Dream Workshop 作品集:基于LSTM时序模型生成的动态艺术画展示
Pixel Dream Workshop 作品集:基于LSTM时序模型生成的动态艺术画展示 1. 当AI遇见艺术:LSTM如何创造动态视觉叙事 在数字艺术创作领域,时序模型正带来一场革命性的变化。Pixel Dream Workshop最新推出的动态艺术画系列,展示了长…...
解密革命性构建工具:PoeCharm如何突破传统限制实现高效角色规划
解密革命性构建工具:PoeCharm如何突破传统限制实现高效角色规划 【免费下载链接】PoeCharm Path of Building Chinese version 项目地址: https://gitcode.com/gh_mirrors/po/PoeCharm 在流放之路的复杂游戏生态中,角色构建往往成为玩家面临的最大…...
DeepSeek-R1-Distill-Qwen-1.5B响应慢?函数调用优化实战解决方案
DeepSeek-R1-Distill-Qwen-1.5B响应慢?函数调用优化实战解决方案 你是不是也遇到过这种情况:好不容易在本地部署了DeepSeek-R1-Distill-Qwen-1.5B这个“小钢炮”模型,结果发现函数调用时响应特别慢?明明官方说RTX 3060能跑200 to…...
IT运维监控/可观测性
?? 前言:为什么选择 OpenClaw 对接企业微信? 在2026年的企业数字化办公浪潮中,OpenClaw(曾用名 Clawdbot、Moltbot)已成长为国内领先的开源AI自动化代理工具。凭借其“自然语言驱动、插件化拓展、多平台无缝集成”的…...
Obsidian-i18n:破解插件语言壁垒的无缝本地化方案——让中文用户零门槛掌控千款插件
Obsidian-i18n:破解插件语言壁垒的无缝本地化方案——让中文用户零门槛掌控千款插件 【免费下载链接】obsidian-i18n 项目地址: https://gitcode.com/gh_mirrors/ob/obsidian-i18n 问题诊断:插件语言障碍如何制约Obsidian用户体验? …...
快速找回Chrome密码:ChromePass终极使用指南
快速找回Chrome密码:ChromePass终极使用指南 【免费下载链接】chromepass Get all passwords stored by Chrome on WINDOWS. 项目地址: https://gitcode.com/gh_mirrors/chr/chromepass 你是否曾经因为忘记Chrome浏览器中保存的重要登录密码而感到困扰&#…...
