三个参数对随机森林分类结果的影响(附代码)

31
五月
2021

使用手写数据集研究集成规模,树的最大深度以及特征数对随机森林分类结果的影响。代码在末尾。
手写数据集

使用交叉验证,返回accuracy,折数为10

1. 集成规模n_estimators

集成规模即树的数量。为了观察其影响,分别对数量1~40进行交叉验证,返回得到准确率并绘图。
集成规模与准确度的关系
观察可得,集成规模的增加可以提高模型在训练集和测试集上的准确度,这是因为增加树的数量可以减少偏差和方差,还可以发现模型不会随着复杂度的增加而过度拟合训练数据。

2. 树的最大深度max_depth

树的最大深度反映了单个树的复杂度,将集成规模固定为20,max_depth为变量进行类似实验。
树的最大深度与准确度的关系
显而易见,在随机森林中,决策树越强,集成的模型就越强,并不会失去泛化性。而基于一组基本学习器的其他算法可能会对弱学习器或者表现不佳的学习器产生跟强学习器一样的偏爱,例如Boosting的一些情况,

3. 特征数max_features

随机森林算法为了减少相关性以及在集成中引入随机性,从每个内部决策树的所有可用输入特征中选择k特征的随机子集。

使用较少数量的输入特征会降低各个树之间的相似性,但也会导致树的复杂性降低从而导致树的强度降低。

反之,增加特征数量可以使每棵树更强大,但也增加了树之间相关性。

特征数与准确度的关系
实验结果证明了,在一共64个特征中,选取大约 k = n f e a t u r e s = 64 = 8 k=\sqrt{n_{features}}=\sqrt{64}=8 k=nfeatures =64 =8个特征会得到比较好的结果。这也与一些文献中的所述一致。

附:实现代码
#!/usr/bin/env python
# coding: utf-8
import numpy as np
from sklearn.model_selection import cross_validate
from matplotlib import pyplot as plt
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from tqdm import tqdm
import seaborn as sns

sns.set(
    context="notebook",
    style="whitegrid",
    rc={"figure.dpi": 120, "scatter.edgecolors": "k"},
)


def evaluate_n_estimators(X: np.ndarray, y: np.ndarray, n: int) -> (float, float):
    """"Run 10 fold cross-validation of the model for a given number of trees and returns the
    mean train and test score."""
    clf = RandomForestClassifier(n_estimators=n)
    scores = cross_validate(
        estimator=clf,
        X=X,
        y=y,
        scoring="accuracy",
        cv=10,
        return_train_score=True,
    )
    return np.mean(scores["train_score"]),np.mean(scores["test_score"])


def evaluate_depth(X: np.ndarray, y: np.ndarray, depth: int) -> (float, float):
    """Run 10 fold cross-validation of the model for a given tree depth and returns the
    mean train and test score."""
    clf=RandomForestClassifier(n_estimators=20, max_depth=depth)
    scores = cross_validate(
        estimator=clf,
        X=X,
        y=y,
        cv=10,
        scoring="accuracy",
        return_train_score=True,
    )
    return np.mean(scores["train_score"]), np.mean(scores["test_score"])


def evaluate_features(X: np.ndarray, y: np.ndarray, n_features: int) -> (float, float):
    """"Run 10 fold cross-validation of the model for a given number of features per tree and returns the
    mean train and test score."""
    clf = RandomForestClassifier(n_estimators=20, max_features=n_features)
    scores = cross_validate(
        estimator=clf,
        X=X,
        y=y,
        cv=10,
        scoring="accuracy",
        return_train_score=True,
    )
    return np.mean(scores["train_score"]), np.mean(scores["test_score"])


def plot_accuracy(xs: range, accuracies: np.ndarray, xlabel: str, ylabel="Accuracy") -> None:
    """Plot results for the given accuracies."""
    acc_train = accuracies[:, 0]
    acc_test = accuracies[:, 1]
    plt.figure()
    plt.plot(xs, acc_train, label="Train", linestyle="--")
    plt.plot(xs, acc_test, label="Test", linestyle="--")
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(xs[::5])
    plt.legend()
    plt.show()


def main():
    # Load digits
    X, y = load_digits(return_X_y=True)

    # Plot samples
    plt.figure()
    for i, (x_i, y_i) in enumerate(zip(X[:4], y[:4]), start=1):
        plt.subplot(140 + i)
        plt.imshow(x_i.reshape(8, 8), cmap="gray")
        plt.title("label = " + str(y_i))
        plt.axis("off")

    # Define interval
    n_estimators = range(1, 41)

    # Evaluate interval
    accuracies_n_est = np.array([evaluate_n_estimators(X, y, alpha) for alpha in tqdm(n_estimators)])

    plot_accuracy(n_estimators, accuracies_n_est, "Number of Trees")

    # Define interval
    depths = range(1, 15)

    # Evaluate interval
    accuracies_depths = np.array([evaluate_depth(X, y, d) for d in tqdm(depths)])

    # Plot results
    plot_accuracy(depths, accuracies_depths, "Tree Depth")

    # Define interval
    n_features = range(1, X.shape[1], 1)

    # Evaluate interval
    accuracies_n_feat = np.array([evaluate_features(X, y, n) for n in tqdm(n_features)])

    # Plot results
    plot_accuracy(n_features, accuracies_n_feat, "Max. Number of Features per Tree")


if __name__ == '__main__':
    main()

TAG

网友评论

共有访客发表了评论
请登录后再发布评论,和谐社会,请文明发言,谢谢合作! 立即登录 注册会员