-
Notifications
You must be signed in to change notification settings - Fork 256
/
Copy pathintegrated_gradients.py
1380 lines (1169 loc) · 55.6 KB
/
integrated_gradients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import logging
import string
import warnings
from enum import Enum
from typing import Callable, List, Optional, Tuple, Union, cast
import numpy as np
import tensorflow as tf
from alibi.api.defaults import DEFAULT_DATA_INTGRAD, DEFAULT_META_INTGRAD
from alibi.api.interfaces import Explainer, Explanation
from alibi.utils.approximation_methods import approximation_parameters
logger = logging.getLogger(__name__)
_valid_output_shape_type: List = [tuple, list]
def _compute_convergence_delta(model: Union[tf.keras.models.Model],
input_dtypes: List[tf.DType],
attributions: List[np.ndarray],
start_point: Union[List[np.ndarray], np.ndarray],
end_point: Union[List[np.ndarray], np.ndarray],
forward_kwargs: Optional[dict],
target: Optional[np.ndarray],
_is_list: bool) -> np.ndarray:
"""
Computes convergence deltas for each data point. Convergence delta measures how close the sum of all attributions
is to the difference between the model output at the baseline and the model output at the data point.
Parameters
----------
model
`tensorflow` model.
input_dtypes
List with data types of the inputs.
attributions
Attributions assigned by the integrated gradients method to each feature.
start_point
Baselines.
end_point
Data points.
forward_kwargs
Input keywords args.
target
Target for which the gradients are calculated for classification models.
_is_list
Whether the model's input is a `list` (multiple inputs) or a `np.narray` (single input).
Returns
-------
Convergence deltas for each data point.
"""
if forward_kwargs is None:
forward_kwargs = {}
if _is_list:
start_point = [tf.convert_to_tensor(start_point[k], dtype=input_dtypes[k]) for k in range(len(input_dtypes))]
end_point = [tf.convert_to_tensor(end_point[k], dtype=input_dtypes[k]) for k in range(len(input_dtypes))]
else:
start_point = tf.convert_to_tensor(start_point)
end_point = tf.convert_to_tensor(end_point)
def _sum_rows(inp):
input_str = string.ascii_lowercase[1: len(inp.shape)]
if isinstance(inp, tf.Tensor):
sums = tf.einsum('a{}->a'.format(input_str), inp).numpy()
elif isinstance(inp, np.ndarray):
sums = np.einsum('a{}->a'.format(input_str), inp)
else:
raise NotImplementedError('input must be a tensorflow tensor or a numpy array')
return sums
start_out = _run_forward(model, start_point, target, forward_kwargs=forward_kwargs)
end_out = _run_forward(model, end_point, target, forward_kwargs=forward_kwargs)
if (len(model.output_shape) == 1 or model.output_shape[-1] == 1) and target is not None:
target_tensor = tf.cast(target, dtype=start_out.dtype)
target_tensor = tf.reshape(1 - target_tensor, [len(target), 1])
sign = 2 * target_tensor - 1
start_out = target_tensor + sign * start_out
end_out = target_tensor + sign * end_out
start_out_sum = _sum_rows(start_out)
end_out_sum = _sum_rows(end_out)
attr_sum = np.zeros(start_out_sum.shape)
for j in range(len(attributions)):
attrs_sum_j = _sum_rows(attributions[j])
attr_sum += attrs_sum_j
_deltas = attr_sum - (end_out_sum - start_out_sum)
return _deltas
def _select_target(preds: tf.Tensor,
targets: Union[None, tf.Tensor, np.ndarray, list]) -> tf.Tensor:
"""
Select the predictions corresponding to the targets if targets is not ``None``.
Parameters
----------
preds
Predictions before selection.
targets
Targets to select.
Returns
-------
Selected predictions.
"""
if not isinstance(targets, tf.Tensor):
targets = tf.convert_to_tensor(targets)
if targets is not None:
if isinstance(preds, tf.Tensor):
preds = tf.gather_nd(preds, tf.expand_dims(targets, axis=1), batch_dims=1)
else:
raise NotImplementedError
else:
raise ValueError("target cannot be `None` if `model` output dimensions > 1")
return preds
def _run_forward(model: Union[tf.keras.models.Model],
x: Union[List[tf.Tensor], List[np.ndarray], tf.Tensor, np.ndarray],
target: Union[None, tf.Tensor, np.ndarray, list],
forward_kwargs: Optional[dict] = None) -> tf.Tensor:
"""
Returns the output of the model. If the target is not ``None``, only the output for the selected
target is returned.
Parameters
----------
model
`tensorflow` model.
x
Input data point.
target
Target for which the gradients are calculated for classification models.
forward_kwargs
Input keyword args.
Returns
-------
Model output or model output after target selection for classification models.
"""
if forward_kwargs is None:
forward_kwargs = {}
preds = model(x, **forward_kwargs)
if len(model.output_shape) > 1 and model.output_shape[-1] > 1:
preds = _select_target(preds, target)
return preds
def _run_forward_from_layer(model: tf.keras.models.Model,
layer: tf.keras.layers.Layer,
orig_call: Callable,
orig_dummy_input: Union[list, np.ndarray],
x: tf.Tensor,
target: Union[None, tf.Tensor, np.ndarray, list],
forward_kwargs: Optional[dict] = None,
run_from_layer_inputs: bool = False,
select_target: bool = True) -> tf.Tensor:
"""
Function currently unused.
Executes a forward call from an internal layer of the model to the model output.
Parameters
----------
model
`tensorflow` model.
layer
Starting layer for the forward call.
orig_call
Original `call` method of the layer.
orig_dummy_input
Dummy input needed to initiate the model forward call. The number of instances in the dummy input must
be the same as the number of instances in `x`. The dummy input values play no role in the evaluation
as the layer's status is overwritten during the forward call.
x
Layer's inputs. The layer's status is overwritten with `x` during the forward call.
target
Target for the output position to be returned.
forward_kwargs
Input keyword args. It must be a dict with `numpy` arrays as values. If it's not ``None``,
the first dimension of the arrays must correspond to the number of instances in `x` and orig_dummy_input.
run_from_layer_inputs
If ``True``, the forward pass starts from the layer's inputs, if ``False`` it starts from the layer's outputs.
select_target
Whether to return predictions for selected targets or return predictions for all targets.
Returns
-------
Model's predictions for the given target.
"""
def feed_layer(layer):
"""
Overwrites the intermediate layer status with the precomputed values `x`.
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Store the result and inputs of `layer.call` internally.
if run_from_layer_inputs:
layer.inp = x
layer.result = func(*x, **kwargs)
else:
layer.inp = args
layer.result = x
# Return the result to continue with the forward pass.
return layer.result
return wrapper
layer.call = decorator(layer.call)
feed_layer(layer)
if forward_kwargs is None:
forward_kwargs = {}
preds = model(orig_dummy_input, **forward_kwargs)
delattr(layer, 'inp')
delattr(layer, 'result')
layer.call = orig_call
if select_target and len(model.output_shape) > 1 and model.output_shape[-1] > 1:
preds = _select_target(preds, target)
return preds
def _run_forward_to_layer(model: tf.keras.models.Model,
layer: tf.keras.layers.Layer,
orig_call: Callable,
x: Union[List[np.ndarray], np.ndarray],
forward_kwargs: Optional[dict] = None,
run_to_layer_inputs: bool = False) -> tf.Tensor:
"""
Executes a forward call from the model input to an internal layer output.
Parameters
----------
model
`tensorflow` model.
layer
Starting layer for the forward call.
orig_call
Original `call` method of the layer.
x
Model's inputs.
forward_kwargs
Input keyword args.
run_to_layer_inputs
If ``True``, the layer's inputs are returned. If ``False``, the layer's output's are returned.
Returns
-------
Output of the given layer.
"""
if forward_kwargs is None:
forward_kwargs = {}
def take_layer(layer):
"""
Stores the layer's outputs internally to the layer's object.
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Store the result of `layer.call` internally.
layer.inp = args
layer.result = func(*args, **kwargs)
# Return the result to continue with the forward pass.
return layer.result
return wrapper
layer.call = decorator(layer.call)
# inp = tf.zeros((x.shape[0], ) + model.input_shape[1:])
take_layer(layer)
_ = model(x, **forward_kwargs)
layer_inp = layer.inp
layer_out = layer.result
delattr(layer, 'inp')
delattr(layer, 'result')
layer.call = orig_call
if run_to_layer_inputs:
return layer_inp
else:
return layer_out
def _forward_input_baseline(X: Union[List[np.ndarray], np.ndarray],
bls: Union[List[np.ndarray], np.ndarray],
model: tf.keras.Model,
layer: tf.keras.layers.Layer,
orig_call: Callable,
forward_kwargs: Optional[dict] = None,
forward_to_inputs: bool = False) -> Tuple[Union[list, tf.Tensor], Union[list, tf.Tensor]]:
"""
Forwards inputs and baselines to the layer's inputs or outputs.
Parameters
----------
X
Input data points.
bls
Baselines.
model
`tensorflow` model.
layer
Desired layer output.
orig_call
Original `call` method of layer.
forward_kwargs
Input keyword args.
forward_to_inputs
If ``True``, `X` and bls are forwarded to the layer's input. If ``False``, they are forwarded to
the layer's outputs.
Returns
-------
Forwarded inputs and baselines as a `numpy` arrays.
"""
if forward_kwargs is None:
forward_kwargs = {}
if layer is not None:
X_layer = _run_forward_to_layer(model,
layer,
orig_call,
X,
forward_kwargs=forward_kwargs,
run_to_layer_inputs=forward_to_inputs)
bls_layer = _run_forward_to_layer(model,
layer,
orig_call,
bls,
forward_kwargs=forward_kwargs,
run_to_layer_inputs=forward_to_inputs)
if isinstance(X_layer, tuple):
X_layer = list(X_layer)
if isinstance(bls_layer, tuple):
bls_layer = list(bls_layer)
return X_layer, bls_layer
else:
return X, bls
def _gradients_input(model: Union[tf.keras.models.Model],
x: List[tf.Tensor],
target: Union[None, tf.Tensor],
forward_kwargs: Optional[dict] = None) -> List[tf.Tensor]:
"""
Calculates the gradients of the target class output (or the output if the output dimension is equal to 1)
with respect to each input feature.
Parameters
----------
model
`tensorflow` model.
x
Input data point.
target
Target for which the gradients are calculated if the output dimension is higher than 1.
forward_kwargs
Input keyword args.
Returns
-------
Gradients for each input feature.
"""
if forward_kwargs is None:
forward_kwargs = {}
with tf.GradientTape() as tape:
tape.watch(x)
preds = _run_forward(model, x, target, forward_kwargs=forward_kwargs)
grads = tape.gradient(preds, x)
# If certain inputs don't impact the target, the gradient is None, but we need to return a tensor
if isinstance(x, list):
for idx, grad in enumerate(grads):
if grad is None:
grads[idx] = tf.convert_to_tensor(np.zeros(x[idx].shape), dtype=x[idx].dtype)
return grads
def _gradients_layer(model: Union[tf.keras.models.Model],
layer: Union[tf.keras.layers.Layer],
orig_call: Callable,
orig_dummy_input: Union[list, np.ndarray],
x: Union[List[tf.Tensor], tf.Tensor],
target: Union[None, tf.Tensor],
forward_kwargs: Optional[dict] = None,
compute_layer_inputs_gradients: bool = False) -> tf.Tensor:
"""
Calculates the gradients of the target class output (or the output if the output dimension is equal to 1)
with respect to each element of `layer`.
Parameters
----------
model
`tensorflow` model.
layer
Layer of the model with respect to which the gradients are calculated.
orig_call
Original `call` method of the layer. This is necessary since the call method is modified by the function
in order to make the layer output visible to the `GradientTape`.
x
Input data point.
target
Target for which the gradients are calculated if the output dimension is higher than 1.
forward_kwargs
Input keyword args.
compute_layer_inputs_gradients
If ``True``, gradients are computed with respect to the layer's inputs.
If ``False``, they are computed with respect to the layer's outputs.
Returns
-------
Gradients for each element of layer.
"""
def watch_layer(layer, tape):
"""
Make an intermediate hidden `layer` watchable by the `tape`.
After calling this function, you can obtain the gradient with
respect to the output of the `layer` by calling:
grads = tape.gradient(..., layer.result)
"""
def decorator(func):
def wrapper(*args, **kwargs):
# Store the result and the input of `layer.call` internally.
if compute_layer_inputs_gradients:
layer.inp = x
layer.result = func(*x, **kwargs)
# From this point onwards, watch this tensor.
tape.watch(layer.inp)
else:
layer.inp = args
layer.result = x
# From this point onwards, watch this tensor.
tape.watch(layer.result)
# Return the result to continue with the forward pass.
return layer.result
return wrapper
layer.call = decorator(layer.call)
# Repeating the dummy input needed to initiate the model's forward call in order to ensure that
# the number of dummy instances is the same as the number of real instances.
# This is necessary in case `forward_kwargs` is not None. In that case, the model forward call would crash
# if the number of instances in `orig_dummy_input` is different from the number of instances in `forward_kwargs`.
# The number of instances in `forward_kwargs` is the same as the number of instances in `x` by construction.
if isinstance(orig_dummy_input, list):
if isinstance(x, list):
orig_dummy_input = [np.repeat(inp, x[0].shape[0], axis=0) for inp in orig_dummy_input]
else:
orig_dummy_input = [np.repeat(inp, x.shape[0], axis=0) for inp in orig_dummy_input]
else:
if isinstance(x, list):
orig_dummy_input = np.repeat(orig_dummy_input, x[0].shape[0], axis=0)
else:
orig_dummy_input = np.repeat(orig_dummy_input, x.shape[0], axis=0)
if forward_kwargs is None:
forward_kwargs = {}
# Calculating the gradients with respect to the layer.
with tf.GradientTape() as tape:
watch_layer(layer, tape)
preds = _run_forward(model, orig_dummy_input, target, forward_kwargs=forward_kwargs)
if compute_layer_inputs_gradients:
grads = tape.gradient(preds, layer.inp)
else:
grads = tape.gradient(preds, layer.result)
# If certain inputs don't impact the target, the gradient is None, but we need to return a tensor
if isinstance(x, list):
for idx, grad in enumerate(grads):
if grad is None:
grads[idx] = tf.convert_to_tensor(np.zeros(x[idx].shape), dtype=x[idx].dtype)
delattr(layer, 'inp')
delattr(layer, 'result')
layer.call = orig_call
return grads
def _format_baseline(X: np.ndarray,
baselines: Union[None, int, float, np.ndarray]) -> np.ndarray:
"""
Formats baselines to return a `numpy` array.
Parameters
----------
X
Input data points.
baselines
Baselines.
Returns
-------
Formatted inputs and baselines as a `numpy` arrays.
"""
if baselines is None:
bls = np.zeros(X.shape).astype(X.dtype)
elif isinstance(baselines, int) or isinstance(baselines, float):
bls = np.full(X.shape, baselines).astype(X.dtype)
elif isinstance(baselines, np.ndarray):
bls = baselines.astype(X.dtype)
else:
raise ValueError(f"baselines must be `int`, `float`, `np.ndarray` or `None`. Found {type(baselines)}")
return bls
def _format_target(target: Union[None, int, list, np.ndarray],
nb_samples: int) -> Union[None, np.ndarray]:
"""
Formats target to return a np.array.
Parameters
----------
target
Original target.
nb_samples
Number of samples in the batch.
Returns
-------
Formatted target as a np.array.
"""
if target is not None:
if isinstance(target, int):
target = np.array([target for _ in range(nb_samples)])
elif isinstance(target, list):
target = np.array(target)
elif isinstance(target, np.ndarray):
pass
else:
raise NotImplementedError
return target
def _check_target(output_shape: Tuple,
target: Optional[np.ndarray],
nb_samples: int) -> None:
"""
Parameters
----------
output_shape
Output shape of the tensorflow model
target
Target formatted as np array target.
nb_samples
Number of samples in the batch.
Returns
-------
None
"""
if target is not None:
if not np.issubdtype(target.dtype, np.integer):
raise ValueError("Targets must be integers")
if target.shape[0] != nb_samples:
raise ValueError(f"First dimension in target must be the same as nb of samples. "
f"Found target first dimension: {target.shape[0]}; nb of samples: {nb_samples}")
if len(target.shape) > 2:
raise ValueError("Target must be a rank-1 or a rank-2 tensor. If target is a rank-2 tensor, "
"each column contains the index of the corresponding dimension "
"in the model's output tensor.")
if len(output_shape) == 1:
# in case of squash output, the rank of the model's output tensor (out_rank) consider the batch dimension
out_rank, target_rank = 1, len(target.shape)
tmax, tmin = target.max(axis=0), target.min(axis=0)
if tmax > 1:
raise ValueError(f"Target value {tmax} out of range for output shape {output_shape} ")
# for all other cases, batch dimension is not considered in the out_rank
elif len(output_shape) == 2:
out_rank, target_rank = 1, len(target.shape)
tmax, tmin = target.max(axis=0), target.min(axis=0)
if (output_shape[-1] > 1 and (tmax >= output_shape[-1]).any()) or (output_shape[-1] == 1 and tmax > 1):
raise ValueError(f"Target value {tmax} out of range for output shape {output_shape} ")
else:
out_rank, target_rank = len(output_shape[1:]), target.shape[-1]
tmax, tmin = target.max(axis=0), target.min(axis=0)
if (tmax >= output_shape[1:]).any():
raise ValueError(f"Target value {tmax} out of range for output shape {output_shape} ")
if (tmin < 0).any():
raise ValueError(f"Negative value {tmin} for target. Targets represent positional "
f"arguments and cannot be negative")
if out_rank != target_rank:
raise ValueError(f"The last dimension of target must match the rank of the model's output tensor. "
f"Found target last dimension: {target_rank}; model's output rank: {out_rank}")
def _get_target_from_target_fn(target_fn: Callable,
model: tf.keras.Model,
X: Union[np.ndarray, List[np.ndarray]],
forward_kwargs: Optional[dict] = None) -> np.ndarray:
"""
Generate a target vector by using the `target_fn` to pick out a
scalar dimension from the predictions.
Parameters
----------
target_fn
Target function.
model
Model.
X
Data to be explained.
forward_kwargs
Any additional kwargs needed for the model forward pass.
Returns
-------
Integer array of dimension `(N, )`.
"""
if forward_kwargs is None:
preds = model(X)
else:
preds = model(X, **forward_kwargs)
# raise a warning if the predictions are scalar valued already
# TODO: in the future we want to support outputs that are >2D at which point this check should change
if preds.shape[-1] == 1:
msg = "Predictions from the model are scalar valued but `target_fn` was passed. `target_fn` is not necessary" \
"when predictions are scalar valued already. Using `target_fn` here may result in unexpected behaviour."
warnings.warn(msg)
target = target_fn(preds)
expected_shape = (target.shape[0],)
if target.shape != expected_shape:
# TODO: in the future we want to support outputs that are >2D at which point this check should change
msg = f"`target_fn` returned an array of shape {target.shape} but expected an array of shape {expected_shape}."
raise ValueError(msg) # TODO: raise a more specific error type?
return target.astype(int)
def _sum_integral_terms(step_sizes: list,
grads: Union[tf.Tensor, np.ndarray]) -> Union[tf.Tensor, np.ndarray]:
"""
Sums the terms in the path integral with weights `step_sizes`.
Parameters
----------
step_sizes
Weights in the path integral sum.
grads
Gradients to sum for each feature.
Returns
-------
Sums of the gradients along the chosen path.
"""
input_str = string.ascii_lowercase[1: len(grads.shape)]
if isinstance(grads, tf.Tensor):
step_sizes = tf.convert_to_tensor(step_sizes)
einstr = 'a,a{}->{}'.format(input_str, input_str)
sums = tf.einsum(einstr, step_sizes, grads).numpy()
elif isinstance(grads, np.ndarray):
einstr = 'a,a{}->{}'.format(input_str, input_str)
sums = np.einsum(einstr, step_sizes, grads)
else:
raise NotImplementedError('input must be a tensorflow tensor or a numpy array')
return sums
def _calculate_sum_int(batches: List[List[tf.Tensor]],
model: Union[tf.keras.Model],
target: Optional[np.ndarray],
target_paths: np.ndarray,
n_steps: int,
nb_samples: int,
step_sizes: List[float],
j: int) -> Union[tf.Tensor, np.ndarray]:
"""
Calculates the sum of all the terms in the integral from a list of batch gradients.
Parameters
----------
batches
List of batch gradients.
model
`tf.keras` or `keras` model.
target
List of targets.
target_paths
Targets for each path in the integral.
n_steps
Number of steps in the integral.
nb_samples
Total number of samples.
step_sizes
Step sizes used to calculate the integral.
j
Iterates through list of inputs or list of layers.
Returns
-------
Sums of the gradients along the chosen path.
"""
grads = tf.concat(batches[j], 0)
shape = grads.shape[1:]
if isinstance(shape, tf.TensorShape):
shape = tuple(shape.as_list())
# invert sign of gradients for target 0 examples if classifier returns only positive class probability
if (len(model.output_shape) == 1 or model.output_shape[-1] == 1) and target is not None:
sign = 2 * target_paths - 1
grads = np.array([s * g for s, g in zip(sign, grads)])
grads = tf.reshape(grads, (n_steps, nb_samples) + shape)
# sum integral terms and scale attributions
sum_int = _sum_integral_terms(step_sizes, grads.numpy())
return sum_int
def _validate_output(model: tf.keras.Model,
target: Optional[np.ndarray]) -> None:
"""
Validates the model's output type and raises an error if the output type is not supported.
Parameters
----------
model
`Keras` model for which the output is validated.
target
Targets for which gradients are calculated
"""
if not model.output_shape or not any(isinstance(model.output_shape, t) for t in _valid_output_shape_type):
raise NotImplementedError(f"The model output_shape attribute must be in {_valid_output_shape_type}. "
f"Found model.output_shape: {model.output_shape}")
if (len(model.output_shape) == 1
or model.output_shape[-1] == 1) \
and target is None:
logger.warning("It looks like you are passing a model with a scalar output and target is set to `None`."
"If your model is a regression model this will produce correct attributions. If your model "
"is a classification model, targets for each datapoint must be defined. "
"Not defining the target may lead to incorrect values for the attributions."
"Targets can be either the true classes or the classes predicted by the model.")
class LayerState(str, Enum):
UNSPECIFIED = 'unspecified'
NON_SERIALIZABLE = 'non-serializable'
CALLABLE = 'callable'
class IntegratedGradients(Explainer):
def __init__(self,
model: tf.keras.Model,
layer: Optional[
Union[
Callable[[tf.keras.Model], tf.keras.layers.Layer],
tf.keras.layers.Layer
]
] = None,
target_fn: Optional[Callable] = None,
method: str = "gausslegendre",
n_steps: int = 50,
internal_batch_size: int = 100
) -> None:
"""
An implementation of the integrated gradients method for `tensorflow` models.
For details of the method see the original paper: https://arxiv.org/abs/1703.01365 .
Parameters
----------
model
`tensorflow` model.
layer
A layer or a function having as parameter the model and returning a layer with respect to which the
gradients are calculated. If not provided, the gradients are calculated with respect to the input.
To guarantee saving and loading of the explainer, the layer has to be specified as a callable which
returns a layer given the model. E.g. ``lambda model: model.layers[0].embeddings``.
target_fn
A scalar function that is applied to the predictions of the model.
This can be used to specify which scalar output the attributions should be calculated for.
This can be particularly useful if the desired output is not known before calling the model
(e.g. explaining the `argmax` output for a probabilistic classifier, in this case we could pass
``target_fn=partial(np.argmax, axis=1)``).
method
Method for the integral approximation. Methods available:
``"riemann_left"``, ``"riemann_right"``, ``"riemann_middle"``, ``"riemann_trapezoid"``, ``"gausslegendre"``.
n_steps
Number of step in the path integral approximation from the baseline to the input instance.
internal_batch_size
Batch size for the internal batching.
"""
super().__init__(meta=copy.deepcopy(DEFAULT_META_INTGRAD))
params = locals()
remove = ['self', 'model', '__class__', 'layer']
params = {k: v for k, v in params.items() if k not in remove}
self.model = model
if self.model.inputs is None:
self._has_inputs = False
else:
self._has_inputs = True
if layer is None:
self.orig_call: Optional[Callable] = None
self.layer = None
layer_meta: Union[int, str] = LayerState.UNSPECIFIED.value
elif isinstance(layer, tf.keras.layers.Layer):
self.orig_call = layer.call
self.layer = layer
try:
layer_meta = model.layers.index(layer)
except ValueError:
layer_meta = LayerState.NON_SERIALIZABLE.value
logger.warning('Layer not in the list of `model.layers`. Passing the layer directly would not '
'permit the serialization of the explainer. This is due to nested layers. To permit '
'the serialization of the explainer, provide the layer as a callable which returns '
'the layer given the model.')
elif callable(layer):
self.layer = layer(self.model)
self.orig_call = self.layer.call
self.callable_layer = layer
layer_meta = LayerState.CALLABLE.value
else:
raise TypeError(f'Unsupported layer type. Received {type(layer)}.')
params['layer'] = layer_meta
self.meta['params'].update(params)
self.n_steps = n_steps
self.method = method
self.internal_batch_size = internal_batch_size
self._is_list: Optional[bool] = None
self._is_np: Optional[bool] = None
self.orig_dummy_input: Optional[Union[list, np.ndarray]] = None
self.target_fn = target_fn
def explain(self,
X: Union[np.ndarray, List[np.ndarray]],
forward_kwargs: Optional[dict] = None,
baselines: Optional[Union[int, float, np.ndarray, List[int], List[float], List[np.ndarray]]] = None,
target: Optional[Union[int, list, np.ndarray]] = None,
attribute_to_layer_inputs: bool = False) -> Explanation:
"""Calculates the attributions for each input feature or element of layer and
returns an Explanation object.
Parameters
----------
X
Instance for which integrated gradients attribution are computed.
forward_kwargs
Input keyword args. If it's not ``None``, it must be a dict with `numpy` arrays as values.
The first dimension of the arrays must correspond to the number of examples.
It will be repeated for each of `n_steps` along the integrated path.
The attributions are not computed with respect to these arguments.
baselines
Baselines (starting point of the path integral) for each instance.
If the passed value is an `np.ndarray` must have the same shape as `X`.
If not provided, all features values for the baselines are set to 0.
target
Defines which element of the model output is considered to compute the gradients.
Target can be a numpy array, a list or a numeric value.
Numeric values are only valid if the model's output is a rank-n tensor
with n <= 2 (regression and classification models).
If a numeric value is passed, the gradients are calculated for
the same element of the output for all data points.
For regression models whose output is a scalar, target should not be provided.
For classification models `target` can be either the true classes or the classes predicted by the model.
It must be provided for classification models and regression models whose output is a vector.
If the model's output is a rank-n tensor with n > 2,
the target must be a rank-2 numpy array or a list of lists (a matrix) with dimensions nb_samples X (n-1) .
attribute_to_layer_inputs
In case of layers gradients, controls whether the gradients are computed for the layer's inputs or
outputs. If ``True``, gradients are computed for the layer's inputs, if ``False`` for the layer's outputs.
Returns
-------
explanation
`Explanation` object including `meta` and `data` attributes with integrated gradients attributions
for each feature. See usage at `IG examples`_ for details.
.. _IG examples:
https://docs.seldon.io/projects/alibi/en/stable/methods/IntegratedGradients.html
"""
# target handling logic
if self.target_fn and target is not None:
msg = 'Both `target_fn` and `target` were provided. Only one of these should be provided.'
raise ValueError(msg)
if self.target_fn:
target = _get_target_from_target_fn(self.target_fn, self.model, X, forward_kwargs)
self._is_list = isinstance(X, list)
self._is_np = isinstance(X, np.ndarray)
if forward_kwargs is None:
forward_kwargs = {}
if self._is_list:
X = cast(List[np.ndarray], X) # help mypy out
self.orig_dummy_input = [np.zeros((1,) + xx.shape[1:], dtype=xx.dtype) for xx in X]
nb_samples = len(X[0])
input_dtypes = [xx.dtype for xx in X]
# Formatting baselines in case of models with multiple inputs
if baselines is None:
baselines = [None for _ in range(len(X))] # type: ignore
else:
if not isinstance(baselines, list):
raise ValueError(f"If the input X is a list, baseline can only be `None` or "
f"a list of the same length of X. Found baselines type {type(baselines)}")
else:
if len(X) != len(baselines):
raise ValueError(f"Length of 'X' must match length of 'baselines'. "
f"Found len(X): {len(X)}, len(baselines): {len(baselines)}")
if max([len(x) for x in X]) != min([len(x) for x in X]):
raise ValueError("First dimension must be egual for all inputs")
for i in range(len(X)):
x, baseline = X[i], baselines[i] # type: ignore
# format and check baselines
baseline = _format_baseline(x, baseline)
baselines[i] = baseline # type: ignore
elif self._is_np:
X = cast(np.ndarray, X) # help mypy out
self.orig_dummy_input = np.zeros((1,) + X.shape[1:], dtype=X.dtype)
nb_samples = len(X)
input_dtypes = [X.dtype]
# Formatting baselines for models with a single input
baselines = _format_baseline(X, baselines) # type: ignore # TODO: validate/narrow baselines type
else:
raise ValueError("Input must be a np.ndarray or a list of np.ndarray")
# defining integral method
step_sizes_func, alphas_func = approximation_parameters(self.method)
step_sizes, alphas = step_sizes_func(self.n_steps), alphas_func(self.n_steps)
target = _format_target(target, nb_samples)
if self._is_list:
X = cast(List[np.ndarray], X) # help mypy out
# Attributions calculation in case of multiple inputs
if not self._has_inputs:
# Inferring model's inputs from data points for models with no explicit inputs
# (typically subclassed models)
inputs = [tf.keras.Input(shape=xx.shape[1:], dtype=xx.dtype) for xx in X]
self.model(inputs, **forward_kwargs)