Skip to content
This repository was archived by the owner on Sep 9, 2025. It is now read-only.

arcruz0/marginaleffectsJAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

77 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Note

marginaleffects now includes experimental support for everything in this package, and more! 🚀

Check out the documentation.

marginaleffectsJAX

A JAX backend for marginaleffects. Under construction!

Installation

install.packages("remotes") # if `remotes` is not installed
remotes::install_github("arcruz0/marginaleffectsJAX")

Benchmarks (very preliminary)

Setting: lm() models with $N$ observations, $y \sim \text{N}(0,1)$, and $K$ regressors, half of which are $x_k \sim \text{N}(0,1)$ and half of which are $x_k \sim \text{Bernoulli}(0.5)$. by = "var" predictions are for one of the latter. Median times over 10 replications. See benchmarks/benchmark_predictions.R for the code.

Note: Loading marginaleffectsJAX takes a few seconds. Calling a function for the first time takes more time than subsequent runs; see plot for maximum times.

Usage (very preliminary)

library(marginaleffects)
library(marginaleffectsJAX)
enable_JAX_backend()
#> JAX is now a backend for `marginaleffects`. Run `disable_JAX_backend()` to disable.

mod <- lm(mpg ~ hp + am, mtcars)

predictions(mod) |> head()
#> 
#>  Estimate Std. Error    z Pr(>|z|)     S 2.5 % 97.5 %
#>      25.4      0.818 31.0   <0.001 700.5  23.8   27.0
#>      25.4      0.818 31.0   <0.001 700.5  23.8   27.0
#>      26.4      0.850 31.1   <0.001 701.1  24.7   28.1
#>      20.1      0.775 25.9   <0.001 490.0  18.6   21.6
#>      16.3      0.677 24.0   <0.001 421.6  15.0   17.6
#>      20.4      0.796 25.6   <0.001 478.6  18.8   22.0
#> 
#> Type: response

predictions(mod, by = TRUE)
#> 
#>  Estimate Std. Error    z Pr(>|z|)   S 2.5 % 97.5 %
#>      20.1      0.514 39.1   <0.001 Inf  19.1   21.1
#> 
#> Type: response

predictions(mod, by = "am")
#> 
#>  am Estimate Std. Error    z Pr(>|z|)     S 2.5 % 97.5 %
#>   0     17.1      0.667 25.7   <0.001 481.2  15.8   18.5
#>   1     24.4      0.807 30.2   <0.001 664.5  22.8   26.0
#> 
#> Type: response

Supported marginaleffects calls (very preliminary)

(Only models of class lm are supported).

Functionality Example call(s) Supported?
Predictions: unit-level predictions(mod)
Predictions: aggregate predictions(mod, by = TRUE)
avg_predictions(mod)
Predictions: marginal predictions(mod, by = "var")
avg_predictions(mod, by = "var")
plot_predictions(mod, by = "var")
Predictions: custom grid predictions(mod, newdata = data_grid(...))
Predictions: counterfactual grid predictions(mod, newdata = data_grid(..., grid_type = "counterfactual"))
Predictions: mean-or-mode grid predictions(mod, newdata = "mean")
Predictions: balanced grid predictions(mod, newdata = "balanced"))
Predictions: weighted predictions(mod, by = "var", wt = "wvar") 🔜
Predictions: summed predictions(mod, by = ..., byfun = sum) 🔜
Comparisons comparisons(mod, ...) 🔜

About

A JAX Backend for `marginaleffects`

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages