3
3
# pyre-strict
4
4
5
5
import math
6
- from typing import Any , Callable , cast , Generator , List , Optional , Tuple , TypeVar , Union
6
+ from typing import (
7
+ Any ,
8
+ Callable ,
9
+ cast ,
10
+ Dict ,
11
+ Generator ,
12
+ List ,
13
+ Optional ,
14
+ Tuple ,
15
+ TypeVar ,
16
+ Union ,
17
+ )
7
18
8
19
import torch
9
20
from captum ._utils .common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465
476
attrib_type : dtype ,
466
477
** kwargs : Any ,
467
478
) -> Tuple [List [Tensor ], List [Tensor ]]:
479
+ feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
480
+ for i , mask in enumerate (formatted_feature_mask ):
481
+ for feature_idx in torch .unique (mask ):
482
+ if feature_idx .item () not in feature_idx_to_tensor_idx .keys ():
483
+ feature_idx_to_tensor_idx [feature_idx .item ()] = []
484
+ feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
485
+
468
486
for (
469
487
current_inputs ,
470
488
current_mask ,
471
489
) in self ._ablation_generator (
472
490
formatted_inputs ,
473
491
baselines ,
474
492
formatted_feature_mask ,
493
+ feature_idx_to_tensor_idx ,
475
494
** kwargs ,
476
495
):
477
496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,11 +530,12 @@ def _ablation_generator(
511
530
inputs : Tuple [Tensor , ...],
512
531
baselines : BaselineType ,
513
532
input_mask : Tuple [Tensor , ...],
533
+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
514
534
** kwargs : Any ,
515
535
) -> Generator [
516
536
Tuple [
517
537
Tuple [Tensor , ...],
518
- Tuple [Tensor , ...],
538
+ Tuple [Optional [ Tensor ] , ...],
519
539
],
520
540
None ,
521
541
None ,
@@ -531,7 +551,11 @@ def _ablation_generator(
531
551
for feature_idx in unique_feature_ids :
532
552
ablated_inputs , current_masks = (
533
553
self ._construct_ablated_input_across_tensors (
534
- inputs , input_mask , baselines , feature_idx
554
+ inputs ,
555
+ input_mask ,
556
+ baselines ,
557
+ feature_idx ,
558
+ feature_idx_to_tensor_idx [feature_idx ],
535
559
)
536
560
)
537
561
yield ablated_inputs , current_masks
@@ -542,18 +566,17 @@ def _construct_ablated_input_across_tensors(
542
566
input_mask : Tuple [Tensor , ...],
543
567
baselines : BaselineType ,
544
568
feature_idx : int ,
545
- ) -> Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]]:
569
+ tensor_idxs : List [int ],
570
+ ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
546
571
547
572
ablated_inputs = []
548
- current_masks = []
573
+ current_masks : List [ Optional [ Tensor ]] = []
549
574
for i , input_tensor in enumerate (inputs ):
550
- mask = input_mask [i ]
551
- tensor_mask = mask == feature_idx
552
- if not tensor_mask .any ():
575
+ if i not in tensor_idxs :
553
576
ablated_inputs .append (input_tensor )
554
- current_masks .append (torch . zeros_like ( tensor_mask ) )
577
+ current_masks .append (None )
555
578
continue
556
- tensor_mask = tensor_mask .to (input_tensor .device ).long ()
579
+ tensor_mask = ( input_mask [ i ] == feature_idx ) .to (input_tensor .device ).long ()
557
580
baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
558
581
if isinstance (baseline , torch .Tensor ):
559
582
baseline = baseline .reshape (
@@ -1173,7 +1196,7 @@ def _process_ablated_out(
1173
1196
def _process_ablated_out_full (
1174
1197
self ,
1175
1198
modified_eval : Tensor ,
1176
- current_mask : Tuple [Tensor , ...],
1199
+ current_mask : Tuple [Optional [ Tensor ] , ...],
1177
1200
flattened_initial_eval : Tensor ,
1178
1201
inputs : TensorOrTupleOfTensorsGeneric ,
1179
1202
n_outputs : int ,
@@ -1195,9 +1218,10 @@ def _process_ablated_out_full(
1195
1218
1196
1219
if self .use_weights :
1197
1220
for weight , mask in zip (weights , current_mask ):
1198
- weight += mask .float ().sum (dim = 0 )
1221
+ if mask is not None :
1222
+ weight += mask .float ().sum (dim = 0 )
1199
1223
for i , mask in enumerate (current_mask ):
1200
- if inputs [i ].numel () == 0 :
1224
+ if mask is None or inputs [i ].numel () == 0 :
1201
1225
continue
1202
1226
eval_diff = eval_diff .reshape (
1203
1227
eval_diff_shape + (inputs [i ].dim () - 1 ) * (1 ,)
0 commit comments