Skip to content

Conversation

@vincentarelbundock
Copy link
Owner

@vincentarelbundock vincentarelbundock commented Aug 31, 2025

@arcruz0

I’ve been thinking about autodiff a lot recently and have become convinced that a tight integration with marginaleffects is highly desirable. The gains are substantial and the implementation straightforward.

So here’s my proposal:

  1. Create a marginaleffectsAD package in Python to host all the JAX functions.
  2. Create a helper function in marginaleffects for R that looks at a marginaleffects call and dispatches to an appropriate function from the Python package, and falls back to finite difference whenever necessary.

There are three main benefits to this approach, the first being most important:

  1. The same JAX functions can be re-used in both the R and Python versions of marginaleffects.
  2. We need to host much less boilerplate code in R.
  3. This allows tight integration with an autodiff() function, directly in the R package, without having to rely on hacks like global options.

The Python interface would look something like this:

import marginaleffectsAD as mad
mad.logit.predictions.jacobian_byG(beta, X, groups, num_groups)

In addition to your existing functions, I added support for Probit and Poisson, as well as an initial implementation of comparisons() and avg_comparisons() to compute ATE / G-computation.

I sent you an invitation as contributor to the marginaleffectsAD python package repo, in case you would like to move core development efforts there.

In this PR, I implement a prototype of this idea on the R side. The code below shows that the workflow could be very nice for users.

library(marginaleffects)

# install python dependencies
autodiff(install = TRUE)

# activate autodiff
autodiff(TRUE)

mod <- glm(carb ~ ., data = mtcars, family = poisson)

avg_comparisons(mod)
JAX is fast!


 Term Contrast Estimate Std. Error      z Pr(>|z|)   S   2.5 %  97.5 %
 am      1 - 0 -0.51569    1.43986 -0.358   0.7202 0.5 -3.3378 2.30639
 cyl     +1     0.23779    0.72932  0.326   0.7444 0.4 -1.1916 1.66723
 disp    +1    -0.01349    0.00806 -1.674   0.0942 3.4 -0.0293 0.00231
 drat    +1     0.78415    1.40614  0.558   0.5771 0.8 -1.9718 3.54014
 gear    +1     0.84879    1.14651  0.740   0.4591 1.1 -1.3983 3.09590
 hp      +1     0.00467    0.01143  0.408   0.6830 0.5 -0.0177 0.02706
 mpg     +1    -0.05458    0.15711 -0.347   0.7283 0.5 -0.3625 0.25334
 qsec    +1    -0.32770    0.45950 -0.713   0.4757 1.1 -1.2283 0.57290
 vs      1 - 0 -0.60648    1.33246 -0.455   0.6490 0.6 -3.2180 2.00509
 wt      +1     2.20869    1.81349  1.218   0.2233 2.2 -1.3457 5.76306

Type: response
# accuracy
auto <- function() {
    autodiff(TRUE)
    avg_comparisons(mod) |> suppressMessages()
}

finite <- function() {
    autodiff(FALSE)
    avg_comparisons(mod)
}

# results are within numerical tolerance
a <- auto()
f <- finite()
all.equal(a$estimate, f$estimate)
[1] TRUE
all.equal(a$std.error, f$std.error)
[1] "Mean relative difference: 6.694612e-07"
# benchmark
library(microbenchmark)
microbenchmark(
    auto(),
    finite()
)
Warning in microbenchmark(auto(), finite()): less accurate nanosecond times to
avoid potential integer overflows

Unit: milliseconds
     expr      min       lq     mean   median       uq      max neval cld
   auto() 22.23348 22.85422 24.75146 23.29194 24.91074 85.21522   100  a 
 finite() 58.03616 60.51901 62.56966 61.90844 63.32104 77.45339   100   b

@vincentarelbundock
Copy link
Owner Author

I would be really interested in people's experience with this.

  • Is the user interface easy and intuitive?
  • Did you have problems setting it up?
  • Did you run into unexpected errors?
  • Are the warnings/fallbacks clear and informative?

The feature set is currently limited to lm and glm (logit, probit, and poisson), and we do not yet support arguments like hypothesis and wts. But the speed is pretty amazing. In models with many parameters, I get 5-15x speedup.

Tagging people who seem to like bleeding edge stuff: @strengejacke @mattansb @andrewheiss @saudiwin

# Install the dev version of `marginaleffects`
remotes::install_github("vincentarelbundock/marginaleffects@autodiff")

library(marginaleffects)
library(microbenchmark)

# Install Python autodiff dependencies
autodiff(install = TRUE)

# Activate autodiff
autodiff(TRUE)

# Download data and fit a large model
dat <- get_dataset("airbnb")
mod <- glm(TV ~ ., data = dat, family = binomial)

# Average Predictions
finite <- function() {
    autodiff(FALSE)
    predictions(mod, type = "response")
}

auto <- function() {
    autodiff(TRUE)
    predictions(mod, type = "response")
}

microbenchmark(finite(), auto(), times = 5)

# Average Treatment Effect
finite <- function() {
    autodiff(FALSE)
    avg_comparisons(mod, variables = "Heating")
}

auto <- function() {
    autodiff(TRUE)
    avg_comparisons(mod, variables = "Heating")
}

microbenchmark(finite(), auto(), times = 5)

@strengejacke
Copy link
Contributor

What's autodiff? Does it require Python?
(if the answer to the 2nd question is "yes", I'm out ;-))

@vincentarelbundock
Copy link
Owner Author

Autodiff gives you faster and more accurate derivatives, which is what we need for standard errors.

It requires an installation of python on your machine, but you shouldn't have to interact with Python at all.

autodiff(install=TRUE) should do everything for you from R, using the reticulate package.

@strengejacke
Copy link
Contributor

Ah ok. But I don't have python installed... 😬

@t-kalinowski
Copy link
Contributor

You don't need to install Python - reticulate bootstraps everything it needs on its own.

@vincentarelbundock
Copy link
Owner Author

@t-kalinowski I think I said it to you in person, but that thing is magic.

@saudiwin
Copy link

saudiwin commented Sep 1, 2025

So, obviously don't have much add in terms of autodiff expertise, but the move makes a lot of sense to me. Upgrading the differentiation engine to something state-of-the-art that is robust to many parameters will make it useful for many more kinds of applications in science & industry. Maybe even deep learning? Not sure they want marginaleffects but still just a thought.

The obvious lacuna is, of course, support for ordered beta regression 😁

It's a standard GLM so it shouldn't be too hard, right?

@saudiwin
Copy link

saudiwin commented Sep 1, 2025

And re: reticulate, yes, that is really cool, obviously at the same time Python installation is tech debt. But as long as it's an optional part of the package I don't think it's a problem.

There will always be someone running some weird hacked version of Linux in a smart washing machine who won't be able to install it, and CRAN will punish you if you make it a default & they can't install it, etc.

@mattansb
Copy link
Contributor

mattansb commented Sep 2, 2025

This is very cool!

I haven't tried it out yet, but it is currently limited only to lm and glm classes? Or more generally these types of models (e.g., rms::ols())?

(@t-kalinowski I also have to chime in on the praise - reticulate is awesome. I keep telling people that because of it, R is even better than python in python! 😉)

@vincentarelbundock
Copy link
Owner Author

@mattansb

I haven’t tried it out yet, but it is currently limited only to lm and
glm classes? Or more generally these types of models (e.g.,
rms::ols())?

All we need to do is write a predict() function in JAX to make predictions based on the model.matrix and vector of coefficients. rms::ols() and rms::lrm() are thus easy to support (except for things like penalty).

Here we get a 13x speedup.

library(microbenchmark)
library(marginaleffects)
library(rms)
dat <- get_dataset("airbnb")
mod <- lrm(TV ~ ., data = dat)

finite <- function() {
  autodiff(FALSE)
  predictions(mod, type = "fitted")
}

auto <- function() {
  autodiff(TRUE)
  predictions(mod, type = "fitted")
}

p1 <- auto()
p2 <- finite()

all.equal(p1$estimate, p2$estimate)

    [1] TRUE

all.equal(p1$std.error, p2$std.error, tol = 1e-6)

    [1] TRUE

microbenchmark(finite(), auto(), times = 5)

    Unit: milliseconds
         expr       min        lq      mean    median        uq       max neval cld
     finite() 2590.1808 2596.2159 2604.5910 2598.2415 2602.0976 2636.2194     5  a 
       auto()  186.3906  241.0184  253.8272  253.9089  270.5217  317.2962     5   b

@mattansb
Copy link
Contributor

mattansb commented Sep 2, 2025

Wild!

@strengejacke
Copy link
Contributor

Just to summarize:

  1. I install the reticulate package
  2. I install marginaleffects from this PR
  3. I run autodiff(TRUE)

And then, whenever possible, marginaleffects uses JAX for the Jacobian internally, which is much faster (and sometime also more accurate)?

@strengejacke
Copy link
Contributor

strengejacke commented Sep 2, 2025

Related to this comment: https://bsky.app/profile/bbolker.bsky.social/post/3lxpwy3nsb222

Since glmmTMB is based on TMB, is it still worth to have a dedicated support for glmmTMB and autodiff in marginaleffects, or don't you expect larger benefits in terms of speed?

@vincentarelbundock
Copy link
Owner Author

Since glmmTMB is based on TMB, is it still worth to have a dedicated support for glmmTMB and autodiff in marginaleffects, or don't you expect larger benefits in terms of speed?

I don't expect to be able to support glmmTMB at all. (But maybe I'm wrong!)

@vincentarelbundock
Copy link
Owner Author

vincentarelbundock commented Sep 2, 2025

Also need to call autodiff(install = TRUE) and autodiff(TRUE), which is a new exported marginaleffects function.

@vincentarelbundock vincentarelbundock merged commit ee2ea3a into main Sep 4, 2025
10 checks passed
@vincentarelbundock vincentarelbundock deleted the autodiff branch September 6, 2025 13:03
@teecrow
Copy link

teecrow commented Sep 16, 2025

As a suggestion for the documentation/help file for autodiff() - it could be useful to clarify whether autodiff(autodiff=TRUE) needs to be run once ever or once per session. (It's clear that running it a second time after restarting the session does not trigger the same python installs as the first time I ran it, but it's not clear to me whether autodiff will now be used by default for supported models in every new session or not.)

@vincentarelbundock
Copy link
Owner Author

Thanks.

TBH, I think the current documentation is explicit enough, becauce it says there TRUE enables and FALSE disables.

If you want to out a PR with an improved wording, I'll be happy to review.

@teecrow
Copy link

teecrow commented Sep 16, 2025

What I mean is: once enabled, is it enabled forever? Or does it need to be enabled for every new R session? (Forgive me if this is obvious to those more R-savvy; I consider myself only intermediate in my understanding of these things.)

If the former, the wording could be:

autodiff only needs to be enabled once: it will persist across sessions.

If the latter, the wording could be:

autodiff needs to be enabled once at the start of each new R session.

@vincentarelbundock
Copy link
Owner Author

ah, I see the potential for confusion now. It should not persist across sessions, so the second wording would be correct.

@andrewheiss
Copy link
Contributor

(like how tinytex::install_latex() only ever has to run once on a computer vs. autodiff(TRUE) has to run once per session)

@vincentarelbundock
Copy link
Owner Author

Improved the docs here. Thanks for the suggestion!

67f1edc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants