当前位置: 首页 > news >正文

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型
采用的数据是kaggle的房价预测数据
涉及的数据文件,提取码为:zxcv

#导入相关包
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

首先读取数据

train=pd.read_csv("path",encoding="gbk")
test=pd.read_csv("path",encoding="gbk")

数据预处理
数据预处理包括:对数据进行标准化处理,将非数字类型数据转化为数字类型数据

#查看训练集和测试集的数据大小
train.shape
#(1460, 81)
test.shape
#(1460, 81)
#将需要训练和测试的特征合成为一个DataFrame,保证对于训练数据和测试数据的处理是一致的。
all_feature=pd.concat([train.iloc[:,1:80],test.iloc[]])
#找出数字类型的数据,进行标准化处理
num_da=[i for i in all_feature.columns if all_feature[i].dtypes!='object']
all_feature[num_da]=all_feature[num_da].apply(lambda x: (x-x.mean())/x.std())
#将非数字类型数据转化为数字类型的数据
all_feature=pd.get_dummies(all_feature,dummy_na=True)
#对缺失值用所在列的均值进行填充
all_feature = all_feature.fillna(all_feature.mean())
#将训练数据的y取log值
train_y=train['SalePrice'].apply(lambda x:np.log(x))

将数据划分为训练数据集,和测试数据集:做好数据的预处理转化工作之后

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train = torch.tensor(X_train.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.long)
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.long)```数据转换```python
#我们需要将dataframe数据转化为神经网络能处理的tensor数据
input_x=torch.tensor(all_feature.iloc[:train.shape[0],:].values.astype(np.float32))
input_y=torch.tensor(train_y.values.astype(np.float32))
test_x=torch.tensor(all_feature.iloc[train.shape[0]:,:].values.astype(np.float32))
#查看各类数据的维度是否符合要求
input_x
input_y
test_x

模型搭建

input_size = input_x.shape[1] #样本个数
hidden1_size = 128 # 隐含层神经元个数
hidden2_size=256
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden1_size), #全连接层torch.nn.ReLU(),  #激活函数torch.nn.Linear(hidden1_size, hidden2_size), #全连接层torch.nn.ReLU(), #激活函数torch.nn.Linear(hidden2_size, output_size),
)
# MSE损失函数
cost = torch.nn.MSELoss(reduction='mean')
# Adam优化器
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)

模型训练

losses = []
for i in range(500):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_x), batch_size):end = start + batch_size if start + batch_size < len(input_x) else len(input_x)xx = torch.tensor(input_x[start:end],dtype=torch.float,requires_grad=True)yy = torch.tensor(input_y[start:end],dtype=torch.float,requires_grad=True)prediction = my_nn(xx)# 前向传播loss = cost(prediction, yy) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward(retain_graph=True) #反向传播optimizer.step() # 更新参数batch_loss.append(loss.data.numpy()) #记录损失,便于打印# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))

模型保存

torch.save(my_nn,'model.pth')
print("Saved PyTorch Model to model.pth")

模型预测

#加载模型
model =torch.load("model.pth")
#不改变模型参数的基础上进行预测
pred=model(torch.tensor(test_x)).detach()
#对预测后的结果进行还原(之前取了对数)
pred=np.exp(pred)

保存结果

#对测试的结果进行保存
test['SalePrice']=pred.reshape(1,-1)[0]
sub=pd.concat([test["Id"],test["SalePrice"]],axis=1)
sub=sub.set_index("Id")
sub.to_csv("sub.csv")

相关文章:

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型 采用的数据是kaggle的房价预测数据 涉及的数据文件&#xff0c;提取码为&#xff1a;zxcv #导入相关包 import pandas as pd import numpy as np import torch import torch.nn as nn首先读取数据 trainpd.read_csv("path",enc…...

Pytorch深度学习-----DataLoader的用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…...

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载 本站下载的 macOS 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 &#xff08;ML&#xff09; 中&#xff0c;一些最重要的线性代数概念是奇异值分解 &#xff08;SVD&#xff09; 和主成分分析 &#xff08;PCA&#xff09;。收集到所有原始数据后&#xff0c;我们如何发现结构&#xff1f;例如&#xff0c;通过过去 6 天…...

如何用logging记录python实验结果?

做python实验有时候需要打印很多信息在控制台(console&#xff09;&#xff0c;但是控制台的信息不方便回顾和保存&#xff0c;故而可以采用logging将信息存储起来。 先新建一个文件message.log代码如下&#xff1a; import logging logging.basicConfig(filename"messa…...

C语言假期作业 DAY 03

目录 题目 一、选择题 1、已知函数的原型是&#xff1a; int fun(char b[10], int *a); &#xff0c;设定义&#xff1a; char c[10];int d; &#xff0c;正确的调用语句是&#xff08; &#xff09; 2、请问下列表达式哪些会被编译器禁止【多选】&#xff08; &#xff09; 3、…...

使用serverless实现从oss下载文件并压缩

公司之前开发一个网盘系统, 可以上传文件, 打包压缩下载文件, 但是在处理大文件的时候, 服务器遇到了性能问题, 主要是这个项目是单机部署.......(离谱), 然后带宽只有100M, 现在用户比之前多很多, 然后所有人的压缩下载请求都给到这一台服务器了, 比如多个人下载的时候带宽问…...

从上到下打印二叉树

题目描述 从上到下打印出二叉树的每个节点&#xff0c;同一层的节点按照从左到右的顺序打印。 例如: 给定二叉树: [3,9,20,null,null,15,7], 返回&#xff1a; [3,9,20,15,7] 算法思想 建立一个vector数组ret用来当做返回的结果数组&#xff0c;建立一个队列用来接收二叉树…...

【推荐】排序模型的调优

【推荐】排序模型的调优 排序模型的选择 排序模型常见的训练方式 样本类别不均衡处理尝试 欠拟合 过拟合 其他问题 排序模型的选择 LR&#xff0c;GBDT&#xff0c;LRGBDT&#xff0c;FM/FFM&#xff0c; 深度模型&#xff08;wide & deep&#xff0c;DeepFM&#x…...

负载均衡安装配置详解

负载均衡&#xff08;Load Balancing&#xff09;是一种将网络流量分布到多个服务器上的技术&#xff0c;以提高系统的性能、可靠性和可扩展性。 在负载均衡中&#xff0c;有一个负载均衡器&#xff08;Load Balancer&#xff09;&#xff0c;它充当了传入请求的前置接收器。当…...

Java-逻辑控制

目录 一、顺序结构 二、分支结构 1.if语句 2.swich语句 三、循环结构 1.while循环 2.break 3.continue 4.for循环 5.do while循环 四、输入输出 1.输出到控制台 2.从键盘输入 一、顺序结构 按照代码的书写结构一行一行执行。 System.out.println("aaa"); …...

UE 透明渲染次序

附加顺序 用最外面的球, 依次附加里面的球 最后附加的物体优先级最高 附加顺序 用最里面的球, 依次附加外面的球 这样渲染顺序就对了...

【C++】多态原理剖析,Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout

author&#xff1a;&Carlton tag&#xff1a;C topic&#xff1a;【C】多态原理剖析&#xff0c;Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout website:黑马程序员C tool&#xff1a;Visual Studio 2019 date&#xff1a;2023年7月24日 目…...

vue实现flv格式视频播放

公司项目需要实现摄像头实时视频播放&#xff0c;flv格式的视频。先百度使用flv.js插件实现&#xff0c;但是两个摄像头一个能放一个不能放&#xff0c;没有找到原因。&#xff08;开始两个都能放&#xff0c;后端更改地址后不有一个不能放&#xff09;但是在另一个系统上是可以…...

iptables安全技术和防火墙

防火墙&#xff1a;隔离功能 位置&#xff1a;部署在网络边缘或主机边缘&#xff0c;在工作中&#xff0c;防火墙的主要作用是决定哪些数据可以被外网访问以及哪些数据可以进入内网访问&#xff0c;主要在网络层工作 其他类型的安全技术&#xff1a;1、入侵检测系统 2、入侵…...

微信小程序开发5

一、自定义组件-插槽 1.1、什么是插槽 在自定义组件的wxml结构中&#xff0c;可以提供一个<slot>节点(插槽)&#xff0c;用于承载组件使用者提供的wxml结构 1.2、单个插槽 在小程序中&#xff0c;默认每个自定义组件中允许使用一个<slot>进行占位&#xff0c;这种…...

【算法题】2681. 英雄的力量

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums &#xff0c;它表示英雄的能力值。如果我们选出一部分英雄&#xff0c;这组英雄的 力量 定义为&#xff1a; i0 &#xff0c;i1 &#xff0c;… ik 表示这组英雄在数组中的下标。那么这组英雄的力量为 max(nums[i0],n…...

fastutil简单测试下性能

前言 简单测试一下fastutil的实现和Java类库实现的速率。 使用jmh进行测试。 简单解释一下&#xff0c;每轮测试预热2次&#xff0c;每次1s&#xff1b;实测2次&#xff0c;每次1秒。 进行5轮测试。数组大小3种。 package fastutil;import it.unimi.dsi.fastutil.ints.IntArr…...

【FAQ】关于无法判断和区分用户与地图交互手势类型的解决办法

一&#xff0e; 问题描述 当用户通过缩放手势、平移手势、倾斜手势和旋转手势与地图交互&#xff0c;控制地图移动改变其可见区域时&#xff0c;华为地图SDK没有提供直接获取用户手势类型的API。 二&#xff0e; 解决方案 华为地图SDK的地图相机有提供CameraPosition类&…...

腾讯云裸金属服务器CPU型号处理器主频说明

腾讯云裸金属服务器CPU型号是什么&#xff1f;标准型BMSA2裸金属服务器CPU采用AMD EPYC ROME处理器&#xff0c;BMS5实例CPU采用Intel Xeon Cooper Lake处理器&#xff0c;腾讯云服务器网分享落进书房武器CPU型号、处理器主频说明&#xff1a; 裸金属服务器CPU处理器说明 腾讯…...

浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)

✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义&#xff08;Task Definition&…...

国防科技大学计算机基础课程笔记02信息编码

1.机内码和国标码 国标码就是我们非常熟悉的这个GB2312,但是因为都是16进制&#xff0c;因此这个了16进制的数据既可以翻译成为这个机器码&#xff0c;也可以翻译成为这个国标码&#xff0c;所以这个时候很容易会出现这个歧义的情况&#xff1b; 因此&#xff0c;我们的这个国…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…...

镜像里切换为普通用户

如果你登录远程虚拟机默认就是 root 用户&#xff0c;但你不希望用 root 权限运行 ns-3&#xff08;这是对的&#xff0c;ns3 工具会拒绝 root&#xff09;&#xff0c;你可以按以下方法创建一个 非 root 用户账号 并切换到它运行 ns-3。 一次性解决方案&#xff1a;创建非 roo…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

如何在网页里填写 PDF 表格?

有时候&#xff0c;你可能希望用户能在你的网站上填写 PDF 表单。然而&#xff0c;这件事并不简单&#xff0c;因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件&#xff0c;但原生并不支持编辑或填写它们。更糟的是&#xff0c;如果你想收集表单数据&#xff…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA&#xff1a;通过低成本全身远程操作学习双手移动操作 传统模仿学习&#xff08;Imitation Learning&#xff09;缺点&#xff1a;聚焦与桌面操作&#xff0c;缺乏通用任务所需的移动性和灵活性 本论文优点&#xff1a;&#xff08;1&#xff09;在ALOHA…...

go 里面的指针

指针 在 Go 中&#xff0c;指针&#xff08;pointer&#xff09;是一个变量的内存地址&#xff0c;就像 C 语言那样&#xff1a; a : 10 p : &a // p 是一个指向 a 的指针 fmt.Println(*p) // 输出 10&#xff0c;通过指针解引用• &a 表示获取变量 a 的地址 p 表示…...

uniapp 实现腾讯云IM群文件上传下载功能

UniApp 集成腾讯云IM实现群文件上传下载功能全攻略 一、功能背景与技术选型 在团队协作场景中&#xff0c;群文件共享是核心需求之一。本文将介绍如何基于腾讯云IMCOS&#xff0c;在uniapp中实现&#xff1a; 群内文件上传/下载文件元数据管理下载进度追踪跨平台文件预览 二…...