@@ -353,10 +353,12 @@ def attribute(
353
353
formatted_feature_mask ,
354
354
attr_progress ,
355
355
flattened_initial_eval ,
356
+ initial_eval ,
356
357
n_outputs ,
357
358
total_attrib ,
358
359
weights ,
359
360
attrib_type ,
361
+ perturbations_per_eval ,
360
362
** kwargs ,
361
363
)
362
364
else :
@@ -470,10 +472,12 @@ def _attribute_with_cross_tensor_feature_masks(
470
472
formatted_feature_mask : Tuple [Tensor , ...],
471
473
attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
472
474
flattened_initial_eval : Tensor ,
475
+ initial_eval : Tensor ,
473
476
n_outputs : int ,
474
477
total_attrib : List [Tensor ],
475
478
weights : List [Tensor ],
476
479
attrib_type : dtype ,
480
+ perturbations_per_eval : int ,
477
481
** kwargs : Any ,
478
482
) -> Tuple [List [Tensor ], List [Tensor ]]:
479
483
feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
@@ -482,17 +486,78 @@ def _attribute_with_cross_tensor_feature_masks(
482
486
if feature_idx .item () not in feature_idx_to_tensor_idx :
483
487
feature_idx_to_tensor_idx [feature_idx .item ()] = []
484
488
feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
489
+ all_feature_idxs = list (feature_idx_to_tensor_idx .keys ())
490
+
491
+ additional_args_repeated : object
492
+ if perturbations_per_eval > 1 :
493
+ # Repeat features and additional args for batch size.
494
+ all_features_repeated = tuple (
495
+ torch .cat ([formatted_inputs [j ]] * perturbations_per_eval , dim = 0 )
496
+ for j in range (len (formatted_inputs ))
497
+ )
498
+ additional_args_repeated = (
499
+ _expand_additional_forward_args (
500
+ formatted_additional_forward_args , perturbations_per_eval
501
+ )
502
+ if formatted_additional_forward_args is not None
503
+ else None
504
+ )
505
+ target_repeated = _expand_target (target , perturbations_per_eval )
506
+ else :
507
+ all_features_repeated = formatted_inputs
508
+ additional_args_repeated = formatted_additional_forward_args
509
+ target_repeated = target
510
+ num_examples = formatted_inputs [0 ].shape [0 ]
511
+
512
+ current_additional_args : object
513
+ if isinstance (baselines , tuple ):
514
+ reshaped = False
515
+ reshaped_baselines : list [BaselineType ] = []
516
+ for baseline in baselines :
517
+ if isinstance (baseline , Tensor ):
518
+ reshaped = True
519
+ reshaped_baselines .append (
520
+ baseline .reshape ((1 ,) + tuple (baseline .shape ))
521
+ )
522
+ else :
523
+ reshaped_baselines .append (baseline )
524
+ baselines = tuple (reshaped_baselines ) if reshaped else baselines
525
+ for i in range (0 , len (all_feature_idxs ), perturbations_per_eval ):
526
+ current_feature_idxs = all_feature_idxs [i : i + perturbations_per_eval ]
527
+ current_num_ablated_features = min (
528
+ perturbations_per_eval , len (current_feature_idxs )
529
+ )
530
+
531
+ # Store appropriate inputs and additional args based on batch size.
532
+ if current_num_ablated_features != perturbations_per_eval :
533
+ current_additional_args = (
534
+ _expand_additional_forward_args (
535
+ formatted_additional_forward_args , current_num_ablated_features
536
+ )
537
+ if formatted_additional_forward_args is not None
538
+ else None
539
+ )
540
+ current_target = _expand_target (target , current_num_ablated_features )
541
+ expanded_inputs = tuple (
542
+ feature_repeated [0 : current_num_ablated_features * num_examples ]
543
+ for feature_repeated in all_features_repeated
544
+ )
545
+ else :
546
+ current_additional_args = additional_args_repeated
547
+ current_target = target_repeated
548
+ expanded_inputs = all_features_repeated
549
+
550
+ current_inputs , current_masks = (
551
+ self ._construct_ablated_input_across_tensors (
552
+ expanded_inputs ,
553
+ formatted_feature_mask ,
554
+ baselines ,
555
+ current_feature_idxs ,
556
+ feature_idx_to_tensor_idx ,
557
+ current_num_ablated_features ,
558
+ )
559
+ )
485
560
486
- for (
487
- current_inputs ,
488
- current_mask ,
489
- ) in self ._ablation_generator (
490
- formatted_inputs ,
491
- baselines ,
492
- formatted_feature_mask ,
493
- feature_idx_to_tensor_idx ,
494
- ** kwargs ,
495
- ):
496
561
# modified_eval has (n_feature_perturbed * n_outputs) elements
497
562
# shape:
498
563
# agg mode: (*initial_eval.shape)
@@ -501,8 +566,8 @@ def _attribute_with_cross_tensor_feature_masks(
501
566
modified_eval = _run_forward (
502
567
self .forward_func ,
503
568
current_inputs ,
504
- target ,
505
- formatted_additional_forward_args ,
569
+ current_target ,
570
+ current_additional_args ,
506
571
)
507
572
508
573
if attr_progress is not None :
@@ -515,75 +580,65 @@ def _attribute_with_cross_tensor_feature_masks(
515
580
516
581
total_attrib , weights = self ._process_ablated_out_full (
517
582
modified_eval ,
518
- current_mask ,
583
+ current_masks ,
519
584
flattened_initial_eval ,
520
- formatted_inputs ,
585
+ initial_eval ,
586
+ current_inputs ,
521
587
n_outputs ,
588
+ num_examples ,
522
589
total_attrib ,
523
590
weights ,
524
591
attrib_type ,
592
+ perturbations_per_eval ,
525
593
)
526
594
return total_attrib , weights
527
595
528
- def _ablation_generator (
529
- self ,
530
- inputs : Tuple [Tensor , ...],
531
- baselines : BaselineType ,
532
- input_mask : Tuple [Tensor , ...],
533
- feature_idx_to_tensor_idx : Dict [int , List [int ]],
534
- ** kwargs : Any ,
535
- ) -> Generator [
536
- Tuple [
537
- Tuple [Tensor , ...],
538
- Tuple [Optional [Tensor ], ...],
539
- ],
540
- None ,
541
- None ,
542
- ]:
543
- if isinstance (baselines , torch .Tensor ):
544
- baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
545
-
546
- # Process one feature per time, rather than processing every input tensor
547
- for feature_idx in feature_idx_to_tensor_idx .keys ():
548
- ablated_inputs , current_masks = (
549
- self ._construct_ablated_input_across_tensors (
550
- inputs ,
551
- input_mask ,
552
- baselines ,
553
- feature_idx ,
554
- feature_idx_to_tensor_idx [feature_idx ],
555
- )
556
- )
557
- yield ablated_inputs , current_masks
558
-
559
596
def _construct_ablated_input_across_tensors (
560
597
self ,
561
598
inputs : Tuple [Tensor , ...],
562
599
input_mask : Tuple [Tensor , ...],
563
600
baselines : BaselineType ,
564
- feature_idx : int ,
565
- tensor_idxs : List [int ],
601
+ feature_idxs : List [int ],
602
+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
603
+ current_num_ablated_features : int ,
566
604
) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
567
-
568
605
ablated_inputs = []
569
606
current_masks : List [Optional [Tensor ]] = []
607
+ tensor_idxs = {
608
+ tensor_idx
609
+ for sublist in (
610
+ feature_idx_to_tensor_idx [feature_idx ] for feature_idx in feature_idxs
611
+ )
612
+ for tensor_idx in sublist
613
+ }
614
+
570
615
for i , input_tensor in enumerate (inputs ):
571
616
if i not in tensor_idxs :
572
617
ablated_inputs .append (input_tensor )
573
618
current_masks .append (None )
574
619
continue
575
- tensor_mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
620
+ tensor_mask = []
621
+ ablated_input = input_tensor .clone ()
576
622
baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
577
- if isinstance ( baseline , torch . Tensor ):
578
- baseline = baseline . reshape (
579
- ( 1 ,) * ( input_tensor .dim () - baseline . dim ()) + tuple ( baseline . shape )
623
+ for j , feature_idx in enumerate ( feature_idxs ):
624
+ original_input_size = (
625
+ input_tensor .shape [ 0 ] // current_num_ablated_features
580
626
)
581
- assert baseline is not None , "baseline must be provided"
582
- ablated_input = (
583
- input_tensor * (1 - tensor_mask ).to (input_tensor .dtype )
584
- ) + (baseline * tensor_mask .to (input_tensor .dtype ))
627
+ start_idx = j * original_input_size
628
+ end_idx = (j + 1 ) * original_input_size
629
+
630
+ mask = (input_mask [i ] == feature_idx ).to (input_tensor .device ).long ()
631
+ if mask .ndim == 0 :
632
+ mask = mask .reshape ((1 ,) * input_tensor .dim ())
633
+ tensor_mask .append (mask )
634
+
635
+ assert baseline is not None , "baseline must be provided"
636
+ ablated_input [start_idx :end_idx ] = input_tensor [start_idx :end_idx ] * (
637
+ 1 - mask
638
+ ) + (baseline * mask .to (input_tensor .dtype ))
639
+ current_masks .append (torch .stack (tensor_mask , dim = 0 ))
585
640
ablated_inputs .append (ablated_input )
586
- current_masks . append ( tensor_mask )
641
+
587
642
return tuple (ablated_inputs ), tuple (current_masks )
588
643
589
644
def _initial_eval_to_processed_initial_eval_fut (
@@ -784,7 +839,7 @@ def _attribute_progress_setup(
784
839
formatted_inputs , feature_mask , ** kwargs
785
840
)
786
841
total_forwards = (
787
- int (sum (feature_counts ))
842
+ math . ceil ( int (sum (feature_counts )) / perturbations_per_eval )
788
843
if enable_cross_tensor_attribution
789
844
else sum (
790
845
math .ceil (count / perturbations_per_eval ) for count in feature_counts
@@ -1194,13 +1249,46 @@ def _process_ablated_out_full(
1194
1249
modified_eval : Tensor ,
1195
1250
current_mask : Tuple [Optional [Tensor ], ...],
1196
1251
flattened_initial_eval : Tensor ,
1252
+ initial_eval : Tensor ,
1197
1253
inputs : TensorOrTupleOfTensorsGeneric ,
1198
1254
n_outputs : int ,
1255
+ num_examples : int ,
1199
1256
total_attrib : List [Tensor ],
1200
1257
weights : List [Tensor ],
1201
1258
attrib_type : dtype ,
1259
+ perturbations_per_eval : int ,
1202
1260
) -> Tuple [List [Tensor ], List [Tensor ]]:
1203
1261
modified_eval = self ._parse_forward_out (modified_eval )
1262
+ # if perturbations_per_eval > 1, the output shape must grow with
1263
+ # input and not be aggregated
1264
+ current_batch_size = inputs [0 ].shape [0 ]
1265
+
1266
+ # number of perturbation, which is not the same as
1267
+ # perturbations_per_eval when not enough features to perturb
1268
+ n_perturb = current_batch_size / num_examples
1269
+ if perturbations_per_eval > 1 and not self ._is_output_shape_valid :
1270
+
1271
+ current_output_shape = modified_eval .shape
1272
+
1273
+ # use initial_eval as the forward of perturbations_per_eval = 1
1274
+ initial_output_shape = initial_eval .shape
1275
+
1276
+ assert (
1277
+ # check if the output is not a scalar
1278
+ current_output_shape
1279
+ and initial_output_shape
1280
+ # check if the output grow in same ratio, i.e., not agg
1281
+ and current_output_shape [0 ] == n_perturb * initial_output_shape [0 ]
1282
+ ), (
1283
+ "When perturbations_per_eval > 1, forward_func's output "
1284
+ "should be a tensor whose 1st dim grow with the input "
1285
+ f"batch size: when input batch size is { num_examples } , "
1286
+ f"the output shape is { initial_output_shape } ; "
1287
+ f"when input batch size is { current_batch_size } , "
1288
+ f"the output shape is { current_output_shape } "
1289
+ )
1290
+
1291
+ self ._is_output_shape_valid = True
1204
1292
1205
1293
# reshape the leading dim for n_feature_perturbed
1206
1294
# flatten each feature's eval outputs into 1D of (n_outputs)
@@ -1209,9 +1297,6 @@ def _process_ablated_out_full(
1209
1297
eval_diff = flattened_initial_eval - modified_eval
1210
1298
eval_diff_shape = eval_diff .shape
1211
1299
1212
- # append the shape of one input example
1213
- # to make it broadcastable to mask
1214
-
1215
1300
if self .use_weights :
1216
1301
for weight , mask in zip (weights , current_mask ):
1217
1302
if mask is not None :
@@ -1224,6 +1309,7 @@ def _process_ablated_out_full(
1224
1309
)
1225
1310
eval_diff = eval_diff .to (total_attrib [i ].device )
1226
1311
total_attrib [i ] += (eval_diff * mask .to (attrib_type )).sum (dim = 0 )
1312
+
1227
1313
return total_attrib , weights
1228
1314
1229
1315
def _fut_tuple_to_accumulate_fut_list (
0 commit comments