We already know a single decision tree can work surprisingly well. The idea of constructing a forest from individual trees seems like the natural next step.
Today you’ll learn how the Random Forest classifier works and implement it from scratch in Python. This is the sixth of many upcoming from-scratch articles, so stay tuned to the blog if you want to learn more. The links to the previous articles are located at the end of this piece.
The article is structured as follows:
- Introduction to Random Forest
- Math Behind Random Forest
- From-Scratch Implementation
- Model Evaluation
- Comparison with Scikit-Learn
You can download the corresponding notebook here.
Introduction to Random Forest
Just like decision trees, random forests are a non-parametric model used for both regression and classification tasks. If you understood the previous article on decision trees, you’ll have no issues understanding this one.
Needless to say, but that article is also a prerequisite for this one, for obvious reasons.
The entire random forest algorithm is built on top of weak learners (decision trees), giving you the analogy of using trees to make a forest. The term “random” indicates that each decision tree is built with a random subset of data.
Here’s an excellent image comparing decision trees and random forests:
As simple as that.
The random forest algorithm is based on the bagging method. It represents a concept of combining learning models to increase performance (higher accuracy or some other metric).
In a nutshell:
- N subsets are made from the original datasets
- N decision trees are build from the subsets
- A prediction is made with every trained tree, and a final prediction is returned as a majority vote
Here’s a diagram to drive these points home:
Let’s go over the math behind the algorithm next.
Math Behind Random Forest
Good news – no math today!
The math behind random forests is the same as with decision trees. You only need to implement two formulas – entropy and information gain.
If these sound like a foreign language, please refer to the previous article. Both concepts are discussed in great detail there.
The rest of the article assumes you’re familiar with the inner workings of decision trees, as it is required to build the algorithm from scratch.
We’ll need three classes this time:
Node– implements a single node of a decision tree
DecisionTree– implements a single decision tree
RandomForest– implements our ensemble algorithm
The first two classes are identical as they were in the previous article, so feel free to skip ahead if you already have them written.
Let’s start with the
Node class. It is here to store the data about the feature, threshold, data going left and right, information gain, and the leaf node value. All are initially set to
None. The root and decision nodes will contain values for everything besides the leaf node value, and the leaf node will contain the opposite.
Here’s the code for the class (alongside the library imports):
Let’s implement the decision tree classifier next. It will contain a bunch of methods, all of which are discussed below:
__init__()– the constructor, holds values for
max_depth. These are hyperparameters. The first one is used to specify a minimum number of samples required to split a node, and the second one specifies a maximum depth of a tree. Both are used in recursive functions as exit conditions
_entropy(s)– calculates the impurity of an input vector
_information_gain(parent, left_child, right_child)calculates the information gain value of a split between a parent and two children
_best_split(X, y)function calculates the best splitting parameters for input features
Xand a target variable
y. It does so by iterating over every column in
Xand every threshold value in every column to find the optimal split using information gain
_build(X, y, depth)function recursively builds a decision tree until stopping criteria is met (hyperparameters in the constructor)
fit(X, y)function calls the
_build()function and stores the built tree to the constructor
_predict(x)function traverses the tree to classify a single instance
predict(X)function applies the
_predict()function to every instance in matrix
Yes, it’s a lot, but you should already feel comfortable with this. Here’s the code snippet for a single decision tree:
Finally, let’s build the forest. This class is built on top of a single decision tree and has the following methods:
__init__()– the constructor, holds hyperparameter values for the number of trees in the forest, minimum samples split, and maximum depth. It will also hold individually trained decision trees once the model is trained
_sample(X, y)function applies bootstrap sampling to input features and input target
fit(X, y)function trains the classifier model
predict(X)function makes predictions with individual decision trees and then applies majority voting for the final prediction
Code-wise it’s a much simpler class than a decision tree. Here’s the entire snippet:
You might not understand everything fully in one sitting, but this won’t be too much of a challenge if you understood decision trees.
Let’s train and evaluate our model next.
Let’s test our classifier next. We’ll use the Iris dataset from Scikit-Learn. The following code snippet loads the dataset and separates it into features (
X) and the target (
Let’s split the dataset into training and testing portions next. The following code snippet does just that, in an 80:20 ratio:
And now let’s do the training. The code snippet below trains the model with default hyperparameters and makes predictions on the test set:
Let’s take a look at the generated predictions (
And now at the actual class labels (
As you can see, both are identical, indicating a perfectly accurate classifier. You can further evaluate the performance if you want. The code below prints the accuracy score on the test set:
If you were to run the above code, the value of
1.0 would get printed, indicating a perfect classifier. The Iris dataset is incredibly easy to classify correctly, so don’t let this fool you.
Let’s compare our classifier to the one built into Scikit-Learn next.
Comparison with Scikit-Learn
We want to know if our model is any good, so let’s compare it with something we know works well — a
RandomForestClassifier class from Scikit-Learn.
You can use the following snippet to import the model class, train the model, make predictions, and print the accuracy score:
As you would expect, we get a perfect accuracy score of 1.0.
And that’s all for today. Let’s wrap things up in the next section.
And there you have it – how to build a forest from the trees. It’s easier than you would think, especially if you consider that random forests are among the top-performing machine learning algorithms today.
You now know how to implement the Decision tree classifier algorithm from scratch. Does that mean you should ditch the de facto standard machine learning libraries? No, not at all. Let me elaborate.
Just because you can write something from scratch doesn’t mean you should. Still, knowing every detail of how algorithms work is a valuable skill and can help you stand out from every other fit and predict data scientist.
Thanks for reading, and please stay tuned to the blog if you’re interested in more machine learning from scratch articles.
- Master Machine Learning: Simple Linear Regression From Scratch With Python
- Master Machine Learning: Multiple Linear Regression From Scratch With Python
- Master Machine Learning: Logistic Regression From Scratch With Python
- Master Machine Learning: K Nearest Neighbors From Scratch With Python
- Master Machine Learning: Decision Trees From Scratch With Python