当前位置: 首页 > news >正文

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型
采用的数据是kaggle的房价预测数据
涉及的数据文件,提取码为:zxcv

#导入相关包
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

首先读取数据

train=pd.read_csv("path",encoding="gbk")
test=pd.read_csv("path",encoding="gbk")

数据预处理
数据预处理包括:对数据进行标准化处理,将非数字类型数据转化为数字类型数据

#查看训练集和测试集的数据大小
train.shape
#(1460, 81)
test.shape
#(1460, 81)
#将需要训练和测试的特征合成为一个DataFrame,保证对于训练数据和测试数据的处理是一致的。
all_feature=pd.concat([train.iloc[:,1:80],test.iloc[]])
#找出数字类型的数据,进行标准化处理
num_da=[i for i in all_feature.columns if all_feature[i].dtypes!='object']
all_feature[num_da]=all_feature[num_da].apply(lambda x: (x-x.mean())/x.std())
#将非数字类型数据转化为数字类型的数据
all_feature=pd.get_dummies(all_feature,dummy_na=True)
#对缺失值用所在列的均值进行填充
all_feature = all_feature.fillna(all_feature.mean())
#将训练数据的y取log值
train_y=train['SalePrice'].apply(lambda x:np.log(x))

将数据划分为训练数据集,和测试数据集:做好数据的预处理转化工作之后

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train = torch.tensor(X_train.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.long)
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.long)```数据转换```python
#我们需要将dataframe数据转化为神经网络能处理的tensor数据
input_x=torch.tensor(all_feature.iloc[:train.shape[0],:].values.astype(np.float32))
input_y=torch.tensor(train_y.values.astype(np.float32))
test_x=torch.tensor(all_feature.iloc[train.shape[0]:,:].values.astype(np.float32))
#查看各类数据的维度是否符合要求
input_x
input_y
test_x

模型搭建

input_size = input_x.shape[1] #样本个数
hidden1_size = 128 # 隐含层神经元个数
hidden2_size=256
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden1_size), #全连接层torch.nn.ReLU(),  #激活函数torch.nn.Linear(hidden1_size, hidden2_size), #全连接层torch.nn.ReLU(), #激活函数torch.nn.Linear(hidden2_size, output_size),
)
# MSE损失函数
cost = torch.nn.MSELoss(reduction='mean')
# Adam优化器
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)

模型训练

losses = []
for i in range(500):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_x), batch_size):end = start + batch_size if start + batch_size < len(input_x) else len(input_x)xx = torch.tensor(input_x[start:end],dtype=torch.float,requires_grad=True)yy = torch.tensor(input_y[start:end],dtype=torch.float,requires_grad=True)prediction = my_nn(xx)# 前向传播loss = cost(prediction, yy) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward(retain_graph=True) #反向传播optimizer.step() # 更新参数batch_loss.append(loss.data.numpy()) #记录损失,便于打印# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))

模型保存

torch.save(my_nn,'model.pth')
print("Saved PyTorch Model to model.pth")

模型预测

#加载模型
model =torch.load("model.pth")
#不改变模型参数的基础上进行预测
pred=model(torch.tensor(test_x)).detach()
#对预测后的结果进行还原(之前取了对数)
pred=np.exp(pred)

保存结果

#对测试的结果进行保存
test['SalePrice']=pred.reshape(1,-1)[0]
sub=pd.concat([test["Id"],test["SalePrice"]],axis=1)
sub=sub.set_index("Id")
sub.to_csv("sub.csv")

相关文章:

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型 采用的数据是kaggle的房价预测数据 涉及的数据文件&#xff0c;提取码为&#xff1a;zxcv #导入相关包 import pandas as pd import numpy as np import torch import torch.nn as nn首先读取数据 trainpd.read_csv("path",enc…...

Pytorch深度学习-----DataLoader的用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…...

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载 本站下载的 macOS 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 &#xff08;ML&#xff09; 中&#xff0c;一些最重要的线性代数概念是奇异值分解 &#xff08;SVD&#xff09; 和主成分分析 &#xff08;PCA&#xff09;。收集到所有原始数据后&#xff0c;我们如何发现结构&#xff1f;例如&#xff0c;通过过去 6 天…...

如何用logging记录python实验结果?

做python实验有时候需要打印很多信息在控制台(console&#xff09;&#xff0c;但是控制台的信息不方便回顾和保存&#xff0c;故而可以采用logging将信息存储起来。 先新建一个文件message.log代码如下&#xff1a; import logging logging.basicConfig(filename"messa…...

C语言假期作业 DAY 03

目录 题目 一、选择题 1、已知函数的原型是&#xff1a; int fun(char b[10], int *a); &#xff0c;设定义&#xff1a; char c[10];int d; &#xff0c;正确的调用语句是&#xff08; &#xff09; 2、请问下列表达式哪些会被编译器禁止【多选】&#xff08; &#xff09; 3、…...

使用serverless实现从oss下载文件并压缩

公司之前开发一个网盘系统, 可以上传文件, 打包压缩下载文件, 但是在处理大文件的时候, 服务器遇到了性能问题, 主要是这个项目是单机部署.......(离谱), 然后带宽只有100M, 现在用户比之前多很多, 然后所有人的压缩下载请求都给到这一台服务器了, 比如多个人下载的时候带宽问…...

从上到下打印二叉树

题目描述 从上到下打印出二叉树的每个节点&#xff0c;同一层的节点按照从左到右的顺序打印。 例如: 给定二叉树: [3,9,20,null,null,15,7], 返回&#xff1a; [3,9,20,15,7] 算法思想 建立一个vector数组ret用来当做返回的结果数组&#xff0c;建立一个队列用来接收二叉树…...

【推荐】排序模型的调优

【推荐】排序模型的调优 排序模型的选择 排序模型常见的训练方式 样本类别不均衡处理尝试 欠拟合 过拟合 其他问题 排序模型的选择 LR&#xff0c;GBDT&#xff0c;LRGBDT&#xff0c;FM/FFM&#xff0c; 深度模型&#xff08;wide & deep&#xff0c;DeepFM&#x…...

负载均衡安装配置详解

负载均衡&#xff08;Load Balancing&#xff09;是一种将网络流量分布到多个服务器上的技术&#xff0c;以提高系统的性能、可靠性和可扩展性。 在负载均衡中&#xff0c;有一个负载均衡器&#xff08;Load Balancer&#xff09;&#xff0c;它充当了传入请求的前置接收器。当…...

Java-逻辑控制

目录 一、顺序结构 二、分支结构 1.if语句 2.swich语句 三、循环结构 1.while循环 2.break 3.continue 4.for循环 5.do while循环 四、输入输出 1.输出到控制台 2.从键盘输入 一、顺序结构 按照代码的书写结构一行一行执行。 System.out.println("aaa"); …...

UE 透明渲染次序

附加顺序 用最外面的球, 依次附加里面的球 最后附加的物体优先级最高 附加顺序 用最里面的球, 依次附加外面的球 这样渲染顺序就对了...

【C++】多态原理剖析,Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout

author&#xff1a;&Carlton tag&#xff1a;C topic&#xff1a;【C】多态原理剖析&#xff0c;Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout website:黑马程序员C tool&#xff1a;Visual Studio 2019 date&#xff1a;2023年7月24日 目…...

vue实现flv格式视频播放

公司项目需要实现摄像头实时视频播放&#xff0c;flv格式的视频。先百度使用flv.js插件实现&#xff0c;但是两个摄像头一个能放一个不能放&#xff0c;没有找到原因。&#xff08;开始两个都能放&#xff0c;后端更改地址后不有一个不能放&#xff09;但是在另一个系统上是可以…...

iptables安全技术和防火墙

防火墙&#xff1a;隔离功能 位置&#xff1a;部署在网络边缘或主机边缘&#xff0c;在工作中&#xff0c;防火墙的主要作用是决定哪些数据可以被外网访问以及哪些数据可以进入内网访问&#xff0c;主要在网络层工作 其他类型的安全技术&#xff1a;1、入侵检测系统 2、入侵…...

微信小程序开发5

一、自定义组件-插槽 1.1、什么是插槽 在自定义组件的wxml结构中&#xff0c;可以提供一个<slot>节点(插槽)&#xff0c;用于承载组件使用者提供的wxml结构 1.2、单个插槽 在小程序中&#xff0c;默认每个自定义组件中允许使用一个<slot>进行占位&#xff0c;这种…...

【算法题】2681. 英雄的力量

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums &#xff0c;它表示英雄的能力值。如果我们选出一部分英雄&#xff0c;这组英雄的 力量 定义为&#xff1a; i0 &#xff0c;i1 &#xff0c;… ik 表示这组英雄在数组中的下标。那么这组英雄的力量为 max(nums[i0],n…...

fastutil简单测试下性能

前言 简单测试一下fastutil的实现和Java类库实现的速率。 使用jmh进行测试。 简单解释一下&#xff0c;每轮测试预热2次&#xff0c;每次1s&#xff1b;实测2次&#xff0c;每次1秒。 进行5轮测试。数组大小3种。 package fastutil;import it.unimi.dsi.fastutil.ints.IntArr…...

【FAQ】关于无法判断和区分用户与地图交互手势类型的解决办法

一&#xff0e; 问题描述 当用户通过缩放手势、平移手势、倾斜手势和旋转手势与地图交互&#xff0c;控制地图移动改变其可见区域时&#xff0c;华为地图SDK没有提供直接获取用户手势类型的API。 二&#xff0e; 解决方案 华为地图SDK的地图相机有提供CameraPosition类&…...

腾讯云裸金属服务器CPU型号处理器主频说明

腾讯云裸金属服务器CPU型号是什么&#xff1f;标准型BMSA2裸金属服务器CPU采用AMD EPYC ROME处理器&#xff0c;BMS5实例CPU采用Intel Xeon Cooper Lake处理器&#xff0c;腾讯云服务器网分享落进书房武器CPU型号、处理器主频说明&#xff1a; 裸金属服务器CPU处理器说明 腾讯…...

MongoDB学习和应用(高效的非关系型数据库)

一丶 MongoDB简介 对于社交类软件的功能&#xff0c;我们需要对它的功能特点进行分析&#xff1a; 数据量会随着用户数增大而增大读多写少价值较低非好友看不到其动态信息地理位置的查询… 针对以上特点进行分析各大存储工具&#xff1a; mysql&#xff1a;关系型数据库&am…...

23-Oracle 23 ai 区块链表(Blockchain Table)

小伙伴有没有在金融强合规的领域中遇见&#xff0c;必须要保持数据不可变&#xff0c;管理员都无法修改和留痕的要求。比如医疗的电子病历中&#xff0c;影像检查检验结果不可篡改行的&#xff0c;药品追溯过程中数据只可插入无法删除的特性需求&#xff1b;登录日志、修改日志…...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件

今天呢&#xff0c;博主的学习进度也是步入了Java Mybatis 框架&#xff0c;目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学&#xff0c;希望能对大家有所帮助&#xff0c;也特别欢迎大家指点不足之处&#xff0c;小生很乐意接受正确的建议&…...

基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容

基于 ​UniApp + WebSocket​实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配​微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子&#xff0c;再用 CNN-BiLSTM-Attention 来动态预测每个子序列&#xff0c;最后重构出总位移&#xff0c;预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵&#xff08;S…...

C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。

1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj&#xff0c;再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...

Angular微前端架构:Module Federation + ngx-build-plus (Webpack)

以下是一个完整的 Angular 微前端示例&#xff0c;其中使用的是 Module Federation 和 npx-build-plus 实现了主应用&#xff08;Shell&#xff09;与子应用&#xff08;Remote&#xff09;的集成。 &#x1f6e0;️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析

Linux 内存管理实战精讲&#xff1a;核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用&#xff0c;还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...