当前位置: 首页 > news >正文

Pytorch--3.使用CNN和LSTM对数据进行预测

这个系列前面的文章我们学会了使用全连接层来做简单的回归任务,但是在现实情况里,我们不仅需要做回归,可能还需要做预测工作。同时,我们的数据可能在时空上有着联系,但是简单的全连接层并不能满足我们的需求,所以我们在这篇文章里使用CNN和LSTM来对时间上有联系的数据来进行学习,同时来实现预测的功能。

1.数据集:使用的是kaggle上一个公开的气象数据集(CSV)

有需要的可以去kaggle下载,也可以在评论区留下mail,题主发送过去
在这里插入图片描述

2.导入我们所需要的库和完成前置工作

2.1导入相关的库

torch为人工智能的库,pandas用于数据读取,numpy为张量处理的库,matplotlib为画图库

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import warnings
import torch.nn as nn
import torch.optim as optim
import random

2.2设置相关配置

我们设置随机种子(方便代码的复现)和警告的忽律(防止出现太多警告看不到代码运行的效果)

warnings.filterwarnings('ignore')
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(99)
np.random.seed(99)
random.seed(99)
print ("随机种子")

2.3数据的读入

pd.read_csv里面的参数为相对位置,即代码和文件要在同一个文件夹下面。使用.head()函数来读一下数据的前几行,保证数据是存在的

train_data = pd.read_csv("LSTM-Multivariate_pollution.csv")
train_data.head()

请添加图片描述
我们来看一下各个值的前2048个数据分布情况(方便挑选数据进行代码测试)
代码里面的pollution可以换成dew,temp等值(也就是上图里面的值),用于观看分布情况。

train_use = train_data["pollution"].values
plt.plot([i for i in range(2048)], pollution[:2048])

pollution:
请添加图片描述
dew:
请添加图片描述
temp:
请添加图片描述
我们可以看到temp属性里面的数据整体呈现上升的趋势,所以我们使用属性为temp的值来进行学习和预测。
首先对数据进行归一化操作(因为值过大的话会导致神经网络损失不降低,同时神经网络难以达到收敛),我们使用minmax归一化后将其打印出来可以看到代码显示的效果

from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
train_use = scaler.fit_transform(train_use.reshape(-1, 1))
print ((train_use))                                                                     
print ("归一化处理")

可以看到归一化后的结果如下图所示:
在这里插入图片描述
我们将数据进行处理,默认使用30天的数据对第31天的数据进行预测,同时将数据进行升维处理,使得输入的训练数据为3维度,分别为batchsize,每次所需要的数据(30个数据),和数据的输入维度(1维度)

def split_data(data, time_step = 30):dataX = []dataY = []for i in range(len(data) - time_step):dataX.append(data[i:i + time_step])dataY.append(data[i + time_step])dataX = np.array(dataX).reshape(len(dataX), time_step, -1)dataY = np.array(dataY)return dataX, dataY

进行数据处理后,获得了可以训练的数据和标签

datax,datay = split_data(train_use, 30)
print ((datay))

结果如下:
请添加图片描述

紧接着我们划分训练集和测试集,默认为80%的数据用于做训练集,20%的数据用于做测试集,shuffle表示是否要将数据进行打乱,以此来测试训练效果

def train_test_split(dataX,datay,shuffle = True,percentage = 0.8):if shuffle:random_num = [i for i in range(len(dataX))]np.random.shuffle(random_num)dataX = dataX[random_num]datay = datay[random_num]split_num = int(len(dataX)*percentage)train_X = dataX[:split_num]train_y = datay[:split_num]testX = dataX[split_num:]testy = datay[split_num:]return train_X, train_y, testX, testy

获取我们的训练数据和测试数据,同时把源数据保存到X_train和y_train里面,方便以后对网络的性能进行评比。

train_X, train_y, testx,testy = train_test_split(datax,datay,False,0.8)
print (type(testx))
print("datax的形状为{},dataY的形状为{}".format(train_X.shape, train_y.shape))
X_train = train_X
y_train = train_y

定义我们的自定义网络

class CNN_LSTM(nn.Module):def __init__(self, conv_input, input_size, hidden_size, num_layers, output_size):super(CNN_LSTM, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.conv = nn.Conv1d(conv_input, conv_input, 1)self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first = True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):x = self.conv(x)h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)out, _= self.lstm(x,(h0,c0))out = self.fc(out[:,-1,:])return out

设置我们网络训练所需要的参数

test_X1 = torch.Tensor(testx)
test_y1 = torch.Tensor(testy)input_size = 1
conv_input = 30
hidden_size = 64
num_layers = 2output_size = 1model = CNN_LSTM(conv_input, input_size, hidden_size, num_layers,output_size)num_epoch = 1000
batch_size = 4optimizer = optim.Adam(model.parameters(), lr = 0.0001, betas=(0.5, 0.999))criterion = nn.MSELoss()
#print ((torch.Tensor(train_X[:batch_size])))

开始运行代码:

train_losses = []
test_losses = []
for epoch in range(num_epoch):random_num = [i for i in range(len(train_X))]np.random.shuffle(random_num)train_X = train_X[random_num]train_y = train_y[random_num]train_x1 = torch.Tensor(train_X[:batch_size])train_y1 = torch.Tensor(train_y[:batch_size])model.train()optimizer.zero_grad()output = model(train_x1)train_loss = criterion(output, train_y1)train_loss.backward()optimizer.step()if epoch%50 == 0 :model.eval()with torch.no_grad():output = model(test_X1)test_loss = criterion(output, test_y1)train_losses.append(train_loss)test_losses.append(test_loss)print("epoch{},train_loss:{},test_loss:{}".format(epoch, train_loss, test_loss))

在这里插入图片描述

自己手写一个mse计算函数(直接调库也可以),什么是mse?(均方误差,均方误差越小说明模型拟合的越好)

def mse(pred_y, true_y):return np.mean((pred_y - true_y) **2)

然后我们对模型进行测试,观察mse的值

train_X1 = torch.Tensor(X_train)
train_pred = model(train_X1).detach().numpy()
test_pred = model(test_X1).detach().numpy()pred_y = np.concatenate((train_pred, test_pred))
pred_y = scaler.inverse_transform(pred_y).T[0]true_y = np.concatenate((y_train, testy))
#print (true_y)
true_y = scaler.inverse_transform(true_y).T[0]
#print (true_y)
print (f"mse(pred_y, true_y):{mse(pred_y, true_y)}")
##print (pred_y)

在这里插入图片描述

我们取前2048个值来看我们的预测的情况(因为数据有几万条,为了避免图形太过密集难以看出效果,所以我们只采用前2048个值来进行展示)

plt.title("CNN_LSTM")
x = [i for i in range(2048)]
plt.plot(x, pred_y[:2048], marker = "o", markersize =1, label="pred_y",color=(1, 0, 0))
plt.plot(x, true_y[:2048], marker = "x", markersize=1, label="true_y",color=(0, 0, 1))
plt.legend()
plt.show()

可以看出来,已经学习到了基本的上升趋势的
在这里插入图片描述
我们将两个图拆开来看,看到前8192个点的值,可以看到已经获得到了相对应的趋势。
请添加图片描述
在这里插入图片描述

码字不易,写代码不易,点个赞再走把

相关文章:

Pytorch--3.使用CNN和LSTM对数据进行预测

这个系列前面的文章我们学会了使用全连接层来做简单的回归任务,但是在现实情况里,我们不仅需要做回归,可能还需要做预测工作。同时,我们的数据可能在时空上有着联系,但是简单的全连接层并不能满足我们的需求&#xff0…...

爬虫进阶-反爬破解9(下游业务如何使用爬取到的数据+数据和文件的存储方式)

一、下游业务如何使用爬取到的数据 (一)常用数据存储方案 1.百万级别数据:单机数据库,搭建和使用方便快捷,成本低 2.千万级别数据:负载均衡的多台数据库,安全和稳定 3.海量数据:…...

Docker常用应用部署

Docker常用应用部署 一、Ubuntu系统Docker快速安装 Docker官网安装文档:https://docs.docker.com/engine/install/ubuntu/ # 文本处理的流编辑器 -i直接修改读取的文件内容,而不是输出到终端 # sed -i s/原字符串/新字符串/ /home/1.txt # 下面这个是修…...

【数据分享】2014-2022年我国淘宝村点位数据(Excel格式/Shp格式)

电子商务是过去一二十年我国发展最快的行业,其中又以淘宝为代表,淘宝的发展壮大带动了一大批服务淘宝电子商务的村庄,这些村庄被称为淘宝村! 截至到目前,阿里研究院梳理并公布了2014-2022年共9个年份的淘宝村名单&…...

Ubuntu 安装 docker-compose

在Ubuntu上安装Docker Compose,可以按照以下步骤进行操作: 下载 Docker Compose 二进制文件 sudo curl -L "https://github.com/docker/compose/releases/latest/download/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker…...

vue2、vue3中路由守卫变化

什么是路由守卫? 路由守卫就是路由跳转的一些验证,比如登录鉴权(没有登录不能进入个人中心页)等等等 路由守卫分为三大类: 全局守卫:前置守卫:beforeEach 后置钩子:afterEach 单个…...

Leetcode—547.省份数量【中等】

2023每日刷题(八) Leetcode—547.省份数量 实现代码 static int father[210] {0};int Find(int x) {if(x ! father[x]) {father[x] Find(father[x]);}return father[x]; }void Union(int x, int y) {int a Find(x);int b Find(y);if(a ! b) {fathe…...

Nginx 防盗链

nginx防盗链问题 盗链: 就是a网站有一张照片,b网站引用了a网站的照片 。 防盗链: a网站通过设置禁止b网站引用a网站的照片。 nginx防止网站资源被盗用模块 ngx_http_referer_module 如何区分哪些是不正常的用户? HTTP Referer…...

26. 通过 cilium pwru了解网络包的来龙去脉

pwru是一种基于eBPF的工具,可跟踪Linux内核中的网络数据包,并具有先进的过滤功能。它允许对内核状态进行细粒度检查,以便通过调试网络连接问题来解决传统工具(如iptables TRACE或tcpdump)难以解决甚至无法解决的问题。在本文中,我将介绍pwru如何在不必事先了解所有内容的…...

刷题笔记day01-数组

704 题 主要强调,左闭右闭的情况,就是每次查询都会和 [left, right] 进行比较。所以后面的都是mid-1,mid1 的情况。 package mainfunc search(nums []int, target int) int {// 二分查找方法// 每次查找都是左闭右闭的情况left : 0right : …...

C#调用C++ 的DLL传送和接收中文字符串

1 c#向c传送中文字符串 设置&#xff1a;将 字符集 改为 使用多字节字符集 cpp代码&#xff1a; extern "C"_declspec(dllexport) int input_chn_str(char in_str[]) {cout<<in_str<<endl;return 0; }c#代码&#xff1a; [DllImport("Demo.dll…...

【MySQL】数据库常见错误及解决

目录 2003错误&#xff1a;连接错误1251错误&#xff1a;身份验证错误1045错误&#xff1a;拒绝访问错误服务没有报告任何错误net start mysql 发生系统错误 5。 1064错误&#xff1a;语法错误1054错误&#xff1a;列名不存在1442错误&#xff1a;触发器中不能对本表增删改1303…...

spring常见问题汇总

1. 什么是spring? Spring是一个轻量级Java开发框架&#xff0c;最早有Rod Johnson创建&#xff0c;目的是为了解决企业级应用开发的业务 逻辑层和其他各层的耦合问题。它是一个分层的JavaSE/JavaEE full-stack&#xff08;一站式&#xff09;轻量级开源框架&#xff0c; 为开…...

java8 Lambda表达式以及Stream 流

Lambda表达式 Lambda表达式规则 Lambda表达式可以看作是一段可以传递的代码&#xff0c; Lambda表达式只能用于函数式接口&#xff0c;而函数式接口只有一个抽象方法&#xff0c;所以可以省略方法名&#xff0c;参数类型等 Lambda格式&#xff1a;&#xff08;形参列表&…...

基于Java的音乐网站管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序&#xff08;小蔡coding&#xff09; 代码参考数据库参考源码获取 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&am…...

【蓝桥】小蓝的疑问

1、题目 问题描述 小蓝和小桥上完课后&#xff0c;小桥回顾了课上教的树形数据结构&#xff0c;他在地上画了一棵根节点为 1 的树&#xff0c;并且对每个节点都赋上了一个权值 w i w_i wi​。 小蓝对小桥多次询问&#xff0c;每次询问包含两个整数 x , k x,k x,k&#xff…...

漏洞复现-海康威视综合安防管理平台信息泄露【附Poc】

目录 【产品介绍】 【产品系统UI】 【漏洞说明】 【指纹】 【Nuclei Poc】 【验证】 【产品介绍】 海康威视&#xff08;Hikvision&#xff09;是一家总部位于中国杭州的公司&#xff0c;是全球最大的视频监控产品供应商。除了传统的CCTV摄像机和网络摄像机&#xff0c;海…...

【完美世界】被骂国漫之耻,石昊人设战力全崩,现在真成恋爱世界了

【侵权联系删除】【文/郑尔巴金】 深度爆料&#xff0c;《完美世界》动漫第135集预告片已经更新了&#xff0c;但是网友们对此却是一脸槽点。从预告中可以看出&#xff0c;石昊在和战王战天歌的大战中被打成重伤&#xff0c;最后云曦也被战天歌抓住。在云曦面临生死危机的时候…...

34二叉树-BFS和DFS求树的深度

目录 LeetCode之路——104. 二叉树的最大深度 分析 解法一&#xff1a;广度优先遍历 解法二&#xff1a;深度优先遍历 总结 深度优先搜索 (DFS) 广度优先搜索 (BFS LeetCode之路——104. 二叉树的最大深度 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的…...

Android Glide判断图像资源是否缓存onlyRetrieveFromCache,使用缓存数据,Kotlin

Android Glide判断图像资源是否缓存onlyRetrieveFromCache&#xff0c;使用缓存数据&#xff0c;Kotlin import android.graphics.Bitmap import android.os.Bundle import android.util.Log import android.widget.ImageView import androidx.appcompat.app.AppCompatActivity…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制&#xff0c;因此这个了16进制的数据既可以翻译成为这个机器码&#xff0c;也可以翻译成为这个国标码&#xff0c;所以这个时候很容易会出现这个歧义的情况&#xff1b; 因此&#xff0c;我们的这个国…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中&#xff0c;我们已经大致实现了rpc服务端的各项功能代…...

1.3 VSCode安装与环境配置

进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件&#xff0c;然后打开终端&#xff0c;进入下载文件夹&#xff0c;键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

【HTML-16】深入理解HTML中的块元素与行内元素

HTML元素根据其显示特性可以分为两大类&#xff1a;块元素(Block-level Elements)和行内元素(Inline Elements)。理解这两者的区别对于构建良好的网页布局至关重要。本文将全面解析这两种元素的特性、区别以及实际应用场景。 1. 块元素(Block-level Elements) 1.1 基本特性 …...

关于 WASM:1. WASM 基础原理

一、WASM 简介 1.1 WebAssembly 是什么&#xff1f; WebAssembly&#xff08;WASM&#xff09; 是一种能在现代浏览器中高效运行的二进制指令格式&#xff0c;它不是传统的编程语言&#xff0c;而是一种 低级字节码格式&#xff0c;可由高级语言&#xff08;如 C、C、Rust&am…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

C++:多态机制详解

目录 一. 多态的概念 1.静态多态&#xff08;编译时多态&#xff09; 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1&#xff09;.协变 2&#xff09;.析构函数的重写 5.override 和 final关键字 1&#…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...