Decision trees are one of the most intuitive machine learning algorithms used both for classification and regression. After reading, you’ll know how to implement a decision tree classifier entirely from scratch.

This is the fifth of many upcoming from-scratch articles, so stay tuned to the blog if you want to learn more. The links to the previous articles are located at the end of this piece.

The article is structured as follows:

- Introduction to Decision Trees
- Math Behind Decision Trees
- Recursion Crash Course
- From-Scratch Implementation
- Model Evaluation
- Comparison with Scikit-Learn
- Conclusion

You can download the corresponding notebook here.

## Introduction to Decision Trees

Decision trees are a non-parametric model used for both regression and classification tasks. The from-scratch implementation will take you some time to fully understand, but the intuition behind the algorithm is quite simple.

Decision trees are constructed from only two elements – nodes and branches. We’ll discuss different types of nodes in a bit. If you decide to follow along, the term **recursion** shouldn’t feel like a foreign language, as the algorithm is based on this concept. You’ll get a crash course in recursion in a couple of minutes, so don’t sweat it if you’re a bit rusty on the topic.

Let’s take a look at an example decision tree first:

As you can see, there are multiple types of nodes:

**Root node**– node at the top of the tree. It contains a feature that best splits the data (a single feature that alone classifies the target variable most accurately)**Decision nodes**– nodes where the variables are evaluated. These nodes have arrows pointing to them and away from them**Leaf nodes**– final nodes at which the prediction is made

Depending on the dataset size (both in rows and columns), there are probably thousands to millions of ways the nodes and their conditions can be arranged. **So, how do we determine the root node?**

### How to determine the root node

In a nutshell, we need to check how every input feature classifies the target variable independently. If none of the features alone is 100% correct in the classification, we can consider these features **impure**.

To further decide which of the impure features is most pure, we can use the **Entropy** metric. We’ll discuss the formula and the calculations later, but you should remember that the entropy value ranges from 0 (best) to 1 (worst).

The variable with the lowest entropy is then used as a root node.

### Training process

To begin training the decision tree classifier, we have to determine the root node. That part has already been discussed.

Then, for every single split, the **Information gain** metric is calculated. Put simply, it represents an average of all entropy values based on a specific split. We’ll discuss the formula and calculations later, but please remember that the higher the gain is, the better the decision split is.

The algorithm then performs a greedy search – goes over all input features and their unique values, calculates information gain for every combination, and saves the best split feature and threshold for every node.

In this way, the tree is built recursively. The recursion process could go on forever, so we’ll have to specify some exit conditions manually. The most common ones are maximum depth and minimum samples at the node. Both will be discussed later upon implementation.

### Prediction process

Once the tree is built, we can make predictions for unseen data by recursively traversing the tree. We can check for the traversal direction (left or right) based on the input data and learned thresholds at each node.

Once the leaf node is reached, the most common value is returned.

And that’s it for the basic theory and intuition behind decision trees. Let’s talk about the math behind the algorithm in the next section.

## Math Behind Decision Trees

Decision trees represent much more of a coding challenge than a mathematical one. You’ll only have to implement two formulas for the learning part – entropy and information gain.

Let’s start with **entropy**. As mentioned earlier, it measures a purity of a split at a node level. Its value ranges from 0 (pure) and 1 (impure).

Here’s the formula for entropy:

As you can see, it’s a relatively simple equation, so let’s see it in action. Imagine you want to calculate the purity of the following vector:

To summarize, zeros and ones are the class labels with the following counts:

The entropy calculation is as simple as it can be from this point (rounded to five decimal points):

The result of 0.88 indicates the split is nowhere near pure. Let’s repeat the calculation in Python next. The following code implements the `entropy(s)`

formula and calculates it on the same vector:

The results are shown in the following image:

As you can see, the results are identical, indicating the formula was implemented correctly.

Let’s take a look at the **information gain** next. It represents an average of all entropy values based on a specific split. The higher the information gain value, the better the decision split is.

Information gain can be calculated with the following formula:

Let’s take a look at an example split and calculate the information gain:

As you can see, the entropy values were calculated beforehand, so we don’t have to waste time on them. Calculating information gain is now a trivial process:

Let’s implement it in Python next. The following code snippet implements the `information_gain()`

function and calculates it for the previously discussed split:

The results are shown in the following image:

As you can see, the values match.

And that’s all there is to the math behind decision trees. I’ll repeat – this algorithm is much more challenging to implement in code than to understand mathematically. That’s why you’ll need an additional primer on recursion – coming up next.

## Recursion Crash Course

A lot of implementation regarding decision trees boils down to recursion. This section will provide a sneak peek at recursive functions and isn’t by any means a go-to guide to the topic. If this term is new to you, please research it if you want to understand decision trees.

Put simply, a recursive function is a function that calls itself. We don’t want this process going on indefinitely, so the function will need an exit condition. You’ll find it written at the top of the function.

Let’s take a look at the simplest example possible – a recursive function that returns a factorial of an integer:

The results are shown in the following image:

As you can see, the function calls itself until the entered number isn’t 1. That’s the exit condition of our function.

Recursion is needed in decision tree classifiers to build additional nodes until some exit condition is met. That’s why it’s crucial to understand this concept.

Up next, we’ll implement the classifier. It will require around 200 lines of code (minus the docstrings and comments), so embrace yourself.

## From-Scratch Implementation

We’ll need two classes:

`Node`

– implements a single node of a decision tree`DecisionTree`

– implements the algorithm

Let’s start with the `Node`

class. It is here to store the data about the feature, threshold, data going left and right, information gain, and the leaf node value. All are initially set to `None`

. The root and decision nodes will contain values for everything besides the leaf node value, and the leaf node will contain the opposite.

Here’s the code for the class:

That was the easy part. Let’s implement the classifier next. It will contain a bunch of methods, all of which are discussed below:

`__init__()`

– the constructor, holds values for`min_samples_split`

and`max_depth`

. These are hyperparameters. The first one is used to specify a minimum number of samples required to split a node, and the second one specifies a maximum depth of a tree. Both are used in recursive functions as exit conditions`_entropy(s)`

– calculates the impurity of an input vector`s`

`_information_gain(parent, left_child, right_child)`

calculates the information gain value of a split between a parent and two children`_best_split(X, y)`

function calculates the best splitting parameters for input features`X`

and a target variable`y`

. It does so by iterating over every column in`X`

and every threshold value in every column to find the optimal split using information gain`_build(X, y, depth)`

function recursively builds a decision tree until stopping criteria is met (hyperparameters in the constructor)`fit(X, y)`

function calls the`_build()`

function and stores the built tree to the constructor`_predict(x)`

function traverses the tree to classify a single instance`predict(X)`

function applies the`_predict()`

function to every instance in matrix`X`

.

It’s a lot – no arguing there. Take your time to understand every line from the code snippet below. It is well-documented, so the comments should help a bit:

You’re not expected to understand every line of code in one sitting. Give it time, go over the code line by line and try to reason why things work. It’s not that difficult once you understand the basic intuition behind the algorithm.

## Model Evaluation

Let’s test our classifier next. We’ll use the Iris dataset from Scikit-Learn. The following code snippet loads the dataset and separates it into features (`X`

) and the target (`y`

):

Let’s split the dataset into training and testing portions next. The following code snippet does just that, in an 80:20 ratio:

And now let’s do the training. The code snippet below trains the model with default hyperparameters and makes predictions on the test set:

Let’s take a look at the generated predictions (`preds`

):

And now at the actual class labels (`y_test`

):

As you can see, both are identical, indicating a perfectly accurate classifier. You can further evaluate the performance if you want. The code below prints the accuracy score on the test set:

As expected, the value of `1.0`

would get printed. Don’t let this fool you – the Iris dataset is incredibly easy to classify correctly, especially if you get a good “random” test set. Still, let’s compare our classifier to the one built into Scikit-Learn.

## Comparison with Scikit-Learn

We want to know if our model is any good, so let’s compare it with something we know works well — a `DecisionTreeClassifier`

class from Scikit-Learn.

You can use the following snippet to import the model class, train the model, make predictions, and print the accuracy score:

As you would expect, we get a perfect accuracy score of `1.0`

.

And that’s all for today. Let’s wrap things up in the next section.

## Conclusion

This was one of the most challenging articles I have ever written. It took around a week to get everything right and to make the code as understandable as possible. Naturally, it will take you at least a couple of readings to understand the topic altogether. Feel free to explore additional resources, as it will further advance your understanding.

You now know how to implement the Decision tree classifier algorithm from scratch. *Does that mean you should ditch the de facto standard machine learning libraries?* No, not at all. Let me elaborate.

Just because you can write something from scratch doesn’t mean you should. Still, knowing every detail of how algorithms work is a valuable skill and can help you stand out from every other *fit and predict* data scientist.

Thanks for reading, and please stay tuned to the blog if you’re interested in more machine learning from scratch articles.

## Learn More

- Master Machine Learning: Simple Linear Regression From Scratch With Python
- Master Machine Learning: Multiple Linear Regression From Scratch With Python
- Master Machine Learning: Logistic Regression From Scratch With Python
- Master Machine Learning: K Nearest Neighbors From Scratch With Python

## Stay connected

- Follow me on Medium for more stories like this
- Sign up for my newsletter
- Connect on LinkedIn