当前位置: 首页 > news >正文

【Python】 剪辑法欠采样 CNN压缩近邻法欠采样

借鉴:关于K近邻(KNN),看这一篇就够了!算法原理,kd树,球树,KNN解决样本不平衡,剪辑法,压缩近邻法 - 知乎

但是不要看他里面的代码,因为作者把代码里的一些符号故意颠倒了 ,比如“==”改成“!=”,还有乱加“~”,看明白逻辑才能给他改过来

一、剪辑法

        当训练集数据中存在一部分不同类别数据的重叠时(在一部分程度上说明这部分数据的类别比较模糊),这部分数据会对模型造成一定的过拟合,那么一个简单的想法就是将这部分数据直接剔除掉即可,也就是剪辑法。

        剪辑法将训练集 D 随机分成两个部分,一部分作为新的训练集 Dtrain,一部分作为测试集 Dtest,然后基于 Dtrain,使用 KNN 的方法对 Dtest 进行分类,并将其中分类错误的样本从整体训练集 D 中剔除掉,得到 Dnew。

        由于对训练集 D 的划分是随机划分,难以保证数据重叠部分的样本在第一次剪辑时就被剔除,因此在得到 Dnew 后,可以对 Dnew 继续进行上述操作数次,这样可以得到一个比较清爽的类别分界。

        效果如下图:

        附上可直接运行的代码:

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=1000, n_features=2,n_informative=2, n_redundant=0, n_repeated=0,n_classes=4, n_clusters_per_class=1)# # # 画出二维散点图
# for label, _ in counter.items():
# 	row_ix = where(y == label)[0]
# 	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
# pyplot.legend()
# pyplot.show()# 剪辑10次
for i in range(10):x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5)k = 5KNN_clf = KNN(n_neighbors=k)KNN_clf.fit(x_train, y_train)  # 用训练集训练KNNy_predict = KNN_clf.predict(x_test)  # 用测试集测试cond = y_predict == y_testx_test = x_test[cond]  # 把预测错误的从整体数据集中剔除掉y_test = y_test[cond]  # 把预测错误的从整体数据集中剔除掉X = np.vstack([x_train, x_test])  # 为下一次循环做准备(剔除掉本轮预测错误的y = np.hstack([y_train, y_test])  # 为下一次循环做准备(剔除掉本轮预测错误的# summarize the new class distribution
counter = Counter(y)
print(counter)# 画出二维散点图
for label, _ in counter.items():row_ix = where(y == label)[0]pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

        以上使用了k=20的参数进行剪辑的结果,循环了10次,一般而言,k越大,被抛弃的样本会越多,因为被分类的错误的概率更大。

二、CNN压缩近邻法欠采样

        

        压缩近邻法的想法是认为同一类型的样本大量集中在类簇的中心,而这些集中在中心的样本对分类没有起到太大的作用,因此可以舍弃掉这些样本。

        其做法是将训练集随机分为两个部分,第一个部分为 store,占所有样本的 10% 左右,第二个部分为 grabbag,占所有样本的 90% 左右,然后将 store 作为训练集训练 KNN 模型,grabbag 作为测试集,将分类错误的样本从 grabbag 中移动到 store 里,然后继续用增加了样本的 store 和减少了样本的 grabbag 再次训练和测试 KNN 模型,直到 grabbag 中所有样本被分类正确,或者 grabbag 中样本数为0。

        在压缩结束之后,store 中存储的是初始化时随机选择的 10% 左右的样本,以及在之后每一次循环中被分类错误的样本,这些被分类错误的样本集中在类簇的边缘,认为是对分类作用较大的样本。

        CNN欠采样已经有相应的Python实现库了,相应的方法是CondensedNearestNeighbour(),下面是可直接运行的代码。

# Undersample and plot imbalanced dataset with the Condensed Nearest Neighbor Rule
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.under_sampling import CondensedNearestNeighbour
from matplotlib import pyplot
from numpy import where# make_classification方法用于生成分类任务的人造数据集
# X是数据,几维都可以,n_features=4表示4维
# y用0/1表示类别,weights调整0和1的占比
X, y = make_classification(n_samples=500, n_classes=2, n_features=3, n_redundant=0,# n_clusters_per_class表示每个类别多少簇  # flip_y噪声,增加分类难度n_clusters_per_class=2, weights=[0.5], flip_y=0, random_state=1)# summarize class distribution
counter = Counter(y)  # {0: 990, 1: 10} counter是一个字典,value存储类别,key存储类别个数
print(counter)# ==================CNN有直接可以调用的包  n_neighbors设置k值,k值越小越省时间,就设置为1吧
undersample = CondensedNearestNeighbour(n_neighbors=1)
# transform the dataset
X, y = undersample.fit_resample(X, y)# summarize the new class distribution
counter = Counter(y)
print(counter)# scatter plot of examples by class label
for label, _ in counter.items():row_ix = where(y == label)[0]pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

        但是我觉得这个CondensedNearestNeighbour()方法的可操作性太低,所以没用这个方法,而是根据CNN的原理(CNN底层是训练KNN)去写的

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import where# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=1000, n_features=2,n_informative=2, n_redundant=0, n_repeated=0,n_classes=4, n_clusters_per_class=1, random_state=1)
counter = Counter(y)
# 画出二维散点图
for label, _ in counter.items():row_ix = where(y == label)[0]pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()# 10%作为训练集,90%作为测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9)while True:k = 1KNN_clf = KNN(n_neighbors=k)KNN_clf.fit(x_train, y_train)y_predict = KNN_clf.predict(x_test)cond = y_predict == y_test  # cond记录分类的对与错,分类错是False,正确是True# 都分类正确,退出if  cond.all():print('所有测试集都分类正确,CNN正常结束')breakx_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里y_train = np.hstack([y_train, y_test[~cond]])x_test = x_test[cond]  # 把分类对的继续作为下一轮的测试集y_test = y_test[cond]if len(x_test) == 0:print("所有样本都能做到分类错误,也就是结果集=原始数据集,一般不会出现这种情况")break# summarize the new class distribution
counter = Counter(y_train)
print(counter)# 画出二维散点图
for label, _ in counter.items():row_ix = where(y_train == label)[0]pyplot.scatter(x_train[row_ix, 0], x_train[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

2.1 改进版——指定压缩后样本大小的CNN

在如下代码中,用sampleNum指定全体样本数量,用endNum指定压缩后样本数量

from sklearn import datasets
import matplotlib.pyplot as pyplot
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier as KNN
import numpy as np
from collections import Counter
from numpy import wheresampleNum = 1000
endNum = 500
k = 1  # KNN算法的K值
# make_classification用于手动构造数据
# 1000个样本,分成4类
X, y = datasets.make_classification(n_samples=sampleNum, n_features=2,n_informative=2, n_redundant=0, n_repeated=0,n_classes=4, n_clusters_per_class=1, random_state=1)
# counter = Counter(y)
# # 画出二维散点图
# for label, _ in counter.items():
# 	row_ix = where(y == label)[0]
# 	pyplot.scatter(X[row_ix, 0], X[row_ix, 1], label=str(label))
# pyplot.legend()
# pyplot.show()# 10%作为训练集,90%作为测试集
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9)
# print(x_train.shape[0])  # 100nowNum = x_train.shape[0]  # 用来控制 训练集/筛选后的样本数 满足resultNum就停下, 初始有x_train这么多个while True:KNN_clf = KNN(n_neighbors=k)KNN_clf.fit(x_train, y_train)y_predict = KNN_clf.predict(x_test)cond = y_predict == y_test  # cond记录分类的对与错,分类错是False,正确是True# 都分类正确,退出if cond.all():print('所有测试集都分类正确,CNN自动结束,但是结果集没凑够呢!')break# 如果结果集数量不够要求的endNum,继续下一轮if nowNum+y_test[~cond].shape[0] < endNum:nowNum = nowNum+y_test[~cond].shape[0]print("目前结果集数量:", nowNum)x_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里y_train = np.hstack([y_train, y_test[~cond]])x_test = x_test[cond]  # 把分类对的继续作为下一轮的测试集y_test = y_test[cond]# 如果结果集数量超过endNum,我们只要测试集里分类错误的前endNum-nowNum个else:# 记录前endNum-nowNum个的位置(截取位置condCut = 0  # 记录截取位置for i in range(cond.shape[0]):if not cond[i]:nowNum = nowNum + 1if nowNum == endNum:condCut = i  # 在cond[condCut]处刚好是我们要的第endNum个结果集样本break# 把cond[condCut]后面的都设置成Truecond[condCut+1:] = Truex_train = np.vstack([x_train, x_test[~cond]])  # 把分类错误(cond的值是False)的移动到训练集里y_train = np.hstack([y_train, y_test[~cond]])print("结果集的数量为", x_train.shape[0], "满足endNum=", endNum)breakif len(x_test) == 0:print("所有样本都能做到分类错误,也就是结果集=原始数据集,一般不会出现这种情况")break# summarize the new class distribution
counter = Counter(y_train)
print(counter)# 画出二维散点图
for label, _ in counter.items():row_ix = where(y_train == label)[0]pyplot.scatter(x_train[row_ix, 0], x_train[row_ix, 1], label=str(label))
pyplot.legend()
pyplot.show()

相关文章:

【Python】 剪辑法欠采样 CNN压缩近邻法欠采样

借鉴&#xff1a;关于K近邻&#xff08;KNN&#xff09;&#xff0c;看这一篇就够了&#xff01;算法原理&#xff0c;kd树&#xff0c;球树&#xff0c;KNN解决样本不平衡&#xff0c;剪辑法&#xff0c;压缩近邻法 - 知乎 但是不要看他里面的代码&#xff0c;因为作者把代码…...

springmvc+ssm+springboot房屋中介服务平台的设计与实现 i174z

本论文拟采用计算机技术设计并开发的房屋中介服务平台&#xff0c;主要是为用户提供服务。使得用户可以在系统上查看房屋出租、房屋出售、房屋求购、房屋求租&#xff0c;管理员对信息进行统一管理&#xff0c;与此同时可以筛选出符合的信息&#xff0c;给笔者提供更符合实际的…...

挑战30天学完Python:Day19 文件处理

&#x1f4d8; Day 19 &#x1f389; 本系列为Python基础学习&#xff0c;原稿来源于 30-Days-Of-Python 英文项目&#xff0c;大奇主要是对其本地化翻译、逐条验证和补充&#xff0c;想通过30天完成正儿八经的系统化实践。此系列适合零基础同学&#xff0c;或仅了解Python一点…...

Spring Boot application.properties和application.yml文件的配置

在Spring Boot中&#xff0c;application.properties 和 application.yml 文件用于配置应用程序的各个方面&#xff0c;如服务器端口、数据库连接、日志级别等。这两个文件是Spring Boot的配置文件&#xff0c;位于 src/main/resources 目录下。 application.properties 示例 …...

Unity单元测试

Unity单元测试是一个专门用于嵌入式单元测试的库, 现在简单讲下移植以及代码结构. 源码地址: GitHub - ThrowTheSwitch/Unity: Simple Unit Testing for C 1.我们只需要移植三个文件即可: unity.c, unity.h, unity_internals.h 2.然后添加需要测试的函数. 3.在main.c中添加…...

Spring Bean 的生命周期了解么?

Spring Bean 的生命周期基本流程 一个Spring的Bean从出生到销毁的全过程就是他的整个生命周期, 整个生命周期可以大致分为3个大的阶段 : 创建 使用 销毁 还可以分为5个小步骤 : 实例化(Bean的创建) , 初始化赋值, 注册Destruction回调 , Bean的正常使用 以及 Bean的销毁 …...

.ryabina勒索病毒数据怎么处理|数据解密恢复

导言&#xff1a; 随着网络安全威胁的不断增加&#xff0c;勒索软件已成为严重的威胁之一&#xff0c;.ryabina勒索病毒是其中之一。本文将介绍.ryabina勒索病毒的特点、数据恢复方法和预防措施&#xff0c;以帮助用户更好地应对这一威胁。当面对被勒索病毒攻击导致的数据文件…...

上网行为监控软件能够看到聊天内容吗

随着信息技术的不断发展&#xff0c;上网行为监控软件在企业网络安全管理中扮演着越来越重要的角色。 这类软件主要用于监控员工的上网行为&#xff0c;以确保工作效率和网络安全。 而在这其中&#xff0c;域智盾软件作为一款知名的上网行为监控软件&#xff0c;其功能和使用…...

Java知识点一

hello&#xff0c;大家好&#xff01;我们今天开启Java语言的学习之路&#xff0c;与C语言的学习内容有些许异同&#xff0c;今天我们来简单了解一下Java的基础知识。 一、数据类型 分两种&#xff1a;基本数据类型 引用数据类型 &#xff08;1&#xff09;整型 八种基本数…...

Django学习笔记-forms使用

1.创建forms.py文件,导入包 from django import forms from django.forms import fields from django.forms import widgets2. 创建EmployeeForm,继承forms.Form 3.创建testform.html文件 4.urls.py添加路由 5.views中导入forms 创建testform,编写代码 1).如果请求方式为GET,…...

BM100 设计LRU缓存结构(java实现)

一、题目 设计LRU(最近最少使用)缓存结构&#xff0c;该结构在构造时确定大小&#xff0c;假设大小为 capacity &#xff0c;操作次数是 n &#xff0c;并有如下功能: Solution(int capacity) 以正整数作为容量 capacity 初始化 LRU 缓存get(key)&#xff1a;如果关键字 key …...

论文阅读——ONE-PEACE

ONE-PEACE: EXPLORING ONE GENERAL REPRESENTATION MODEL TOWARD UNLIMITED MODALITIES 适应不同模态并且支持多模态交互。 预训练任务不仅能提取单模态信息&#xff0c;还能模态间对齐。 预训练任务通用且直接&#xff0c;使得他们可以应用到不同模态。 各个模态独立编码&am…...

围剿尚未终止 库迪深陷瑞幸9.9阳谋

文|智能相对论 作者|霖霖 总能被“累了困了”的打工人优先pick的咖啡&#xff0c;刚复工就顺利站上话题C位。 #瑞幸9.9元一杯活动缩水#的话题才爬上新浪微博热搜&#xff0c;“库迪咖啡河北分公司运营总监带头坑害河北联营商”的实名举报帖就出现在了小红书&#xff0c;一时…...

5G网络(接入网+承载网+核心网)

5G网络&#xff08;接入网承载网核心网&#xff09; 一、5G网络全网架构图 这张图分为左右两部分&#xff0c;右边为无线侧网络架构&#xff0c;左边为固定侧网络架构。 无线侧&#xff1a;手机或者集团客户通过基站接入到无线接入网&#xff0c;在接入网侧可以通过RTN或者IP…...

学习Markdown

https://shadows.brumm.af 欢迎使用Markdown编辑器 你好&#xff01; 这是你第一次使用 Markdown编辑器 所展示的欢迎页。如果你想学习如何使用Markdown编辑器, 可以仔细阅读这篇文章&#xff0c;了解一下Markdown的基本语法知识。 新的改变 我们对Markdown编辑器进行了一些…...

MySQL知识点总结(五)——锁

MySQL知识点总结&#xff08;五&#xff09;——锁 锁分类表锁 & 行锁如何添加表锁&#xff1f;如何添加行锁&#xff1f; 读锁 & 写锁行锁 & 间隙锁&#xff08;gap lock&#xff09;& 临键锁&#xff08;next-key lock&#xff09; 加锁机制分析可重复读隔离…...

IDEA 2023.2 配置 JavaWeb 工程

目录 1 不使用 Maven 创建 JavaWeb 工程 1.1 新建一个工程 1.2 配置 Tomcat 1.3 配置模块 Web 2 使用 Maven 配置 JavaWeb 工程 2.1 新建一个 Maven 工程 2.2 配置 Tomcat &#x1f4a5;提示&#xff1a;IDEA 只有专业版才能配置 JavaWeb 工程&#xff0c;若是社区版&am…...

软考40-上午题-【数据库】-关系代数运算2-专门的集合运算

一、专门的集合运算 1、投影 示例&#xff1a; 可以用属性名进行投影&#xff0c;也可以用列的序号进行投影。 2、选择 例题 1、笛卡尔积 2、投影 3、选择 3、连接 第一步都要算&#xff1a;笛卡尔积。 3-1、θ连接 示例&#xff1a; 3-2、等值连接 示例&#xff1a; 3-3、自…...

RHEL9安装Python2.7

RHEL9作为2022年5月新推出的版本&#xff0c;较RHEL8有了很多地方的改进&#xff0c;而且自带很多包&#xff0c;功能非常强大&#xff0c;稳定性和流畅度也较先前版本有了很大的提升。RHEL9自带python3.9&#xff0c;但是过高版本的python不可避免地会导致一些旧版本包地不兼容…...

更新至2022年世界各国数字经济发展相关指标(23个指标)

更新至2022年世界各国数字经济发展相关指标&#xff08;23个指标&#xff09; 1、时间&#xff1a;具体指标时间见下文 2、来源&#xff1a;WDI、世界银行、WEF、UNCTAD、SJR、国际电联 3、指标&#xff1a;移动网络覆盖率&#xff08;2000-2022&#xff09;、固定电话普及率…...

超短脉冲激光自聚焦效应

前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应&#xff0c;这是一种非线性光学现象&#xff0c;主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场&#xff0c;对材料产生非线性响应&#xff0c;可能…...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis&#xff1f;2.为什么要使用redis作为mysql的缓存&#xff1f;3.什么是缓存雪崩、缓存穿透、缓存击穿&#xff1f;3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

使用分级同态加密防御梯度泄漏

抽象 联邦学习 &#xff08;FL&#xff09; 支持跨分布式客户端进行协作模型训练&#xff0c;而无需共享原始数据&#xff0c;这使其成为在互联和自动驾驶汽车 &#xff08;CAV&#xff09; 等领域保护隐私的机器学习的一种很有前途的方法。然而&#xff0c;最近的研究表明&…...

渲染学进阶内容——模型

最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...

srs linux

下载编译运行 git clone https:///ossrs/srs.git ./configure --h265on make 编译完成后即可启动SRS # 启动 ./objs/srs -c conf/srs.conf # 查看日志 tail -n 30 -f ./objs/srs.log 开放端口 默认RTMP接收推流端口是1935&#xff0c;SRS管理页面端口是8080&#xff0c;可…...

有限自动机到正规文法转换器v1.0

1 项目简介 这是一个功能强大的有限自动机&#xff08;Finite Automaton, FA&#xff09;到正规文法&#xff08;Regular Grammar&#xff09;转换器&#xff0c;它配备了一个直观且完整的图形用户界面&#xff0c;使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...

Web 架构之 CDN 加速原理与落地实践

文章目录 一、思维导图二、正文内容&#xff08;一&#xff09;CDN 基础概念1. 定义2. 组成部分 &#xff08;二&#xff09;CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 &#xff08;三&#xff09;CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 &#xf…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表

##鸿蒙核心技术##运动开发##Sensor Service Kit&#xff08;传感器服务&#xff09;# 前言 在运动类应用中&#xff0c;运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据&#xff0c;如配速、距离、卡路里消耗等&#xff0c;用户可以更清晰…...