-
Notifications
You must be signed in to change notification settings - Fork 217
Description
Hi all,
I’m the author of #454.
It’s been a while since our last discussion, but I’d like to revisit this and get your input before I start contributing code.
Context
I was experimenting with DiCE in a heterogeneous graph learning setting, where I wanted to see which node features need to change to achieve a desired probability. My problem involved 3 classes.
Current limitations I ran into
The dice-torch class appears to only support binary classification, treating one outcome as “bad” (prob→0) and the other as “good” (prob→1). This prevents straightforward use in multiclass problems.
Arbitrary thresholds passed by the user are ignored and defaulted to 0.75.
The current setup makes it hard to:
- target changes in probabilities (e.g., increase a class probability by +7%),
- or return “best-effort” counterfactuals when thresholds are unattainable.
Proposal
- Generalize the treatment of outcomes from a scalar probability to a vector-based format.
- Most classifiers already output softmax vectors.
- If we instead optimize a target index of the probability vector (provided at construction), we can support multiclass problems directly.
Add options to:
- target examples that gave the largest probability for the desired class (vs. closest to threshold),
- store and return best-effort counterfactuals when thresholds cannot be met,
- support arbitrary probability thresholds (e.g., “increase probability of class X by 7%”).
Benefits
- Extends DiCE beyond binary-only settings into general multiclass classification.
- Enables use in probabilistic workflows (e.g., Monte Carlo simulations), where shifts in probability distributions are more meaningful than passing a single threshold.
- Reuses most of DiCE’s existing machinery with minimal conceptual changes.
If this direction makes sense to you, I’d be happy to start drafting test cases and working on a prototype implementation.
Thanks,
Konstantin