understanding multinomial regression with partial dependence plots

some intuition for multinomial regression’s initially intimidating functional form

methods
Author
Published

October 23, 2018

Motivation

This post assumes you are familiar with logistic regression and that you just fit your first or second multinomial logistic regression model. While there is an interpretation for the coefficients in a multinomial regression, that interpretation is relative to a base class, which may not be the most useful. Partial dependence plots are an alternative way to understand multinomial regression, and in fact can be used to understand any predictive model. This post explains what partial dependence plots are and how to create them using R.

Data

I’ll use the built in iris dataset for this post. If you’ve already seen the iris dataset a hundred times, I apologize. Our goal will be to predict the Species of an iris flower based on four numerical measures of the flower: Sepal.Length, Speal.Width, Petal.Length and Petal.Width. There are 150 measurements and three species of iris: setosa, versicolor and virginica.

library(tidyverse)
library(skimr)

data <- as_tibble(iris)
glimpse(data)
Rows: 150
Columns: 5
$ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.…
$ Sepal.Width  <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.…
$ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.…
$ Petal.Width  <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.…
$ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa, s…

The multinomial logistic regression model

Recall that the probability of an event \(y = 1\) given data \(x \in \mathbb R^p\) in a logistic regression model is:

\[ P(y = 1|x) = {1 \over 1 + \exp(-\beta^T x)} \] where \(\beta \in \mathbb R^p\) is a coefficient vector. Multinomial logistic regression generalizes this relation by assuming that we have \(y \in \{1, 2, ..., K\}\). Then we have coefficient vectors \(\beta_1, ..., \beta_{k-1}\) such that

\[ P(y = k|x) = {\exp(\beta_k^T x) \over 1 + \sum_{k=1}^{K - 1} \exp(\beta_k^T x)} \]

and

\[ P(y = K|x) = {1 \over 1 + \sum_{k=1}^{K - 1} \exp(\beta_k^T x)} \]

There are only \(K-1\) coefficient vectors in order to prevent overparameterization1. The purpose here isn’t to describe the model in any meaningful detail, but rather to remind you of what it looks like. I strongly encourage you to read this fantastic derivation of multinomial logistic regression, which follows the work that lead to McFadden’s Noble prize in economics in 2000.

If you’d like to interpret the coefficients, I recommend reading the Stata page, but I won’t rehash that here. Instead we’ll explore partial dependence plots as a way of understanding the fit model.

Partial dependence plots

Partial dependence plots are a way to understand the marginal effect of a variable \(x_s\) on the response. The gist goes like this:

  1. Pick some interesting grid of points in the \(x_s\) dimension
    • Typically the observed values of \(x_s\) in the training set
  2. For each point \(x\) in the grid:
    • Replace the \(x_s\) with a bunch of repeated \(x\)s in the training set
    • Calculate the average response (class probabilities in our case)

More formally, suppose that we have a data set \(X = [x_s \, x_c] \in \mathbb R^{n \times p}\) where \(x_s\) is a matrix of variables we want to know the partial dependencies for and \(x_c\) is a matrix of the remaining predictors. Suppose we estimate some fit \(\hat f\).

Then \(\hat f_s (x)\), the partial dependence of \(\hat f\) at \(x\) (here \(x\) lives in the same space as \(x_s\)), is defined as:

\[\hat f_s(x) = {1 \over n} \sum_{i=1}^n \hat f(x, x_{c_i})\]

This says: hold \(x\) constant for the variables of interest and take the average prediction over all other combinations of other variables in the training set. So we need to pick variables of interest, and also to pick a region of the space that \(x_s\) lives in that we are interested in. Be careful extrapolating the marginal mean of \(f(x)\) outside of this region!

Here’s an example implementation in R. We start by fitting a multinomial regression to the iris dataset.

library(nnet)

fit <- multinom(Species ~ ., data, trace = FALSE)
fit
Call:
multinom(formula = Species ~ ., data = data, trace = FALSE)

Coefficients:
           (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
versicolor    18.69037    -5.458424   -8.707401     14.24477   -3.097684
virginica    -23.83628    -7.923634  -15.370769     23.65978   15.135301

Residual Deviance: 11.89973 
AIC: 31.89973 

Next we pick the feature we’re interested in estimating partial dependencies for:

var <- quo(Sepal.Length)

Now we can split the dataset into this predictor and other predictors:

x_s <- select(data, !!var)   # grid where we want partial dependencies
x_c <- select(data, -!!var)  # other predictors

Then we create a dataframe of all combinations of these datasets2:

# if the training dataset is large, use a subsample of x_c instead
grid <- crossing(x_s, x_c)

We want to know the predictions of \(\hat f\) at each point on this grid. I define a helper in the spirit of broom::augment() for this:

library(broom)

augment.multinom <- function(object, newdata) {
  newdata <- as_tibble(newdata)
  class_probs <- predict(object, newdata, type = "prob")
  bind_cols(newdata, as_tibble(class_probs))
}

au <- augment(fit, grid)
au
# A tibble: 5,005 × 8
   Sepal.Length Sepal.Width Petal.Length Petal…¹ Spe…²   setosa versi…³ virgin…⁴
          <dbl>       <dbl>        <dbl>   <dbl> <fct>    <dbl>   <dbl>    <dbl>
 1          4.3         2            3.5     1   vers… 2.15e-11 1.00e+0 2.34e- 7
 2          4.3         2.2          4       1   vers… 9.89e-14 1.00e+0 6.84e- 6
 3          4.3         2.2          4.5     1.5 vers… 4.76e-17 1.27e-1 8.73e- 1
 4          4.3         2.2          5       1.5 virg… 3.96e-22 1.31e-3 9.99e- 1
 5          4.3         2.3          1.3     0.3 seto… 9.99e- 1 7.32e-4 6.71e-26
 6          4.3         2.3          3.3     1   vers… 5.06e- 9 1.00e+0 4.82e- 9
 7          4.3         2.3          4       1.3 vers… 5.98e-13 9.99e-1 8.33e- 4
 8          4.3         2.3          4.4     1.3 vers… 1.94e-15 9.65e-1 3.48e- 2
 9          4.3         2.4          3.3     1   vers… 1.21e- 8 1.00e+0 2.48e- 9
10          4.3         2.4          3.7     1   vers… 4.05e-11 1.00e+0 1.07e- 7
# … with 4,995 more rows, and abbreviated variable names ¹​Petal.Width,
#   ²​Species, ³​versicolor, ⁴​virginica

Now we have the predictions and we marginalize by taking the average for each point in \(x_s\):

pd <- au %>%
  gather(class, prob, setosa, versicolor, virginica) %>% 
  group_by(class, !!var) %>%
  summarize(marginal_prob = mean(prob))
pd
# A tibble: 105 × 3
# Groups:   class [3]
   class  Sepal.Length marginal_prob
   <chr>         <dbl>         <dbl>
 1 setosa          4.3         0.322
 2 setosa          4.4         0.322
 3 setosa          4.5         0.322
 4 setosa          4.6         0.322
 5 setosa          4.7         0.322
 6 setosa          4.8         0.322
 7 setosa          4.9         0.322
 8 setosa          5           0.322
 9 setosa          5.1         0.322
10 setosa          5.2         0.322
# … with 95 more rows

We can visualize this as well:

pd %>%
  ggplot(aes(!!var, marginal_prob, color = class)) +
  geom_line(size = 1) +
  scale_color_viridis_d() +
  labs(title = paste("Partial dependence plot for", quo_name(var)),
       y = "Average class probability across all other predictors",
       x = quo_name(var)) +
  theme_classic()

I won’t show it here, but these values agree exactly with the implementation in the pdp package, which is a good sanity check on our code.

Partial dependence plots for all the predictors at once

In practice it’s useful to look at partial dependence plots for all of the predictors at once. We can do this by wrapping the code we’ve written so far into a helper function and then mapping over all the predictors.

partial_dependence <- function(predictor) {
  
  var <- ensym(predictor)
  x_s <- select(data, !!var)
  x_c <- select(data, -!!var)
  grid <- crossing(x_s, x_c)

  augment(fit, grid) %>% 
    gather(class, prob, setosa, versicolor, virginica) %>% 
    group_by(class, !!var) %>%
    summarize(marginal_prob = mean(prob))
}

all_dependencies <- colnames(iris)[1:4] %>% 
  map_dfr(partial_dependence) %>% 
  gather(feature, feature_value, -class, -marginal_prob) %>% 
  na.omit()

all_dependencies
# A tibble: 369 × 4
# Groups:   class [3]
   class  marginal_prob feature      feature_value
   <chr>          <dbl> <chr>                <dbl>
 1 setosa         0.322 Sepal.Length           4.3
 2 setosa         0.322 Sepal.Length           4.4
 3 setosa         0.322 Sepal.Length           4.5
 4 setosa         0.322 Sepal.Length           4.6
 5 setosa         0.322 Sepal.Length           4.7
 6 setosa         0.322 Sepal.Length           4.8
 7 setosa         0.322 Sepal.Length           4.9
 8 setosa         0.322 Sepal.Length           5  
 9 setosa         0.322 Sepal.Length           5.1
10 setosa         0.322 Sepal.Length           5.2
# … with 359 more rows

Then we can plot everything at once!

all_dependencies %>% 
  ggplot(aes(feature_value, marginal_prob, color = class)) +
  geom_line(size = 1) +
  facet_wrap(vars(feature), scales = "free_x") +
  scale_color_viridis_d() +
  labs(title = "Partial dependence plots for all features",
       y = "Marginal probability of class",
       x = "Value of feature") +
  theme_classic()

Here we see that Sepal.Length and Sepal.Width don’t influence class probabilites that much on average, but that Petal.Length and Petal.Width do.

Takeaways

Partial dependence plots are useful tool to understand the marginal behavior of models. The plots are especially helpful when telling a story about what your model means. In this post, I’ve only worked with continuous predictors, but you can calculate partial dependencies for categorical predictors as well, although you’ll probably want to plot them slightly differently. Additionally, it’s natural to consider the partial dependencies of a model when \(x_s\) is multidimensional, in which case you can visualize marginal response surfaces.

I recommend using the pdp package to calculate partial dependencies in practice, and refer you to Christoph Molnar’s excellent book on interpretable machine learning for additional reading.

Footnotes

  1. Some machine learning courses present multinomial regression using a \(K \times p\) coefficient matrix, but then estimate the coefficients with some sort of penalty. The penalty is necessary to prevent the likelihood from becoming infinite (in the \(k \times p\) parameterization, multiplying \(\beta\) by any constant \(c\) retains the same class probabilities while inflating the likelihood). Statisticians are typically more interested in unbiased estimators and present the \((K-1) \times p\) parameterization.↩︎

  2. At one point I began wondering how to get a small but representative subset of \(x_c\), which lead me down the rabbit hole of sampling from convex sets (for this problem I was imagining using the convex hulls of \(x_c\)). There’s an interesting observation that you can use the Dirichlet distribution for this in an old R-help thread. Then I stumbled across hit and run samplers, which are intuitively satisfying, and finally the walkr package and the more sophisticated methods it implements. I imagine sampling this is just a hard problem in high dimensions, but if anybody can show me how to convert the convex hull of a dataset calculated using chull() into a format suitable for walkr, please email me!↩︎

Reuse