pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客
PyTorch 提供三种主要的 RNN 变体:
nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
| 网络类型 | 优势 | 适用场景 |
|---|---|---|
| RNN | 计算简单,适用于短时序列 | 语音、文本处理(短序列) |
| LSTM | 适用于长序列,能记忆长期信息 | 机器翻译、语音识别、股票预测 |
| GRU | 比 LSTM 计算更高效,效果相似 | 语音处理、文本生成 |
例子:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples) # 生成 1000 个等间距数据点y = torch.sin(x) # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1)) # 过去 seq_length 作为输入Y_data.append(y[i + seq_length]) # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10 # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()
代码解析
数据生成
torch.linspace(0, 100, num_samples)生成 1000 个均匀分布的数据点。torch.sin(x)计算正弦值,形成时间序列数据。X为过去 10 个时间步的数据,Y为下一个时间步的预测目标。
构建 RNN
nn.RNN(input_size, hidden_size, num_layers, batch_first=True)定义循环神经网络:input_size=1:每个时间步只有一个输入值(正弦波)。hidden_size=32:隐藏层神经元数目。num_layers=1:单层 RNN。
self.fc = nn.Linear(hidden_size, output_size)负责最终输出。
训练
- 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
- 使用 Adam 优化器 更新模型参数。
- 每 10 个
epoch输出一次损失loss。
测试 & 绘图
- 关闭梯度计算 (
torch.no_grad()),执行前向传播预测测试数据。 - Matplotlib 绘制预测曲线与真实曲线。
运行效果
如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近:
相关文章:
pytorch实现循环神经网络
人工智能例子汇总:AI常见的算法和例子-CSDN博客 PyTorch 提供三种主要的 RNN 变体: nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决…...
Python从零构建macOS状态栏应用(仿ollama)并集成AI同款流式聊天 API 服务(含打包为独立应用)
在本教程中,我们将一步步构建一个 macOS 状态栏应用程序,并集成一个 Flask 服务器,提供流式响应的 API 服务。 如果你手中正好持有一台 MacBook Pro,又怀揣着搭建 AI 聊天服务的想法,却不知从何处迈出第一步,那么这篇文章绝对是你的及时雨。 最终,我们将实现以下功能: …...
leetcode 2080. 区间内查询数字的频率
题目如下 数据范围 示例 这题十分有意思一开始我想对每个子数组排序二分结果超时了。 转换思路:我们可以提前把每个数字出现的位置先记录下来形成集合, 然后拿着left和right利用二分查找看看left和right是不是在集合里然后做一个相减就出答案了。通过…...
深入了解 SSRF 漏洞:原理、条件、危害
目录 前言 SSRF 原理 漏洞产生原因 产生条件 使用协议 使用函数 漏洞影响 防御措施 结语 前言 本文将深入剖析 SSRF(服务端请求伪造)漏洞,从原理、产生原因、条件、影响,到防御措施,为你全面梳理相关知识&am…...
11.QT控件:输入类控件
1. Line Edit(单行输入框) QLineEdit表示单行输入框,用来输入一段文本,但是不能换行。 核心属性: 核心信号: 2. Text Edit(多行输入框) QTextEdit表示多行输入框,也是一个富文本 & markdown编辑器。并且能在内容超…...
Cesium+Vue3教程(011):打造数字城市
文章目录 Cesium打造数字城市创建项目加载地球设置底图设置摄像头查看具体位置和方向添加纽约建筑模型并设置样式添加纽约建筑模型设置样式划分城市区域并着色地图标记显示与实现实现飞机巡城完整项目下载Cesium打造数字城市 创建项目 使用vite创建vue3项目: pnpm create v…...
Windows系统本地部署deepseek 更改目录
本地部署deepseek 无论是mac还是windows系统本地部署deepseek或者其他模型的命令和步骤是一样的。 可以看: 本地部署deepsek 无论是ollama还是部署LLM时候都默认是系统磁盘,对于Windows系统,我们一般不把应用放到系统盘(C:)而是…...
基于Python的药物相互作用预测模型AI构建与优化(下.代码部分)
四、特征工程 4.1 分子描述符计算 分子描述符作为量化分子性质的关键数值,能够从多维度反映药物分子的结构和化学特征,在药物相互作用预测中起着举足轻重的作用。RDKit 库凭借其强大的功能,为我们提供了丰富的分子描述符计算方法,涵盖了多个重要方面的分子性质。 分子量…...
[Python学习日记-79] socket 开发中的粘包现象(解决模拟 SSH 远程执行命令代码中的粘包问题)
[Python学习日记-79] socket 开发中的粘包现象(解决模拟 SSH 远程执行命令代码中的粘包问题) 简介 粘包问题底层原理分析 粘包问题的解决 简介 在Python学习日记-78我们留下了两个问题,一个是服务器端 send() 中使用加号的问题,…...
origin如何在已经画好的图上修改数据且不改变原图像的画风和格式
例如我现在的.opju文件长这样 现在我换了数据集,我想修改这两个图表里对应的算法里的数据,但是我还想保留这图像现在的形式,可以尝试像下面这样做: 右击第一个图,出现下面,选择Book[sheet1] 选择工作簿 出…...
OPENGLPG第九版学习
文章目录 一、OpenGL概述二、着色器基础三、OpenGL绘制方式四、颜色、像素和片元五、视口变换、裁减、剪切与反馈六、纹理与帧缓存七、光照与阴影八、程序式纹理 skip九、细分着色器 skip十、几何着色器 skip十一、内存十二、计算着色器 skip附录 A 第三方支持库附录 B OpenGL …...
5.3.2 软件设计原则
文章目录 抽象模块化信息隐蔽与独立性衡量 软件设计原则:抽象、模块化、信息隐蔽。 抽象 抽象是抽出事物本质的共同特性。过程抽象是指将一个明确定义功能的操作当作单个实体看待。数据抽象是对数据的类型、操作、取值范围进行定义,然后通过这些操作对数…...
【ArcGIS遇上Python】批量提取多波段影像至单个波段
本案例基于ArcGIS python,将landsat影像的7个波段影像数据,批量提取至单个波段。 相关阅读:【ArcGIS微课1000例】0141:提取多波段影像中的单个波段 文章目录 一、数据准备二、效果比对二、python批处理1. 编写python代码2. 运行代码一、数据准备 实验数据及完整的python位…...
Spring Security(maven项目) 3.0.2.9版本 --- 改
前言: 通过实践而发现真理,又通过实践而证实真理和发展真理。从感性认识而能动地发展到理性认识,又从理性认识而能动地指导革命实践,改造主观世界和客观世界。实践、认识、再实践、再认识,这种形式,循环往…...
仿真设计|基于51单片机的温度与烟雾报警系统
目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现(protues8.7) 程序(Keil5) 全部内容 资料获取 具体实现功能 (1)LCD1602实时监测及显示温度值和烟雾浓度值; (2…...
深入剖析 CSRF 漏洞:原理、危害案例与防护
目录 前言 漏洞介绍 漏洞原理 产生条件 产生的危害 靶场练习 post 请求csrf案例 防御措施 验证请求来源 设置 SameSite 属性 双重提交 Cookie 结语 前言 在网络安全领域,各类漏洞层出不穷,时刻威胁着用户的隐私与数据安全。跨站请求伪造&…...
rust跨平台调用动态库
动态库在不同的操作系统,扩展名是不一样的,所以要做处理: static LIB: Lazy<Mutex<Option<Library>>> Lazy::new(|| Mutex::new(None));type CreateFunc unsafe extern "C" fn(*const c_char, *const c_char) -> c_int…...
buuuctf_秘密文件
题目: 应该是分析流量包了,用wireshark打开 我追踪http流未果,分析下ftp流 追踪流看看 用户 “ctf” 使用密码 “ctf” 登录。 PORT命令用于为后续操作设置数据连接。 LIST命令用于列出 FTP 服务器上目录的内容,但在此日志中未…...
课程设计|结构力学
课 程 设 计 第一部分 (结构力学) 2、两种结构在静力等效荷载作用下,内力有哪些不同?(分析比较) 1/2 1 1 1 1 1 1/2 1/4 11(1/2) 1/4 图1求解过程及结果: 轴力图: 内力计算 单位&…...
三次方根pow
给定一个浮点数n,求它的三次方根。 输入格式: 共一行,包含一个浮点数n,−10000≤n≤10000。 输出格式: 共一行,包含一个浮点数,表示问题的解。 注意,结果保留6位小数。 输入样例: 1000.00输出样例: 10.000000 …...
跟李沐学AI:视频生成类论文精读(Movie Gen、HunyuanVideo)
Movie Gen:A Cast of Media Foundation Models 简介 Movie Gen是Meta公司提出的一系列内容生成模型,包含了 3.2.1 预训练数据 Movie Gen采用大约 100M 的视频-文本对和 1B 的图片-文本对进行预训练。 图片-文本对的预训练流程与Meta提出的 Emu: Enh…...
【项目集成Husky】
项目集成Husky 安装初始化 Husky在.husky → pre-commit文件中添加想要执行的命令 安装 使用 Husky 可以帮助你在 Git 钩子中运行脚本,例如在提交代码前运行测试或格式化代码pnpm add --save-dev husky初始化 Husky npx husky init这会在项目根目录下创建一个 .hu…...
keil5如何添加.h 和.c文件,以及如何添加文件夹
1.简介 在hal库的编程中我们一般会生成如下的几个文件夹,在这几个文件夹内存储着各种外设所需要的函数接口.h文件,和实现函数具体功能的.c文件,但是有时我们想要创建自己的文件夹并在这些文件夹下面创造.h .c文件来实现某些功能,…...
2025-1-28-sklearn学习(47) (48) 万家灯火亮年至,一声烟花开新来。
文章目录 sklearn学习(47) & (48)sklearn学习(47) 把它们放在一起47.1 模型管道化47.2 用特征面进行人脸识别47.3 开放性问题: 股票市场结构 sklearn学习(48) 寻求帮助48.1 项目邮件列表48.2 机器学习从业者的 Q&A 社区 sklearn学习(47) & (48) 文章参考网站&…...
Flask数据的增删改查(CRUD)_flask删除数据自动更新
查询年龄小于17的学生信息 Student.query.filter(Student.s_age < 17) students Student.query.filter(Student.s_age.__lt__(17))模糊查询,使用like,查询姓名中第二位为花的学生信息 like ‘_花%’,_代表必须有一个数据,%任何数据 st…...
算法随笔_33: 132模式
上一篇:算法随笔_32: 移掉k位数字-CSDN博客 题目描述如下: 给你一个整数数组 nums ,数组中共有 n 个整数。132 模式的子序列 由三个整数 nums[i]、nums[j] 和 nums[k] 组成,并同时满足:i < j < k 和 nums[i] < nums[k] < nums[j…...
Linux内核中的页面错误处理机制与按需分页技术
在现代操作系统中,内存管理是核心功能之一,而页面错误(Page Fault)处理机制是内存管理的重要组成部分。当程序访问一个尚未映射到物理内存的虚拟地址时,CPU会触发页面错误异常,内核需要捕获并处理这种异常,以决定如何响应,例如加载缺失的页面、处理权限错误等。Linux内…...
【Git】使用笔记总结
目录 概述安装Git注册GitHub配置Git常用命令常见场景1. 修改文件2. 版本回退3. 分支管理 常见问题1. git add [中文文件夹] 无法显示中文问题2. git add [文件夹] 文件名中含有空格3. git add 触发 LF 回车换行警告4. git push 提示不存在 Origin 仓库5. Git与GitHub中默认分支…...
C语言中的存储类
C语言中的存储类 在C语言中,存储类是用于定义变量和函数的作用域、生命周期以及可见性的关键字。存储类决定了数据在内存中的存储位置以及它们在程序中的使用方式。本文将详细介绍C语言中的存储类,包括其类型、作用以及如何使用。 1. 存储类的类型 C语…...
DeepSeek 云端部署,释放无限 AI 潜力!
1.简介 目前,OpenAI、Anthropic、Google 等公司的大型语言模型(LLM)已广泛应用于商业和私人领域。自 ChatGPT 推出以来,与 AI 的对话变得司空见惯,对我而言没有 LLM 几乎无法工作。 国产模型「DeepSeek-R1」的性能与…...
