12.2深度学习_项目实战
十、项目实战
鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。
他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。
鲍勃想找出手机的特性(例如:RAM、内存等)和售价之间的关系。但他不太擅长机器学习。所以他需要你帮他解决这个问题。
在这个问题中,你不需要预测实际价格,而是要预测一个价格区间,表明价格多高。
需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。
数据说明:https://tianchi.aliyun.com/dataset/157241 Mobile Price Classification
推荐专业的数据集平台:https://www.kaggle.com/datasets/iabhishekofficial/mobile-price-classification
| 字段 | 说明 |
|---|---|
| battery_power | 电池容量(mAh) |
| blue | 是否支持蓝牙 |
| clock_speed | 微处理器执行指令的速度 |
| dual_sim | 是否支持双卡 |
| fc | 前置摄像头分辨率(百万像素) |
| four_g | 是否支持4G |
| int_memory | 存储内存(GB) |
| m_dep | 手机厚度(厘米) |
| mobile_wt | 重量 |
| n_cores | 核心数 |
| pc | 主摄像头分辨率(百万像素) |
| px_height | 屏幕分辨率高度(像素) |
| px_width | 屏幕分辨率宽度(像素) |
| ram | 运行内存(MB) |
| sc_h | 屏幕长度(厘米) |
| sc_w | 屏幕宽度(厘米) |
| talk_time | 单次充电最长通话时间 |
| three_g | 是否支持3G |
| touch_screen | 是否是触摸屏 |
| wifi | 是否支持WIFI |
| price_range | 价格区间 |
1. 构建数据集
数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。
# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid=train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()
2. 构建分类网络模型
我们构建的用于手机价格分类的模型叫做全连接神经网络。它主要由三个线性层来构建,在每个线性层后,我们使用的时 sigmoid 激活函数。
# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))output = self.linear3(x)return output
3. 编写训练函数
网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。
def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')
4. 编写评估函数
评估函数、也叫预测函数、推理函数,主要使用训练好的模型,对未知的样本的进行预测的过程。我们这里使用前面单独划分出来的测试集来进行评估。
def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))
5. 网络性能调优
我们前面的网络模型在测试集的准确率为: 0.54750, 我们可以通过以下方面进行调优:
- 对输入数据进行标准化
- 调整优化方法
- 调整学习率
- 增加批量归一化层
- 增加网络层数、神经元个数
- 增加训练轮数
- 等等…
我进行下如下调整:
- 优化方法由 SGD 调整为 Adam
- 学习率由 1e-3 调整为 1e-4
- 对数据数据进行标准化
- 增加网络深度, 即: 增加网络参数量
网络模型在测试集的准确率由 0.5475 上升到 0.9625
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import time
from sklearn.preprocessing import StandardScaler# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = \train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 数据标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_valid = transfer.transform(x_valid)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, 512)self.linear4 = nn.Linear(512, 128)self.linear5 = nn.Linear(128, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))x = self._activation(self.linear3(x))x = self._activation(self.linear4(x))output = self.linear5(x)return output# 编写训练函数
def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))if __name__ == '__main__':train()test()
相关文章:
12.2深度学习_项目实战
十、项目实战 鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。 他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。 鲍勃想找出手机的特性(例…...
LeetCode 64. 最小路径和(HOT100)
第一次错误代码: class Solution { public:int minPathSum(vector<vector<int>>& grid) {int dp[205][205] {0};int m grid.size(),n grid[0].size();for(int i 1 ;i<m;i){for(int j 1;j<n;j){dp[i][j] min(dp[i][j-1],dp[i-1][j])gr…...
ESP8266作为TCP客户端或者服务器使用
ESP8266模块,STA模式(与手机搭建TCP通讯,EPS8266为服务端)_esp8266作为station-CSDN博客 ESP8266模块,STA模式(与电脑搭建TCP通讯,ESP8266 为客户端)_esp8266 sta 连接tcp-CSDN博客…...
C#结合.NET框架快速构建和部署AI应用
在人工智能(AI)的浪潮中,C#作为一种功能强大且类型安全的编程语言,为AI工程开发提供了坚实的基础。C#结合.NET框架,使得开发者能够快速构建和部署AI应用。本文将通过一个简单的实例,展示如何使用C#进行AI工…...
题外话 (火影密令)
哥们! 玩火影不! 村里人全部评论! 不评论的忍战李全保底! 哥们! 密令领了不! “1219村里人集合”领了吗! 100金币! 哥们! 我粉丝没人能上影! 老舅说的…...
蓝桥杯准备训练(lesson1,c++方向)
前言 报名参加了蓝桥杯(c)方向的宝子们,今天我将与大家一起努力参赛,后序会与大家分享我的学习情况,我将从最基础的内容开始学习,带大家打好基础,在每节课后都会有练习题,刚开始的练…...
RTDETR融合[ECCV2024]WTConvNeXt中的WTConv模块及相关改进思路
RT-DETR使用教程: RT-DETR使用教程 RT-DETR改进汇总贴:RT-DETR更新汇总贴 《Wavelet Convolutions for Large Receptive Fields》 一、 模块介绍 论文链接:https://arxiv.org/pdf/2407.05848 代码链接:https://github.com/BGU-CS…...
AD7606使用方法
AD7606是一款8通道最高16位200ksps的AD采样芯片。5V单模拟电源供电,真双极性模拟输入可以选择10 V,5 V两种量程。支持串口与并口两种读取方式。 硬件连接方式: 配置引脚 引脚功能 详细说明 OS2 OS1 OS2 过采样率配置 000 1倍过采样率 …...
嵌入式系统应用-LVGL的应用-平衡球游戏 part1
平衡球游戏 part1 1 平衡球游戏的界面设计2 界面设计2.1 背景设计2.2 球的设计2.3 移动球的坐标2.4 用鼠标移动这个球2.5 增加边框规则2.6 效果图2.7 游戏失败重启游戏 3 为小球增加增加动画效果3.1 增加移动效果代码3.2 具体效果图片 平衡球游戏 part2 第二部分文章在这里 1 …...
JVM(四) - JVM 内存结构
目录 一、程序计数器 1.1 作用 1.2 概述 二、虚拟机栈 2.1 概述 2.2 栈的存储单位 2.3 栈运行原理 2.4 栈帧的内部结构 2.4.1. 局部变量表 槽 Slot 2.4.2. 操作数栈 概述 栈顶缓存(Top-of-stack-Cashing) 2.4.3. 动态链接(指向…...
【AI系统】CANN 算子类型
CANN 算子类型 算子是编程和数学中的重要概念,它们是用于执行特定操作的符号或函数,以便处理输入值并生成输出值。本文将会介绍 CANN 算子类型及其在 AI 编程和神经网络中的应用,以及华为 CANN 算子在 AI CPU 的详细架构和开发要求。 算子基…...
VUE脚手架练习
脚手架安装的问题: 1.安装node.js,配置环境变量,cmd输入node -v和npm -v可以看到版本号(如果显示不是命令,确认环境变量是否配置成功,记得配置环境变量之后重新打开cmd,再去验证) 2.在安装cnmp时…...
动态艺术:用Python将文字融入GIF动画
文章内容: 在数字媒体的多样化发展中,GIF动画作为一种流行的表达形式,常被用于广告、社交媒体和娱乐。本文通过一个具体的Python编程示例,展示了如何将文字以动态形式融入到GIF动画中,创造出具有视觉冲击力的动态艺术…...
更多开源创新 挑战OpenAI-o1的模型出现和AI个体模拟突破
每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…...
VR眼镜可视化编程:开启医疗信息系统新纪元
一、引言 随着科技的飞速发展,VR 可视化编程在医疗信息系统中的应用正逐渐成为医疗领域的新趋势。它不仅为医疗教育、手术培训、疼痛管理等方面带来了新的机遇,还在提升患者体验、推动医疗信息系统智能化等方面发挥着重要作用。 在当今医疗领域…...
Ubuntu访问简书403
日期 二〇二四年十二月三日 操作系统 Ubuntu 22.04 浏览器 firefox 问题 打开简书提示403. 原因 简书不认带ubuntu的UA 解决办法 - 浏览器地址栏输入 about:config。接受风险 - 搜索 general.useragent.override,无则新建 string类型。 - 查看浏览器 UA&…...
SQL高级应用——索引与视图
数据库优化离不开索引和视图的合理使用。索引用于加速查询性能,而视图则在逻辑层简化了查询逻辑,提高了可维护性。本文将从以下几个方面详细探讨索引与视图的概念、应用场景、优化技巧以及最新的技术发展: 1. 索引类型与应用场景 索引是数据…...
docker部署文件编写(还未尝试)
docker文件启动mysql 要使用Docker启动MySQL,您可以通过以下步骤编写Dockerfile: 选择一个基础镜像,通常是一个包含了MySQL的Linux发行版。 设置环境变量,如MySQL的root密码等。 在容器启动时运行MySQL服务。 以下是一个简单…...
缓存与数据库数据一致性 详解
缓存与数据库数据一致性详解 在分布式系统中,缓存(如 Redis、Memcached)与数据库(如 MySQL、PostgreSQL)一起使用是提高系统性能的常用方法。然而,缓存与数据库可能因更新时序、操作失误等原因出现数据不一…...
每日计划-1203
1. 完成 236. 二叉树的最近公共祖先 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode(int x) : val(x), left(NULL), right(NULL) {}* };*/ class Solution {public:TreeNode* lowe…...
Anthropic Managed Agents:AI 运行时的事件日志革命
1. 这不是新赛道,是 runtime 层的“操作系统时刻”来了你有没有试过让一个 AI 代理连续工作四十分钟?不是闲聊,而是真正在查文档、调 API、写代码、改配置、再验证——一环扣一环地推进一个真实业务流程。我去年就带着团队跑过这样一个销售线…...
从脚本到智能体:自动化体系如何被 Agent 重新定义
从脚本到智能体:自动化体系如何被 Agent 重新定义 关键词:智能体Agent、自动化脚本、LLM原生应用、自主决策系统、RAG检索增强生成、工具调用、自动化体系演进 摘要:本文从所有开发者都熟悉的传统自动化脚本痛点切入,用奶茶店员工到金牌店长的生活化类比,一步步拆解自动化…...
告别报错!手把手教你用Pycharm 2023.2 + Git搞定Manim社区版安装(附国内镜像源配置)
Manim社区版极速安装指南:PyCharm 2023.2与Git的完美协作方案 当数学可视化遇上Python开发神器PyCharm,Manim社区版的安装过程却常常成为新手的第一道门槛。不同于常规教程的线性步骤,我们将以"问题-解决"为主线,直击两…...
基于Multisim的四路带计分系统抢答器设计与仿真
摘要:本项目设计了一个四路带计分系统的智能抢答器,具有声光显示、计时和计分功能。使用Multisim 14.3进行电路设计 与仿真验证。项目简介本项目设计了一个基于Multisim的四路带计分系统智能抢答器,采用74系列数字逻辑芯片实现纯硬件电路设计…...
美容顾问转型AI训练师:2024紧缺新职业认证路径(含国家人社部备案课程编号)
更多请点击: https://kaifayun.com 第一章:AI Agent美容行业应用概述 AI Agent正以前所未有的深度融入美容行业,从智能肤质分析、个性化护肤方案生成,到虚拟试妆、客户行为预测与自动化私域运营,其核心价值在于将非结…...
如何永久激活IDM?免费IDM激活脚本终极指南
如何永久激活IDM?免费IDM激活脚本终极指南 【免费下载链接】IDM-Activation-Script IDM Activation & Trail Reset Script 项目地址: https://gitcode.com/gh_mirrors/id/IDM-Activation-Script 还在为IDM试用期到期而烦恼吗?IDM Activation …...
博德之门3 2026最新免费下载 一键转存 永久更新 (看到速转存 资源随时走丢)
下载链接 电子角色扮演游戏的范式革新:博德之门3的技术架构与玩法机制剖析 在现代电子游戏工业中,古典角色扮演游戏(CRPG)曾因其高昂的学习门槛与繁复的规则体系,一度被视为分众市场的垂类产品。然而,2023…...
告别断电重启就丢程序:深入聊聊紫光同创FPGA的Flash固化与CPLD内置eFlash配置差异
紫光同创FPGA与CPLD配置存储机制深度解析:从瞬态下载到永久固化的技术实现 在数字电路设计领域,FPGA和CPLD的可重构特性为硬件开发带来了极大灵活性。然而,这种灵活性背后需要可靠的配置存储机制作为支撑——断电后程序能否自动恢复…...
免费AI搜索工具推荐2026,92%用户不知道的3个隐藏权限设置——关闭行为追踪、锁定模型版本、强制HTTPS直连
更多请点击: https://kaifayun.com 第一章:免费AI搜索工具推荐2026 2026年,开源与社区驱动的AI搜索工具生态迎来爆发式增长。得益于大语言模型轻量化部署、RAG(检索增强生成)架构普及,以及WebAssembly对客…...
天学网英语听力对孩子有用吗?2026最新真实测评结果告诉你
做了5年英语听力领域的技术研究,最近后台好多家长问我这类AI听力训练产品对孩子提分有没有用,刚好我们团队刚做完2026年的公立校落地测评,今天就客观给大家拆解清楚。先聊聊英语听力训练的行业共性痛点我们团队在实践中发现,现在国…...
