当前位置: 首页 > news >正文

Pytorch--3.使用CNN和LSTM对数据进行预测

这个系列前面的文章我们学会了使用全连接层来做简单的回归任务,但是在现实情况里,我们不仅需要做回归,可能还需要做预测工作。同时,我们的数据可能在时空上有着联系,但是简单的全连接层并不能满足我们的需求,所以我们在这篇文章里使用CNN和LSTM来对时间上有联系的数据来进行学习,同时来实现预测的功能。

1.数据集:使用的是kaggle上一个公开的气象数据集(CSV)

有需要的可以去kaggle下载,也可以在评论区留下mail,题主发送过去
在这里插入图片描述

2.导入我们所需要的库和完成前置工作

2.1导入相关的库

torch为人工智能的库,pandas用于数据读取,numpy为张量处理的库,matplotlib为画图库

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import torch.nn as nn
import torch.optim as optim
import random

2.2设置相关配置

我们设置随机种子(方便代码的复现)和警告的忽律(防止出现太多警告看不到代码运行的效果)

warnings.filterwarnings('ignore')
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(99)
np.random.seed(99)
random.seed(99)
print ("随机种子")

2.3数据的读入

pd.read_csv里面的参数为相对位置,即代码和文件要在同一个文件夹下面。使用.head()函数来读一下数据的前几行,保证数据是存在的

train_data = pd.read_csv("LSTM-Multivariate_pollution.csv")
train_data.head()

请添加图片描述
我们来看一下各个值的前2048个数据分布情况(方便挑选数据进行代码测试)
代码里面的pollution可以换成dew,temp等值(也就是上图里面的值),用于观看分布情况。

train_use = train_data["pollution"].values
plt.plot([i for i in range(2048)], pollution[:2048])

pollution:
请添加图片描述
dew:
请添加图片描述
temp:
请添加图片描述
我们可以看到temp属性里面的数据整体呈现上升的趋势,所以我们使用属性为temp的值来进行学习和预测。
首先对数据进行归一化操作(因为值过大的话会导致神经网络损失不降低,同时神经网络难以达到收敛),我们使用minmax归一化后将其打印出来可以看到代码显示的效果

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
train_use = scaler.fit_transform(train_use.reshape(-1, 1))
print ((train_use))                                                                     
print ("归一化处理")

可以看到归一化后的结果如下图所示:
在这里插入图片描述
我们将数据进行处理,默认使用30天的数据对第31天的数据进行预测,同时将数据进行升维处理,使得输入的训练数据为3维度,分别为batchsize,每次所需要的数据(30个数据),和数据的输入维度(1维度)

def split_data(data, time_step = 30):dataX = []dataY = []for i in range(len(data) - time_step):dataX.append(data[i:i + time_step])dataY.append(data[i + time_step])dataX = np.array(dataX).reshape(len(dataX), time_step, -1)dataY = np.array(dataY)return dataX, dataY

进行数据处理后,获得了可以训练的数据和标签

datax,datay = split_data(train_use, 30)
print ((datay))

结果如下:
请添加图片描述

紧接着我们划分训练集和测试集,默认为80%的数据用于做训练集,20%的数据用于做测试集,shuffle表示是否要将数据进行打乱,以此来测试训练效果

def train_test_split(dataX,datay,shuffle = True,percentage = 0.8):if shuffle:random_num = [i for i in range(len(dataX))]np.random.shuffle(random_num)dataX = dataX[random_num]datay = datay[random_num]split_num = int(len(dataX)*percentage)train_X = dataX[:split_num]train_y = datay[:split_num]testX = dataX[split_num:]testy = datay[split_num:]return train_X, train_y, testX, testy

获取我们的训练数据和测试数据,同时把源数据保存到X_train和y_train里面,方便以后对网络的性能进行评比。

train_X, train_y, testx,testy = train_test_split(datax,datay,False,0.8)
print (type(testx))
print("datax的形状为{},dataY的形状为{}".format(train_X.shape, train_y.shape))
X_train = train_X
y_train = train_y

定义我们的自定义网络

class CNN_LSTM(nn.Module):def __init__(self, conv_input, input_size, hidden_size, num_layers, output_size):super(CNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.conv = nn.Conv1d(conv_input, conv_input, 1)self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first = True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):x = self.conv(x)h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)out, _= self.lstm(x,(h0,c0))out = self.fc(out[:,-1,:])return out

设置我们网络训练所需要的参数

test_X1 = torch.Tensor(testx)
test_y1 = torch.Tensor(testy)input_size = 1
conv_input = 30
hidden_size = 64
num_layers = 2output_size = 1model = CNN_LSTM(conv_input, input_size, hidden_size, num_layers,output_size)num_epoch = 1000
batch_size = 4optimizer = optim.Adam(model.parameters(), lr = 0.0001, betas=(0.5, 0.999))criterion = nn.MSELoss()
#print ((torch.Tensor(train_X[:batch_size])))

开始运行代码:

train_losses = []
test_losses = []
for epoch in range(num_epoch):random_num = [i for i in range(len(train_X))]np.random.shuffle(random_num)train_X = train_X[random_num]train_y = train_y[random_num]train_x1 = torch.Tensor(train_X[:batch_size])train_y1 = torch.Tensor(train_y[:batch_size])model.train()optimizer.zero_grad()output = model(train_x1)train_loss = criterion(output, train_y1)train_loss.backward()optimizer.step()if epoch%50 == 0 :model.eval()with torch.no_grad():output = model(test_X1)test_loss = criterion(output, test_y1)train_losses.append(train_loss)test_losses.append(test_loss)print("epoch{},train_loss:{},test_loss:{}".format(epoch, train_loss, test_loss))

在这里插入图片描述

自己手写一个mse计算函数(直接调库也可以),什么是mse?(均方误差,均方误差越小说明模型拟合的越好)

def mse(pred_y, true_y):return np.mean((pred_y - true_y) **2)

然后我们对模型进行测试,观察mse的值

train_X1 = torch.Tensor(X_train)
train_pred = model(train_X1).detach().numpy()
test_pred = model(test_X1).detach().numpy()pred_y = np.concatenate((train_pred, test_pred))
pred_y = scaler.inverse_transform(pred_y).T[0]true_y = np.concatenate((y_train, testy))
#print (true_y)
true_y = scaler.inverse_transform(true_y).T[0]
#print (true_y)
print (f"mse(pred_y, true_y):{mse(pred_y, true_y)}")
##print (pred_y)

在这里插入图片描述

我们取前2048个值来看我们的预测的情况(因为数据有几万条,为了避免图形太过密集难以看出效果,所以我们只采用前2048个值来进行展示)

plt.title("CNN_LSTM")
x = [i for i in range(2048)]
plt.plot(x, pred_y[:2048], marker = "o", markersize =1, label="pred_y",color=(1, 0, 0))
plt.plot(x, true_y[:2048], marker = "x", markersize=1, label="true_y",color=(0, 0, 1))
plt.legend()
plt.show()

可以看出来,已经学习到了基本的上升趋势的
在这里插入图片描述
我们将两个图拆开来看,看到前8192个点的值,可以看到已经获得到了相对应的趋势。
请添加图片描述
在这里插入图片描述

码字不易,写代码不易,点个赞再走把

相关文章:

Pytorch--3.使用CNN和LSTM对数据进行预测

这个系列前面的文章我们学会了使用全连接层来做简单的回归任务,但是在现实情况里,我们不仅需要做回归,可能还需要做预测工作。同时,我们的数据可能在时空上有着联系,但是简单的全连接层并不能满足我们的需求&#xff0…...

爬虫进阶-反爬破解9(下游业务如何使用爬取到的数据+数据和文件的存储方式)

一、下游业务如何使用爬取到的数据 (一)常用数据存储方案 1.百万级别数据:单机数据库,搭建和使用方便快捷,成本低 2.千万级别数据:负载均衡的多台数据库,安全和稳定 3.海量数据:…...

Docker常用应用部署

Docker常用应用部署 一、Ubuntu系统Docker快速安装 Docker官网安装文档:https://docs.docker.com/engine/install/ubuntu/ # 文本处理的流编辑器 -i直接修改读取的文件内容,而不是输出到终端 # sed -i s/原字符串/新字符串/ /home/1.txt # 下面这个是修…...

【数据分享】2014-2022年我国淘宝村点位数据(Excel格式/Shp格式)

电子商务是过去一二十年我国发展最快的行业,其中又以淘宝为代表,淘宝的发展壮大带动了一大批服务淘宝电子商务的村庄,这些村庄被称为淘宝村! 截至到目前,阿里研究院梳理并公布了2014-2022年共9个年份的淘宝村名单&…...

Ubuntu 安装 docker-compose

在Ubuntu上安装Docker Compose,可以按照以下步骤进行操作: 下载 Docker Compose 二进制文件 sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker…...

vue2、vue3中路由守卫变化

什么是路由守卫? 路由守卫就是路由跳转的一些验证,比如登录鉴权(没有登录不能进入个人中心页)等等等 路由守卫分为三大类: 全局守卫:前置守卫:beforeEach 后置钩子:afterEach 单个…...

Leetcode—547.省份数量【中等】

2023每日刷题(八) Leetcode—547.省份数量 实现代码 static int father[210] {0};int Find(int x) {if(x ! father[x]) {father[x] Find(father[x]);}return father[x]; }void Union(int x, int y) {int a Find(x);int b Find(y);if(a ! b) {fathe…...

Nginx 防盗链

nginx防盗链问题 盗链: 就是a网站有一张照片,b网站引用了a网站的照片 。 防盗链: a网站通过设置禁止b网站引用a网站的照片。 nginx防止网站资源被盗用模块 ngx_http_referer_module 如何区分哪些是不正常的用户? HTTP Referer…...

26. 通过 cilium pwru了解网络包的来龙去脉

pwru是一种基于eBPF的工具,可跟踪Linux内核中的网络数据包,并具有先进的过滤功能。它允许对内核状态进行细粒度检查,以便通过调试网络连接问题来解决传统工具(如iptables TRACE或tcpdump)难以解决甚至无法解决的问题。在本文中,我将介绍pwru如何在不必事先了解所有内容的…...

刷题笔记day01-数组

704 题 主要强调,左闭右闭的情况,就是每次查询都会和 [left, right] 进行比较。所以后面的都是mid-1,mid1 的情况。 package mainfunc search(nums []int, target int) int {// 二分查找方法// 每次查找都是左闭右闭的情况left : 0right : …...

C#调用C++ 的DLL传送和接收中文字符串

1 c#向c传送中文字符串 设置&#xff1a;将 字符集 改为 使用多字节字符集 cpp代码&#xff1a; extern "C"_declspec(dllexport) int input_chn_str(char in_str[]) {cout<<in_str<<endl;return 0; }c#代码&#xff1a; [DllImport("Demo.dll…...

【MySQL】数据库常见错误及解决

目录 2003错误&#xff1a;连接错误1251错误&#xff1a;身份验证错误1045错误&#xff1a;拒绝访问错误服务没有报告任何错误net start mysql 发生系统错误 5。 1064错误&#xff1a;语法错误1054错误&#xff1a;列名不存在1442错误&#xff1a;触发器中不能对本表增删改1303…...

spring常见问题汇总

1. 什么是spring? Spring是一个轻量级Java开发框架&#xff0c;最早有Rod Johnson创建&#xff0c;目的是为了解决企业级应用开发的业务 逻辑层和其他各层的耦合问题。它是一个分层的JavaSE/JavaEE full-stack&#xff08;一站式&#xff09;轻量级开源框架&#xff0c; 为开…...

java8 Lambda表达式以及Stream 流

Lambda表达式 Lambda表达式规则 Lambda表达式可以看作是一段可以传递的代码&#xff0c; Lambda表达式只能用于函数式接口&#xff0c;而函数式接口只有一个抽象方法&#xff0c;所以可以省略方法名&#xff0c;参数类型等 Lambda格式&#xff1a;&#xff08;形参列表&…...

基于Java的音乐网站管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09; 代码参考数据库参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&am…...

【蓝桥】小蓝的疑问

1、题目 问题描述 小蓝和小桥上完课后&#xff0c;小桥回顾了课上教的树形数据结构&#xff0c;他在地上画了一棵根节点为 1 的树&#xff0c;并且对每个节点都赋上了一个权值 w i w_i wi​。 小蓝对小桥多次询问&#xff0c;每次询问包含两个整数 x , k x,k x,k&#xff…...

漏洞复现-海康威视综合安防管理平台信息泄露【附Poc】

目录 【产品介绍】 【产品系统UI】 【漏洞说明】 【指纹】 【Nuclei Poc】 【验证】 【产品介绍】 海康威视&#xff08;Hikvision&#xff09;是一家总部位于中国杭州的公司&#xff0c;是全球最大的视频监控产品供应商。除了传统的CCTV摄像机和网络摄像机&#xff0c;海…...

【完美世界】被骂国漫之耻,石昊人设战力全崩,现在真成恋爱世界了

【侵权联系删除】【文/郑尔巴金】 深度爆料&#xff0c;《完美世界》动漫第135集预告片已经更新了&#xff0c;但是网友们对此却是一脸槽点。从预告中可以看出&#xff0c;石昊在和战王战天歌的大战中被打成重伤&#xff0c;最后云曦也被战天歌抓住。在云曦面临生死危机的时候…...

34二叉树-BFS和DFS求树的深度

目录 LeetCode之路——104. 二叉树的最大深度 分析 解法一&#xff1a;广度优先遍历 解法二&#xff1a;深度优先遍历 总结 深度优先搜索 (DFS) 广度优先搜索 (BFS LeetCode之路——104. 二叉树的最大深度 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的…...

Android Glide判断图像资源是否缓存onlyRetrieveFromCache,使用缓存数据,Kotlin

Android Glide判断图像资源是否缓存onlyRetrieveFromCache&#xff0c;使用缓存数据&#xff0c;Kotlin import android.graphics.Bitmap import android.os.Bundle import android.util.Log import android.widget.ImageView import androidx.appcompat.app.AppCompatActivity…...

ansible自动化playbook简单实践

方法一&#xff1a;部分使用ansible 基于现有的nginx配置文件&#xff0c;定制部署nginx软件&#xff0c;将我们的知识进行整合 定制要求&#xff1a; 启动用户&#xff1a;nginx-test&#xff0c;uid是82&#xff0c;系统用户&#xff0c;不能登录 启动端口82 web项目根目录/…...

QPS 和 TPS 详解

QPS 和 TPS 是性能测试中的两个核心指标&#xff0c;用于衡量系统的吞吐能力&#xff0c;但关注点不同。以下是具体解析&#xff1a; 1. QPS&#xff08;Queries Per Second&#xff09; 定义&#xff1a;每秒查询数&#xff0c;表示系统每秒能处理的请求数量&#xff08;无论…...

【DAY34】GPU训练及类的call方法

内容来自浙大疏锦行python打卡训练营 浙大疏锦行 知识点&#xff1a; CPU性能的查看&#xff1a;看架构代际、核心数、线程数GPU性能的查看&#xff1a;看显存、看级别、看架构代际GPU训练的方法&#xff1a;数据和模型移动到GPU device上类的call方法&#xff1a;为什么定义前…...

MiniMax V-Triune让强化学习(RL)既擅长推理也精通视觉感知

MiniMax 近日在github上分享了技术研究成果——V-Triune&#xff0c;这次MiniMax V-Triune的发布既是AI视觉技术也是应用工程上的一次“突围”&#xff0c;让强化学习&#xff08;RL&#xff09;既擅长推理也精通视觉感知&#xff0c;其实缓解了传统视觉RL“鱼和熊掌不可兼得”…...

重读《人件》Peopleware -(13)Ⅱ 办公环境 Ⅵ 电话

当你开始收集有关工作时间质量的数据时&#xff0c;你的注意力自然会集中在主要的干扰源之一——打进来的电话。一天内接15个电话并不罕见。虽然这看似平常&#xff0c;但由于重新沉浸所需的时间&#xff0c;它可能会耗尽你几乎一整天的时间。当一天结束时&#xff0c;你会纳闷…...

四、web安全-行业术语

1. 肉鸡 所谓“肉鸡”是一种很形象的比喻&#xff0c;比喻那些可以随意被我们控制的电脑&#xff0c;对方可以是WINDOWS系统&#xff0c;也可以是UNIX/LINUX系统&#xff0c;可以是普通的个人电脑&#xff0c;也可以是大型的服务器&#xff0c;我们可以象操作自己的电脑那样来…...

LiveGBS海康、大华、宇视、华为摄像头GB28181国标语音对讲及语音喊话:摄像头设备与服务HTTPS准备

LiveGBS海康、大华、宇视、华为摄像头GB28181国标语音对讲及语音喊话&#xff1a;摄像头设备与服务HTTPS准备 1、背景2、准备工作2.1、服务端必备条件&#xff08;注意事项&#xff09;2.2、语音对讲设备准备2.2.1、大华摄像机2.2.2、海康摄像机 3、开启音频并开始对讲4、相关问…...

IPv6代理如何引领下一代网络未来

随着互联网技术的不断发展&#xff0c;IPv6逐渐成为下一代网络协议的核心&#xff0c;替代IPv4已是大势所趋。IPv6代理作为IPv6网络环境下的重要工具&#xff0c;为用户提供了更高效、更安全的网络解决方案。 IPv6代理的定义 IPv6代理是在IPv6网络环境中为处理IPv4转换和其他网…...

Starrocks 物化视图的实现以及在刷新期间能否读数据

背景 本司在用Starrocks做一些业务上的分析的时候&#xff0c;用到了物化视图&#xff0c;并且在高QPS的情况下&#xff0c;RT也没有很大的波动&#xff0c;所以在此研究一下Starrock的实现&#xff0c;以及在刷新的时候是不是原子性的 本文基于Starrocks 3.3.5 结论 Starro…...

Apache Paimon:存储结构、写入及其源码分析

Apache Paimon (此前称为 Flink Table Store)是一种流式数据湖存储技术&#xff0c;采用 LSM&#xff08;Log-Structured Merge-tree&#xff09;树结构来存储数据&#xff0c;支持高吞吐、低延迟的数据摄入和实时查询&#xff0c;尤其适用于流式和批量统一的场景。 1. 创建表…...