pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客
PyTorch 提供三种主要的 RNN 变体:
nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
| 网络类型 | 优势 | 适用场景 |
|---|---|---|
| RNN | 计算简单,适用于短时序列 | 语音、文本处理(短序列) |
| LSTM | 适用于长序列,能记忆长期信息 | 机器翻译、语音识别、股票预测 |
| GRU | 比 LSTM 计算更高效,效果相似 | 语音处理、文本生成 |
例子:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples) # 生成 1000 个等间距数据点y = torch.sin(x) # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1)) # 过去 seq_length 作为输入Y_data.append(y[i + seq_length]) # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10 # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()
代码解析
数据生成
torch.linspace(0, 100, num_samples)生成 1000 个均匀分布的数据点。torch.sin(x)计算正弦值,形成时间序列数据。X为过去 10 个时间步的数据,Y为下一个时间步的预测目标。
构建 RNN
nn.RNN(input_size, hidden_size, num_layers, batch_first=True)定义循环神经网络:input_size=1:每个时间步只有一个输入值(正弦波)。hidden_size=32:隐藏层神经元数目。num_layers=1:单层 RNN。
self.fc = nn.Linear(hidden_size, output_size)负责最终输出。
训练
- 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
- 使用 Adam 优化器 更新模型参数。
- 每 10 个
epoch输出一次损失loss。
测试 & 绘图
- 关闭梯度计算 (
torch.no_grad()),执行前向传播预测测试数据。 - Matplotlib 绘制预测曲线与真实曲线。
运行效果
如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近:
相关文章:
pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客 PyTorch 提供三种主要的 RNN 变体: nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决…...
Java 16进制 10进制 2进制数 相互的转换
在 Java 中,进行进制之间的转换时,除了功能的正确性外,效率和安全性也很重要。为了确保高效和相对安全的转换,我们通常需要考虑: 性能:使用内置的转换方法,如 Integer.toHexString()、Integer.…...
力扣动态规划-14【算法学习day.108】
前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向(例如想要掌握基础用法,该刷哪些题?建议灵神的题单和代码随想录)和记录自己的学习过程,我的解析也不会做的非常详细,只会提供思路和一些关…...
数据结构day02
1 线性表的定义和基本操作 1.1 线性表的定义 分析: 1.1.1 问题一:我们为什么探讨线性表的定义和基本操作 在研究数据结构时,需要重点关注三个方面:逻辑结构、物理结构以及数据的运算。在本节内容里,我们首先来介绍线…...
随笔 | 写在一月的最后一天
. 前言 这个月比预想中过的要快更多。突然回看这一个月,还有点不知从何提笔。 整个一月可以总结为以下几个关键词: 期许,保持期许出现休息 . 期许 关于期许,没有什么时候比一年伊始更适合设立目标和计划的了。但令人惭愧的…...
JVM方法区
一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分,但是一些简单的实现可能不会去进行垃圾收集或者进行压缩,方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样,是各个…...
一文读懂fgc之cms
一文读懂 fgc之cms-实战篇 1. 前言 线上应用运行过程中可能会出现内存使用率较高,甚至达到95仍然不触发fgc的情况,存在内存打满风险,持续触发fgc回收;或者内存占用率较低时触发了fgc,导致某些接口tp99,tp…...
MYSQL 商城系统设计 商品数据表的设计 商品 商品类别 商品选项卡 多表查询
介绍 在开发商品模块时,通常使用分表的方式进行查询以及关联。在通过表连接的方式进行查询。每个商品都有不同的分类,每个不同分类下面都有商品规格可以选择,每个商品分类对应商品规格都有自己的价格和库存。在实际的开发中应该给这些表进行…...
HTB:Administrator[WriteUP]
目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…...
开源项目Umami网站统计MySQL8.0版本Docker+Linux安装部署教程
Umami是什么? Umami是一个开源项目,简单、快速、专注用户隐私的网站统计项目。 下面来介绍如何本地安装部署Umami项目,进行你的网站统计接入。特别对于首次使用docker的萌新有非常好的指导、参考和帮助作用。 Umami的github和docker镜像地…...
FBX SDK的使用:基础知识
Windows环境配置 FBX SDK安装后,目录下有三个文件夹: include 头文件lib 编译的二进制库,根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库,需要在配置属性->C/C->预…...
VisionMamba安装
1.安装python环境 conda create -n mamba python3.10.13 -y conda activate mamba2.安装torch环境 conda install cudatoolkit11.8 -c nvidia pip install torch2.1.1 torchvision0.16.1 torchaudio2.1.1 --index-url https://download.pytorch.org/whl/cu1183.安装其他包 c…...
h2oGPT
文章目录 一、关于 h2oGPT二、现场演示特点 三、开始行动安装h2oGPT拼贴画演示资源文档指南开发致谢为什么选择 H2O.ai?免责声明 一、关于 h2oGPT 使用文档、图像、视频等与本地GPT进行私人聊天。100%私人,Apache 2.0。支持oLLaMa、Mixtral、llama. cpp…...
Win10安装MySQL、Pycharm连接MySQL,Pycharm中运行Django
一、Windows系统mysql相关操作 1、 检查系统是否安装mysql 按住win r (调出运行窗口) 输入service.msc,点击【确定】 image.png 打开服务列表-检查是否有mysql服务 (compmgmt.msc) image.png 2、 Windows安装MySQL …...
使用Pygame制作“俄罗斯方块”游戏
1. 前言 俄罗斯方块(Tetris) 是一款由方块下落、行消除等核心规则构成的经典益智游戏: 每次从屏幕顶部出现一个随机的方块(由若干小方格组成),玩家可以左右移动或旋转该方块,让它合适地堆叠在…...
【Block总结】ODConv动态卷积,适用于CV任务|即插即用
一、论文信息 论文标题:Omni-Dimensional Dynamic Convolution作者:Chao Li, Aojun Zhou, Anbang Yao发表会议:ICLR 2022论文链接:https://arxiv.org/pdf/2209.07947GitHub链接:https://github.com/OSVAI/ODConv 二…...
RK3568 opencv播放视频
文章目录 一、opencv相关视频播放类1. `cv::VideoCapture` 类主要构造方法:主要方法:2. 视频播放基本流程代码示例:3. 获取和设置视频属性4. 结合 FFmpeg 使用5. OpenCV 视频播放的局限性6. 结合 Qt 实现更高级的视频播放总结二、QT中的代码实现一、opencv相关视频播放类 在…...
《LLM大语言模型+RAG实战+Langchain+ChatGLM-4+Transformer》
文章目录 Langchain的定义Langchain的组成三个核心组件实现整个核心组成部分 为什么要使用LangchainLangchain的底层原理Langchain实战操作LangSmithLangChain调用LLM安装openAI库-国内镜像源代码运行结果小结 使用Langchain的提示模板部署Langchain程序安装langserve代码请求格…...
【搜索回溯算法篇】:拓宽算法视野--BFS如何解决拓扑排序问题
✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:搜索回溯算法篇–CSDN博客 文章目录 一.广度优先搜索(BFS)解决拓扑排…...
计算机网络 (61)移动IP
前言 移动IP(Mobile IP)是由Internet工程任务小组(Internet Engineering Task Force,IETF)提出的一个协议,旨在解决移动设备在不同网络间切换时的通信问题,确保移动设备可以在离开原有网络或子网…...
docker详细操作--未完待续
docker介绍 docker官网: Docker:加速容器应用程序开发 harbor官网:Harbor - Harbor 中文 使用docker加速器: Docker镜像极速下载服务 - 毫秒镜像 是什么 Docker 是一种开源的容器化平台,用于将应用程序及其依赖项(如库、运行时环…...
rknn优化教程(二)
文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...
visual studio 2022更改主题为深色
visual studio 2022更改主题为深色 点击visual studio 上方的 工具-> 选项 在选项窗口中,选择 环境 -> 常规 ,将其中的颜色主题改成深色 点击确定,更改完成...
Java多线程实现之Callable接口深度解析
Java多线程实现之Callable接口深度解析 一、Callable接口概述1.1 接口定义1.2 与Runnable接口的对比1.3 Future接口与FutureTask类 二、Callable接口的基本使用方法2.1 传统方式实现Callable接口2.2 使用Lambda表达式简化Callable实现2.3 使用FutureTask类执行Callable任务 三、…...
基于Docker Compose部署Java微服务项目
一. 创建根项目 根项目(父项目)主要用于依赖管理 一些需要注意的点: 打包方式需要为 pom<modules>里需要注册子模块不要引入maven的打包插件,否则打包时会出问题 <?xml version"1.0" encoding"UTF-8…...
拉力测试cuda pytorch 把 4070显卡拉满
import torch import timedef stress_test_gpu(matrix_size16384, duration300):"""对GPU进行压力测试,通过持续的矩阵乘法来最大化GPU利用率参数:matrix_size: 矩阵维度大小,增大可提高计算复杂度duration: 测试持续时间(秒&…...
大学生职业发展与就业创业指导教学评价
这里是引用 作为软工2203/2204班的学生,我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要,而您认真负责的教学态度,让课程的每一部分都充满了实用价值。 尤其让我…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
CSS设置元素的宽度根据其内容自动调整
width: fit-content 是 CSS 中的一个属性值,用于设置元素的宽度根据其内容自动调整,确保宽度刚好容纳内容而不会超出。 效果对比 默认情况(width: auto): 块级元素(如 <div>)会占满父容器…...
短视频矩阵系统文案创作功能开发实践,定制化开发
在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...
