Skip to content

Additional SAGE estimators #48

@jemus42

Description

@jemus42

Currently the only approach is to use the Monte Carlo integration method which is not particularly efficient.

Current setup

  • SAGE base class

    • .evaluate_coalitions_batch() method expands test data with reference data (subset of test data) and averages predictions per observation
    • (reference data == "background data") in Python-sage's MarginalImputer implementation in sage
    • Abstract .marginalize_features() method (line 767) for subclasses to implement (marginalization can be conditional lol)
  • MarginalSAGE subclass:

    • Simple .marginalize_features() implementation (lines 844-850) that replaces masked features with reference values
    • Uses base class .evaluate_coalitions_batch() (expansion + averaging)
  • ConditionalSAGE subclass:

    • Overrides .evaluate_coalitions_batch() to sample from conditional distribution (e.g. ARF)
    • No need to expand with reference data, would just add noise (this was a bug now fixed)
    • Maybe useful: sample multiple obs from conditional sampler and average? (Not possible with KnockoffSampler for now)
      • (check with Marvin / Jan in to not ARF dumbly)

What now?

In the interest of efficiency and future maintainability there's some things to consider:

Imputer classes?

  • sage has "imputer" classes (see above), and while I object to the name they are probably what we also want, as they encapsulate
    • A trained learner (something to $predict_newdata() with)
    • A sampler (e.g. ConditionalSampler for us, like ARF, also Knockoffs)
      • (but not MarginalSampler though because the "reference data expansion approach" is used there for efficiency)
    • (I used to assume "Imputers" where the analogue to our FeatureSampler, but the important bit is that an Imputer can make predictions)

Estimation procedures (Permutation, Kernel SAGE, "naive"?)

  • There is also the estimation procedure, currently we only support the "permutation" approach:
    • shuffle feature names, and the evaluate coalitions of the first j features in the vector:
task = mlr3::tsk("mtcars")
num_perms = 5
all_perms = replicate(num_perms, sample(task$feature_names), simplify = FALSE)

for (permutation_idx in seq_len(num_perms)) {
	for (j in seq_len(task$n_features)) {
	coalition = all_perms[[permutation_idx]][1:j]
	cli::cli_inform("Evaluating coalition {.val {coalition}}")
	# predict on test data with features out of coalition marginalized / replaced with sampled values
	}
}
  • Not yet implemented but very needed for speed: Kernel SAGE (See also Kernel SHAP)
  • Nice to have: "Naive" estimator that, given a task has few features, exhaustively evaluates all coalitions
    • Would be nice for "deterministic" results, especially for our simulation data with <= 5 features
  • -> "Estimator" class or overridable private methods?
    • Needs access to Imputer
    • learner is trained during resampling initially, so trained learner needs to be passed around and isn't (for now) known before SAGE is instantiated
  • There's also the causal structural learning approach I may need to look into

Nice to haves / misc

  • Speed: To use learner$predict_newdata_fast() we can't use GraphLearners so we need to check that beforehand and maybe even warn or message?
  • Convergence tracking and SE calculation for SAGE values is nice but adds complexity
    • also these SEs are not suitable for inference and only calculated in the first resampling iteration

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions