implementing the super learner with tidymodels
Apr 13, 2019
Alex Hayes
17 minute read

Summary

In this post I demonstrate how to implement the Super Learner using tidymodels infrastructure. The Super Learner is an ensembling strategy that relies on cross-validation to determine how to combine predictions from many models. tidymodels provides low-level predictive modeling infrastructure that makes the implementation rather slick. The goal of this post is to show how you can use this infrastructure to build new methods with consistent, tidy behavior. You’ll get the most out of this post if you’ve used rsample, recipes and parsnip before and are comfortable working with list-columns.

How do I fit the super learner?

The Super Learner is an ensembling strategy with nice optimality properties. It’s also not too terrible to implement:

  1. Fit a library of predictive models \(f_1, ..., f_n\) on the full data set
  2. Get heldout predictions from \(f_1, ..., f_n\) using k-fold cross-validation
  3. Train a metalearner on the heldout predictions

Then when you want to predict on new data, you first run the data through \(f_1, ..., f_n\), then take these predictions and send them through the metalearner.

I’ll walk through this step by step in R code, and then we’ll wrap it up into a slightly more reusable function.

Warning: development versions of packages in use

This post relies on new code only that only exists in the development version of packages. In particular we make use of the new tidyr pivoting functionality.

Further, several tidymodels packages are undergoing rapid development and are not particularly stable. Writing this post has taken much longer than originally anticipated because I keep coming across bugs. Also note that the dials API is likely to change in the near future.

I’ll keep the code in this post up to date as package development continues, but don’t be surprised if things break.

Now that you have been thoroughly warned, I believe you should be able to get everything you need with:

install.packages("tidymodels")
install.packages("furrr")

devtools::install_github("tidyverse/tidyr")
devtools::install_github("tidymodels/parsnip")

Implementation

You’ll want to load the requisite packages with:

library(tidymodels)
library(tidyr)

library(furrr)

# use `plan(sequential)` to effectively convert all
# subsequent `future_map*` calls to `map*`
# calls. this will result in sequential execution of 
# embarassingly parallel model fitting procedures
# but may prevent R from getting angry at parallelism

plan(multicore)  

set.seed(27)  # the one true seed

We’ll build a super learner on the iris dataset, which is build into R. iris looks like:

data <- as_tibble(iris)
data
## # A tibble: 150 x 5
##    Sepal.Length Sepal.Width Petal.Length Petal.Width Species
##           <dbl>       <dbl>        <dbl>       <dbl> <fct>  
##  1          5.1         3.5          1.4         0.2 setosa 
##  2          4.9         3            1.4         0.2 setosa 
##  3          4.7         3.2          1.3         0.2 setosa 
##  4          4.6         3.1          1.5         0.2 setosa 
##  5          5           3.6          1.4         0.2 setosa 
##  6          5.4         3.9          1.7         0.4 setosa 
##  7          4.6         3.4          1.4         0.3 setosa 
##  8          5           3.4          1.5         0.2 setosa 
##  9          4.4         2.9          1.4         0.2 setosa 
## 10          4.9         3.1          1.5         0.1 setosa 
## # ... with 140 more rows

We want to predict Species based on Sepal.Length, Sepal.Width, Petal.Length and Petal.Width. While this data isn’t terribly exciting, multiclass classification is the most general case to deal with. Our code should just work for binary classification, and will require only minor modifications for regression problems.

Step 1: Fitting the library of predictive models

First we need to fit a library of predictive models on the full data set. We’ll use parsnip to specify the models, and dials to specify hyperparameter grids. Both parsnip and dials get loaded when you call library(tidymodels).

For now we record the model we want to use. I’m going to fit C5.0 classification trees, where each tree has different hyperparameters:

model <- decision_tree(mode = "classification") %>%
  set_engine("C5.0")

model
## Decision Tree Model Specification (classification)
## 
## Computational engine: C5.0

If you look at ?decision_tree, you’ll see that we need to specify two hyperparameters, min_n and tree_depth, for the C5.0 engine. To do this we’ll create a random hyperparameter grid using dials.

# the dials API is the most unstable out of all
# packages in this post at the moment. the
# following uses dials 0.0.2

hp_grid <- grid_random(
  min_n %>% range_set(c(2, 20)),
  tree_depth,
  size = 10
)

hp_grid
## # A tibble: 10 x 2
##    min_n tree_depth
##    <int>      <int>
##  1    20         10
##  2     3         13
##  3    18          9
##  4     8         12
##  5     6         14
##  6     9         13
##  7     3         15
##  8     2          3
##  9     4         10
## 10     5          2

Now we create a tibble with a list-column of completed model specifications (C5.0 trees where we’ve specified the hyperparameter values). It’ll be useful to keep track of precisely which tree we’re working with, so we also add a model_id column:

spec_df <- tibble(spec = merge(model, hp_grid)) %>% 
  mutate(model_id = row_number())

spec_df
## # A tibble: 10 x 2
##    spec      model_id
##    <list>       <int>
##  1 <spec[+]>        1
##  2 <spec[+]>        2
##  3 <spec[+]>        3
##  4 <spec[+]>        4
##  5 <spec[+]>        5
##  6 <spec[+]>        6
##  7 <spec[+]>        7
##  8 <spec[+]>        8
##  9 <spec[+]>        9
## 10 <spec[+]>       10

Now that we’ve specified our library of models, we’ll describe the data design we’d like to use with a recipe. For giggles, we’ll use the first two principle components:

recipe <- data %>% 
  recipe(Species ~ .) %>% 
  step_pca(all_predictors(), num_comp = 2)

recipe
## Data Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          4
## 
## Operations:
## 
## PCA extraction with all_predictors()

Now we can wrap up the first step and fit each of these trees on the full dataset. Here I use furrr::future_map() to do this in parallel, taking advantage of the embarrassingly parallel nature of model fitting.

prepped <- prep(recipe, training = data)

x <- juice(prepped, all_predictors())
y <- juice(prepped, all_outcomes())

full_fits <- spec_df %>% 
  mutate(fit = future_map(spec, fit_xy, x, y))

full_fits
## # A tibble: 10 x 3
##    spec      model_id fit     
##    <list>       <int> <list>  
##  1 <spec[+]>        1 <fit[+]>
##  2 <spec[+]>        2 <fit[+]>
##  3 <spec[+]>        3 <fit[+]>
##  4 <spec[+]>        4 <fit[+]>
##  5 <spec[+]>        5 <fit[+]>
##  6 <spec[+]>        6 <fit[+]>
##  7 <spec[+]>        7 <fit[+]>
##  8 <spec[+]>        8 <fit[+]>
##  9 <spec[+]>        9 <fit[+]>
## 10 <spec[+]>       10 <fit[+]>

Step 2: Getting holdout predictions

We’ll use rsample to generate the resampled datasets for 10-fold cross-validation, like so:

folds <- vfold_cv(data, v = 10)

We will want to fit a model on each fold, which is a mapping operation like before. We define a helper that will fit one of our trees (defined by a parsnip model specification) on a given fold, and pass the data in the form of a trained recipe object, which we call prepped:

fit_on_fold <- function(spec, prepped) {
  
  x <- juice(prepped, all_predictors())
  y <- juice(prepped, all_outcomes())
  
  fit_xy(spec, x, y)
}

Now we create a tibble containing all combinations of the cross-validation resamples and all the tree specifications:

crossed <- crossing(folds, spec_df)
crossed
## # A tibble: 100 x 4
##    splits           id     spec      model_id
##    <list>           <chr>  <list>       <int>
##  1 <split [135/15]> Fold01 <spec[+]>        1
##  2 <split [135/15]> Fold01 <spec[+]>        2
##  3 <split [135/15]> Fold01 <spec[+]>        3
##  4 <split [135/15]> Fold01 <spec[+]>        4
##  5 <split [135/15]> Fold01 <spec[+]>        5
##  6 <split [135/15]> Fold01 <spec[+]>        6
##  7 <split [135/15]> Fold01 <spec[+]>        7
##  8 <split [135/15]> Fold01 <spec[+]>        8
##  9 <split [135/15]> Fold01 <spec[+]>        9
## 10 <split [135/15]> Fold01 <spec[+]>       10
## # ... with 90 more rows

The fitting procedure is then the longest part of the whole process, and looks like:

cv_fits <- crossed %>%
  mutate(
    prepped = future_map(splits, prepper, recipe),
    fit = future_map2(spec, prepped, fit_on_fold)
  )

Now that we have our fits, we need to get holdout predictions. Recall that we trained each fit on the analysis() set, but we want to get holdout predictions using the assessment() set. There are a lot of moving pieces here, so we define a prediction helper function that includes the original row number of each prediction:

predict_helper <- function(fit, new_data, recipe) {
  
  # new_data can either be an rsample::rsplit object
  # or a data frame of genuinely new data
  
  if (inherits(new_data, "rsplit")) {
    obs <- as.integer(new_data, data = "assessment")
    
    # never forget to bake when predicting with recipes!
    new_data <- bake(recipe, assessment(new_data))
  } else {
    obs <- 1:nrow(new_data)
    new_data <- bake(recipe, new_data)
  }
  
  # if you want to generalize this code to a regression
  # super learner, you'd need to set `type = "response"` here
  
  predict(fit, new_data, type = "prob") %>% 
    mutate(obs = obs)
}

Now we use our helper to get predictions for each fold, for each hyperparameter combination. The preds column will be a list-column, so we unnest() to take a look.

holdout_preds <- cv_fits %>% 
  mutate(
    preds = future_pmap(list(fit, splits, prepped), predict_helper)
  )

holdout_preds %>% 
  unnest(preds)
## # A tibble: 1,500 x 6
##    id     model_id .pred_setosa .pred_versicolor .pred_virginica   obs
##    <chr>     <int>        <dbl>            <dbl>           <dbl> <int>
##  1 Fold01        1      0.986            0.00693         0.00709     6
##  2 Fold01        1      0.986            0.00693         0.00709    14
##  3 Fold01        1      0.986            0.00693         0.00709    28
##  4 Fold01        1      0.986            0.00693         0.00709    44
##  5 Fold01        1      0.00811          0.984           0.00794    60
##  6 Fold01        1      0.00811          0.984           0.00794    61
##  7 Fold01        1      0.00811          0.984           0.00794    75
##  8 Fold01        1      0.00811          0.984           0.00794    76
##  9 Fold01        1      0.00811          0.984           0.00794    87
## 10 Fold01        1      0.00811          0.984           0.00794    92
## # ... with 1,490 more rows

Now we have to shape this into something we can train a metalearner on, which means we need:

  • 1 row per original observation
  • 1 column per regression tree and outcome category

Getting data into this kind of tidy format is exactly what tidyr excels at. Here we need to go from a long format to a wide format, which will often be the case when working with models in list columns1.

The new pivot_wider() function exactly solves this our reshaping problem once we realize that:

  • The row number of each observation in the original dataset is in the obs column
  • The .pred_* columns contain the values of interest
  • The model_id column identifies what the names of the new columns should be.

We’re going to need to use this operation over and over again, so we’ll put it into a function.

spread_nested_predictions <- function(data) {
  data %>% 
    unnest(preds) %>% 
    pivot_wider(
      id_cols = obs,
      names_from = model_id,
      values_from = contains(".pred")
    )
}

holdout_preds <- spread_nested_predictions(holdout_preds)
holdout_preds
## # A tibble: 150 x 31
##      obs .pred_setosa_1 .pred_setosa_2 .pred_setosa_3 .pred_setosa_4
##    <int>          <dbl>          <dbl>          <dbl>          <dbl>
##  1     6        0.986          0.986          0.986          0.986  
##  2    14        0.986          0.986          0.986          0.986  
##  3    28        0.986          0.986          0.986          0.986  
##  4    44        0.986          0.986          0.986          0.986  
##  5    60        0.00811        0.00811        0.00811        0.00811
##  6    61        0.00811        0.00811        0.00811        0.00811
##  7    75        0.00811        0.00811        0.00811        0.00811
##  8    76        0.00811        0.00811        0.00811        0.00811
##  9    87        0.00811        0.00811        0.00811        0.00811
## 10    92        0.00811        0.00811        0.00811        0.00811
## # ... with 140 more rows, and 26 more variables: .pred_setosa_5 <dbl>,
## #   .pred_setosa_6 <dbl>, .pred_setosa_7 <dbl>, .pred_setosa_8 <dbl>,
## #   .pred_setosa_9 <dbl>, .pred_setosa_10 <dbl>, .pred_versicolor_1 <dbl>,
## #   .pred_versicolor_2 <dbl>, .pred_versicolor_3 <dbl>,
## #   .pred_versicolor_4 <dbl>, .pred_versicolor_5 <dbl>,
## #   .pred_versicolor_6 <dbl>, .pred_versicolor_7 <dbl>,
## #   .pred_versicolor_8 <dbl>, .pred_versicolor_9 <dbl>,
## #   .pred_versicolor_10 <dbl>, .pred_virginica_1 <dbl>,
## #   .pred_virginica_2 <dbl>, .pred_virginica_3 <dbl>,
## #   .pred_virginica_4 <dbl>, .pred_virginica_5 <dbl>,
## #   .pred_virginica_6 <dbl>, .pred_virginica_7 <dbl>,
## #   .pred_virginica_8 <dbl>, .pred_virginica_9 <dbl>,
## #   .pred_virginica_10 <dbl>

We’re almost ready to fit a the metalearning model on top of these predictions, but first we need to join these predictions back to the original dataset using obs to recover the labels!

meta_train <- data %>% 
  mutate(obs = row_number()) %>% 
  right_join(holdout_preds, by = "obs") %>% 
  select(Species, contains(".pred"))

meta_train
## # A tibble: 150 x 31
##    Species .pred_setosa_1 .pred_setosa_2 .pred_setosa_3 .pred_setosa_4
##    <fct>            <dbl>          <dbl>          <dbl>          <dbl>
##  1 setosa         0.986          0.986          0.986          0.986  
##  2 setosa         0.986          0.986          0.986          0.986  
##  3 setosa         0.986          0.986          0.986          0.986  
##  4 setosa         0.986          0.986          0.986          0.986  
##  5 versic~        0.00811        0.00811        0.00811        0.00811
##  6 versic~        0.00811        0.00811        0.00811        0.00811
##  7 versic~        0.00811        0.00811        0.00811        0.00811
##  8 versic~        0.00811        0.00811        0.00811        0.00811
##  9 versic~        0.00811        0.00811        0.00811        0.00811
## 10 versic~        0.00811        0.00811        0.00811        0.00811
## # ... with 140 more rows, and 26 more variables: .pred_setosa_5 <dbl>,
## #   .pred_setosa_6 <dbl>, .pred_setosa_7 <dbl>, .pred_setosa_8 <dbl>,
## #   .pred_setosa_9 <dbl>, .pred_setosa_10 <dbl>, .pred_versicolor_1 <dbl>,
## #   .pred_versicolor_2 <dbl>, .pred_versicolor_3 <dbl>,
## #   .pred_versicolor_4 <dbl>, .pred_versicolor_5 <dbl>,
## #   .pred_versicolor_6 <dbl>, .pred_versicolor_7 <dbl>,
## #   .pred_versicolor_8 <dbl>, .pred_versicolor_9 <dbl>,
## #   .pred_versicolor_10 <dbl>, .pred_virginica_1 <dbl>,
## #   .pred_virginica_2 <dbl>, .pred_virginica_3 <dbl>,
## #   .pred_virginica_4 <dbl>, .pred_virginica_5 <dbl>,
## #   .pred_virginica_6 <dbl>, .pred_virginica_7 <dbl>,
## #   .pred_virginica_8 <dbl>, .pred_virginica_9 <dbl>,
## #   .pred_virginica_10 <dbl>

Step 3: Fit the metalearner

I’m going to use a multinomial regression as the metalearner. You can use any metalearner that does multiclass classification here, but I’m going with something simple because I don’t want to obscure the logic with additional hyperparameter search here.

# these settings correspond to multinomial regression
# with a small ridge penalty. the ridge penalty makes
# sure this doesn't explode when the number of columns
# of heldout predictions is greater than the number of
# observations in the original data set
#
# in practice, you'll probably want to avoid base learner
# libraries that large due to difficulties estimating
# the relative performance of the base learners

metalearner <- multinom_reg(penalty = 0.01, mixture = 0) %>% 
  set_engine("glmnet") %>% 
  fit(Species ~ ., meta_train)

metalearner
## parsnip model object
## 
## 
## Call:  glmnet::glmnet(x = as.matrix(x), y = y, family = "multinomial",      alpha = ~0, lambda = ~0.01) 
## 
##      Df   %Dev Lambda
## [1,] 30 0.8685   0.01

That’s it! We’ve fit the super learner! Just like the training process, prediction itself proceeds involves two separate stages:

new_data <- head(iris)

# run the new data through the library of base learners first

base_preds <- full_fits %>% 
  mutate(preds = future_map(fit, predict_helper, new_data, prepped)) %>% 
  spread_nested_predictions()

# then through the metalearner

predict(metalearner, base_preds, type = "prob")
## # A tibble: 6 x 3
##   .pred_setosa .pred_versicolor .pred_virginica
##          <dbl>            <dbl>           <dbl>
## 1        0.978           0.0201         0.00197
## 2        0.978           0.0201         0.00197
## 3        0.978           0.0201         0.00197
## 4        0.978           0.0201         0.00197
## 5        0.978           0.0201         0.00197
## 6        0.978           0.0201         0.00197

Putting it all together

Now we can take all the code we’ve written up and encapsulate it into a single function (still relying on the helper functions we defined above).

Note that this is a reference implementation and in practice I recommend following the tidymodels recommendations when implementing new methods. Luckily, we do end up inherit a fair amount of nice consistency from parsnip itself.

#' Fit the super learner!
#'
#' @param library A data frame with a column `spec` containing
#'   complete `parsnip` model specifications for the base learners 
#'   and a column `model_id`.
#' @param recipe An untrained `recipe` specifying data design
#' @param meta_spec A singe `parsnip` model specification
#'   for the metalearner.
#' @param data The dataset to fit the super learner on.
#'
#' @return A list with class `"super_learner"` and three elements:
#'
#'   - `full_fits`: A tibble with list-column `fit` of fit
#'     base learners as parsnip `model_fit` objects
#'
#'   - `metalearner`: The metalearner as a single parsnip
#'     `model_fit` object
#'
#'   - `recipe`: A trained version of the original recipe
#'
super_learner <- function(library, recipe, meta_spec, data) {
  
  folds <- vfold_cv(data, v = 5)
  
  cv_fits <- crossing(folds, library) %>%
    mutate(
      prepped = future_map(splits, prepper, recipe),
      fit = future_pmap(list(spec, prepped), fit_on_fold)
    )
  
  prepped <- prep(recipe, training = data)
  
  x <- juice(prepped, all_predictors())
  y <- juice(prepped, all_outcomes())
  
  full_fits <- library %>% 
    mutate(fit = future_map(spec, fit_xy, x, y))
  
  holdout_preds <- cv_fits %>% 
    mutate(
      preds = future_pmap(list(fit, splits, prepped), predict_helper)
    ) %>% 
    spread_nested_predictions() %>% 
    select(-obs)
  
  metalearner <- fit_xy(meta_spec, holdout_preds, y)
  
  sl <- list(full_fits = full_fits, metalearner = metalearner, recipe = prepped)
  class(sl) <- "super_learner"
  sl
}

We also write an S3 predict method:

predict.super_learner <- function(x, new_data, type = c("class", "prob")) {
  
  type <- rlang::arg_match(type)
  
  new_preds <- x$full_fits %>% 
    mutate(preds = future_map(fit, predict_helper, new_data, x$recipe)) %>% 
    spread_nested_predictions() %>% 
    select(-obs)
    
  predict(x$metalearner, new_preds, type = type)
}

Our helpers do assume that we’re working on a classification problem, but other than this we pretty much only rely on the parsnip API. This means we can mix and match parts to our hearts desire and things should still work. For example, we can build off the parsnip classification vignette, which starts like so:

data_split <- credit_data %>% 
  na.omit() %>% 
  initial_split(strata = "Status", p = 0.75)

credit_train <- training(data_split)
credit_test  <- testing(data_split)

credit_recipe <- recipe(Status ~ ., data = credit_train) %>%
  step_center(all_numeric()) %>%
  step_scale(all_numeric())

But now let’s fit a Super Learner based on a stack of MARS fits instead of a neural net. You could also mix in other arbitrary models2. First we take a moment to set up the specification:

credit_model <- mars(mode = "classification", prune_method = "backward") %>% 
  set_engine("earth")

credit_hp_grid <- grid_random(
  num_terms %>% range_set(c(1, 30)),
  prod_degree,
  size = 5
)

credit_library <- tibble(spec = merge(credit_model, credit_hp_grid)) %>% 
  mutate(model_id = row_number())

credit_meta <- multinom_reg(penalty = 0, mixture = 1) %>% 
  set_engine("glmnet")

Now we do the actual fitting and take a quick coffee break:

credit_sl <- super_learner(
  credit_library,
  credit_recipe,
  credit_meta,
  credit_train
)

Since we inherit the tidymodels predict() conventions, getting a holdout ROC curve is as easy as:

pred <- predict(credit_sl, credit_test, type = "prob")

pred %>% 
  bind_cols(credit_test) %>% 
  roc_curve(Status, .pred_bad) %>%
  autoplot()

Wrap up

That’s it! We’ve fit a clever ensembling technique in a few lines of code! Hopefully the concepts are clear and you can start to play with ensembling on your own. I should note that this post uses a ton of different tidymodels abstractions, which can be intimidating. The goal here is to demonstrate how to integrate all of various components together into a big picture. If you aren’t familiar with the individual tidymodels packages, my impression is that the best way to gain this familiarity is by gradually working through the various tidymodels vignettes.

In practice, it is a bit of relief to be done with this post. I’ve been playing around with implementing the Super Learner since summer 2017, but each time I gave it a shot things got messy much faster than I anticipated and I kicked the task down the line. Only recently have the tools to make the Super Learner implementation so pleasant come to life3. Thanks Max and Davis!

If you want to use the Super Learner in practice, I believe the sl3 package is the most actively developed. There’s also Eric Polley’s classic SuperLearner package, which may be more full featured than sl3 at the moment. Also be sure to check out h2o::automl(), which makes stacking about as painless as can be if you just need results!

References

If you’re new to the Super Learner, I recommend starting with LeDell (2015a). Section 2.2 of LeDell (2015b) is similar but goes into more detail. Laan, Polley, and Hubbard (2007) is the original Super Learner paper and contains the proof the oracle property, an optimality result. Polley and Laan (2010) discusses the Super Learner from a more applied point of view, with some simulations demonstrating performance. Laan and Rose (2018) is a comprehensive reference on both the Super Learner and TMLE. The Super Learner papers and book are targeted at a research audience with a high level of mathematical background, and are not easy reading. Wolpert (1992) is another often cited paper on stacking that is more approachable.

Laan, Mark J. van der, Eric C Polley, and Alan E. Hubbard. 2007. “Super Learner.” https://biostats.bepress.com/cgi/viewcontent.cgi?article=1226&context=ucbbiostat.

Laan, Mark J. van der, and Sherri Rose. 2018. Targeted Learning in Data Science. Springer Series in Statistics. Springer International Publishing. https://doi.org/10.1007/978-3-319-65304-4.

LeDell, Erin. 2015a. “Intro to Practical Ensemble Learning.” https://www.stat.berkeley.edu/~ledell/docs/dlab_ensembles.pdf.

———. 2015b. “Scalable Ensemble Learning and Computationally Efficient Variance Estimation.” PhD thesis. https://www.stat.berkeley.edu/~ledell/papers/ledell-phd-thesis.pdf.

Polley, Eric C, and Mark J. van der Laan. 2010. “Super Learner in Prediction.” https://biostats.bepress.com/ucbbiostat/paper266/.

Wolpert, David H. 1992. “Stacked Generalization.” http://www.machine-learning.martinsewell.com/ensembles/stacking/Wolpert1992.pdf.


  1. Note that multiclass prediction is hardest to deal with because we have multiple prediction columns. For binary classification and regression, we’d only have a single column containing predictions, making the tidying easier.

  2. For some reason, I got poor test performance when I tried this, and I’m not sure why. I’ve asked Erin Ledell on Twitter and will update if I get a response, hopefully including a more full example here.

  3. The newness of these tools also means that some of them aren’t entirely stable, however, and I found some bugs while writing this post.



comments powered by Disqus