Skip to content

Commit 83a7e3e

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Fixes to Lime Model Construction and API Docs (#525)
Summary: Minor fixes to LIME API documentation and updating so model construction occurs in constructor rather than in signature. Pull Request resolved: #525 Reviewed By: miguelmartin75 Differential Revision: D24930693 Pulled By: vivekmig fbshipit-source-id: 68465f69f26c7fa815f07d1928ec7057e85fbb3a
1 parent a1f07de commit 83a7e3e

File tree

3 files changed

+92
-72
lines changed

3 files changed

+92
-72
lines changed

captum/attr/_core/kernel_shap.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,27 @@ def attribute( # type: ignore
130130
Baselines can be provided as:
131131
132132
- a single tensor, if inputs is a single tensor, with
133-
exactly the same dimensions as inputs or the first
134-
dimension is one and the remaining dimensions match
135-
with inputs.
133+
exactly the same dimensions as inputs or the first
134+
dimension is one and the remaining dimensions match
135+
with inputs.
136136
137137
- a single scalar, if inputs is a single tensor, which will
138-
be broadcasted for each input value in input tensor.
138+
be broadcasted for each input value in input tensor.
139139
140140
- a tuple of tensors or scalars, the baseline corresponding
141-
to each tensor in the inputs' tuple can be:
141+
to each tensor in the inputs' tuple can be:
142142
143-
- either a tensor with matching dimensions to
143+
- either a tensor with matching dimensions to
144144
corresponding tensor in the inputs' tuple
145145
or the first dimension is one and the remaining
146146
dimensions match with the corresponding
147147
input tensor.
148148
149-
- or a scalar, corresponding to a tensor in the
149+
- or a scalar, corresponding to a tensor in the
150150
inputs' tuple. This scalar value is broadcasted
151151
for corresponding input tensor.
152152
In the cases when `baselines` is not provided, we internally
153-
use zero corresponding to each input tensor.
153+
use zero scalar corresponding to each input tensor.
154154
Default: None
155155
target (int, tuple, tensor or list, optional): Output indices for
156156
which surrogate model is trained

captum/attr/_core/lime.py

Lines changed: 82 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ def __init__(
8888
interpretable_model (Model): Model object to train interpretable model.
8989
A Model object provides a `fit` method to train the model,
9090
given a dataloader, with batches containing three tensors:
91-
interpretable_inputs: Tensor
92-
[2D num_samples x num_interp_features],
93-
expected_outputs: Tensor [1D num_samples],
94-
weights: Tensor [1D num_samples]
91+
92+
- interpretable_inputs: Tensor
93+
[2D num_samples x num_interp_features],
94+
- expected_outputs: Tensor [1D num_samples],
95+
- weights: Tensor [1D num_samples]
9596
9697
The model object must also provide a `representation` method to
9798
access the appropriate coefficients or representation of the
@@ -113,16 +114,16 @@ def __init__(
113114
114115
The expected signature of this callable is:
115116
116-
similarity_func(
117-
original_input: Tensor or tuple of Tensors,
118-
perturbed_input: Tensor or tuple of Tensors,
119-
perturbed_interpretable_input:
120-
Tensor [2D 1 x num_interp_features],
121-
**kwargs: Any
122-
) -> float or Tensor containing float scalar
117+
>>> similarity_func(
118+
>>> original_input: Tensor or tuple of Tensors,
119+
>>> perturbed_input: Tensor or tuple of Tensors,
120+
>>> perturbed_interpretable_input:
121+
>>> Tensor [2D 1 x num_interp_features],
122+
>>> **kwargs: Any
123+
>>> ) -> float or Tensor containing float scalar
123124
124125
perturbed_input and original_input will be the same type and
125-
contain tensors of the same shape (regardless of whether
126+
contain tensors of the same shape (regardless of whether or not
126127
the sampling function returns inputs in the interpretable
127128
space). original_input is the same as the input provided
128129
when calling attribute.
@@ -139,10 +140,10 @@ def __init__(
139140
140141
The expected signature of this callable is:
141142
142-
perturb_func(
143-
original_input: Tensor or tuple of Tensors,
144-
**kwargs: Any
145-
) -> Tensor or tuple of Tensors
143+
>>> perturb_func(
144+
>>> original_input: Tensor or tuple of Tensors,
145+
>>> **kwargs: Any
146+
>>> ) -> Tensor or tuple of Tensors
146147
147148
All kwargs passed to the attribute method are
148149
provided as keyword arguments (kwargs) to this callable.
@@ -175,11 +176,11 @@ def __init__(
175176
176177
The expected signature of this callable is:
177178
178-
from_interp_rep_transform(
179-
curr_sample: Tensor [2D 1 x num_interp_features]
180-
original_input: Tensor or Tuple of Tensors,
181-
**kwargs: Any
182-
) -> Tensor or tuple of Tensors
179+
>>> from_interp_rep_transform(
180+
>>> curr_sample: Tensor [2D 1 x num_interp_features]
181+
>>> original_input: Tensor or Tuple of Tensors,
182+
>>> **kwargs: Any
183+
>>> ) -> Tensor or tuple of Tensors
183184
184185
Returned sampled input should match the type of original_input
185186
and corresponding tensor shapes.
@@ -197,11 +198,11 @@ def __init__(
197198
198199
The expected signature of this callable is:
199200
200-
to_interp_rep_transform(
201-
curr_sample: Tensor or Tuple of Tensors,
202-
original_input: Tensor or Tuple of Tensors,
203-
**kwargs: Any
204-
) -> Tensor [2D 1 x num_interp_features]
201+
>>> to_interp_rep_transform(
202+
>>> curr_sample: Tensor or Tuple of Tensors,
203+
>>> original_input: Tensor or Tuple of Tensors,
204+
>>> **kwargs: Any
205+
>>> ) -> Tensor [2D 1 x num_interp_features]
205206
206207
curr_sample will match the type of original_input
207208
and corresponding tensor shapes.
@@ -640,9 +641,9 @@ class Lime(LimeBase):
640641
def __init__(
641642
self,
642643
forward_func: Callable,
643-
train_interpretable_model_func: Model = SkLearnLasso(alpha=1.0),
644-
similarity_func: Callable = get_exp_kernel_similarity_function(),
645-
perturb_func: Callable = default_perturb_func,
644+
interpretable_model: Optional[Model] = None,
645+
similarity_func: Optional[Callable] = None,
646+
perturb_func: Optional[Callable] = None,
646647
) -> None:
647648
r"""
648649
@@ -664,10 +665,11 @@ def __init__(
664665
Alternatively, a custom model object must provide a `fit` method to
665666
train the model, given a dataloader, with batches containing
666667
three tensors:
667-
interpretable_inputs: Tensor
668-
[2D num_samples x num_interp_features],
669-
expected_outputs: Tensor [1D num_samples],
670-
weights: Tensor [1D num_samples]
668+
669+
- interpretable_inputs: Tensor
670+
[2D num_samples x num_interp_features],
671+
- expected_outputs: Tensor [1D num_samples],
672+
- weights: Tensor [1D num_samples]
671673
672674
The model object must also provide a `representation` method to
673675
access the appropriate coefficients or representation of the
@@ -676,52 +678,72 @@ def __init__(
676678
Note that calling fit multiple times should retrain the
677679
interpretable model, each attribution call reuses
678680
the same given interpretable model object.
679-
similarity_func (callable): Function which takes a single sample
681+
similarity_func (optional, callable): Function which takes a single sample
680682
along with its corresponding interpretable representation
681683
and returns the weight of the interpretable sample for
682684
training the interpretable model.
683685
This is often referred to as a similarity kernel.
684686
687+
This argument is optional and defaults to a function which
688+
applies an exponential kernel to the consine distance between
689+
the original input and perturbed input, with a kernel width
690+
of 1.0.
691+
692+
A similarity function applying an exponential
693+
kernel to cosine / euclidean distances can be constructed
694+
using the provided get_exp_kernel_similarity_function in
695+
captum.attr._core.lime.
696+
697+
Alternately, a custom callable can also be provided.
685698
The expected signature of this callable is:
686699
687-
similarity_func(
688-
original_input: Tensor or tuple of Tensors,
689-
perturbed_input: Tensor or tuple of Tensors,
690-
perturbed_interpretable_input:
691-
Tensor [2D 1 x num_interp_features],
692-
**kwargs: Any
693-
) -> float or Tensor containing float scalar
700+
>>> def similarity_func(
701+
>>> original_input: Tensor or tuple of Tensors,
702+
>>> perturbed_input: Tensor or tuple of Tensors,
703+
>>> perturbed_interpretable_input:
704+
>>> Tensor [2D 1 x num_interp_features],
705+
>>> **kwargs: Any
706+
>>> ) -> float or Tensor containing float scalar
694707
695708
perturbed_input and original_input will be the same type and
696709
contain tensors of the same shape, with original_input
697710
being the same as the input provided when calling attribute.
698711
699712
kwargs includes baselines, feature_mask, num_interp_features
700-
(integer, determined from feature mask), and
701-
alpha (for Lasso regression).
702-
perturb_func (callable): Function which returns a single
713+
(integer, determined from feature mask).
714+
perturb_func (optional, callable): Function which returns a single
703715
sampled input, which is a binary vector of length
704-
num_interp_features. The default function returns
716+
num_interp_features.
717+
718+
This function is optional, the default function returns
705719
a binary vector where each element is selected
706720
independently and uniformly at random. Custom
707721
logic for selecting sampled binary vectors can
708722
be implemented by providing a function with the
709723
following expected signature:
710724
711-
perturb_func(
712-
original_input: Tensor or tuple of Tensors,
713-
**kwargs: Any
714-
) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
725+
>>> perturb_func(
726+
>>> original_input: Tensor or tuple of Tensors,
727+
>>> **kwargs: Any
728+
>>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
715729
716730
kwargs includes baselines, feature_mask, num_interp_features
717-
(integer, determined from feature mask), and
718-
alpha (for Lasso regression).
731+
(integer, determined from feature mask).
719732
720733
"""
734+
if interpretable_model is None:
735+
interpretable_model = SkLearnLasso(alpha=1.0)
736+
737+
if similarity_func is None:
738+
similarity_func = get_exp_kernel_similarity_function()
739+
740+
if perturb_func is None:
741+
perturb_func = default_perturb_func
742+
721743
LimeBase.__init__(
722744
self,
723745
forward_func,
724-
train_interpretable_model_func,
746+
interpretable_model,
725747
similarity_func,
726748
perturb_func,
727749
True,
@@ -788,27 +810,27 @@ def attribute( # type: ignore
788810
Baselines can be provided as:
789811
790812
- a single tensor, if inputs is a single tensor, with
791-
exactly the same dimensions as inputs or the first
792-
dimension is one and the remaining dimensions match
793-
with inputs.
813+
exactly the same dimensions as inputs or the first
814+
dimension is one and the remaining dimensions match
815+
with inputs.
794816
795817
- a single scalar, if inputs is a single tensor, which will
796-
be broadcasted for each input value in input tensor.
818+
be broadcasted for each input value in input tensor.
797819
798820
- a tuple of tensors or scalars, the baseline corresponding
799-
to each tensor in the inputs' tuple can be:
821+
to each tensor in the inputs' tuple can be:
800822
801-
- either a tensor with matching dimensions to
823+
- either a tensor with matching dimensions to
802824
corresponding tensor in the inputs' tuple
803825
or the first dimension is one and the remaining
804826
dimensions match with the corresponding
805827
input tensor.
806828
807-
- or a scalar, corresponding to a tensor in the
829+
- or a scalar, corresponding to a tensor in the
808830
inputs' tuple. This scalar value is broadcasted
809831
for corresponding input tensor.
810832
In the cases when `baselines` is not provided, we internally
811-
use zero corresponding to each input tensor.
833+
use zero scalar corresponding to each input tensor.
812834
Default: None
813835
target (int, tuple, tensor or list, optional): Output indices for
814836
which surrogate model is trained
@@ -886,10 +908,6 @@ def attribute( # type: ignore
886908
If the forward function returns a single scalar per batch,
887909
perturbations_per_eval must be set to 1.
888910
Default: 1
889-
alpha (float, optional): Alpha used for training interpretable surrogate
890-
model in Lasso Regression. This parameter is used only
891-
if using default interpretable model trainer (Lasso).
892-
Default: 1.0
893911
return_input_shape (bool, optional): Determines whether the returned
894912
tensor(s) only contain the coefficients for each interp-
895913
retable feature from the trained surrogate model, or

sphinx/source/lime.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ Lime
55
:members:
66
.. autoclass:: captum.attr.Lime
77
:members:
8+
9+
.. autofunction:: captum.attr._core.lime.get_exp_kernel_similarity_function

0 commit comments

Comments
 (0)