25 Decision trees
\[ \DeclarePairedDelimiters{\set}{\{}{\}} \DeclareMathOperator*{\argmax}{argmax} \]
Decision trees are simple and intuitive machine learning models, that can be expanded and elaborated into a surprisingly powerful and general method for both classification and regression. They are based on the idea of performing a series of yes-or-no tests on the data, which finally lead to decision.
We have already seen the tree-like structure for illustrating decision problems in Chapter 3, but in this section, we are doing similar yet different things. The name of “decision trees” still pertains to both, but it should hopefully be clear from context when we are talking about what.
25.1 Decision trees for classification on categorical values
For a trivial example, lets us say you want to go and spend the day at the beach. There are certain criteria that should be fullfilled for that to be a good idea, and some of them depend on each other. A decision tree could look like this:
Depending on input data such as the weather, we end up following a certain path from the root node, along the branches, down to the leaf node, which returns the final decision for these given observations. The botany analogies are not strictly necessary, but at least we see where the name decision tree comes from.
Studying the above tree structure more closely, we see that there are several possible ways of structuring it, that would lead to the same outcome. We can choose to first split on the Sunny
node, and split on Workday
afterwards. Drawing it out on paper, however, would show that this structure needs a larger total number of nodes, since we always need to split on Workday
. Hence, the most efficient tree is the one that steps through the observables in order of descending importance.
The basic algorithm for buiding a decision tree (or growing it, if you prefer) on categorical data, can be written out quite compactly. Consider the following pseudo-code:
function BuildTree(examples, target_feature, features)
# examples is the training data
# target_feature is the feature we want to predict
# features is the list of features present in the data
tree ← a single node, so far without any label
if all examples are of the same classification then
give tree a label equal to the classification
return tree
else if features is empty then
give tree a label equal the most common value of target_feature in examples
return tree
else
best_feature ← the feature from features with highest Importance(examples)
for each value v of best_feature do
examples_v ← the subset of examples where best_feature has the value v
subtree ← BuildTree(examples_v, target_feature, features - {best_feature})
add a branch with label v to tree, and below it, add the tree subtree
return tree
This is the original ID3 algorithm [@quinlan1986induction]. Note how it works recursively – for each new feature, the function calls itself to build a subtree.
The same algorithm is shown and explained in section 19.3 in Russell and Norvig, although they fail to specify that this is ID3.
We start by creating a node, which becomes a leaf node either if it classifies all examples correctly (no reason to split), or if there are no more features left (not possible to split). Otherwise, we find the most important feature by calling Importance(examples)
, and proceed to make all possible splits. Now, the magic happens in the Importance
function. How can we quantify which feature is best to discriminate on? We have in sections Section 23.1 and Section 18.5 met a useful definition from information theory, which is the Shannon entropy:
\[ H(f) \coloneqq-\sum_{i} f_i\ \log_2 f_i \qquad\text{\color[RGB]{119,119,119}\small(with \(0\cdot\log 0 \coloneqq 0\))} \]
where the \(f_i\) are frequency distributions. If we stick to the simple example of our target features being “yes” or “no”, we can write out the summation like so:
\[ H = - f_{\mathrm{yes}} \log_2 f_{\mathrm{yes}} - f_{\mathrm{no}} \log_2 f_{\mathrm{no}} \]
Let us compute the entropy for two different cases, to see how it works. In the first case, we have 10 examples: 6 corresponding to “yes”, and four corresponding to “no”. The entropy is then
\[ H(6\;\text{yes}, 4\;\text{no}) = - (6/10) \log_2 (6/10) - (4/10) \log_2 (4/10) = 0.97 \]
In the second case, we still have 10 examples, but nearly all of the same class: 9 examples are “yes”, and 1 is “no”:
\[ H(9\;\text{yes}, 1\;\text{no}) = - (9/10) \log_2 (9/10) - (1/10) \log_2 (1/10) = 0.47 \]
Interpreting the entropy as a measure of impurity in the set of examples, we can guess (or compute, using \(0\cdot \log_2 0 = 0\)) that the lowest possible entropy occurs for a set where all are of the same class. When doing classification, this is of course what we aim for – separating all examples into those corresponding to “yes” and those corresponding to “no”. A way of selecting the most important feature is then to choose the one where we expect the highest reduction in entropy, caused by splitting on this feature. This is called the information gain, and is generally defined as
\[ Gain(A, S) \coloneqq H(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} H(S_v) \,, \tag{25.1}\]
where \(A\) is the feature under consideration, \(Values(A)\) are all the possible values that \(A\) can have. Further, \(S\) is the set of examples, and \(S_v\) is the subset containing examples where \(A\) has the value \(v\). Looking again at the binary yes/no case, it looks a little simpler. Using the feature Sunny
as \(A\), we get:
\[ Gain(\texttt{Sunny}, S) = H(S) - \left(\frac{S_{\texttt{Sunny=yes}}}{S} H(S_{\texttt{Sunny=yes}}) + \frac{S_{\texttt{Sunny=no}}}{S} H(S_{\texttt{Sunny=no}})\right) \,. \]
This equation can be read as “gain equals the original entropy before splitting on Sunny
, minus the weighted entropy after splitting”, which is what we were after. One thing to note about equation 25.1: while it allows for splitting on an arbitrary number of values, we typically always want to split in two, resulting in binary trees. Non-binary trees tend to quickly overfit, which why few of the successors to the ID3 algorithm allow this. The extreme case would be if a feature is continuous instead of categorical. For a continuous feature it is unlikely that the data will contain values that are identical – probably many values are similar, but not identical to e.g. ten digits precision. ID3 would potentially split such a feature into as many branches as there are data points, which is maximal overfitting. Binary trees can of course overfit too (we will get back to this shortly), but first, let us introduce a similar algorithm that can deal with both continuous inputs, and continuous output.
25.2 Decision trees for regression (in addition to classification)
The problem of continuous features can be solved by requiring only binary splits, and then searching for the optimal threshold value for where to split. Each node will then ask “is the value of feature \(A\) larger or smaller than the threshold \(x\)”? Finding the best threshold involves going through all the values in data and computing the expected information gain. One could initially think that we need to consider all possible values that the data could take, but luckily we need only to consider the values the data does take, since the expected information gain has discrete steps for value where we move a data point from one branch to the other.
A second thing we would like to solve, is to not only have categorical outputs (i.e. do classification), but also continuous values (i.e. do regression). Looking again at the pseudocode for the ID3 algorithm, we see that once we have “used up” all the features, the label assigned to the final branch will be the majority label among the remaining examples. For continuous target values, the fix is relatively simple – we instead just take the average of the values in the examples. These two solutions form the basis for the CART (Classification and Regression Trees) algorithm, which creates decision trees for any kind of inputs and outputs.
25.3 Preventing overfitting
Again from the pseudocode of the ID3 algorithm, we see that the basic rules for building a tree will keep making splits until we have perfect classification, or until no more features are available. Perfect classification surely sounds good, but is in practice rarely attainable, and these rules will typically create too many splits and thereby overfit to the training data.
Pruning and hyperparameter choices
The common approach to avoid this is in fact to just let it happen – and then afterwards, go in and remove the branches that give the least improvement in prediction. Sticking with the biologically inspired jargon, we call this pruning. We will not go into the details of this, but leave it as an illustration of how well-defined machine learning methods often need improvised heuristics to work. Other tweaks that are used in parallel include
- requiring a minimum number of training events in each end node, and
- enforcing a maximum depth, i.e. allowing only a certain number of subsequent splits.
Random forests
Training several decision trees on similar data tends to end up looking like figure 24.3 (c): they are sensitive to small variations in the data and hence have high variance. But while each individual tree can be far off the target, the average is still good, since the variations often cancel out. This can be the case for many types of machine learning methods. We can improve performance by defining a new ensemble model \(f(\mathbf{x})\) composed of several separate models \(f_m(\mathbf{x})\),
\[ f(\mathbf{x}) = \sum_{m=1}^M \frac{1}{M} f_m(\mathbf{x}) \,, \]
where the \(f_m\) are trained on a randomly selected subset of the total data. Necessarily, the \(f_m\) models will be highly correlated, so one can also train them on randomly selected subsets of features, in order to reduce this correlation. In the case of decision trees, the ensemble is (obviously) called a random forest.
Boosting
A final trick for decision trees, which is boosting. This is an important method that all high-performing implementations of random forests use, but it also relies on a lot of tedious math, so we will mostly gloss over it. The point is that we can create the ensemble iteratively by adding new trees one at a time, where each new addition tries to improve on the prediction by the existing trees.
Starting with a single tree \(f_1\), its predictions might be good, but not perfect. So when evaluated on the training data \(\mathbf{x}\), there will be a difference between the targets \(\mathbf{y}\) and the predictions \(f(\mathbf{x})\), and this difference we typically call residuals \(r\):
\[ \mathbf{r}_1 = \mathbf{y} - f_1(\mathbf{x}) \]
or if we re-write:
\[ \mathbf{y} = f_1(\mathbf{x}) + \mathbf{r}_1 \]
We want \(\mathbf{r}\) to be as small as possible, but with a single tree, there is only so much we can do. The magic is to train a second tree \(f_2\), not to again predict \(y\), but to predict the residual \(r_1\). Then we get
\[ \mathbf{y} = f_1(\mathbf{x}) + f_2(\mathbf{x}) + \mathbf{r}_2 \]
with \(r_2 < r_1\) (hopefully), resulting in an improved prediction. This stagewise additive modelling can be done until we see no further improvement, resulting in an emsemble model that typically outperforms a standard random forest approach. Several variants of the boosting algorithm exists, some of which are discussed in chapter 16.4 of the Murphy book. Here, the explanation of the different boosting variants are based on arguments about loss functions. We have not started looking into loss funtions yet, but this will be a topic in the next chapter. It can be useful to come back to the Murphy book after having gone through next chapter’s material.
25.4 Software frameworks
We are about to start the programming exercises, so let us briefly discuss our options for the implementation. While coding a decision tree algorithm from scratch is not too technically challenging, it is an ineffective way to learn aboubt the its behaviour in practice, and we also do not have the time to so. Hence we will use ready-made frameworks.
For decent-performing and easy-to-use frameworks, we suggest to use scikit-learn
for Python, and tidymodels
for R. There are, however, plenty of alternatives, and you are free to choose whichever you like.
For state-of-the-art boosted decision trees, we suggest XGBoost
, which regularly ranks among the best tree-based algorithms on the machine learning competitions at Kaggle, and is available for both Python and R.