当前位置: 首页 > news >正文

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型
采用的数据是kaggle的房价预测数据
涉及的数据文件,提取码为:zxcv

#导入相关包
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

首先读取数据

train=pd.read_csv("path",encoding="gbk")
test=pd.read_csv("path",encoding="gbk")

数据预处理
数据预处理包括:对数据进行标准化处理,将非数字类型数据转化为数字类型数据

#查看训练集和测试集的数据大小
train.shape
#(1460, 81)
test.shape
#(1460, 81)
#将需要训练和测试的特征合成为一个DataFrame,保证对于训练数据和测试数据的处理是一致的。
all_feature=pd.concat([train.iloc[:,1:80],test.iloc[]])
#找出数字类型的数据,进行标准化处理
num_da=[i for i in all_feature.columns if all_feature[i].dtypes!='object']
all_feature[num_da]=all_feature[num_da].apply(lambda x: (x-x.mean())/x.std())
#将非数字类型数据转化为数字类型的数据
all_feature=pd.get_dummies(all_feature,dummy_na=True)
#对缺失值用所在列的均值进行填充
all_feature = all_feature.fillna(all_feature.mean())
#将训练数据的y取log值
train_y=train['SalePrice'].apply(lambda x:np.log(x))

将数据划分为训练数据集,和测试数据集:做好数据的预处理转化工作之后

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train = torch.tensor(X_train.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.long)
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.long)```数据转换```python
#我们需要将dataframe数据转化为神经网络能处理的tensor数据
input_x=torch.tensor(all_feature.iloc[:train.shape[0],:].values.astype(np.float32))
input_y=torch.tensor(train_y.values.astype(np.float32))
test_x=torch.tensor(all_feature.iloc[train.shape[0]:,:].values.astype(np.float32))
#查看各类数据的维度是否符合要求
input_x
input_y
test_x

模型搭建

input_size = input_x.shape[1] #样本个数
hidden1_size = 128 # 隐含层神经元个数
hidden2_size=256
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden1_size), #全连接层torch.nn.ReLU(),  #激活函数torch.nn.Linear(hidden1_size, hidden2_size), #全连接层torch.nn.ReLU(), #激活函数torch.nn.Linear(hidden2_size, output_size),
)
# MSE损失函数
cost = torch.nn.MSELoss(reduction='mean')
# Adam优化器
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)

模型训练

losses = []
for i in range(500):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_x), batch_size):end = start + batch_size if start + batch_size < len(input_x) else len(input_x)xx = torch.tensor(input_x[start:end],dtype=torch.float,requires_grad=True)yy = torch.tensor(input_y[start:end],dtype=torch.float,requires_grad=True)prediction = my_nn(xx)# 前向传播loss = cost(prediction, yy) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward(retain_graph=True) #反向传播optimizer.step() # 更新参数batch_loss.append(loss.data.numpy()) #记录损失,便于打印# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))

模型保存

torch.save(my_nn,'model.pth')
print("Saved PyTorch Model to model.pth")

模型预测

#加载模型
model =torch.load("model.pth")
#不改变模型参数的基础上进行预测
pred=model(torch.tensor(test_x)).detach()
#对预测后的结果进行还原(之前取了对数)
pred=np.exp(pred)

保存结果

#对测试的结果进行保存
test['SalePrice']=pred.reshape(1,-1)[0]
sub=pd.concat([test["Id"],test["SalePrice"]],axis=1)
sub=sub.set_index("Id")
sub.to_csv("sub.csv")

相关文章:

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型 采用的数据是kaggle的房价预测数据 涉及的数据文件&#xff0c;提取码为&#xff1a;zxcv #导入相关包 import pandas as pd import numpy as np import torch import torch.nn as nn首先读取数据 trainpd.read_csv("path",enc…...

Pytorch深度学习-----DataLoader的用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…...

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载 本站下载的 macOS 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 &#xff08;ML&#xff09; 中&#xff0c;一些最重要的线性代数概念是奇异值分解 &#xff08;SVD&#xff09; 和主成分分析 &#xff08;PCA&#xff09;。收集到所有原始数据后&#xff0c;我们如何发现结构&#xff1f;例如&#xff0c;通过过去 6 天…...

如何用logging记录python实验结果?

做python实验有时候需要打印很多信息在控制台(console&#xff09;&#xff0c;但是控制台的信息不方便回顾和保存&#xff0c;故而可以采用logging将信息存储起来。 先新建一个文件message.log代码如下&#xff1a; import logging logging.basicConfig(filename"messa…...

C语言假期作业 DAY 03

目录 题目 一、选择题 1、已知函数的原型是&#xff1a; int fun(char b[10], int *a); &#xff0c;设定义&#xff1a; char c[10];int d; &#xff0c;正确的调用语句是&#xff08; &#xff09; 2、请问下列表达式哪些会被编译器禁止【多选】&#xff08; &#xff09; 3、…...

使用serverless实现从oss下载文件并压缩

公司之前开发一个网盘系统, 可以上传文件, 打包压缩下载文件, 但是在处理大文件的时候, 服务器遇到了性能问题, 主要是这个项目是单机部署.......(离谱), 然后带宽只有100M, 现在用户比之前多很多, 然后所有人的压缩下载请求都给到这一台服务器了, 比如多个人下载的时候带宽问…...

从上到下打印二叉树

题目描述 从上到下打印出二叉树的每个节点&#xff0c;同一层的节点按照从左到右的顺序打印。 例如: 给定二叉树: [3,9,20,null,null,15,7], 返回&#xff1a; [3,9,20,15,7] 算法思想 建立一个vector数组ret用来当做返回的结果数组&#xff0c;建立一个队列用来接收二叉树…...

【推荐】排序模型的调优

【推荐】排序模型的调优 排序模型的选择 排序模型常见的训练方式 样本类别不均衡处理尝试 欠拟合 过拟合 其他问题 排序模型的选择 LR&#xff0c;GBDT&#xff0c;LRGBDT&#xff0c;FM/FFM&#xff0c; 深度模型&#xff08;wide & deep&#xff0c;DeepFM&#x…...

负载均衡安装配置详解

负载均衡&#xff08;Load Balancing&#xff09;是一种将网络流量分布到多个服务器上的技术&#xff0c;以提高系统的性能、可靠性和可扩展性。 在负载均衡中&#xff0c;有一个负载均衡器&#xff08;Load Balancer&#xff09;&#xff0c;它充当了传入请求的前置接收器。当…...

Java-逻辑控制

目录 一、顺序结构 二、分支结构 1.if语句 2.swich语句 三、循环结构 1.while循环 2.break 3.continue 4.for循环 5.do while循环 四、输入输出 1.输出到控制台 2.从键盘输入 一、顺序结构 按照代码的书写结构一行一行执行。 System.out.println("aaa"); …...

UE 透明渲染次序

附加顺序 用最外面的球, 依次附加里面的球 最后附加的物体优先级最高 附加顺序 用最里面的球, 依次附加外面的球 这样渲染顺序就对了...

【C++】多态原理剖析,Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout

author&#xff1a;&Carlton tag&#xff1a;C topic&#xff1a;【C】多态原理剖析&#xff0c;Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout website:黑马程序员C tool&#xff1a;Visual Studio 2019 date&#xff1a;2023年7月24日 目…...

vue实现flv格式视频播放

公司项目需要实现摄像头实时视频播放&#xff0c;flv格式的视频。先百度使用flv.js插件实现&#xff0c;但是两个摄像头一个能放一个不能放&#xff0c;没有找到原因。&#xff08;开始两个都能放&#xff0c;后端更改地址后不有一个不能放&#xff09;但是在另一个系统上是可以…...

iptables安全技术和防火墙

防火墙&#xff1a;隔离功能 位置&#xff1a;部署在网络边缘或主机边缘&#xff0c;在工作中&#xff0c;防火墙的主要作用是决定哪些数据可以被外网访问以及哪些数据可以进入内网访问&#xff0c;主要在网络层工作 其他类型的安全技术&#xff1a;1、入侵检测系统 2、入侵…...

微信小程序开发5

一、自定义组件-插槽 1.1、什么是插槽 在自定义组件的wxml结构中&#xff0c;可以提供一个<slot>节点(插槽)&#xff0c;用于承载组件使用者提供的wxml结构 1.2、单个插槽 在小程序中&#xff0c;默认每个自定义组件中允许使用一个<slot>进行占位&#xff0c;这种…...

【算法题】2681. 英雄的力量

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums &#xff0c;它表示英雄的能力值。如果我们选出一部分英雄&#xff0c;这组英雄的 力量 定义为&#xff1a; i0 &#xff0c;i1 &#xff0c;… ik 表示这组英雄在数组中的下标。那么这组英雄的力量为 max(nums[i0],n…...

fastutil简单测试下性能

前言 简单测试一下fastutil的实现和Java类库实现的速率。 使用jmh进行测试。 简单解释一下&#xff0c;每轮测试预热2次&#xff0c;每次1s&#xff1b;实测2次&#xff0c;每次1秒。 进行5轮测试。数组大小3种。 package fastutil;import it.unimi.dsi.fastutil.ints.IntArr…...

【FAQ】关于无法判断和区分用户与地图交互手势类型的解决办法

一&#xff0e; 问题描述 当用户通过缩放手势、平移手势、倾斜手势和旋转手势与地图交互&#xff0c;控制地图移动改变其可见区域时&#xff0c;华为地图SDK没有提供直接获取用户手势类型的API。 二&#xff0e; 解决方案 华为地图SDK的地图相机有提供CameraPosition类&…...

腾讯云裸金属服务器CPU型号处理器主频说明

腾讯云裸金属服务器CPU型号是什么&#xff1f;标准型BMSA2裸金属服务器CPU采用AMD EPYC ROME处理器&#xff0c;BMS5实例CPU采用Intel Xeon Cooper Lake处理器&#xff0c;腾讯云服务器网分享落进书房武器CPU型号、处理器主频说明&#xff1a; 裸金属服务器CPU处理器说明 腾讯…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

XML Group端口详解

在XML数据映射过程中&#xff0c;经常需要对数据进行分组聚合操作。例如&#xff0c;当处理包含多个物料明细的XML文件时&#xff0c;可能需要将相同物料号的明细归为一组&#xff0c;或对相同物料号的数量进行求和计算。传统实现方式通常需要编写脚本代码&#xff0c;增加了开…...

手游刚开服就被攻击怎么办?如何防御DDoS?

开服初期是手游最脆弱的阶段&#xff0c;极易成为DDoS攻击的目标。一旦遭遇攻击&#xff0c;可能导致服务器瘫痪、玩家流失&#xff0c;甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案&#xff0c;帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

linux之kylin系统nginx的安装

一、nginx的作用 1.可做高性能的web服务器 直接处理静态资源&#xff08;HTML/CSS/图片等&#xff09;&#xff0c;响应速度远超传统服务器类似apache支持高并发连接 2.反向代理服务器 隐藏后端服务器IP地址&#xff0c;提高安全性 3.负载均衡服务器 支持多种策略分发流量…...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

ETLCloud可能遇到的问题有哪些?常见坑位解析

数据集成平台ETLCloud&#xff0c;主要用于支持数据的抽取&#xff08;Extract&#xff09;、转换&#xff08;Transform&#xff09;和加载&#xff08;Load&#xff09;过程。提供了一个简洁直观的界面&#xff0c;以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...

Linux云原生安全:零信任架构与机密计算

Linux云原生安全&#xff1a;零信任架构与机密计算 构建坚不可摧的云原生防御体系 引言&#xff1a;云原生安全的范式革命 随着云原生技术的普及&#xff0c;安全边界正在从传统的网络边界向工作负载内部转移。Gartner预测&#xff0c;到2025年&#xff0c;零信任架构将成为超…...

零基础设计模式——行为型模式 - 责任链模式

第四部分&#xff1a;行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习&#xff01;行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想&#xff1a;使多个对象都有机会处…...