Skip to content

Commit 218636f

Browse files
authored
Merge branch 'master' into lime
2 parents 68138ed + b917b2a commit 218636f

File tree

2 files changed

+51
-26
lines changed

2 files changed

+51
-26
lines changed

captum/attr/_core/feature_ablation.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
# pyre-strict
44

55
import math
6-
from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union
6+
from typing import (
7+
Any,
8+
Callable,
9+
cast,
10+
Dict,
11+
Generator,
12+
List,
13+
Optional,
14+
Tuple,
15+
TypeVar,
16+
Union,
17+
)
718

819
import torch
920
from captum._utils.common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476
attrib_type: dtype,
466477
**kwargs: Any,
467478
) -> Tuple[List[Tensor], List[Tensor]]:
479+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
480+
for i, mask in enumerate(formatted_feature_mask):
481+
for feature_idx in torch.unique(mask):
482+
if feature_idx.item() not in feature_idx_to_tensor_idx:
483+
feature_idx_to_tensor_idx[feature_idx.item()] = []
484+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
485+
468486
for (
469487
current_inputs,
470488
current_mask,
471489
) in self._ablation_generator(
472490
formatted_inputs,
473491
baselines,
474492
formatted_feature_mask,
493+
feature_idx_to_tensor_idx,
475494
**kwargs,
476495
):
477496
# modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,27 +530,28 @@ def _ablation_generator(
511530
inputs: Tuple[Tensor, ...],
512531
baselines: BaselineType,
513532
input_mask: Tuple[Tensor, ...],
533+
feature_idx_to_tensor_idx: Dict[int, List[int]],
514534
**kwargs: Any,
515535
) -> Generator[
516536
Tuple[
517537
Tuple[Tensor, ...],
518-
Tuple[Tensor, ...],
538+
Tuple[Optional[Tensor], ...],
519539
],
520540
None,
521541
None,
522542
]:
523-
unique_feature_ids = torch.unique(
524-
torch.cat([mask.flatten() for mask in input_mask])
525-
).tolist()
526-
527543
if isinstance(baselines, torch.Tensor):
528544
baselines = baselines.reshape((1,) + tuple(baselines.shape))
529545

530546
# Process one feature per time, rather than processing every input tensor
531-
for feature_idx in unique_feature_ids:
547+
for feature_idx in feature_idx_to_tensor_idx.keys():
532548
ablated_inputs, current_masks = (
533549
self._construct_ablated_input_across_tensors(
534-
inputs, input_mask, baselines, feature_idx
550+
inputs,
551+
input_mask,
552+
baselines,
553+
feature_idx,
554+
feature_idx_to_tensor_idx[feature_idx],
535555
)
536556
)
537557
yield ablated_inputs, current_masks
@@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
542562
input_mask: Tuple[Tensor, ...],
543563
baselines: BaselineType,
544564
feature_idx: int,
545-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
565+
tensor_idxs: List[int],
566+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
546567

547568
ablated_inputs = []
548-
current_masks = []
569+
current_masks: List[Optional[Tensor]] = []
549570
for i, input_tensor in enumerate(inputs):
550-
mask = input_mask[i]
551-
tensor_mask = mask == feature_idx
552-
if not tensor_mask.any():
571+
if i not in tensor_idxs:
553572
ablated_inputs.append(input_tensor)
554-
current_masks.append(torch.zeros_like(tensor_mask))
573+
current_masks.append(None)
555574
continue
556-
tensor_mask = tensor_mask.to(input_tensor.device).long()
575+
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
557576
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
558577
if isinstance(baseline, torch.Tensor):
559578
baseline = baseline.reshape(
@@ -1173,7 +1192,7 @@ def _process_ablated_out(
11731192
def _process_ablated_out_full(
11741193
self,
11751194
modified_eval: Tensor,
1176-
current_mask: Tuple[Tensor, ...],
1195+
current_mask: Tuple[Optional[Tensor], ...],
11771196
flattened_initial_eval: Tensor,
11781197
inputs: TensorOrTupleOfTensorsGeneric,
11791198
n_outputs: int,
@@ -1195,9 +1214,10 @@ def _process_ablated_out_full(
11951214

11961215
if self.use_weights:
11971216
for weight, mask in zip(weights, current_mask):
1198-
weight += mask.float().sum(dim=0)
1217+
if mask is not None:
1218+
weight += mask.float().sum(dim=0)
11991219
for i, mask in enumerate(current_mask):
1200-
if inputs[i].numel() == 0:
1220+
if mask is None or inputs[i].numel() == 0:
12011221
continue
12021222
eval_diff = eval_diff.reshape(
12031223
eval_diff_shape + (inputs[i].dim() - 1) * (1,)

captum/attr/_core/feature_permutation.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import torch
77
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
@@ -26,15 +26,15 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
2626

2727

2828
def _permute_features_across_tensors(
29-
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...]
29+
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...]
3030
) -> Tuple[Tensor, ...]:
3131
"""
3232
Permutes features across multiple input tensors using the corresponding
3333
feature masks.
3434
"""
3535
permuted_outputs = []
3636
for input_tensor, feature_mask in zip(inputs, feature_masks):
37-
if not feature_mask.any():
37+
if feature_mask is None or not feature_mask.any():
3838
permuted_outputs.append(input_tensor)
3939
continue
4040
n = input_tensor.size(0)
@@ -103,7 +103,7 @@ def __init__(
103103
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
104104
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
105105
perm_func_cross_tensor: Callable[
106-
[Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...]
106+
[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...]
107107
] = _permute_features_across_tensors,
108108
) -> None:
109109
r"""
@@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors(
392392
input_mask: Tuple[Tensor, ...],
393393
baselines: BaselineType,
394394
feature_idx: int,
395-
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
396-
feature_masks = tuple(
397-
(mask == feature_idx).to(inputs[0].device) for mask in input_mask
398-
)
395+
tensor_idxs: List[int],
396+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
397+
current_masks: List[Optional[Tensor]] = []
398+
for i, mask in enumerate(input_mask):
399+
if i in tensor_idxs:
400+
current_masks.append((mask == feature_idx).to(inputs[0].device))
401+
else:
402+
current_masks.append(None)
403+
feature_masks = tuple(current_masks)
399404
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
400405
return permuted_outputs, feature_masks

0 commit comments

Comments
 (0)