当前位置: 首页 > news >正文

机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

什么是knn算法?

KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm)。它是一种简单但非常有效的分类和回归算法。
该算法的基本思想是:对于一个新的输入样本,通过计算它与训练集中所有样本的距离,找到与它距离最近的K个训练集样本,然后基于这K个样本的类别信息来进行分类或回归预测。KNN算法中的“K”代表了在预测时使用的邻居数,通常需要手动设置。
KNN算法的主要优点是简单、易于实现,并且在某些情况下可以获得很好的分类或回归精度。但是,它也有一些缺点,如需要存储所有训练集样本、计算距离的开销较大、对于高维数据容易过拟合等。

KNN算法常用于分类问题,如文本分类、图像分类等,以及回归问题,如预测房价等。

我们这次学习机器学习的knn算法分别对前二维数据和前四维数据进行训练和可视化。

两个目标:

1、通过knn算法对iris数据集前两个维度的数据进行模型训练并求出错误率,最后进行可视化展示数据区域划分。
2、通过knn算法对iris数据集总共四个维度的数据进行模型训练并求出错误率,并对前四维数据进行可视化。

基本思路:

1、先载入iris数据集 Load Iris data
2、分离训练集和设置测试集split train and test sets
3、对数据进行标准化处理Normalize the data
4、使用knn模型进行训练Train using KNN
5、然后进行可视化处理Visualization
6、最后通过绘图决策平面plot decision plane

1、通过knn算法对iris数据集前两个维度的数据进行模型训练并求出错误率,最后进行可视化展示数据区域划分:

from sklearn import datasets
import numpy as np
### Load Iris data
iris = datasets.load_iris()
x = iris.data[:,:2]#前2个维度
# x = iris.data
y = iris.target
print("class labels: ", np.unique(y))
x.shape
y.shape
### split train and test sets
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
x_train.shape
print("Labels count in y:", np.bincount(y))
print("Labels count in y_train:", np.bincount(y_train))
print("Labels count in y_test:", np.bincount(y_test))
### Normalize the data
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
sc.fit(x_train)
x_train_std = sc.transform(x_train)
x_test_std = sc.transform(x_test)
print("TrainSets Orig mean:{}, std mean:{}".format(np.mean(x_train,axis=0), np.mean(x_train_std,axis=0)))
print("TrainSets Orig std:{}, std std:{}".format(np.std(x_train,axis=0), np.std(x_train_std,axis=0)))
print("TestSets Orig mean:{}, std mean:{}".format(np.mean(x_test,axis=0), np.mean(x_test_std,axis=0)))
print("TestSets Orig std:{}, std std:{}".format(np.std(x_test,axis=0), np.std(x_test_std,axis=0)))
### Train using KNN
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5, p=2, metric='minkowski')
knn.fit(x_train_std, y_train)
pred_test=knn.predict(x_test_std)
err_num = (pred_test != y_test).sum()
rate = err_num/y_test.size
print("Misclassfication num: {}\nError rate: {}".format(err_num, rate))#计算错误率
### Visualization
x_combined_std = np.vstack((x_train_std, x_test_std))
y_combined = np.hstack((y_train, y_test))
plot_decision_regions(x_combined_std, y_combined,
classifier=knn, test_idx=range(105,150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()
#### plot decision plane
x_combined_std = np.vstack((x_train_std, x_test_std))
y_combined = np.hstack((y_train, y_test))
plot_decision_regions(x_combined_std, y_combined,
classifier=knn, test_idx=range(105,150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()

代码及其可视化效果截图:

在这里插入图片描述

2、通过knn算法对iris数据集总共四个维度的数据进行模型训练并求出错误率并进行可视化:

from sklearn import datasets
import numpy as np
iris = datasets.load_iris()
x = iris.data  #4个维度
# x = iris.data
y = iris.target
print("class labels: ", np.unique(y))
x.shape
y.shape
### split train and test sets
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
x_train.shape
print("Labels count in y:", np.bincount(y))
print("Labels count in y_train:", np.bincount(y_train))
print("Labels count in y_test:", np.bincount(y_test))
### Normalize the data
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
sc.fit(x_train)
x_train_std = sc.transform(x_train)
x_test_std = sc.transform(x_test)
print("TrainSets Orig mean:{}, std mean:{}".format(np.mean(x_train,axis=0), np.mean(x_train_std,axis=0)))
print("TrainSets Orig std:{}, std std:{}".format(np.std(x_train,axis=0), np.std(x_train_std,axis=0)))
print("TestSets Orig mean:{}, std mean:{}".format(np.mean(x_test,axis=0), np.mean(x_test_std,axis=0)))
print("TestSets Orig std:{}, std std:{}".format(np.std(x_test,axis=0), np.std(x_test_std,axis=0)))
### Train using KNN
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5, p=2, metric='minkowski')
knn.fit(x_train_std, y_train)
pred_test=knn.predict(x_test_std)
err_num = (pred_test != y_test).sum()
rate = err_num/y_test.size
print("Misclassfication num: {}\nError rate: {}".format(err_num, rate))#计算错误率
#四维可视化在二维或三维空间中是无法呈现的,但我们可以使用降维技术来可视化数据。在这种情况下,我们可以使用主成分分析(PCA)或线性判别分析(LDA)等技术将数据降到二维或三维空间中,并在此空间中可视化数据。
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA# Perform PCA to reduce the dimensionality from 4D to 3D
pca = PCA(n_components=3)
x_train_pca = pca.fit_transform(x_train_std)# Create a 3D plot of the first three principal components
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')# Plot the three classes in different colors
for label, color in zip(np.unique(y_train), ['blue', 'red', 'green']):ax.scatter(x_train_pca[y_train==label, 0], x_train_pca[y_train==label, 1], x_train_pca[y_train==label, 2], c=color, label=label, alpha=0.8)ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
ax.legend(loc='upper right')
ax.set_title('Iris Dataset - PCA')
plt.show()

四维可视化在二维或三维空间中是无法呈现的,但我们可以使用降维技术来可视化数据。在这种情况下,我们可以使用主成分分析(PCA)或线性判别分析(LDA)等技术将数据降到二维或三维空间中,并在此空间中可视化数据。
下面是代码效果图,展示如何使用PCA将四维数据降至三维,并在三维空间中可视化iris数据集:
在这里插入图片描述

将iris数据集的四维数据降至三维,并在三维空间中可视化了训练集。每个点代表一个数据样本,不同颜色代表不同的类别。我们可以看到,在三维空间中,有两个类别可以相对清晰地分开,而另一个类别则分布在两个主成分的中间。

我们要注意对于高维数据使用knn算法容易出现高维数据容易过拟合的情况,这是因为在高维空间中,数据点之间的距离变得很大,同时训练样本的数量相对于特征的数量很少,容易导致KNN算法无法很好地进行预测。

为了避免高维数据容易过拟合的情况,可以采取以下措施:

  1. 特征选择:选择有意义的特征进行训练,可以降低特征数量,避免过拟合。常用的特征选择方法有Filter方法、Wrapper方法和Embedded方法。

  2. 降维:可以通过主成分分析(PCA)等方法将高维数据映射到低维空间中,以减少特征数量,避免过拟合。

  3. 调整K值:KNN算法中的K值决定了邻居的数量,K值过大容易出现欠拟合,而K值过小容易出现过拟合。因此,可以通过交叉验证等方法来确定最佳的K值。

  4. 距离度量:KNN算法中的距离度量方法对结果影响较大,不同的距离度量方法会导致不同的预测结果。因此,可以尝试不同的距离度量方法,选择最优的方法。

  5. 数据增强:在数据量较少的情况下,可以通过数据增强的方法来增加训练样本,以提高模型的泛化能力。

希望通过这片文章能够进一步认识knn算法的原理及其应用。
今天是五一劳动节,在这里小马同学祝各位五一劳动节快乐!

相关文章:

机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

什么是knn算法? KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm)。它是一种简单但非常有效的分类和回归算法。 该算法的基本思想是:对于一个新的输入样本,通过…...

【MySQL】MySQL基础知识详解

文章目录 1. MySQL概述1.1 数据库相关概念1.1.1 数据库、数据库管理系统与SQL1.1.2 关系型数据库与数据模型 1.2 MySQL数据库1.2.1 MySQL的安装与配置1.2.2 MySQL服务的启动与停止1.2.3 连接MySQL服务端 2. SQL2.1 SQL简介2.2 DDL2.2.1 数据库操作2.2.2 表操作2.2.2.1 创建表2.…...

RabbitMQ日常使用小结

一、使用场景 削峰、解耦、异步。 基于AMQP(高级消息队列协议)协议来统一数据交互,通过channel(网络信道)传递信息。erlang语言开发,并发量12000,支持持久化,稳定性好,集群不支持动态扩展。 RabbitMQ的基本概念 二、组成及工作流…...

​​​​​​​博物馆文物馆藏环境空气质量无线监控系统方案

博物馆文物馆藏环境空气质量无线监控系统方案 博物馆无线环境测控系统 博物馆恒温恒湿消毒净化系统 现代化博物馆空气质量一体化3D可视化管控平台 博物馆温湿度在线监控系统 博物馆光照在线监控系统 博物馆二氧化碳在线监控系统 博物馆在线监控系统 博物馆紫外线在线监控…...

测试理论----Bug的严重程度(Severity)和优先级(Priority)的分类

【原文链接】测试理论----Bug的严重程度(Severity)和优先级(Priority)的分类 一、Bug的严重程度(Severity) Bug的Severity(严重程度)指的是一个Bug对软件系统功能影响的程度&#…...

斯坦福、Nautilus Chain等联合主办的 Hackathon 活动,现已接受报名

由 Stanford Blockchain Accelerator、Zebec Protocol、 Nautilus Chain、Rootz Lab 共同主办的黑客松活动,现已接受优秀项目提交参赛申请。 在加密行业发展早期,密码极客们就始终在对区块链世界基础设施,在发展方向的无限可能性进行探索。而…...

00后卷王,把我们这些老油条卷的辞职信都写好了........

现在的小年轻真的卷得过分了。 前段时间我们公司来了个00年的,工作没两年,跳槽到我们公司起薪18K,都快接近我了。 后来才知道人家是个卷王,从早干到晚就差搬张床到工位睡觉了。 最近和他聊了一次天,原来这位小老弟家…...

JavaEE(系列12) -- 常见锁策略

目录 1. 乐观锁和悲观锁 2. 轻量级锁与重量级锁 3. 自旋锁和挂起等待锁 4. 互斥锁和读写锁 5. 可重入锁与不可重入锁 6. 死锁 6.1 死锁的必要条件 6.2 如何避免死锁 7. 公平锁和非公平锁 8. Synchronized原理及加锁过程 8.1 Synchronized 小结 8.2 加锁工作过程 8.2.1 偏向锁…...

前端nginx接口跨域

前提:vue项目本地接口通过proxy都可使用,但是项目部署在服务器上后发现所有接口出现503如下状况 简而言之:页面部署在域名为https://aa.bb.cc.com/vehicle/#/下,但是我接口需访问的是https:// azz.qqv.com/permission/company/gro…...

【国产虚拟仪器】基于 ZYNQ 的电能质量系统高速数据采集系统设计

随着电网中非线性负荷用户的不断增加 , 电能质量问题日益严重 。 高精度数据采集系统能够为电能质 量分析提供准确的数据支持 , 是解决电能质量问题的关键依据 。 通过对比现有高速采集系统的设计方案 , 主 控电路多以 ARM 微控制器搭配…...

Java前缀和算法

一.什么是前缀和算法 通俗来讲,前缀和算法就是使用一个新数组来储存原数组中前n-1个元素的和(如果新数组的当前元素的下标为n,计算当前元素的值为原数组中从0到n-1下标数组元素的和),可能这样讲起来有点抽象&#xff0…...

pico 的两个双核相关函数的延时问题

pico高级API函数中, multicore_fifo_pop_timeout_us 和 multicore_fifo_push_timeout_us 的延时参数, 如修改为500微秒以上时,其延时似乎远远超过设定值,其反馈速度似乎被主核的交互所左右 ,而修改为200以下时&#x…...

Doxygen源码分析: QCString类依赖的qstr系列C函数浅析

2023-05-20 17:02:21 ChrisZZ imzhuofoxmailcom Hompage https://github.com/zchrissirhcz 文章目录 1. doxygen 版本2. QCString 类简介3. qstr 系列函数浅析qmemmove()qsnprintfqstrdup()qstrfree()qstrlen()qstrcpy()qstrncpy()qisempty()qstrcmp()qstrncmp()qisspace()qstr…...

华为OD机试之一种字符串压缩表示的解压(Java源码)

一种字符串压缩表示的解压 题目描述 有一种简易压缩算法:针对全部由小写英文字母组成的字符串,将其中连续超过两个相同字母的部分压缩为连续个数加该字母,其他部分保持原样不变。 例如:字符串“aaabbccccd”经过压缩成为字符串“…...

Microsoft Project Online部署方案

目录 一、前言 二、Microsoft Project Online简介 三、Microsoft Project Online的优势 1、云端部署 2、多设备支持...

飞浆AI studio人工智能课程学习(3)-在具体场景下优化Prompt

文章目录 在具体场景下优化Prompt营销场景办公效率场景日常生活场景海报背景图生成办公效率场景预设Prompt 生活场景中日常学习Prompt: 给写完的代码做文档 将优质Prompt模板化Prompt 1:Prompt 1:Prompt 2步骤文本过长而导致遗失信息的示例修改后 特殊示例 如何提升安全性主要目…...

企业工程行业管理系统源码-专业的工程管理软件-提供一站式服务

Java版工程项目管理系统 Spring CloudSpring BootMybatisVueElementUI前后端分离 功能清单如下: 首页 工作台:待办工作、消息通知、预警信息,点击可进入相应的列表 项目进度图表:选择(总体或单个)项目显示1…...

Ehcache 整合Spring 使用页面、对象缓存

Ehcache在很多项目中都出现过,用法也比较简单。一般的加些配置就可以了,而且Ehcache可以对页面、对象、数据进行缓存,同时支持集群/分布式缓存。如果整合Spring、Hibernate也非常的简单,Spring对Ehcache的支持也非常好。EHCache支…...

Spring Cloud中的服务路由与负载均衡

Spring Cloud中的服务路由与负载均衡 一、服务路由1. 服务发现2. 服务注册3. 服务消费4. 服务提供5. 服务路由实现 二、负载均衡1. 负载均衡的概念2. 负载均衡算法3. 负载均衡实现4. 负载均衡策略5. 使用Spring Cloud实现负载均衡 三、服务路由与负载均衡的集成1. 集成背景2. 集…...

rails routes的使用

Rails routes 是用于确定应该将请求发送到哪个控制器和操作的一种机制。在 Rails 应用程序中,可以通过定义路由来映射 URL 到控制器操作。可以使用 rails routes 命令查看当前应用程序中定义的所有路由。 以下是一些常见的用法: 查看所有路由&#xff…...

为什么ITK在医学影像分析中如此强大?深入解析其Pipeline设计原理

为什么ITK在医学影像分析中如此强大?深入解析其Pipeline设计原理 医学影像处理领域对计算效率和精度有着近乎苛刻的要求,而ITK(InsightToolkit)正是在这样的需求背景下成长为行业标杆的开源工具包。当我们需要处理CT扫描的数百层切…...

GHelper:华硕笔记本的轻量级控制中心 - 简单高效的硬件管理方案

GHelper:华硕笔记本的轻量级控制中心 - 简单高效的硬件管理方案 【免费下载链接】g-helper Lightweight, open-source control tool for ASUS laptops and ROG Ally. Manage performance modes, fans, GPU, battery, and RGB lighting across Zephyrus, Flow, TUF, …...

基于STM32LXXX的数字电位器(DS3502U+TR)驱动应用程序设计

一、简介: DS3502 是 Maxim Integrated(现为 ADI 旗下)推出的一款高压、非易失数字电位器。 二、主要技术特性: 参数 规格 抽头数 128 个(7 位分辨率) 端到端电阻 10kΩ 电阻精度 20% 接口类型 IC(标准/快速模式,最高 400kHz) 数字工作电压 2.5V ~ 5.5V 模拟工作电压…...

OpenClaw技能扩展实战:用Qwen3-14B镜像自动生成技术文档

OpenClaw技能扩展实战:用Qwen3-14B镜像自动生成技术文档 1. 为什么需要自动化文档生成 作为一个经常需要编写技术文档的开发者,我长期被两个问题困扰:一是文档写作耗时太长,二是维护成本太高。每次代码更新后,文档版…...

OpenClaw技能市场巡礼:Top10Qwen3.5-9B增强插件测评

OpenClaw技能市场巡礼:Top10 Qwen3.5-9B增强插件测评 1. 为什么需要关注OpenClaw技能市场? 第一次接触OpenClaw时,我被它"AI操控电脑"的核心能力震撼,但真正让我持续使用的原因是它的技能市场(ClawHub&…...

AGI通用人工智能:离我们还有多远

AGI通用人工智能:离我们还有多远📝 本章学习目标:通过本章学习,你将全面掌握"AGI通用人工智能:离我们还有多远"这一核心主题,建立系统性认知。一、引言:为什么这个话题如此重要 在人工…...

穿透式监管是什么?终于有人把穿透式监管落地讲明白了!

最近,各位老板有没有发现各种审计、检查多起来了?国资委、集团总部的发文一个接一个,问题也越来越细致。最近大家都被穿透式监管这个词弄得有点紧张,害怕自己的企业那天也被点名。其实,穿透式监管对企业来说&#xff0…...

如何彻底解决微信QQ消息撤回难题?3步打造终极防撤回方案

如何彻底解决微信QQ消息撤回难题?3步打造终极防撤回方案 【免费下载链接】RevokeMsgPatcher :trollface: A hex editor for WeChat/QQ/TIM - PC版微信/QQ/TIM防撤回补丁(我已经看到了,撤回也没用了) 项目地址: https://gitcode.…...

「码动四季·开源同行」go语言:如何使用 ELK 进行日志采集以及统一处理?

在前面的一系列文章中,我们介绍了微服务各个组件的相关实践,从本文开始我们将会介绍微服务日常开发的一些"利器”,这些工具会帮助我们构建更加健壮的微服务系统,并帮助排查解决微服务系统中的问题与性能瓶颈等。ELK 技术栈本…...

企业级消息保留技术实现:3大核心机制深度解析与完整部署方案

企业级消息保留技术实现:3大核心机制深度解析与完整部署方案 【免费下载链接】RevokeMsgPatcher :trollface: A hex editor for WeChat/QQ/TIM - PC版微信/QQ/TIM防撤回补丁(我已经看到了,撤回也没用了) 项目地址: https://gitc…...