5
5
import itertools
6
6
import math
7
7
import warnings
8
- from typing import Callable , cast , Iterable , Optional , Sequence , Tuple , Union
8
+ from typing import Callable , cast , Iterable , List , Optional , Sequence , Tuple , Union
9
9
10
10
import torch
11
11
from captum ._utils .common import (
20
20
_is_tuple ,
21
21
_run_forward ,
22
22
)
23
+ from captum ._utils .exceptions import ShapleyValueFutureError
23
24
from captum ._utils .progress import progress
24
25
from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
25
26
from captum .attr ._utils .attribution import PerturbationAttribution
29
30
_tensorize_baseline ,
30
31
)
31
32
from captum .log import log_usage
32
- from torch import dtype , Tensor
33
+ from torch import dtype , Size , Tensor
34
+ from torch .futures import collect_all , Future
33
35
34
36
35
37
def _all_perm_generator (num_features : int , num_samples : int ) -> Iterable [Sequence [int ]]:
@@ -394,7 +396,6 @@ def attribute(
394
396
)
395
397
if show_progress :
396
398
attr_progress .update ()
397
-
398
399
if agg_output_mode :
399
400
eval_diff = modified_eval - prev_results
400
401
prev_results = modified_eval
@@ -438,7 +439,6 @@ def attribute(
438
439
439
440
# (*output_shape, *input_feature_shape)
440
441
total_attrib [j ] += cur_attr
441
-
442
442
if show_progress :
443
443
attr_progress .close ()
444
444
@@ -452,14 +452,298 @@ def attribute(
452
452
# `Tuple[Tensor, ...]`.
453
453
return formatted_attr
454
454
455
- # pyre-fixme[24] Generic type `Callable` expects 2 type parameters.
456
- def attribute_future (self ) -> Callable :
455
+ def attribute_future (
456
+ self ,
457
+ inputs : TensorOrTupleOfTensorsGeneric ,
458
+ baselines : BaselineType = None ,
459
+ target : TargetType = None ,
460
+ additional_forward_args : Optional [Tuple [object , ...]] = None ,
461
+ feature_mask : Union [None , TensorOrTupleOfTensorsGeneric ] = None ,
462
+ n_samples : int = 25 ,
463
+ perturbations_per_eval : int = 1 ,
464
+ show_progress : bool = False ,
465
+ ) -> Future [TensorOrTupleOfTensorsGeneric ]:
457
466
r"""
458
467
This method is not implemented for ShapleyValueSampling.
459
468
"""
460
- raise NotImplementedError (
461
- "attribute_future is not implemented for ShapleyValueSampling"
469
+ is_inputs_tuple = _is_tuple (inputs )
470
+ inputs_tuple , baselines = _format_input_baseline (inputs , baselines )
471
+ additional_forward_args = _format_additional_forward_args (
472
+ additional_forward_args
462
473
)
474
+ formatted_feature_mask = _format_feature_mask (feature_mask , inputs_tuple )
475
+ reshaped_feature_mask = _shape_feature_mask (
476
+ formatted_feature_mask , inputs_tuple
477
+ )
478
+
479
+ assert (
480
+ isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
481
+ ), "Ablations per evaluation must be at least 1."
482
+
483
+ with torch .no_grad ():
484
+ baselines = _tensorize_baseline (inputs_tuple , baselines )
485
+ num_examples = inputs_tuple [0 ].shape [0 ]
486
+
487
+ total_features = _get_max_feature_index (reshaped_feature_mask ) + 1
488
+
489
+ if show_progress :
490
+ attr_progress = progress (
491
+ desc = f"{ self .get_name ()} attribution" ,
492
+ total = self ._get_n_evaluations (
493
+ total_features , n_samples , perturbations_per_eval
494
+ )
495
+ + 1 , # add 1 for the initial eval
496
+ )
497
+ attr_progress .update (0 )
498
+
499
+ initial_eval = self ._strict_run_forward_future (
500
+ self .forward_func , baselines , target , additional_forward_args
501
+ )
502
+
503
+ if show_progress :
504
+ attr_progress .update ()
505
+
506
+ prev_result_tuple = initial_eval .then (
507
+ lambda initial_eval = initial_eval : self ._initial_eval_to_prev_results_tuple (
508
+ initial_eval ,
509
+ num_examples ,
510
+ perturbations_per_eval ,
511
+ reshaped_feature_mask ,
512
+ inputs_tuple ,
513
+ )
514
+ )
515
+
516
+ iter_count = 0
517
+ # Iterate for number of samples, generate a permutation of the features
518
+ # and evalute the incremental increase for each feature.
519
+ for feature_permutation in self .permutation_generator (
520
+ total_features , n_samples
521
+ ):
522
+ prev_result_tuple = prev_result_tuple .then (
523
+ lambda prev_result_tuple = prev_result_tuple : self ._set_prev_results_to_initial_eval (
524
+ prev_result_tuple
525
+ )
526
+ )
527
+
528
+ iter_count += 1
529
+ for (
530
+ current_inputs ,
531
+ current_add_args ,
532
+ current_target ,
533
+ current_masks ,
534
+ ) in self ._perturbation_generator (
535
+ inputs_tuple ,
536
+ additional_forward_args ,
537
+ target ,
538
+ baselines ,
539
+ reshaped_feature_mask ,
540
+ feature_permutation ,
541
+ perturbations_per_eval ,
542
+ ):
543
+ if sum (torch .sum (mask ).item () for mask in current_masks ) == 0 :
544
+ warnings .warn (
545
+ "Feature mask is missing some integers between 0 and "
546
+ "num_features, for optimal performance, make sure each"
547
+ " consecutive integer corresponds to a feature." ,
548
+ stacklevel = 1 ,
549
+ )
550
+ # modified_eval dimensions: 1D tensor with length
551
+ # equal to #num_examples * #features in batch
552
+ modified_eval = self ._strict_run_forward_future (
553
+ self .forward_func ,
554
+ current_inputs ,
555
+ current_target ,
556
+ current_add_args ,
557
+ )
558
+ if show_progress :
559
+ attr_progress .update ()
560
+
561
+ assert isinstance (modified_eval , torch .Future ), (
562
+ "when using futures method, modified_eval should have "
563
+ f"Future type rather than { type (modified_eval )} "
564
+ )
565
+ eval_futs = collect_all ([prev_result_tuple , modified_eval ])
566
+ prev_result_tuple = eval_futs .then (
567
+ lambda eval_futs = eval_futs , num_examples = num_examples , inputs_tuple = inputs_tuple , current_masks = current_masks : self ._eval_fut_to_prev_results_tuple (
568
+ eval_futs , num_examples , inputs_tuple , current_masks
569
+ )
570
+ )
571
+
572
+ if show_progress :
573
+ attr_progress .close ()
574
+
575
+ # Divide total attributions by number of random permutations and return
576
+ # formatted attributions.
577
+ formatted_attr = prev_result_tuple .then (
578
+ lambda prev_result_tuple = prev_result_tuple : self ._prev_result_tuple_to_formatted_attr (
579
+ prev_result_tuple , iter_count , is_inputs_tuple
580
+ )
581
+ )
582
+ # pyre-fixme[7]: Expected `TensorOrTupleOfTensorsGeneric` but got
583
+ # `Tuple[Tensor, ...]`.
584
+ return formatted_attr
585
+
586
+ def _initial_eval_to_prev_results_tuple (
587
+ self ,
588
+ initial_eval : Future [Tensor ],
589
+ num_examples : int ,
590
+ perturbations_per_eval : int ,
591
+ reshaped_feature_mask : TensorOrTupleOfTensorsGeneric ,
592
+ inputs_tuple : Tuple [Tensor , ...],
593
+ ) -> Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]:
594
+ """Since the initial eval is a Future, it is easier to bundle the prev_result, agg_output_mode, output_shape, and total_attrib together
595
+ as Shapley Value Feature Attributions are being calculated"""
596
+ try :
597
+ initial_eval_processed = initial_eval .value ()
598
+ prev_result = initial_eval_processed
599
+ if not isinstance (initial_eval_processed , Tensor ):
600
+ raise AssertionError (
601
+ "initial_eval_to_processed_initial_eval_fut: "
602
+ "initial_eval should be a Tensor"
603
+ )
604
+ agg_output_mode = _find_output_mode_and_verify (
605
+ initial_eval_processed ,
606
+ num_examples ,
607
+ perturbations_per_eval ,
608
+ reshaped_feature_mask ,
609
+ allow_multi_outputs = True ,
610
+ )
611
+ output_shape = initial_eval_processed .shape
612
+ total_attrib : List [Tensor ] = [
613
+ torch .zeros (
614
+ tuple (output_shape ) + tuple (input .shape [1 :]),
615
+ dtype = torch .float ,
616
+ device = inputs_tuple [0 ].device ,
617
+ )
618
+ for input in inputs_tuple
619
+ ]
620
+ result = (
621
+ initial_eval_processed ,
622
+ prev_result ,
623
+ output_shape ,
624
+ total_attrib ,
625
+ agg_output_mode ,
626
+ )
627
+ except ShapleyValueFutureError as e :
628
+ raise ShapleyValueFutureError (
629
+ "_initial_eval_to_prev_results_tuple func failed"
630
+ ) from e
631
+ return result
632
+
633
+ def _set_prev_results_to_initial_eval (
634
+ self ,
635
+ processed_initial_eval : Future [Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]],
636
+ ) -> Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]:
637
+ """At the beginning of each feature permutation, the prev_results is reset to the initial eval, and this method helps set that up"""
638
+ (initial_eval , prev_results , output_shape , total_attrib , agg_output_mode ) = (
639
+ processed_initial_eval .value ()
640
+ )
641
+ prev_results = initial_eval
642
+ return (initial_eval , prev_results , output_shape , total_attrib , agg_output_mode )
643
+
644
+ def _eval_fut_to_prev_results_tuple (
645
+ self ,
646
+ eval_futs : Future [
647
+ List [
648
+ Union [
649
+ Future [Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]],
650
+ Future [Tensor ],
651
+ ]
652
+ ]
653
+ ],
654
+ num_examples : int ,
655
+ inputs_tuple : Tuple [Tensor , ...],
656
+ current_masks : Tuple [Tensor , ...],
657
+ ) -> Tuple [Tensor , Tensor , Size , List [Tensor ], bool ]:
658
+ """Helper method responsible for calculating eval differences between the modified eval and prev_results Tensor and storing them in total_attrib. Returns prev_results_tuple with modified total_attrib and prev_results"""
659
+ prev_results_tuple = eval_futs .value ()[0 ].value ()
660
+ modified_eval = eval_futs .value ()[1 ].value ()
661
+ if not isinstance (modified_eval , Tensor ) or not isinstance (
662
+ prev_results_tuple , tuple
663
+ ):
664
+ raise ShapleyValueFutureError (
665
+ "_eval_fut_to_prev_results_tuple func failed due to type mismatch"
666
+ )
667
+ (
668
+ initial_eval ,
669
+ prev_results ,
670
+ output_shape ,
671
+ total_attrib ,
672
+ agg_output_mode ,
673
+ ) = prev_results_tuple
674
+ if agg_output_mode :
675
+ eval_diff = modified_eval - prev_results
676
+ prev_results = modified_eval
677
+ else :
678
+ # when perturb_per_eval > 1, every num_examples stands for
679
+ # one perturb. Since the perturbs are from a consecutive
680
+ # perumuation, each diff of a perturb is its eval minus
681
+ # the eval of the previous perturb
682
+
683
+ all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
684
+ eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
685
+ prev_results = all_eval [- num_examples :]
686
+
687
+ for j in range (len (total_attrib )):
688
+ # format eval_diff to shape
689
+ # (n_perturb, *output_shape, 1,.. 1)
690
+ # where n_perturb may not be perturb_per_eval
691
+ # Append n_input_feature dim of 1 to make the tensor
692
+ # have the same dim as the mask tensor.
693
+ formatted_eval_diff = eval_diff .reshape (
694
+ (- 1 ,) + tuple (output_shape ) + (len (inputs_tuple [j ].shape ) - 1 ) * (1 ,)
695
+ )
696
+
697
+ # mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
698
+ # reshape to
699
+ # (
700
+ # n_perturb,
701
+ # *broadcastable_to_output_shape
702
+ # *broadcastable_to_input_feature_shape
703
+ # )
704
+ cur_mask = current_masks [j ]
705
+ cur_mask = cur_mask .reshape (
706
+ tuple (cur_mask .shape [:2 ])
707
+ + (len (output_shape ) - 1 ) * (1 ,)
708
+ + tuple (cur_mask .shape [2 :])
709
+ )
710
+
711
+ # aggregate n_perturb
712
+ cur_attr = (formatted_eval_diff * cur_mask .float ()).sum (dim = 0 )
713
+ # (*output_shape, *input_feature_shape)
714
+ total_attrib [j ] += cur_attr
715
+
716
+ result = (
717
+ initial_eval ,
718
+ prev_results ,
719
+ output_shape ,
720
+ total_attrib ,
721
+ agg_output_mode ,
722
+ )
723
+ return result
724
+
725
+ def _prev_result_tuple_to_formatted_attr (
726
+ self ,
727
+ prev_result_tuple : Future [
728
+ Tuple [Tensor , Tensor , Tuple [int ], List [Tensor ], bool ]
729
+ ],
730
+ iter_count : int ,
731
+ is_inputs_tuple : bool ,
732
+ ) -> Union [Tensor , Tuple [Tensor , ...]]:
733
+ """Helper method to format total_attrib, which is a list of tensors, into formatted attributions, which are either a single tensor or a tuple of tensors"""
734
+
735
+ (
736
+ _ ,
737
+ _ ,
738
+ _ ,
739
+ total_attrib ,
740
+ _ ,
741
+ ) = prev_result_tuple .value ()
742
+ attrib = tuple (
743
+ tensor_attrib_total / iter_count for tensor_attrib_total in total_attrib
744
+ )
745
+ formatted_attr = _format_output (is_inputs_tuple , attrib )
746
+ return formatted_attr
463
747
464
748
def _perturbation_generator (
465
749
self ,
@@ -574,6 +858,37 @@ def _strict_run_forward(self, *args, **kwargs) -> Tensor:
574
858
# ref: https://github.com/pytorch/pytorch/pull/21215
575
859
return torch .tensor ([forward_output ], dtype = cast (dtype , output_type ))
576
860
861
+ # pyre-fixme[2]: Parameter must be annotated.
862
+ def _strict_run_forward_future (self , * args , ** kwargs ) -> Future [Tensor ]:
863
+ """
864
+ A temp wrapper for global _run_forward util to force forward output
865
+ type assertion & conversion, but takes into account the Future tensor type
866
+ """
867
+
868
+ def process_strict_run_forward (fut : Future [Tensor ]) -> Tensor :
869
+ output = fut .value ()
870
+ if isinstance (output , Tensor ):
871
+ # format scalar to shape (1) so we can always assume non-empty output_shape
872
+ if not output .shape :
873
+ output = output .reshape (1 )
874
+ return output
875
+ output_type = type (output )
876
+ assert output_type is int or output_type is float , (
877
+ "the return of forward_func must be a Future of tensor, int, or float,"
878
+ f" received: { output_type } "
879
+ )
880
+ output = torch .tensor ([output ], dtype = cast (dtype , output_type ))
881
+ return output
882
+
883
+ forward_output = _run_forward (* args , ** kwargs )
884
+ assert isinstance (forward_output , torch .Future ), (
885
+ "The return type of forward_func must be a Future"
886
+ f" received: { type (forward_output )} "
887
+ )
888
+
889
+ return_output = forward_output .then (process_strict_run_forward )
890
+ return return_output
891
+
577
892
578
893
class ShapleyValues (ShapleyValueSampling ):
579
894
"""
0 commit comments