Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ S3method(df.residual,slopes)
S3method(get_autodiff_args,glm)
S3method(get_autodiff_args,ivreg)
S3method(get_autodiff_args,lm)
S3method(get_autodiff_args,lrm)
S3method(get_autodiff_args,ols)
S3method(get_coef,afex_aov)
S3method(get_coef,betareg)
S3method(get_coef,bracl)
Expand Down
59 changes: 54 additions & 5 deletions R/autodiff.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ jax_align_group_J <- function(jac_fun, mfx, original, estimates, X, X_hi, X_lo)
if (is.character(mfx@by)) {
# predictions with by=character: use user-specified columns from newdata
bycols <- mfx@by
groups_data <- subset(mfx@newdata, select = bycols)
if (inherits(mfx@newdata, "data.table")) {
groups_data <- mfx@newdata[, ..bycols]
} else {
groups_data <- mfx@newdata[, bycols, drop = FALSE]
}
} else if (!is.null(original)) {
# comparisons with by=TRUE: use term/contrast from original
bycols <- intersect(c("term", "contrast"), colnames(original))
Expand All @@ -76,7 +80,7 @@ jax_align_group_J <- function(jac_fun, mfx, original, estimates, X, X_hi, X_lo)
if (!is.null(X_lo)) X_lo <- X_lo[idx, , drop = FALSE]
# Create group IDs
groups_combined <- apply(groups_data, 1, function(x) paste0(x, collapse = "_"))
groups <- as.integer(as.factor(groups_combined)) - 1L
groups <- as.integer(factor(groups_combined, levels = unique(groups_combined))) - 1L
num_groups <- max(groups) + 1L
} else {
groups <- num_groups <- NULL
Expand Down Expand Up @@ -132,6 +136,12 @@ jax_predictions <- function(mfx, vcov_matrix, ...) {
# Get coefficients
coefs <- get_coef(mfx@model, ...)

# Check for NA coefficients (e.g., from aliased terms)
if (anyNA(coefs)) {
autodiff_warning("models with NA coefficients (aliased terms)")
return(NULL)
}

# Determine aggregation function
if (isFALSE(mfx@by)) {
fun_name <- "predictions"
Expand All @@ -148,6 +158,11 @@ jax_predictions <- function(mfx, vcov_matrix, ...) {
groups <- group_result$groups
num_groups <- group_result$num_groups
X <- group_result$X

# If groups couldn't be created, fall back to finite differences
if (is.null(groups) || is.null(num_groups)) {
return(NULL)
}
}

# Select autodiff function
Expand All @@ -173,10 +188,18 @@ jax_predictions <- function(mfx, vcov_matrix, ...) {
result <- do.call(eval_fun_with_numpy_arrays, c(list(FUN = FUN), args))

# Convert to R objects
J <- as.matrix(result[["jacobian"]])

# Ensure jacobian is (n_predictions x n_coefs), transpose if needed
# Only transpose if we have (n_coefs x 1) instead of (1 x n_coefs)
if (nrow(J) == length(coefs) && ncol(J) == 1) {
J <- t(J)
}

out <- list(
estimate = as.vector(result[["estimate"]]),
std.error = as.vector(result[["std_error"]]),
jacobian = as.matrix(result[["jacobian"]])
jacobian = J
)

# Add column names to jacobian
Expand Down Expand Up @@ -210,7 +233,14 @@ jax_comparisons <- function(mfx, vcov_matrix, hi, lo, original, ...) {
}

if (!is.character(mfx@comparison) || !mfx@comparison %in% c("difference", "ratio")) {
autodiff_warning(sprintf("`comparison='%s'` (only 'difference' and 'ratio' supported)", mfx@comparison))
comp_str <- if (is.character(mfx@comparison)) mfx@comparison else "custom function"
autodiff_warning(sprintf("`comparison='%s'` (only 'difference' and 'ratio' supported)", comp_str))
return(NULL)
}

# Ratio comparisons with by=TRUE compute ratio-then-average instead of average-then-ratio
if (isTRUE(mfx@by) && mfx@comparison == "ratio") {
autodiff_warning("`comparison='ratio'` with `by=TRUE` (averaging order differs from finite differences)")
return(NULL)
}

Expand All @@ -227,6 +257,12 @@ jax_comparisons <- function(mfx, vcov_matrix, hi, lo, original, ...) {
# Get coefficients
coefs <- get_coef(mfx@model, ...)

# Check for NA coefficients (e.g., from aliased terms)
if (anyNA(coefs)) {
autodiff_warning("models with NA coefficients (aliased terms)")
return(NULL)
}

# Map comparison type
comparison_type <- switch(mfx@comparison,
difference = mAD$comparisons$ComparisonType$DIFFERENCE,
Expand All @@ -248,6 +284,11 @@ jax_comparisons <- function(mfx, vcov_matrix, hi, lo, original, ...) {
num_groups <- group_result$num_groups
X_hi <- group_result$X_hi
X_lo <- group_result$X_lo

# If groups couldn't be created, fall back to finite differences
if (is.null(groups) || is.null(num_groups)) {
return(NULL)
}
}

# Select autodiff function
Expand Down Expand Up @@ -275,10 +316,18 @@ jax_comparisons <- function(mfx, vcov_matrix, hi, lo, original, ...) {
result <- do.call(eval_fun_with_numpy_arrays, c(list(FUN = FUN), args))

# Convert to R
J <- as.matrix(result[["jacobian"]])

# Ensure jacobian is (n_comparisons x n_coefs), transpose if needed
# Only transpose if we have (n_coefs x 1) instead of (1 x n_coefs)
if (nrow(J) == length(coefs) && ncol(J) == 1) {
J <- t(J)
}

out <- list(
estimate = as.vector(result[["estimate"]]),
std.error = as.vector(result[["std_error"]]),
jacobian = as.matrix(result[["jacobian"]])
jacobian = J
)

if (!is.null(names(coefs)) && ncol(out$jacobian) == length(coefs)) {
Expand Down
15 changes: 13 additions & 2 deletions R/methods_aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ get_autodiff_args.lm <- function(model, mfx) {
return(NULL)
}

if (!is.null(mfx@wts)) {
autodiff_warning("the `wts` argument")
return(NULL)
}

# Check type support
if (!mfx@type %in% c("response", "link", "invlink(link)")) {
autodiff_warning(sprintf("`type='%s'`", mfx@type))
Expand Down Expand Up @@ -157,15 +162,21 @@ get_autodiff_args.glm <- function(model, mfx) {
return(NULL)
}

if (!is.null(mfx@wts)) {
autodiff_warning("the `wts` argument")
return(NULL)
}

# Check type support
if (!mfx@type %in% c("response", "link", "invlink(link)")) {
autodiff_warning(sprintf("`type='%s'`", mfx@type))
return(NULL)
}

# Check comparison type for comparisons function
if (mfx@calling_function == "comparisons" && !mfx@comparison %in% c("difference", "ratio")) {
autodiff_warning("other functions than `predictions()` or `comparisons()`, with `comparisons='difference'` or `'ratio'`")
if (mfx@calling_function == "comparisons" && (!is.character(mfx@comparison) || !mfx@comparison %in% c("difference", "ratio"))) {
comp_str <- if (is.character(mfx@comparison)) mfx@comparison else "custom function"
autodiff_warning(sprintf("`comparison='%s'` (only 'difference' and 'ratio' supported)", comp_str))
return(NULL)
}

Expand Down
5 changes: 5 additions & 0 deletions R/methods_ivreg.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ get_autodiff_args.ivreg <- function(model, mfx) {
return(NULL)
}

if (!is.null(mfx@wts)) {
autodiff_warning("the `wts` argument")
return(NULL)
}

# If all checks pass, return supported arguments
out <- list(model_type = "linear")
return(out)
Expand Down
129 changes: 66 additions & 63 deletions R/methods_rms.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,66 +77,69 @@ get_predict.ols <- get_predict.rms



#' @keywords internal
#' @export
get_autodiff_args.ols <- function(model, mfx) {
# no inheritance! Important to avoid breaking other models
if (!class(model)[1] == "ols") {
return(NULL)
}

if (!is.null(model$offset)) {
autodiff_warning("models with offsets")
return(NULL)
}

if (!is.null(model$penalty)) {
autodiff_warning("models with offsets")
return(NULL)
}

# Check type support
if (!mfx@type %in% c("lp")) {
autodiff_warning(sprintf("`type='%s'`", mfx@type))
return(NULL)
}

# If all checks pass, return supported arguments
out <- list(model_type = "linear")
return(out)
}


#' @keywords internal
#' @export
get_autodiff_args.lrm <- function(model, mfx) {
# no inheritance! Important to avoid breaking other models
if (!class(model)[1] == "lrm") {
return(NULL)
}

if (!is.null(model$offset)) {
autodiff_warning("models with offsets")
return(NULL)
}

if (!is.null(model$penalty)) {
autodiff_warning("models with offsets")
return(NULL)
}

# Check type support
if (!mfx@type %in% c("fitted", "lp")) {
autodiff_warning(sprintf("`type='%s'`", mfx@type))
return(NULL)
}

# If all checks pass, return supported arguments
mAD <- settings_get("mAD")
out <- list(
model_type = "glm",
family_type = mAD$glm$families$Family$BINOMIAL,
link_type = mAD$glm$families$Link$LOGIT
)
return(out)
}
### AUTODIFF:
### DO NOT DO THIS BECAUSE get_model_matrix() doesn't work with these models

# #' @keywords internal
# #' @export
# get_autodiff_args.ols <- function(model, mfx) {
# # no inheritance! Important to avoid breaking other models
# if (!class(model)[1] == "ols") {
# return(NULL)
# }
#
# if (!is.null(model$offset)) {
# autodiff_warning("models with offsets")
# return(NULL)
# }
#
# if (!is.null(model$penalty)) {
# autodiff_warning("models with offsets")
# return(NULL)
# }
#
# # Check type support
# if (!mfx@type %in% c("lp")) {
# autodiff_warning(sprintf("`type='%s'`", mfx@type))
# return(NULL)
# }
#
# # If all checks pass, return supported arguments
# out <- list(model_type = "linear")
# return(out)
# }
#
#
# #' @keywords internal
# #' @export
# get_autodiff_args.lrm <- function(model, mfx) {
# # no inheritance! Important to avoid breaking other models
# if (!class(model)[1] == "lrm") {
# return(NULL)
# }
#
# if (!is.null(model$offset)) {
# autodiff_warning("models with offsets")
# return(NULL)
# }
#
# if (!is.null(model$penalty)) {
# autodiff_warning("models with offsets")
# return(NULL)
# }
#
# # Check type support
# if (!mfx@type %in% c("fitted", "lp")) {
# autodiff_warning(sprintf("`type='%s'`", mfx@type))
# return(NULL)
# }
#
# # If all checks pass, return supported arguments
# mAD <- settings_get("mAD")
# out <- list(
# model_type = "glm",
# family_type = mAD$glm$families$Family$BINOMIAL,
# link_type = mAD$glm$families$Link$LOGIT
# )
# return(out)
# }
29 changes: 28 additions & 1 deletion R/predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,20 @@ predictions <- function(
mfx@vcov_model <- get_vcov(mfx@model, vcov = vcov, type = if (link_to_response) "link" else mfx@type, ...)

if (isTRUE(checkmate::check_matrix(mfx@vcov_model))) {
# Filter out padding rows before autodiff when using by=character
# Padding is only needed for model.matrix construction, which is already done
if (is.character(mfx@by) && "rowid" %in% colnames(mfx@newdata)) {
idx <- mfx@newdata$rowid > 0
if (!all(idx)) {
# Filter newdata and model matrix
mfx@newdata <- mfx@newdata[idx, , drop = FALSE]
# Also filter the model matrix attribute
MM <- attr(mfx@newdata, "marginaleffects_model_matrix")
if (!is.null(MM)) {
attr(mfx@newdata, "marginaleffects_model_matrix") <- MM[idx, , drop = FALSE]
}
}
}
# Try autodiff
autodiff_result <- jax_predictions(
mfx = mfx,
Expand All @@ -334,6 +348,8 @@ predictions <- function(
if ("rowidcf" %in% colnames(mfx@newdata)) {
tmp[["rowidcf"]] <- mfx@newdata[["rowidcf"]]
}
# Unpad: remove padding rows added for missing factor levels
tmp <- unpad(tmp, draws = NULL)$out
} else {
# Aggregated results (by=TRUE or by=character)
tmp <- data.frame(
Expand All @@ -346,7 +362,18 @@ predictions <- function(
if (is.character(mfx@by)) {
# Extract unique group combinations
bycols <- mfx@by
group_data <- unique(mfx@newdata[, bycols, drop = FALSE])
if (inherits(mfx@newdata, "data.table")) {
group_data <- unique(mfx@newdata[, ..bycols])
data.table::setorderv(group_data, bycols)
group_data <- as.data.frame(group_data)
} else {
group_data <- unique(mfx@newdata[, bycols, drop = FALSE])
if (nrow(group_data) > 1) {
ord <- do.call(order, group_data)
group_data <- group_data[ord, , drop = FALSE]
}
}
rownames(group_data) <- NULL
tmp <- cbind(group_data, tmp)
}
}
Expand Down
Loading