当前位置: 首页 > news >正文

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型
采用的数据是kaggle的房价预测数据
涉及的数据文件,提取码为:zxcv

#导入相关包
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

首先读取数据

train=pd.read_csv("path",encoding="gbk")
test=pd.read_csv("path",encoding="gbk")

数据预处理
数据预处理包括:对数据进行标准化处理,将非数字类型数据转化为数字类型数据

#查看训练集和测试集的数据大小
train.shape
#(1460, 81)
test.shape
#(1460, 81)
#将需要训练和测试的特征合成为一个DataFrame,保证对于训练数据和测试数据的处理是一致的。
all_feature=pd.concat([train.iloc[:,1:80],test.iloc[]])
#找出数字类型的数据,进行标准化处理
num_da=[i for i in all_feature.columns if all_feature[i].dtypes!='object']
all_feature[num_da]=all_feature[num_da].apply(lambda x: (x-x.mean())/x.std())
#将非数字类型数据转化为数字类型的数据
all_feature=pd.get_dummies(all_feature,dummy_na=True)
#对缺失值用所在列的均值进行填充
all_feature = all_feature.fillna(all_feature.mean())
#将训练数据的y取log值
train_y=train['SalePrice'].apply(lambda x:np.log(x))

将数据划分为训练数据集,和测试数据集:做好数据的预处理转化工作之后

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train = torch.tensor(X_train.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.long)
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.long)```数据转换```python
#我们需要将dataframe数据转化为神经网络能处理的tensor数据
input_x=torch.tensor(all_feature.iloc[:train.shape[0],:].values.astype(np.float32))
input_y=torch.tensor(train_y.values.astype(np.float32))
test_x=torch.tensor(all_feature.iloc[train.shape[0]:,:].values.astype(np.float32))
#查看各类数据的维度是否符合要求
input_x
input_y
test_x

模型搭建

input_size = input_x.shape[1] #样本个数
hidden1_size = 128 # 隐含层神经元个数
hidden2_size=256
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden1_size), #全连接层torch.nn.ReLU(),  #激活函数torch.nn.Linear(hidden1_size, hidden2_size), #全连接层torch.nn.ReLU(), #激活函数torch.nn.Linear(hidden2_size, output_size),
)
# MSE损失函数
cost = torch.nn.MSELoss(reduction='mean')
# Adam优化器
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)

模型训练

losses = []
for i in range(500):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_x), batch_size):end = start + batch_size if start + batch_size < len(input_x) else len(input_x)xx = torch.tensor(input_x[start:end],dtype=torch.float,requires_grad=True)yy = torch.tensor(input_y[start:end],dtype=torch.float,requires_grad=True)prediction = my_nn(xx)# 前向传播loss = cost(prediction, yy) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward(retain_graph=True) #反向传播optimizer.step() # 更新参数batch_loss.append(loss.data.numpy()) #记录损失,便于打印# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))

模型保存

torch.save(my_nn,'model.pth')
print("Saved PyTorch Model to model.pth")

模型预测

#加载模型
model =torch.load("model.pth")
#不改变模型参数的基础上进行预测
pred=model(torch.tensor(test_x)).detach()
#对预测后的结果进行还原(之前取了对数)
pred=np.exp(pred)

保存结果

#对测试的结果进行保存
test['SalePrice']=pred.reshape(1,-1)[0]
sub=pd.concat([test["Id"],test["SalePrice"]],axis=1)
sub=sub.set_index("Id")
sub.to_csv("sub.csv")

相关文章:

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型 采用的数据是kaggle的房价预测数据 涉及的数据文件&#xff0c;提取码为&#xff1a;zxcv #导入相关包 import pandas as pd import numpy as np import torch import torch.nn as nn首先读取数据 trainpd.read_csv("path",enc…...

Pytorch深度学习-----DataLoader的用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…...

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载 本站下载的 macOS 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 &#xff08;ML&#xff09; 中&#xff0c;一些最重要的线性代数概念是奇异值分解 &#xff08;SVD&#xff09; 和主成分分析 &#xff08;PCA&#xff09;。收集到所有原始数据后&#xff0c;我们如何发现结构&#xff1f;例如&#xff0c;通过过去 6 天…...

如何用logging记录python实验结果?

做python实验有时候需要打印很多信息在控制台(console&#xff09;&#xff0c;但是控制台的信息不方便回顾和保存&#xff0c;故而可以采用logging将信息存储起来。 先新建一个文件message.log代码如下&#xff1a; import logging logging.basicConfig(filename"messa…...

C语言假期作业 DAY 03

目录 题目 一、选择题 1、已知函数的原型是&#xff1a; int fun(char b[10], int *a); &#xff0c;设定义&#xff1a; char c[10];int d; &#xff0c;正确的调用语句是&#xff08; &#xff09; 2、请问下列表达式哪些会被编译器禁止【多选】&#xff08; &#xff09; 3、…...

使用serverless实现从oss下载文件并压缩

公司之前开发一个网盘系统, 可以上传文件, 打包压缩下载文件, 但是在处理大文件的时候, 服务器遇到了性能问题, 主要是这个项目是单机部署.......(离谱), 然后带宽只有100M, 现在用户比之前多很多, 然后所有人的压缩下载请求都给到这一台服务器了, 比如多个人下载的时候带宽问…...

从上到下打印二叉树

题目描述 从上到下打印出二叉树的每个节点&#xff0c;同一层的节点按照从左到右的顺序打印。 例如: 给定二叉树: [3,9,20,null,null,15,7], 返回&#xff1a; [3,9,20,15,7] 算法思想 建立一个vector数组ret用来当做返回的结果数组&#xff0c;建立一个队列用来接收二叉树…...

【推荐】排序模型的调优

【推荐】排序模型的调优 排序模型的选择 排序模型常见的训练方式 样本类别不均衡处理尝试 欠拟合 过拟合 其他问题 排序模型的选择 LR&#xff0c;GBDT&#xff0c;LRGBDT&#xff0c;FM/FFM&#xff0c; 深度模型&#xff08;wide & deep&#xff0c;DeepFM&#x…...

负载均衡安装配置详解

负载均衡&#xff08;Load Balancing&#xff09;是一种将网络流量分布到多个服务器上的技术&#xff0c;以提高系统的性能、可靠性和可扩展性。 在负载均衡中&#xff0c;有一个负载均衡器&#xff08;Load Balancer&#xff09;&#xff0c;它充当了传入请求的前置接收器。当…...

Java-逻辑控制

目录 一、顺序结构 二、分支结构 1.if语句 2.swich语句 三、循环结构 1.while循环 2.break 3.continue 4.for循环 5.do while循环 四、输入输出 1.输出到控制台 2.从键盘输入 一、顺序结构 按照代码的书写结构一行一行执行。 System.out.println("aaa"); …...

UE 透明渲染次序

附加顺序 用最外面的球, 依次附加里面的球 最后附加的物体优先级最高 附加顺序 用最里面的球, 依次附加外面的球 这样渲染顺序就对了...

【C++】多态原理剖析,Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout

author&#xff1a;&Carlton tag&#xff1a;C topic&#xff1a;【C】多态原理剖析&#xff0c;Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout website:黑马程序员C tool&#xff1a;Visual Studio 2019 date&#xff1a;2023年7月24日 目…...

vue实现flv格式视频播放

公司项目需要实现摄像头实时视频播放&#xff0c;flv格式的视频。先百度使用flv.js插件实现&#xff0c;但是两个摄像头一个能放一个不能放&#xff0c;没有找到原因。&#xff08;开始两个都能放&#xff0c;后端更改地址后不有一个不能放&#xff09;但是在另一个系统上是可以…...

iptables安全技术和防火墙

防火墙&#xff1a;隔离功能 位置&#xff1a;部署在网络边缘或主机边缘&#xff0c;在工作中&#xff0c;防火墙的主要作用是决定哪些数据可以被外网访问以及哪些数据可以进入内网访问&#xff0c;主要在网络层工作 其他类型的安全技术&#xff1a;1、入侵检测系统 2、入侵…...

微信小程序开发5

一、自定义组件-插槽 1.1、什么是插槽 在自定义组件的wxml结构中&#xff0c;可以提供一个<slot>节点(插槽)&#xff0c;用于承载组件使用者提供的wxml结构 1.2、单个插槽 在小程序中&#xff0c;默认每个自定义组件中允许使用一个<slot>进行占位&#xff0c;这种…...

【算法题】2681. 英雄的力量

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums &#xff0c;它表示英雄的能力值。如果我们选出一部分英雄&#xff0c;这组英雄的 力量 定义为&#xff1a; i0 &#xff0c;i1 &#xff0c;… ik 表示这组英雄在数组中的下标。那么这组英雄的力量为 max(nums[i0],n…...

fastutil简单测试下性能

前言 简单测试一下fastutil的实现和Java类库实现的速率。 使用jmh进行测试。 简单解释一下&#xff0c;每轮测试预热2次&#xff0c;每次1s&#xff1b;实测2次&#xff0c;每次1秒。 进行5轮测试。数组大小3种。 package fastutil;import it.unimi.dsi.fastutil.ints.IntArr…...

【FAQ】关于无法判断和区分用户与地图交互手势类型的解决办法

一&#xff0e; 问题描述 当用户通过缩放手势、平移手势、倾斜手势和旋转手势与地图交互&#xff0c;控制地图移动改变其可见区域时&#xff0c;华为地图SDK没有提供直接获取用户手势类型的API。 二&#xff0e; 解决方案 华为地图SDK的地图相机有提供CameraPosition类&…...

腾讯云裸金属服务器CPU型号处理器主频说明

腾讯云裸金属服务器CPU型号是什么&#xff1f;标准型BMSA2裸金属服务器CPU采用AMD EPYC ROME处理器&#xff0c;BMS5实例CPU采用Intel Xeon Cooper Lake处理器&#xff0c;腾讯云服务器网分享落进书房武器CPU型号、处理器主频说明&#xff1a; 裸金属服务器CPU处理器说明 腾讯…...

OpenLayers 可视化之热力图

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 热力图&#xff08;Heatmap&#xff09;又叫热点图&#xff0c;是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

Unity3D中Gfx.WaitForPresent优化方案

前言 在Unity中&#xff0c;Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染&#xff08;即CPU被阻塞&#xff09;&#xff0c;这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案&#xff1a; 对惹&#xff0c;这里有一个游戏开发交流小组&…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了&#xff1a;一行…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动&#xff08;如演唱会、马拉松赛事、高考中考等&#xff09;期间&#xff0c;城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例&#xff0c;暖城商圈曾因观众集中离场导致周边…...

java 实现excel文件转pdf | 无水印 | 无限制

文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...

UE5 学习系列(三)创建和移动物体

这篇博客是该系列的第三篇&#xff0c;是在之前两篇博客的基础上展开&#xff0c;主要介绍如何在操作界面中创建和拖动物体&#xff0c;这篇博客跟随的视频链接如下&#xff1a; B 站视频&#xff1a;s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

网站指纹识别

网站指纹识别 网站的最基本组成&#xff1a;服务器&#xff08;操作系统&#xff09;、中间件&#xff08;web容器&#xff09;、脚本语言、数据厍 为什么要了解这些&#xff1f;举个例子&#xff1a;发现了一个文件读取漏洞&#xff0c;我们需要读/etc/passwd&#xff0c;如…...

Kafka入门-生产者

生产者 生产者发送流程&#xff1a; 延迟时间为0ms时&#xff0c;也就意味着每当有数据就会直接发送 异步发送API 异步发送和同步发送的不同在于&#xff1a;异步发送不需要等待结果&#xff0c;同步发送必须等待结果才能进行下一步发送。 普通异步发送 首先导入所需的k…...