Introducing {workboots}

Generate bootstrap prediction intervals from a tidymodel workflow!

Sometimes, we want a model that generates a range of possible outcomes around each prediction and may opt for a model that can generate a prediction interval, like a linear model. Other times, we just care about point predictions and may opt to use a more powerful model like XGBoost. But what if we want the best of both worlds: getting a range of predictions while still using a powerful model? That’s where {workboots} comes to the rescue! {workboots} uses bootstrap resampling to train many models which can be used to generate a range of outcomes — regardless of model type.

Installation

Version 0.1.0 of {workboots} is available on CRAN. Given that the package is still in early development, however, I’d recommend installing the development version from github:

# install from CRAN
install.packages("workboots")

# or install the development version
devtools::install_github("markjrieke/workboots")

Usage

{workboots} builds on top of the {tidymodels} suite of packages and is intended to be used in conjunction with a tidymodel workflow. Teaching how to use {tidymodels} is beyond the scope of this post, but some helpful resources are linked at the bottom for further exploration.

We’ll walk through two examples that show the benefit of the package: estimating a linear model’s prediction interval and generating a prediction interval for a boosted tree model.

Estimating a prediction interval

Let’s get started with a model we know can generate a prediction interval: a basic linear model. In this example, we’ll use the Ames housing dataset to predict a home’s price based on its square footage.

library(tidymodels)

# setup our data
data("ames")
ames_mod <- ames %>% select(First_Flr_SF, Sale_Price)

# relationship between square footage and price
ames_mod %>%
  ggplot(aes(x = First_Flr_SF, y = Sale_Price)) +
  geom_point(alpha = 0.25) +
  scale_y_continuous(labels = scales::dollar_format(), trans = "log10") +
  scale_x_continuous(labels = scales::comma_format(), trans = "log10") +
  labs(title = "Relationship between Square Feet and Sale Price",
       subtitle = "Linear relationship between the log transforms of square footage and price",
       x = NULL,
       y = NULL)

We can use a linear model to predict the log transform of Sale_Price based on the log transform of First_Flr_SF. In this example, we’ll train a linear model then plot our predictions against a holdout set with a prediction interval.

# log transform
ames_mod <- 
  ames_mod %>%
  mutate(across(everything(), log10))

# split into train/test data
set.seed(918)
ames_split <- initial_split(ames_mod)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)
# train a linear model
set.seed(314)
mod <- lm(Sale_Price ~ First_Flr_SF, data = ames_train)

# predict on new data with a prediction interval
ames_preds <-
  mod %>%
  predict(ames_test, interval = "predict") %>%
  as_tibble()

# plot!
ames_preds %>%
  
  # re-scale predictions to match the original dataset's scale
  bind_cols(ames_test) %>%
  mutate(across(everything(), ~10^.x)) %>%
  
  # add geoms
  ggplot(aes(x = First_Flr_SF)) +
  geom_point(aes(y = Sale_Price),
             alpha = 0.25) +
  geom_line(aes(y = fit),
            size = 1) +
  geom_ribbon(aes(ymin = lwr,
                  ymax = upr),
              alpha = 0.25) +
  scale_y_continuous(labels = scales::dollar_format(), trans = "log10") +
  scale_x_continuous(labels = scales::comma_format(), trans = "log10") +
  labs(title = "Linear Model of Sale Price predicted by Square Footage",
       subtitle = "Shaded area represents the 95% prediction interval",
       x = NULL,
       y = NULL) 

With {workboots}, we can approximate the linear model’s prediction interval by passing a workflow built on a linear model to the function predict_boots().

library(tidymodels)
library(workboots)

# setup a workflow with a linear model
ames_wf <-
  workflow() %>%
  add_recipe(recipe(Sale_Price ~ First_Flr_SF, data = ames_train)) %>%
  add_model(linear_reg())

# generate bootstrap predictions on ames_test
set.seed(713)
ames_preds_boot <-
  ames_wf %>%
  predict_boots(
    n = 2000,
    training_data = ames_train,
    new_data = ames_test
  )

predict_boots() works by creating 2000 bootstrap resamples of the training data, fitting a linear model to each resample, then generating 2000 predictions for each home’s price in the holdout set. We can then use summarise_predictions() to generate upper and lower intervals for each prediction.

ames_preds_boot %>%
  summarise_predictions()
## # A tibble: 733 x 5
##    rowid .preds               .pred_lower .pred .pred_upper
##    <int> <list>                     <dbl> <dbl>       <dbl>
##  1     1 <tibble [2,000 x 2]>        5.17  5.44        5.71
##  2     2 <tibble [2,000 x 2]>        4.98  5.27        5.55
##  3     3 <tibble [2,000 x 2]>        4.97  5.25        5.52
##  4     4 <tibble [2,000 x 2]>        5.12  5.40        5.67
##  5     5 <tibble [2,000 x 2]>        5.15  5.44        5.71
##  6     6 <tibble [2,000 x 2]>        4.93  5.21        5.49
##  7     7 <tibble [2,000 x 2]>        4.67  4.94        5.22
##  8     8 <tibble [2,000 x 2]>        4.85  5.13        5.40
##  9     9 <tibble [2,000 x 2]>        4.87  5.14        5.41
## 10    10 <tibble [2,000 x 2]>        5.14  5.41        5.69
## # ... with 723 more rows

By overlaying the intervals on top of one another, we can see that the prediction interval generated by predict_boots() is a good approximation of the theoretical interval generated by lm().

ames_preds_boot %>%
  summarise_predictions() %>%
  bind_cols(ames_preds) %>%
  bind_cols(ames_test) %>%
  mutate(across(c(.pred_lower:Sale_Price), ~10^.x)) %>%
  ggplot(aes(x = First_Flr_SF)) +
  geom_point(aes(y = Sale_Price),
             alpha = 0.25) +
  geom_line(aes(y = fit),
            size = 1) +
  geom_ribbon(aes(ymin = lwr,
                  ymax = upr),
              alpha = 0.25) +
  geom_point(aes(y = .pred),
             color = "blue",
             alpha = 0.25) +
  geom_errorbar(aes(ymin = .pred_lower,
                    ymax = .pred_upper),
                color = "blue",
                alpha = 0.25,
                width = 0.0125) +
  scale_y_continuous(labels = scales::dollar_format(), trans = "log10") +
  scale_x_continuous(labels = scales::comma_format(), trans = "log10") +
  labs(title = "Linear Model of Sale Price predicted by Square Footage",
       subtitle = "Bootstrap prediction interval closely matches theoretical prediction interval",
       x = NULL,
       y = NULL)

Both lm() and summarise_predictions() use a 95% prediction interval by default but we can generate other intervals by passing different values to the parameter conf:

ames_preds_boot %>%
  
  # generate 95% prediction interval
  summarise_predictions(conf = 0.95) %>%
  rename(.pred_lower_95 = .pred_lower,
         .pred_upper_95 = .pred_upper) %>%
  select(-.pred) %>%
  
  # generate 80% prediction interval
  summarise_predictions(conf = 0.80) %>%
  rename(.pred_lower_80 = .pred_lower,
         .pred_upper_80 = .pred_upper) %>%
  bind_cols(ames_test) %>%
  mutate(across(c(.pred_lower_95:Sale_Price), ~10^.x)) %>%
  
  # plot!
  ggplot(aes(x = First_Flr_SF)) +
  geom_point(aes(y = Sale_Price),
             alpha = 0.25) +
  geom_line(aes(y = .pred),
            size = 1,
            color = "blue") +
  geom_ribbon(aes(ymin = .pred_lower_95,
                  ymax = .pred_upper_95),
              alpha = 0.25,
              fill = "blue") +
  geom_ribbon(aes(ymin = .pred_lower_80,
                  ymax = .pred_upper_80),
              alpha = 0.25,
              fill = "blue") +
  scale_y_continuous(labels = scales::dollar_format(), trans = "log10") +
  scale_x_continuous(labels = scales::comma_format(), trans = "log10") +
  labs(title = "Linear Model of Sale Price predicted by Square Footage",
       subtitle = "Predictions alongside 95% and 80% bootstrap prediction interval",
       x = NULL,
       y = NULL)

As this example shows, {workboots} can approximate linear prediction intervals pretty well! But this isn’t very useful, since we can just generate a linear prediction interval from a linear model directly. The real benefit of {workboots} comes from generating prediction intervals from any model!

Bootstrap prediction intervals with non-linear models

XGBoost is one of my favorite models. Up until now, however, in situations that require a prediction interval, I’ve had to opt for a simpler model. With {workboots}, that’s no longer an issue! In this example, we’ll use XGBoost and {workboots} to generate predictions of a penguins weight from the Palmer Penguins dataset.

To get started, let’s build a workflow and train an individual model.

# load and prep data
data("penguins")

penguins <-
  penguins %>%
  drop_na()

# split data into training and testing sets
set.seed(123)
penguins_split <- initial_split(penguins)
penguins_test <- testing(penguins_split)
penguins_train <- training(penguins_split)
# create a workflow
penguins_wf <-
  workflow() %>%
  
  # add preprocessing steps
  add_recipe(
    recipe(body_mass_g ~ ., data = penguins_train) %>%
      step_dummy(all_nominal_predictors()) 
  ) %>%
  
  # add xgboost model spec
  add_model(
    boost_tree("regression")
  )

# fit to training data & predict on test data
set.seed(234)
penguins_preds <-
  penguins_wf %>%
  fit(penguins_train) %>%
  predict(penguins_test)

As mentioned above, XGBoost models can only generate point predictions.

penguins_preds %>%
  bind_cols(penguins_test) %>%
  ggplot(aes(x = body_mass_g,
             y = .pred)) +
  geom_point() +
  geom_abline(linetype = "dashed",
              color = "gray") +
  labs(title = "XGBoost Model of Penguin Weight",
       subtitle = "Individual model can only output individual predictions")

With {workboots}, however, we can generate a prediction interval from our XGBoost model for each penguin’s weight!

# create 2000 models from bootstrap resamples and make predictions on the test set
set.seed(345)
penguins_preds_boot <-
  penguins_wf %>%
  predict_boots(
    n = 2000,
    training_data = penguins_train,
    new_data = penguins_test
  )

penguins_preds_boot %>%
  summarise_predictions()
## # A tibble: 84 x 5
##    rowid .preds               .pred_lower .pred .pred_upper
##    <int> <list>                     <dbl> <dbl>       <dbl>
##  1     1 <tibble [2,000 x 2]>       2788. 3470.       4136.
##  2     2 <tibble [2,000 x 2]>       2838. 3534.       4231.
##  3     3 <tibble [2,000 x 2]>       2942. 3598.       4301.
##  4     4 <tibble [2,000 x 2]>       3354. 4158.       4889.
##  5     5 <tibble [2,000 x 2]>       3186. 3870.       4500.
##  6     6 <tibble [2,000 x 2]>       2884. 3519.       4208.
##  7     7 <tibble [2,000 x 2]>       2790. 3434.       4094.
##  8     8 <tibble [2,000 x 2]>       3394. 4071.       4772.
##  9     9 <tibble [2,000 x 2]>       2812. 3447.       4096.
## 10    10 <tibble [2,000 x 2]>       2744. 3404.       4063.
## # ... with 74 more rows

How does our bootstrap model perform?

penguins_preds_boot %>%
  summarise_predictions() %>%
  bind_cols(penguins_test) %>%
  ggplot(aes(x = body_mass_g,
             y = .pred,
             ymin = .pred_lower,
             ymax = .pred_upper)) +
  geom_abline(linetype = "dashed",
              color = "gray") +
  geom_errorbar(alpha = 0.5,
                color = "blue") +
  geom_point(alpha = 0.5,
             color = "blue") +
  labs(title = "XGBoost Model of Penguin Weight",
       subtitle = "Bootstrap models can generate prediction intervals")

This particular model may be in need of some tuning for better performance, but the important takeaway is that we were able to generate a prediction distribution for the model! This method works with other regression models as well — just create a workflow then let {workboots} take care of the rest!

Tidymodel Resources

Avatar
Mark Rieke
Senior CX Analyst

I’m a mechanical engineer by education, data analyst by practice. I love machine learning and communicating complex topics clearly with simple and beautiful charts.