当前位置: 首页 > news >正文

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型
采用的数据是kaggle的房价预测数据
涉及的数据文件,提取码为:zxcv

#导入相关包
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

首先读取数据

train=pd.read_csv("path",encoding="gbk")
test=pd.read_csv("path",encoding="gbk")

数据预处理
数据预处理包括:对数据进行标准化处理,将非数字类型数据转化为数字类型数据

#查看训练集和测试集的数据大小
train.shape
#(1460, 81)
test.shape
#(1460, 81)
#将需要训练和测试的特征合成为一个DataFrame,保证对于训练数据和测试数据的处理是一致的。
all_feature=pd.concat([train.iloc[:,1:80],test.iloc[]])
#找出数字类型的数据,进行标准化处理
num_da=[i for i in all_feature.columns if all_feature[i].dtypes!='object']
all_feature[num_da]=all_feature[num_da].apply(lambda x: (x-x.mean())/x.std())
#将非数字类型数据转化为数字类型的数据
all_feature=pd.get_dummies(all_feature,dummy_na=True)
#对缺失值用所在列的均值进行填充
all_feature = all_feature.fillna(all_feature.mean())
#将训练数据的y取log值
train_y=train['SalePrice'].apply(lambda x:np.log(x))

将数据划分为训练数据集,和测试数据集:做好数据的预处理转化工作之后

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train = torch.tensor(X_train.values, dtype=torch.float32)
y_train = torch.tensor(y_train.values, dtype=torch.long)
X_test = torch.tensor(X_test.values, dtype=torch.float32)
y_test = torch.tensor(y_test.values, dtype=torch.long)```数据转换```python
#我们需要将dataframe数据转化为神经网络能处理的tensor数据
input_x=torch.tensor(all_feature.iloc[:train.shape[0],:].values.astype(np.float32))
input_y=torch.tensor(train_y.values.astype(np.float32))
test_x=torch.tensor(all_feature.iloc[train.shape[0]:,:].values.astype(np.float32))
#查看各类数据的维度是否符合要求
input_x
input_y
test_x

模型搭建

input_size = input_x.shape[1] #样本个数
hidden1_size = 128 # 隐含层神经元个数
hidden2_size=256
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden1_size), #全连接层torch.nn.ReLU(),  #激活函数torch.nn.Linear(hidden1_size, hidden2_size), #全连接层torch.nn.ReLU(), #激活函数torch.nn.Linear(hidden2_size, output_size),
)
# MSE损失函数
cost = torch.nn.MSELoss(reduction='mean')
# Adam优化器
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)

模型训练

losses = []
for i in range(500):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_x), batch_size):end = start + batch_size if start + batch_size < len(input_x) else len(input_x)xx = torch.tensor(input_x[start:end],dtype=torch.float,requires_grad=True)yy = torch.tensor(input_y[start:end],dtype=torch.float,requires_grad=True)prediction = my_nn(xx)# 前向传播loss = cost(prediction, yy) # 计算损失optimizer.zero_grad() # 梯度清零loss.backward(retain_graph=True) #反向传播optimizer.step() # 更新参数batch_loss.append(loss.data.numpy()) #记录损失,便于打印# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))

模型保存

torch.save(my_nn,'model.pth')
print("Saved PyTorch Model to model.pth")

模型预测

#加载模型
model =torch.load("model.pth")
#不改变模型参数的基础上进行预测
pred=model(torch.tensor(test_x)).detach()
#对预测后的结果进行还原(之前取了对数)
pred=np.exp(pred)

保存结果

#对测试的结果进行保存
test['SalePrice']=pred.reshape(1,-1)[0]
sub=pd.concat([test["Id"],test["SalePrice"]],axis=1)
sub=sub.set_index("Id")
sub.to_csv("sub.csv")

相关文章:

如何快速从csv文件搭建一个简单的神经网络模型(回归)

快速搭建一个简单的神经网络预测模型 采用的数据是kaggle的房价预测数据 涉及的数据文件&#xff0c;提取码为&#xff1a;zxcv #导入相关包 import pandas as pd import numpy as np import torch import torch.nn as nn首先读取数据 trainpd.read_csv("path",enc…...

Pytorch深度学习-----DataLoader的用法

系列文章目录 PyTorch深度学习——Anaconda和PyTorch安装 Pytorch深度学习-----数据模块Dataset类 Pytorch深度学习------TensorBoard的使用 Pytorch深度学习------Torchvision中Transforms的使用&#xff08;ToTensor&#xff0c;Normalize&#xff0c;Resize &#xff0c;Co…...

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载

macOS Ventura 13.5 (22G74) Boot ISO 原版可引导镜像下载 本站下载的 macOS 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也支持在 Windows 和 Lin…...

【机器学习】 奇异值分解 (SVD) 和主成分分析 (PCA)

一、说明 在机器学习 &#xff08;ML&#xff09; 中&#xff0c;一些最重要的线性代数概念是奇异值分解 &#xff08;SVD&#xff09; 和主成分分析 &#xff08;PCA&#xff09;。收集到所有原始数据后&#xff0c;我们如何发现结构&#xff1f;例如&#xff0c;通过过去 6 天…...

如何用logging记录python实验结果?

做python实验有时候需要打印很多信息在控制台(console&#xff09;&#xff0c;但是控制台的信息不方便回顾和保存&#xff0c;故而可以采用logging将信息存储起来。 先新建一个文件message.log代码如下&#xff1a; import logging logging.basicConfig(filename"messa…...

C语言假期作业 DAY 03

目录 题目 一、选择题 1、已知函数的原型是&#xff1a; int fun(char b[10], int *a); &#xff0c;设定义&#xff1a; char c[10];int d; &#xff0c;正确的调用语句是&#xff08; &#xff09; 2、请问下列表达式哪些会被编译器禁止【多选】&#xff08; &#xff09; 3、…...

使用serverless实现从oss下载文件并压缩

公司之前开发一个网盘系统, 可以上传文件, 打包压缩下载文件, 但是在处理大文件的时候, 服务器遇到了性能问题, 主要是这个项目是单机部署.......(离谱), 然后带宽只有100M, 现在用户比之前多很多, 然后所有人的压缩下载请求都给到这一台服务器了, 比如多个人下载的时候带宽问…...

从上到下打印二叉树

题目描述 从上到下打印出二叉树的每个节点&#xff0c;同一层的节点按照从左到右的顺序打印。 例如: 给定二叉树: [3,9,20,null,null,15,7], 返回&#xff1a; [3,9,20,15,7] 算法思想 建立一个vector数组ret用来当做返回的结果数组&#xff0c;建立一个队列用来接收二叉树…...

【推荐】排序模型的调优

【推荐】排序模型的调优 排序模型的选择 排序模型常见的训练方式 样本类别不均衡处理尝试 欠拟合 过拟合 其他问题 排序模型的选择 LR&#xff0c;GBDT&#xff0c;LRGBDT&#xff0c;FM/FFM&#xff0c; 深度模型&#xff08;wide & deep&#xff0c;DeepFM&#x…...

负载均衡安装配置详解

负载均衡&#xff08;Load Balancing&#xff09;是一种将网络流量分布到多个服务器上的技术&#xff0c;以提高系统的性能、可靠性和可扩展性。 在负载均衡中&#xff0c;有一个负载均衡器&#xff08;Load Balancer&#xff09;&#xff0c;它充当了传入请求的前置接收器。当…...

Java-逻辑控制

目录 一、顺序结构 二、分支结构 1.if语句 2.swich语句 三、循环结构 1.while循环 2.break 3.continue 4.for循环 5.do while循环 四、输入输出 1.输出到控制台 2.从键盘输入 一、顺序结构 按照代码的书写结构一行一行执行。 System.out.println("aaa"); …...

UE 透明渲染次序

附加顺序 用最外面的球, 依次附加里面的球 最后附加的物体优先级最高 附加顺序 用最里面的球, 依次附加外面的球 这样渲染顺序就对了...

【C++】多态原理剖析,Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout

author&#xff1a;&Carlton tag&#xff1a;C topic&#xff1a;【C】多态原理剖析&#xff0c;Visual Studio开发人员工具使用查看类结构cl /d1 reportSingleClassLayout website:黑马程序员C tool&#xff1a;Visual Studio 2019 date&#xff1a;2023年7月24日 目…...

vue实现flv格式视频播放

公司项目需要实现摄像头实时视频播放&#xff0c;flv格式的视频。先百度使用flv.js插件实现&#xff0c;但是两个摄像头一个能放一个不能放&#xff0c;没有找到原因。&#xff08;开始两个都能放&#xff0c;后端更改地址后不有一个不能放&#xff09;但是在另一个系统上是可以…...

iptables安全技术和防火墙

防火墙&#xff1a;隔离功能 位置&#xff1a;部署在网络边缘或主机边缘&#xff0c;在工作中&#xff0c;防火墙的主要作用是决定哪些数据可以被外网访问以及哪些数据可以进入内网访问&#xff0c;主要在网络层工作 其他类型的安全技术&#xff1a;1、入侵检测系统 2、入侵…...

微信小程序开发5

一、自定义组件-插槽 1.1、什么是插槽 在自定义组件的wxml结构中&#xff0c;可以提供一个<slot>节点(插槽)&#xff0c;用于承载组件使用者提供的wxml结构 1.2、单个插槽 在小程序中&#xff0c;默认每个自定义组件中允许使用一个<slot>进行占位&#xff0c;这种…...

【算法题】2681. 英雄的力量

题目&#xff1a; 给你一个下标从 0 开始的整数数组 nums &#xff0c;它表示英雄的能力值。如果我们选出一部分英雄&#xff0c;这组英雄的 力量 定义为&#xff1a; i0 &#xff0c;i1 &#xff0c;… ik 表示这组英雄在数组中的下标。那么这组英雄的力量为 max(nums[i0],n…...

fastutil简单测试下性能

前言 简单测试一下fastutil的实现和Java类库实现的速率。 使用jmh进行测试。 简单解释一下&#xff0c;每轮测试预热2次&#xff0c;每次1s&#xff1b;实测2次&#xff0c;每次1秒。 进行5轮测试。数组大小3种。 package fastutil;import it.unimi.dsi.fastutil.ints.IntArr…...

【FAQ】关于无法判断和区分用户与地图交互手势类型的解决办法

一&#xff0e; 问题描述 当用户通过缩放手势、平移手势、倾斜手势和旋转手势与地图交互&#xff0c;控制地图移动改变其可见区域时&#xff0c;华为地图SDK没有提供直接获取用户手势类型的API。 二&#xff0e; 解决方案 华为地图SDK的地图相机有提供CameraPosition类&…...

腾讯云裸金属服务器CPU型号处理器主频说明

腾讯云裸金属服务器CPU型号是什么&#xff1f;标准型BMSA2裸金属服务器CPU采用AMD EPYC ROME处理器&#xff0c;BMS5实例CPU采用Intel Xeon Cooper Lake处理器&#xff0c;腾讯云服务器网分享落进书房武器CPU型号、处理器主频说明&#xff1a; 裸金属服务器CPU处理器说明 腾讯…...

【大模型RAG】Docker 一键部署 Milvus 完整攻略

本文概要 Milvus 2.5 Stand-alone 版可通过 Docker 在几分钟内完成安装&#xff1b;只需暴露 19530&#xff08;gRPC&#xff09;与 9091&#xff08;HTTP/WebUI&#xff09;两个端口&#xff0c;即可让本地电脑通过 PyMilvus 或浏览器访问远程 Linux 服务器上的 Milvus。下面…...

Golang dig框架与GraphQL的完美结合

将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用&#xff0c;可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器&#xff0c;能够帮助开发者更好地管理复杂的依赖关系&#xff0c;而 GraphQL 则是一种用于 API 的查询语言&#xff0c;能够提…...

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统&#xff1a;ubuntu22.04 IDE:Visual Studio Code 编程语言&#xff1a;C11 题目描述 地上有一个 m 行 n 列的方格&#xff0c;从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子&#xff0c;但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…...

C# 类和继承(抽象类)

抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

select、poll、epoll 与 Reactor 模式

在高并发网络编程领域&#xff0c;高效处理大量连接和 I/O 事件是系统性能的关键。select、poll、epoll 作为 I/O 多路复用技术的代表&#xff0c;以及基于它们实现的 Reactor 模式&#xff0c;为开发者提供了强大的工具。本文将深入探讨这些技术的底层原理、优缺点。​ 一、I…...

Device Mapper 机制

Device Mapper 机制详解 Device Mapper&#xff08;简称 DM&#xff09;是 Linux 内核中的一套通用块设备映射框架&#xff0c;为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程&#xff0c;并配以详细的…...

处理vxe-table 表尾数据是单独一个接口,表格tableData数据更新后,需要点击两下,表尾才是正确的

修改bug思路&#xff1a; 分别把 tabledata 和 表尾相关数据 console.log() 发现 更新数据先后顺序不对 settimeout延迟查询表格接口 ——测试可行 升级↑&#xff1a;async await 等接口返回后再开始下一个接口查询 ________________________________________________________…...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案

引言 在分布式系统的事务处理中&#xff0c;如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议&#xff08;2PC&#xff09;通过准备阶段与提交阶段的协调机制&#xff0c;以同步决策模式确保事务原子性。其改进版本三阶段提交协议&#xff08;3PC&#xf…...

Matlab实现任意伪彩色图像可视化显示

Matlab实现任意伪彩色图像可视化显示 1、灰度原始图像2、RGB彩色原始图像 在科研研究中&#xff0c;如何展示好看的实验结果图像非常重要&#xff01;&#xff01;&#xff01; 1、灰度原始图像 灰度图像每个像素点只有一个数值&#xff0c;代表该点的​​亮度&#xff08;或…...

python读取SQLite表个并生成pdf文件

代码用于创建含50列的SQLite数据库并插入500行随机浮点数据&#xff0c;随后读取数据&#xff0c;通过ReportLab生成横向PDF表格&#xff0c;包含格式化&#xff08;两位小数&#xff09;及表头、网格线等美观样式。 # 导入所需库 import sqlite3 # 用于操作…...