PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类

鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为:Iris Setosa,Iris Versicolour,Iris Virginica。
对于多分类任务,有较多机器学习的算法可以支持。本文将使用决策树、线性回归、SVM等多种算法来完成这一任务,并对不同方法进行比较。
01、使用Logistic实现鸢尾花分类
在前面介绍过Logistic用于二分类任务,对其进行扩展也用于多分类任务。下面将使用sklearn库完成一个基于Logistic的鸢尾花分类任务。如代码清单1所示,首先是导入sklearn.datasets包从而加载数据集,并将数据集按照测试集占比0.2随机分为训练集和测试集。
代码清单1 导入包以及加载数据集
rom sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import confusion_matrix, precision_score, accuracy_score,recall_score, f1_score, roc_auc_score, \roc_curve
import matplotlib.pyplot as plt# 加载数据集
def loadDataSet():iris_dataset = load_iris()X = iris_dataset.datay = iris_dataset.target# 将数据划分为训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)return X_train, X_test, y_train, y_test
如代码清单2所示,编写函数训练Logistic模型。
代码清单2 训练Logistic模型
# 训练Logistic模性
def trainLS(x_train, y_train):# Logistic生成和训练clf = LogisticRegression()clf.fit(x_train, y_train)return clf
Logistic模型较为简单,不需要额外设置超参数即可开始训练。如代码清单3所示,初始化Logistic模型并将模型在训练集上训练,返回训练好的模型。
代码清单3 测试模型及打印各种评价指标
# 测试模型
def test(model, x_test, y_test):# 将标签转换为one-hot形式y_one_hot = label_binarize(y_test, np.arange(3))# 预测结果y_pre = model.predict(x_test)# 预测结果的概率y_pre_pro = model.predict_proba(x_test)# 混淆矩阵con_matrix = confusion_matrix(y_test, y_pre)print('confusion_matrix:\n', con_matrix)print('accuracy:{}'.format(accuracy_score(y_test, y_pre)))print('precision:{}'.format(precision_score(y_test, y_pre, average='micro')))print('recall:{}'.format(recall_score(y_test, y_pre, average='micro')))print('f1-score:{}'.format(f1_score(y_test, y_pre, average='micro')))# 绘制ROC曲线drawROC(y_one_hot, y_pre_pro)
在预测结果时,为了方便后面绘制ROC曲线,需要首先将测试集的标签转化为one-hot的形式,并得到模型在测试集上预测结果的概率值即y_pre_pro,从而传入drawROC函数完成ROC曲线的绘制。除此外,该函数实现了输出混淆矩阵以及计算准确率、精确率、查全率以及f1-score的功能。
代码清单4 绘制ROC曲线
def drawROC(y_one_hot, y_pre_pro):# AUC值auc = roc_auc_score(y_one_hot, y_pre_pro, average='micro')# 绘制ROC曲线fpr, tpr, thresholds = roc_curve(y_one_hot.ravel(), y_pre_pro.ravel())plt.plot(fpr, tpr, linewidth=2, label='AUC=%.3f' % auc)plt.plot([0, 1], [0, 1], 'k--')plt.axis([0, 1.1, 0, 1.1])plt.xlabel('False Postivie Rate')plt.ylabel('True Positive Rate')plt.legend()plt.show()
如代码清单4所示为绘制ROC曲线的代码实现。最后将加载数据集,训练模型,以及模型验证的整个流程连接起来从而实现main函数,如代码清单5所示。
代码清单5 main函数设置
if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainLS(X_train, y_train)test(model, X_test, y_test)
将上述所有代码放在同一py脚本文件中,如图1所示可得最终的输出结果为

图1 命令行打印的测试结果
绘制得到的ROC曲线如图2所示。

图2 ROC曲线
Logistic是一个较为简单的模型,参数量较少,一般也用于较为简单的分类任务中,当任务更为复杂时,可以选取更为复杂的模型获得更好的效果,下面将使用不同的模型从而验证同一任务在不同模型下的表现。
02、使用决策树实现鸢尾花分类
由于只改动了模型,加载数据集、模型评价等其他部分的代码不需要改动,如代码清单6所示,增加新的函数用于训练决策树模型。
代码清单6 使用决策树模型进行训练
from sklearn import tree
# 训练决策树模性
def trainDT(x_train, y_train):# DT生成和训练clf = tree.DecisionTreeClassifier(criterion="entropy")clf.fit(x_train, y_train)return clf
同时修改main函数中调用的训练函数如代码清单7所示。
代码清单7 修改main函数内容
if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainDT(X_train, y_train)test(model, X_test, y_test)
最后运行可得命令行输出如图3所示。

图3 决策树模型预测结果
以及ROC曲线如图4所示。

图4 决策树模型绘制ROC曲线
相比Logistic模型,决策树模型无论在哪一项指标上都得到了更高的评分,且决策树模型不会像Logistic模型一样受初始化的影响,多次运行程序均可获得相同的输出模型,而Logistic模型运行多次会发现评价指标会在某个范围内上下抖动。
03、使用SVM实现鸢尾花分类
到现在相信大家都已经非常熟悉如何继续修改代码从而实现SVM模型的预测,实现SVM模型的训练代码如代码清单8所示
代码清单8 使用SVM模型进行训练
# 训练SVM模性
from sklearn import svm
def trainSVM(x_train, y_train):# SVM生成和训练clf = svm.SVC(kernel='rbf', probability=True)clf.fit(x_train, y_train)return clf
同时修改main函数,如代码清单9所示。
代码清单9 修改main函数内容
if __name__ == '__main__':X_train, X_test, y_train, y_test = loadDataSet()model = trainSVM(X_train, y_train)test(model, X_test, y_test)
程序运行输出如图5所示。

图5 使用SVM模型预测结果
绘制得到的ROC曲线如图6所示。

图6 使用SVM模型绘制的ROC曲线
可以发现,随着模型进一步变得复杂,最终预测的各项指标进一步上升,在三个模型中SVM模型的高斯核最终结果在测试集中表现得最好且没有发生过拟合的现象,因此可以选用SVM模型来完成鸢尾花分类这一任务。
相关文章:

PyTorch深度学习实战 | 基于线性回归、决策树和SVM进行鸢尾花分类
鸢尾花数据集是机器学习领域非常经典的一个分类任务数据集。它的英文名称为Iris Data Set,使用sklearn库可以直接下载并导入该数据集。数据集总共包含150行数据,每一行数据由4个特征值及一个标签组成。标签为三种不同类别的鸢尾花,分别为&…...

服务端接口优化方案
一、背景 针对老项目,去年做了许多降本增效的事情,其中发现最多的就是接口耗时过长的问题,就集中搞了一次接口性能优化。本文将给小伙伴们分享一下接口优化的通用方案。 二、接口优化方案总结 1. 批处理 批量思想:批量操作数据…...

【并发基础】Happens-Before模型详解
目录 一、Happens-Before模型简介 二、组成Happens-Before模型的八种规则 2.1 程序顺序规则(as-if-serial语义) 2.2 传递性规则 2.3 volatile变量规则 2.4 监视器锁规则 2.5 start规则 2.6 Join规则 一、Happens-Before模型简介 除了显示引用vo…...

Kubernetes系列---Kubernetes 理论知识 | 初识
Kubernetes系列---Kubernetes 理论知识 | 初识 1.K8s 是什么?2.K8s 特性3.小拓展(业务升级)4.K8s 集群架构与组件①架构拓扑图:②Master 组件③Node 组件 五 K8s 核心概念六 官方提供的三种部署方式总结 1.K8s 是什么?…...
KingbaseES 原生XML系列三--XML数据查询函数
KingbaseES 原生XML系列三--XML数据查询函数(EXTRACT,EXTRACTVALUE,EXISTSNODE,XPATH,XPATH_EXISTS,XMLEXISTS) XML的简单使其易于在任何应用程序中读写数据,这使XML很快成为数据交换的一种公共语言。在不同平台下产生的信息,可以很容易加载XML数据到程序…...

【51单片机】点亮一个LED灯(看开发板原理图十分重要)
🎊专栏【51单片机】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【The Right Path】 🥰大一同学小吉,欢迎并且感谢大家指出我的问题🥰 目录 🍔基础内容 🏳…...

数据可视化工具 - ECharts以及柱状图的编写
1 快速上手 引入echarts 插件文件到html页面中 <head><meta charset"utf-8"/><title>ECharts</title><!-- step1 引入刚刚下载的 ECharts 文件 --><script src"./echarts.js"></script> </head>准备一个…...

【AI绘画】——Midjourney关键词格式解析(常用参数分享)
目前在AI绘画模型中,Midjourney的效果是公认的top级别,但同时也是相对较难使用的,对小白来说比较难上手,主要就在于Mj没有webui,不能选择参数,怎么找到这些隐藏参数并且触发它是用好Mj的第一步。 今天就来…...

操作符知识点大全(简洁,全面,含使用场景,演示,代码)
目录 一.算术操作符 1.要点: 二.负数原码,反码,补码的互推 1.按位取反操作符:~(二进制位) 2.原反补互推演示 三.进制位的表示 1.不同进制位的特征: 2.二进制位表示 3.整型的二进制表…...
华工研究生语音课
这门课讲啥 语音蕴含的信息、语音识别的目的 语音的准平稳性、分帧、预加重、时域特征分析(能量和过零率)、端点检测(双门限法) 语音的基频及检测(主要是自相关法、野点的处理) 声音的产生过程…...
KingbaseES 原生XML系列二 -- XML数据操作函数
KingbaseES 原生XML系列二--XML数据操作函数(DELETEXML,APPENDCHILDXML,INSERTCHILDXML,INSERTCHILDXMLAFTER,INSERTCHILDXMLBEFORE,INSERTXMLAFTER,INSERTXMLBEFORE,UPDATEXML) XML的简单使其易于在任何应用程序中读写数据,这使XML很快成为数据交换的一种公共语言。…...

【Flink】DataStream API使用之源算子(Source)
源算子 创建环境之后,就可以构建数据的业务处理逻辑了,Flink可以从各种来源获取数据,然后构建DataStream进项转换。一般将数据的输入来源称为数据源(data source),而读取数据的算子就叫做源算子(…...

树莓派硬件介绍及配件选择
目录 树莓派Datasheet下载地址: Raspberry 4B 外观图: 技术规格书: 性能介绍: 树莓派配件选用 电源的选用: 树莓派外壳选用: 内存卡/U盘选用 树莓派Datasheet下载地址: Raspberry Pi …...

O2OA (翱途) 平台 V8.0 发布新增数据台账能力
亲爱的小伙伴们,O2OA (翱途) 平台开发团队经过几个月的持续努力,实现功能的新增、优化以及问题的修复。2023 年度 V8.0 版本已正式发布。欢迎大家到 O2OA 的官网上下载进行体验,也希望大家在藕粉社区里多提宝贵建议。本篇我们先为大家介绍应用…...
数控解锁怎么解 数控系统解锁解密
Amazon Fargate 在中国区正式落地,因 数控解锁使用 Serverless 架构,更加适合对性能要求不敏感的服务使用,Pyroscope 是一款基于 Golang 开发的应用程序性能分析工具,Pyroscope 的服务端为无状态服务且性能要求不敏感,…...

3.0 响应式系统的设计与实现
1、Proxy代理对象 Proxy用于对一个普通对象代理,实现对象的拦截和自定义,如拦截其赋值、枚举、函数调用等。里面包含了很多组捕获器(trap),在代理对象执行相应的操作时捕获,然后在内部实现自定义。 const…...

Rust 快速入门60分① 看完这篇就能写代码了
Rust 一门赋予每个人构建可靠且高效软件能力的语言https://hannyang.blog.csdn.net/article/details/130467813?spm1001.2014.3001.5502关于Rust安装等内容请参考上文链接,写完上文就在考虑写点关于Rust的入门文章,本专辑将直接从Rust基础入门内容开始讲…...

【5.JS基础-JavaScript的DOM操作】
1 认识DOM和BOM 所以我们学习DOM,就是在学习如何通过JavaScript对文档进行操作的; DOM Tree的理解 DOM的学习顺序 DOM的继承关系图 2 document对象 3 节点(Node)之间的导航(navigator) 4 元素࿰…...

【大数据之Hadoop】二十九、HDFS存储优化
纠删码和异构存储测试需要5台虚拟机。准备另外一套5台服务器集群。 环境准备: (1)克隆hadoop105为hadoop106,修改ip地址和hostname,然后重启。 vim /etc/sysconfig/network-scripts/ifcfg-ens33 vim /etc/hostname r…...

SuperMap GIS基础产品组件GIS FAQ集锦(2)
SuperMap GIS基础产品组件GIS FAQ集锦(2) 【iObjects for Spark】读取GDB参数该如何填写? 【解决办法】可参考以下示例: val GDB_params new util.HashMapString, java.io.Serializable GDB_params.put(FeatureRDDProviderParam…...

2025年能源电力系统与流体力学国际会议 (EPSFD 2025)
2025年能源电力系统与流体力学国际会议(EPSFD 2025)将于本年度在美丽的杭州盛大召开。作为全球能源、电力系统以及流体力学领域的顶级盛会,EPSFD 2025旨在为来自世界各地的科学家、工程师和研究人员提供一个展示最新研究成果、分享实践经验及…...
在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:
在 HarmonyOS 应用开发中,手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力,既支持点击、长按、拖拽等基础单一手势的精细控制,也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档,…...

MMaDA: Multimodal Large Diffusion Language Models
CODE : https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA,它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构…...
【算法训练营Day07】字符串part1
文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接:344. 反转字符串 双指针法,两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践
6月5日,2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席,并作《智能体在安全领域的应用实践》主题演讲,分享了在智能体在安全领域的突破性实践。他指出,百度通过将安全能力…...

用docker来安装部署freeswitch记录
今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...

python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...

【从零学习JVM|第三篇】类的生命周期(高频面试题)
前言: 在Java编程中,类的生命周期是指类从被加载到内存中开始,到被卸载出内存为止的整个过程。了解类的生命周期对于理解Java程序的运行机制以及性能优化非常重要。本文会深入探寻类的生命周期,让读者对此有深刻印象。 目录 …...

RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill
视觉语言模型(Vision-Language Models, VLMs),为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展,机器人仍难以胜任复杂的长时程任务(如家具装配),主要受限于人…...