a tidymodels workflow in python using list columns in pandas dataframes

python

workflow

tidymodels

Author

Alex Hayes

Published

2020-08-25

This summer I worked on my first substantial research project in Python. I’ve used Python for a number of small projects before, but this was the first time that it was important for me to have an efficient workflow for working with many models at once. In R, I spend almost all of my time using a ‘many models’ workflow that leverages list-columns in tibbles and a hefty amount of tidyverse manipulation. It’s taken me a while to find a similar workflow in Python that I like, and so I’m documenting it here for myself and other parties who might benefit.

This post demonstrates how to organize models into dataframes for exploratory data analysis, and discusses why you might want to do this. In a future blog post I will show how to extend the basic workflow I present here to handle sample splitting, custom estimators, and parallel processing.

Interlude for Pythonistas: many models workflows

The many models workflow is an extension of the ‘split-apply-combine’ workflow, a largely functional approach to data manipulation implemented in dplyr and pandas. The essential ideas of ‘split-apply-combine’ are articulated nicely in Hadley Wickham’s short and easy to read paper, which I strongly encourage you to read. I assume you are comfortable with this general idea, which largely facilitates natural ways to compute descriptive statistics from data, and have some experience applying these concepts in either dplyr, pandas, or SQL.

Several years ago, this idea evolved: what if, instead of computing descriptive statistics, we wanted to compute more complicated estimands, while leveraging the power of grouped operations? Hadley’s solution, which has proven very fruitful and served as the underlying idea for tidymodels, tidyverts, and several other modeling frameworks, is to put model objects themselves into dataframes. Hadley presented on this idea, and also wrote about it in the many models chapter of R for Data Science and it has turned out to be quite fruitful.

Why is putting model objects in dataframes a good idea?

The simple answer is that it keeps information about your models from drifting apart, as it tends to do otherwise.

My exploratory modeling often ends up looking something like this:

I want to compare models across a range of hyperparameter values. Often there are several distinct hyperparameters to consider at once^{1}.

I want to look at many different properties of the model. As a contrived example, I might want to look at AIC, BIC, \(R^2\) and RMSE for all my models^{2}.

I want to quick create plots to compare these models, and am probably using a plotting libraries expects data in data frames

Anyway it turns out that dataframes handle this use case super well, provided we have some helpers. The overall workflow will look like this:

Specifying models to fit: We organize a dataframe so that each row of the dataframe corresponds to one model we want to fit. For example, a single row of the dataframe might correspond to a single set of hyperparameter values, or a subset of the data.

Iterative model fitting: We create a new column in the dataframes that holds fit models.

Iterative estimate extraction: We extract information we want from the fits into new data frame columns, and then manipulate this data with our favorite dataframe or plotting libraries.

Note that steps (2) and (3) require iteration over many models. In functional languages, map-style operations are a natural (and easily parallelizable!) way to do this; in Python we can use list-comprehensions.

Another innovation in this workflow came from standardizing step (3), where we extract information from the models into a new column of the dataframe. A big issue that we can run into here is that when we extract information from the model object, it can have an inconvenient type that is hard to put in a dataframe column. This may seem esoteric but it turns out to matter a lot more than you’d expect.

David Robinson in his broom package proposed a solution that is increasingly the standard within the R community^{3}. The idea is to create special getter methods for model objects that always return information in consistently formatted data frames. For each model object, we get a data frame with information. Since there are many model objects, we end up with a column of dataframes, which we then flatten.

There has been a lot of pushback around the idea of a column of dataframes with the Python community, largely on the basis that it is not a performant data structure^{4}. This misses the point. The compute time in these workflows comes from model fitting, and afterwards we just want to keep track of things^{5}.

Anyway, there’s a lot more to say about the workflow conceptually, but for now, it is time to show some code and see how things work in practice. For this post, I am going to recreate the analysis that Hadley does in his linked presentation above, which makes use of the gapminder data. The gapminder data consists of life expectancies for 142 counties, reported every 5 years from 1952 to 2007.

Setting the scene

import matplotlib.pyplot as pltimport pandas as pdimport seaborn as snsimport statsmodels.formula.api as smf# you can download the gapminder csv file from my google drive# https://drive.google.com/file/d/1pAgIKdsZPwYteQ2rfq14WrUtS23Tg4Lu/gapminder = pd.read_csv('gapminder.csv')gapminder

country

continent

year

lifeExp

pop

gdpPercap

0

Afghanistan

Asia

1952

28.801

8425333

779.445314

1

Afghanistan

Asia

1957

30.332

9240934

820.853030

2

Afghanistan

Asia

1962

31.997

10267083

853.100710

3

Afghanistan

Asia

1967

34.020

11537966

836.197138

4

Afghanistan

Asia

1972

36.088

13079460

739.981106

...

...

...

...

...

...

...

1699

Zimbabwe

Africa

1987

62.351

9216418

706.157306

1700

Zimbabwe

Africa

1992

60.377

10704340

693.420786

1701

Zimbabwe

Africa

1997

46.809

11404948

792.449960

1702

Zimbabwe

Africa

2002

39.989

11926563

672.038623

1703

Zimbabwe

Africa

2007

43.487

12311143

469.709298

1704 rows × 6 columns

To get a feel for the data, lets plot the life expectancy of each country over time, facetting by continent^{6}.

Following Hadley’s presentation, suppose we would like to summarize the trend for each country by fitting a linear regression to the data from each country. So we have a correspondence 1 model ~ data from 1 country, and want to set up our data frame so that each row corresponds to data from a single country.

groupby() plus a list-comprehension handles this nicely, leveraging the fact that gapminder.groupby('country') is an iterable. In R, you could also use group_by() for this step, or additionally nest() or rowwise(), two tidyverse specification abstractions.

models = pd.DataFrame({# this works because grouped dataframes in pandas are iterable# and because you can pretty much treat series objects like# they are lists'data': [data for _, data in gapminder.groupby('country')],})models.index = [country for country, _ in gapminder.groupby('country')]# the downside of putting weird stuff into pandas dataframes is that# the dataframes print poorlymodels

data

Afghanistan

country continent year lifeExp ...

Albania

country continent year lifeExp pop ...

Algeria

country continent year lifeExp pop...

Angola

country continent year lifeExp pop ...

Argentina

country continent year lifeExp p...

...

...

Vietnam

country continent year lifeExp p...

West Bank and Gaza

country continent year life...

Yemen, Rep.

country continent year lifeExp ...

Zambia

country continent year lifeExp po...

Zimbabwe

country continent year lifeExp ...

142 rows × 1 columns

Step 2: iterative model fitting

Now we need to do the actual model fitting. My preferred approach is to use list-comprehensions.

def country_model(df):return smf.ols('lifeExp ~ year', data=df).fit()models['fit'] = [ country_model(data)for _, data in gapminder.groupby('country')]

One compelling advantage of this (effectively) functional approach to iteration over list-columns of models is that most of these computations are embarrassingly parallel, and map()-like operations are often very easy to parallelize.

An alternative approach here is to use DataFrame.apply(). However, I have found the Series.apply() and DataFrame.apply() methods to be hard to reason about when used together with list-columns, and so I recommend avoiding them.

Now that we’ve fit all of our models, we can extract information from them. Here I’ll define some helper functions very much in the spirit of broom. When you don’t own the model classes you’re using, you pretty much have to write extractor functions to do this; see Michael Chow’s excellent blog post showing how to handle this in an elegant way.

Even if you do own the model objects you’re using, I recommend extractor functions over class methods. This is because, during EDA, you typically fit some expensive models once, and then repeatedly investigate them–you can modify an extractor function and use it right away, but if you modify a method for a model class, you’ll have to refit all the model objects. This leads to slow iteration.

# ripped directly from michael chow's blog post!!# go read his stuff it's very cool!def tidy(fit):from statsmodels.iolib.summary import summary_params_frame tidied = summary_params_frame(fit).reset_index() rename_cols = {'index': 'term', 'coef': 'estimate', 'std err': 'std_err','t': 'statistic', 'P>|t|': 'p_value','Conf. Int. Low': 'conf_int_low', 'Conf. Int. Upp.': 'conf_int_high' }return tidied.rename(columns = rename_cols)def glance(fit):return pd.DataFrame({'aic': fit.aic,'bic': fit.bic,'ess': fit.ess, # explained sum of squares'centered_tss': fit.centered_tss,'fvalue': fit.fvalue,'f_pvalue': fit.f_pvalue,'nobs': fit.nobs,'rsquared': fit.rsquared,'rsquared_adj': fit.rsquared_adj }, index=[0])# note that augment() takes 2 inputs, whereas tidy() and glance() take 1def augment(fit, data): df = data.copy()iflen(df) != fit.nobs:raiseValueError("`data` does not have same number of observations as in training data.") df['fitted'] = fit.fittedvalues df['resid'] = fit.residreturn df

We sanity check the helper functions by seeing if they work on a single model object before working with the entire list-column of models.

tidy(models.fit[0])

term

estimate

std_err

statistic

p_value

conf_int_low

conf_int_high

0

Intercept

-507.534272

40.484162

-12.536613

1.934055e-07

-597.738606

-417.329937

1

year

0.275329

0.020451

13.462890

9.835213e-08

0.229761

0.320896

glance(models.fit[0])

aic

bic

ess

centered_tss

fvalue

f_pvalue

nobs

rsquared

rsquared_adj

0

40.69387

41.663683

271.006011

285.958116

181.24941

9.835213e-08

12.0

0.947712

0.942483

augment() actually takes two inputs, where one input is the model object, and the other is the training data used to fit that model object.

augment(models.fit[0], models.data[0])

country

continent

year

lifeExp

pop

gdpPercap

fitted

resid

0

Afghanistan

Asia

1952

28.801

8425333

779.445314

29.907295

-1.106295

1

Afghanistan

Asia

1957

30.332

9240934

820.853030

31.283938

-0.951938

2

Afghanistan

Asia

1962

31.997

10267083

853.100710

32.660582

-0.663582

3

Afghanistan

Asia

1967

34.020

11537966

836.197138

34.037225

-0.017225

4

Afghanistan

Asia

1972

36.088

13079460

739.981106

35.413868

0.674132

5

Afghanistan

Asia

1977

38.438

14880372

786.113360

36.790512

1.647488

6

Afghanistan

Asia

1982

39.854

12881816

978.011439

38.167155

1.686845

7

Afghanistan

Asia

1987

40.822

13867957

852.395945

39.543798

1.278202

8

Afghanistan

Asia

1992

41.674

16317921

649.341395

40.920442

0.753558

9

Afghanistan

Asia

1997

41.763

22227415

635.341351

42.297085

-0.534085

10

Afghanistan

Asia

2002

42.129

25268405

726.734055

43.673728

-1.544728

11

Afghanistan

Asia

2007

43.828

31889923

974.580338

45.050372

-1.222372

Now we are ready to iterate over the list-column of models. Again, we leverage list-comprehensions. For tidy() and glance() these comprehension are straightforward, but for augment(), which will consume elements from two columns at once, we will need to do something a little fancier. In R we could use purrr::map2() or purrr::pmap(), but the Pythonic idiom here is to use zip() together with tuple unpacking.

models['tidied'] = [tidy(fit) for fit in models.fit]models['glanced'] = [glance(fit) for fit in models.fit]models['augmented'] = [ augment(fit, data)for fit, data inzip(models.fit, models.data)]

Note for R users: In R the calls to tidy(), etc, would typically live inside a mutate() call. The pandas equivalent is assign, but pandas doesn’t leverage non-standard evaluation and I typically don’t use assign() unless I really want to leverage method chaining for some reason. Normally I save method chaining for data manipulation once I have a flat dataframe, and otherwise I operate entirely via list comprehensions.

Anyway, the print method garbles the hell out of our results but whatever.

models

data

fit

tidied

glanced

augmented

Afghanistan

country continent year lifeExp ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

aic bic ess centered_t...

country continent year lifeExp ...

Albania

country continent year lifeExp pop ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

aic bic ess centered_...

country continent year lifeExp pop ...

Algeria

country continent year lifeExp pop...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

aic bic ess centered_...

country continent year lifeExp pop...

Angola

country continent year lifeExp pop ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

aic bic ess centered_t...

country continent year lifeExp pop ...

Argentina

country continent year lifeExp p...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic ...

aic bic ess centered_ts...

country continent year lifeExp p...

...

...

...

...

...

...

Vietnam

country continent year lifeExp p...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

aic bic ess centered...

country continent year lifeExp p...

West Bank and Gaza

country continent year life...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

aic bic ess centered_...

country continent year life...

Yemen, Rep.

country continent year lifeExp ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

aic bic ess centered...

country continent year lifeExp ...

Zambia

country continent year lifeExp po...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

aic bic ess centered_t...

country continent year lifeExp po...

Zimbabwe

country continent year lifeExp ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

aic bic ess centered_ts...

country continent year lifeExp ...

142 rows × 5 columns

Flattening the dataframe

Our final step before we can recreate Hadley’s plots is to flatten or “unnest” the dataframe. pandas has no unnest method, but the following has served me well so far to unnest a single column. This will not play well with dataframes with MultiIndexes, which I recommend avoiding.

glance_results = unnest(models, 'glanced')# equivalentlyglance_results = ( models .pipe(unnest, 'glanced')# little bit of extra cleanup .reset_index() .rename(columns={'index': 'country'}))glance_results

country

data

fit

tidied

augmented

aic

bic

ess

centered_tss

fvalue

f_pvalue

nobs

rsquared

rsquared_adj

0

Afghanistan

country continent year lifeExp ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

country continent year lifeExp ...

40.693870

41.663683

271.006011

285.958116

181.249410

9.835213e-08

12.0

0.947712

0.942483

1

Albania

country continent year lifeExp pop ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

country continent year lifeExp pop ...

52.298071

53.267884

400.445959

439.771289

101.829014

1.462763e-06

12.0

0.910578

0.901636

2

Algeria

country continent year lifeExp pop...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

country continent year lifeExp pop...

42.584427

43.554240

1158.583855

1176.087314

661.917086

1.808143e-10

12.0

0.985117

0.983629

3

Angola

country continent year lifeExp pop ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic...

country continent year lifeExp pop ...

44.061857

45.031670

156.667858

176.464605

79.138182

4.593498e-06

12.0

0.887815

0.876596

4

Argentina

country continent year lifeExp p...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statistic ...

country continent year lifeExp p...

6.347866

7.317679

191.937384

192.791819

2246.366349

4.215567e-13

12.0

0.995568

0.995125

...

...

...

...

...

...

...

...

...

...

...

...

...

...

...

137

Vietnam

country continent year lifeExp p...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

country continent year lifeExp p...

42.414079

43.383892

1612.565329

1629.822069

934.455357

3.289412e-11

12.0

0.989412

0.988353

138

West Bank and Gaza

country continent year life...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

country continent year life...

52.287427

53.257240

1291.726331

1331.016795

328.763323

5.585089e-09

12.0

0.970481

0.967529

139

Yemen, Rep.

country continent year lifeExp ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

country continent year lifeExp ...

46.932773

47.902586

1310.527555

1335.675109

521.135193

5.868274e-10

12.0

0.981172

0.979290

140

Zambia

country continent year lifeExp po...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

country continent year lifeExp po...

72.117172

73.086985

13.053046

218.145441

0.636447

4.435318e-01

12.0

0.059836

-0.034180

141

Zimbabwe

country continent year lifeExp ...

<statsmodels.regression.linear_model.Regressio...

term estimate std_err statisti...

country continent year lifeExp ...

83.262706

84.232520

30.934127

550.116442

0.595824

4.580290e-01

12.0

0.056232

-0.038145

142 rows × 14 columns

Now we could ask a question like “what countries seem to have the most linear trends in life expectancy?” and use R-squared as a measure of this.

Okay so this plot is awful but I don’t have the patience at the moment to make it better. We could also look at residuals for individual fits to inspect them for any patterns that might indicate systematic error.

It would be nice to add smooths by continent as Hadley does but again I don’t have the patience or masochistic urge to figure out how to do that. In any case, there are some clear trends in the residuals, especially for Africa, which suggest that some further modeling is a good idea.

The end

So that’s the basic idea behind the many models workflow. Note that we’ve been working at a fairly low-level of abstraction. This means you have a lot of control over what happens (good for research and EDA), but have to write a lot of code. If you’re just fitting prediction models and the only thing you want to do is compare risk estimates, you can save time and effort by using sklearn’s GridSearchCV object.

One final note: in Hadley’s gapminder example we iterate over disjoint data sets. In practice I do this extremely rarely. Much more often I find myself iterating over (train, test) pairs, or hyperparameters, or both at once. This hyperparameter optimization over many CV-folds workflow is more complex than the simple example here, but still fits nicely into the many models workflow I’ve described here. I’ll demonstrate how to do that in a followup post.

Footnotes

It often makes sense to store hyperparameters in dicts in Python. This means you can’t easily store model objects in dicts, because model_store[hyperparameters] = my_model_object gets all fussy because the hyperparameter dictionary isn’t hashable.↩︎

One natural approach here would be to have a list of model objects, a list of AIC values, a list of BIC values, etc. Now you run into an indexing issue where you have to figure out which index corresponds to a given model and keeping track of a bunch of maps like this. A natural solution is to say force all the indexes to match up – everything with index 0 should correspond to the first model. Congratulations, you’ve just invented a data frame.↩︎

Another big selling point of dataframes is vectorized operations on the columns. However, you can’t vectorize operations on model objects, and this leads to code that has to perform vectorize explicitly rather than letting dataframe libraries handle it implicitly. There is a definitely a learning curve for new users here but I’m pretty convinced it’s worth it.↩︎

Another complaint that you might have as a Pythonista is that you should construct a custom model object to handle all this model tracking. I’m actually pretty curious about this and would love to hear ideas. My suspition is that the big challenge is coming up with a sufficiently general framework to accommodate widely varying use cases.↩︎

Aside: I have found plotting in Python to be a largely unpleasant experience. At this point, I’ve pretty much settled on seaborn as my go-to plotting library, and I use sns.FacetGrid for almost everything, even when I don’t need facetting, because enough of my ggplot2 intuition carries over that I can mostly get things done.↩︎