-
Notifications
You must be signed in to change notification settings - Fork 58
autodiff()
#1580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
autodiff()
#1580
Conversation
|
I would be really interested in people's experience with this.
The feature set is currently limited to Tagging people who seem to like bleeding edge stuff: @strengejacke @mattansb @andrewheiss @saudiwin # Install the dev version of `marginaleffects`
remotes::install_github("vincentarelbundock/marginaleffects@autodiff")
library(marginaleffects)
library(microbenchmark)
# Install Python autodiff dependencies
autodiff(install = TRUE)
# Activate autodiff
autodiff(TRUE)
# Download data and fit a large model
dat <- get_dataset("airbnb")
mod <- glm(TV ~ ., data = dat, family = binomial)
# Average Predictions
finite <- function() {
autodiff(FALSE)
predictions(mod, type = "response")
}
auto <- function() {
autodiff(TRUE)
predictions(mod, type = "response")
}
microbenchmark(finite(), auto(), times = 5)
# Average Treatment Effect
finite <- function() {
autodiff(FALSE)
avg_comparisons(mod, variables = "Heating")
}
auto <- function() {
autodiff(TRUE)
avg_comparisons(mod, variables = "Heating")
}
microbenchmark(finite(), auto(), times = 5)
|
|
What's autodiff? Does it require Python? |
|
Autodiff gives you faster and more accurate derivatives, which is what we need for standard errors. It requires an installation of python on your machine, but you shouldn't have to interact with Python at all.
|
|
Ah ok. But I don't have python installed... 😬 |
|
You don't need to install Python - reticulate bootstraps everything it needs on its own. |
|
@t-kalinowski I think I said it to you in person, but that thing is magic. |
|
So, obviously don't have much add in terms of autodiff expertise, but the move makes a lot of sense to me. Upgrading the differentiation engine to something state-of-the-art that is robust to many parameters will make it useful for many more kinds of applications in science & industry. Maybe even deep learning? Not sure they want marginaleffects but still just a thought. The obvious lacuna is, of course, support for ordered beta regression 😁 It's a standard GLM so it shouldn't be too hard, right? |
|
And re: reticulate, yes, that is really cool, obviously at the same time Python installation is tech debt. But as long as it's an optional part of the package I don't think it's a problem. There will always be someone running some weird hacked version of Linux in a smart washing machine who won't be able to install it, and CRAN will punish you if you make it a default & they can't install it, etc. |
|
This is very cool! I haven't tried it out yet, but it is currently limited only to lm and glm classes? Or more generally these types of models (e.g., (@t-kalinowski I also have to chime in on the praise - reticulate is awesome. I keep telling people that because of it, R is even better than python in python! 😉) |
All we need to do is write a Here we get a 13x speedup. library(microbenchmark)
library(marginaleffects)
library(rms)
dat <- get_dataset("airbnb")
mod <- lrm(TV ~ ., data = dat)
finite <- function() {
autodiff(FALSE)
predictions(mod, type = "fitted")
}
auto <- function() {
autodiff(TRUE)
predictions(mod, type = "fitted")
}
p1 <- auto()
p2 <- finite()
all.equal(p1$estimate, p2$estimate)
[1] TRUE
all.equal(p1$std.error, p2$std.error, tol = 1e-6)
[1] TRUE
microbenchmark(finite(), auto(), times = 5)
Unit: milliseconds
expr min lq mean median uq max neval cld
finite() 2590.1808 2596.2159 2604.5910 2598.2415 2602.0976 2636.2194 5 a
auto() 186.3906 241.0184 253.8272 253.9089 270.5217 317.2962 5 b |
|
Wild! |
|
Just to summarize:
And then, whenever possible, marginaleffects uses JAX for the Jacobian internally, which is much faster (and sometime also more accurate)? |
|
Related to this comment: https://bsky.app/profile/bbolker.bsky.social/post/3lxpwy3nsb222 Since glmmTMB is based on TMB, is it still worth to have a dedicated support for glmmTMB and |
I don't expect to be able to support |
|
Also need to call |
|
As a suggestion for the documentation/help file for |
|
Thanks. TBH, I think the current documentation is explicit enough, becauce it says there TRUE enables and FALSE disables. If you want to out a PR with an improved wording, I'll be happy to review. |
|
What I mean is: once enabled, is it enabled forever? Or does it need to be enabled for every new R session? (Forgive me if this is obvious to those more R-savvy; I consider myself only intermediate in my understanding of these things.) If the former, the wording could be:
If the latter, the wording could be:
|
|
ah, I see the potential for confusion now. It should not persist across sessions, so the second wording would be correct. |
|
(like how |
|
Improved the docs here. Thanks for the suggestion! |
@arcruz0
I’ve been thinking about autodiff a lot recently and have become convinced that a tight integration with
marginaleffectsis highly desirable. The gains are substantial and the implementation straightforward.So here’s my proposal:
marginaleffectsADpackage in Python to host all the JAX functions.marginaleffectsforRthat looks at amarginaleffectscall and dispatches to an appropriate function from the Python package, and falls back to finite difference whenever necessary.There are three main benefits to this approach, the first being most important:
RandPythonversions ofmarginaleffects.R.autodiff()function, directly in theRpackage, without having to rely on hacks like global options.The Python interface would look something like this:
In addition to your existing functions, I added support for Probit and Poisson, as well as an initial implementation of
comparisons()andavg_comparisons()to compute ATE / G-computation.I sent you an invitation as contributor to the
marginaleffectsADpython package repo, in case you would like to move core development efforts there.In this PR, I implement a prototype of this idea on the
Rside. The code below shows that the workflow could be very nice for users.