当前位置: 首页 > news >正文

mindspore的MLP模型(多层感知机)

导入模块

import hashlib
import os
import tarfile
import zipfile
import requests
import numpy as np
import pandas as pd
import mindspore
import mindspore.dataset as ds
from mindspore import nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore import Tensor
from IPython import display
from matplotlib import pyplot as plt

数据预处理

数据下载:https://www.kaggle.com/datasets/ahsan81/hotel-reservations-classification-dataset

train_data = pd.read_csv("Hotel Reservations_train.csv")
test_data = pd.read_csv("Hotel Reservations_test.csv")print(train_data.shape)
print(test_data.shape)
(30000, 20)
(6275, 20)
# 可去掉第0列与第1列的信息
print(train_data.iloc[0:4, [0, 1, 2, -3, -2, -1]])
   Unnamed: 0 Booking_ID  no_of_adults  avg_price_per_room  \
0           0   INN00001             2               65.00   
1           1   INN00002             2              106.68   
2           2   INN00003             1               60.00   
3           3   INN00004             2              100.00   no_of_special_requests booking_status  
0                       0   Not_Canceled  
1                       1   Not_Canceled  
2                       0       Canceled  
3                       0       Canceled  
# 将train_data和test_data合并,后面做数据预处理方便
all_features = pd.concat((train_data.iloc[:, 2:-1], test_data.iloc[:, 2:-1]))all_features
no_of_adultsno_of_childrenno_of_weekend_nightsno_of_week_nightstype_of_meal_planrequired_car_parking_spaceroom_type_reservedlead_timearrival_yeararrival_montharrival_datemarket_segment_typerepeated_guestno_of_previous_cancellationsno_of_previous_bookings_not_canceledavg_price_per_roomno_of_special_requests
02012Meal Plan 10Room_Type 12242017102Offline00065.000
12023Not Selected0Room_Type 152018116Online000106.681
21021Meal Plan 10Room_Type 112018228Online00060.000
32002Meal Plan 10Room_Type 12112018520Online000100.000
42011Not Selected0Room_Type 1482018411Online00094.500
......................................................
62703026Meal Plan 10Room_Type 485201883Online000167.801
62712013Meal Plan 10Room_Type 122820181017Online00090.952
62722026Meal Plan 10Room_Type 1148201871Online00098.392
62732003Not Selected0Room_Type 1632018421Online00094.500
62742012Meal Plan 10Room_Type 120720181230Offline000161.670

36275 rows × 17 columns

# 将所有缺失的值替换为相应特征的平均值。 通过将特征重新缩放到零均值和单位方差来标准化数据# 先将为数字类型的列取出来,dtypes[all_features.dtypes != 'object'].index 返回类型是数字的列的索引
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index
# 之后对其应用apply方法 apply中对每列进行了标准化(Z-score标准化方法)
all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std()))
# 在标准化数据之后,所有均值消失,因此我们可以将缺失值设置为0
all_features[numeric_features] = all_features[numeric_features].fillna(0)
# 处理离散值。我们用独热编码替换它们
# 独热编码:例如,“MSZoning”包含值“RL”和“Rm”。 我们将创建两个新的指示器特征“MSZoning_RL”和“MSZoning_RM”,其值为0或1。print(all_features.shape)# “Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征
all_features = pd.get_dummies(all_features, dummy_na=True)print(all_features.shape)
(36275, 17)
(36275, 33)
all_labels = pd.concat((train_data.iloc[:,-1], test_data.iloc[:, -1]))change = {'Not_Canceled':1,'Canceled':0}
all_labels = all_labels.map(change)
all_labels
0       1
1       1
2       0
3       0
4       0..
6270    1
6271    0
6272    1
6273    0
6274    1
Name: booking_status, Length: 36275, dtype: int64
n_train = train_data.shape[0]         # 提取训练样本数
train_features = all_features[:n_train].values.astype(np.float32)      # 注意要统一数据的类型:np.float32
test_features = all_features[n_train:].values.astype(np.float32)
train_labels = all_labels.iloc[:n_train].values.astype(np.int64)
test_labels = all_labels.iloc[n_train:].values.astype(np.int64)
class SyntheticData():  def __init__(self,features,labels):self.features, self.labels = features , labelsdef __getitem__(self, index):   # __getitem__(self, index) 一般用来迭代序列(常见序列如:列表、元组、字符串)return self.features[index], self.labels[index]def __len__(self):return len(self.labels)
# 数据集
train_dataset= ds.GeneratorDataset(source=SyntheticData(train_features, train_labels), column_names=['features', 'label'],python_multiprocessing=False)test_dataset= ds.GeneratorDataset(source=SyntheticData(test_features, test_labels ), column_names=['features', 'label'],python_multiprocessing=False)

构建模型

class Accumulator:  """累加器"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]
def accuracy(y_hat, y):  """计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:   # 判断y_hat是不是矩阵y_hat = y_hat.argmax(axis=1)                  # 得到每样本预测概率最大所属分类的下标cmp = y_hat.asnumpy() == y.asnumpy()              # y_hat.asnumpy() == y.asnumpy()返回的是一个布尔数组return float(cmp.sum())def evaluate_accuracy(net, data_iter):  """计算在指定数据集上模型的精度"""metric = Accumulator(2)         # 累加器,metric[0]记录正确预测数,metric[1]记录预测总数for X, y in data_iter:metric.add(accuracy(net(X), y), y.size)return metric[0] / metric[1]    # 正确预测数 / 预测总数
def train_epoch( train_iter, learning_rate, weight_decay, batch_size):  """训练模型一个迭代周期"""net = nn.SequentialCell([nn.Dense(all_features.shape[1], 32),nn.ReLU(),nn.Dense(32, 16),nn.ReLU(),nn.Dense(16, 2)]) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')#optim = nn.SGD(net.trainable_params(), learning_rate = learning_rate, weight_decay = weight_decay)optim = nn.Adam(net.trainable_params(), learning_rate = learning_rate, weight_decay = weight_decay) net_with_loss = nn.WithLossCell(net, loss)                net_train = nn.TrainOneStepCell(net_with_loss, optim)     metric = Accumulator(3)for X, y in train_iter:l = net_train(X, y)y_hat = net(X)metric.add(float(l.sum().asnumpy()), accuracy(y_hat, y), y.size)return metric[0] / metric[2], metric[1] / metric[2] ,net      # 误差 / 预测总数 ,正确预测数 / 预测总数
def trainer( train_iter, test_iter, num_epochs, learning_rate, weight_decay, batch_size, train_acc_plot, test_acc_plot):  """训练模型"""train_iter = train_iter.batch(batch_size = batch_size, num_parallel_workers=1)test_iter = test_iter.batch(batch_size = batch_size, num_parallel_workers=1)for epoch in range(num_epochs):train_metrics = train_epoch(train_iter, learning_rate, weight_decay, batch_size)train_loss, train_acc, net = train_metricstest_acc = evaluate_accuracy(net, test_iter)train_acc_plot.append(float(train_acc))test_acc_plot.append(float(test_acc))print('最终训练集精度:', train_acc, '最终测试集精度:',test_acc )# 检测assert train_loss < 0.6, train_lossassert train_acc <= 1 and train_acc > 0.7, train_accassert test_acc <= 1 and test_acc > 0.7, test_acc

训练

num_epochs,  weight_decay, batch_size  =20, 0, 64# 动态学习率
learning_rate = 0.1
end_learning_rate = 0.05
decay_steps = 6
power = 0.5
learning_rate  = nn.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)train_acc_plot=[]
test_acc_plot=[]
trainer( train_dataset, test_dataset, num_epochs, learning_rate, weight_decay, batch_size, train_acc_plot, test_acc_plot)
最终训练集精度: 0.8078666666666666 最终测试集精度: 0.8124302788844622
# 构建loss-step曲线可了解loss随epoch的变化情况plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=Falsex=np.linspace(0, num_epochs-1,num_epochs)plt.figure(figsize=(4,3)) 
plt.xlabel(u"epoch")
plt.ylabel(u"精度")
plt.plot(x, train_acc_plot, label='train acc')
plt.plot(x, test_acc_plot, label='test acc')
plt.legend(loc="best")
plt.tight_layout(rect = [0,0,1,1]) 

在这里插入图片描述

相关文章:

mindspore的MLP模型(多层感知机)

导入模块 import hashlib import os import tarfile import zipfile import requests import numpy as np import pandas as pd import mindspore import mindspore.dataset as ds from mindspore import nn import mindspore.ops as ops import mindspore.numpy as mnp from …...

【论文极速读】VQ-VAE:一种稀疏表征学习方法

【论文极速读】VQ-VAE&#xff1a;一种稀疏表征学习方法 FesianXu 20221208 at Baidu Search Team 前言 最近有需求对特征进行稀疏编码&#xff0c;看到一篇论文VQ-VAE&#xff0c;简单进行笔记下。如有谬误请联系指出&#xff0c;本文遵循 CC 4.0 BY-SA 版权协议&#xff0c;…...

Flask-Blueprint

Flask-Blueprint 一、简介 概念&#xff1a; Blueprint 是一个存储操作方法的容器&#xff0c;这些操作在这个Blueprint 被注册到一个应用之后就可以被调用&#xff0c;Flask 可以通过Blueprint来组织URL以及处理请求 。 好处&#xff1a; 其本质上来说就是让程序更加松耦合…...

png图片转eps格式

下载latex工具后 在要转换的png图片文件夹路径下&#xff0c;打开命令行窗口&#xff0c;输入以下命令&#xff1a; bmeps -c fig图片名.png 图片名.eps...

English Learning - L2 语音作业打卡 Day2 2023.2.23 周四

English Learning - L2 语音作业打卡 Day2 2023.2.23 周四&#x1f48c; 发音小贴士&#xff1a;&#x1f48c; 当日目标音发音规则/技巧&#xff1a;&#x1f36d; Part 1【热身练习】&#x1f36d; Part2【练习内容】&#x1f36d;【练习感受】&#x1f353;元音[ ɔ: ]&…...

低频量化之 可转债 配债 策略数据 - 全网独家

目录历史文章可转债配债数据待发转债&#xff08;进展统计&#xff09;待发转债&#xff08;行业统计&#xff09;待发转债&#xff08;5证监会通过&#xff0c;PE排序&#xff09;待发转债&#xff08;5证监会通过&#xff0c;安全垫排序&#xff09;待发转债&#xff08;4发审…...

论文阅读_DALLE-2的unCLIP模型

论文信息 name_en: Hierarchical Text-Conditional Image Generation with CLIP Latents name_ch: 利用CLIP的层次化文本条件图像生成 paper_addr: http://arxiv.org/abs/2204.06125 doi: 10.48550/arXiv.2204.06125 date_read: 2023-02-12 date_publish: 2022-04-12 tags: [‘…...

软件测试5年,历经3轮面试成功拿下华为Offer,24K/16薪不过分吧

前言 转眼过去&#xff0c;距离读书的时候已经这么久了吗&#xff1f;&#xff0c;从18年5月本科毕业入职了一家小公司&#xff0c;到现在快5年了&#xff0c;前段时间社招想着找一个新的工作&#xff0c;前前后后花了一个多月的时间复习以及面试&#xff0c;前几天拿到了华为的…...

【软件工程】课程作业(三道题目:需求分析、概要设计、详细设计、软件测试)

文章目录&#xff1a;故事的开头总是极尽温柔&#xff0c;故事会一直温柔……&#x1f49c;一、你怎么理解需求分析&#xff1f;1、需求分析的定义&#xff1a;2、需求分析的重要性&#xff1a;3、需求分析的内容&#xff1a;4、基于系统分析的方法分类&#xff1a;5、需求分析…...

05 DC-AC逆变器(DCAC Converter / Inverter)简介

文章目录0、概述逆变原理方波变换阶梯波变换斩控调制方式逆变器分类逆变器波形指标1、方波变换器A 单相单相全桥对称单脉冲调制移相单脉冲调制单相半桥2、方波变换器B 三相180度导通120度导通&#xff08;线、相的关系与180度相反&#xff09;3、阶梯波逆变器独立直流源二极管钳…...

带你深层了解c语言指针

前言 &#x1f388;个人主页:&#x1f388; :✨✨✨初阶牛✨✨✨ &#x1f43b;推荐专栏: &#x1f354;&#x1f35f;&#x1f32f; c语言进阶 &#x1f511;个人信条: &#x1f335;知行合一 &#x1f349;本篇简介:>:介绍c语言中有关指针更深层的知识. 金句分享: ✨今天…...

2-MATLAB APP Design-下拉菜单栏的使用

一、APP 界面设计展示 1.新建一个空白的APP,在此次的学习中,我们会用到编辑字段(文本框)、下拉菜单栏、坐标区,首先在界面中拖入一个编辑字段(文本框),在文本框中输入内容:下拉菜单栏的使用,调整背景颜色,字体的颜色为黑色,字体的大小调为26. 2.在左侧组件库常用栏…...

七、HTTPTomcatServlet

1&#xff0c;Web概述 1.1 Web和JavaWeb的概念 Web是全球广域网&#xff0c;也称为万维网(www)&#xff0c;能够通过浏览器访问的网站。 在我们日常的生活中&#xff0c;经常会使用浏览器去访问百度、京东、传智官网等这些网站&#xff0c;这些网站统称为Web网站。如下就是通…...

LeetCode 热题 C++ 198. 打家劫舍

力扣198 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯入&#xff0c;系统会自动报警。 给定一个代表每个房屋存…...

C语言学习笔记——程序环境和预处理

目录 前言 一、程序环境 1. 翻译环境 1.1 主要过程 1.2 编译过程 2. 运行环境 二、预处理 1. 预定义符号 2. #define 2.1 #define定义标识符 2.2 #define定义宏 2.3 命名约定和移除定义 3. 条件编译 4. 文件包含 结束语 前言 每次我们写完代码运行的时候都…...

「JVM 高效并发」Java 内存模型

Amdahl 定律代替摩尔定律成为了计算机性能发展的新源动力&#xff0c;也是人类压榨计算机运算能力的最有力武器&#xff1b; 摩尔定律&#xff0c;描述处理器晶体管数量与运行效率之间的发展关系&#xff1b;Amdahl 定律&#xff0c;描述系统并行化与串行化的比重与系统运算加…...

C语言刷题(2)——“C”

各位CSDN的uu们你们好呀&#xff0c;今天小雅兰来复习一下之前所学过的内容噢&#xff0c;复习的方式&#xff0c;那当然是刷题啦&#xff0c;现在&#xff0c;就让我们进入C语言的世界吧 当然&#xff0c;题目还是来源于牛客网 完完全全零基础 编程语言初学训练营_在线编程题…...

第一个 Spring MVC 注解式开发案例(初学必看)

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

openresty学习笔记

openresty 简介 openresty 是一个基于 nginx 与 lua 的高性能 web 平台&#xff0c;其内部 集成了大量精良的 lua 库、第三方模块以及大数的依赖项。用于 方便搭建能够处理超高并发、扩展性极高的动态 web 应用、 web 服务和动态网关。 openresty 通过汇聚各种设计精良的 ngi…...

微信小程序DAY3

文章目录一、页面导航1-1、声明式导航1-2、编程式导航1-3、声明式导航传参1-4、编程式导航传参1-5、获取导航传递的参数二、页面事件2-1、下拉刷新事件2-1-1、启用下拉刷新2-1-2、配置下拉刷新2-1-3、监听页面下拉刷新事件2-2、上拉触底事件2-2-1、事件触发2-2-1、事件配置三、…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址&#xff1a;pdf 英文是纯手打的&#xff01;论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误&#xff0c;若有发现欢迎评论指正&#xff01;文章偏向于笔记&#xff0c;谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本&#xff1a; 3.8.1 语言&#xff1a; JavaScript/TypeScript、C、Java 环境&#xff1a;Window 参考&#xff1a;Java原生反射机制 您好&#xff0c;我是鹤九日&#xff01; 回顾 在上篇文章中&#xff1a;CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代&#xff0c;智能代理&#xff08;agents&#xff09;不再是孤立的个体&#xff0c;而是能够像一个数字团队一样协作。然而&#xff0c;当前 AI 生态系统的碎片化阻碍了这一愿景的实现&#xff0c;导致了“AI 巴别塔问题”——不同代理之间…...

leetcodeSQL解题:3564. 季节性销售分析

leetcodeSQL解题&#xff1a;3564. 季节性销售分析 题目&#xff1a; 表&#xff1a;sales ---------------------- | Column Name | Type | ---------------------- | sale_id | int | | product_id | int | | sale_date | date | | quantity | int | | price | decimal | -…...

selenium学习实战【Python爬虫】

selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...

dify打造数据可视化图表

一、概述 在日常工作和学习中&#xff0c;我们经常需要和数据打交道。无论是分析报告、项目展示&#xff0c;还是简单的数据洞察&#xff0c;一个清晰直观的图表&#xff0c;往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server&#xff0c;由蚂蚁集团 AntV 团队…...

Mysql中select查询语句的执行过程

目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析&#xff08;Parser&#xff09; 2.4、执行sql 1. 预处理&#xff08;Preprocessor&#xff09; 2. 查询优化器&#xff08;Optimizer&#xff09; 3. 执行器…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...

基于IDIG-GAN的小样本电机轴承故障诊断

目录 🔍 核心问题 一、IDIG-GAN模型原理 1. 整体架构 2. 核心创新点 (1) ​梯度归一化(Gradient Normalization)​​ (2) ​判别器梯度间隙正则化(Discriminator Gradient Gap Regularization)​​ (3) ​自注意力机制(Self-Attention)​​ 3. 完整损失函数 二…...