HomeAbout MeBlogOther StuffPublications
Contact Me

How does pruning work in CART ?

Ok so as we saw in previous parts, the CART algorithm allows us to build decision trees. Up till now we have built these trees until all leaves are pure, meaning they have only one class of examples (for classification trees), however this can lead to overfitting the training data which decreases the generalizability of our model, and therefore it's usefulness. This is where cost-complexity pruning comes into play.

What is pruning?

So pruning comes from biology, pruning a plant is selectively removing some part of it. In the case of decision trees, it just means removing some branches. However even though we remove branches we want to keep all of our samples, so we cant just eliminate part of the samples from branches, so effectively removing a branch corresponds to choosing a pruning node, where we want our branch to end, and collapsing all it's child nodes into it.
Now how do we choose which branches to remove ? if we remove too many our model looses any classifying, or regressing power it has and if we remove too few we can still have overfitting. This is addressed by cost-complexity pruning, which balances the complexity of the tree (the number of leaves, so potential overfitting) with the performance of the tree.

Notation

In order to explain cost-complexity pruning, we are going to need to give some names to things we need, luckily that's already been done.

Tree nomenclature

Let us consider a decision tree TTT and two of its nodes, ttt and t′t't′.

  • t′t't′ is a descendant of ttt if there is a path down (from the root to the leaves) the tree from ttt to t′t't′.
  • ttt is an ancestor of t′t't′ if there is a path up (from the leaves to the root) from t′t't′ to ttt.
  • tRt_RtR​ and tLt_LtL​ are, respectively, the right and left child nodes of ttt
  • A branch TtT_tTt​ is the branch of TTT with root ttt, is composed of the node ttt and all of its descendants.
  • pruning a branch TtT_tTt​ from TTT is removing all nodes of TtT_tTt​ from TTT, the pruned tree is called T−TtT-T_tT−Tt​
  • If you can get a tree T′T'T′ from TTT by pruning branches, the T′T'T′ is a pruned subtree of TTT and we denote that with: $ T' \leq T$
  • For a given tree TTT, we can define, the set of leaf nodes T~\widetilde{T}T
  • The complexity of TTT is given by the cardinality of T~\widetilde{T}T, (ie. the number of leaf nodes), it is noted: ∣T~∣\vert\widetilde{T}\vert∣T∣

Measures

Let us consider a leaf node ttt of TTT, with κ(t)\kappa(t)κ(t) the class of ttt (ie the majority class in the node).

  • r(t)=1−p(κ(t)∣t)r(t) = 1 - p(\kappa(t)\vert t)r(t)=1−p(κ(t)∣t) the is the resubstitution error estimate of ttt. p(κ(t)∣t)p(\kappa(t)\vert t)p(κ(t)∣t) is the proportion of the majority class in ttt.
  • We denote R(t)=p(t)⋅r(t)R(t) = p(t)\cdot r(t)R(t)=p(t)⋅r(t), with p(t)p(t)p(t) simply being the proportion of samples in node ttt compared to the rest of the tree.
  • It is provable that R(t)≥R(tR)+R(tL)R(t) \geq R(t_R) + R(t_L)R(t)≥R(tR​)+R(tL​), which just means that if we split a node the misclassification rate is sure to improve.
  • The overall missclassification rate for TTT, is:
R(T)=∑t∈T~R(t)=∑t∈T~r(t)⋅p(t)R(T) = \sum_{t\in \widetilde{T}} R(t) = \sum_{t\in \widetilde{T}} r(t)\cdot p(t)R(T)=t∈T∑​R(t)=t∈T∑​r(t)⋅p(t)

Which is to say the sum of the resubstitution error of a leaf node multiplied by the probability of being in said node over all of the leaf nodes.

The pruning

The first step in pruning a tree is, ..., you guessed it: having a tree. So we start by growing TmaxT_{max}Tmax​ the maximal tree, with pure leaves. Now the naive approach would be to go through all the possible pruned subtrees and see which one has the best trade-off between performance and complexity, however that is, in practice, impossible because of the huge number of possible pruned subtrees.

This is a test of how the git thing works

Luc Blassel. 2020