当前位置: 首页 > news >正文

[PyTorch][chapter 48][LSTM -3]

简介:

     主要介绍一下  

      sin(x):  为 数据

      cos(x):   为对应的label

      项目包括两个文件

      main.py:

              模型的训练,验证,参数保存

     lstm.py

              模型的构建

目录:

  1.      lstm.py
  2.      main.py

一 lstm.py

   

# -*- coding: utf-8 -*-
"""
Created on Tue Aug  8 14:01:15 2023@author: chengxf2
"""import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_lay, b_first):super(LSTM,self).__init__()self.lstm = nn.LSTM(input_size = input_dim, hidden_size = hidden_dim, num_layers = num_lay, batch_first=b_first)self.linear = nn.Linear(hidden_dim, 1)def forward(self, X):#X.shape:[batch_size=1, seq_num=256, input_size=1]output, (hidden, cell) = self.lstm(X)outs =[]seq_num = output.size(1)#output:[batch_size, seq_num, hidden_dim=64]#hidden.shape:[num_layer, batch_size, hiden_size]#print("\n output.shape",output.shape)#print("\n hidden.shape",hidden.shape)for time_step in range(seq_num):#h.shape[batch, hidden_dim]h = output[:,time_step,:]#print("\n h",h.shape)out = self.linear(h)outs.append(out)#沿着一个新维度对输入张量序列进行连接。 #[batch, seq_num, 1]pred = torch.stack(outs, dim=1)return pred

二  main.py

import numpy as np
from matplotlib import pyplot as plt
import torch
from lstm import LSTM
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import timedef showDiff(pred, label, steps):plt.figure()plt.rcParams['font.family'] = 'SimHei' # 正常显示中文plt.title('预测值 and 真实值', fontsize='18')plt.plot(steps, pred.cpu().data.numpy().flatten(),color='r',label='预测值')plt.plot(steps, label.cpu().data.numpy().flatten(), color='g',label='真实值')plt.legend(loc='best')plt.show()def get_data(epoch):TIME_STEP = 256start, end = epoch*np.pi, epoch*np.pi+2*np.pisteps = np.linspace(start,end,TIME_STEP,dtype=np.float32)sin_x = np.sin(steps)cos_x = np.cos(steps)sinx_torch = torch.from_numpy(sin_x[np.newaxis, :, np.newaxis])if torch.cuda.is_available():sinx_torch = torch.from_numpy(sin_x[np.newaxis,:,np.newaxis]).cuda()# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis]).cuda()else:sinx_torch = torch.from_numpy(sin_x[np.newaxis, :, np.newaxis])# print('sinx_input.shape:',sinx_input.shape)cosx_lable = torch.from_numpy(cos_x[np.newaxis, :, np.newaxis])# [batch,seq_num,input_size] (1,256,1)return sinx_torch,cosx_lable,stepsdef eval(model):#等同于 self.train(False) 就是评估模式。#在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移model.eval()test_data,test_label,steps = get_data(2)with torch.no_grad():y_pred = model(test_data)showDiff(y_pred, test_label, steps)def train(model,maxIter,criterion):'''训练模型----------model : lstm 模型.maxIter : 迭代次数.criterion : 损失函数------'''#作用是启用 batch normalization 和 dropoutmodel.train()time_stamp = time.time()for epoch in range(maxIter):sinx_torch,cosx_lable,steps = get_data(epoch)y_pre = model(sinx_torch)   loss = criterion(y_pre,cosx_lable)optimzer.zero_grad()loss.backward()optimzer.step()if epoch%100==0:data_time_interval = time.time() - time_stampprint('epoch: %d loss: %7.3f interval: %6.2f'%(epoch, loss.detach().numpy(),data_time_interval))#torch.save(model.state_dict(), 'model_params.pth') showDiff(y_pre, cosx_lable,steps)if __name__ == '__main__':input_dim =1hidden_dim = 64num_layers =2batch_first = TruemaxIter = 3000model = LSTM(input_dim, hidden_dim, num_layers, batch_first)DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")optimzer = optim.Adam(model.parameters(),lr=0.0001,weight_decay=0.00001)criterion = nn.MSELoss()model.to(DEVICE)criterion.to(DEVICE)train(model,maxIter,criterion)#model.load_state_dict(torch.load('model_params.pth',map_location='cpu'))#eval(model)

 

参考:

pytorch利用rnn通过sin预测cos 利用lstm预测手写数字_pytorch lstm cos_薛定谔的智能的博客-CSDN博客

相关文章:

[PyTorch][chapter 48][LSTM -3]

简介: 主要介绍一下 sin(x): 为 数据 cos(x): 为对应的label 项目包括两个文件 main.py: 模型的训练,验证,参数保存 lstm.py 模型的构建 目录: lstm.py main.py 一 lstm.py # -*- coding: utf-8 -*- "&q…...

xss csrf 攻击

介绍 xss csrf 攻击 XSS: XSS 是指跨站脚本攻击。攻击者利用站点的漏洞,在表单提交时,在表单内容中加入一些恶意脚本,当其他正常用户浏览页面,而页面中刚好出现攻击者的恶意脚本时,脚本被执行,从…...

如何使用win10专业版系统自带远程桌面公司内网电脑,从而实现居家办公?

使用win10专业版自带远程桌面公司内网电脑 文章目录 使用win10专业版自带远程桌面公司内网电脑 在现代社会中,各类电子硬件已经遍布我们身边,除了应用在个人娱乐场景的消费类电子产品外,各项工作也离不开电脑的帮助,特别是涉及到数…...

leetcode做题笔记62

一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish” )。 问总共有多少条不同的路径? 思路一…...

图论 <最短路问题>模板

图论 <最短路问题> 有向图 1.邻接矩阵&#xff0c;稠密图 2.邻接表 &#xff08;常用&#xff09;单链表&#xff0c;每一个点都有一个单链表 &#xff0c;插入一般在头的地方插&#xff0c; 图的邻接表的存储方式 树的深度优先遍历 特殊的深度优先搜索&#xff0c…...

计算机网络性能指标

比特&#xff1a;数据量的单位 KB 2^10B 2^13 bit 比特率&#xff1a;连接在计算机网络上的主机在数字通道上传送比特的速率 kb/s 10^3b/s 带宽&#xff1a;信号所包含的各种频率不同的成分所占据的频率范围 Hz 表示在网络中的通信线路所能传送数据的能力&#xff08…...

vue + elementUI 实现下拉树形结构选择部门,支持多选,支持检索

vue elementUI 实现下拉树形结构选择部门&#xff0c;支持多选&#xff0c;支持检索 <template><div><el-select v-model"multiple?choosedValue:choosedValue[0]" element-loading-background"rgba(0,0,0,0.8)":disabled"disableFl…...

招投标系统简介 企业电子招投标采购系统源码之电子招投标系统 —降低企业采购成本 tbms

​功能模块&#xff1a; 待办消息&#xff0c;招标公告&#xff0c;中标公告&#xff0c;信息发布 描述&#xff1a; 全过程数字化采购管理&#xff0c;打造从供应商管理到采购招投标、采购合同、采购执行的全过程数字化管理。通供应商门户具备内外协同的能力&#xff0c;为外…...

半监督学习(主要伪标签方法)

半监督学习 1. 引言 应用场景&#xff1a;存在少量的有标签样本和大量的无标签样本的场景。在此应用场景下&#xff0c;通常标注数据是匮乏的&#xff0c;成本高的&#xff0c;难以获取的&#xff0c;与之相对应的是却存在大量的无标注数据。半监督学习的假设&#xff1a;决策…...

datePicker一个或多个日期组件,如何快捷选择多个日期(时间段)

elementUI的组件文档中没有详细说明type"dates"如何快捷选择一个时间段的日期&#xff0c;我们可以通过picker-options参数来设置快捷选择&#xff1a; <div class"block"><span class"demonstration">多个日期</span><el…...

【语音合成】微软 edge-tts

目录 1. edge-tts 介绍 2. 代码示例 1. edge-tts 介绍 https://github.com/rany2/edge-tts 在Python代码中使用Microsoft Edge的在线文本到语音服务 2. 代码示例 import asyncio # pip install edge_tts import edge_tts TEXT """给我放首我喜欢听的歌曲…...

elevation mapping学习笔记3之使用D435i相机离线或在线订阅点云和tf关系生成高程图

文章目录 0 引言1 数据1.1 D435i相机配置1.2 协方差位姿1.3 tf 关系2 离线demo2.1 yaml配置文件2.2 launch启动文件2.3 数据录制2.4 离线加载点云生成高程图3 在线demo3.1 launch启动文件3.2 CMakeLists.txt3.3 在线加载点云生成高程图0 引言 elevation mapping学习笔记1已经成…...

ESP32 Max30102 (3)修复心率误差

1. 运行效果 2. 新建修复心率误差.py 代码如下: from machine import sleep, SoftI2C, Pin, Timer from utime import ticks_diff, ticks_us from max30102 import MAX30102, MAX30105_PULSE_AMP_MEDIUM from hrcalc import calc_hr_and_spo2BEATS = 0 # 存储心率 FINGER_F…...

16-4_Qt 5.9 C++开发指南_Qt 应用程序的发布

文章目录 1. 应用程序发布方式2. Windows 平台上的应用程序发布 1. 应用程序发布方式 用 Qt 开发一个应用程序后&#xff0c;将应用程序提供给用户在其他计算机上使用就是应用程序的发布。应用程序发布一般会提供一个安装程序&#xff0c;将应用程序的可执行文件及需要的运行库…...

oracle容灾备份怎么样Oracle容灾备份

随着科学技术的发展和业务的增长&#xff0c;数据安全问题越来越突出。为了保证数据的完整性、易用性和保密性&#xff0c;公司需要采取一系列措施来防止内容丢失的风险。  Oracle是一个关系数据库管理系统(RDBMS),OracleCorporation是由美国软件公司开发和维护的。该系统功能…...

AcWing 4957:飞机降落

【题目来源】https://www.acwing.com/problem/content/4960/【题目描述】 有 N 架飞机准备降落到某个只有一条跑道的机场。 其中第 i 架飞机在 Ti 时刻到达机场上空&#xff0c;到达时它的剩余油料还可以继续盘旋 Di 个单位时间&#xff0c;即它最早可以于 Ti 时刻开始降落&…...

强化学习研究 PG

由于一些原因&#xff0c; 需要学习一下强化学习。用这篇博客来学习吧&#xff0c; 用的资料是李宏毅老师的强化学习课程。 深度强化学习(DRL)-李宏毅1-8课&#xff08;全&#xff09;_哔哩哔哩_bilibili 这篇文章的目的是看懂公式&#xff0c; 毕竟这是我的弱中弱。 强化…...

uniapp微信小程序 401时重复弹出登录弹框问题

APP.vue 登陆成功后&#xff0c;保存登陆信息 if (res.code 200) {uni.setStorageSync(loginResult, res)uni.setStorageSync(token, res.token);uni.setStorageSync(login,false);uni.navigateTo({url: "/pages/learning/learning"}) }退出登录 toLogout: func…...

Cloud Studio实战——热门视频Top100爬虫应用开发

最近Cloud Studio非常火&#xff0c;我也去试了一下&#xff0c;感觉真的非常方便&#xff01;我就以Python爬取B站各区排名前一百的视频&#xff0c;并作可视化来给大家分享一下Cloud Studio&#xff01;应用链接&#xff1a;Cloud Studio实战——B站热门视频Top100爬虫应用开…...

php 去除二维数组重复

在 PHP 中&#xff0c;我们常常需要对数组进行处理和操作。有时候&#xff0c;我们需要去除数组中的重复元素&#xff0c;这里介绍一种针对二维数组的去重方法。 以下是列举一些常见的方法&#xff1a; 方法一&#xff1a;使用 array_map 和 serialize 函数 array_map 函数可以…...

Windows主题自由革命:SecureUxTheme安全启动兼容的内存补丁终极指南

Windows主题自由革命&#xff1a;SecureUxTheme安全启动兼容的内存补丁终极指南 【免费下载链接】SecureUxTheme &#x1f3a8; A secure boot compatible in-memory UxTheme patcher 项目地址: https://gitcode.com/gh_mirrors/se/SecureUxTheme 厌倦了Windows千篇一律…...

线程与进程的区别与联系:操作系统入门详解(含 Python 示例)

、先搞懂&#xff1a;进程与线程到底是什么&#xff1f;&#xff08;通俗类比官方定义&#xff09; 1.1 生活化类比&#xff1a;快速建立认知 如果把计算机的操作系统比作一个大型工厂&#xff1a; 进程&#xff1a;就是工厂里的一个个独立车间。每个车间有自己专属的生产资…...

深入Fast DDS传输层:从UDP、TCP到共享内存,如何为你的ROS2应用选择最佳通信方式?

Fast DDS传输层深度解析&#xff1a;UDP、TCP与共享内存的工程实践指南 在分布式系统架构中&#xff0c;通信中间件的性能直接影响整个系统的响应速度和可靠性。作为ROS 2的默认通信中间件&#xff0c;Fast DDS提供了多种传输协议选择&#xff0c;但如何根据实际场景做出最优决…...

在PC上畅玩Switch游戏:Ryujinx模拟器完全指南

在PC上畅玩Switch游戏&#xff1a;Ryujinx模拟器完全指南 【免费下载链接】Ryujinx 用 C# 编写的实验性 Nintendo Switch 模拟器 项目地址: https://gitcode.com/GitHub_Trending/ry/Ryujinx 想在电脑上体验《塞尔达传说&#xff1a;旷野之息》的震撼冒险&#xff0c;或…...

MT5中文增强工具多场景落地:保险条款通俗化改写与消费者理解度提升实践

MT5中文增强工具多场景落地&#xff1a;保险条款通俗化改写与消费者理解度提升实践 1. 项目概述与核心价值 MT5中文增强工具是一个基于Streamlit和阿里达摩院mT5模型构建的本地化NLP工具&#xff0c;专门针对中文文本进行语义改写和数据增强。这个工具的最大特点是能够在保持…...

高效音频获取与资源管理:喜马拉雅下载工具全解析

高效音频获取与资源管理&#xff1a;喜马拉雅下载工具全解析 【免费下载链接】xmly-downloader-qt5 喜马拉雅FM专辑下载器. 支持VIP与付费专辑. 使用GoQt5编写(Not Qt Binding). 项目地址: https://gitcode.com/gh_mirrors/xm/xmly-downloader-qt5 在数字内容消费时代&a…...

手把手教你优化SiC MOSFET模块:从铜带键合到双面散热的5个关键技术

SiC MOSFET功率模块封装优化实战&#xff1a;五大关键技术深度解析 在电力电子领域&#xff0c;碳化硅(SiC)MOSFET功率模块正逐步取代传统硅基IGBT&#xff0c;成为高效率、高功率密度应用的首选。然而&#xff0c;要充分发挥SiC材料的性能优势&#xff0c;封装技术面临前所未…...

AUTOSAR CANFM模块中,BusOff恢复的50ms和1000ms周期到底怎么来的?底层驱动配置详解

AUTOSAR CANFM模块中BusOff恢复时序的硬件级解析 在车载ECU开发中&#xff0c;CAN总线通信的可靠性直接关系到整车功能安全。当节点因连续错误进入BusOff状态时&#xff0c;AUTOSAR标准定义的50ms快恢复周期和1000ms慢恢复周期并非随意设定&#xff0c;而是源于CAN控制器硬件特…...

科哥Image-to-Video镜像实战:从零开始制作你的第一个AI视频

科哥Image-to-Video镜像实战&#xff1a;从零开始制作你的第一个AI视频 1. 前言&#xff1a;为什么选择科哥的Image-to-Video镜像&#xff1f; 想象一下&#xff0c;你有一张美丽的风景照片&#xff0c;如果能把它变成一段生动的视频该有多好&#xff1f;这就是Image-to-Vide…...

5步攻克MZmine 3质谱数据分析:从问题解决到专业应用的实战指南

5步攻克MZmine 3质谱数据分析&#xff1a;从问题解决到专业应用的实战指南 【免费下载链接】mzmine3 MZmine 3 source code repository 项目地址: https://gitcode.com/gh_mirrors/mz/mzmine3 MZmine 3作为开源质谱数据分析领域的核心工具&#xff0c;在代谢组学、蛋白质…...