分类与回归树(Classification and Regression Trees, CART)是由四人帮Leo Breiman, Jerome Friedman, Richard Olshen与Charles Stone于1984年提出,既可用于分类也可用于回归。本文将主要介绍用于分类的CART。CART被称为数据挖掘领域内里程碑式的算法。
不同于C4.5,CART本质是对特征空间进行二元划分,能够对标量属性(nominal attribute)与连续属性(continuous attribute)进行分裂;也就是说CART生成的决策树是一棵二叉树。
前一篇提到过决策树生成涉及到两个问题:如何选择最优特征属性进行分裂,以及停止分裂的条件是什么。
CART对特征属性进行二元分裂。特别地,当特征属性为标量或连续时,可选择如下方式分裂:
An instance goes left if CONDITION , and goes right otherwise
即样本记录满足 CONDITION 则分裂给左子树,否则则分裂给右子树。
标量属性进行二元分裂时的 CONDITION 可置为 不等于属性的某值
。比如,标量属性 Car Type
取值空间为 {Sports, Family, Luxury}
,二元分裂与多路分裂如下:
对 连续属性 而言, CONDITION 可置为不大于$/varepsilon$。比如,连续属性 Annual Income
,$/varepsilon$取属性相邻值的平均值,其二元分裂结果如下:
接下来,需要解决的问题:应该选择哪种特征属性, CONDITION 应如何定义。CART采用Gini指数来度量分裂时的不纯度,之所以采用Gini指数,是因为较于熵而言其计算速度更快一些。对决策树的节点$t$,Gini指数计算公式如下:
/begin{equation}
Gini(t)=1-/sum/limits_{k}[p(c_k|t)]^2
/end{equation}
Gini指数即为$1$与类别$c_k$的概率平方之和的差值,反映了样本集合的不确定性程度。Gini指数越大,样本集合的不确定性程度越高。分类学习过程的本质是样本不确定性程度的减少(即熵减过程),故应选择 最小Gini指数 的特征分裂。父节点对应的样本集合为$D$,CART选择特征$A$分裂为两个子节点,对应集合为$D_L$与$D_R$;分裂后的Gini指数定义如下:
/begin{equation}
G(D,A)={/left|{D_L} /right| /over /left|{D} /right|}Gini(D_L)+{/left|{D_R} /right| /over /left|{D} /right|}Gini(D_R)
/end{equation}
其中,$/left| /cdot /right|$表示样本集合的记录数量。
CART算法流程与C4.5算法相类似:
CART剪枝与C4.5的剪枝策略相似,均以极小化整体损失函数实现。同理,定义决策树$T$的损失函数为:
$$
L_/alpha (T)=C(T)+/alpha /left| T /right|
$$
其中,$C(T)$表示决策树的训练误差,$/alpha$为调节参数,$/left| T /right|$为模型的复杂度。
CART算法采用递归的方法进行剪枝,具体办法:
如何计算最优子树为$T_i$呢?首先,定义以$t$为单节点的损失函数为
$$
L_/alpha (t)=C(t)+/alpha
$$
以$t$为根节点的子树$T_t$的损失函数为
$$
L_/alpha (T_t)=C(T_t)+/alpha /left| T_t /right|
$$
令$L_/alpha (t)=L_/alpha (T_t)$,则得到
$$
/alpha = {C(t)-C(T_t) /over /left| T_t /right|-1}
$$
此时,单节点$t$与子树$T_t$有相同的损失函数,而单节点$t$的模型复杂度更小,故更为可取;同时也说明对节点$t$的剪枝为有效剪枝。由此,定义对节点$t$的剪枝后整体损失函数减少程度为
$$
g(t) = {C(t)-C(T_t) /over /left| T_t /right|-1}
$$
剪枝流程如下:
关于CART剪枝算法的具体描述请参看[1],其中关于剪枝算法的描述有误:
(6)如果T不是由根节点单独构成的树,则回到步骤(4)
应改为 回到步骤(3)
,要不然所有$/alpha$均一样了。
[1] 李航,《统计学习方法》.
[2] Pang-Ning Tan, Michael Steinbach, Vipin Kumar, Introduction to Data Mining .
[3] Xindong Wu, Vipin Kumar, The Top Ten Algorithms in Data Mining.