Black-box models are a thing of the past – even with deep learning. You can use SHAP to interpret the predictions of deep learning models, and it requires only a couple of lines of code. Today you’ll learn how on the well-known MNIST dataset.
Convolutional neural networks can be tough to understand. A network learns the optimal feature extractors (kernels) from the image. These features are useful to detect any patterns that help the network to classify images correctly.
Your brain isn’t that much different. It also uses a series of patterns to recognize objects in front of you. For example, what makes a number zero a zero? It’s a round to oval outlined shape with nothing inside. That’s a general pattern the kernels behind convolutional layers try to learn.
If you want to represent your model’s interpretations visually, look no further than SHAP (SHapely Additive exPlanations) – a game theoretic approach to explaining the output of any machine learning model. You can refer to this article for a complete beginners guide.
The article is structured as follows:
You can download the corresponding Notebook here.
Defining the model architecture
You’ll use PyTorch to train a simple handwritten digit classifier. It’s a go-to Python library for deep learning, both in research and in business. If you haven’t used PyTorch before but have some Python experience, it will feel natural.
Before defining the model architecture, you’ll have to import a couple of libraries. Most of these are related to PyTorch, and
shap will be used later:
The model architecture is simple and borrowed from the official documentation. Feel free to declare your own architecture, but this one is good enough for our needs:
The following section shows you how to train the model.
Training the model
Let’s start by declaring a couple of variables:
batch_size– how many images are shown to the model at once
num_epochs– number of complete passes through the training dataset
device– specifies is the training done on CPU or GPU. Replace
cpuif you don’t have a CUDA-compatible GPU
Next, you’ll declare couple of functions –
test(). These will be used to train and evaluate the model on separate subsets and print the intermediate results.
The entire code snippet is shown below:
Next, you can download the datasets with the
torchvision.datasets module. The datasets are then loaded and transformed (conversion to tensor and normalization) and organized in batches:
And now you have everything ready for model training. Here’s how to instantiate the model and train it for the previously declared number of epochs:
You’ll see the intermediate results printed out during the training phase. Here’s how they look on my machine:
Keep in mind – the actual values may differ slightly on your machine, but you should land north of 95% accuracy on the test set.
Next step – interpretations with SHAP!
Interpreting the model
Prediction interpretation is now as simple as writing a couple of lines of code. The following snippet loads in a batch of random images from the test set and interprets predictions for five of them:
After executing the above code snippet, you’ll see the following image:
Input images are displayed on the left, and the interpretations for every class on the right. Anything colored in red increases the model output (the model is more confident in the classification), while everything colored blue decreases it.
That’s how SHAP explanations work with convolutional neural networks. Let’s wrap things up in the next section.
Today you’ve learned how to create a basic convolutional neural network model for classifying handwritten digits with PyTorch. You’ve also learned how to explain the predictions made by the model.
Brining this interpretation skillset to your domain is now as simple as changing the dataset and model architecture. Explanation code should be identical or require minimal changes to accommodate for different subsets.
Thanks for reading.
- SHAP: How to Interpret Machine Learning Models with Python
- LIME: How to Interpret Machine Learning Models with Python
- LIME vs. SHAP: Which is Better for Explaining Machine Learning Models?
- 3 Essential Ways to Calculate Feature Importance in Python
- Are The New M1 Macbooks Any Good for Data Science? Let’s Find Out