【机器学习】超参数调优指南:交叉验证,网格搜索,混淆矩阵——基于鸢尾花与数字识别案例的深度解析
一、前言:为何要学交叉验证与网格搜索?
大家好!在机器学习的道路上,我们经常面临一个难题:模型调参。比如在 KNN 算法中,选择多少个邻居(n_neighbors)直接影响预测效果。
• 蛮力猜测:就像在厨房随便“加盐加辣椒”,不仅费时费力,还可能把菜搞砸。
• 交叉验证 + 网格搜索:更像是让你请来一位“大厨”,提前试好所有配方,帮你挑选出最完美的“调料搭配”。
交叉验证与网格搜索的组合,能让你在众多超参数组合中自动挑选出最佳方案,从而让模型预测达到“哇塞,这也太准了吧!”的境界。
二、概念扫盲:交叉验证 & 网格搜索
1. 交叉验证(Cross-Validation)
核心思路:
• 分组品尝:将整个数据集平均分成若干份(比如分成 5 份,即“5折交叉验证”)。
• 轮流担任评委:每次选取其中一份作为“验证集”(就像让这部分数据来“评委打分”),剩下的作为“训练集”来训练模型。
• 集体评定:重复多次,每一份都轮流担任验证集,然后把所有“评分”取平均,作为模型在数据集上的最终表现。
好处:
• 每个样本都有机会既当“选手”又当“评委”,使得评估结果更稳定、可靠。
• 避免单一划分带来的偶然性,确保你调出来的参数在不同数据切分下都表现良好。
2. 网格搜索(Grid Search)
核心思路:
• 列出所有可能:将你想尝试的超参数组合“罗列成一个表格(网格)”。
• 自动试菜:每种组合都进行一次完整的模型训练和评估,记录下它们的表现。
• 选出最佳配方:最后找出在交叉验证中表现最好的超参数组合。
好处:
• 自动化、系统化地寻找最佳参数组合,避免你手动“胡乱猜测”。
• 和交叉验证结合后,每个参数组合都经过了多次评估,结果更稳健。
3. 网格搜索 + 交叉验证
这两者结合就像“炼丹”高手的秘诀:
• 交叉验证解决了“数据切分”的问题,让评估更准确;
• 网格搜索解决了“超参数组合”问题,帮你遍历所有可能性。
合体后,你就能轻松找到最优超参数,让模型发挥出最佳性能!
三、案例一:鸢尾花数据集 + KNN + 交叉验证网格搜索
3.1 数据集介绍
• 数据来源:scikit-learn 内置的 load_iris
• 特征:萼片长度、萼片宽度、花瓣长度、花瓣宽度
• 目标:根据花的外部特征预测其所属的鸢尾花种类
3.2 代码示例
下面代码展示如何在鸢尾花数据集上使用 KNN 算法,并通过 GridSearchCV(交叉验证+网格搜索)自动调优 n_neighbors 参数:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_scoredef iris_knn_cv():"""使用KNN算法在鸢尾花数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。"""# 1. 加载数据iris = load_iris()X = iris.data # 特征矩阵,包含四个特征y = iris.target # 标签,分别代表三种鸢尾花# 2. 划分训练集和测试集# test_size=0.2 表示 20% 的数据用于测试,保证测试结果具有代表性# random_state=22 固定随机数种子,确保每次运行划分一致X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=22)# 3. 数据标准化# 标准化可使各特征均值为0、方差为1,消除量纲影响(对于基于距离的KNN非常重要)scaler = StandardScaler()X_train_scaled = scaler.fit_transform(X_train)X_test_scaled = scaler.transform(X_test)# 4. 构建KNN模型及参数调优knn = KNeighborsClassifier() # 初始化KNN模型# 4.1 设置网格搜索参数范围:尝试不同的邻居数param_grid = {'n_neighbors': [1, 3, 5, 7, 9]}# 4.2 进行网格搜索 + 交叉验证(5折交叉验证)grid_search = GridSearchCV(estimator=knn, # 待调参的模型param_grid=param_grid, # 超参数候选列表cv=5, # 5折交叉验证:将训练集分为5个子集,每次用1个子集验证,其余4个训练scoring='accuracy', # 以准确率作为评估指标n_jobs=-1 # 使用所有CPU核心并行计算)grid_search.fit(X_train_scaled, y_train) # 自动遍历各参数组合并评估# 4.3 输出网格搜索结果print("最佳交叉验证分数:", grid_search.best_score_)print("最优超参数组合:", grid_search.best_params_)print("最优模型:", grid_search.best_estimator_)# 5. 模型评估:用测试集评估最优模型的泛化能力best_model = grid_search.best_estimator_y_pred = best_model.predict(X_test_scaled)acc = accuracy_score(y_test, y_pred)print("在测试集上的准确率:{:.2f}%".format(acc * 100))# 6. 可视化(选做):可进一步绘制混淆矩阵或学习曲线# 直接调用函数进行测试
if __name__ == "__main__":iris_knn_cv()
输出:
3.3 结果解读
• 最佳交叉验证分数:表示在5折交叉验证过程中,所有参数组合中平均准确率最高的值。
• 最优超参数组合:显示在候选参数 [1, 3, 5, 7, 9] 中哪个 n_neighbors 的效果最好。
• 测试集准确率:验证模型在未见数据上的表现,反映其泛化能力。
通过这个案例,你可以看到交叉验证网格搜索如何自动帮你“挑菜”选料,让 KNN 模型在鸢尾花分类任务上达到最佳表现。
四、案例二:手写数字数据集 + KNN + 交叉验证网格搜索
4.1 数据集介绍
• 数据来源:scikit-learn 内置的 load_digits
• 特征:每张 8×8 像素的手写数字图像被拉伸成64维特征向量
• 目标:识别图片中数字所属类别(0~9)
4.2 代码示例
下面代码展示如何在手写数字数据集上使用 KNN 算法,并通过交叉验证网格搜索调优参数:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits # 导入手写数字数据集(内置于 scikit-learn)
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler # 导入数据标准化工具
from sklearn.neighbors import KNeighborsClassifier # 导入KNN分类器
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns # 导入 seaborn,用于绘制更美观的图表def digits_knn_cv():"""使用KNN算法在手写数字数据集上进行分类,并通过网格搜索+交叉验证寻找最优超参数。"""# 1. 加载数据digits = load_digits() # 从scikit-learn加载内置手写数字数据集X = digits.data # 特征数据,形状为 (1797, 64),每一行对应一张图片的64个像素值y = digits.target # 目标标签,共10个类别(数字 0 到 9)# 2. 数据可视化:展示前5张图片及其标签# 创建一个1行5列的子图区域,图像尺寸为10x2英寸fig, axes = plt.subplots(1, 5, figsize=(10, 2))for i in range(5):# 显示第 i 张图片,使用灰度图(cmap='gray')axes[i].imshow(digits.images[i], cmap='gray')# 设置每个子图的标题,显示该图片对应的标签axes[i].set_title("Label: {}".format(digits.target[i]))# 关闭坐标轴显示(避免坐标信息干扰视觉效果)axes[i].axis('off')plt.suptitle("手写数字数据集示例") # 为整个图表添加一个总标题plt.show() # 显示图表# 3. 数据划分 + 标准化# 将数据划分为训练集和测试集,其中测试集占20%,random_state保证每次划分一致X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 初始化标准化工具,将特征数据转换为均值为0、方差为1的标准正态分布scaler = StandardScaler()# 仅在训练集上拟合标准化参数,并转换训练集数据X_train_scaled = scaler.fit_transform(X_train)# 使用相同的转换参数转换测试集数据(避免数据泄露)X_test_scaled = scaler.transform(X_test)# 4. 构建KNN模型及网格搜索调参knn = KNeighborsClassifier() # 初始化KNN分类器,暂未指定 n_neighbors 参数# 定义一个字典,列出希望尝试的超参数组合# 这里我们测试不同邻居数的效果:[1, 3, 5, 7, 9]param_grid = {'n_neighbors': [1, 3, 5, 7, 9]}# 初始化网格搜索对象,结合交叉验证grid_search = GridSearchCV(estimator=knn, # 需要调参的KNN模型param_grid=param_grid, # 超参数候选组合cv=5, # 5折交叉验证,将训练数据分成5份,每次用4份训练,1份验证scoring='accuracy', # 使用准确率作为模型评估指标n_jobs=-1 # 并行计算,使用所有可用的CPU核心加速计算)# 在标准化后的训练集上进行网格搜索,自动尝试所有参数组合,并进行交叉验证grid_search.fit(X_train_scaled, y_train)# 5. 输出网格搜索调参结果# 打印在交叉验证中获得的最佳平均准确率print("手写数字 - 最佳交叉验证分数:", grid_search.best_score_)# 打印获得最佳结果时所使用的超参数组合,例如 {'n_neighbors': 3}print("手写数字 - 最优超参数组合:", grid_search.best_params_)# 打印最佳模型对象,该模型已使用最优参数重新训练best_model = grid_search.best_estimator_# 6. 模型评估:用测试集评估模型效果# 使用最优模型对测试集进行预测y_pred = best_model.predict(X_test_scaled)# 计算测试集上的准确率acc = accuracy_score(y_test, y_pred)print("手写数字 - 测试集准确率:{:.2f}%".format(acc * 100))# 7. 可视化混淆矩阵(直观展示各数字分类效果)# 混淆矩阵能够显示真实标签与预测标签之间的对应关系cm = confusion_matrix(y_test, y_pred)plt.figure(figsize=(6, 5))# 使用 seaborn 的 heatmap 绘制混淆矩阵,annot=True 表示在每个单元格中显示数字sns.heatmap(cm, annot=True, cmap='Blues', fmt='d')plt.title("手写数字 - 混淆矩阵")plt.xlabel("预测值")plt.ylabel("真实值")plt.show()# 直接调用函数进行测试
if __name__ == "__main__":digits_knn_cv()
输出:
4.3 结果解读
• 最优 n_neighbors:通过交叉验证,我们找到了在候选参数中使模型表现最佳的邻居数量。
• 测试集准确率:在手写数字识别任务上,通常准确率能达到90%以上,证明 KNN 在小数据集上也能表现不错。
• 混淆矩阵:直观展示哪些数字容易混淆(例如数字“3”和“5”),便于进一步分析和改进。
混淆矩阵图的含义与作用
1. 横纵坐标的含义
• 行(纵轴)代表真实标签(真实的数字 0~9)。
• 列(横轴)代表模型预测的标签(预测的数字 0~9)。
2. 数值和颜色深浅
• 单元格 (i, j) 内的数值表示:真实类别为 i 的样本中,有多少被预测为 j。
• 越靠近对角线(i = j)代表预测正确的数量;
• 离对角线越远,说明模型将真实类别 i 的样本错误地预测成类别 j。
• 热力图中颜色越深表示数量越多,浅色则表示数量少。
3. 作用
• 评估模型分类效果:如果对角线上的数值高且远离对角线的数值低,说明模型分类准确度高;反之,说明某些类别容易被混淆。
• 发现易混淆的类别:通过观察非对角线位置是否有较大的数值,可以知道哪些数字最容易被误判。例如,模型可能经常把“3”预测成“5”,这能提示我们在后续改进中加强这两个类别的区分。
• 比单纯的准确率更全面:准确率只能告诉你模型整体正确率,而混淆矩阵能告诉你哪类错误最多,便于更有针对性地提升模型性能。
五、总结 & 彩蛋
1. 交叉验证的价值
• 有效避免过拟合,通过多次分组验证,使得模型评估更稳健。
2. 网格搜索的强大
• 自动遍历所有超参数组合,省去手动调参的烦恼,快速锁定“最佳拍档”。
3. KNN 的局限
• 虽然简单易用,但在大规模、高维数据中计算量较大,且对异常值较敏感。
4. 后续进阶
• 可以尝试随机搜索(RandomizedSearchCV)或贝叶斯优化,甚至转向更复杂的模型如 CNN 进行数字识别。
结语
如果你觉得本篇文章对你有所帮助,请记得点赞、收藏、转发和评论哦!你的支持是我继续创作的最大动力。让我们一起在机器学习的道路上不断探索、不断进步,早日成为调参界的“神仙”!
祝学习愉快,炼丹顺利~
相关文章:

【机器学习】超参数调优指南:交叉验证,网格搜索,混淆矩阵——基于鸢尾花与数字识别案例的深度解析
一、前言:为何要学交叉验证与网格搜索? 大家好!在机器学习的道路上,我们经常面临一个难题:模型调参。比如在 KNN 算法中,选择多少个邻居(n_neighbors)直接影响预测效果。 • 蛮力猜…...

Burp Suite基本使用(web安全)
工具介绍 在网络安全的领域,你是否听说过抓包,挖掘漏洞等一系列的词汇,这篇文章将带你了解漏洞挖掘的热门工具——Burp Suite的使用。 Burp Suite是一款由PortSwigger Web Security公司开发的集成化Web应用安全检测工具,它主要用于…...

React实现自定义图表(线状+柱状)
要使用 React 绘制一个结合线状图和柱状图的图表,你可以使用 react-chartjs-2 库,它是基于 Chart.js 的 React 封装。以下是一个示例代码,展示如何实现这个需求: 1. 安装依赖 首先,你需要安装 react-chartjs-2 和 ch…...

从低清到4K的魔法:FlashVideo突破高分辨率视频生成计算瓶颈(港大港中文字节)
论文链接:https://arxiv.org/pdf/2502.05179 项目链接:https://github.com/FoundationVision/FlashVideo 亮点直击 提出了 FlashVideo,一种将视频生成解耦为两个目标的方法:提示匹配度和视觉质量。通过在两个阶段分别调整模型规模…...
Qt的QTabWidget的使用
在PyQt5中,QTabWidget 是一个用于管理多个选项卡页面的容器控件。以下是其使用方法的详细说明和示例: 1. 基本用法 import sys from PyQt5.QtWidgets import QApplication, QMainWindow, QTabWidget, QWidget, QLabel, QVBoxLayoutclass MainWindow(QMa…...

Next.js【详解】获取数据(访问接口)
Next.js 中分为 服务端组件 和 客户端组件,内置的获取数据各不相同 服务端组件 方式1 – 使用 fetch export default async function Page() {const data await fetch(https://api.vercel.app/blog)const posts await data.json()return (<ul>{posts.map((…...

反向代理模块kd
1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说,反向代理就相当于…...
leaflet前端初始化项目
1、通过npm安装leaflet包,或者直接在项目中引入leaflet.js库文件。 npm 安装:npm i leaflet 如果在index.html中引入leaflet.js,在项目中可以直接使用变量L. 注意:尽量要么使用npm包,要么使用leaflet.js库,两者一起使用容易发生…...
CMS DTcms 靶场(弱口令、文件上传、tasklist提权、开启远程桌面3389、gotohttp远程登录控制)
环境说明 攻击机kali:192.168.111.128 信息收集 主机发现 ┌──(root㉿kali-plus)-[~/Desktop] └─# nmap -sP 192.168.111.0/24 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-11-23 14:57 CST Nmap scan report for 192.168.111.1 Host is up (0.00039s latenc…...

Docker 入门与实战:从安装到容器管理的完整指南
🚀 Docker 入门与实战:从安装到容器管理的完整指南 🌟 📖 简介 在现代软件开发中,容器化技术已经成为不可或缺的一部分。而 Docker 作为容器化领域的领头羊,以其轻量级、高效和跨平台的特性,深…...

git删除本地分支
一、命令方式 1、查看本地分支 git branch 2、切换到一个不删除的分支 git checkout branch_name 3、强制删除分支 git branch -D local_branch_name 二、工具方式 1、选择"Browse references",右键"Delete branch"...

spring cloud gateway限流常见算法
目录 一、网关限流 1、限流的作用 1. 保护后端服务 2. 保证服务质量 (QoS) 3. 避免滥用和恶意攻击 4. 减少资源浪费 5. 提高系统可扩展性和稳定性 6. 控制不同用户的访问频率 7. 提升用户体验 8. 避免API滥用和负载过高 9. 监控与分析 10. 避免系统崩溃 2、网关限…...

本地使用docker部署DeepSeek大模型
1、相关技术介绍 1.1、RAG RAG(Retrieval Augmented Generation),即“检索,增强,生成”,用于提升自然语言处理任务的性能。其核心思想是通过检索相关信息来增强生成模型的能力,具体步骤如下&am…...
C++ 设计模式-外观模式
外观模式的定义 外观模式是一种 结构型设计模式,它通过提供一个简化的接口来隐藏系统的复杂性。外观模式的核心思想是: 封装复杂子系统:将多个复杂的子系统或组件封装在一个统一的接口后面。提供简单接口:为客户端提供一个更简单、更易用的接口,而不需要客户端直接与复杂…...

【Linux网络编程】应用层协议HTTP(请求方法,状态码,重定向,cookie,session)
🎁个人主页:我们的五年 🔍系列专栏:Linux网络编程 🌷追光的人,终会万丈光芒 🎉欢迎大家点赞👍评论📝收藏⭐文章 Linux网络编程笔记: https://blog.cs…...
SQL进阶技巧:如何统计用户跨端消费行为?
目录 0 问题描述 2 问题剖析 技术难点解析 3 完整解决方案 步骤1:构造全量日期平台组合 步骤2:用户行为标记 步骤3:最终关联聚合 4 核心技巧总结 5 复杂度评估 往期精彩 0 问题描述 支出表: Spending +-------------+---------+ | Column Name | Type | +-----…...

Fiddler笔记
文章目录 一、与F12对比二、核心作用三、原理四、配置1.Rules:2.配置证书抓取https包3.设置过滤器4、抓取App包 五、模拟弱网测试六、调试1.线上调试2.断点调试 七、理论1.四要素2.如何定位前后端bug 注 一、与F12对比 相同点: 都可以对http和https请求进行抓包分析…...

基于SpringBoot+Vue的老年人体检管理系统的设计与实现(源码+SQL脚本+LW+部署讲解等)
专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...

51c自动驾驶~合集51
我自己的原文哦~ https://blog.51cto.com/whaosoft/13320191 #毫末最新OAD 轨迹偏移学习助力端到端新SOTA~ 端到端自动驾驶技术在近年来取得了显著进展。在本研究中,我们提出了轨迹偏移学习,将传统的直接预测自车轨迹,转换为预测相对于…...
Redis 监视器:深入解析与实战指南
Redis 监视器:深入解析与实战指南 引言 随着互联网技术的飞速发展,企业对实时数据处理和高并发场景的需求日益增长。Redis作为一款高性能的内存数据库,在各个领域中得到了广泛应用,包括缓存、消息队列、实时数据分析等。然而&am…...

Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

全球首个30米分辨率湿地数据集(2000—2022)
数据简介 今天我们分享的数据是全球30米分辨率湿地数据集,包含8种湿地亚类,该数据以0.5X0.5的瓦片存储,我们整理了所有属于中国的瓦片名称与其对应省份,方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...

CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
汇编常见指令
汇编常见指令 一、数据传送指令 指令功能示例说明MOV数据传送MOV EAX, 10将立即数 10 送入 EAXMOV [EBX], EAX将 EAX 值存入 EBX 指向的内存LEA加载有效地址LEA EAX, [EBX4]将 EBX4 的地址存入 EAX(不访问内存)XCHG交换数据XCHG EAX, EBX交换 EAX 和 EB…...
Java求职者面试指南:计算机基础与源码原理深度解析
Java求职者面试指南:计算机基础与源码原理深度解析 第一轮提问:基础概念问题 1. 请解释什么是进程和线程的区别? 面试官:进程是程序的一次执行过程,是系统进行资源分配和调度的基本单位;而线程是进程中的…...

C++_哈希表
本篇文章是对C学习的哈希表部分的学习分享 相信一定会对你有所帮助~ 那咱们废话不多说,直接开始吧! 一、基础概念 1. 哈希核心思想: 哈希函数的作用:通过此函数建立一个Key与存储位置之间的映射关系。理想目标:实现…...

快速排序算法改进:随机快排-荷兰国旗划分详解
随机快速排序-荷兰国旗划分算法详解 一、基础知识回顾1.1 快速排序简介1.2 荷兰国旗问题 二、随机快排 - 荷兰国旗划分原理2.1 随机化枢轴选择2.2 荷兰国旗划分过程2.3 结合随机快排与荷兰国旗划分 三、代码实现3.1 Python实现3.2 Java实现3.3 C实现 四、性能分析4.1 时间复杂度…...

职坐标物联网全栈开发全流程解析
物联网全栈开发涵盖从物理设备到上层应用的完整技术链路,其核心流程可归纳为四大模块:感知层数据采集、网络层协议交互、平台层资源管理及应用层功能实现。每个模块的技术选型与实现方式直接影响系统性能与扩展性,例如传感器选型需平衡精度与…...
【Redis】Redis从入门到实战:全面指南
Redis从入门到实战:全面指南 一、Redis简介 Redis(Remote Dictionary Server)是一个开源的、基于内存的键值存储系统,它可以用作数据库、缓存和消息代理。由Salvatore Sanfilippo于2009年开发,因其高性能、丰富的数据结构和广泛的语言支持而广受欢迎。 Redis核心特点:…...