当前位置: 首页 > news >正文

机器学习:训练集与测试集分割train_test_split

1 引言

在使用机器学习训练模型算法的过程中,为提高模型的泛化能力、防止过拟合等目的,需要将整体数据划分为训练集和测试集两部分,训练集用于模型训练,测试集用于模型的验证。此时,使用train_test_split函数可便捷高效的实现数据训练集与测试集的划分。

2 train_test_split介绍

train_test_split函数来自scikit-learn库(也称为sklearn),安装命令:

pip install sklearn

函数的导入:

from sklearn.model_selection import train_test_split

1.1 函数定义

def train_test_split(*arrays,test_size=None,train_size=None,random_state=None,shuffle=True,stratify=None,):

1.2 参数说明

  • *arrays: 单个数组或元组,表示需要划分的数据集。如果传入多个数组,则必须保证每个数组的第一维大小相同。
  • test_size: 测试集的大小(占总数据集的比例,值为0.0-1.0,表示测试集占总样本比例)。默认值为0.25,即将传入数据的25%作为测试集。
  • train_size: 训练集的大小(占总数据集的比例,值为0.0-1.0,表示训练集占总样本比例)。默认值为None,此时和test_size互补,即训练集的大小为(1-test_size)。
  • random_state: 随机数种子。可以设置一个整数,用于复现结果。默认为None。其实是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。(比如每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。)
  • shuffle: 是否随机打乱数据。默认为True。
  • stratify: 可选参数,用于进行分层抽样。传入标签数组,保证划分后的训练集和测试集中各类别样本比例与原始数据集相同。默认为None,即普通的随机划分。(此参数作用是保持测试集与整个数据集里的数据分类比例一致,比如有1000个数据,800个属于A类,200个属于B类。设置stratify = y_lable,test_size=0.25,split之后数据组成如下:training: 750个数据,其中600个属于A类,150个属于B类;testing: 250个数据,其中200个属于A类,50个属于B类)

1.3 返回值

该函数返回一个元组(X_train, X_test, y_train, y_test),其中X_train表示训练集的特征数据,X_test表示测试集的特征数据,y_train表示训练集的标签数据,y_test表示测试集的标签数据。

1.4 注意事项

  • test_sizetrain_size必须至少有一个设置为非None
  • 当传入多个数组时,请确保每个数组的第一维大小相同。
  • random_state要设置一个整数值,从而保证每次获取相同的训练集和测试集
  • 当使用分层抽样时,请确保传入的标签数组是正确的。

3 train_test_split使用

3.1 使用train_test_split分割Iris数据

from sklearn import datasets
from sklearn.model_selection import train_test_split# 加载Iris数据集
iris = datasets.load_iris()
X = iris.data
y = iris.targetX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1)
print(X_train)
print(X_test)

结果展示:

X_train=[[6.5 2.8 4.6 1.5][6.7 2.5 5.8 1.8][6.8 3.  5.5 2.1][5.1 3.5 1.4 0.3][6.  2.2 5.  1.5]......此处数据省略[4.9 3.6 1.4 0.1]]
X_test=[[5.8 4.  1.2 0.2][5.1 2.5 3.  1.1][6.6 3.  4.4 1.4][5.4 3.9 1.3 0.4][7.9 3.8 6.4 2. ]......此处数据省略[5.2 3.4 1.4 0.2]]

3.2 使用train_test_split分割水果识别数据

在/opt/dataset下存放着水果图片的分类数据文件夹(文件夹名称为标签),每个文件夹下存储着多张对应标签的水果图片,如下所示:

以apple文件夹为例,图片内容如下:

数据加载和分割数据集的代码如下:

from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split# 图像变换
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]), ])
# 加载数据集
dataset = ImageFolder('/opt/dataset', transform=transform)# 划分训练集与测试集
train_dataset, valid_dataset = train_test_split(dataset, test_size=0.2, random_state=10)batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

相关文章:

机器学习:训练集与测试集分割train_test_split

1 引言 在使用机器学习训练模型算法的过程中,为提高模型的泛化能力、防止过拟合等目的,需要将整体数据划分为训练集和测试集两部分,训练集用于模型训练,测试集用于模型的验证。此时,使用train_test_split函数可便捷高…...

淘宝API开发(一)简单介绍淘宝API功能接口作用

前一阵子按照上级指示,根据淘宝API开发符合自已应用的系统,比如批量上传,批量修改名称,价格等功能什么的,在此就将我的开发历程写一写,为自己前段时间的工作做个总结。 淘宝开发平台(淘宝网 - 淘&#xff…...

Redis相关面试题

Redis的使用场景 根据自己简历上的业务进行回答 缓存 穿透、击穿、雪崩、双写一致、持久化、数据过期、淘汰策略 分布式锁 setnx redisson 缓存穿透:查询一个不存在的数据,数据库查不到数据也不会直接写入缓存,就会导致每次请求都查询数据库…...

数据库简介

1、数据库安装: rpm (redhat package manager) 也是个包管理工具: rpm -ivh 安装 rpm -e 表示卸载,卸载的时候有可能出现依赖的问题,可以用 --nodeps 忽略依赖卸载。 rpm -qa 搜索系统中安装的rpm的应用。 如果使用离线包,安装顺序不要乱。 m…...

腾讯云国际轻量应用服务器怎么使用呢?

腾讯云国际轻量应用服务器怎么使用呢?下面一起来了解一下: 1. 熟悉轻量应用服务器基础知识 ①什么是轻量应用服务器 TencentCloud Lighthouse? ②轻量应用服务器与云服务器 CVM 的区别是什么? ③为什么选择轻量应用服务器&#xf…...

arm环境cloudstack在vpc下创建虚拟机失败

一、环境说明 操作系统:openEuler 22.03CPU:Kunpeng-920,arm v8cloudstack:4.18libvirtd:6.2.0 二、问题描述 在UI上创建VPC后,平台会同时创建一个virtual router,此时virtual router有两个网…...

Linux上安装Keepalived,多台Nginx配置Keepalived(保姆级教程)

目录 一、yum安装 第一步:下载 第二步:编辑Keepalived配置文件(第一台) 第三步:编辑Keepalived配置文件(第二台) 第四步:我们在本机利用cmd ping一下 一、yum安装 第一步&…...

centos7 ‘xxx‘ is not in the sudoers file...

如题 执行命令输入密码后时报错: [sudo] password for admin (我的账户)原因,当前用户还没有加入到root的配置文件中。 解决 vim打开配置文件,如下: #切换到root用户 su #编辑配置文件 vim /etc/sudoe…...

Zebec Payroll :计划推出 WageLink On-Demand Pay,进军薪酬发放领域

“Zebec Protocol 生态旨以 Web3 的方式建立全新的公平秩序,基于其流支付体系构建的薪酬支付板块,就是解决问题的一把利刃”...

【2023】字节跳动 10 日心动计划——第三关

目录 1. 最长有效括号2. 有序数组的平方 1. 最长有效括号 🔗 原题链接:32. 最长有效括号 类似于有效的括号,考虑用栈来解决。 具体来讲,我们始终保持栈底元素为当前已经遍历过的元素中「最后一个没有被匹配的右括号的下标」&…...

【无网络】win10更新后无法联网,有线无线都无法连接,且打开网络与Internet闪退

win10更新后无法联网,有线无线都无法连接,且打开网络与Internet闪退 法1 重新配置网络法2 更新驱动法3 修改注册表编辑器法4 重装系统 自从昨晚点了更新与重启后,今天电脑就再也不听话了,变着花样地连不上网。 检查路由器&#xf…...

HTML <script> 标签

实例 在 HTML 页面中插入一段 JavaScript: <script type="text/javascript"> document.write("Hello World!") </script>(在本页底部可以找到更多实例) 定义和用法 <script> 标签用于定义客户端脚本,比如 JavaScript。 script …...

FPGA----UltraScale+系列的PS侧与PL侧通过AXI-HP交互(全网唯一最详)附带AXI4协议校验IP使用方法

1、之前写过一篇关于ZYNQ系列通用的PS侧与PL侧通过AXI-HP通道的文档&#xff0c;下面是链接。 FPGA----ZCU106基于axi-hp通道的pl与ps数据交互&#xff08;全网唯一最详&#xff09;_zcu106调试_发光的沙子的博客-CSDN博客大家好&#xff0c;今天给大家带来的内容是&#xff0…...

Unity小游戏——迷你拼图

游戏展示 拼图演示 资源&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1BGeSmRCO_WZRUyl3MxefGw 提取码&#xff1a;0n4a 一、玩法介绍 排列拼图碎片&#xff0c;拼出最后的图案。可以点住碎片的任意位置拖动&#xff1b;点击"重来"按钮&#xff0c;可以…...

三 动手学深度学习v2 —— Softmax回归+损失函数+图片分类数据集

三 动手学深度学习v2 —— Softmax回归损失函数图片分类数据集 目录: softmax回归损失函数 1. softmax回归 回归vs分类: 回归估计一个连续值分类预测一个离散类别 从回归到多类分类 回归 单连续数值输出自然区间R跟真实值的误差作为损失 分类 通常多个输出输出i是预测为第…...

Stable Diffusion 使用教程

环境说明&#xff1a; stable diffusion version: v1.5.1python: 3.10.6torch: 2.0.1cu118xformers: N/Agradio: 3.32.0 1. 下载 webui 下载地址&#xff1a; GitHub stable-diffusion-webui 下载 根据自己的情况去下载&#xff1a; 最好是 N 卡&#xff1a;&#xff08;我的…...

在线考试系统springboot学生试卷问答管理java jsp源代码mysql

本项目为前几天收费帮学妹做的一个项目&#xff0c;Java EE JSP项目&#xff0c;在工作环境中基本使用不到&#xff0c;但是很多学校把这个当作编程入门的项目来做&#xff0c;故分享出本项目供初学者参考。 一、项目描述 在线考试系统springboot 系统有2权限&#xff1a;管理…...

创建vue-cli(脚手架搭建)

目录 功能 需要的环境 使用HbuilderX快速搭建一个vue-cli项目 组件路由 element-ui vue-cli 官方提供的一个脚手架&#xff0c;用于快速生成一个 vue 的项目模板&#xff1b;预先定义 好的目录结构及基础代码&#xff0c;就好比咱们在创建 Maven 项目时可以选择创建一个 骨…...

【单调栈part02】| 503.下一个更大元素||、42.接雨水

&#x1f388;LeetCode503.下一个更大元素|| 链接&#xff1a;503.下一个更大元素|| 给定一个循环数组 nums &#xff08; nums[nums.length - 1] 的下一个元素是 nums[0] &#xff09;&#xff0c;返回 nums 中每个元素的 下一个更大元素 。 数字 x 的 下一个更大的元素 是按…...

Java——如何使用Stream替换掉List<Student>中符合要求的元素

使用Stream替换掉List中符合要求的元素 要使用Stream流替换掉List中符合特定条件的元素&#xff0c;您可以使用Stream的map()方法对每个元素进行映射&#xff0c;并使用collect()方法将映射后的元素收集到一个新的List中。 示例代码&#xff1a; import java.util.ArrayList; …...

在软件开发中正确使用MySQL日期时间类型的深度解析

在日常软件开发场景中&#xff0c;时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志&#xff0c;到供应链系统的物流节点时间戳&#xff0c;时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库&#xff0c;其日期时间类型的…...

React 第五十五节 Router 中 useAsyncError的使用详解

前言 useAsyncError 是 React Router v6.4 引入的一个钩子&#xff0c;用于处理异步操作&#xff08;如数据加载&#xff09;中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误&#xff1a;捕获在 loader 或 action 中发生的异步错误替…...

React hook之useRef

React useRef 详解 useRef 是 React 提供的一个 Hook&#xff0c;用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途&#xff0c;下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文&#xff5c;DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色&#xff0c;华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型&#xff0c;能助力我们轻松驾驭 DeepSeek-V3/R1&#xff0c;本文中将分享如何…...

pikachu靶场通关笔记22-1 SQL注入05-1-insert注入(报错法)

目录 一、SQL注入 二、insert注入 三、报错型注入 四、updatexml函数 五、源码审计 六、insert渗透实战 1、渗透准备 2、获取数据库名database 3、获取表名table 4、获取列名column 5、获取字段 本系列为通过《pikachu靶场通关笔记》的SQL注入关卡(共10关&#xff0…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

打手机检测算法AI智能分析网关V4守护公共/工业/医疗等多场景安全应用

一、方案背景​ 在现代生产与生活场景中&#xff0c;如工厂高危作业区、医院手术室、公共场景等&#xff0c;人员违规打手机的行为潜藏着巨大风险。传统依靠人工巡查的监管方式&#xff0c;存在效率低、覆盖面不足、判断主观性强等问题&#xff0c;难以满足对人员打手机行为精…...

比较数据迁移后MySQL数据库和OceanBase数据仓库中的表

设计一个MySQL数据库和OceanBase数据仓库的表数据比较的详细程序流程,两张表是相同的结构,都有整型主键id字段,需要每次从数据库分批取得2000条数据,用于比较,比较操作的同时可以再取2000条数据,等上一次比较完成之后,开始比较,直到比较完所有的数据。比较操作需要比较…...