Skip to content

Increasing efficiency of add_epred_draws for large number of predictors? #314

Open
@petermacp

Description

@petermacp

I am exploring working out a workflow for using add_epred_draws() and subsequent calculations when the number of potential predictors is very large. Appreciate the answer might be "just use a smaller prediction matrix" 🤣, but interested to see what is possible.

Following along with this example from Andrew Heiss, I have constructed a brms regression model, that fits well:

library(tidyverse)
library(brms)
library(tidybayes)
library(duckdb)

m3 <- brm(
  bf(choice_alt ~ 0 + (duration + numtabs + reduction +
       passon + adverseeffects + followup + cost + oo) *
       (diseaserisk + age_z + sex + reading + tb_contact) + 
       (1 | ID | pid)),
  data = m_data,
  family = categorical(refcat = "0"),
  prior = c(
    prior(normal(0, 3), class = b, dpar = mu1),
    prior(normal(0, 3), class = b, dpar = mu2),
    prior(normal(0, 3), class = b, dpar = mu3),
    prior(exponential(1), class = sd, dpar = mu1),
    prior(exponential(1), class = sd, dpar = mu2),
    prior(exponential(1), class = sd, dpar = mu3),
    prior(lkj(1), class = cor)
  ),
  chains = 4, cores = 4, iter = 2000, seed = 1234,
  backend = "cmdstanr", threads = threading(2)
)

We make a very large prediction matrix, comprising 5 million rows, with all combinations of predictors (just for this example).

nd3_matrix <-
  expand.grid(
    diseaserisk = unique(m_data$diseaserisk),
    duration = unique(m_data$duration),
    numtabs = unique(m_data$numtabs),
    reduction = unique(m_data$reduction),
    passon = unique(m_data$passon),
    adverseeffects = unique(m_data$adverseeffects),
    followup = unique(m_data$followup),
    cost = unique(m_data$cost),
    oo = unique(m_data$oo),
    sex = unique(m_data$sex),
    reading = unique(m_data$reading),
    tb_contact = unique(m_data$tb_contact),
    age_z = c(-1.5,-1,-0.5,0,0.5,1,1.5,2)
  )

Of course, when we try to add_epred_draws() using nd3_matrix in the newdata= argument, we very quickly run out of memory.

So instead, I wondered if it would be possible split nd3_matrix into more manageable chunks, and write to a database to allow more efficient post-processing, like this:

# Create a temporary duckdb database
con <- dbConnect(duckdb::duckdb(dbdir = "preds_m3.duckdb"))

#make a function to write a table to database for each `nest_id`

write_preds <- function(nest_value, ndraws){
  
  temp <- nd3_matrix %>%
  group_nest(sex, reading, tb_contact, age_z) %>%
  mutate(nest_id = row_number()) %>%
  unnest() %>%
  ungroup() %>%
  filter(nest_id==nest_value) %>%
  add_epred_draws(object=m3, re_formula = NA, ndraws = ndraws) %>%
  filter(.category == 0) %>% 
  mutate(.epred = 1 - .epred) 
  
  dbWriteTable(con, paste0("processed_data", "_", nest_value), temp, append = TRUE)
  
}

#Now run for all groups, and save results to database
nest_ids <- nd3_matrix %>%
  group_nest(sex, reading, tb_contact, age_z) %>%
  mutate(nest_id = row_number()) %>%
  unnest() %>%
  ungroup() %>%
  distinct(nest_id) %>%
  pull(nest_id)

walk(nest_ids, .f = \(x) write_preds(nest_value = x, ndraws = 100)) #still only 100 draws... but

# Disconnect from the DuckDB database
dbDisconnect(con)

This seems to work (within a reasonable hour or so), and allows me to work with predictions using dbplyr that might not otherwise be possible. But wondering if there are any ways to future optimise this, and potentially increase the ndraws possible?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions