当前位置: 首页 > news >正文

深度学习基础小结_项目实战:手机价格预测

目录

库函数导入

一、构建数据集

 二、构建分类网络模型

三、编写训练函数

四、编写评估函数

五、网络性能调优


鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。 他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。

鲍勃想找出手机的特性(例如:RAM、内存等)和售价之间的关系。但他不太擅长机器学习。所以他需要你帮他解决这个问题。 在这个问题中,你不需要预测实际价格,而是要预测一个价格区间,表明价格多高。

需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

数据说明:手机价格分类_数据集-阿里云天池 Mobile Price Classification

库函数导入

import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import import Dataset,DataLoader,TensorDataset
from sklearn.preprocessing import StandardScaler
import time

一、构建数据集

数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。

"""构建数据"""
def phone_data_set(path):# 加载本地数据文件data = pd.read_csv(path)# 抽离特征和目标数据x = data.iloc[:,:-1]y = data.iloc[:,-1]# 标准化transfer = StanderdScaler()x = transfer.fit_transform(x.values)x = torch.tensor(x,dtype=torch.float32)y = torch.tensor(y.values,dtype=torch.int64) # 输出是分类结果,所以用整型# 数据集划分x_trian,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=42,stratify=y)return x_trian,x_test,y_train,y_testclass my_phone_data_loader(Dataset):
# TensorDataset 可取代该类的作用def __init__(self,x,y):self.x = xself.y = ydef __len__(self):return len(self.x)def __getitem__(self, index):return self.x[index],self.y[index]def data_loader(x_trian,y_train,batch_size=16):# data=my_phone_data_loader(x_train,y_train)data = TensorDataset(x_train,y_trian)data_loader = DataLoader(data,batch_size=batch_size,shuffle=True)return data_loader

 二、构建分类网络模型

构建用于手机价格分类的模型叫做全连接神经网络。它主要由数个线性层来构建,在每个线性层后,还需使用激活函数。

"""模型 构建和初始化参数"""
class Net(torch.nn.Module):def __init__(self,input_features,out_features):super(Net,self).__init__()# 隐藏层 LeakyReLU 激活和输出层 Softmax 激活self.hide1 = nn.Sequential(nn.Linear(input_features,128),nn.LeakyReLU())self.hide2 = nn.Sequential(nn.Linear(128,256),nn.LeakyReLU())self.hide3 = nn.Sequential(nn.Linear(256,512),nn.LeakyReLU())self.hide4 = nn.Sequential(nn.Linear(512,128),nn.LeakyReLU())self.out = nn.Sequential(nn.Linear(128,out_features),nn.Softmax())self.initdata()def forward(self,input_data):# 前向传播x = self.hide1(iput_data)x=self.hide2(x)x=self.hide3(x)x=self.hide4(x)y_pred=self.out(x)return y_pred#分类结果,混淆矩阵 [[0.9,0.01,.02,.7],[...],...]def initdata(self):# 权重He初始化nn.init.kaiming_uniform_(self.hide1[0].weight,nonlinearity="leaky_relu")nn.init.kaiming_uniform_(self.hide2[0].weight,nonlinearity="leaky_relu")nn.init.kaiming_uniform_(self.hide3[0].weight,nonlinearity="leaky_relu")nn.init.kaiming_uniform_(self.hide4[0].weight,nonlinearity="leaky_relu")       

三、编写训练函数

网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。

"""训练数据"""
# 加载数据
x_train,x_test,y_trian,y_test=phone_data_set("./data/手机价格预测.csv")
def train():data_loader_=data_loader(x_train,y_trian)# 模型生成x_features=x_train.shape[1]y_features=torch.unique(y_trian).shape[0]# 输出特征(类别的数量)model=Net(x_features,y_features)# 初始化模型参数# 默认是初始化过的torch.nn.init.kaiming_uniform_(model.linear1.weight,nonlinearity="leaky_relu")torch.nn.init.kaiming_uniform_(model.linear2.weight,nonlinearity="leaky_relu") # 3.损失函数loss_fn=torch.nn.CrossEntropyLoss()#torch.nn.MSELoss()#虽然分类的结果也可以用均方误差来计算,但是一般用交叉熵(因为计算出来的梯度更大)# 4.优化器optim=torch.optim.Adam(model.parameters(),lr=1e-4)# 定义训练参数epoch=100for i in range(epoch):e=0count=0start_time=time.time()for x,y in data_loader_:count+=1# 生成预测值y_pred=model(x)#执行model对象的forward# 损失计算loss=loss_fn(y_pred,y)e+=loss# 梯度清零optim.zero_grad()# 反向传播loss.backward()# 更新参数optim.step()end_time=time.time()print(f"epoch:{i},loss:{e/count},time:{end_time-start_time}") # 保存模型参数# model.linear1.weight.datatorch.save(model.state_dict(),"./model/model.pth")

四、编写评估函数

评估函数、也叫预测函数、推理函数,主要使用训练好的模型,对未知的样本的进行预测的过程。这里使用前面单独划分出来的测试集来进行评估。

"""评估函数"""
def test():# 加载数据data_loader_=data_loader(x_test,y_test,batch_size=16)# data_loader_=DataLoader(data,batch_size=4,shuffle=True)# 加载模型# 模型生成x_test_features=x_test.shape[1]y_test_features=torch.unique(y_test).shape[0]# 类别数model=Net(x_test_features,y_test_features)# model.linear1.weight=model.linear1.weight.datastate_dict=torch.load("./model/model.pth",map_location="cpu")model.load_state_dict(state_dict)total=0for x,y in data_loader_:y_pred=model(x)#[0.4,0.3,0.1,0.1,0.1]y_pred=torch.argmax(y_pred,dim=1)# print("y",y)# print("y_pred",y_pred)total+=torch.sum(y_pred==y)print(f"精准度:{total/len(x_test)}")

五、网络性能调优

可以通过以下方面对模型进行调优:

  1. 对输入数据进行标准化

  2. 调整优化方法

  3. 调整学习率

  4. 增加批量归一化层

  5. 增加网络层数、神经元个数

  6. 增加训练轮数

相关文章:

深度学习基础小结_项目实战:手机价格预测

目录 库函数导入 一、构建数据集 二、构建分类网络模型 三、编写训练函数 四、编写评估函数 五、网络性能调优 鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。 他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地…...

EMall实践DDD模拟电商系统总结

目录 一、事件风暴 二、系统用例 三、领域上下文 四、架构设计 (一)六边形架构 (二)系统分层 五、系统实现 (一)项目结构 (二)提交订单功能实现 (三&#xff0…...

【随笔】AI技术在电商中的应用

这几年,伴随着ChatGPT开始的AI浪潮席卷全球,从聊天场景逐步向多场景扩散,形成了广泛开花的现象。至今,虽然在部分场景的进展已经略显疲态,但当前的这种趋势仍然还在不断的扩展。不少公司,甚至有不少大型电商…...

序列式容器详细攻略(vector、list)C++

vector std::vector 是 STL 提供的 内存连续的、可变长度 的数组(亦称列表)数据结构。能够提供线性复杂度的插入和删除,以及常数复杂度的随机访问。 为什么要使用 vector 作为 OIer,对程序效率的追求远比对工程级别的稳定性要高得多,而 vector 由于其对内存的动态处理,…...

快速启动项目

1 后端项目 https://gitee.com/liuyunkai666/gungun-boot.git 分支: mini 是 springboot3 jdk17 的基础版本,后续其他功能模块陆续在其基础上追加即可​。 1.1 必备环境 1.1.1 mysql 创建一个 自定义名称 数据库,【只要】 执行对应数据库…...

springboot347基于web的铁路订票管理系统(论文+源码)_kaic

摘 要 当今社会进入了科技进步、经济社会快速发展的新时代。计算机技术对经济社会发展和人民生活改善的影响也日益突出,人类的生存和思考方式也产生了变化。传统铁路订票管理采取了人工的管理方法,但这种管理方法存在着许多弊端,比如效率低…...

使用API管理Dynadot域名,在账户中添加域名服务器(Name Server)

前言 Dynadot是通过ICANN认证的域名注册商,自2002年成立以来,服务于全球108个国家和地区的客户,为数以万计的客户提供简洁,优惠,安全的域名注册以及管理服务。 Dynadot平台操作教程索引(包括域名邮箱&…...

【Linux | 计网】TCP协议深度解析:从连接管理到流量控制与滑动窗口

目录 前言: 1、三次握手和四次挥手的联系: 为什么挥手必须要将ACK和FIN分开呢? 2.理解 CLOSE_WAIT 状态 CLOSE_WAIT状态的特点 3.FIN_WAIT状态讲解 3.1、FIN_WAIT_1状态 3.2、FIN_WAIT_2状态 3.3、FIN_WAIT状态的作用与意义 4.理解…...

go语言的成神之路-筑基篇-对文件的操作

目录 一、对文件的读写 Reader 接口 Writer接口 copy接口 bufio的使用 ioutil库 二、cat命令 三、包 1. 包的声明 2. 导入包 3. 包的可见性 4. 包的初始化 5. 标准库包 6. 第三方包 7. 包的组织 8. 包的别名 9. 包的路径 10. 包的版本管理 四、go mod 1. 初始…...

两道数据结构编程题

1.写出在顺序存储结构下将线性表逆转的算法,要求使用最少的附加空间。 解:输入:长度为n的线性表数组A(1:n) 输出:逆转后的长度为n的线性表数组A(1:n)。 C语言描述如下(其中ET为数据元素的类型):…...

【Qt】QDateTimeEdit控件实现清空(不保留默认时间/最小时间)

一、QDateTimeEdit控件 QDateTimeEdit 提供了一个用于编辑日期和时间的控件。用户可以通过键盘或使用上下箭头键来增加或减少日期和时间值。日期和时间的显示格式根据设置的格式显示,可以通过 setDisplayFormat() 方法来设置。 二、如何清空 我在使用的时候&#…...

12、字符串

1、字符串概念 字符串用来存储一组字符,因此需要字符数组来存。 C语言中字符串的表示 C语言里面字符串只能用字符数组来存 字符串要求这个数组的末尾必须至少有一个\0 char ch1[] {a,b,c}; // 不是字符串 char ch2[5] {h,e,l,l,o}; // 不是字符串 char…...

DPDK用户态协议栈-Tcp Posix API 1

和udp一样&#xff0c;我们需要实现和系统调用一样的接口来实现我们的tcp server。先来看看我们之前写的unix_tcp使用了哪些接口&#xff0c;这边我加上两个系统调用&#xff0c;分别是接收数据和发送数据。 #include <stdio.h> #include <arpa/inet.h> #include …...

【人工智能-科普】图神经网络(GNN):与传统神经网络的区别与优势

文章目录 图神经网络(GNN):与传统神经网络的区别与优势什么是图神经网络?图的基本概念GNN的工作原理GNN与传统神经网络的不同1. 数据结构的不同2. 信息传递方式的不同3. 模型的可扩展性4. 局部与全局信息的结合GNN的应用领域总结图神经网络(GNN):与传统神经网络的区别与…...

LabVIEW实现UDP通信

目录 1、UDP通信原理 2、硬件环境部署 3、云端环境部署 4、UDP通信函数 5、程序架构 6、前面板设计 7、程序框图设计 8、测试验证 本专栏以LabVIEW为开发平台,讲解物联网通信组网原理与开发方法,覆盖RS232、TCP、MQTT、蓝牙、Wi-Fi、NB-IoT等协议。 结合实际案例,展示如何利…...

[pdf,epub]228页《分析模式》漫谈合集01-45提供下载

《分析模式》漫谈合集01-45的pdf、epub文件提供下载。已上传至本号的CSDN资源。 如果CSDN资源下载有问题&#xff0c;可到umlchina.com/url/ap.html。 已排版成适合手机阅读&#xff0c;pdf的排版更好一些。 ★UMLChina为什么叒要翻译《分析模式》&#xff1f; ★[缝合故事]…...

Kafka的消费消息是如何传递的?

大家好&#xff0c;我是锋哥。今天分享关于【Kafka的消费消息是如何传递的&#xff1f;】面试题。希望对大家有帮助&#xff1b; Kafka的消费消息是如何传递的&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在Kafka中&#xff0c;消息的消费是通过消费…...

二分查找(Java实现)(1)

二分查找&#xff08;Java实现&#xff09;&#xff08;1&#xff09; leetcode 34.排序数组中查找元素第一个和最后一个位置 题目描述: 给你一个按照非递减顺序排列的整数数组 nums&#xff0c;和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如…...

力扣103.二叉树的锯齿形层序遍历

题目描述 题目链接103. 二叉树的锯齿形层序遍历 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09;。 示例 1&#xff…...

Search with Orama

1.前言 在不久之前&#xff0c;我把 DevNow 的搜索组件通过 Lunr 进行了重构&#xff0c;从前端角度实现了对文章内容的搜索&#xff0c;但是在使用体验上&#xff0c;感觉不是特别好&#xff0c;大概有如下几个原因&#xff1a; 社区的文章数量比较少&#xff0c;项目的 Com…...

基于大模型的 UI 自动化系统

基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...

CTF show Web 红包题第六弹

提示 1.不是SQL注入 2.需要找关键源码 思路 进入页面发现是一个登录框&#xff0c;很难让人不联想到SQL注入&#xff0c;但提示都说了不是SQL注入&#xff0c;所以就不往这方面想了 ​ 先查看一下网页源码&#xff0c;发现一段JavaScript代码&#xff0c;有一个关键类ctfs…...

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする

日语学习-日语知识点小记-构建基础-JLPT-N4阶段(33):にする 1、前言(1)情况说明(2)工程师的信仰2、知识点(1) にする1,接续:名词+にする2,接续:疑问词+にする3,(A)は(B)にする。(2)復習:(1)复习句子(2)ために & ように(3)そう(4)にする3、…...

使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装

以下是基于 vant-ui&#xff08;适配 Vue2 版本 &#xff09;实现截图中照片上传预览、删除功能&#xff0c;并封装成可复用组件的完整代码&#xff0c;包含样式和逻辑实现&#xff0c;可直接在 Vue2 项目中使用&#xff1a; 1. 封装的图片上传组件 ImageUploader.vue <te…...

SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现

摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序&#xff0c;以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务&#xff0c;提供稳定高效的数据处理与业务逻辑支持&#xff1b;利用 uniapp 实现跨平台前…...

3-11单元格区域边界定位(End属性)学习笔记

返回一个Range 对象&#xff0c;只读。该对象代表包含源区域的区域上端下端左端右端的最后一个单元格。等同于按键 End 向上键(End(xlUp))、End向下键(End(xlDown))、End向左键(End(xlToLeft)End向右键(End(xlToRight)) 注意&#xff1a;它移动的位置必须是相连的有内容的单元格…...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)

安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...

HubSpot推出与ChatGPT的深度集成引发兴奋与担忧

上周三&#xff0c;HubSpot宣布已构建与ChatGPT的深度集成&#xff0c;这一消息在HubSpot用户和营销技术观察者中引发了极大的兴奋&#xff0c;但同时也存在一些关于数据安全的担忧。 许多网络声音声称&#xff0c;这对SaaS应用程序和人工智能而言是一场范式转变。 但向任何技…...

适应性Java用于现代 API:REST、GraphQL 和事件驱动

在快速发展的软件开发领域&#xff0c;REST、GraphQL 和事件驱动架构等新的 API 标准对于构建可扩展、高效的系统至关重要。Java 在现代 API 方面以其在企业应用中的稳定性而闻名&#xff0c;不断适应这些现代范式的需求。随着不断发展的生态系统&#xff0c;Java 在现代 API 方…...

通过 Ansible 在 Windows 2022 上安装 IIS Web 服务器

拓扑结构 这是一个用于通过 Ansible 部署 IIS Web 服务器的实验室拓扑。 前提条件&#xff1a; 在被管理的节点上安装WinRm 准备一张自签名的证书 开放防火墙入站tcp 5985 5986端口 准备自签名证书 PS C:\Users\azureuser> $cert New-SelfSignedCertificate -DnsName &…...