当前位置: 首页 > news >正文

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为:Iris Setosa,Iris Versicolour,Iris Virginica。

对于多分类任务,有较多机器学习的算法可以支持。本文将使用决策树、线性回归、SVM等多种算法来完成这一任务,并对不同方法进行比较。

01、使用Logistic实现鸢尾花分类

在前面介绍过Logistic用于二分类任务,对其进行扩展也用于多分类任务。下面将使用sklearn库完成一个基于Logistic的鸢尾花分类任务。如代码清单1所示,首先是导入sklearn.datasets包从而加载数据集,并将数据集按照测试集占比0.2随机分为训练集和测试集。

代码清单1 导入包以及加载数据集

rom sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score, roc_auc_score, \roc_curve
import matplotlib.pyplot as plt# 加载数据集
def loadDataSet():iris_dataset = load_iris()X = iris_dataset.datay = iris_dataset.target# 将数据划分为训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)return X_train, X_test, y_train, y_test

如代码清单2所示,编写函数训练Logistic模型。

代码清单2 训练Logistic模型

# 训练Logistic模性
def trainLS(x_train, y_train):# Logistic生成和训练clf = LogisticRegression()clf.fit(x_train, y_train)return clf

Logistic模型较为简单,不需要额外设置超参数即可开始训练。如代码清单3所示,初始化Logistic模型并将模型在训练集上训练,返回训练好的模型。

代码清单3 测试模型及打印各种评价指标

# 测试模型
def test(model, x_test, y_test):# 将标签转换为one-hot形式y_one_hot = label_binarize(y_test, np.arange(3))# 预测结果y_pre = model.predict(x_test)# 预测结果的概率y_pre_pro = model.predict_proba(x_test)# 混淆矩阵con_matrix = confusion_matrix(y_test, y_pre)print('confusion_matrix:\n', con_matrix)print('accuracy:{}'.format(accuracy_score(y_test, y_pre)))print('precision:{}'.format(precision_score(y_test, y_pre, average='micro')))print('recall:{}'.format(recall_score(y_test, y_pre, average='micro')))print('f1-score:{}'.format(f1_score(y_test, y_pre, average='micro')))# 绘制ROC曲线drawROC(y_one_hot, y_pre_pro)

在预测结果时,为了方便后面绘制ROC曲线,需要首先将测试集的标签转化为one-hot的形式,并得到模型在测试集上预测结果的概率值即y_pre_pro,从而传入drawROC函数完成ROC曲线的绘制。除此外,该函数实现了输出混淆矩阵以及计算准确率、精确率、查全率以及f1-score的功能。

代码清单4 绘制ROC曲线

def drawROC(y_one_hot, y_pre_pro):# AUC值auc = roc_auc_score(y_one_hot, y_pre_pro, average='micro')# 绘制ROC曲线fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(), y_pre_pro.ravel())plt.plot(fpr, tpr, linewidth=2, label='AUC=%.3f' % auc)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1.1, 0, 1.1])plt.xlabel('False Postivie Rate')plt.ylabel('True Positive Rate')plt.legend()plt.show()

如代码清单4所示为绘制ROC曲线的代码实现。最后将加载数据集,训练模型,以及模型验证的整个流程连接起来从而实现main函数,如代码清单5所示。

代码清单5 main函数设置

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainLS(X_train, y_train)test(model, X_test, y_test)

将上述所有代码放在同一py脚本文件中,如图1所示可得最终的输出结果为

图1 命令行打印的测试结果

绘制得到的ROC曲线如图2所示。

图2 ROC曲线

Logistic是一个较为简单的模型,参数量较少,一般也用于较为简单的分类任务中,当任务更为复杂时,可以选取更为复杂的模型获得更好的效果,下面将使用不同的模型从而验证同一任务在不同模型下的表现。

02、使用决策树实现鸢尾花分类

由于只改动了模型,加载数据集、模型评价等其他部分的代码不需要改动,如代码清单6所示,增加新的函数用于训练决策树模型。

代码清单6 使用决策树模型进行训练

from sklearn import tree
# 训练决策树模性
def trainDT(x_train, y_train):# DT生成和训练clf = tree.DecisionTreeClassifier(criterion="entropy")clf.fit(x_train, y_train)return clf

同时修改main函数中调用的训练函数如代码清单7所示。

代码清单7 修改main函数内容

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainDT(X_train, y_train)test(model, X_test, y_test)

最后运行可得命令行输出如图3所示。

图3 决策树模型预测结果

以及ROC曲线如图4所示。

图4 决策树模型绘制ROC曲线

相比Logistic模型,决策树模型无论在哪一项指标上都得到了更高的评分,且决策树模型不会像Logistic模型一样受初始化的影响,多次运行程序均可获得相同的输出模型,而Logistic模型运行多次会发现评价指标会在某个范围内上下抖动。

03、使用SVM实现鸢尾花分类

到现在相信大家都已经非常熟悉如何继续修改代码从而实现SVM模型的预测,实现SVM模型的训练代码如代码清单8所示

代码清单8 使用SVM模型进行训练

# 训练SVM模性
from sklearn import svm
def trainSVM(x_train, y_train):# SVM生成和训练clf = svm.SVC(kernel='rbf', probability=True)clf.fit(x_train, y_train)return clf

同时修改main函数,如代码清单9所示。

代码清单9 修改main函数内容

if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainSVM(X_train, y_train)test(model, X_test, y_test)

程序运行输出如图5所示。

图5 使用SVM模型预测结果

绘制得到的ROC曲线如图6所示。

图6 使用SVM模型绘制的ROC曲线

可以发现,随着模型进一步变得复杂,最终预测的各项指标进一步上升,在三个模型中SVM模型的高斯核最终结果在测试集中表现得最好且没有发生过拟合的现象,因此可以选用SVM模型来完成鸢尾花分类这一任务。

相关文章:

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为&…...

服务端接口优化方案

一、背景 针对老项目,去年做了许多降本增效的事情,其中发现最多的就是接口耗时过长的问题,就集中搞了一次接口性能优化。本文将给小伙伴们分享一下接口优化的通用方案。 二、接口优化方案总结 1. 批处理 批量思想:批量操作数据…...

【并发基础】Happens-Before模型详解

目录 一、Happens-Before模型简介 二、组成Happens-Before模型的八种规则 2.1 程序顺序规则(as-if-serial语义) 2.2 传递性规则 2.3 volatile变量规则 2.4 监视器锁规则 2.5 start规则 2.6 Join规则 一、Happens-Before模型简介 除了显示引用vo…...

Kubernetes系列---Kubernetes 理论知识 | 初识

Kubernetes系列---Kubernetes 理论知识 | 初识 1.K8s 是什么?2.K8s 特性3.小拓展(业务升级)4.K8s 集群架构与组件①架构拓扑图:②Master 组件③Node 组件 五 K8s 核心概念六 官方提供的三种部署方式总结 1.K8s 是什么&#xff1f…...

KingbaseES 原生XML系列三--XML数据查询函数

KingbaseES 原生XML系列三--XML数据查询函数(EXTRACT,EXTRACTVALUE,EXISTSNODE,XPATH,XPATH_EXISTS,XMLEXISTS) XML的简单使其易于在任何应用程序中读写数据,这使XML很快成为数据交换的一种公共语言。在不同平台下产生的信息,可以很容易加载XML数据到程序…...

【51单片机】点亮一个LED灯(看开发板原理图十分重要)

🎊专栏【51单片机】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【The Right Path】 🥰大一同学小吉,欢迎并且感谢大家指出我的问题🥰 目录 🍔基础内容 &#x1f3f3…...

数据可视化工具 - ECharts以及柱状图的编写

1 快速上手 引入echarts 插件文件到html页面中 <head><meta charset"utf-8"/><title>ECharts</title><!-- step1 引入刚刚下载的 ECharts 文件 --><script src"./echarts.js"></script> </head>准备一个…...

【AI绘画】——Midjourney关键词格式解析(常用参数分享)

目前在AI绘画模型中&#xff0c;Midjourney的效果是公认的top级别&#xff0c;但同时也是相对较难使用的&#xff0c;对小白来说比较难上手&#xff0c;主要就在于Mj没有webui&#xff0c;不能选择参数&#xff0c;怎么找到这些隐藏参数并且触发它是用好Mj的第一步。 今天就来…...

操作符知识点大全(简洁,全面,含使用场景,演示,代码)

目录 一.算术操作符 1.要点&#xff1a; 二.负数原码&#xff0c;反码&#xff0c;补码的互推 1.按位取反操作符&#xff1a;~&#xff08;二进制位&#xff09; 2.原反补互推演示 三.进制位的表示 1.不同进制位的特征&#xff1a; 2.二进制位表示 3.整型的二进制表…...

华工研究生语音课

这门课讲啥 语音蕴含的信息、语音识别的目的 语音的准平稳性、分帧、预加重、时域特征分析&#xff08;能量和过零率&#xff09;、端点检测&#xff08;双门限法&#xff09; 语音的基频及检测&#xff08;主要是自相关法、野点的处理&#xff09; 声音的产生过程&#xf…...

KingbaseES 原生XML系列二 -- XML数据操作函数

KingbaseES 原生XML系列二--XML数据操作函数(DELETEXML,APPENDCHILDXML,INSERTCHILDXML,INSERTCHILDXMLAFTER,INSERTCHILDXMLBEFORE,INSERTXMLAFTER,INSERTXMLBEFORE,UPDATEXML) XML的简单使其易于在任何应用程序中读写数据&#xff0c;这使XML很快成为数据交换的一种公共语言。…...

【Flink】DataStream API使用之源算子(Source)

源算子 创建环境之后&#xff0c;就可以构建数据的业务处理逻辑了&#xff0c;Flink可以从各种来源获取数据&#xff0c;然后构建DataStream进项转换。一般将数据的输入来源称为数据源&#xff08;data source&#xff09;&#xff0c;而读取数据的算子就叫做源算子&#xff08…...

树莓派硬件介绍及配件选择

目录 树莓派Datasheet下载地址&#xff1a; Raspberry 4B 外观图&#xff1a; 技术规格书&#xff1a; 性能介绍&#xff1a; 树莓派配件选用 电源的选用&#xff1a; 树莓派外壳选用&#xff1a; 内存卡/U盘选用 树莓派Datasheet下载地址&#xff1a; Raspberry Pi …...

O2OA (翱途) 平台 V8.0 发布新增数据台账能力

亲爱的小伙伴们&#xff0c;O2OA (翱途) 平台开发团队经过几个月的持续努力&#xff0c;实现功能的新增、优化以及问题的修复。2023 年度 V8.0 版本已正式发布。欢迎大家到 O2OA 的官网上下载进行体验&#xff0c;也希望大家在藕粉社区里多提宝贵建议。本篇我们先为大家介绍应用…...

数控解锁怎么解 数控系统解锁解密

Amazon Fargate 在中国区正式落地&#xff0c;因 数控解锁使用 Serverless 架构&#xff0c;更加适合对性能要求不敏感的服务使用&#xff0c;Pyroscope 是一款基于 Golang 开发的应用程序性能分析工具&#xff0c;Pyroscope 的服务端为无状态服务且性能要求不敏感&#xff0c;…...

3.0 响应式系统的设计与实现

1、Proxy代理对象 Proxy用于对一个普通对象代理&#xff0c;实现对象的拦截和自定义&#xff0c;如拦截其赋值、枚举、函数调用等。里面包含了很多组捕获器&#xff08;trap&#xff09;&#xff0c;在代理对象执行相应的操作时捕获&#xff0c;然后在内部实现自定义。 const…...

Rust 快速入门60分① 看完这篇就能写代码了

Rust 一门赋予每个人构建可靠且高效软件能力的语言https://hannyang.blog.csdn.net/article/details/130467813?spm1001.2014.3001.5502关于Rust安装等内容请参考上文链接&#xff0c;写完上文就在考虑写点关于Rust的入门文章&#xff0c;本专辑将直接从Rust基础入门内容开始讲…...

【5.JS基础-JavaScript的DOM操作】

1 认识DOM和BOM 所以我们学习DOM&#xff0c;就是在学习如何通过JavaScript对文档进行操作的&#xff1b; DOM Tree的理解 DOM的学习顺序 DOM的继承关系图 2 document对象 3 节点&#xff08;Node&#xff09;之间的导航&#xff08;navigator&#xff09; 4 元素&#xff0…...

【大数据之Hadoop】二十九、HDFS存储优化

纠删码和异构存储测试需要5台虚拟机。准备另外一套5台服务器集群。 环境准备&#xff1a; &#xff08;1&#xff09;克隆hadoop105为hadoop106&#xff0c;修改ip地址和hostname&#xff0c;然后重启。 vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname r…...

SuperMap GIS基础产品组件GIS FAQ集锦(2)

SuperMap GIS基础产品组件GIS FAQ集锦&#xff08;2&#xff09; 【iObjects for Spark】读取GDB参数该如何填写&#xff1f; 【解决办法】可参考以下示例&#xff1a; val GDB_params new util.HashMapString, java.io.Serializable GDB_params.put(FeatureRDDProviderParam…...

css实现圆环展示百分比,根据值动态展示所占比例

代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

多模态商品数据接口:融合图像、语音与文字的下一代商品详情体验

一、多模态商品数据接口的技术架构 &#xff08;一&#xff09;多模态数据融合引擎 跨模态语义对齐 通过Transformer架构实现图像、语音、文字的语义关联。例如&#xff0c;当用户上传一张“蓝色连衣裙”的图片时&#xff0c;接口可自动提取图像中的颜色&#xff08;RGB值&…...

OkHttp 中实现断点续传 demo

在 OkHttp 中实现断点续传主要通过以下步骤完成&#xff0c;核心是利用 HTTP 协议的 Range 请求头指定下载范围&#xff1a; 实现原理 Range 请求头&#xff1a;向服务器请求文件的特定字节范围&#xff08;如 Range: bytes1024-&#xff09; 本地文件记录&#xff1a;保存已…...

C++中string流知识详解和示例

一、概览与类体系 C 提供三种基于内存字符串的流&#xff0c;定义在 <sstream> 中&#xff1a; std::istringstream&#xff1a;输入流&#xff0c;从已有字符串中读取并解析。std::ostringstream&#xff1a;输出流&#xff0c;向内部缓冲区写入内容&#xff0c;最终取…...

高防服务器能够抵御哪些网络攻击呢?

高防服务器作为一种有着高度防御能力的服务器&#xff0c;可以帮助网站应对分布式拒绝服务攻击&#xff0c;有效识别和清理一些恶意的网络流量&#xff0c;为用户提供安全且稳定的网络环境&#xff0c;那么&#xff0c;高防服务器一般都可以抵御哪些网络攻击呢&#xff1f;下面…...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

热烈祝贺埃文科技正式加入可信数据空间发展联盟

2025年4月29日&#xff0c;在福州举办的第八届数字中国建设峰会“可信数据空间分论坛”上&#xff0c;可信数据空间发展联盟正式宣告成立。国家数据局党组书记、局长刘烈宏出席并致辞&#xff0c;强调该联盟是推进全国一体化数据市场建设的关键抓手。 郑州埃文科技有限公司&am…...