3
3
# pyre-strict
4
4
5
5
import math
6
- from typing import Any , Callable , cast , Generator , List , Optional , Tuple , TypeVar , Union
6
+ from typing import (
7
+ Any ,
8
+ Callable ,
9
+ cast ,
10
+ Dict ,
11
+ Generator ,
12
+ List ,
13
+ Optional ,
14
+ Tuple ,
15
+ TypeVar ,
16
+ Union ,
17
+ )
7
18
8
19
import torch
9
20
from captum ._utils .common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465
476
attrib_type : dtype ,
466
477
** kwargs : Any ,
467
478
) -> Tuple [List [Tensor ], List [Tensor ]]:
479
+ feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
480
+ for i , mask in enumerate (formatted_feature_mask ):
481
+ for feature_idx in torch .unique (mask ):
482
+ if feature_idx .item () not in feature_idx_to_tensor_idx :
483
+ feature_idx_to_tensor_idx [feature_idx .item ()] = []
484
+ feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
485
+
468
486
for (
469
487
current_inputs ,
470
488
current_mask ,
471
489
) in self ._ablation_generator (
472
490
formatted_inputs ,
473
491
baselines ,
474
492
formatted_feature_mask ,
493
+ feature_idx_to_tensor_idx ,
475
494
** kwargs ,
476
495
):
477
496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,27 +530,28 @@ def _ablation_generator(
511
530
inputs : Tuple [Tensor , ...],
512
531
baselines : BaselineType ,
513
532
input_mask : Tuple [Tensor , ...],
533
+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
514
534
** kwargs : Any ,
515
535
) -> Generator [
516
536
Tuple [
517
537
Tuple [Tensor , ...],
518
- Tuple [Tensor , ...],
538
+ Tuple [Optional [ Tensor ] , ...],
519
539
],
520
540
None ,
521
541
None ,
522
542
]:
523
- unique_feature_ids = torch .unique (
524
- torch .cat ([mask .flatten () for mask in input_mask ])
525
- ).tolist ()
526
-
527
543
if isinstance (baselines , torch .Tensor ):
528
544
baselines = baselines .reshape ((1 ,) + tuple (baselines .shape ))
529
545
530
546
# Process one feature per time, rather than processing every input tensor
531
- for feature_idx in unique_feature_ids :
547
+ for feature_idx in feature_idx_to_tensor_idx . keys () :
532
548
ablated_inputs , current_masks = (
533
549
self ._construct_ablated_input_across_tensors (
534
- inputs , input_mask , baselines , feature_idx
550
+ inputs ,
551
+ input_mask ,
552
+ baselines ,
553
+ feature_idx ,
554
+ feature_idx_to_tensor_idx [feature_idx ],
535
555
)
536
556
)
537
557
yield ablated_inputs , current_masks
@@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
542
562
input_mask : Tuple [Tensor , ...],
543
563
baselines : BaselineType ,
544
564
feature_idx : int ,
545
- ) -> Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]]:
565
+ tensor_idxs : List [int ],
566
+ ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
546
567
547
568
ablated_inputs = []
548
- current_masks = []
569
+ current_masks : List [ Optional [ Tensor ]] = []
549
570
for i , input_tensor in enumerate (inputs ):
550
- mask = input_mask [i ]
551
- tensor_mask = mask == feature_idx
552
- if not tensor_mask .any ():
571
+ if i not in tensor_idxs :
553
572
ablated_inputs .append (input_tensor )
554
- current_masks .append (torch . zeros_like ( tensor_mask ) )
573
+ current_masks .append (None )
555
574
continue
556
- tensor_mask = tensor_mask .to (input_tensor .device ).long ()
575
+ tensor_mask = ( input_mask [ i ] == feature_idx ) .to (input_tensor .device ).long ()
557
576
baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
558
577
if isinstance (baseline , torch .Tensor ):
559
578
baseline = baseline .reshape (
@@ -1173,7 +1192,7 @@ def _process_ablated_out(
1173
1192
def _process_ablated_out_full (
1174
1193
self ,
1175
1194
modified_eval : Tensor ,
1176
- current_mask : Tuple [Tensor , ...],
1195
+ current_mask : Tuple [Optional [ Tensor ] , ...],
1177
1196
flattened_initial_eval : Tensor ,
1178
1197
inputs : TensorOrTupleOfTensorsGeneric ,
1179
1198
n_outputs : int ,
@@ -1195,9 +1214,10 @@ def _process_ablated_out_full(
1195
1214
1196
1215
if self .use_weights :
1197
1216
for weight , mask in zip (weights , current_mask ):
1198
- weight += mask .float ().sum (dim = 0 )
1217
+ if mask is not None :
1218
+ weight += mask .float ().sum (dim = 0 )
1199
1219
for i , mask in enumerate (current_mask ):
1200
- if inputs [i ].numel () == 0 :
1220
+ if mask is None or inputs [i ].numel () == 0 :
1201
1221
continue
1202
1222
eval_diff = eval_diff .reshape (
1203
1223
eval_diff_shape + (inputs [i ].dim () - 1 ) * (1 ,)
0 commit comments