3
3
# pyre-strict
4
4
5
5
import math
6
- from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union
6
+ from typing import (
7
+ Any,
8
+ Callable,
9
+ cast,
10
+ Dict,
11
+ Generator,
12
+ List,
13
+ Optional,
14
+ Tuple,
15
+ TypeVar,
16
+ Union,
17
+ )
7
18
8
19
import torch
9
20
from captum._utils.common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465
476
attrib_type: dtype,
466
477
**kwargs: Any,
467
478
) -> Tuple[List[Tensor], List[Tensor]]:
479
+ feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
480
+ for i, mask in enumerate(formatted_feature_mask):
481
+ for feature_idx in torch.unique(mask):
482
+ if feature_idx.item() not in feature_idx_to_tensor_idx.keys():
483
+ feature_idx_to_tensor_idx[feature_idx.item()] = []
484
+ feature_idx_to_tensor_idx[feature_idx.item()].append(i)
485
+
468
486
for (
469
487
current_inputs,
470
488
current_mask,
471
489
) in self._ablation_generator(
472
490
formatted_inputs,
473
491
baselines,
474
492
formatted_feature_mask,
493
+ feature_idx_to_tensor_idx,
475
494
**kwargs,
476
495
):
477
496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,11 +530,12 @@ def _ablation_generator(
511
530
inputs: Tuple[Tensor, ...],
512
531
baselines: BaselineType,
513
532
input_mask: Tuple[Tensor, ...],
533
+ feature_idx_to_tensor_idx: Dict[int, List[int]],
514
534
**kwargs: Any,
515
535
) -> Generator[
516
536
Tuple[
517
537
Tuple[Tensor, ...],
518
- Tuple[Tensor, ...],
538
+ Tuple[Optional[ Tensor] , ...],
519
539
],
520
540
None,
521
541
None,
@@ -531,7 +551,11 @@ def _ablation_generator(
531
551
for feature_idx in unique_feature_ids:
532
552
ablated_inputs, current_masks = (
533
553
self._construct_ablated_input_across_tensors(
534
- inputs, input_mask, baselines, feature_idx
554
+ inputs,
555
+ input_mask,
556
+ baselines,
557
+ feature_idx,
558
+ feature_idx_to_tensor_idx[feature_idx],
535
559
)
536
560
)
537
561
yield ablated_inputs, current_masks
@@ -542,18 +566,17 @@ def _construct_ablated_input_across_tensors(
542
566
input_mask: Tuple[Tensor, ...],
543
567
baselines: BaselineType,
544
568
feature_idx: int,
545
- ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
569
+ tensor_idxs: List[int],
570
+ ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
546
571
547
572
ablated_inputs = []
548
- current_masks = []
573
+ current_masks: List[Optional[Tensor]] = []
549
574
for i, input_tensor in enumerate(inputs):
550
- mask = input_mask[i]
551
- tensor_mask = mask == feature_idx
552
- if not tensor_mask.any():
575
+ if i not in tensor_idxs:
553
576
ablated_inputs.append(input_tensor)
554
- current_masks.append(torch.zeros_like(tensor_mask) )
577
+ current_masks.append(None )
555
578
continue
556
- tensor_mask = tensor_mask .to(input_tensor.device).long()
579
+ tensor_mask = (input_mask[i] == feature_idx) .to(input_tensor.device).long()
557
580
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
558
581
if isinstance(baseline, torch.Tensor):
559
582
baseline = baseline.reshape(
@@ -1173,7 +1196,7 @@ def _process_ablated_out(
1173
1196
def _process_ablated_out_full(
1174
1197
self,
1175
1198
modified_eval: Tensor,
1176
- current_mask: Tuple[Tensor, ...],
1199
+ current_mask: Tuple[Optional[ Tensor] , ...],
1177
1200
flattened_initial_eval: Tensor,
1178
1201
inputs: TensorOrTupleOfTensorsGeneric,
1179
1202
n_outputs: int,
@@ -1195,9 +1218,10 @@ def _process_ablated_out_full(
1195
1218
1196
1219
if self.use_weights:
1197
1220
for weight, mask in zip(weights, current_mask):
1198
- weight += mask.float().sum(dim=0)
1221
+ if mask is not None:
1222
+ weight += mask.float().sum(dim=0)
1199
1223
for i, mask in enumerate(current_mask):
1200
- if inputs[i].numel() == 0:
1224
+ if mask is None or inputs[i].numel() == 0:
1201
1225
continue
1202
1226
eval_diff = eval_diff.reshape(
1203
1227
eval_diff_shape + (inputs[i].dim() - 1) * (1,)
0 commit comments