28
28
)
29
29
from captum ._utils .exceptions import FeatureAblationFutureError
30
30
from captum ._utils .progress import progress , SimpleProgress
31
- from captum ._utils .typing import BaselineType , TargetType , TensorOrTupleOfTensorsGeneric
31
+ from captum ._utils .typing import (
32
+ BaselineTupleType ,
33
+ BaselineType ,
34
+ TargetType ,
35
+ TensorOrTupleOfTensorsGeneric ,
36
+ )
32
37
from captum .attr ._utils .attribution import PerturbationAttribution
33
38
from captum .attr ._utils .common import _format_input_baseline
34
39
from captum .log import log_usage
@@ -353,10 +358,12 @@ def attribute(
353
358
formatted_feature_mask ,
354
359
attr_progress ,
355
360
flattened_initial_eval ,
361
+ initial_eval ,
356
362
n_outputs ,
357
363
total_attrib ,
358
364
weights ,
359
365
attrib_type ,
366
+ perturbations_per_eval ,
360
367
** kwargs ,
361
368
)
362
369
else :
@@ -470,10 +477,12 @@ def _attribute_with_cross_tensor_feature_masks(
470
477
formatted_feature_mask : Tuple [Tensor , ...],
471
478
attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472
479
flattened_initial_eval : Tensor ,
480
+ initial_eval : Tensor ,
473
481
n_outputs : int ,
474
482
total_attrib : List [Tensor ],
475
483
weights : List [Tensor ],
476
484
attrib_type : dtype ,
485
+ perturbations_per_eval : int ,
477
486
** kwargs : Any ,
478
487
) -> Tuple [List [Tensor ], List [Tensor ]]:
479
488
feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
@@ -482,17 +491,78 @@ def _attribute_with_cross_tensor_feature_masks(
482
491
if feature_idx .item () not in feature_idx_to_tensor_idx :
483
492
feature_idx_to_tensor_idx [feature_idx .item ()] = []
484
493
feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
494
+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
495
+
496
+ additional_args_repeated : object
497
+ if perturbations_per_eval > 1 :
498
+ # Repeat features and additional args for batch size.
499
+ all_features_repeated = tuple (
500
+ torch .cat ([formatted_inputs [j ]] * perturbations_per_eval , dim = 0 )
501
+ for j in range (len (formatted_inputs ))
502
+ )
503
+ additional_args_repeated = (
504
+ _expand_additional_forward_args (
505
+ formatted_additional_forward_args , perturbations_per_eval
506
+ )
507
+ if formatted_additional_forward_args is not None
508
+ else None
509
+ )
510
+ target_repeated = _expand_target (target , perturbations_per_eval )
511
+ else :
512
+ all_features_repeated = formatted_inputs
513
+ additional_args_repeated = formatted_additional_forward_args
514
+ target_repeated = target
515
+ num_examples = formatted_inputs [0 ].shape [0 ]
516
+
517
+ current_additional_args : object
518
+ if isinstance (baselines , tuple ):
519
+ reshaped = False
520
+ reshaped_baselines : list [Union [Tensor , int , float ]] = []
521
+ for baseline in baselines :
522
+ if isinstance (baseline , Tensor ):
523
+ reshaped = True
524
+ reshaped_baselines .append (
525
+ baseline .reshape ((1 ,) + tuple (baseline .shape ))
526
+ )
527
+ else :
528
+ reshaped_baselines .append (baseline )
529
+ baselines = tuple (reshaped_baselines ) if reshaped else baselines
530
+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
531
+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
532
+ current_num_ablated_features = min (
533
+ perturbations_per_eval , len (current_feature_idxs )
534
+ )
535
+
536
+ # Store appropriate inputs and additional args based on batch size.
537
+ if current_num_ablated_features != perturbations_per_eval :
538
+ current_additional_args = (
539
+ _expand_additional_forward_args (
540
+ formatted_additional_forward_args , current_num_ablated_features
541
+ )
542
+ if formatted_additional_forward_args is not None
543
+ else None
544
+ )
545
+ current_target = _expand_target (target , current_num_ablated_features )
546
+ expanded_inputs = tuple (
547
+ feature_repeated [0 : current_num_ablated_features * num_examples ]
548
+ for feature_repeated in all_features_repeated
549
+ )
550
+ else :
551
+ current_additional_args = additional_args_repeated
552
+ current_target = target_repeated
553
+ expanded_inputs = all_features_repeated
554
+
555
+ current_inputs , current_masks = (
556
+ self ._construct_ablated_input_across_tensors (
557
+ expanded_inputs ,
558
+ formatted_feature_mask ,
559
+ baselines ,
560
+ current_feature_idxs ,
561
+ feature_idx_to_tensor_idx ,
562
+ current_num_ablated_features ,
563
+ )
564
+ )
485
565
486
- for (
487
- current_inputs ,
488
- current_mask ,
489
- ) in self ._ablation_generator (
490
- formatted_inputs ,
491
- baselines ,
492
- formatted_feature_mask ,
493
- feature_idx_to_tensor_idx ,
494
- ** kwargs ,
495
- ):
496
566
# modified_eval has (n_feature_perturbed * n_outputs) elements
497
567
# shape:
498
568
# agg mode: (*initial_eval.shape)
@@ -501,8 +571,8 @@ def _attribute_with_cross_tensor_feature_masks(
501
571
modified_eval = _run_forward (
502
572
self .forward_func ,
503
573
current_inputs ,
504
- target ,
505
- formatted_additional_forward_args ,
574
+ current_target ,
575
+ current_additional_args ,
506
576
)
507
577
508
578
if attr_progress is not None :
@@ -515,75 +585,65 @@ def _attribute_with_cross_tensor_feature_masks(
515
585
516
586
total_attrib , weights = self ._process_ablated_out_full (
517
587
modified_eval ,
518
- current_mask ,
588
+ current_masks ,
519
589
flattened_initial_eval ,
520
- formatted_inputs ,
590
+ initial_eval ,
591
+ current_inputs ,
521
592
n_outputs ,
593
+ num_examples ,
522
594
total_attrib ,
523
595
weights ,
524
596
attrib_type ,
597
+ perturbations_per_eval ,
525
598
)
526
599
return total_attrib , weights
527
600
528
- def _ablation_generator (
529
- self ,
530
- inputs : Tuple [Tensor , ...],
531
- baselines : BaselineType ,
532
- input_mask : Tuple [Tensor , ...],
533
- feature_idx_to_tensor_idx : Dict [int , List [int ]],
534
- ** kwargs : Any ,
535
- ) -> Generator [
536
- Tuple [
537
- Tuple [Tensor , ...],
538
- Tuple [Optional [Tensor ], ...],
539
- ],
540
- None ,
541
- None ,
542
- ]:
543
- if isinstance (baselines , torch .Tensor ):
544
- baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
545
-
546
- # Process one feature per time, rather than processing every input tensor
547
- for feature_idx in feature_idx_to_tensor_idx .keys ():
548
- ablated_inputs , current_masks = (
549
- self ._construct_ablated_input_across_tensors (
550
- inputs ,
551
- input_mask ,
552
- baselines ,
553
- feature_idx ,
554
- feature_idx_to_tensor_idx [feature_idx ],
555
- )
556
- )
557
- yield ablated_inputs , current_masks
558
-
559
601
def _construct_ablated_input_across_tensors (
560
602
self ,
561
603
inputs : Tuple [Tensor , ...],
562
604
input_mask : Tuple [Tensor , ...],
563
605
baselines : BaselineType ,
564
- feature_idx : int ,
565
- tensor_idxs : List [int ],
606
+ feature_idxs : List [int ],
607
+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
608
+ current_num_ablated_features : int ,
566
609
) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
567
-
568
610
ablated_inputs = []
569
611
current_masks : List [Optional [Tensor ]] = []
612
+ tensor_idxs = {
613
+ tensor_idx
614
+ for sublist in (
615
+ feature_idx_to_tensor_idx [feature_idx ] for feature_idx in feature_idxs
616
+ )
617
+ for tensor_idx in sublist
618
+ }
619
+
570
620
for i , input_tensor in enumerate (inputs ):
571
621
if i not in tensor_idxs :
572
622
ablated_inputs .append (input_tensor )
573
623
current_masks .append (None )
574
624
continue
575
- tensor_mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
625
+ tensor_mask = []
626
+ ablated_input = input_tensor .clone ()
576
627
baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
577
- if isinstance ( baseline , torch . Tensor ):
578
- baseline = baseline . reshape (
579
- ( 1 ,) * ( input_tensor .dim () - baseline . dim ()) + tuple ( baseline . shape )
628
+ for j , feature_idx in enumerate ( feature_idxs ):
629
+ original_input_size = (
630
+ input_tensor .shape [ 0 ] // current_num_ablated_features
580
631
)
581
- assert baseline is not None , "baseline must be provided"
582
- ablated_input = (
583
- input_tensor * (1 - tensor_mask ).to (input_tensor .dtype )
584
- ) + (baseline * tensor_mask .to (input_tensor .dtype ))
632
+ start_idx = j * original_input_size
633
+ end_idx = (j + 1 ) * original_input_size
634
+
635
+ mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
636
+ if mask .ndim == 0 :
637
+ mask = mask .reshape ((1 ,) * input_tensor .dim ())
638
+ tensor_mask .append (mask )
639
+
640
+ assert baseline is not None , "baseline must be provided"
641
+ ablated_input [start_idx :end_idx ] = input_tensor [start_idx :end_idx ] * (
642
+ 1 - mask
643
+ ) + (baseline * mask .to (input_tensor .dtype ))
644
+ current_masks .append (torch .stack (tensor_mask , dim = 0 ))
585
645
ablated_inputs .append (ablated_input )
586
- current_masks . append ( tensor_mask )
646
+
587
647
return tuple (ablated_inputs ), tuple (current_masks )
588
648
589
649
def _initial_eval_to_processed_initial_eval_fut (
@@ -784,7 +844,7 @@ def _attribute_progress_setup(
784
844
formatted_inputs , feature_mask , ** kwargs
785
845
)
786
846
total_forwards = (
787
- int (sum (feature_counts ))
847
+ math . ceil ( int (sum (feature_counts )) / perturbations_per_eval )
788
848
if enable_cross_tensor_attribution
789
849
else sum (
790
850
math .ceil (count / perturbations_per_eval ) for count in feature_counts
@@ -1194,13 +1254,46 @@ def _process_ablated_out_full(
1194
1254
modified_eval : Tensor ,
1195
1255
current_mask : Tuple [Optional [Tensor ], ...],
1196
1256
flattened_initial_eval : Tensor ,
1257
+ initial_eval : Tensor ,
1197
1258
inputs : TensorOrTupleOfTensorsGeneric ,
1198
1259
n_outputs : int ,
1260
+ num_examples : int ,
1199
1261
total_attrib : List [Tensor ],
1200
1262
weights : List [Tensor ],
1201
1263
attrib_type : dtype ,
1264
+ perturbations_per_eval : int ,
1202
1265
) -> Tuple [List [Tensor ], List [Tensor ]]:
1203
1266
modified_eval = self ._parse_forward_out (modified_eval )
1267
+ # if perturbations_per_eval > 1, the output shape must grow with
1268
+ # input and not be aggregated
1269
+ current_batch_size = inputs [0 ].shape [0 ]
1270
+
1271
+ # number of perturbation, which is not the same as
1272
+ # perturbations_per_eval when not enough features to perturb
1273
+ n_perturb = current_batch_size / num_examples
1274
+ if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
1275
+
1276
+ current_output_shape = modified_eval .shape
1277
+
1278
+ # use initial_eval as the forward of perturbations_per_eval = 1
1279
+ initial_output_shape = initial_eval .shape
1280
+
1281
+ assert (
1282
+ # check if the output is not a scalar
1283
+ current_output_shape
1284
+ and initial_output_shape
1285
+ # check if the output grow in same ratio, i.e., not agg
1286
+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
1287
+ ), (
1288
+ "When perturbations_per_eval > 1, forward_func's output "
1289
+ "should be a tensor whose 1st dim grow with the input "
1290
+ f"batch size: when input batch size is { num_examples } , "
1291
+ f"the output shape is { initial_output_shape } ; "
1292
+ f"when input batch size is { current_batch_size } , "
1293
+ f"the output shape is { current_output_shape } "
1294
+ )
1295
+
1296
+ self ._is_output_shape_valid = True
1204
1297
1205
1298
# reshape the leading dim for n_feature_perturbed
1206
1299
# flatten each feature's eval outputs into 1D of (n_outputs)
@@ -1209,9 +1302,6 @@ def _process_ablated_out_full(
1209
1302
eval_diff = flattened_initial_eval - modified_eval
1210
1303
eval_diff_shape = eval_diff .shape
1211
1304
1212
- # append the shape of one input example
1213
- # to make it broadcastable to mask
1214
-
1215
1305
if self .use_weights :
1216
1306
for weight , mask in zip (weights , current_mask ):
1217
1307
if mask is not None :
@@ -1224,6 +1314,7 @@ def _process_ablated_out_full(
1224
1314
)
1225
1315
eval_diff = eval_diff .to (total_attrib [i ].device )
1226
1316
total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
1317
+
1227
1318
return total_attrib , weights
1228
1319
1229
1320
def _fut_tuple_to_accumulate_fut_list (
0 commit comments