交叉验证之KFold和StratifiedKFold的使用(附案例实战)

🤵♂️ 个人主页:@艾派森的个人主页
✍🏻作者简介:Python学习者
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+
一、交叉验证简介
交叉验证是在机器学习建立模型和验证模型参数时常用的办法。交叉验证,顾名思义,就是重复的使用数据,把得到的样本数据进行切分,组合为不同的训练集和测试集,用训练集来训练模型,用测试集来评估模型预测的好坏。在此基础上可以得到多组不同的训练集和测试集,某次训练集中的某样本在下次可能成为测试集中的样本,即所谓“交叉”。
那么什么时候才需要交叉验证呢?交叉验证用在数据不是很充足的时候。通常情况下,如果数据样本量小于一万条,我们就会采用交叉验证来训练优化选择模型。如果样本大于一万条的话,我们一般随机的把数据分成三份,一份为训练集(Training Set),一份为验证集(Validation Set),最后一份为测试集(Test Set)。用训练集来训练模型,用验证集来评估模型预测的好坏和选择模型及其对应的参数。把最终得到的模型再用于测试集,最终决定使用哪个模型以及对应参数。
学习预测函数的参数,并在相同数据集上进行测试是一种错误的做法: 一个仅给出测试用例标签的模型将会获得极高的分数,但对于尚未出现过的数据它则无法预测出任何有用的信息。 这种情况称为 overfitting(过拟合).。为了避免这种情况,在进行机器学习实验时,通常取出部分可利用数据作为 test set(测试数据集) X_test, y_test。下面是模型训练中典型的交叉验证工作流流程图。通过网格搜索可以确定最佳参数。

k-折交叉验证得出的性能指标是循环计算中每个值的平均值。 该方法虽然计算代价很高,但是它不会浪费太多的数据(如固定任意测试集的情况一样), 在处理样本数据集较少的问题(例如,逆向推理)时比较有优势。

k-折交叉验证步骤
- 第一步,不重复抽样将原始数据随机分为 k 份。
- 第二步,每一次挑选其中 1 份作为测试集,剩余 k-1 份作为训练集用于模型训练。
- 第三步,重复第二步 k 次,这样每个子集都有一次机会作为测试集,其余机会作为训练集。
- 在每个训练集上训练后得到一个模型,
- 用这个模型在相应的测试集上测试,计算并保存模型的评估指标,
- 第四步,计算 k 组测试结果的平均值作为模型精度的估计,并作为当前 k 折交叉验证下模型的性能指标。
例如:
十折交叉验证
- 将训练集分成十份,轮流将其中9份作为训练数据,1份作为测试数据,进行试验。每次试验都会得出相应的正确率。
- 10次的结果的正确率的平均值作为对算法精度的估计,一般还需要进行多次10折交叉验证(例如10次10折交叉验证),再求其均值,作为对算法准确性的估计
- 模型训练过程的所有步骤,包括模型选择,特征选择等都是在单个折叠 fold 中独立执行的。
- 此外:
- 多次 k 折交叉验证再求均值,例如:10 次10 折交叉验证,以求更精确一点。
- 数据量大时,k设置小一些 / 数据量小时,k设置大一些。

KFold和StratifiedKFold的使用
StratifiedKFold用法类似Kfold,但是它是分层采样,确保训练集,测试集中各类别样本的比例与原始数据集中相同。这一区别在于当遇到非平衡数据时,StratifiedKFold() 各个类别的比例大致和完整数据集中相同,若数据集有4个类别,比例是2:3:3:2,则划分后的样本比例约是2:3:3:2;但是KFold可能存在一种情况:数据集有5类,抽取出来的也正好是按照类别划分的5类,也就是说第一折全是0类,第二折全是1类等等,这样的结果就会导致模型训练时没有学习到测试集中数据的特点,从而导致模型得分很低,甚至为0。
Parameters
- n_splits : int, default=3 也就是K折中的k值,必须大于等于2
- shuffle : boolean True表示打乱顺序,False反之
- random_state :int,default=None 随机种子,如果设置值了,shuffle必须为True
# KFold
from sklearn.model_selection import KFold
kfolds = KFold(n_splits=3)
for train_index, test_index in kfolds.split(X,y):print('X_train:%s ' % X[train_index])print('X_test: %s ' % X[test_index])# StratifiedKFold
from sklearn.model_selection import StratifiedKFold
skfold = StratifiedKFold(n_splits=3)
for train_index, test_index in skfold.split(X,y):print('X_train:%s ' % X[train_index])print('X_test: %s ' % X[test_index])
KFold和StratifiedKFold实战案例
首先导入数据集,本数据集为员工离职数据,属于二分类任务
import pandas as pd
import warnings
warnings.filterwarnings('ignore')data = pd.read_excel('data.xlsx')
data['薪资情况'].replace(to_replace={'低':0,'中':1,'高':2},inplace=True)
data.head()

拆分数据集为训练集和测试集,测试集比例为0.2
from sklearn.model_selection import train_test_split
X = data.drop('是否离职',axis=1)
y = data['是否离职']
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2)
初始化一个分类模型,这里用逻辑回归模型举例。方法1使用cross_val_score()可以直接得到k折训练的模型效果,比如下面使用3折进行训练,得分评估使用准确率,关于scoring这个参数我会在文末介绍。
# 初始化一个分类模型,比如逻辑回归
from sklearn.linear_model import LogisticRegression
lg = LogisticRegression()
# 方法1
from sklearn.model_selection import cross_val_score
scores = cross_val_score(lg,X_train,y_train,cv=3,scoring='accuracy')
print(scores)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

接下来分别使用KFold和StratifiedKFold,其实两者代码非常类似,只是前面的方法不同。
KFold
# 方法2-KFold和StratifiedKFold
import numpy as np
from sklearn.model_selection import KFold,StratifiedKFold
from sklearn.metrics import accuracy_score,recall_score,f1_score
# KFold
kfolds = KFold(n_splits=3)
accuracy_score_list,recall_score_list,f1_score_list = [],[],[]
for train_index,test_index in kfolds.split(X_train,y_train):# 准备交叉验证的数据X_train_fold = X_train.iloc[train_index]y_train_fold = y_train.iloc[train_index]X_test_fold = X_train.iloc[test_index]y_test_fold = y_train.iloc[test_index]# 训练模型lg.fit(X_train_fold,y_train_fold)y_pred = lg.predict(X_test_fold)# 评估模型AccuracyScore = accuracy_score(y_test_fold,y_pred)RecallScore = recall_score(y_test_fold,y_pred)F1Score = f1_score(y_test_fold,y_pred)# 将评估指标存放对应的列表中accuracy_score_list.append(AccuracyScore)recall_score_list.append(RecallScore)f1_score_list.append(F1Score)# 打印每一次训练的正确率、召回率、F1值print('accuracy_score:',AccuracyScore,'recall_score:',RecallScore,'f1_score:',F1Score)
# 打印各指标的平均值和95%的置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (np.average(accuracy_score_list), np.std(accuracy_score_list) * 2))
print("Recall: %0.2f (+/- %0.2f)" % (np.average(recall_score_list), np.std(recall_score_list) * 2))
print("F1_score: %0.2f (+/- %0.2f)" % (np.average(f1_score_list), np.std(f1_score_list) * 2))

StratifiedKFold
# StratifiedKFold
skfolds = StratifiedKFold(n_splits=3)
accuracy_score_list,recall_score_list,f1_score_list = [],[],[]
for train_index,test_index in skfolds.split(X_train,y_train):# 准备交叉验证的数据X_train_fold = X_train.iloc[train_index]y_train_fold = y_train.iloc[train_index]X_test_fold = X_train.iloc[test_index]y_test_fold = y_train.iloc[test_index]# 训练模型lg.fit(X_train_fold,y_train_fold)y_pred = lg.predict(X_test_fold)# 评估模型AccuracyScore = accuracy_score(y_test_fold,y_pred)RecallScore = recall_score(y_test_fold,y_pred)F1Score = f1_score(y_test_fold,y_pred)# 将评估指标存放对应的列表中accuracy_score_list.append(AccuracyScore)recall_score_list.append(RecallScore)f1_score_list.append(F1Score)# 打印每一次训练的正确率、召回率、F1值print('accuracy_score:',AccuracyScore,'recall_score:',RecallScore,'f1_score:',F1Score)
# 打印各指标的平均值和95%的置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (np.average(accuracy_score_list), np.std(accuracy_score_list) * 2))
print("Recall: %0.2f (+/- %0.2f)" % (np.average(recall_score_list), np.std(recall_score_list) * 2))
print("F1_score: %0.2f (+/- %0.2f)" % (np.average(f1_score_list), np.std(f1_score_list) * 2))

补充
scoring 参数: 定义模型评估规则
Model selection (模型选择)和 evaluation (评估)使用工具,例如 model_selection.GridSearchCV 和 model_selection.cross_val_score ,采用 scoring 参数来控制它们对 estimators evaluated (评估的估计量)应用的指标。
常见场景: 预定义值
对于最常见的用例, 可以使用 scoring 参数指定一个 scorer object (记分对象); 下表显示了所有可能的值。 所有 scorer objects (记分对象)遵循惯例 higher return values are better than lower return values(较高的返回值优于较低的返回值)。因此,测量模型和数据之间距离的 metrics (度量),如 metrics.mean_squared_error 可用作返回 metric (指数)的 negated value (否定值)的 neg_mean_squared_error 。
| Scoring(得分) | Function(函数) | Comment(注解) |
|---|---|---|
| Classification(分类) | ||
| ‘accuracy’ | metrics.accuracy_score | |
| ‘average_precision’ | metrics.average_precision_score | |
| ‘f1’ | metrics.f1_score | for binary targets(用于二进制目标) |
| ‘f1_micro’ | metrics.f1_score | micro-averaged(微平均) |
| ‘f1_macro’ | metrics.f1_score | macro-averaged(宏平均) |
| ‘f1_weighted’ | metrics.f1_score | weighted average(加权平均) |
| ‘f1_samples’ | metrics.f1_score | by multilabel sample(通过 multilabel 样本) |
| ‘neg_log_loss’ | metrics.log_loss | requires predict_proba support(需要 predict_proba 支持) |
| ‘precision’ etc. | metrics.precision_score | suffixes apply as with ‘f1’(后缀适用于 ‘f1’) |
| ‘recall’ etc. | metrics.recall_score | suffixes apply as with ‘f1’(后缀适用于 ‘f1’) |
| ‘roc_auc’ | metrics.roc_auc_score | |
| Clustering(聚类) | ||
| ‘adjusted_mutual_info_score’ | metrics.adjusted_mutual_info_score | |
| ‘adjusted_rand_score’ | metrics.adjusted_rand_score | |
| ‘completeness_score’ | metrics.completeness_score | |
| ‘fowlkes_mallows_score’ | metrics.fowlkes_mallows_score | |
| ‘homogeneity_score’ | metrics.homogeneity_score | |
| ‘mutual_info_score’ | metrics.mutual_info_score | |
| ‘normalized_mutual_info_score’ | metrics.normalized_mutual_info_score | |
| ‘v_measure_score’ | metrics.v_measure_score | |
| Regression(回归) | ||
| ‘explained_variance’ | metrics.explained_variance_score | |
| ‘neg_mean_absolute_error’ | metrics.mean_absolute_error | |
| ‘neg_mean_squared_error’ | metrics.mean_squared_error | |
| ‘neg_mean_squared_log_error’ | metrics.mean_squared_log_error | |
| ‘neg_median_absolute_error’ | metrics.median_absolute_error | |
| ‘r2’ | metrics.r2_score |
相关文章:
交叉验证之KFold和StratifiedKFold的使用(附案例实战)
🤵♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞Ǵ…...
Cloud Kernel SIG月度动态:发布ANCK 5.10、4.19新版本,ABS新增仓库构建功能
Cloud Kernel SIG(Special Interest Group):支撑龙蜥内核版本的研发、发布和服务,提供生产可用的高性价比内核产品。 01 SIG 整体进展 发布 ANCK 5.10-014 版本。 发布 ANCK 4.19-027.2 版本。 ABS 平台新增 OOT 仓库临时构建功…...
JavaScript:new操作符
一、new操作符的作用 用于创建一个给定构造函数的实例对象 new操作符创建一个用户定义的对象类型的实例 或 具有构造函数的内置对象的实例。二、new一个构造函数的执行过程 2.1、创建一个空对象obj 2.2、将空对象的原型与构造函数的原型连接起来 2.3、将构造函数中的this绑定…...
XShell配置以及使用教程
目录 1、XShell介绍 2、安装XShell 1. 双击运行XShell安装文件,并点击“下一步” 2. 点击“我接受许可证协议中的条款”,点击“下一步” 3. 点击“浏览”更改默认安装路径,点击“下一步” 4. 直接点击“安装” 5. 安装完成࿰…...
Vue3 基础语法
文章目录 1.创建Vue项目1.1创建项目1.2 初始项目 2.vue3 语法2.1 复杂写法2.2 简易写法2.3 reactive(对象类型)2.4 ref(简单类型)2.5 computed(计算属性)2.6 watch(监听) 3.vue3 生命周期4.vue3 组件通信4.…...
【开源项目】Disruptor框架介绍及快速入门
Disruptor框架简介 Disruptor框架内部核心的数据结构是Ring Buffer,Ring Buffer是一个环形的数组,Disruptor框架以Ring Buffer为核心实现了异步事件处理的高性能架构;JDK的BlockingQueue相信大家都用过,其是一个阻塞队列…...
双向链表实现约瑟夫问题
title: 双向链表实现约瑟夫问题 date: 2023-05-16 11:42:26 tags: **问题:**知n个人围坐在一张圆桌周围。从编号为k的人开始报数,数到m的那个人出列;他的下一个人又从1开始报数,数到m的那个人又出列;依此规律重复下去&…...
日心说为人类正确认识宇宙打下了基础(善用工具的重要性)
文章目录 引言I 伽利略1.1 借助天文望远镜获得了比别人更多的信息。1.2 确定了科学研究方法:实验和观测 II 开普勒三定律 引言 享有科学史上崇高地位的人,都需要在构建科学体系上有重大贡献。 日心说在哥白尼那里还是一个假说,伽利略拿事实…...
Kali-linux系统指纹识别
现在一些便携式计算机操作系统使用指纹识别来验证密码进行登录。指纹识别是识别系统的一个典型模式,包括指纹图像获取、处理、特征提取和对等模块。如果要做渗透测试,需要了解要渗透测试的操作系统的类型才可以。本节将介绍使用Nmap工具测试正在运行的主…...
Java版本电子招标采购系统源码:营造全面规范安全的电子招投标环境,促进招投标市场健康可持续发展
营造全面规范安全的电子招投标环境,促进招投标市场健康可持续发展 传统采购模式面临的挑战 一、立项管理 1、招标立项申请 功能点:招标类项目立项申请入口,用户可以保存为草稿,提交。 2、非招标立项申请 功能点:非招标…...
Java字符串知多少:String、StringBuffer、StringBuilder
一、String 1、简介 String 是 Java 中使用得最频繁的一个类了,不管是作为开发者的业务使用,还是一些系统级别的字符使用, String 都发挥着重要的作用。String 是不可变的、final的,不能被继承,且 Java 在运行时也保…...
中国20强(上市)游戏公司2022年财报分析:营收结构优化,市场竞争进入白热化
易观:受全球经济增速下行的消极影响,2022年国内外游戏市场规模普遍下滑。但中国游戏公司凭借处于全球领先水平的研发、发行和运营的能力与经验,继续加大海外市场布局,推动高质量发展迈上新台阶。 风险提示:本文内容仅代…...
如何自学C++编程语言,聊聊C++的特点,别轻易踩坑
为什么现在有那么多C培训班呢?因为这些培训班可以为学生安排工作,而外包公司因为缺人,需要做很多项目,可能需要在全国各地分配不同的程序员去干不同的项目,因此需要大量的程序员入职。这样,外包公司就会找培…...
算法Day07 | 454.四数相加II,383. 赎金信,15. 三数之和, 18. 四数之和
Day07 454.四数相加II383. 赎金信15. 三数之和18. 四数之和 454.四数相加II 题目链接:454.四数相加II 寻找两个数组之和,是否与另外两个数组之和有特定的关系。 因为数值可能跨度太大,选择使用下标表示为对应的数值大小,会很浪费…...
ps抠图、抠头发去背景等
方法一:背景橡皮擦 一、很早之前我们使用的是魔术棒工具,但现在我们可以使用Photoshop 有内置的“背景橡皮擦” 步骤: 第1步:在Photoshop中打开需要修的图。 第2步:单击并按住工具栏…...
计算机组成原理基础练习题第一章
有些计算机将一部分软件永恒地存于只读存储器中,称之为() A.硬件 B.软件C.固件 D.辅助存储器输入、输出装置以及外界的辅助存储器称为() A.操作系统 B.存储器 C.主机 D.外围设备完整的计算机系…...
[PyTorch][chapter 34][池化层与采样]
前言: 这里主要讲解一下卷积神经网络中的池化层与采样 目录 DownSampleMax poolingavg poolingupsampleReLu 1: DownSample 下采样,间隔一定行或者列进行采样,达到降维效果 早期LeNet-5 就采样该采样方式。 LeNet-5 2 Max pooling 最大值采样…...
Java进阶-字符串的使用
1.API 1.1API概述 什么是API API (Application Programming Interface) :应用程序编程接口 java中的API 指的就是 JDK 中提供的各种功能的 Java类,这些类将底层的实现封装了起来,我们不需要关心这些类是如何实现的,只需要…...
接口自动化框架对比 | 质量工程
一、前言 自动化测试是把将手工驱动的测试行为转化为机器自动执行,通常操作是在某一框架下进行代码编写,实现用例自动发现与执行,托管在CI/CD平台上,通过条件触发或手工触发,进行回归测试&线上监控,代替…...
谷歌浏览器network error解决方法
很多用户在使用谷歌浏览器时候会出现network error网页提示,很多用户不知道该如何处理这一问题,其实解决方法不止一种,小编整理了两种谷歌浏览器network error解决方法,一起来看看吧~ 谷歌浏览器network error解决方法࿱…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
CentOS下的分布式内存计算Spark环境部署
一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...
el-switch文字内置
el-switch文字内置 效果 vue <div style"color:#ffffff;font-size:14px;float:left;margin-bottom:5px;margin-right:5px;">自动加载</div> <el-switch v-model"value" active-color"#3E99FB" inactive-color"#DCDFE6"…...
页面渲染流程与性能优化
页面渲染流程与性能优化详解(完整版) 一、现代浏览器渲染流程(详细说明) 1. 构建DOM树 浏览器接收到HTML文档后,会逐步解析并构建DOM(Document Object Model)树。具体过程如下: (…...
GitHub 趋势日报 (2025年06月08日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...
项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)
Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
在QWebEngineView上实现鼠标、触摸等事件捕获的解决方案
这个问题我看其他博主也写了,要么要会员、要么写的乱七八糟。这里我整理一下,把问题说清楚并且给出代码,拿去用就行,照着葫芦画瓢。 问题 在继承QWebEngineView后,重写mousePressEvent或event函数无法捕获鼠标按下事…...
基于SpringBoot在线拍卖系统的设计和实现
摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统,主要的模块包括管理员;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

