12.2深度学习_项目实战
十、项目实战
鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。
他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。
鲍勃想找出手机的特性(例如:RAM、内存等)和售价之间的关系。但他不太擅长机器学习。所以他需要你帮他解决这个问题。
在这个问题中,你不需要预测实际价格,而是要预测一个价格区间,表明价格多高。
需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。
数据说明:https://tianchi.aliyun.com/dataset/157241 Mobile Price Classification
推荐专业的数据集平台:https://www.kaggle.com/datasets/iabhishekofficial/mobile-price-classification
| 字段 | 说明 |
|---|---|
| battery_power | 电池容量(mAh) |
| blue | 是否支持蓝牙 |
| clock_speed | 微处理器执行指令的速度 |
| dual_sim | 是否支持双卡 |
| fc | 前置摄像头分辨率(百万像素) |
| four_g | 是否支持4G |
| int_memory | 存储内存(GB) |
| m_dep | 手机厚度(厘米) |
| mobile_wt | 重量 |
| n_cores | 核心数 |
| pc | 主摄像头分辨率(百万像素) |
| px_height | 屏幕分辨率高度(像素) |
| px_width | 屏幕分辨率宽度(像素) |
| ram | 运行内存(MB) |
| sc_h | 屏幕长度(厘米) |
| sc_w | 屏幕宽度(厘米) |
| talk_time | 单次充电最长通话时间 |
| three_g | 是否支持3G |
| touch_screen | 是否是触摸屏 |
| wifi | 是否支持WIFI |
| price_range | 价格区间 |
1. 构建数据集
数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便构造数据集加载对象。
# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid=train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()
2. 构建分类网络模型
我们构建的用于手机价格分类的模型叫做全连接神经网络。它主要由三个线性层来构建,在每个线性层后,我们使用的时 sigmoid 激活函数。
# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))output = self.linear3(x)return output
3. 编写训练函数
网络编写完成之后,我们需要编写训练函数。所谓的训练函数,指的是输入数据读取、送入网络、计算损失、更新参数的流程,该流程较为固定。我们使用的是多分类交叉生损失函数、使用 SGD 优化方法。最终,将训练好的模型持久化到磁盘中。
def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')
4. 编写评估函数
评估函数、也叫预测函数、推理函数,主要使用训练好的模型,对未知的样本的进行预测的过程。我们这里使用前面单独划分出来的测试集来进行评估。
def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))
5. 网络性能调优
我们前面的网络模型在测试集的准确率为: 0.54750, 我们可以通过以下方面进行调优:
- 对输入数据进行标准化
- 调整优化方法
- 调整学习率
- 增加批量归一化层
- 增加网络层数、神经元个数
- 增加训练轮数
- 等等…
我进行下如下调整:
- 优化方法由 SGD 调整为 Adam
- 学习率由 1e-3 调整为 1e-4
- 对数据数据进行标准化
- 增加网络深度, 即: 增加网络参数量
网络模型在测试集的准确率由 0.5475 上升到 0.9625
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import time
from sklearn.preprocessing import StandardScaler# 构建数据集
def create_dataset():data = pd.read_csv('data/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = \train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 数据标准化transfer = StandardScaler()x_train = transfer.fit_transform(x_train)x_valid = transfer.transform(x_valid)# 构建数据集train_dataset = TensorDataset(torch.from_numpy(x_train), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128)self.linear2 = nn.Linear(128, 256)self.linear3 = nn.Linear(256, 512)self.linear4 = nn.Linear(512, 128)self.linear5 = nn.Linear(128, output_dim)def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))x = self._activation(self.linear2(x))x = self._activation(self.linear3(x))x = self._activation(self.linear4(x))output = self.linear5(x)return output# 编写训练函数
def train():# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.Adam(model.parameters(), lr=1e-4)# 训练轮数num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone-price-model.bin')def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone-price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))if __name__ == '__main__':train()test()
相关文章:
12.2深度学习_项目实战
十、项目实战 鲍勃开了自己的手机公司。他想与苹果、三星等大公司展开硬仗。 他不知道如何估算自己公司生产的手机的价格。在这个竞争激烈的手机市场,你不能简单地假设事情。为了解决这个问题,他收集了各个公司的手机销售数据。 鲍勃想找出手机的特性(例…...
LeetCode 64. 最小路径和(HOT100)
第一次错误代码: class Solution { public:int minPathSum(vector<vector<int>>& grid) {int dp[205][205] {0};int m grid.size(),n grid[0].size();for(int i 1 ;i<m;i){for(int j 1;j<n;j){dp[i][j] min(dp[i][j-1],dp[i-1][j])gr…...
ESP8266作为TCP客户端或者服务器使用
ESP8266模块,STA模式(与手机搭建TCP通讯,EPS8266为服务端)_esp8266作为station-CSDN博客 ESP8266模块,STA模式(与电脑搭建TCP通讯,ESP8266 为客户端)_esp8266 sta 连接tcp-CSDN博客…...
C#结合.NET框架快速构建和部署AI应用
在人工智能(AI)的浪潮中,C#作为一种功能强大且类型安全的编程语言,为AI工程开发提供了坚实的基础。C#结合.NET框架,使得开发者能够快速构建和部署AI应用。本文将通过一个简单的实例,展示如何使用C#进行AI工…...
题外话 (火影密令)
哥们! 玩火影不! 村里人全部评论! 不评论的忍战李全保底! 哥们! 密令领了不! “1219村里人集合”领了吗! 100金币! 哥们! 我粉丝没人能上影! 老舅说的…...
蓝桥杯准备训练(lesson1,c++方向)
前言 报名参加了蓝桥杯(c)方向的宝子们,今天我将与大家一起努力参赛,后序会与大家分享我的学习情况,我将从最基础的内容开始学习,带大家打好基础,在每节课后都会有练习题,刚开始的练…...
RTDETR融合[ECCV2024]WTConvNeXt中的WTConv模块及相关改进思路
RT-DETR使用教程: RT-DETR使用教程 RT-DETR改进汇总贴:RT-DETR更新汇总贴 《Wavelet Convolutions for Large Receptive Fields》 一、 模块介绍 论文链接:https://arxiv.org/pdf/2407.05848 代码链接:https://github.com/BGU-CS…...
AD7606使用方法
AD7606是一款8通道最高16位200ksps的AD采样芯片。5V单模拟电源供电,真双极性模拟输入可以选择10 V,5 V两种量程。支持串口与并口两种读取方式。 硬件连接方式: 配置引脚 引脚功能 详细说明 OS2 OS1 OS2 过采样率配置 000 1倍过采样率 …...
嵌入式系统应用-LVGL的应用-平衡球游戏 part1
平衡球游戏 part1 1 平衡球游戏的界面设计2 界面设计2.1 背景设计2.2 球的设计2.3 移动球的坐标2.4 用鼠标移动这个球2.5 增加边框规则2.6 效果图2.7 游戏失败重启游戏 3 为小球增加增加动画效果3.1 增加移动效果代码3.2 具体效果图片 平衡球游戏 part2 第二部分文章在这里 1 …...
JVM(四) - JVM 内存结构
目录 一、程序计数器 1.1 作用 1.2 概述 二、虚拟机栈 2.1 概述 2.2 栈的存储单位 2.3 栈运行原理 2.4 栈帧的内部结构 2.4.1. 局部变量表 槽 Slot 2.4.2. 操作数栈 概述 栈顶缓存(Top-of-stack-Cashing) 2.4.3. 动态链接(指向…...
【AI系统】CANN 算子类型
CANN 算子类型 算子是编程和数学中的重要概念,它们是用于执行特定操作的符号或函数,以便处理输入值并生成输出值。本文将会介绍 CANN 算子类型及其在 AI 编程和神经网络中的应用,以及华为 CANN 算子在 AI CPU 的详细架构和开发要求。 算子基…...
VUE脚手架练习
脚手架安装的问题: 1.安装node.js,配置环境变量,cmd输入node -v和npm -v可以看到版本号(如果显示不是命令,确认环境变量是否配置成功,记得配置环境变量之后重新打开cmd,再去验证) 2.在安装cnmp时…...
动态艺术:用Python将文字融入GIF动画
文章内容: 在数字媒体的多样化发展中,GIF动画作为一种流行的表达形式,常被用于广告、社交媒体和娱乐。本文通过一个具体的Python编程示例,展示了如何将文字以动态形式融入到GIF动画中,创造出具有视觉冲击力的动态艺术…...
更多开源创新 挑战OpenAI-o1的模型出现和AI个体模拟突破
每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…...
VR眼镜可视化编程:开启医疗信息系统新纪元
一、引言 随着科技的飞速发展,VR 可视化编程在医疗信息系统中的应用正逐渐成为医疗领域的新趋势。它不仅为医疗教育、手术培训、疼痛管理等方面带来了新的机遇,还在提升患者体验、推动医疗信息系统智能化等方面发挥着重要作用。 在当今医疗领域…...
Ubuntu访问简书403
日期 二〇二四年十二月三日 操作系统 Ubuntu 22.04 浏览器 firefox 问题 打开简书提示403. 原因 简书不认带ubuntu的UA 解决办法 - 浏览器地址栏输入 about:config。接受风险 - 搜索 general.useragent.override,无则新建 string类型。 - 查看浏览器 UA&…...
SQL高级应用——索引与视图
数据库优化离不开索引和视图的合理使用。索引用于加速查询性能,而视图则在逻辑层简化了查询逻辑,提高了可维护性。本文将从以下几个方面详细探讨索引与视图的概念、应用场景、优化技巧以及最新的技术发展: 1. 索引类型与应用场景 索引是数据…...
docker部署文件编写(还未尝试)
docker文件启动mysql 要使用Docker启动MySQL,您可以通过以下步骤编写Dockerfile: 选择一个基础镜像,通常是一个包含了MySQL的Linux发行版。 设置环境变量,如MySQL的root密码等。 在容器启动时运行MySQL服务。 以下是一个简单…...
缓存与数据库数据一致性 详解
缓存与数据库数据一致性详解 在分布式系统中,缓存(如 Redis、Memcached)与数据库(如 MySQL、PostgreSQL)一起使用是提高系统性能的常用方法。然而,缓存与数据库可能因更新时序、操作失误等原因出现数据不一…...
每日计划-1203
1. 完成 236. 二叉树的最近公共祖先 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode(int x) : val(x), left(NULL), right(NULL) {}* };*/ class Solution {public:TreeNode* lowe…...
VMware macOS虚拟机解锁方案:开源工具Unlocker完整实践指南
VMware macOS虚拟机解锁方案:开源工具Unlocker完整实践指南 【免费下载链接】unlocker VMware Workstation macOS 项目地址: https://gitcode.com/gh_mirrors/unloc/unlocker 你是否想在Windows或Linux系统上运行macOS虚拟机,却苦于VMware不支持…...
如何快速配置AI自瞄系统:面向游戏爱好者的完整指南
如何快速配置AI自瞄系统:面向游戏爱好者的完整指南 【免费下载链接】RookieAI_yolov8 基于yolov8实现的AI自瞄项目 AI self-aiming project based on yolov8 项目地址: https://gitcode.com/gh_mirrors/ro/RookieAI_yolov8 还在为FPS游戏中的精准瞄准而烦恼吗…...
373. Java IO API - 文件存储属性
文章目录373. Java IO API - 文件存储属性📏 示例:检查文件存储的空间使用情况⚙️ 解释🔍 确定 MIME 类型📂 示例:获取文件 MIME 类型⚠️ 重要注意事项🛠️ 示例:自定义文件类型探测器&#x…...
STM32内存管理实战:如何避免局部变量数组导致的栈溢出问题?
ST32内存管理实战:如何避免局部变量数组导致的栈溢出问题? 在嵌入式开发领域,内存管理一直是开发者必须面对的挑战之一。对于使用STM32系列微控制器的开发者来说,理解并掌握内存分配机制尤为重要。本文将深入探讨STM32开发中常见的…...
ChampR:让每个英雄联盟玩家都能掌握专业级游戏策略
ChampR:让每个英雄联盟玩家都能掌握专业级游戏策略 【免费下载链接】champ-r 🐶 Yet another League of Legends helper 项目地址: https://gitcode.com/gh_mirrors/ch/champ-r 一、核心价值解析:ChampR如何重新定义游戏辅助工具&…...
3天从零到精通:录播姬全方位实战指南
3天从零到精通:录播姬全方位实战指南 【免费下载链接】BililiveRecorder 录播姬 | mikufans 生放送录制 项目地址: https://gitcode.com/gh_mirrors/bi/BililiveRecorder 你是否曾经因为错过心爱主播的直播而感到遗憾?是否在录制直播时遇到各种技…...
顺丰控股年营收3082亿:净利111亿 现金分红21亿
雷递网 雷建平 4月5日顺丰控股(证券代码:002352)日前发布截至2025年12月31日的财报。财报显示,顺丰控股2025年营收3082.27亿,较上年同期的2844亿元增长8.37%。顺丰控股2025年时效快递业务实现营业收入1,310.5亿元&…...
造相-Z-Image-Turbo亚洲美女LoRA入门指南:开箱即用的图片生成服务
造相-Z-Image-Turbo亚洲美女LoRA入门指南:开箱即用的图片生成服务 1. 服务概览与核心价值 造相-Z-Image-Turbo亚洲美女LoRA是一个基于Z-Image-Turbo模型的图片生成Web服务,特别集成了laonansheng/Asian-beauty-Z-Image-Turbo-Tongyi-MAI-v1.0 LoRA模型…...
WindowResizer终极指南:三步解决Windows窗口无法调整大小的难题
WindowResizer终极指南:三步解决Windows窗口无法调整大小的难题 【免费下载链接】WindowResizer 一个可以强制调整应用程序窗口大小的工具 项目地址: https://gitcode.com/gh_mirrors/wi/WindowResizer 在日常使用Windows电脑时,你是否遇到过这样…...
NLP-StructBERT赋能内容去重:展示海量文本相似度排查惊艳效果
NLP-StructBERT赋能内容去重:展示海量文本相似度排查惊艳效果 每次打开内容平台,你是不是也经常看到一堆“换汤不换药”的文章?标题不一样,内容却大同小异。对于平台运营者来说,这更是个头疼的问题:怎么从…...
