决策树的生成与剪枝
决策树的生成与剪枝
- 决策树的生成
- 生成决策树的过程
- 决策树的生成算法
- 决策树的剪枝
- 决策树的损失函数
- 决策树的剪枝算法
- 代码
决策树的生成
生成决策树的过程
为了方便分析描述,我们对上节课中的训练样本进行编号,每个样本加一个ID值,如图所示:
从根结点开始生成决策树,先将上述训练样本(1-9)全部放在根节点中。
然后选择信息增益或信息增益比最大的特征向下分裂,按照所选特征取值的不同将训练样本分发到不同的子结点中。
例如所选特征有3个取值,则分裂出3个子结点,然后在子结点中重复上述过程,直到所有特征的信息增益(比)很小或者没有特征可选为止,完成决策树的构建,如下图所示:
图中的决策树共有2个内部结点和3个叶子结点,每个结点旁边的编号代表训练样本的 ID 值。内部结点代表样本的特征,叶子结点代表样本的预测类别,我们将叶子节点中训练样本占比最大的类作为决策树的预测标记。
在使用构建好的决策树在测试数据上分类时,只需要从根节点开始依次测试内部结点代表的特征即可得到测试样本的预测分类。
决策树的生成算法
下面我们先来总结一下 ID3分类决策树的生成算法:
输入:训练数据集 D 、特征集合 A 的信息增益阈值 ;输出:决策树 T
-
若 D 中的训练样本属于同一类,则 T 为单结点树,返回 D 中任意样本的类别。
-
若 A 中的特征为空,则 T 为单结点树,返回 D 中数量最多的类别。
-
使用信息增益在 A 中进行特征选择,若所选特征 A_i 的信息增益小于设定的阈值,则 T 为单结点树,返回 D 中数量最多的类别。
-
否则根据 A_i 的每一个取值,将 D 分成若干子集 D_i,将 D_i 中数量最多的类作为标记值,构建子结点,返回 T。
-
以 D_i 为训练集,{A - A_i} 为特征集,递归地调用上述步骤,得到子树 T_i,返回 T。
使用 C4.5 算法进行决策树的生成只需要将信息增益改成信息增益比即可。
决策树的剪枝
决策树的损失函数
决策树的叶子节点越多,模型越复杂。决策树的损失函数考虑了模型复杂度,我们可以通过优化其损失函数来对决策树进行剪枝。决策树的损失函数计算过程如下:
-
计算叶子结点 t 的样本类别经验熵
对于叶子结点 t 来说,其样本类别的经验熵越小, t 中训练样本的分类误差就越小。当叶子结点 t 中的训练样本为同一类别时,经验熵为零,分类误差为零。 -
计算决策树 T 在所有训练样本上的损失之和 C(T)
对于叶子结点 t 中的每一个训练样本,其类别标记都是随机变量 Y 的一个取值,这个取值的不确定性用信息熵来衡量,且可以用经验熵来估计。由上文可知,经验熵在一定程度上可以反映决策树在该样本上的预测损失,累加所有叶子结点上的训练样本损失即上图中的计算公式。 -
计算考虑模型复杂度的的决策树损失函数
决策树的叶子结点个数表示模型的复杂度,通过最小化上面的损失函数,一方面可以减少模型在训练样本上的预测误差,另一方面可以控制模型的复杂度,保证模型的泛化能力。
决策树的剪枝算法
-
计算决策树中每个结点的样本类别经验熵:
如上图所示,对于本课示例中的决策树,需要计算 5 个结点的经验熵。 -
遍历非叶子结点,剪枝相当于去除其子结点,自身变为叶子结点:
对于图中的非叶子结点(有工作?),剪枝后变为叶子结点,并通过多数表决的方法确定其类别标记。
以上就是这节课的所有内容了,实际上还有一种决策树算法:分类与回归树(classification and regression tree,简称 CART),它既可以用于分类也可以用于回归,同样包含了特征选择、决策树的生成与剪枝算法。
关于 CART 算法的内容,我们将在最后一章 XGBoost 中进行学习,下面请你来做一道关于信息增益比的题目,顺便回顾一下前面所学的知识
代码
## 1. 创建数据集import pandas as pd
data = [['yes', 'no', '青年', '同意贷款'],['yes', 'no', '青年', '同意贷款'],['yes', 'yes', '中年', '同意贷款'],['no', 'no', '中年', '不同意贷款'],['no', 'no', '青年', '不同意贷款'],['no', 'no', '青年', '不同意贷款'],['no', 'no', '中年', '不同意贷款'],['no', 'no', '中年', '不同意贷款'],['no', 'yes', '中年', '同意贷款']]# 转为 dataframe 格式
df = pd.DataFrame(data)
# 设置列名
df.columns = ['有房?', '有工作?', '年龄', '类别']## 2. 经验熵的实现from math import log2
from collections import Counter
def H(y):'''y: 随机变量 y 的一组观测值,例如:[1,1,0,0,0]'''# 随机变量 y 取值的概率估计值probs = [n/len(y) for n in Counter(y).values()]# 经验熵:根据概率值计算信息量的数学期望return sum([-p*log2(p) for p in probs])## 3. 经验条件熵的实现def cond_H(a):'''a: 根据某个特征的取值分组后的 y 的观测值,例如:[[1,1,1,0],[0,0,1,1]]每一行表示特征 A=a_i 对应的样本子集'''# 计算样本总数sample_num = sum([len(y) for y in a])# 返回条件概率分布的熵对特征的数学期望return sum([(len(y)/sample_num)*H(y) for y in a])## 4. 特征选择函数
def feature_select(df,feats,label):'''df:训练集数据,dataframe 类型feats:候选特征集label:df 中的样本标记名,字符串类型'''# 最佳的特征与对应的信息增益比best_feat,gainR_max = None,-1# 遍历每个特征for feat in feats:# 按照特征的取值对样本进行分组,并取分组后的样本标记数组group = df.groupby(feat)[label].apply(lambda x:x.tolist()).tolist()# 计算该特征的信息增益:经验熵-经验条件熵gain = H(df[label].values) - cond_H(group)# 计算该特征的信息增益比gainR = gain / H(df[feat].values)# 更新最大信息增益比和对应的特征if gainR > gainR_max:best_feat,gainR_max = feat,gainRreturn best_feat,gainR_max ## 5. 决策树的生成函数
import pickle
def creat_tree(df,feats,label):'''df:训练集数据,dataframe 类型feats:候选特征集,字符串列表label:df 中的样本标记名,字符串类型'''# 当前候选的特征列表feat_list = feats.copy()# 若当前训练数据的样本标记值只有一种if df[label].nunique()==1:# 将数据中的任意样本标记返回,这里取第一个样本的标记值return df[label].values[0]# 若候选的特征列表为空时if len(feat_list)==0:# 返回当前数据样本标记中的众数,各类样本标记持平时取第一个return df[label].mode()[0]# 在候选特征集中进行特征选择feat,gain = feature_select(df,feat_list,label)# 若选择的特征信息增益太小,小于阈值 0.1if gain<0.1:# 返回当前数据样本标记中的众数return df[label].mode()[0]# 根据所选特征构建决策树,使用字典存储tree = {feat:{}}# 根据特征取值对训练样本进行分组g = df.groupby(feat)# 用过的特征要移除feat_list.remove(feat)# 遍历特征的每个取值 ifor i in g.groups:# 获取分组数据,使用剩下的候选特征集创建子树tree[feat][i] = creat_tree(g.get_group(i),feat_list,label)# 存储决策树pickle.dump(tree,open('tree.model','wb'))return tree# 6. 决策树的预测函数
def predict(tree,feats,x):'''tree:决策树,字典结构feats:特征集合,字符串列表x:测试样本特征向量,与 feats 对应'''# 获取决策树的根结点:对应样本特征root = next(iter(tree))# 获取该特征在测试样本 x 中的索引i = feats.index(root)# 遍历根结点分裂出的每条边:对应特征取值for edge in tree[root]:# 若测试样本的特征取值=当前边代表的特征取值if x[i]==edge:# 获取当前边指向的子结点child = tree[root][edge]# 若子结点是字典结构,说明是一颗子树if type(child)==dict:# 将测试样本划分到子树中,继续预测return predict(child,feats,x)# 否则子结点就是叶子节点else:# 返回叶子节点代表的样本预测值return child## 7. 在样例数据上测试# 获取特征名列表
feats = list(df.columns[:-1])
# 获取标记名
label = df.columns[-1]
# 创建决策树(此处使用信息增益比进行特征选择)
T = creat_tree(df,feats,label)
# 计算训练集上的预测结果
preds = [predict(T,feats,x) for x in df[feats].values]
# 计算准确率
acc = sum([int(i) for i in (df[label].values==preds)])/len(preds)
# 输出决策树和准确率
print(T,acc)
相关文章:

决策树的生成与剪枝
决策树的生成与剪枝 决策树的生成生成决策树的过程决策树的生成算法 决策树的剪枝决策树的损失函数决策树的剪枝算法 代码 决策树的生成 生成决策树的过程 为了方便分析描述,我们对上节课中的训练样本进行编号,每个样本加一个ID值,如图所示…...
蓝桥杯算法训练 黑色星期五
题目描述 有些西方人比较迷信,如果某个月的13号正好是星期五,他们就会觉得不太吉利,用古人的说法,就是“诸事不宜”。请你编写一个程序,统计出在某个特定的年份中,出现了多少次既是13号又是星期五的情形&am…...

MySQL存储引擎-存储结构
Innodb存储结构 Buffer Pool(缓冲池):BP以Page页为单位,页默认大小16K,BP的底层采用链表数据结构管理Page。在InnoDB访问表记录和索引时会在Page页中缓存,以后使用可以减少磁盘IO操作,提升效率。 ○ Page根据状态可以分…...
理解torch函数bmm
基本信息 功能描述 torch.bmm 是 PyTorch 中的一个函数,用于执行批量矩阵乘法(Batch Matrix Multiplication)。它适用于处理一批矩阵的乘法操作,特别适合于深度学习任务中的场景,比如卷积神经网络中的某些层。 参数…...

2024 年的科技趋势
2024 年在科技领域有着诸多重大进展与突破。从人工智能、量子计算到基因组医学、可再生能源以及新兴技术重塑了众多行业。随着元宇宙等趋势的兴起以及太空探索取得的进步,未来在接下来的岁月里有望继续取得进展与突破。让我们来探讨一下定义 2024 年的一些关键趋势&…...

win服务器的架设、windows server 2012 R2 系统的下载与安装使用
文章目录 windows server 2012 R2 系统的下载与安装使用1 windows server 2012 的下载2 打开 VMware 虚拟机软件(1)新建虚拟机(2)设置虚拟机(3)打开虚拟机 windows server 2012(4)进…...
leetcode45.跳跃游戏II
标签:动态规划 给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。每个元素 nums[i] 表示从索引 i 向前跳转的最大长度。换句话说,如果你在 nums[i] 处,你可以跳转到任意 nums[i j] 处:返回到达 nums[n - 1] 的最小跳跃次数。…...

边缘智能创新应用大赛获奖作品系列三:边缘智能强力驱动,机器人天团花式整活赋能千行百业
边缘智能技术快速迭代,并与行业深度融合。它正重塑产业格局,催生新产品、新体验,带动终端需求增长。为促进边缘智能技术的进步与发展,拓展开发者的思路与能力,挖掘边缘智能应用的创新与潜能,高通技术公司联…...
基于语义的NLP任务去重:大语言模型应用与实践
引言 在自然语言处理(NLP)任务中,数据质量是模型性能的关键因素之一。重复或冗余的数据会导致模型过度拟合或浪费计算资源,特别是在大语言模型(如 BERT、GPT 系列等)训练和推理阶段。传统的基于字符匹配的…...
使用阿里云Certbot-DNS-Aliyun插件自动获取并更新免费SSL泛域名(通配符)证书
进入nginx docker,一般是Alpine Linux系统 1. 依次执行命令: sudo docker-compose exec nginx bashapk updateapk add certbot apk add --no-cache python3 python3-dev build-baseapk add python3 py3-pippip3 install --upgrade pippip3 install certbot-dns-ali…...

Node.js安装配置+Vue环境配置+创建一个VUE项目
目录 安装Node.js搭建VUE环境 安装Node.js 下载 测试是否安装成功 在目录下新建两个文件夹 管理员打开cmd npm config set prefix "D:\Software\nodejs\node_global" npm config set cache "D:\Software\nodejs\node_cache"将默认的 C 盘下【 AppData\…...

“TA”说|表数据备份还原:SQLark 百灵连接助力项目部署验收
💬 南飞雁|应用开发工程师 有些重要项目的部署验收,会在生产环境完成,验收完成后,又需要把这部分数据清空。这时就需要对数据表进行备份和还原,虽然可以通过命令直接实现,但是有一些操作门槛&am…...

【FFmpeg】解封装 ① ( 封装与解封装流程 | 解封装函数简介 | 查找码流标号和码流参数信息 | 使用 MediaInfo 分析视频文件 )
文章目录 一、解封装1、封装与解封装流程2、解封装 常用函数 二、解封装函数简介1、avformat_alloc_context 函数2、avformat_free_context 函数3、avformat_open_input 函数4、avformat_close_input 函数5、avformat_find_stream_info 函数6、av_read_frame 函数7、avformat_s…...

Spring Boot 集成 MyBatis 全面讲解
Spring Boot 集成 MyBatis 全面讲解 MyBatis 是一款优秀的持久层框架,与 Spring Boot 集成后可以大大简化开发流程。本文将全面讲解如何在 Spring Boot 中集成 MyBatis,包括环境配置、基础操作、高级功能和最佳实践。 一、MyBatis 简介 1. SqlSession …...

C语言小练习-打印字母倒三角
编写一个程序,在用户输入某个大写字母后,产生一个金字塔图案。 #include <stdio.h>int main(int argc,char *argv[]) {char ch; loop:printf("请输入大写字母!\n");scanf("%c",&ch);getchar();if(ch < A ||…...

Linux -- 线程控制相关的函数
目录 pthread_create -- 创建线程 参数 返回值 代码 -- 不传 args: 编译时带 -lpthread 运行结果 为什么输出混杂? 如何证明两个线程属于同一个进程? 如何证明是两个执行流? 什么是LWP? 代码 -- 传 args&a…...

基于quasar,只选择年度与月份的组件
为什么要做 quasar是个基于vue的强大的UI开发库,它提供了非常多的组件,比如日期选择。但是有些时候只需要选择到月份就可以了,quasar中没有,所以自己动手写了一个。因为对界面编程我不熟悉,所以,如果你有更…...

健康养生:拥抱生活的艺术
健康养生:拥抱生活的艺术 在快节奏的现代生活中,健康已成为我们最宝贵的财富。健康养生,不仅仅是一种生活方式的选择,更是一种对待生活的态度,它关乎于如何在日常中寻找到平衡,让身心得以滋养,…...

注意力机制+时空特征融合!组合模型集成学习预测!LSTM-Attention-Adaboost多变量时序预测
注意力机制时空特征融合!组合模型集成学习预测!LSTM-Attention-Adaboost多变量时序预测 目录 注意力机制时空特征融合!组合模型集成学习预测!LSTM-Attention-Adaboost多变量时序预测效果一览基本介绍程序设计参考资料 效果一览 基…...

uniapp 微信小程序 均分数据展示
效果图 数据展示,可自行搭配 html <view class"num-wrapper"><view class"num-item" click.stop"routerGo(跳转的地址)"><text class"num">¥{{ 要展示的数据 || 0}}</text><view…...
解锁数据库简洁之道:FastAPI与SQLModel实战指南
在构建现代Web应用程序时,与数据库的交互无疑是核心环节。虽然传统的数据库操作方式(如直接编写SQL语句与psycopg2交互)赋予了我们精细的控制权,但在面对日益复杂的业务逻辑和快速迭代的需求时,这种方式的开发效率和可…...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
稳定币的深度剖析与展望
一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...

基于 TAPD 进行项目管理
起因 自己写了个小工具,仓库用的Github。之前在用markdown进行需求管理,现在随着功能的增加,感觉有点难以管理了,所以用TAPD这个工具进行需求、Bug管理。 操作流程 注册 TAPD,需要提供一个企业名新建一个项目&#…...
CRMEB 中 PHP 短信扩展开发:涵盖一号通、阿里云、腾讯云、创蓝
目前已有一号通短信、阿里云短信、腾讯云短信扩展 扩展入口文件 文件目录 crmeb\services\sms\Sms.php 默认驱动类型为:一号通 namespace crmeb\services\sms;use crmeb\basic\BaseManager; use crmeb\services\AccessTokenServeService; use crmeb\services\sms\…...

STM32---外部32.768K晶振(LSE)无法起振问题
晶振是否起振主要就检查两个1、晶振与MCU是否兼容;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容(CL)与匹配电容(CL1、CL2)的关系 2. 如何选择 CL1 和 CL…...
C语言中提供的第三方库之哈希表实现
一. 简介 前面一篇文章简单学习了C语言中第三方库(uthash库)提供对哈希表的操作,文章如下: C语言中提供的第三方库uthash常用接口-CSDN博客 本文简单学习一下第三方库 uthash库对哈希表的操作。 二. uthash库哈希表操作示例 u…...

TSN交换机正在重构工业网络,PROFINET和EtherCAT会被取代吗?
在工业自动化持续演进的今天,通信网络的角色正变得愈发关键。 2025年6月6日,为期三天的华南国际工业博览会在深圳国际会展中心(宝安)圆满落幕。作为国内工业通信领域的技术型企业,光路科技(Fiberroad&…...

若依登录用户名和密码加密
/*** 获取公钥:前端用来密码加密* return*/GetMapping("/getPublicKey")public RSAUtil.RSAKeyPair getPublicKey() {return RSAUtil.rsaKeyPair();}新建RSAUti.Java package com.ruoyi.common.utils;import org.apache.commons.codec.binary.Base64; im…...

6.9-QT模拟计算器
源码: 头文件: widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QMouseEvent>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widget(QWidget *parent nullptr);…...