当前位置: 首页 > news >正文

机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

什么是knn算法?

KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm)。它是一种简单但非常有效的分类和回归算法。
该算法的基本思想是:对于一个新的输入样本,通过计算它与训练集中所有样本的距离,找到与它距离最近的K个训练集样本,然后基于这K个样本的类别信息来进行分类或回归预测。KNN算法中的“K”代表了在预测时使用的邻居数,通常需要手动设置。
KNN算法的主要优点是简单、易于实现,并且在某些情况下可以获得很好的分类或回归精度。但是,它也有一些缺点,如需要存储所有训练集样本、计算距离的开销较大、对于高维数据容易过拟合等。

KNN算法常用于分类问题,如文本分类、图像分类等,以及回归问题,如预测房价等。

我们这次学习机器学习的knn算法分别对前二维数据和前四维数据进行训练和可视化。

两个目标:

1、通过knn算法对iris数据集前两个维度的数据进行模型训练并求出错误率,最后进行可视化展示数据区域划分。
2、通过knn算法对iris数据集总共四个维度的数据进行模型训练并求出错误率,并对前四维数据进行可视化。

基本思路:

1、先载入iris数据集 Load Iris data
2、分离训练集和设置测试集split train and test sets
3、对数据进行标准化处理Normalize the data
4、使用knn模型进行训练Train using KNN
5、然后进行可视化处理Visualization
6、最后通过绘图决策平面plot decision plane

1、通过knn算法对iris数据集前两个维度的数据进行模型训练并求出错误率,最后进行可视化展示数据区域划分:

from sklearn import datasets
import numpy as np
### Load Iris data
iris = datasets.load_iris()
x = iris.data[:,:2]#前2个维度
# x = iris.data
y = iris.target
print("class labels: ", np.unique(y))
x.shape
y.shape
### split train and test sets
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
x_train.shape
print("Labels count in y:", np.bincount(y))
print("Labels count in y_train:", np.bincount(y_train))
print("Labels count in y_test:", np.bincount(y_test))
### Normalize the data
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
sc.fit(x_train)
x_train_std = sc.transform(x_train)
x_test_std = sc.transform(x_test)
print("TrainSets Orig mean:{}, std mean:{}".format(np.mean(x_train,axis=0), np.mean(x_train_std,axis=0)))
print("TrainSets Orig std:{}, std std:{}".format(np.std(x_train,axis=0), np.std(x_train_std,axis=0)))
print("TestSets Orig mean:{}, std mean:{}".format(np.mean(x_test,axis=0), np.mean(x_test_std,axis=0)))
print("TestSets Orig std:{}, std std:{}".format(np.std(x_test,axis=0), np.std(x_test_std,axis=0)))
### Train using KNN
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5, p=2, metric='minkowski')
knn.fit(x_train_std, y_train)
pred_test=knn.predict(x_test_std)
err_num = (pred_test != y_test).sum()
rate = err_num/y_test.size
print("Misclassfication num: {}\nError rate: {}".format(err_num, rate))#计算错误率
### Visualization
x_combined_std = np.vstack((x_train_std, x_test_std))
y_combined = np.hstack((y_train, y_test))
plot_decision_regions(x_combined_std, y_combined,
classifier=knn, test_idx=range(105,150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()
#### plot decision plane
x_combined_std = np.vstack((x_train_std, x_test_std))
y_combined = np.hstack((y_train, y_test))
plot_decision_regions(x_combined_std, y_combined,
classifier=knn, test_idx=range(105,150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')
plt.tight_layout()
plt.show()

代码及其可视化效果截图:

在这里插入图片描述

2、通过knn算法对iris数据集总共四个维度的数据进行模型训练并求出错误率并进行可视化:

from sklearn import datasets
import numpy as np
iris = datasets.load_iris()
x = iris.data  #4个维度
# x = iris.data
y = iris.target
print("class labels: ", np.unique(y))
x.shape
y.shape
### split train and test sets
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y)
x_train.shape
print("Labels count in y:", np.bincount(y))
print("Labels count in y_train:", np.bincount(y_train))
print("Labels count in y_test:", np.bincount(y_test))
### Normalize the data
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
sc.fit(x_train)
x_train_std = sc.transform(x_train)
x_test_std = sc.transform(x_test)
print("TrainSets Orig mean:{}, std mean:{}".format(np.mean(x_train,axis=0), np.mean(x_train_std,axis=0)))
print("TrainSets Orig std:{}, std std:{}".format(np.std(x_train,axis=0), np.std(x_train_std,axis=0)))
print("TestSets Orig mean:{}, std mean:{}".format(np.mean(x_test,axis=0), np.mean(x_test_std,axis=0)))
print("TestSets Orig std:{}, std std:{}".format(np.std(x_test,axis=0), np.std(x_test_std,axis=0)))
### Train using KNN
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5, p=2, metric='minkowski')
knn.fit(x_train_std, y_train)
pred_test=knn.predict(x_test_std)
err_num = (pred_test != y_test).sum()
rate = err_num/y_test.size
print("Misclassfication num: {}\nError rate: {}".format(err_num, rate))#计算错误率
#四维可视化在二维或三维空间中是无法呈现的,但我们可以使用降维技术来可视化数据。在这种情况下,我们可以使用主成分分析(PCA)或线性判别分析(LDA)等技术将数据降到二维或三维空间中,并在此空间中可视化数据。
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA# Perform PCA to reduce the dimensionality from 4D to 3D
pca = PCA(n_components=3)
x_train_pca = pca.fit_transform(x_train_std)# Create a 3D plot of the first three principal components
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')# Plot the three classes in different colors
for label, color in zip(np.unique(y_train), ['blue', 'red', 'green']):ax.scatter(x_train_pca[y_train==label, 0], x_train_pca[y_train==label, 1], x_train_pca[y_train==label, 2], c=color, label=label, alpha=0.8)ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
ax.legend(loc='upper right')
ax.set_title('Iris Dataset - PCA')
plt.show()

四维可视化在二维或三维空间中是无法呈现的,但我们可以使用降维技术来可视化数据。在这种情况下,我们可以使用主成分分析(PCA)或线性判别分析(LDA)等技术将数据降到二维或三维空间中,并在此空间中可视化数据。
下面是代码效果图,展示如何使用PCA将四维数据降至三维,并在三维空间中可视化iris数据集:
在这里插入图片描述

将iris数据集的四维数据降至三维,并在三维空间中可视化了训练集。每个点代表一个数据样本,不同颜色代表不同的类别。我们可以看到,在三维空间中,有两个类别可以相对清晰地分开,而另一个类别则分布在两个主成分的中间。

我们要注意对于高维数据使用knn算法容易出现高维数据容易过拟合的情况,这是因为在高维空间中,数据点之间的距离变得很大,同时训练样本的数量相对于特征的数量很少,容易导致KNN算法无法很好地进行预测。

为了避免高维数据容易过拟合的情况,可以采取以下措施:

  1. 特征选择:选择有意义的特征进行训练,可以降低特征数量,避免过拟合。常用的特征选择方法有Filter方法、Wrapper方法和Embedded方法。

  2. 降维:可以通过主成分分析(PCA)等方法将高维数据映射到低维空间中,以减少特征数量,避免过拟合。

  3. 调整K值:KNN算法中的K值决定了邻居的数量,K值过大容易出现欠拟合,而K值过小容易出现过拟合。因此,可以通过交叉验证等方法来确定最佳的K值。

  4. 距离度量:KNN算法中的距离度量方法对结果影响较大,不同的距离度量方法会导致不同的预测结果。因此,可以尝试不同的距离度量方法,选择最优的方法。

  5. 数据增强:在数据量较少的情况下,可以通过数据增强的方法来增加训练样本,以提高模型的泛化能力。

希望通过这片文章能够进一步认识knn算法的原理及其应用。
今天是五一劳动节,在这里小马同学祝各位五一劳动节快乐!

相关文章:

机器学习与深度学习——通过knn算法分类鸢尾花数据集iris求出错误率并进行可视化

什么是knn算法? KNN算法是一种基于实例的机器学习算法,其全称为K-最近邻算法(K-Nearest Neighbors Algorithm)。它是一种简单但非常有效的分类和回归算法。 该算法的基本思想是:对于一个新的输入样本,通过…...

【MySQL】MySQL基础知识详解

文章目录 1. MySQL概述1.1 数据库相关概念1.1.1 数据库、数据库管理系统与SQL1.1.2 关系型数据库与数据模型 1.2 MySQL数据库1.2.1 MySQL的安装与配置1.2.2 MySQL服务的启动与停止1.2.3 连接MySQL服务端 2. SQL2.1 SQL简介2.2 DDL2.2.1 数据库操作2.2.2 表操作2.2.2.1 创建表2.…...

RabbitMQ日常使用小结

一、使用场景 削峰、解耦、异步。 基于AMQP(高级消息队列协议)协议来统一数据交互,通过channel(网络信道)传递信息。erlang语言开发,并发量12000,支持持久化,稳定性好,集群不支持动态扩展。 RabbitMQ的基本概念 二、组成及工作流…...

​​​​​​​博物馆文物馆藏环境空气质量无线监控系统方案

博物馆文物馆藏环境空气质量无线监控系统方案 博物馆无线环境测控系统 博物馆恒温恒湿消毒净化系统 现代化博物馆空气质量一体化3D可视化管控平台 博物馆温湿度在线监控系统 博物馆光照在线监控系统 博物馆二氧化碳在线监控系统 博物馆在线监控系统 博物馆紫外线在线监控…...

测试理论----Bug的严重程度(Severity)和优先级(Priority)的分类

【原文链接】测试理论----Bug的严重程度(Severity)和优先级(Priority)的分类 一、Bug的严重程度(Severity) Bug的Severity(严重程度)指的是一个Bug对软件系统功能影响的程度&#…...

斯坦福、Nautilus Chain等联合主办的 Hackathon 活动,现已接受报名

由 Stanford Blockchain Accelerator、Zebec Protocol、 Nautilus Chain、Rootz Lab 共同主办的黑客松活动,现已接受优秀项目提交参赛申请。 在加密行业发展早期,密码极客们就始终在对区块链世界基础设施,在发展方向的无限可能性进行探索。而…...

00后卷王,把我们这些老油条卷的辞职信都写好了........

现在的小年轻真的卷得过分了。 前段时间我们公司来了个00年的,工作没两年,跳槽到我们公司起薪18K,都快接近我了。 后来才知道人家是个卷王,从早干到晚就差搬张床到工位睡觉了。 最近和他聊了一次天,原来这位小老弟家…...

JavaEE(系列12) -- 常见锁策略

目录 1. 乐观锁和悲观锁 2. 轻量级锁与重量级锁 3. 自旋锁和挂起等待锁 4. 互斥锁和读写锁 5. 可重入锁与不可重入锁 6. 死锁 6.1 死锁的必要条件 6.2 如何避免死锁 7. 公平锁和非公平锁 8. Synchronized原理及加锁过程 8.1 Synchronized 小结 8.2 加锁工作过程 8.2.1 偏向锁…...

前端nginx接口跨域

前提:vue项目本地接口通过proxy都可使用,但是项目部署在服务器上后发现所有接口出现503如下状况 简而言之:页面部署在域名为https://aa.bb.cc.com/vehicle/#/下,但是我接口需访问的是https:// azz.qqv.com/permission/company/gro…...

【国产虚拟仪器】基于 ZYNQ 的电能质量系统高速数据采集系统设计

随着电网中非线性负荷用户的不断增加 , 电能质量问题日益严重 。 高精度数据采集系统能够为电能质 量分析提供准确的数据支持 , 是解决电能质量问题的关键依据 。 通过对比现有高速采集系统的设计方案 , 主 控电路多以 ARM 微控制器搭配…...

Java前缀和算法

一.什么是前缀和算法 通俗来讲,前缀和算法就是使用一个新数组来储存原数组中前n-1个元素的和(如果新数组的当前元素的下标为n,计算当前元素的值为原数组中从0到n-1下标数组元素的和),可能这样讲起来有点抽象&#xff0…...

pico 的两个双核相关函数的延时问题

pico高级API函数中, multicore_fifo_pop_timeout_us 和 multicore_fifo_push_timeout_us 的延时参数, 如修改为500微秒以上时,其延时似乎远远超过设定值,其反馈速度似乎被主核的交互所左右 ,而修改为200以下时&#x…...

Doxygen源码分析: QCString类依赖的qstr系列C函数浅析

2023-05-20 17:02:21 ChrisZZ imzhuofoxmailcom Hompage https://github.com/zchrissirhcz 文章目录 1. doxygen 版本2. QCString 类简介3. qstr 系列函数浅析qmemmove()qsnprintfqstrdup()qstrfree()qstrlen()qstrcpy()qstrncpy()qisempty()qstrcmp()qstrncmp()qisspace()qstr…...

华为OD机试之一种字符串压缩表示的解压(Java源码)

一种字符串压缩表示的解压 题目描述 有一种简易压缩算法:针对全部由小写英文字母组成的字符串,将其中连续超过两个相同字母的部分压缩为连续个数加该字母,其他部分保持原样不变。 例如:字符串“aaabbccccd”经过压缩成为字符串“…...

Microsoft Project Online部署方案

目录 一、前言 二、Microsoft Project Online简介 三、Microsoft Project Online的优势 1、云端部署 2、多设备支持...

飞浆AI studio人工智能课程学习(3)-在具体场景下优化Prompt

文章目录 在具体场景下优化Prompt营销场景办公效率场景日常生活场景海报背景图生成办公效率场景预设Prompt 生活场景中日常学习Prompt: 给写完的代码做文档 将优质Prompt模板化Prompt 1:Prompt 1:Prompt 2步骤文本过长而导致遗失信息的示例修改后 特殊示例 如何提升安全性主要目…...

企业工程行业管理系统源码-专业的工程管理软件-提供一站式服务

Java版工程项目管理系统 Spring CloudSpring BootMybatisVueElementUI前后端分离 功能清单如下: 首页 工作台:待办工作、消息通知、预警信息,点击可进入相应的列表 项目进度图表:选择(总体或单个)项目显示1…...

Ehcache 整合Spring 使用页面、对象缓存

Ehcache在很多项目中都出现过,用法也比较简单。一般的加些配置就可以了,而且Ehcache可以对页面、对象、数据进行缓存,同时支持集群/分布式缓存。如果整合Spring、Hibernate也非常的简单,Spring对Ehcache的支持也非常好。EHCache支…...

Spring Cloud中的服务路由与负载均衡

Spring Cloud中的服务路由与负载均衡 一、服务路由1. 服务发现2. 服务注册3. 服务消费4. 服务提供5. 服务路由实现 二、负载均衡1. 负载均衡的概念2. 负载均衡算法3. 负载均衡实现4. 负载均衡策略5. 使用Spring Cloud实现负载均衡 三、服务路由与负载均衡的集成1. 集成背景2. 集…...

rails routes的使用

Rails routes 是用于确定应该将请求发送到哪个控制器和操作的一种机制。在 Rails 应用程序中,可以通过定义路由来映射 URL 到控制器操作。可以使用 rails routes 命令查看当前应用程序中定义的所有路由。 以下是一些常见的用法: 查看所有路由&#xff…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析

今天聊的内容,我认为是AI开发里面非常重要的内容。它在AI开发里无处不在,当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗",或者让翻译模型 "将这段合同翻译成商务日语" 时,输入的这句话就是 Prompt。…...

从零实现富文本编辑器#5-编辑器选区模型的状态结构表达

先前我们总结了浏览器选区模型的交互策略,并且实现了基本的选区操作,还调研了自绘选区的实现。那么相对的,我们还需要设计编辑器的选区表达,也可以称为模型选区。编辑器中应用变更时的操作范围,就是以模型选区为基准来…...

3.3.1_1 检错编码(奇偶校验码)

从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…...

使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装

以下是基于 vant-ui&#xff08;适配 Vue2 版本 &#xff09;实现截图中照片上传预览、删除功能&#xff0c;并封装成可复用组件的完整代码&#xff0c;包含样式和逻辑实现&#xff0c;可直接在 Vue2 项目中使用&#xff1a; 1. 封装的图片上传组件 ImageUploader.vue <te…...

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…...

LLM基础1_语言模型如何处理文本

基于GitHub项目&#xff1a;https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken&#xff1a;OpenAI开发的专业"分词器" torch&#xff1a;Facebook开发的强力计算引擎&#xff0c;相当于超级计算器 理解词嵌入&#xff1a;给词语画"…...

Axios请求超时重发机制

Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式&#xff1a; 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...