pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客
PyTorch 提供三种主要的 RNN 变体:
nn.RNN
:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM
:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。nn.GRU
:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
网络类型 | 优势 | 适用场景 |
---|---|---|
RNN | 计算简单,适用于短时序列 | 语音、文本处理(短序列) |
LSTM | 适用于长序列,能记忆长期信息 | 机器翻译、语音识别、股票预测 |
GRU | 比 LSTM 计算更高效,效果相似 | 语音处理、文本生成 |
例子:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples) # 生成 1000 个等间距数据点y = torch.sin(x) # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1)) # 过去 seq_length 作为输入Y_data.append(y[i + seq_length]) # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10 # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()
代码解析
数据生成
torch.linspace(0, 100, num_samples)
生成 1000 个均匀分布的数据点。torch.sin(x)
计算正弦值,形成时间序列数据。X
为过去 10 个时间步的数据,Y
为下一个时间步的预测目标。
构建 RNN
nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
定义循环神经网络:input_size=1
:每个时间步只有一个输入值(正弦波)。hidden_size=32
:隐藏层神经元数目。num_layers=1
:单层 RNN。
self.fc = nn.Linear(hidden_size, output_size)
负责最终输出。
训练
- 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
- 使用 Adam 优化器 更新模型参数。
- 每 10 个
epoch
输出一次损失loss
。
测试 & 绘图
- 关闭梯度计算 (
torch.no_grad()
),执行前向传播预测测试数据。 - Matplotlib 绘制预测曲线与真实曲线。
运行效果
如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近:
相关文章:
pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客 PyTorch 提供三种主要的 RNN 变体: nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决…...
Java 16进制 10进制 2进制数 相互的转换
在 Java 中,进行进制之间的转换时,除了功能的正确性外,效率和安全性也很重要。为了确保高效和相对安全的转换,我们通常需要考虑: 性能:使用内置的转换方法,如 Integer.toHexString()、Integer.…...

力扣动态规划-14【算法学习day.108】
前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向(例如想要掌握基础用法,该刷哪些题?建议灵神的题单和代码随想录)和记录自己的学习过程,我的解析也不会做的非常详细,只会提供思路和一些关…...

数据结构day02
1 线性表的定义和基本操作 1.1 线性表的定义 分析: 1.1.1 问题一:我们为什么探讨线性表的定义和基本操作 在研究数据结构时,需要重点关注三个方面:逻辑结构、物理结构以及数据的运算。在本节内容里,我们首先来介绍线…...
随笔 | 写在一月的最后一天
. 前言 这个月比预想中过的要快更多。突然回看这一个月,还有点不知从何提笔。 整个一月可以总结为以下几个关键词: 期许,保持期许出现休息 . 期许 关于期许,没有什么时候比一年伊始更适合设立目标和计划的了。但令人惭愧的…...

JVM方法区
一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分,但是一些简单的实现可能不会去进行垃圾收集或者进行压缩,方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样,是各个…...
一文读懂fgc之cms
一文读懂 fgc之cms-实战篇 1. 前言 线上应用运行过程中可能会出现内存使用率较高,甚至达到95仍然不触发fgc的情况,存在内存打满风险,持续触发fgc回收;或者内存占用率较低时触发了fgc,导致某些接口tp99,tp…...

MYSQL 商城系统设计 商品数据表的设计 商品 商品类别 商品选项卡 多表查询
介绍 在开发商品模块时,通常使用分表的方式进行查询以及关联。在通过表连接的方式进行查询。每个商品都有不同的分类,每个不同分类下面都有商品规格可以选择,每个商品分类对应商品规格都有自己的价格和库存。在实际的开发中应该给这些表进行…...

HTB:Administrator[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…...

开源项目Umami网站统计MySQL8.0版本Docker+Linux安装部署教程
Umami是什么? Umami是一个开源项目,简单、快速、专注用户隐私的网站统计项目。 下面来介绍如何本地安装部署Umami项目,进行你的网站统计接入。特别对于首次使用docker的萌新有非常好的指导、参考和帮助作用。 Umami的github和docker镜像地…...

FBX SDK的使用:基础知识
Windows环境配置 FBX SDK安装后,目录下有三个文件夹: include 头文件lib 编译的二进制库,根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库,需要在配置属性->C/C->预…...

VisionMamba安装
1.安装python环境 conda create -n mamba python3.10.13 -y conda activate mamba2.安装torch环境 conda install cudatoolkit11.8 -c nvidia pip install torch2.1.1 torchvision0.16.1 torchaudio2.1.1 --index-url https://download.pytorch.org/whl/cu1183.安装其他包 c…...
h2oGPT
文章目录 一、关于 h2oGPT二、现场演示特点 三、开始行动安装h2oGPT拼贴画演示资源文档指南开发致谢为什么选择 H2O.ai?免责声明 一、关于 h2oGPT 使用文档、图像、视频等与本地GPT进行私人聊天。100%私人,Apache 2.0。支持oLLaMa、Mixtral、llama. cpp…...

Win10安装MySQL、Pycharm连接MySQL,Pycharm中运行Django
一、Windows系统mysql相关操作 1、 检查系统是否安装mysql 按住win r (调出运行窗口) 输入service.msc,点击【确定】 image.png 打开服务列表-检查是否有mysql服务 (compmgmt.msc) image.png 2、 Windows安装MySQL …...

使用Pygame制作“俄罗斯方块”游戏
1. 前言 俄罗斯方块(Tetris) 是一款由方块下落、行消除等核心规则构成的经典益智游戏: 每次从屏幕顶部出现一个随机的方块(由若干小方格组成),玩家可以左右移动或旋转该方块,让它合适地堆叠在…...

【Block总结】ODConv动态卷积,适用于CV任务|即插即用
一、论文信息 论文标题:Omni-Dimensional Dynamic Convolution作者:Chao Li, Aojun Zhou, Anbang Yao发表会议:ICLR 2022论文链接:https://arxiv.org/pdf/2209.07947GitHub链接:https://github.com/OSVAI/ODConv 二…...
RK3568 opencv播放视频
文章目录 一、opencv相关视频播放类1. `cv::VideoCapture` 类主要构造方法:主要方法:2. 视频播放基本流程代码示例:3. 获取和设置视频属性4. 结合 FFmpeg 使用5. OpenCV 视频播放的局限性6. 结合 Qt 实现更高级的视频播放总结二、QT中的代码实现一、opencv相关视频播放类 在…...

《LLM大语言模型+RAG实战+Langchain+ChatGLM-4+Transformer》
文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...

【搜索回溯算法篇】:拓宽算法视野--BFS如何解决拓扑排序问题
✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:搜索回溯算法篇–CSDN博客 文章目录 一.广度优先搜索(BFS)解决拓扑排…...

计算机网络 (61)移动IP
前言 移动IP(Mobile IP)是由Internet工程任务小组(Internet Engineering Task Force,IETF)提出的一个协议,旨在解决移动设备在不同网络间切换时的通信问题,确保移动设备可以在离开原有网络或子网…...

超短脉冲激光自聚焦效应
前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...

树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
蓝桥杯 2024 15届国赛 A组 儿童节快乐
P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...
vue3 定时器-定义全局方法 vue+ts
1.创建ts文件 路径:src/utils/timer.ts 完整代码: import { onUnmounted } from vuetype TimerCallback (...args: any[]) > voidexport function useGlobalTimer() {const timers: Map<number, NodeJS.Timeout> new Map()// 创建定时器con…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

涂鸦T5AI手搓语音、emoji、otto机器人从入门到实战
“🤖手搓TuyaAI语音指令 😍秒变表情包大师,让萌系Otto机器人🔥玩出智能新花样!开整!” 🤖 Otto机器人 → 直接点明主体 手搓TuyaAI语音 → 强调 自主编程/自定义 语音控制(TuyaAI…...
C# SqlSugar:依赖注入与仓储模式实践
C# SqlSugar:依赖注入与仓储模式实践 在 C# 的应用开发中,数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护,许多开发者会选择成熟的 ORM(对象关系映射)框架,SqlSugar 就是其中备受…...

10-Oracle 23 ai Vector Search 概述和参数
一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI,使用客户端或是内部自己搭建集成大模型的终端,加速与大型语言模型(LLM)的结合,同时使用检索增强生成(Retrieval Augmented Generation &#…...

回溯算法学习
一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...
「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案
在移动互联网营销竞争白热化的当下,推客小程序系统凭借其裂变传播、精准营销等特性,成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径,助力开发者打造具有市场竞争力的营销工具。 一、系统核心功能架构&…...