The feature importance (variable importance) describes which features are relevant. It can help with a better understanding of the solved problem and sometimes lead to model improvements by employing feature selection. In this post, I will present 3 ways (with code examples) how to compute feature importance for the Random Forest algorithm from scikit-learn
package (in Python).
You will learn how to compute and plot:
SHAP
values.The Random Forest algorithm has built-in feature importance which can be computed in two ways:
I will show how to compute feature importance for the Random Forest with scikit-learn
package and Boston dataset (house price regression task).
# Let's load the packages
import numpy as np
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.inspection import permutation_importance
import shap
from matplotlib import pyplot as plt
plt.rcParams.update({'figure.figsize': (12.0, 8.0)})
plt.rcParams.update({'font.size': 14})
Load the data set and split it for training and testing.
boston = load_boston()
X = pd.DataFrame(boston.data, columns=boston.feature_names)
y = boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=12)
Fit the Random Forest Regressor with 100 Decision Trees:
rf = RandomForestRegressor(n_estimators=100)
rf.fit(X_train, y_train)
To get the feature importances from the Random Forest model use the feature_importances_
attribute:
rf.feature_importances_
array([0.04054781, 0.00149293, 0.00576977, 0.00071805, 0.02944643,
0.25261155, 0.01969354, 0.05781783, 0.0050257 , 0.01615872,
0.01066154, 0.01185997, 0.54819617])
Let’s plot the importances (a chart will be easier to interpret than values).
plt.barh(boston.feature_names, rf.feature_importances_)
To have an even better chart, let’s sort the features, and plot again:
sorted_idx = rf.feature_importances_.argsort()
plt.barh(boston.feature_names[sorted_idx], rf.feature_importances_[sorted_idx])
plt.xlabel("Random Forest Feature Importance")
scikit-learn
)The permutation-based importance can be used to overcome drawbacks of default feature importance computed with mean impurity decrease. It is implemented in scikit-learn as permutation_importance method. As arguments, it requires a trained model (can be any model compatible with scikit-learn API) and validation (test data). This method will randomly shuffle each feature and compute the change in the model’s performance. The features which impact the performance the most are the most important ones.
The permutation importance can be easily computed:
perm_importance = permutation_importance(rf, X_test, y_test)
To plot the importance:
sorted_idx = perm_importance.importances_mean.argsort()
plt.barh(boston.feature_names[sorted_idx], perm_importance.importances_mean[sorted_idx])
plt.xlabel("Permutation Importance")
The permutation-based importance is computationally expensive. The permutation-based method can have problems with highly-correlated features, it can report them as unimportant.
The SHAP interpretation can be used (it is model-agnostic) to compute the feature importances from the Random Forest. It is using the Shapley values from game theory to estimate how each feature contributes to the prediction. It can be easily installed (pip install shap
) and used with scikit-learn
Random Forest:
explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test)
To plot feature importance as the horizontal bar plot we need to use summary_plot
method:
shap.summary_plot(shap_values, X_test, plot_type="bar")
The feature importance can be plotted with more details, showing the feature value:
shap.summary_plot(shap_values, X_test)
The computing feature importance with SHAP can be computationally expensive. However, it can provide more information like decision plots or dependence plots.
The 3 ways to compute the feature importance for the scikit-learn
Random Forest were presented:
In my opinion, it is always good to check all methods and compare the results.