Skip to content

Commit 4a49dfc

Browse files
inferences cleanup
1 parent 1051037 commit 4a49dfc

File tree

8 files changed

+89
-44
lines changed

8 files changed

+89
-44
lines changed

R/inferences.R

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ inferences <- function(
137137
# dummy mfx for `estimator` and no marginaleffects object
138138
if (!inherits(mfx, "marginaleffects_internal")) {
139139
mfx <- new_marginaleffects_internal(NULL, call("predictions"))
140+
if (isTRUE(checkmate::check_data_frame(x))) {
141+
mfx@modeldata <- x
142+
mfx@modeldata_available <- TRUE
143+
} else {
144+
mfx@modeldata_available <- FALSE
145+
}
140146
}
141147

142148
# supported classes
@@ -174,13 +180,18 @@ inferences <- function(
174180
)
175181

176182
if (is.null(data_train)) {
177-
if (!isTRUE(mfx@modeldata_available) && method != "conformal_split") {
183+
if (!isTRUE(mfx@modeldata_available)) {
178184
checkmate::assert_data_frame(data_train, null.ok = FALSE)
179185
} else {
180186
data_train <- mfx@modeldata
181187
}
182188
}
183189

190+
if (is.null(data_test)) {
191+
checkmate::assert_data_frame(mfx@newdata)
192+
data_test <- mfx@newdata
193+
}
194+
184195
if (inherits(mfx@model, c("Learner", "model_fit", "workflow"))) {
185196
if (method == "simulation") {
186197
msg <- "Simulation-based inference is not supported for this class."
@@ -218,6 +229,7 @@ inferences <- function(
218229
conf_level = conf_level,
219230
conf_type = conf_type,
220231
estimator = estimator,
232+
data_train = data_train,
221233
mfx = mfx,
222234
...)
223235
} else if (method == "fwb") {
@@ -236,6 +248,7 @@ inferences <- function(
236248
conf_level = conf_level,
237249
conf_type = conf_type,
238250
estimator = estimator,
251+
data_train = data_train,
239252
mfx = mfx,
240253
...)
241254
} else if (method == "simulation") {
@@ -246,12 +259,11 @@ inferences <- function(
246259
mfx = mfx,
247260
...)
248261
} else if (isTRUE(grepl("conformal", method))) {
249-
data_test <- sanity_inferences_conformal(
262+
sanity_inferences_conformal(
250263
mfx = mfx,
251264
score = conformal_score,
252265
method = method,
253266
data_calib = data_calib,
254-
data_test = data_test,
255267
R = R
256268
)
257269

R/inferences_boot.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
inferences_boot <- function(x, R = 1000, conf_level = 0.95, conf_type = "perc", estimator = NULL, mfx = NULL, ...) {
1+
inferences_boot <- function(x, R = 1000, conf_level = 0.95, conf_type = "perc", estimator = NULL, data_train = NULL, mfx = NULL, ...) {
22
insight::check_if_installed("boot")
33

44
out <- x
@@ -17,7 +17,7 @@ inferences_boot <- function(x, R = 1000, conf_level = 0.95, conf_type = "perc",
1717
}
1818
}
1919

20-
args <- list("data" = mfx@modeldata, "statistic" = bootfun, R = R)
20+
args <- list("data" = data_train, "statistic" = bootfun, R = R)
2121
args <- c(args, list(...))
2222
B <- do.call(boot::boot, args)
2323

R/inferences_rsample.R

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
inferences_rsample <- function(x, R = 1000, conf_level = 0.95, conf_type = "perc", estimator = NULL, mfx = NULL, ...) {
1+
inferences_rsample <- function(x, R = 1000, conf_level = 0.95, conf_type = "perc", estimator = NULL, data_train = NULL, mfx = NULL, ...) {
22
insight::check_if_installed("rsample")
33

44
out <- x
55

6-
# Get modeldata from mfx object
7-
modeldata <- mfx@modeldata
8-
96
if (!is.null(estimator)) {
107
bootfun <- function(split, ...) {
118
d <- rsample::analysis(split)
@@ -38,15 +35,7 @@ inferences_rsample <- function(x, R = 1000, conf_level = 0.95, conf_type = "perc
3835

3936
args <- list("apparent" = TRUE)
4037
args[["times"]] <- R
41-
42-
# Sometimes modeldata is empty (ex: `tidymodels`)
43-
if (nrow(modeldata) > 0) {
44-
args[["data"]] <- modeldata
45-
} else if (nrow(mfx@modeldata) > 0) {
46-
args[["data"]] <- mfx@modeldata
47-
} else {
48-
args[["data"]] <- mfx@newdata
49-
}
38+
args[["data"]] <- data_train
5039

5140
args <- c(args, list(...))
5241
if ("group" %in% ...names()) {

R/refit.R

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,37 @@ refit.marginaleffects <- function(object, data = NULL, newdata = NULL, vcov = NU
3030

3131
model <- mfx@model
3232

33+
fit_again <- function(model, data) {
34+
# Try stats::update first
35+
model <- tryCatch(
36+
stats::update(model, data = data),
37+
error = function(e) NULL
38+
)
39+
# Fallback: modify call and re-evaluate
40+
if (is.null(model)) {
41+
if (is.call(mfx@call_model) && "data" %in% names(mfx@call_model)) {
42+
call_new <- mfx@call_model
43+
call_new$data <- data
44+
model <- try(eval(call_new), silent = TRUE)
45+
if (inherits(model, "try-error")) {
46+
stop("Failed to refit the model.", call. = FALSE)
47+
}
48+
} else {
49+
stop("Failed to refit model: no update method available", call. = FALSE)
50+
}
51+
}
52+
return(model)
53+
}
54+
3355
# Step 1: Refit model if data is supplied
3456
if (!is.null(data)) {
3557
# For workflows, tidymodels provides its own fit.workflow method
3658
if (inherits(model, "workflow")) {
3759
model <- generics::fit(model, data = data)
60+
} else if (inherits(model, "model_fit")) {
61+
model <- fit_again(model[["fit"]], data = data)
3862
} else {
39-
# Try stats::update first
40-
model <- tryCatch(
41-
stats::update(model, data = data),
42-
error = function(e) NULL
43-
)
44-
45-
# Fallback: modify call and re-evaluate
46-
if (is.null(model)) {
47-
if (is.call(mfx@call_model) && "data" %in% names(mfx@call_model)) {
48-
call_new <- mfx@call_model
49-
call_new$data <- data
50-
model <- eval(call_new)
51-
} else {
52-
stop("Failed to refit model: no update method available")
53-
}
54-
}
63+
model <- fit_again(model, data = data)
5564
}
5665
}
5766

R/sanity_inferences.R

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
sanity_inferences_conformal <- function(mfx, score, method, data_calib, data_test, R) {
1+
sanity_inferences_conformal <- function(mfx, score, method, data_calib, R) {
22
checkmate::assert_choice(
33
score,
44
choices = c("residual_abs", "residual_sq", "softmax")
@@ -56,11 +56,6 @@ sanity_inferences_conformal <- function(mfx, score, method, data_calib, data_tes
5656
}
5757
}
5858

59-
if (is.null(data_test)) {
60-
checkmate::assert_data_frame(mfx@newdata)
61-
data_test <- mfx@newdata
62-
}
63-
6459
if (method %in% c("conformal_split", "conformal_quantile")) {
6560
checkmate::assert_data_frame(data_calib, null.ok = FALSE)
6661
}
@@ -74,7 +69,7 @@ sanity_inferences_conformal <- function(mfx, score, method, data_calib, data_tes
7469
stop_sprintf(msg)
7570
}
7671

77-
return(data_test)
72+
return(invisible(NULL))
7873
}
7974

8075

inst/tinytest/test-inferences_rsample.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ x <- mod |>
4040
components("inferences") |>
4141
suppressWarnings()
4242
expect_inherits(x, "bootstraps")
43-
nd <<- datagrid(Sepal.Length = range, model = mod)
43+
nd <- datagrid(Sepal.Length = range, model = mod)
4444
x <- mod |>
4545
comparisons(variables = "Sepal.Width", newdata = nd) |>
4646
inferences(method = "rsample", R = R) |>
@@ -82,7 +82,7 @@ model <- coxph(
8282
Surv(dtime, death) ~ hormon * factor(grade) + ns(age, df = 2),
8383
data = rotterdam
8484
)
85-
nd <<- datagrid(
85+
nd <- datagrid(
8686
hormon = unique,
8787
grade = unique,
8888
dtime = seq(36, 7043, length.out = 25),

inst/tinytest/test-pkg-fixest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ dt <- mtcars
149149
dt$cyl <- factor(dt$cyl)
150150
fit1 <- suppressMessages(feols(mpg ~ 0 | carb | vs ~ am, data = dt))
151151
fit2 <- suppressMessages(feols(mpg ~ cyl | carb | vs ~ am, data = dt))
152-
fit3 <- suppressMessages(feols(mpg ~ 0 | carb | vs:cyl ~ am:cyl, data = dt))
152+
fit3 <- suppressWarnings(feols(mpg ~ 0 | carb | vs:cyl ~ am:cyl, data = dt))
153153
mfx1 <- slopes(fit1)
154154
mfx2 <- slopes(fit2)
155155
mfx3 <- slopes(fit3)

inst/tinytest/test-pkg-tidymodels.R

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,17 @@ expect_true(nrow(mfx) > 0)
3838

3939

4040
# conformal
41+
set.seed(48103)
4142
dat <- get_dataset("penguins", "palmerpenguins") |> na.omit()
43+
dat <- dat[sample(1:nrow(dat), nrow(dat)), ] # shuffle species
4244
mod <- set_engine(linear_reg(), "lm") |>
4345
fit(body_mass_g ~ bill_length_mm + flipper_length_mm + species,
4446
data = na.omit(dat))
4547
p <- predictions(mod, newdata = dat[1:100, ]) |>
4648
inferences(
4749
R = 3,
4850
method = "conformal_cv+",
51+
data_train = dat[1:100, ],
4952
data_calib = dat[101:nrow(dat), ]
5053
)
5154
expect_inherits(p, "predictions")
@@ -65,6 +68,7 @@ expect_true("std.error" %in% colnames(p))
6568
p <- predictions(mod, newdata = bikes[1:200, ], vcov = FALSE) |>
6669
inferences(
6770
method = "conformal_split",
71+
data_train = bikes[1:200, ],
6872
data_calib = bikes[201:nrow(bikes), ])
6973
expect_inherits(p, "predictions")
7074

@@ -123,3 +127,39 @@ lr_fit <- lr_wf |>
123127
mfx <- slopes(lr_fit, newdata = my_data, variable = "x")
124128
expect_equivalent(mfx$x, my_data$x)
125129
expect_equivalent(mfx$y, my_data$y)
130+
131+
132+
# Bootstrap
133+
set.seed(48103)
134+
nobs <- 50
135+
wf <- workflow() |>
136+
add_model(boost_tree(mode = "regression")) |>
137+
add_recipe(
138+
recipe(Sepal.Length ~ ., data = iris) |>
139+
# 1. Convert character predictors to factors (if any)
140+
step_string2factor(all_nominal_predictors()) |>
141+
# 2. Dummy-code all nominal predictors
142+
step_dummy(all_nominal_predictors())
143+
) |>
144+
fit(iris)
145+
mfx1 <- comparisons(wf, newdata = iris, variable = "Sepal.Width", vcov = FALSE)
146+
mfx2 <- inferences(mfx1, R = 100, method = "rsample", data_train = iris) |>
147+
suppressWarnings()
148+
expect_false("conf.low" %in% colnames(mfx1))
149+
expect_true("conf.low" %in% colnames(mfx2))
150+
151+
152+
# Bootstrap for some supported models but not all
153+
z <- boost_tree("regression") |>
154+
fit(hp ~ ., data = mtcars)
155+
comparisons(z, variables = "mpg", newdata = mtcars, vcov = FALSE) |>
156+
inferences(R = 10, method = "boot", data_train = mtcars) |>
157+
suppressWarnings() |>
158+
expect_error(pattern = "Failed to refit")
159+
z <- linear_reg() |>
160+
fit(hp ~ ., data = mtcars)
161+
cmp <- comparisons(z, variables = "mpg", newdata = mtcars, vcov = FALSE) |>
162+
inferences(R = 10, method = "boot", data_train = mtcars) |>
163+
suppressWarnings()
164+
expect_inherits(cmp, "comparisons")
165+
expect_true("conf.low" %in% colnames(cmp))

0 commit comments

Comments
 (0)