## Motivation

This post assumes you are familiar with logistic regression and that you just fit your first or second multinomial logistic regression model. While there is an interpretation for the coefficients in a multinomial regression, that interpretation is relative to a base class, which may not be the most useful. Partial dependence plots are an alternative way to understand multinomial regression, and in fact can be used to understand any predictive model. This post explains what partial dependence plots are and how to create them using R.

## Data

I’ll use the built in `iris`

dataset for this post. If you’ve already seen the iris dataset a hundred times, I apologize. Our goal will be to predict the `Species`

of an iris flower based on four numerical measures of the flower: `Sepal.Length`

, `Speal.Width`

, `Petal.Length`

and `Petal.Width`

. There are 150 measurements and three species of iris: `setosa`

, `versicolor`

and `virginica`

.

```
library(tidyverse)
library(skimr)
data <- as_tibble(iris)
glimpse(data)
```

```
## Observations: 150
## Variables: 5
## $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9,...
## $ Sepal.Width <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1,...
## $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5,...
## $ Petal.Width <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1,...
## $ Species <fct> setosa, setosa, setosa, setosa, setosa, setosa, s...
```

## The multinomial logistic regression model

Recall that the probability of an event \(y = 1\) given data \(x \in \mathbb R^p\) in a logistic regression model is:

\[ P(y = 1|x) = {1 \over 1 + \exp(-\beta^T x)} \] where \(\beta \in \mathbb R^p\) is a coefficient vector. Multinomial logistic regression generalizes this relation by assuming that we have \(y \in \{1, 2, ..., K\}\). Then we have coefficient vectors \(\beta_1, ..., \beta_{k-1}\) such that

\[ P(y = k|x) = {\exp(\beta_k^T x) \over 1 + \sum_{k=1}^{K - 1} \exp(\beta_k^T x)} \]

and

\[ P(y = K|x) = {1 \over 1 + \sum_{k=1}^{K - 1} \exp(\beta_k^T x)} \]

There are only \(K-1\) coefficient vectors in order to prevent overparameterization^{1}. The purpose here isn’t to describe the model in any meaningful detail, but rather to remind you of what it looks like. I strongly encourage you to read this fantastic derivation of multinomial logistic regression, which follows the work that lead to McFadden’s Noble prize in economics in 2000.

If you’d like to interpret the coefficients, I recommend reading the Stata page, but I won’t rehash that here. Instead we’ll explore partial dependence plots as a way of understanding the fit model.

## Partial dependence plots

Partial dependence plots are a way to understand the marginal effect of a variable \(x_s\) on the response. The gist goes like this:

- Pick some interesting grid of points in the \(x_s\) dimension
- Typically the observed values of \(x_s\) in the training set

- For each point \(x\) in the grid:
- Replace the \(x_s\) with a bunch of repeated \(x\)s in the training set
- Calculate the average response (class probabilities in our case)

More formally, suppose that we have a data set \(X = [x_s \, x_c] \in \mathbb R^{n \times p}\) where \(x_s\) is a matrix of variables we want to know the partial dependencies for and \(x_c\) is a matrix of the remaining predictors. Suppose we estimate some fit \(\hat f\).

Then \(\hat f_s (x)\), the partial dependence of \(\hat f\) *at* \(x\) (here \(x\) lives in the same space as \(x_s\)), is defined as:

\[\hat f_s(x) = {1 \over n} \sum_{i=1}^n \hat f(x, x_{c_i})\]

This says: hold \(x\) constant for the variables of interest and take the average prediction over all other combinations of other variables in the training set. So we need to pick variables of interest, and also to pick a region of the space that \(x_s\) lives in that we are interested in. Be careful extrapolating the marginal mean of \(f(x)\) outside of this region!

Here’s an example implementation in R. We start by fitting a multinomial regression to the `iris`

dataset.

```
library(nnet)
fit <- multinom(Species ~ ., data, trace = FALSE)
fit
```

```
## Call:
## multinom(formula = Species ~ ., data = data, trace = FALSE)
##
## Coefficients:
## (Intercept) Sepal.Length Sepal.Width Petal.Length Petal.Width
## versicolor 18.69037 -5.458424 -8.707401 14.24477 -3.097684
## virginica -23.83628 -7.923634 -15.370769 23.65978 15.135301
##
## Residual Deviance: 11.89973
## AIC: 31.89973
```

Next we pick the feature we’re interested in estimating partial dependencies for:

`var <- quo(Sepal.Length)`

Now we can split the dataset into this predictor and other predictors:

```
x_s <- select(data, !!var) # grid where we want partial dependencies
x_c <- select(data, -!!var) # other predictors
```

Then we create a dataframe of all combinations of these datasets^{2}:

```
# if the training dataset is large, use a subsample of x_c instead
grid <- crossing(x_s, x_c)
```

We want to know the predictions of \(\hat f\) at each point on this grid. I define a helper in the spirit of `broom::augment()`

for this:

```
library(broom)
augment.multinom <- function(object, newdata) {
newdata <- as_tibble(newdata)
class_probs <- predict(object, newdata, type = "prob")
bind_cols(newdata, as_tibble(class_probs))
}
au <- augment(fit, grid)
au
```

```
## # A tibble: 22,500 x 8
## Sepal.Length Sepal.Width Petal.Length Petal.Width Species setosa
## <dbl> <dbl> <dbl> <dbl> <fct> <dbl>
## 1 5.1 3.5 1.4 0.2 setosa 1.000
## 2 5.1 3 1.4 0.2 setosa 1.000
## 3 5.1 3.2 1.3 0.2 setosa 1.000
## 4 5.1 3.1 1.5 0.2 setosa 1.000
## 5 5.1 3.6 1.4 0.2 setosa 1.000
## 6 5.1 3.9 1.7 0.4 setosa 1.000
## 7 5.1 3.4 1.4 0.3 setosa 1.000
## 8 5.1 3.4 1.5 0.2 setosa 1.000
## 9 5.1 2.9 1.4 0.2 setosa 1.000
## 10 5.1 3.1 1.5 0.1 setosa 1.000
## # ... with 22,490 more rows, and 2 more variables: versicolor <dbl>,
## # virginica <dbl>
```

Now we have the predictions and we marginalize by taking the average for each point in \(x_s\):

```
pd <- au %>%
gather(class, prob, setosa, versicolor, virginica) %>%
group_by(class, !!var) %>%
summarize(marginal_prob = mean(prob))
pd
```

```
## # A tibble: 105 x 3
## # Groups: class [?]
## class Sepal.Length marginal_prob
## <chr> <dbl> <dbl>
## 1 setosa 4.3 0.333
## 2 setosa 4.4 0.333
## 3 setosa 4.5 0.333
## 4 setosa 4.6 0.333
## 5 setosa 4.7 0.333
## 6 setosa 4.8 0.333
## 7 setosa 4.9 0.333
## 8 setosa 5 0.333
## 9 setosa 5.1 0.333
## 10 setosa 5.2 0.333
## # ... with 95 more rows
```

We can visualize this as well:

```
pd %>%
ggplot(aes(!!var, marginal_prob, color = class)) +
geom_line(size = 1) +
labs(title = paste("Partial dependence plot for", quo_name(var)),
y = "Average class probability across all other predictors",
x = quo_name(var)) +
theme_bw() +
scale_color_viridis_d()
```

I won’t show it here, but these values agree exactly with the implementation in the `pdp`

package, which is a good sanity check on our code.

## Partial dependence plots for all the predictors at once

In practice it’s useful to look at partial dependence plots for all of the predictors at once. We can do this by wrapping the code we’ve written so far into a helper function and then mapping over all the predictors.

```
partial_dependence <- function(predictor) {
var <- ensym(predictor)
x_s <- select(data, !!var)
x_c <- select(data, -!!var)
grid <- crossing(x_s, x_c)
augment(fit, grid) %>%
gather(class, prob, setosa, versicolor, virginica) %>%
group_by(class, !!var) %>%
summarize(marginal_prob = mean(prob))
}
all_dependencies <- colnames(iris)[1:4] %>%
map_dfr(partial_dependence) %>%
gather(feature, feature_value, -class, -marginal_prob) %>%
na.omit()
all_dependencies
```

```
## # A tibble: 369 x 4
## # Groups: class [3]
## class marginal_prob feature feature_value
## <chr> <dbl> <chr> <dbl>
## 1 setosa 0.333 Sepal.Length 4.3
## 2 setosa 0.333 Sepal.Length 4.4
## 3 setosa 0.333 Sepal.Length 4.5
## 4 setosa 0.333 Sepal.Length 4.6
## 5 setosa 0.333 Sepal.Length 4.7
## 6 setosa 0.333 Sepal.Length 4.8
## 7 setosa 0.333 Sepal.Length 4.9
## 8 setosa 0.333 Sepal.Length 5
## 9 setosa 0.333 Sepal.Length 5.1
## 10 setosa 0.333 Sepal.Length 5.2
## # ... with 359 more rows
```

Then we can plot everything at once!

```
all_dependencies %>%
ggplot(aes(feature_value, marginal_prob, color = class)) +
geom_line(size = 1) +
facet_wrap(vars(feature), scales = "free_x") +
scale_color_viridis_d() +
labs(title = "Partial dependence plots for all features",
y = "Marginal probability of class",
x = "Value of feature") +
theme_bw()
```

Here we see that `Sepal.Length`

and `Sepal.Width`

don’t influence class probabilites that much on average, but that `Petal.Length`

and `Petal.Width`

do.

## Takeaways

Partial dependence plots are useful tool to understand the marginal behavior of models. The plots are especially helpful when telling a story about what your model means. In this post, I’ve only worked with continuous predictors, but you can calculate partial dependencies for categorical predictors as well, although you’ll probably want to plot them slightly differently. Additionally, it’s natural to consider the partial dependencies of a model when \(x_s\) is multidimensional, in which case you can visualize marginal response surfaces.

I recommend using the `pdp`

package to calculate partial dependencies in practice, and refer you to Christoph Molnar’s excellect book on interpretable machine learning for additional reading.

Some machine learning courses present multinomial regression using a \(K \times p\) coefficient matrix, but then estimate the coefficients with some sort of penalty. The penalty is necessary to prevent the likelihood from becoming infinite (in the \(k \times p\) parameterization, multiplying \(\beta\) by any constant \(c\) retains the same class probabilities while inflating the likelihood). Statisticians are typically more interested in unbiased estimators and present the \((K-1) \times p\) parameterization.↩

At one point I began wondering how to get a small but representative subset of \(x_c\), which lead me down the rabbit hole of sampling from convex sets (for this problem I was imagining using the convex hulls of \(x_c\)). There’s an interesting observation that you can use the Dirichlet distribution for this in an old R-help thread. Then I stumbled across hit and run samplers, which are intuitively satisfying, and finally the

`walkr`

package and the more sophisticated methods it implements. I imagine sampling this is just a hard problem in high dimensions, but if anybody can show me how to convert the convex hull of a dataset calculated using`chull()`

into a format suitable for`walkr`

, please email me!↩