当前位置: 首页 > news >正文

mindspore的MLP模型(多层感知机)

导入模块

import hashlib
import os
import tarfile
import zipfile
import requests
import numpy as np
import pandas as pd
import mindspore
import mindspore.dataset as ds
from mindspore import nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore import Tensor
from IPython import display
from matplotlib import pyplot as plt

数据预处理

数据下载:https://www.kaggle.com/datasets/ahsan81/hotel-reservations-classification-dataset

train_data = pd.read_csv("Hotel Reservations_train.csv")
test_data = pd.read_csv("Hotel Reservations_test.csv")print(train_data.shape)
print(test_data.shape)
(30000, 20)
(6275, 20)
# 可去掉第0列与第1列的信息
print(train_data.iloc[0:4, [0, 1, 2, -3, -2, -1]])
   Unnamed: 0 Booking_ID  no_of_adults  avg_price_per_room  \
0           0   INN00001             2               65.00   
1           1   INN00002             2              106.68   
2           2   INN00003             1               60.00   
3           3   INN00004             2              100.00   no_of_special_requests booking_status  
0                       0   Not_Canceled  
1                       1   Not_Canceled  
2                       0       Canceled  
3                       0       Canceled  
# 将train_data和test_data合并,后面做数据预处理方便
all_features = pd.concat((train_data.iloc[:, 2:-1], test_data.iloc[:, 2:-1]))all_features
no_of_adultsno_of_childrenno_of_weekend_nightsno_of_week_nightstype_of_meal_planrequired_car_parking_spaceroom_type_reservedlead_timearrival_yeararrival_montharrival_datemarket_segment_typerepeated_guestno_of_previous_cancellationsno_of_previous_bookings_not_canceledavg_price_per_roomno_of_special_requests
02012Meal Plan 10Room_Type 12242017102Offline00065.000
12023Not Selected0Room_Type 152018116Online000106.681
21021Meal Plan 10Room_Type 112018228Online00060.000
32002Meal Plan 10Room_Type 12112018520Online000100.000
42011Not Selected0Room_Type 1482018411Online00094.500
......................................................
62703026Meal Plan 10Room_Type 485201883Online000167.801
62712013Meal Plan 10Room_Type 122820181017Online00090.952
62722026Meal Plan 10Room_Type 1148201871Online00098.392
62732003Not Selected0Room_Type 1632018421Online00094.500
62742012Meal Plan 10Room_Type 120720181230Offline000161.670

36275 rows × 17 columns

# 将所有缺失的值替换为相应特征的平均值。 通过将特征重新缩放到零均值和单位方差来标准化数据# 先将为数字类型的列取出来,dtypes[all_features.dtypes != 'object'].index 返回类型是数字的列的索引
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].index
# 之后对其应用apply方法 apply中对每列进行了标准化(Z-score标准化方法)
all_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std()))
# 在标准化数据之后,所有均值消失,因此我们可以将缺失值设置为0
all_features[numeric_features] = all_features[numeric_features].fillna(0)
# 处理离散值。我们用独热编码替换它们
# 独热编码:例如,“MSZoning”包含值“RL”和“Rm”。 我们将创建两个新的指示器特征“MSZoning_RL”和“MSZoning_RM”,其值为0或1。print(all_features.shape)# “Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征
all_features = pd.get_dummies(all_features, dummy_na=True)print(all_features.shape)
(36275, 17)
(36275, 33)
all_labels = pd.concat((train_data.iloc[:,-1], test_data.iloc[:, -1]))change = {'Not_Canceled':1,'Canceled':0}
all_labels = all_labels.map(change)
all_labels
0       1
1       1
2       0
3       0
4       0..
6270    1
6271    0
6272    1
6273    0
6274    1
Name: booking_status, Length: 36275, dtype: int64
n_train = train_data.shape[0]         # 提取训练样本数
train_features = all_features[:n_train].values.astype(np.float32)      # 注意要统一数据的类型:np.float32
test_features = all_features[n_train:].values.astype(np.float32)
train_labels = all_labels.iloc[:n_train].values.astype(np.int64)
test_labels = all_labels.iloc[n_train:].values.astype(np.int64)
class SyntheticData():  def __init__(self,features,labels):self.features, self.labels = features , labelsdef __getitem__(self, index):   # __getitem__(self, index) 一般用来迭代序列(常见序列如:列表、元组、字符串)return self.features[index], self.labels[index]def __len__(self):return len(self.labels)
# 数据集
train_dataset= ds.GeneratorDataset(source=SyntheticData(train_features, train_labels), column_names=['features', 'label'],python_multiprocessing=False)test_dataset= ds.GeneratorDataset(source=SyntheticData(test_features, test_labels ), column_names=['features', 'label'],python_multiprocessing=False)

构建模型

class Accumulator:  """累加器"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]
def accuracy(y_hat, y):  """计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:   # 判断y_hat是不是矩阵y_hat = y_hat.argmax(axis=1)                  # 得到每样本预测概率最大所属分类的下标cmp = y_hat.asnumpy() == y.asnumpy()              # y_hat.asnumpy() == y.asnumpy()返回的是一个布尔数组return float(cmp.sum())def evaluate_accuracy(net, data_iter):  """计算在指定数据集上模型的精度"""metric = Accumulator(2)         # 累加器,metric[0]记录正确预测数,metric[1]记录预测总数for X, y in data_iter:metric.add(accuracy(net(X), y), y.size)return metric[0] / metric[1]    # 正确预测数 / 预测总数
def train_epoch( train_iter, learning_rate, weight_decay, batch_size):  """训练模型一个迭代周期"""net = nn.SequentialCell([nn.Dense(all_features.shape[1], 32),nn.ReLU(),nn.Dense(32, 16),nn.ReLU(),nn.Dense(16, 2)]) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')#optim = nn.SGD(net.trainable_params(), learning_rate = learning_rate, weight_decay = weight_decay)optim = nn.Adam(net.trainable_params(), learning_rate = learning_rate, weight_decay = weight_decay) net_with_loss = nn.WithLossCell(net, loss)                net_train = nn.TrainOneStepCell(net_with_loss, optim)     metric = Accumulator(3)for X, y in train_iter:l = net_train(X, y)y_hat = net(X)metric.add(float(l.sum().asnumpy()), accuracy(y_hat, y), y.size)return metric[0] / metric[2], metric[1] / metric[2] ,net      # 误差 / 预测总数 ,正确预测数 / 预测总数
def trainer( train_iter, test_iter, num_epochs, learning_rate, weight_decay, batch_size, train_acc_plot, test_acc_plot):  """训练模型"""train_iter = train_iter.batch(batch_size = batch_size, num_parallel_workers=1)test_iter = test_iter.batch(batch_size = batch_size, num_parallel_workers=1)for epoch in range(num_epochs):train_metrics = train_epoch(train_iter, learning_rate, weight_decay, batch_size)train_loss, train_acc, net = train_metricstest_acc = evaluate_accuracy(net, test_iter)train_acc_plot.append(float(train_acc))test_acc_plot.append(float(test_acc))print('最终训练集精度:', train_acc, '最终测试集精度:',test_acc )# 检测assert train_loss < 0.6, train_lossassert train_acc <= 1 and train_acc > 0.7, train_accassert test_acc <= 1 and test_acc > 0.7, test_acc

训练

num_epochs,  weight_decay, batch_size  =20, 0, 64# 动态学习率
learning_rate = 0.1
end_learning_rate = 0.05
decay_steps = 6
power = 0.5
learning_rate  = nn.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)train_acc_plot=[]
test_acc_plot=[]
trainer( train_dataset, test_dataset, num_epochs, learning_rate, weight_decay, batch_size, train_acc_plot, test_acc_plot)
最终训练集精度: 0.8078666666666666 最终测试集精度: 0.8124302788844622
# 构建loss-step曲线可了解loss随epoch的变化情况plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=Falsex=np.linspace(0, num_epochs-1,num_epochs)plt.figure(figsize=(4,3)) 
plt.xlabel(u"epoch")
plt.ylabel(u"精度")
plt.plot(x, train_acc_plot, label='train acc')
plt.plot(x, test_acc_plot, label='test acc')
plt.legend(loc="best")
plt.tight_layout(rect = [0,0,1,1]) 

在这里插入图片描述

相关文章:

mindspore的MLP模型(多层感知机)

导入模块 import hashlib import os import tarfile import zipfile import requests import numpy as np import pandas as pd import mindspore import mindspore.dataset as ds from mindspore import nn import mindspore.ops as ops import mindspore.numpy as mnp from …...

【论文极速读】VQ-VAE:一种稀疏表征学习方法

【论文极速读】VQ-VAE&#xff1a;一种稀疏表征学习方法 FesianXu 20221208 at Baidu Search Team 前言 最近有需求对特征进行稀疏编码&#xff0c;看到一篇论文VQ-VAE&#xff0c;简单进行笔记下。如有谬误请联系指出&#xff0c;本文遵循 CC 4.0 BY-SA 版权协议&#xff0c;…...

Flask-Blueprint

Flask-Blueprint 一、简介 概念&#xff1a; Blueprint 是一个存储操作方法的容器&#xff0c;这些操作在这个Blueprint 被注册到一个应用之后就可以被调用&#xff0c;Flask 可以通过Blueprint来组织URL以及处理请求 。 好处&#xff1a; 其本质上来说就是让程序更加松耦合…...

png图片转eps格式

下载latex工具后 在要转换的png图片文件夹路径下&#xff0c;打开命令行窗口&#xff0c;输入以下命令&#xff1a; bmeps -c fig图片名.png 图片名.eps...

English Learning - L2 语音作业打卡 Day2 2023.2.23 周四

English Learning - L2 语音作业打卡 Day2 2023.2.23 周四&#x1f48c; 发音小贴士&#xff1a;&#x1f48c; 当日目标音发音规则/技巧&#xff1a;&#x1f36d; Part 1【热身练习】&#x1f36d; Part2【练习内容】&#x1f36d;【练习感受】&#x1f353;元音[ ɔ: ]&…...

低频量化之 可转债 配债 策略数据 - 全网独家

目录历史文章可转债配债数据待发转债&#xff08;进展统计&#xff09;待发转债&#xff08;行业统计&#xff09;待发转债&#xff08;5证监会通过&#xff0c;PE排序&#xff09;待发转债&#xff08;5证监会通过&#xff0c;安全垫排序&#xff09;待发转债&#xff08;4发审…...

论文阅读_DALLE-2的unCLIP模型

论文信息 name_en: Hierarchical Text-Conditional Image Generation with CLIP Latents name_ch: 利用CLIP的层次化文本条件图像生成 paper_addr: http://arxiv.org/abs/2204.06125 doi: 10.48550/arXiv.2204.06125 date_read: 2023-02-12 date_publish: 2022-04-12 tags: [‘…...

软件测试5年,历经3轮面试成功拿下华为Offer,24K/16薪不过分吧

前言 转眼过去&#xff0c;距离读书的时候已经这么久了吗&#xff1f;&#xff0c;从18年5月本科毕业入职了一家小公司&#xff0c;到现在快5年了&#xff0c;前段时间社招想着找一个新的工作&#xff0c;前前后后花了一个多月的时间复习以及面试&#xff0c;前几天拿到了华为的…...

【软件工程】课程作业(三道题目:需求分析、概要设计、详细设计、软件测试)

文章目录&#xff1a;故事的开头总是极尽温柔&#xff0c;故事会一直温柔……&#x1f49c;一、你怎么理解需求分析&#xff1f;1、需求分析的定义&#xff1a;2、需求分析的重要性&#xff1a;3、需求分析的内容&#xff1a;4、基于系统分析的方法分类&#xff1a;5、需求分析…...

05 DC-AC逆变器(DCAC Converter / Inverter)简介

文章目录0、概述逆变原理方波变换阶梯波变换斩控调制方式逆变器分类逆变器波形指标1、方波变换器A 单相单相全桥对称单脉冲调制移相单脉冲调制单相半桥2、方波变换器B 三相180度导通120度导通&#xff08;线、相的关系与180度相反&#xff09;3、阶梯波逆变器独立直流源二极管钳…...

带你深层了解c语言指针

前言 &#x1f388;个人主页:&#x1f388; :✨✨✨初阶牛✨✨✨ &#x1f43b;推荐专栏: &#x1f354;&#x1f35f;&#x1f32f; c语言进阶 &#x1f511;个人信条: &#x1f335;知行合一 &#x1f349;本篇简介:>:介绍c语言中有关指针更深层的知识. 金句分享: ✨今天…...

2-MATLAB APP Design-下拉菜单栏的使用

一、APP 界面设计展示 1.新建一个空白的APP,在此次的学习中,我们会用到编辑字段(文本框)、下拉菜单栏、坐标区,首先在界面中拖入一个编辑字段(文本框),在文本框中输入内容:下拉菜单栏的使用,调整背景颜色,字体的颜色为黑色,字体的大小调为26. 2.在左侧组件库常用栏…...

七、HTTPTomcatServlet

1&#xff0c;Web概述 1.1 Web和JavaWeb的概念 Web是全球广域网&#xff0c;也称为万维网(www)&#xff0c;能够通过浏览器访问的网站。 在我们日常的生活中&#xff0c;经常会使用浏览器去访问百度、京东、传智官网等这些网站&#xff0c;这些网站统称为Web网站。如下就是通…...

LeetCode 热题 C++ 198. 打家劫舍

力扣198 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯入&#xff0c;系统会自动报警。 给定一个代表每个房屋存…...

C语言学习笔记——程序环境和预处理

目录 前言 一、程序环境 1. 翻译环境 1.1 主要过程 1.2 编译过程 2. 运行环境 二、预处理 1. 预定义符号 2. #define 2.1 #define定义标识符 2.2 #define定义宏 2.3 命名约定和移除定义 3. 条件编译 4. 文件包含 结束语 前言 每次我们写完代码运行的时候都…...

「JVM 高效并发」Java 内存模型

Amdahl 定律代替摩尔定律成为了计算机性能发展的新源动力&#xff0c;也是人类压榨计算机运算能力的最有力武器&#xff1b; 摩尔定律&#xff0c;描述处理器晶体管数量与运行效率之间的发展关系&#xff1b;Amdahl 定律&#xff0c;描述系统并行化与串行化的比重与系统运算加…...

C语言刷题(2)——“C”

各位CSDN的uu们你们好呀&#xff0c;今天小雅兰来复习一下之前所学过的内容噢&#xff0c;复习的方式&#xff0c;那当然是刷题啦&#xff0c;现在&#xff0c;就让我们进入C语言的世界吧 当然&#xff0c;题目还是来源于牛客网 完完全全零基础 编程语言初学训练营_在线编程题…...

第一个 Spring MVC 注解式开发案例(初学必看)

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…...

openresty学习笔记

openresty 简介 openresty 是一个基于 nginx 与 lua 的高性能 web 平台&#xff0c;其内部 集成了大量精良的 lua 库、第三方模块以及大数的依赖项。用于 方便搭建能够处理超高并发、扩展性极高的动态 web 应用、 web 服务和动态网关。 openresty 通过汇聚各种设计精良的 ngi…...

微信小程序DAY3

文章目录一、页面导航1-1、声明式导航1-2、编程式导航1-3、声明式导航传参1-4、编程式导航传参1-5、获取导航传递的参数二、页面事件2-1、下拉刷新事件2-1-1、启用下拉刷新2-1-2、配置下拉刷新2-1-3、监听页面下拉刷新事件2-2、上拉触底事件2-2-1、事件触发2-2-1、事件配置三、…...

生成xcframework

打包 XCFramework 的方法 XCFramework 是苹果推出的一种多平台二进制分发格式&#xff0c;可以包含多个架构和平台的代码。打包 XCFramework 通常用于分发库或框架。 使用 Xcode 命令行工具打包 通过 xcodebuild 命令可以打包 XCFramework。确保项目已经配置好需要支持的平台…...

椭圆曲线密码学(ECC)

一、ECC算法概述 椭圆曲线密码学&#xff08;Elliptic Curve Cryptography&#xff09;是基于椭圆曲线数学理论的公钥密码系统&#xff0c;由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA&#xff0c;ECC在相同安全强度下密钥更短&#xff08;256位ECC ≈ 3072位RSA…...

【JVM】- 内存结构

引言 JVM&#xff1a;Java Virtual Machine 定义&#xff1a;Java虚拟机&#xff0c;Java二进制字节码的运行环境好处&#xff1a; 一次编写&#xff0c;到处运行自动内存管理&#xff0c;垃圾回收的功能数组下标越界检查&#xff08;会抛异常&#xff0c;不会覆盖到其他代码…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

自然语言处理——Transformer

自然语言处理——Transformer 自注意力机制多头注意力机制Transformer 虽然循环神经网络可以对具有序列特性的数据非常有效&#xff0c;它能挖掘数据中的时序信息以及语义信息&#xff0c;但是它有一个很大的缺陷——很难并行化。 我们可以考虑用CNN来替代RNN&#xff0c;但是…...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台

🎯 使用 Streamlit 构建支持主流大模型与 Ollama 的轻量级统一平台 📌 项目背景 随着大语言模型(LLM)的广泛应用,开发者常面临多个挑战: 各大模型(OpenAI、Claude、Gemini、Ollama)接口风格不统一;缺乏一个统一平台进行模型调用与测试;本地模型 Ollama 的集成与前…...

【生成模型】视频生成论文调研

工作清单 上游应用方向&#xff1a;控制、速度、时长、高动态、多主体驱动 类型工作基础模型WAN / WAN-VACE / HunyuanVideo控制条件轨迹控制ATI~镜头控制ReCamMaster~多主体驱动Phantom~音频驱动Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation速…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)

安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...

STM32HAL库USART源代码解析及应用

STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...