Skip to content

Commit

Permalink
Fixes to Lime Model Construction and API Docs (#525)
Browse files Browse the repository at this point in the history
Summary:
Minor fixes to LIME API documentation and updating so model construction occurs in constructor rather than in signature.

Pull Request resolved: #525

Reviewed By: miguelmartin75

Differential Revision: D24930693

Pulled By: vivekmig

fbshipit-source-id: 68465f69f26c7fa815f07d1928ec7057e85fbb3a
  • Loading branch information
vivekmig authored and facebook-github-bot committed Nov 13, 2020
1 parent a1f07de commit 83a7e3e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 72 deletions.
16 changes: 8 additions & 8 deletions captum/attr/_core/kernel_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,27 @@ def attribute( # type: ignore
Baselines can be provided as:
- a single tensor, if inputs is a single tensor, with
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
- a single scalar, if inputs is a single tensor, which will
be broadcasted for each input value in input tensor.
be broadcasted for each input value in input tensor.
- a tuple of tensors or scalars, the baseline corresponding
to each tensor in the inputs' tuple can be:
to each tensor in the inputs' tuple can be:
- either a tensor with matching dimensions to
- either a tensor with matching dimensions to
corresponding tensor in the inputs' tuple
or the first dimension is one and the remaining
dimensions match with the corresponding
input tensor.
- or a scalar, corresponding to a tensor in the
- or a scalar, corresponding to a tensor in the
inputs' tuple. This scalar value is broadcasted
for corresponding input tensor.
In the cases when `baselines` is not provided, we internally
use zero corresponding to each input tensor.
use zero scalar corresponding to each input tensor.
Default: None
target (int, tuple, tensor or list, optional): Output indices for
which surrogate model is trained
Expand Down
146 changes: 82 additions & 64 deletions captum/attr/_core/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ def __init__(
interpretable_model (Model): Model object to train interpretable model.
A Model object provides a `fit` method to train the model,
given a dataloader, with batches containing three tensors:
interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
expected_outputs: Tensor [1D num_samples],
weights: Tensor [1D num_samples]
- interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
- expected_outputs: Tensor [1D num_samples],
- weights: Tensor [1D num_samples]
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
Expand All @@ -113,16 +114,16 @@ def __init__(
The expected signature of this callable is:
similarity_func(
original_input: Tensor or tuple of Tensors,
perturbed_input: Tensor or tuple of Tensors,
perturbed_interpretable_input:
Tensor [2D 1 x num_interp_features],
**kwargs: Any
) -> float or Tensor containing float scalar
>>> similarity_func(
>>> original_input: Tensor or tuple of Tensors,
>>> perturbed_input: Tensor or tuple of Tensors,
>>> perturbed_interpretable_input:
>>> Tensor [2D 1 x num_interp_features],
>>> **kwargs: Any
>>> ) -> float or Tensor containing float scalar
perturbed_input and original_input will be the same type and
contain tensors of the same shape (regardless of whether
contain tensors of the same shape (regardless of whether or not
the sampling function returns inputs in the interpretable
space). original_input is the same as the input provided
when calling attribute.
Expand All @@ -139,10 +140,10 @@ def __init__(
The expected signature of this callable is:
perturb_func(
original_input: Tensor or tuple of Tensors,
**kwargs: Any
) -> Tensor or tuple of Tensors
>>> perturb_func(
>>> original_input: Tensor or tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor or tuple of Tensors
All kwargs passed to the attribute method are
provided as keyword arguments (kwargs) to this callable.
Expand Down Expand Up @@ -175,11 +176,11 @@ def __init__(
The expected signature of this callable is:
from_interp_rep_transform(
curr_sample: Tensor [2D 1 x num_interp_features]
original_input: Tensor or Tuple of Tensors,
**kwargs: Any
) -> Tensor or tuple of Tensors
>>> from_interp_rep_transform(
>>> curr_sample: Tensor [2D 1 x num_interp_features]
>>> original_input: Tensor or Tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor or tuple of Tensors
Returned sampled input should match the type of original_input
and corresponding tensor shapes.
Expand All @@ -197,11 +198,11 @@ def __init__(
The expected signature of this callable is:
to_interp_rep_transform(
curr_sample: Tensor or Tuple of Tensors,
original_input: Tensor or Tuple of Tensors,
**kwargs: Any
) -> Tensor [2D 1 x num_interp_features]
>>> to_interp_rep_transform(
>>> curr_sample: Tensor or Tuple of Tensors,
>>> original_input: Tensor or Tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor [2D 1 x num_interp_features]
curr_sample will match the type of original_input
and corresponding tensor shapes.
Expand Down Expand Up @@ -640,9 +641,9 @@ class Lime(LimeBase):
def __init__(
self,
forward_func: Callable,
train_interpretable_model_func: Model = SkLearnLasso(alpha=1.0),
similarity_func: Callable = get_exp_kernel_similarity_function(),
perturb_func: Callable = default_perturb_func,
interpretable_model: Optional[Model] = None,
similarity_func: Optional[Callable] = None,
perturb_func: Optional[Callable] = None,
) -> None:
r"""
Expand All @@ -664,10 +665,11 @@ def __init__(
Alternatively, a custom model object must provide a `fit` method to
train the model, given a dataloader, with batches containing
three tensors:
interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
expected_outputs: Tensor [1D num_samples],
weights: Tensor [1D num_samples]
- interpretable_inputs: Tensor
[2D num_samples x num_interp_features],
- expected_outputs: Tensor [1D num_samples],
- weights: Tensor [1D num_samples]
The model object must also provide a `representation` method to
access the appropriate coefficients or representation of the
Expand All @@ -676,52 +678,72 @@ def __init__(
Note that calling fit multiple times should retrain the
interpretable model, each attribution call reuses
the same given interpretable model object.
similarity_func (callable): Function which takes a single sample
similarity_func (optional, callable): Function which takes a single sample
along with its corresponding interpretable representation
and returns the weight of the interpretable sample for
training the interpretable model.
This is often referred to as a similarity kernel.
This argument is optional and defaults to a function which
applies an exponential kernel to the consine distance between
the original input and perturbed input, with a kernel width
of 1.0.
A similarity function applying an exponential
kernel to cosine / euclidean distances can be constructed
using the provided get_exp_kernel_similarity_function in
captum.attr._core.lime.
Alternately, a custom callable can also be provided.
The expected signature of this callable is:
similarity_func(
original_input: Tensor or tuple of Tensors,
perturbed_input: Tensor or tuple of Tensors,
perturbed_interpretable_input:
Tensor [2D 1 x num_interp_features],
**kwargs: Any
) -> float or Tensor containing float scalar
>>> def similarity_func(
>>> original_input: Tensor or tuple of Tensors,
>>> perturbed_input: Tensor or tuple of Tensors,
>>> perturbed_interpretable_input:
>>> Tensor [2D 1 x num_interp_features],
>>> **kwargs: Any
>>> ) -> float or Tensor containing float scalar
perturbed_input and original_input will be the same type and
contain tensors of the same shape, with original_input
being the same as the input provided when calling attribute.
kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask), and
alpha (for Lasso regression).
perturb_func (callable): Function which returns a single
(integer, determined from feature mask).
perturb_func (optional, callable): Function which returns a single
sampled input, which is a binary vector of length
num_interp_features. The default function returns
num_interp_features.
This function is optional, the default function returns
a binary vector where each element is selected
independently and uniformly at random. Custom
logic for selecting sampled binary vectors can
be implemented by providing a function with the
following expected signature:
perturb_func(
original_input: Tensor or tuple of Tensors,
**kwargs: Any
) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
>>> perturb_func(
>>> original_input: Tensor or tuple of Tensors,
>>> **kwargs: Any
>>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
kwargs includes baselines, feature_mask, num_interp_features
(integer, determined from feature mask), and
alpha (for Lasso regression).
(integer, determined from feature mask).
"""
if interpretable_model is None:
interpretable_model = SkLearnLasso(alpha=1.0)

if similarity_func is None:
similarity_func = get_exp_kernel_similarity_function()

if perturb_func is None:
perturb_func = default_perturb_func

LimeBase.__init__(
self,
forward_func,
train_interpretable_model_func,
interpretable_model,
similarity_func,
perturb_func,
True,
Expand Down Expand Up @@ -788,27 +810,27 @@ def attribute( # type: ignore
Baselines can be provided as:
- a single tensor, if inputs is a single tensor, with
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
exactly the same dimensions as inputs or the first
dimension is one and the remaining dimensions match
with inputs.
- a single scalar, if inputs is a single tensor, which will
be broadcasted for each input value in input tensor.
be broadcasted for each input value in input tensor.
- a tuple of tensors or scalars, the baseline corresponding
to each tensor in the inputs' tuple can be:
to each tensor in the inputs' tuple can be:
- either a tensor with matching dimensions to
- either a tensor with matching dimensions to
corresponding tensor in the inputs' tuple
or the first dimension is one and the remaining
dimensions match with the corresponding
input tensor.
- or a scalar, corresponding to a tensor in the
- or a scalar, corresponding to a tensor in the
inputs' tuple. This scalar value is broadcasted
for corresponding input tensor.
In the cases when `baselines` is not provided, we internally
use zero corresponding to each input tensor.
use zero scalar corresponding to each input tensor.
Default: None
target (int, tuple, tensor or list, optional): Output indices for
which surrogate model is trained
Expand Down Expand Up @@ -886,10 +908,6 @@ def attribute( # type: ignore
If the forward function returns a single scalar per batch,
perturbations_per_eval must be set to 1.
Default: 1
alpha (float, optional): Alpha used for training interpretable surrogate
model in Lasso Regression. This parameter is used only
if using default interpretable model trainer (Lasso).
Default: 1.0
return_input_shape (bool, optional): Determines whether the returned
tensor(s) only contain the coefficients for each interp-
retable feature from the trained surrogate model, or
Expand Down
2 changes: 2 additions & 0 deletions sphinx/source/lime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ Lime
:members:
.. autoclass:: captum.attr.Lime
:members:

.. autofunction:: captum.attr._core.lime.get_exp_kernel_similarity_function

0 comments on commit 83a7e3e

Please sign in to comment.