Join 7,715 professionals who get data science tutorials, book recommendations, and tips/tricks each Saturday.
New to Python? I'll teach you for free.
I won't send you spam. Unsubscribe at any time.
Issue #13 - Using Surrogate Models for Interpretation
This Week’s Tutorial
Using a surrogate model is a powerful technique for explaining machine learning models to your business stakeholders.
Surrogate models "summarize" complex models to make them explainable.
For example, take the decision tree model from the 3rd tutorial in this series:
The above decision tree model is far too complex for stakeholders (and me 🤣) to interpret directly. However, the model is a good fit for the data:
Estimated average model accuracy is 81.3%.
The model's accuracy will range from 79.2% to 83.4%.
The model scored 81.4% accuracy on the test dataset.
Complex models like this are common in real-world machine learning, so you want to use them. However, you often can't use them unless your business stakeholders approve.
This is where surrogate models can be helpful:
The complex model's job is to predict the test dataset as accurately as possible.
The surrogate model's job is to summarize the complex model so that you can explain what's going on to your business stakeholders and get approval.
The purpose of a surrogate model isn't to accurately predict the labels on the test dataset - that's the complex model's job.
The surrogate model's job is to provide a more straightforward explanation of how (i.e., summarize) the complex model makes its predictions for the test dataset.
A surrogate model accomplishes this by learning how to predict the complex model's predictions for the test dataset.
This is a bit abstract, so the code in this tutorial will demonstrate how to create a surrogate decision tree model to summarize the above complex decision tree.
As discussed earlier in this series, a decision tree model that is three layers deep is a good candidate for a surrogate model because limiting explanation to three factors (or less) makes explanations more believable.
NOTE - I will assume you're familiar with the code of the 3rd tutorial in this series.
The following code will train a decision tree that is three layers deep using the test dataset and the complex model's predictions (i.e., test_preds):
With the surrogate model trained, the next step is to understand how good a summary of the complex model it represents. The higher the prediction accuracy, the better the summary:
The surrogate model is 87.32% accurate in predicting the complex model's predictions. The surrogate model is a good summary!
However, to gain even more insight into how good the surrogate model summarizes the complex model, it's a good idea to check out the confusion matrix:
In the above confusion matrix, the complex model's predictions are the True labels and the surrogate model's predictions are the Predicted labels.
From the confusion matrix we can quickly calculate the following:
The surrogate model is 83.35% accurate in predicting the complex model's <=50K predictions.
The surrogate model is 90.88% accurate in predicting the complex model's >50K predictions.
After examining the confusion matrix, the surrogate model provides a reasonable summary of the complex model - especially for high-income earners.
Now that the surrogate model is confirmed as a good summary, it's time to visualize the surrogate model:
We can now use the surrogate model visualization to provide explanations to business stakeholders. Consider the following path through the tree with the leaf node highlighted:
The highlighted path provides the following explanation in natural language:
"US citizens that do not have a marital status of married civilian spouse and have capital gains greater than $7,055.50 are predicted to have incomes greater than $50,000 a year."
Going through the remaining six leaf nodes provides the following natural language explanations:
"US citizens that do not have a marital status of married civilian spouse and have capital gains less than or equal to $7,055.50 are predicted to have incomes less than or equal to $50,000 per year."
"US citizens that have a marital status of married civilian spouse, are less than or equal to 29.5 years old, and do not have a bachelor's degree are predicted to have incomes less than or equal to $50,000 a year."
"US citizens that have a marital status of married civilian spouse, are less than or equal to 29.5 years old, and have a bachelor's degree are predicted to have incomes greater than $50,000 a year."
"US citizens that have a marital status of married civilian spouse, are older than 29.5 years, and do not have the occupation of Handlers-cleaners are predicted to have incomes greater than $50,000 a year."
"US citizens that have a marital status of married civilian spouse, are older than 29.5 years, and have the occupation of Handlers-cleaners are predictive to have incomes less than or equal to $50,000 a year."
Note that a surrogate model, by definition, can never provide a 100% thorough interpretation of a complex model.
So, if you use a surrogate model, you must be transparent with your business stakeholders that the surrogate model summarizes the complex model.
Despite this limitation, surrogate models are a powerful tool for interpreting complex ML models.
This Week’s Book
A professional looking recently asked me to learn data analysis for the first time for a book recommendation. I recommended the following book because it is excellent and affordable:
That's it for this week.
Stay tuned for next week's newsletter, the start of a new tutorial series on a valuable machine learning technique for any DIY data scientist - hierarchical clustering.
Stay healthy and happy data sleuthing!
Dave Langer
Whenever you're ready, there are 4 ways I can help you:
1 - Are you new to data analysis? My Visual Analysis with Python online course will teach you the fundamentals you need - fast. No complex math required, and it works with Python in Excel!
2 - Cluster Analysis with Python: Most of the world's data is unlabeled and can't be used for predictive models. This is where my self-paced online course teaches you how to extract insights from your unlabeled data.
3 - Introduction to Machine Learning: This self-paced online course teaches you how to build predictive models like regression trees and the mighty random forest using Python. Offered in partnership with TDWI, use code LANGER to save 20% off.
4 - Is machine learning right for your business, but don't know where to start? Check out my Machine Learning Accelerator.