Keras使用sklearn中的交叉验证和网格搜索
Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的机器学习库,它基于Scipy,用于高效的数值计算。scikit-learn是一个功能齐全的通用机器学习库,并提供了许多在开发深度学习过程中非常有帮助的方法。例如scikit-learn提供了很多用于选择模型和对模型调参的方法,这些方法同样适用于深度学习。
Keras提供了一个Wrapper,将Keras的深度学习模型包装成scikit-learn中的分类模型或回归模型,以便于使用scikit-learn中的方法和函数。对于深度学习模型的包装是通过KerasClassifier(分类模型)和KerasRegressor(回归模型)来实现的。KerasClassifier和KerasRegressor类使用参数build_fn,指定用来创建模型的函数的名称。
Keras的一般构建流程:
model = Sequential() # 定义模型
model.add(Dense(units=64, activation='relu', input_dim=100)) # 定义网络结构
#第一层网络:输出尺寸64,输入尺寸100,activation激活函数relu
model.add(Dense(units=10, activation='softmax')) # 定义网络结构
#第二层网络:输出尺寸10,输入是上一层的输出尺寸64,activation激活函数softmax
model.compile(loss='categorical_crossentropy', # 定义loss函数、优化方法、评估标准optimizer='sgd',metrics=['accuracy'])
#输入训练样本和标签,迭代5次,每次迭代32个数据
model.fit(x_train, y_train, epochs=5, batch_size=32) # 训练模型
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128) # 评估模型
classes = model.predict(x_test, batch_size=128) # 使用训练好的数据进行预测
参数意义:
keras.layers.Dense(units, activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None)
units: 正整数,输出空间维度。
activation: 激活函数。 若不指定,则不使用激活函数 (即,「线性」激活: a(x) = x)。
use_bias: 布尔值,该层是否使用偏置向量。
kernel_initializer: kernel 权值矩阵的初始化器。
bias_initializer: 偏置向量的初始化器。
kernel_regularizer: 运用到 kernel 权值矩阵的正则化函数 。
bias_regularizer: 运用到偏置向的的正则化函数 。
activity_regularizer: 运用到层的输出的正则化函数 。
kernel_constraint: 运用到 kernel 权值矩阵的约束函数 。
bias_constraint: 运用到偏置向量的约束函数。
Keras调用scikit-learn实现交叉验证:
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score, KFold
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model():# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11, activation='relu'))model.add(Dense(units=8, activation='relu'))model.add(Dense(units=1, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, epochs=150, batch_size=10, verbose=0)# 10折交叉验证
kfold = KFold(n_splits=10, shuffle=True, random_state=10)
result = cross_val_score(model, X, Y, cv=kfold)
Keras调用scikit-learn实现模型调参
在构建深度学习模型时,如何配置一个最优模型一直是进行一个项目的重点。在机器学习中,可以通过算法自动调优这些配置参数,在这里将通过Keras的包装类,借助scikit-learn的网格搜索算法评估神经网络模型的不同配置,并找到最佳评估性能的参数组合。creat_model()函数被定义为具有两个默认值的参数(optimizer和init)的函数,创建模型后,定义要搜索的参数的数值数组,包括优化器(optimizer)、权重初始化方案(init)、epochs和batch_size。
在scikit-learn中的GridSearchCV需要一个字典类型的字段作为需要调整的参数,默认采用3折交叉验证来评估算法,由于4个参数需要进行调参,因此将会产生4✖️3个模型。
Keras调用scikit-learn实现GridSearchCV网格搜索:
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import pandas as pd
from sklearn.model_selection import GridSearchCV
from keras.wrappers.scikit_learn import KerasClassifierdef creat_model(optimizer='adam,init='glorot_uniform'):# 构建模型model = Sequential()model.add(Dense(units=12, input_dim=11,kernel_initializer=init, activation='relu'))model.add(Dense(units=8, kernel_initializer=init, activation='relu'))model.add(Dense(units=1, kernel_initializer=init, activation='sigmoid'))# 模型编译model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])return model# 导入数据
data = pd.read_csv('data.csv',encoding='gbk')# 删除id列
data.drop('客户编号',axis=1,inplace=True)X, Y = data.values[:,:-1], data.values[:,-1] # Keras调用sklearn
model = KerasClassifier(build_fn=creat_model, verbose=0)# 构建需要调整的参数
param_gird = {}
param_grid['optimizer'] = ['rmsprop','adam']
param_grid['init'] = ['glorot_uniform', 'normal', 'uniform']
param_gird['epochs'] = [50, 100, 150, 200]
param_gird['batch_size'] = [5, 10, 20]# 调参
grid = GridSearchCV(estimator=model, param_gird=param_grid)
result = grid.fit(X, Y)# 输出结果
print('Best: %f using %s' % (result.best_score_, result.best_params_))
关于Epochs和batch_size的解释?
Epochs是神经网络训练过程中的一个重要超参数,定义为向前和向后传播中所有批次的单次训练迭代。简单说,一个Epoch是将所有的数据输入网络完成一次向前计算及反向传播。在训练过程中,数据会被“轮”多少次,即应当完整遍历数据集多少次(一次为一个Epoch)。如果Epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果Epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。所以,选择适当的Epoch数量需要在充分训练和避免过拟合之间找到平衡。
假设我们有1000个数据样本,每次我们送入10个数据进行训练(也就是batch_size为10)。那么完成一个Epoch,我们需要进行100次迭代(也就是100次前向传播和100次反向传播)。具体来说,我们需要将所有的数据都送入神经网络进行一次前向传播和反向传播,所以一次Epoch相当于所有数据集/batch size=N次迭代。
相关文章:
Keras使用sklearn中的交叉验证和网格搜索
Keras是Python在深度学习领域非常受欢迎的第三方库,但Keras的侧重点是深度学习,而不是所以的机器学习。事实上,Keras力求极简主义,只专注于快速、简单地定义和构建深度学习模型所需要的内容。Python中的scikit-learn是非常受欢迎的…...
docker--Prometheus、Grafana、node_exporter的安装配置及Springboot集成Prometheus示例
1. 安装Prometheus Prometheus一个系统和服务监控系统。它以给定的时间间隔从配置的目标收集指标,计算规则表达式,显示结果,并在观察到某些条件为真时触发警报。 可观察性侧重于根据系统产生的数据了解系统的内部状态,这有助于确定基础设施是否健康。Prometheus是用于监视…...
数据结构和算法笔记2:二分法
二分法网上有两种写法,一种左闭右闭,一种左闭右开,个人习惯左闭右闭的写法, 有序数组查找数 这是标准二分法,对应力扣的704. 二分查找: 求值为target的索引 int search(vector<int>& nums, i…...
Mybatis3系列课程8-带参数查询
简介 上节课内容中讲解了查询全部, 不需要带条件查, 这节我们讲讲 带条件查询 目标 1. 带一个条件查询-基本数据类型 2.带两个条件查询-连个基本数据类型 3.带一个对象类型查询 为了实现目标, 我们要实现 按照主键 查询某个学生信息, 按照姓名和年级编号查询学生信息 按照学生…...
IDEA shorten command line介绍和JAR manifest 导致mybatis找不到接口类处理
如果类路径太长,或者有许多VM参数,程序就无法启动。原因是大多数操作系统都有命令行长度限制。在这种情况下,IntelliJIDEA将试图缩短类路径。最好选中 classpath file模式。 shorten command line 选项提供三种选项缩短类路径。 none&#x…...
泽攸科技SEM台式扫描电子显微镜
泽攸科技是一家国产的科学仪器公司,专注于研发、生产和销售原位电镜解决方案、扫描电镜整机、台阶仪、探针台等仪器。目前台式扫描电镜分为三个系列:ZEM15、ZEM18、ZEM20。 ZEM15台式扫描电镜: ZEM18台式扫描电镜: ZEM20台式扫描…...
华为交换机配置BGP的基本示例
BGP简介 定义 边界网关协议BGP(Border Gateway Protocol)是一种实现自治系统AS(Autonomous System)之间的路由可达,并选择最佳路由的距离矢量路由协议。早期发布的三个版本分别是BGP-1(RFC1105࿰…...
数据分析基础之《numpy(4)—ndarry运算》
一、逻辑运算 当我们要操作符合某一条件的数据时,需要用到逻辑运算 1、运算符 满足条件返回true,不满足条件返回false # 重新生成8只股票10个交易日的涨跌幅数据 stock_change np.random.normal(loc0, scale1, size(8, 10))# 获取前5行前5列的数据 s…...
分享一个项目——Sambert UI 声音克隆
文章目录 前言一、运行ipynb二、数据标注三、训练四、生成总结 前言 原教程视频 项目链接 运行一个ipynb,就可操作 总共四步 1)运行ipynb 2)数据标注 3)训练 4)生成 一、运行ipynb 等运行完毕后,获得该…...
ES6 语法精粹简读
本文旨在记录 ES6 的核心常用语法,略去一些细节。 文章目录 1 var 函数作用域与 let/const 块作用域2 解构赋值数组结构赋值对象结构赋值3 ES6 中字符串的新语法模板字符串模板编译标签模板4 ES6 中的函数默认值rest 参数箭头函数this 指向问题部署管道机制尾调用优化...
uniapp整合echarts(目前性能最优、渲染最快方案)
本文echarts示例如上图,可扫码体验渲染速度及loading效果,下文附带本小程序uniapp相关代码 实现代码 <template><view class="source...
解决Electron应用中的白屏问题的实用方法
在使用Electron构建应用程序时,一些开发者可能会面临窗口加载过程中出现的白屏问题。这种问题主要分为两个方面: Electron未加载完毕HTML: 这时Electron自身产生的白色背景可能导致用户在启动应用时看到一片空白。HTML加载渲染过程中的短暂白…...
大数据---34.HBase数据结构
一、HBase简介 HBase是一个开源的、分布式的、版本化的NoSQL数据库(即非关系型数据库),依托Hadoop分布式文件系统HDFS提供分布式数据存储,利用MapReduce来处理海量数据,用Zookeeper作为其分布式协同服务,一…...
【工具使用-有道云笔记】如何在有道云笔记中插入目录
一,简介 本文主要介绍如何在有道云笔记中插入目录,方便后续笔记的查看,供参考。 二,具体步骤 分为两个步骤:1,设置标题格式;2,插入标题。非常简单~ 2.1 设置标题格式 鼠标停在标…...
用户管理第2节课-idea 2023.2 后端一删除表,从零开始---【本人】
一、清空model文件夹下,所有文件 1.1.1效果如下: 1.1代码内容 package com.daisy.usercenter.model;import lombok.Data;Data public class User {private Long id;private String name;private Integer age;private String email; }二、清空mapper文件…...
如何添加jar包到本地Maven项目中
在 Maven 中添加一个外部 JAR 包的依赖,你需要使用 Maven 的 <dependency> 元素来指定该 JAR 包的坐标信息。以下是具体的步骤: 将 JAR 包手动添加到 Maven 本地仓库: 首先,确保将外部 JAR 包手动添加到 Maven 本地仓库。可…...
智能优化算法应用:基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码
智能优化算法应用:基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于学校优化算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.学校优化算法4.实验参数设定5.算法结果6.…...
【MATLAB第85期】基于MATLAB的2023年智能进化算法/元启发式算法合集(持续更新)
【MATLAB第85期】基于MATLAB的2023年智能进化算法/元启发式算法合集(持续更新) 1.海象进化算法(Walrus Optimization Algorithm) 作者:Pavel Trojovsk and Mohammad Dehghani 2.暴龙优化算法(Tyrannosa…...
[Realtek sdk-3.4.14b]RTL8197FH-VG+RTL8812F WiFi使用功率限制功能使用说明
sdk说明 ** Gateway/AP firmware v3.4.14b – Aug 26, 2019** Wireless LAN driver changes as: Refine WiFi Stability and Performance Add 8812F MU-MIMO Add 97G/8812F multiple mac-clone Add 97G 2T3R antenna diversity Fix 97G/8812F/8814B MP issu…...
Vue中为什么data属性是一个函数而不是一个对象?(看完就会了)
文章目录 一、实例和组件定义data的区别二、组件data定义函数与对象的区别三、原理分析四、结论 一、实例和组件定义data的区别 vue实例的时候定义data属性既可以是一个对象,也可以是一个函数 const app new Vue({el:"#app",// 对象格式data:{foo:&quo…...
[2025CVPR]DeepVideo-R1:基于难度感知回归GRPO的视频强化微调框架详解
突破视频大语言模型推理瓶颈,在多个视频基准上实现SOTA性能 一、核心问题与创新亮点 1.1 GRPO在视频任务中的两大挑战 安全措施依赖问题 GRPO使用min和clip函数限制策略更新幅度,导致: 梯度抑制:当新旧策略差异过大时梯度消失收敛困难:策略无法充分优化# 传统GRPO的梯…...
基于大模型的 UI 自动化系统
基于大模型的 UI 自动化系统 下面是一个完整的 Python 系统,利用大模型实现智能 UI 自动化,结合计算机视觉和自然语言处理技术,实现"看屏操作"的能力。 系统架构设计 #mermaid-svg-2gn2GRvh5WCP2ktF {font-family:"trebuchet ms",verdana,arial,sans-…...
Vue3 + Element Plus + TypeScript中el-transfer穿梭框组件使用详解及示例
使用详解 Element Plus 的 el-transfer 组件是一个强大的穿梭框组件,常用于在两个集合之间进行数据转移,如权限分配、数据选择等场景。下面我将详细介绍其用法并提供一个完整示例。 核心特性与用法 基本属性 v-model:绑定右侧列表的值&…...
基于uniapp+WebSocket实现聊天对话、消息监听、消息推送、聊天室等功能,多端兼容
基于 UniApp + WebSocket实现多端兼容的实时通讯系统,涵盖WebSocket连接建立、消息收发机制、多端兼容性配置、消息实时监听等功能,适配微信小程序、H5、Android、iOS等终端 目录 技术选型分析WebSocket协议优势UniApp跨平台特性WebSocket 基础实现连接管理消息收发连接…...
Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具
文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...
【论文笔记】若干矿井粉尘检测算法概述
总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...
【Oracle】分区表
个人主页:Guiat 归属专栏:Oracle 文章目录 1. 分区表基础概述1.1 分区表的概念与优势1.2 分区类型概览1.3 分区表的工作原理 2. 范围分区 (RANGE Partitioning)2.1 基础范围分区2.1.1 按日期范围分区2.1.2 按数值范围分区 2.2 间隔分区 (INTERVAL Partit…...
七、数据库的完整性
七、数据库的完整性 主要内容 7.1 数据库的完整性概述 7.2 实体完整性 7.3 参照完整性 7.4 用户定义的完整性 7.5 触发器 7.6 SQL Server中数据库完整性的实现 7.7 小结 7.1 数据库的完整性概述 数据库完整性的含义 正确性 指数据的合法性 有效性 指数据是否属于所定…...
【LeetCode】3309. 连接二进制表示可形成的最大数值(递归|回溯|位运算)
LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 题目描述解题思路Java代码 题目描述 题目链接:LeetCode 3309. 连接二进制表示可形成的最大数值(中等) 给你一个长度为 3 的整数数组 nums。 现以某种顺序 连接…...
wpf在image控件上快速显示内存图像
wpf在image控件上快速显示内存图像https://www.cnblogs.com/haodafeng/p/10431387.html 如果你在寻找能够快速在image控件刷新大图像(比如分辨率3000*3000的图像)的办法,尤其是想把内存中的裸数据(只有图像的数据,不包…...
