@@ -88,10 +88,11 @@ def __init__(
88
88
interpretable_model (Model): Model object to train interpretable model.
89
89
A Model object provides a `fit` method to train the model,
90
90
given a dataloader, with batches containing three tensors:
91
- interpretable_inputs: Tensor
92
- [2D num_samples x num_interp_features],
93
- expected_outputs: Tensor [1D num_samples],
94
- weights: Tensor [1D num_samples]
91
+
92
+ - interpretable_inputs: Tensor
93
+ [2D num_samples x num_interp_features],
94
+ - expected_outputs: Tensor [1D num_samples],
95
+ - weights: Tensor [1D num_samples]
95
96
96
97
The model object must also provide a `representation` method to
97
98
access the appropriate coefficients or representation of the
@@ -113,16 +114,16 @@ def __init__(
113
114
114
115
The expected signature of this callable is:
115
116
116
- similarity_func(
117
- original_input: Tensor or tuple of Tensors,
118
- perturbed_input: Tensor or tuple of Tensors,
119
- perturbed_interpretable_input:
120
- Tensor [2D 1 x num_interp_features],
121
- **kwargs: Any
122
- ) -> float or Tensor containing float scalar
117
+ >>> similarity_func(
118
+ >>> original_input: Tensor or tuple of Tensors,
119
+ >>> perturbed_input: Tensor or tuple of Tensors,
120
+ >>> perturbed_interpretable_input:
121
+ >>> Tensor [2D 1 x num_interp_features],
122
+ >>> **kwargs: Any
123
+ >>> ) -> float or Tensor containing float scalar
123
124
124
125
perturbed_input and original_input will be the same type and
125
- contain tensors of the same shape (regardless of whether
126
+ contain tensors of the same shape (regardless of whether or not
126
127
the sampling function returns inputs in the interpretable
127
128
space). original_input is the same as the input provided
128
129
when calling attribute.
@@ -139,10 +140,10 @@ def __init__(
139
140
140
141
The expected signature of this callable is:
141
142
142
- perturb_func(
143
- original_input: Tensor or tuple of Tensors,
144
- **kwargs: Any
145
- ) -> Tensor or tuple of Tensors
143
+ >>> perturb_func(
144
+ >>> original_input: Tensor or tuple of Tensors,
145
+ >>> **kwargs: Any
146
+ >>> ) -> Tensor or tuple of Tensors
146
147
147
148
All kwargs passed to the attribute method are
148
149
provided as keyword arguments (kwargs) to this callable.
@@ -175,11 +176,11 @@ def __init__(
175
176
176
177
The expected signature of this callable is:
177
178
178
- from_interp_rep_transform(
179
- curr_sample: Tensor [2D 1 x num_interp_features]
180
- original_input: Tensor or Tuple of Tensors,
181
- **kwargs: Any
182
- ) -> Tensor or tuple of Tensors
179
+ >>> from_interp_rep_transform(
180
+ >>> curr_sample: Tensor [2D 1 x num_interp_features]
181
+ >>> original_input: Tensor or Tuple of Tensors,
182
+ >>> **kwargs: Any
183
+ >>> ) -> Tensor or tuple of Tensors
183
184
184
185
Returned sampled input should match the type of original_input
185
186
and corresponding tensor shapes.
@@ -197,11 +198,11 @@ def __init__(
197
198
198
199
The expected signature of this callable is:
199
200
200
- to_interp_rep_transform(
201
- curr_sample: Tensor or Tuple of Tensors,
202
- original_input: Tensor or Tuple of Tensors,
203
- **kwargs: Any
204
- ) -> Tensor [2D 1 x num_interp_features]
201
+ >>> to_interp_rep_transform(
202
+ >>> curr_sample: Tensor or Tuple of Tensors,
203
+ >>> original_input: Tensor or Tuple of Tensors,
204
+ >>> **kwargs: Any
205
+ >>> ) -> Tensor [2D 1 x num_interp_features]
205
206
206
207
curr_sample will match the type of original_input
207
208
and corresponding tensor shapes.
@@ -640,9 +641,9 @@ class Lime(LimeBase):
640
641
def __init__ (
641
642
self ,
642
643
forward_func : Callable ,
643
- train_interpretable_model_func : Model = SkLearnLasso ( alpha = 1.0 ) ,
644
- similarity_func : Callable = get_exp_kernel_similarity_function () ,
645
- perturb_func : Callable = default_perturb_func ,
644
+ interpretable_model : Optional [ Model ] = None ,
645
+ similarity_func : Optional [ Callable ] = None ,
646
+ perturb_func : Optional [ Callable ] = None ,
646
647
) -> None :
647
648
r"""
648
649
@@ -664,10 +665,11 @@ def __init__(
664
665
Alternatively, a custom model object must provide a `fit` method to
665
666
train the model, given a dataloader, with batches containing
666
667
three tensors:
667
- interpretable_inputs: Tensor
668
- [2D num_samples x num_interp_features],
669
- expected_outputs: Tensor [1D num_samples],
670
- weights: Tensor [1D num_samples]
668
+
669
+ - interpretable_inputs: Tensor
670
+ [2D num_samples x num_interp_features],
671
+ - expected_outputs: Tensor [1D num_samples],
672
+ - weights: Tensor [1D num_samples]
671
673
672
674
The model object must also provide a `representation` method to
673
675
access the appropriate coefficients or representation of the
@@ -676,52 +678,72 @@ def __init__(
676
678
Note that calling fit multiple times should retrain the
677
679
interpretable model, each attribution call reuses
678
680
the same given interpretable model object.
679
- similarity_func (callable): Function which takes a single sample
681
+ similarity_func (optional, callable): Function which takes a single sample
680
682
along with its corresponding interpretable representation
681
683
and returns the weight of the interpretable sample for
682
684
training the interpretable model.
683
685
This is often referred to as a similarity kernel.
684
686
687
+ This argument is optional and defaults to a function which
688
+ applies an exponential kernel to the consine distance between
689
+ the original input and perturbed input, with a kernel width
690
+ of 1.0.
691
+
692
+ A similarity function applying an exponential
693
+ kernel to cosine / euclidean distances can be constructed
694
+ using the provided get_exp_kernel_similarity_function in
695
+ captum.attr._core.lime.
696
+
697
+ Alternately, a custom callable can also be provided.
685
698
The expected signature of this callable is:
686
699
687
- similarity_func(
688
- original_input: Tensor or tuple of Tensors,
689
- perturbed_input: Tensor or tuple of Tensors,
690
- perturbed_interpretable_input:
691
- Tensor [2D 1 x num_interp_features],
692
- **kwargs: Any
693
- ) -> float or Tensor containing float scalar
700
+ >>> def similarity_func(
701
+ >>> original_input: Tensor or tuple of Tensors,
702
+ >>> perturbed_input: Tensor or tuple of Tensors,
703
+ >>> perturbed_interpretable_input:
704
+ >>> Tensor [2D 1 x num_interp_features],
705
+ >>> **kwargs: Any
706
+ >>> ) -> float or Tensor containing float scalar
694
707
695
708
perturbed_input and original_input will be the same type and
696
709
contain tensors of the same shape, with original_input
697
710
being the same as the input provided when calling attribute.
698
711
699
712
kwargs includes baselines, feature_mask, num_interp_features
700
- (integer, determined from feature mask), and
701
- alpha (for Lasso regression).
702
- perturb_func (callable): Function which returns a single
713
+ (integer, determined from feature mask).
714
+ perturb_func (optional, callable): Function which returns a single
703
715
sampled input, which is a binary vector of length
704
- num_interp_features. The default function returns
716
+ num_interp_features.
717
+
718
+ This function is optional, the default function returns
705
719
a binary vector where each element is selected
706
720
independently and uniformly at random. Custom
707
721
logic for selecting sampled binary vectors can
708
722
be implemented by providing a function with the
709
723
following expected signature:
710
724
711
- perturb_func(
712
- original_input: Tensor or tuple of Tensors,
713
- **kwargs: Any
714
- ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
725
+ >>> perturb_func(
726
+ >>> original_input: Tensor or tuple of Tensors,
727
+ >>> **kwargs: Any
728
+ >>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
715
729
716
730
kwargs includes baselines, feature_mask, num_interp_features
717
- (integer, determined from feature mask), and
718
- alpha (for Lasso regression).
731
+ (integer, determined from feature mask).
719
732
720
733
"""
734
+ if interpretable_model is None :
735
+ interpretable_model = SkLearnLasso (alpha = 1.0 )
736
+
737
+ if similarity_func is None :
738
+ similarity_func = get_exp_kernel_similarity_function ()
739
+
740
+ if perturb_func is None :
741
+ perturb_func = default_perturb_func
742
+
721
743
LimeBase .__init__ (
722
744
self ,
723
745
forward_func ,
724
- train_interpretable_model_func ,
746
+ interpretable_model ,
725
747
similarity_func ,
726
748
perturb_func ,
727
749
True ,
@@ -788,27 +810,27 @@ def attribute( # type: ignore
788
810
Baselines can be provided as:
789
811
790
812
- a single tensor, if inputs is a single tensor, with
791
- exactly the same dimensions as inputs or the first
792
- dimension is one and the remaining dimensions match
793
- with inputs.
813
+ exactly the same dimensions as inputs or the first
814
+ dimension is one and the remaining dimensions match
815
+ with inputs.
794
816
795
817
- a single scalar, if inputs is a single tensor, which will
796
- be broadcasted for each input value in input tensor.
818
+ be broadcasted for each input value in input tensor.
797
819
798
820
- a tuple of tensors or scalars, the baseline corresponding
799
- to each tensor in the inputs' tuple can be:
821
+ to each tensor in the inputs' tuple can be:
800
822
801
- - either a tensor with matching dimensions to
823
+ - either a tensor with matching dimensions to
802
824
corresponding tensor in the inputs' tuple
803
825
or the first dimension is one and the remaining
804
826
dimensions match with the corresponding
805
827
input tensor.
806
828
807
- - or a scalar, corresponding to a tensor in the
829
+ - or a scalar, corresponding to a tensor in the
808
830
inputs' tuple. This scalar value is broadcasted
809
831
for corresponding input tensor.
810
832
In the cases when `baselines` is not provided, we internally
811
- use zero corresponding to each input tensor.
833
+ use zero scalar corresponding to each input tensor.
812
834
Default: None
813
835
target (int, tuple, tensor or list, optional): Output indices for
814
836
which surrogate model is trained
@@ -886,10 +908,6 @@ def attribute( # type: ignore
886
908
If the forward function returns a single scalar per batch,
887
909
perturbations_per_eval must be set to 1.
888
910
Default: 1
889
- alpha (float, optional): Alpha used for training interpretable surrogate
890
- model in Lasso Regression. This parameter is used only
891
- if using default interpretable model trainer (Lasso).
892
- Default: 1.0
893
911
return_input_shape (bool, optional): Determines whether the returned
894
912
tensor(s) only contain the coefficients for each interp-
895
913
retable feature from the trained surrogate model, or
0 commit comments