Join 8,587 professionals who get data science tutorials, book recommendations, and tips/tricks each Saturday.

New to Python and SQL? I'll teach you both for free.

I won't send you spam. Unsubscribe at any time.

Issue #11 - Interpreting ML Models with Feature Importance

This Week’s Tutorial

If you want to write code to follow along, you can get the datasets from the DIY Data Scientist GitHub and learn about decision trees with my free ML crash courses.

This week's tutorial will use the Adult Census dataset and will demonstrate a hypothetical scenario of explaining a ML model using feature importance:

Based on profiling the data, the following features will be used in the model:

Want to know more about data profiling for ML? Check out the tutorial series in the newsletter back issues.

The data needs to be prepped (e.g., one-hot encoded) for training the ML model:

And the labels need to be encoded because they are strings:

To make this tutorial a bit more realistic, I'm going to used a tuned DecisionTreeClassifier from Python's scikit-learn library:

Using cross-validation, the above tree is estimated as having the following predictive performance on new, unseen data (e.g., data that would be encountered in a production deployment):

  • 81.3% accuracy on average.

  • A range of 79.22% to 83.38% accuracy 95% of the time.

Let's assume that these estimates are within business requirements and you decide to proceed to final model testing.

BTW - My Introduction to Machine Learning online course offered in partnership with TDWI will teach you how to tune and estimate model predictive performance using decision trees and random forests.

First step in final testing - loading and preparing the test data:

With the test dataset prepped, time to score the model:

The tuned DecisionTreeClassifier scored 81.42% accuracy which is smack in the middle of the estimates! Woohoo!

At this stage, you have a model that meets business requirements for a production deployment. For example, using the model interpretations for insights into improving business processes.

The default way to interpret a decision tree is to visualize the model:

Whoa! That's a big decision tree!

Can you imagine walking your business stakeholders though the decision tree? In my experience, it's a non-starter for all the reasons covered in the first tutorial in this series.

Complex models are the norm in DIY data science due to the premium on predictive performance.

This is a perfect example of why, in practice, we need more ML model interpretability tools than just looking at logistic/linear regression and decision tree models.

Typically, you won't get much buy-in (maybe none) for your ML models without briefing your business stakeholders on the model.

Given that the decision tree is too complex for direct interpretation with stakeholders, we need to take a step back. We already know that:

  1. The final test accuracy meets business requirements.

  2. The final test accuracy is within the estimated ranges, which is a powerful indication (but not a guarantee) that the model will perform within business requirements on new data.

The next thing we can look at is how well the model's predictions performed at the next level of detail. This is where a confusion matrix can help interpret the model:

The confusion matrix provides you additional visibility into the model's predictions beyond what overall accuracy can tell you.

Each row of the confusion matrix shows the model's predictive performance for each label (i.e., <=50K vs. >50K) in the test dataset.

Correct predictions run along the diagonal, starting from the top left and going to the bottom right. Using the confusion matrix, you can easily calculate the following:

  • The accuracy for the <=50K labels is:

    • 2910 / (2910 + 790) = 78.65%

  • The accuracy for the >50K labels is:

    • 3115 / (3115 + 585) = 84.19%

The model is about 5.5% more accurate in predicting incomes greater than $50,000/year than incomes less than that.

Notice how the confusion matrix already provides an interpretation that business stakeholders can easily understand?

NOTE - I typically don't show confusion matrices to my business stakeholders. I've found it's usually more confusing than helpful. I normally summarize confusion matrix findings as bullets in a PowerPoint slide.

It's scarce in practice to have a group of stakeholders that will accept only the following as the model interpretation:

  • Estimated average model accuracy is 81.3%.

  • However, the model's accuracy will range from 79.2% to 83.4%.

  • The model scored 81.4% accuracy on the test dataset.

  • The model is 5.5% accurate for high-income earners on the test dataset.

In my experience, business stakeholders want their explanation in terms of the features that drive the model's predictions.

Enter permutation feature importance.

Permutation feature importance is a powerful tool for interpreting ML models because it works with any algorithm: logistic regression, decision trees, random forests, etc.

Here's the base intuition of how it works:

  1. ML models find features with the most information to make the best predictions possible given the training data.

  2. If you remove these features, the model's predictive performance declines.

Since we have a trained model, we can't really remove the features, but we can do something else that comes close.

We can randomize (i.e., permute) the feature values and then see how much the predictive performance declines as a result. The features with the most information (i.e., the most important) will result in the largest declines in predictive performance.

The intuitive way to think about this is in the extreme worst case scenario. If a feature is permuted and the predictive performance doesn't decline at all, that feature has no information useful for making predictions!

Just in case you're wondering, permutation feature importance works on copies of the original data.

Permutation feature performance takes a trained model, a dataset, and then permutes each feature one at a time (the other features remain unchanged) and records the resulting decline in predictive performance.

For example, the DecisionTreeClassifier was trained using eight features. Here's how permutation feature importance works with this model:

  1. The Age feature is permuted and then predictions are made using the model with the other test dataset (un-permuted) features.

  2. The accuracy in the predictions with permuted Age values is compared to the accuracy with the original Age values (the decline is calculated and stored).

  3. This process is then repeated with the one-hot encoded Education features.

  4. Then with the one-hot encoded MaritalStatus features...

So on and so forth - until all the features have been permuted and all the accuracy declines are calculated and stored.

The main downside of permutation performance is that it is compute-intensive (i.e., it can take a long time to run). However, the wait is worth the model insights provided.

Here's the code using scikit-learn in Python:

A couple of things to note about the code above.

First, the n_repeats parameter determines how often the entire process is repeated. Results are averaged across the repeats. The more repeats, the more accurate the results.

How high should you set n_repeats? As high as you can stand because high values mean it takes longer to run. Generally speaking, set n_repeats to at least 10.

Second, the n_jobs parameter allows the process to run in parallel. The higher the value, the faster it will run.

NOTE - Be careful with the n_jobs parameter! Ensure your computer has enough CPU and memory to handle the n_jobs value you set.

The result of the above code is a DataFrame listing the features in decreasing order of importance:

The MaritalStatus_Married-civ-spouse was found to be the most important.

Permuting this feature, on average across 25 repeats, reduced the model's predictive performance by a whopping 11.59% with the test dataset!

It's worth noting that you can tell that this is a one-hot encoded categorical feature based on the name. Simply randomizing the True/False (or 1/0) values is enough to impact the model's predictive performance significantly.

The next two most important features are Age and CapitalGain, which reduce model accuracy by 4.9% and 4.5%, respectively.

After that, the accuracy impacts drop off dramatically. For example, permuting the CapitalLoss feature only reduces predictive performance by 0.77%.

Using this information, you can provide additional interpretation:

  • Estimated average model accuracy is 81.3%.

  • However, the model's accuracy will range from 79.2% to 83.4%.

  • The model scored 81.4% accuracy on the test dataset.

  • The model is 5.5% accurate for high-income earners on the test dataset.

  • The most important features for making predictions are whether someone is a married civilian, their age, and the amount of any capital gains they have.

One of the great things about permutation feature importance is that it automatically considers any feature interactions. However, this is hidden because it's implicit in the importance scores.

For example, the interaction between Age and CaptialGain is implicitly included in each feature's importance score.

While this is a great start, it's often not enough for business stakeholders to accept your model interpretations/explanations.

We need more tools.

This Week’s Book

This is one of my all-time favorite books on machine learning. It was the textbook that introduced me to ML while earning my master's in computer science:

This book focuses on the most useful algorithms for DIY data science. It's code agnostic (i.e., it uses pseudocode), and the math requirements are not excessive. Overall, a great book.

That's it for this week.

Stay tuned for next week's newsletter, where I will teach how to use data visualizations to interpret important feature interactions.

Stay healthy and happy data sleuthing!

Dave Langer

Whenever you're ready, there are 4 ways I can help you:

1 - Are you new to data analysis? My Visual Analysis with Python online course will teach you the fundamentals you need - fast. No complex math required, and it works with Python in Excel!

2 - Cluster Analysis with Python: Most of the world's data is unlabeled and can't be used for predictive models. This is where my self-paced online course teaches you how to extract insights from your unlabeled data.

3 - Introduction to Machine Learning: This self-paced online course teaches you how to build predictive models like regression trees and the mighty random forest using Python. Offered in partnership with TDWI, use code LANGER to save 20% off.

4 - Is machine learning right for your business, but don't know where to start? Check out my Machine Learning Accelerator.